From 769e43e334d98b11c4e02ac2f4875f42f082219e Mon Sep 17 00:00:00 2001
From: holisticode <holistic.computing@gmail.com>
Date: Wed, 30 Jan 2019 15:03:08 -0500
Subject: [PATCH] swarm: GetPeerSubscriptions RPC (#18972)

(cherry picked from commit 43e1b7b124d2bcfba98fbe54972a35c022d85bf2)
---
 swarm/network/stream/stream.go               |  24 ++
 swarm/network/stream/streamer_test.go        | 233 ++++++++++++++++++-
 swarm/network/stream/testing/snapshot_4.json |   1 +
 3 files changed, 257 insertions(+), 1 deletion(-)
 create mode 100644 swarm/network/stream/testing/snapshot_4.json

diff --git a/swarm/network/stream/stream.go b/swarm/network/stream/stream.go
index ee4f57c1a3..90da862bd3 100644
--- a/swarm/network/stream/stream.go
+++ b/swarm/network/stream/stream.go
@@ -935,3 +935,27 @@ func (api *API) SubscribeStream(peerId enode.ID, s Stream, history *Range, prior
 func (api *API) UnsubscribeStream(peerId enode.ID, s Stream) error {
 	return api.streamer.Unsubscribe(peerId, s)
 }
+
+/*
+GetPeerSubscriptions is a API function which allows to query a peer for stream subscriptions it has.
+It can be called via RPC.
+It returns a map of node IDs with an array of string representations of Stream objects.
+*/
+func (api *API) GetPeerSubscriptions() map[string][]string {
+	//create the empty map
+	pstreams := make(map[string][]string)
+	//iterate all streamer peers
+	for id, p := range api.streamer.peers {
+		var streams []string
+		//every peer has a map of stream servers
+		//every stream server represents a subscription
+		for s := range p.servers {
+			//append the string representation of the stream
+			//to the list for this peer
+			streams = append(streams, s.String())
+		}
+		//set the array of stream servers to the map
+		pstreams[id.String()] = streams
+	}
+	return pstreams
+}
diff --git a/swarm/network/stream/streamer_test.go b/swarm/network/stream/streamer_test.go
index a41235e073..b83521f060 100644
--- a/swarm/network/stream/streamer_test.go
+++ b/swarm/network/stream/streamer_test.go
@@ -21,15 +21,23 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"os"
 	"strconv"
+	"strings"
+	"sync"
 	"testing"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/node"
 	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
 	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
 	"github.com/ethereum/go-ethereum/swarm/network"
+	"github.com/ethereum/go-ethereum/swarm/network/simulation"
+	"github.com/ethereum/go-ethereum/swarm/state"
+	"github.com/ethereum/go-ethereum/swarm/storage"
 	"golang.org/x/crypto/sha3"
 )
 
@@ -1105,7 +1113,6 @@ func TestRequestPeerSubscriptions(t *testing.T) {
 			}
 		}
 	}
-
 	// print some output
 	for p, subs := range fakeSubscriptions {
 		log.Debug(fmt.Sprintf("Peer %s has the following fake subscriptions: ", p))
@@ -1114,3 +1121,227 @@ func TestRequestPeerSubscriptions(t *testing.T) {
 		}
 	}
 }
+
+// TestGetSubscriptions is a unit test for the api.GetPeerSubscriptions() function
+func TestGetSubscriptions(t *testing.T) {
+	// create an amount of dummy peers
+	testPeerCount := 8
+	// every peer will have this amount of dummy servers
+	testServerCount := 4
+	// the peerMap which will store this data for the registry
+	peerMap := make(map[enode.ID]*Peer)
+	// create the registry
+	r := &Registry{}
+	api := NewAPI(r)
+	// call once, at this point should be empty
+	regs := api.GetPeerSubscriptions()
+	if len(regs) != 0 {
+		t.Fatal("Expected subscription count to be 0, but it is not")
+	}
+
+	// now create a number of dummy servers for each node
+	for i := 0; i < testPeerCount; i++ {
+		addr := network.RandomAddr()
+		id := addr.ID()
+		p := &Peer{}
+		p.servers = make(map[Stream]*server)
+		for k := 0; k < testServerCount; k++ {
+			s := Stream{
+				Name: strconv.Itoa(k),
+				Key:  "",
+				Live: false,
+			}
+			p.servers[s] = &server{}
+		}
+		peerMap[id] = p
+	}
+	r.peers = peerMap
+
+	// call the subscriptions again
+	regs = api.GetPeerSubscriptions()
+	// count how many (fake) subscriptions there are
+	cnt := 0
+	for _, reg := range regs {
+		for range reg {
+			cnt++
+		}
+	}
+	// check expected value
+	expectedCount := testPeerCount * testServerCount
+	if cnt != expectedCount {
+		t.Fatalf("Expected %d subscriptions, but got %d", expectedCount, cnt)
+	}
+}
+
+/*
+TestGetSubscriptionsRPC sets up a simulation network of `nodeCount` nodes,
+starts the simulation, waits for SyncUpdateDelay in order to kick off
+stream registration, then tests that there are subscriptions.
+*/
+func TestGetSubscriptionsRPC(t *testing.T) {
+	// arbitrarily set to 4
+	nodeCount := 4
+	// run with more nodes if `longrunning` flag is set
+	if *longrunning {
+		nodeCount = 64
+	}
+	// set the syncUpdateDelay for sync registrations to start
+	syncUpdateDelay := 200 * time.Millisecond
+	// holds the msg code for SubscribeMsg
+	var subscribeMsgCode uint64
+	var ok bool
+	var expectedMsgCount = 0
+
+	// this channel signalizes that the expected amount of subscriptiosn is done
+	allSubscriptionsDone := make(chan struct{})
+	lock := sync.RWMutex{}
+	// after the test, we need to reset the subscriptionFunc to the default
+	defer func() { subscriptionFunc = doRequestSubscription }()
+
+	// we use this subscriptionFunc for this test: just increases count and calls the actual subscription
+	subscriptionFunc = func(r *Registry, p *network.Peer, bin uint8, subs map[enode.ID]map[Stream]struct{}) bool {
+		lock.Lock()
+		expectedMsgCount++
+		lock.Unlock()
+		doRequestSubscription(r, p, bin, subs)
+		return true
+	}
+	// create a standard sim
+	sim := simulation.New(map[string]simulation.ServiceFunc{
+		"streamer": func(ctx *adapters.ServiceContext, bucket *sync.Map) (s node.Service, cleanup func(), err error) {
+			n := ctx.Config.Node()
+			addr := network.NewAddr(n)
+			store, datadir, err := createTestLocalStorageForID(n.ID(), addr)
+			if err != nil {
+				return nil, nil, err
+			}
+			localStore := store.(*storage.LocalStore)
+			netStore, err := storage.NewNetStore(localStore, nil)
+			if err != nil {
+				return nil, nil, err
+			}
+			kad := network.NewKademlia(addr.Over(), network.NewKadParams())
+			delivery := NewDelivery(kad, netStore)
+			netStore.NewNetFetcherFunc = network.NewFetcherFactory(dummyRequestFromPeers, true).New
+			// configure so that sync registrations actually happen
+			r := NewRegistry(addr.ID(), delivery, netStore, state.NewInmemoryStore(), &RegistryOptions{
+				Retrieval:       RetrievalEnabled,
+				Syncing:         SyncingAutoSubscribe, //enable sync registrations
+				SyncUpdateDelay: syncUpdateDelay,
+			}, nil)
+			// get the SubscribeMsg code
+			subscribeMsgCode, ok = r.GetSpec().GetCode(SubscribeMsg{})
+			if !ok {
+				t.Fatal("Message code for SubscribeMsg not found")
+			}
+
+			cleanup = func() {
+				os.RemoveAll(datadir)
+				netStore.Close()
+				r.Close()
+			}
+
+			return r, cleanup, nil
+
+		},
+	})
+	defer sim.Close()
+
+	ctx, cancelSimRun := context.WithTimeout(context.Background(), 1*time.Minute)
+	defer cancelSimRun()
+
+	// upload a snapshot
+	err := sim.UploadSnapshot(fmt.Sprintf("testing/snapshot_%d.json", nodeCount))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// setup the filter for SubscribeMsg
+	msgs := sim.PeerEvents(
+		context.Background(),
+		sim.NodeIDs(),
+		simulation.NewPeerEventsFilter().ReceivedMessages().Protocol("stream").MsgCode(subscribeMsgCode),
+	)
+
+	// strategy: listen to all SubscribeMsg events; after every event we wait
+	// if after `waitDuration` no more messages are being received, we assume the
+	// subscription phase has terminated!
+
+	// the loop in this go routine will either wait for new message events
+	// or times out after 1 second, which signals that we are not receiving
+	// any new subscriptions any more
+	go func() {
+		//for long running sims, waiting 1 sec will not be enough
+		waitDuration := time.Duration(nodeCount/16) * time.Second
+		for {
+			select {
+			case <-ctx.Done():
+				return
+			case m := <-msgs: // just reset the loop
+				if m.Error != nil {
+					log.Error("stream message", "err", m.Error)
+					continue
+				}
+				log.Trace("stream message", "node", m.NodeID, "peer", m.PeerID)
+			case <-time.After(waitDuration):
+				// one second passed, don't assume more subscriptions
+				allSubscriptionsDone <- struct{}{}
+				log.Info("All subscriptions received")
+				return
+
+			}
+		}
+	}()
+
+	//run the simulation
+	result := sim.Run(ctx, func(ctx context.Context, sim *simulation.Simulation) error {
+		log.Info("Simulation running")
+		nodes := sim.Net.Nodes
+
+		//wait until all subscriptions are done
+		select {
+		case <-allSubscriptionsDone:
+		case <-ctx.Done():
+			t.Fatal("Context timed out")
+		}
+
+		log.Debug("Expected message count: ", "expectedMsgCount", expectedMsgCount)
+		//now iterate again, this time we call each node via RPC to get its subscriptions
+		realCount := 0
+		for _, node := range nodes {
+			//create rpc client
+			client, err := node.Client()
+			if err != nil {
+				t.Fatalf("create node 1 rpc client fail: %v", err)
+			}
+
+			//ask it for subscriptions
+			pstreams := make(map[string][]string)
+			err = client.Call(&pstreams, "stream_getPeerSubscriptions")
+			if err != nil {
+				t.Fatal(err)
+			}
+			//length of the subscriptions can not be smaller than number of peers
+			log.Debug("node subscriptions:", "node", node.String())
+			for p, ps := range pstreams {
+				log.Debug("... with: ", "peer", p)
+				for _, s := range ps {
+					log.Debug(".......", "stream", s)
+					// each node also has subscriptions to RETRIEVE_REQUEST streams,
+					// we need to ignore those, we are only counting SYNC streams
+					if !strings.HasPrefix(s, "RETRIEVE_REQUEST") {
+						realCount++
+					}
+				}
+			}
+		}
+		// every node is mutually subscribed to each other, so the actual count is half of it
+		if realCount/2 != expectedMsgCount {
+			return fmt.Errorf("Real subscriptions and expected amount don't match; real: %d, expected: %d", realCount/2, expectedMsgCount)
+		}
+		return nil
+	})
+	if result.Error != nil {
+		t.Fatal(result.Error)
+	}
+}
diff --git a/swarm/network/stream/testing/snapshot_4.json b/swarm/network/stream/testing/snapshot_4.json
new file mode 100644
index 0000000000..a64f31375a
--- /dev/null
+++ b/swarm/network/stream/testing/snapshot_4.json
@@ -0,0 +1 @@
+{"nodes":[{"node":{"config":{"id":"73d6ad4a75069dced660fa4cb98143ee5573df7cb15d9a295acf1655e9683384","private_key":"e567b7d9c554e5102cdc99b6523bace02dbb8951415c8816d82ba2d2e97fa23b","name":"node01","services":["bzz","pss"],"enable_msg_events":false,"port":0},"up":true}},{"node":{"config":{"id":"6e8da86abb894ab35044c8c455147225df96cab498da067a118f1fb9a417f9e3","private_key":"c7526db70acd02f36d3b201ef3e1d85e38c52bee6931453213dbc5edec4d0976","name":"node02","services":["bzz","pss"],"enable_msg_events":false,"port":0},"up":true}},{"node":{"config":{"id":"8a1eb78ff13df318e7f8116dffee98cd7d9905650fa53f16766b754a63f387ac","private_key":"61b5728f59bc43080c3b8eb0458fb30d7723e2747355b6dc980f35f3ed431199","name":"node03","services":["bzz","pss"],"enable_msg_events":false,"port":0},"up":true}},{"node":{"config":{"id":"d7768334f79d626adb433f44b703a818555e3331056036ef3f8d1282586bf044","private_key":"075b07c29ceac4ffa2a114afd67b21dfc438126bc169bf7c154be6d81d86ed38","name":"node04","services":["bzz","pss"],"enable_msg_events":false,"port":0},"up":true}}],"conns":[{"one":"6e8da86abb894ab35044c8c455147225df96cab498da067a118f1fb9a417f9e3","other":"8a1eb78ff13df318e7f8116dffee98cd7d9905650fa53f16766b754a63f387ac","up":true},{"one":"73d6ad4a75069dced660fa4cb98143ee5573df7cb15d9a295acf1655e9683384","other":"6e8da86abb894ab35044c8c455147225df96cab498da067a118f1fb9a417f9e3","up":true},{"one":"8a1eb78ff13df318e7f8116dffee98cd7d9905650fa53f16766b754a63f387ac","other":"d7768334f79d626adb433f44b703a818555e3331056036ef3f8d1282586bf044","up":true}]}
\ No newline at end of file