ref: Introduce proxylisten to flow.Manager.

MultiPart: 2/2

Change-Id: I04166147e2b07702c970e80082722a4133406034
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index 60fa43e..5849a5c 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -40,20 +40,18 @@
 
 	mu              *sync.Mutex
 	listenEndpoints []naming.Endpoint
-	proxyEndpoints  map[string][]naming.Endpoint // keyed by proxy address
 	listeners       []flow.Listener
 	wg              sync.WaitGroup
 }
 
 func New(ctx *context.T, rid naming.RoutingID) flow.Manager {
 	m := &manager{
-		rid:            rid,
-		closed:         make(chan struct{}),
-		q:              upcqueue.New(),
-		cache:          NewConnCache(),
-		mu:             &sync.Mutex{},
-		proxyEndpoints: make(map[string][]naming.Endpoint),
-		listeners:      []flow.Listener{},
+		rid:       rid,
+		closed:    make(chan struct{}),
+		q:         upcqueue.New(),
+		cache:     NewConnCache(),
+		mu:        &sync.Mutex{},
+		listeners: []flow.Listener{},
 	}
 	go func() {
 		ticker := time.NewTicker(reapCacheInterval)
@@ -91,13 +89,6 @@
 	if err := m.validateContext(ctx); err != nil {
 		return err
 	}
-	if protocol == inaming.Network {
-		return m.proxyListen(ctx, address)
-	}
-	return m.listen(ctx, protocol, address)
-}
-
-func (m *manager) listen(ctx *context.T, protocol, address string) error {
 	ln, err := listen(ctx, protocol, address)
 	if err != nil {
 		return flow.NewErrNetwork(ctx, err)
@@ -121,18 +112,26 @@
 	return nil
 }
 
-func (m *manager) proxyListen(ctx *context.T, address string) error {
-	ep, err := inaming.NewEndpoint(address)
-	if err != nil {
-		return flow.NewErrBadArg(ctx, err)
+// ProxyListen causes the Manager to accept flows from the specified endpoint.
+// The endpoint must correspond to a vanadium proxy.
+//
+// update gets passed the complete set of endpoints for the proxy every time it
+// is called.
+//
+// The flow.Manager associated with ctx must be the receiver of the method,
+// otherwise an error is returned.
+func (m *manager) ProxyListen(ctx *context.T, ep naming.Endpoint, update func([]naming.Endpoint)) error {
+	if err := m.validateContext(ctx); err != nil {
+		return err
 	}
 	m.wg.Add(1)
-	go m.connectToProxy(ctx, address, ep)
+	go m.connectToProxy(ctx, ep, update)
 	return nil
 }
 
-func (m *manager) connectToProxy(ctx *context.T, address string, ep naming.Endpoint) {
+func (m *manager) connectToProxy(ctx *context.T, ep naming.Endpoint, update func([]naming.Endpoint)) {
 	defer m.wg.Done()
+	var eps []naming.Endpoint
 	for delay := reconnectDelay; ; delay *= 2 {
 		time.Sleep(delay - reconnectDelay)
 		select {
@@ -154,21 +153,17 @@
 			ctx.Error(err)
 			continue
 		}
-		eps, err := m.readProxyResponse(ctx, f)
+		eps, err = m.readProxyResponse(ctx, f)
 		if err != nil {
 			ctx.Error(err)
 			continue
 		}
-		m.mu.Lock()
-		m.proxyEndpoints[address] = eps
-		m.mu.Unlock()
+		update(eps)
 		select {
 		case <-ctx.Done():
 			return
 		case <-f.Closed():
-			m.mu.Lock()
-			delete(m.proxyEndpoints, address)
-			m.mu.Unlock()
+			update(nil)
 			delay = reconnectDelay
 		}
 	}
@@ -279,7 +274,8 @@
 }
 
 // ListeningEndpoints returns the endpoints that the Manager has explicitly
-// listened on. The Manager will accept new flows on these endpoints.
+// called Listen on. The Manager will accept new flows on these endpoints.
+// Proxied endpoints are not returned.
 // If the Manager is not listening on any endpoints, an endpoint with the
 // Manager's RoutingID will be returned for use in bidirectional RPC.
 // Returned endpoints all have the Manager's unique RoutingID.
@@ -287,9 +283,6 @@
 	m.mu.Lock()
 	ret := make([]naming.Endpoint, len(m.listenEndpoints))
 	copy(ret, m.listenEndpoints)
-	for _, peps := range m.proxyEndpoints {
-		ret = append(ret, peps...)
-	}
 	m.mu.Unlock()
 	if len(ret) == 0 {
 		ret = append(ret, &inaming.Endpoint{RID: m.rid})
diff --git a/runtime/internal/rpc/xserver.go b/runtime/internal/rpc/xserver.go
index e5c7f6f..110f359 100644
--- a/runtime/internal/rpc/xserver.go
+++ b/runtime/internal/rpc/xserver.go
@@ -51,9 +51,13 @@
 	dhcpState         *dhcpState          // dhcpState, nil if not using dhcp
 	principal         security.Principal
 	blessings         security.Blessings
-	protoEndpoints    []*inaming.Endpoint
-	chosenEndpoints   []*inaming.Endpoint
 	typeCache         *typeCache
+	addressChooser    rpc.AddressChooser
+
+	mu              sync.Mutex
+	chosenEndpoints map[string]*inaming.Endpoint            // endpoints chosen by the addressChooser for publishing.
+	protoEndpoints  map[string]*inaming.Endpoint            // endpoints that act as "template" endpoints.
+	proxyEndpoints  map[string]map[string]*inaming.Endpoint // keyed by ep.String()
 
 	disp               rpc.Dispatcher // dispatcher to serve RPCs
 	dispReserved       rpc.Dispatcher // dispatcher for reserved methods
@@ -107,6 +111,7 @@
 		settingsName:      settingsName,
 		disp:              dispatcher,
 		typeCache:         newTypeCache(),
+		proxyEndpoints:    make(map[string]map[string]*inaming.Endpoint),
 	}
 	ipNets, err := ipNetworks()
 	if err != nil {
@@ -139,9 +144,11 @@
 		return nil, err
 	}
 	if len(name) > 0 {
-		for _, ep := range s.chosenEndpoints {
-			s.publisher.AddServer(ep.String())
+		s.mu.Lock()
+		for k, _ := range s.chosenEndpoints {
+			s.publisher.AddServer(k)
 		}
+		s.mu.Unlock()
 		s.publisher.AddName(name, s.servesMountTable, s.isLeaf)
 		vtrace.GetSpan(s.ctx).Annotate("Serving under name: " + name)
 	}
@@ -150,9 +157,11 @@
 
 func (s *xserver) Status() rpc.ServerStatus {
 	ret := rpc.ServerStatus{}
+	s.mu.Lock()
 	for _, e := range s.chosenEndpoints {
 		ret.Endpoints = append(ret.Endpoints, e)
 	}
+	s.mu.Unlock()
 	return ret
 }
 
@@ -175,12 +184,12 @@
 }
 
 // resolveToEndpoint resolves an object name or address to an endpoint.
-func (s *xserver) resolveToEndpoint(address string) (string, error) {
+func (s *xserver) resolveToEndpoint(address string) (naming.Endpoint, error) {
 	var resolved *naming.MountEntry
 	var err error
 	if s.ns != nil {
 		if resolved, err = s.ns.Resolve(s.ctx, address); err != nil {
-			return "", err
+			return nil, err
 		}
 	} else {
 		// Fake a namespace resolution
@@ -190,7 +199,7 @@
 	}
 	// An empty set of protocols means all protocols...
 	if resolved.Servers, err = filterAndOrderServers(resolved.Servers, s.preferredProtocols, s.ipNets); err != nil {
-		return "", err
+		return nil, err
 	}
 	for _, n := range resolved.Names() {
 		address, suffix := naming.SplitAddressName(n)
@@ -198,10 +207,10 @@
 			continue
 		}
 		if ep, err := inaming.NewEndpoint(address); err == nil {
-			return ep.String(), nil
+			return ep, nil
 		}
 	}
-	return "", verror.New(errFailedToResolveToEndpoint, s.ctx, address)
+	return nil, verror.New(errFailedToResolveToEndpoint, s.ctx, address)
 }
 
 // createEndpoints creates appropriate inaming.Endpoint instances for
@@ -239,17 +248,62 @@
 	return ieps, port, unspecified, nil
 }
 
+func (s *xserver) update(pep naming.Endpoint) func([]naming.Endpoint) {
+	return func(leps []naming.Endpoint) {
+		chosenEps := make(map[string]*inaming.Endpoint)
+		pkey := pep.String()
+		for _, ep := range leps {
+			eps, _, _, _ := s.createEndpoints(ep, s.addressChooser)
+			for _, cep := range eps {
+				chosenEps[cep.String()] = cep
+			}
+			// TODO(suharshs): do protoEndpoints need to be handled here?
+		}
+
+		// Endpoints to add and remove.
+		s.mu.Lock()
+		oldEps := s.proxyEndpoints[pkey]
+		s.proxyEndpoints[pkey] = chosenEps
+		rmEps := setDiff(oldEps, chosenEps)
+		addEps := setDiff(chosenEps, oldEps)
+		for k := range rmEps {
+			delete(s.chosenEndpoints, k)
+		}
+		for k, ep := range addEps {
+			s.chosenEndpoints[k] = ep
+		}
+		s.mu.Unlock()
+
+		for k := range rmEps {
+			s.publisher.RemoveServer(k)
+		}
+		for k := range addEps {
+			s.publisher.AddServer(k)
+		}
+	}
+}
+
+// setDiff returns the endpoints in a that are not in b.
+func setDiff(a, b map[string]*inaming.Endpoint) map[string]*inaming.Endpoint {
+	ret := make(map[string]*inaming.Endpoint)
+	for k, ep := range a {
+		if _, ok := b[k]; !ok {
+			ret[k] = ep
+		}
+	}
+	return ret
+}
+
 func (s *xserver) listen(ctx *context.T, listenSpec rpc.ListenSpec) error {
 	s.Lock()
 	defer s.Unlock()
 	var lastErr error
-	var ep string
 	if len(listenSpec.Proxy) > 0 {
-		ep, lastErr = s.resolveToEndpoint(listenSpec.Proxy)
-		if lastErr != nil {
+		var ep naming.Endpoint
+		if ep, lastErr = s.resolveToEndpoint(listenSpec.Proxy); lastErr != nil {
 			s.ctx.VI(2).Infof("resolveToEndpoint(%q) failed: %v", listenSpec.Proxy, lastErr)
 		} else {
-			lastErr = s.flowMgr.Listen(ctx, inaming.Network, ep)
+			lastErr = s.flowMgr.ProxyListen(ctx, ep, s.update(ep))
 			if lastErr != nil {
 				s.ctx.VI(2).Infof("Listen(%q, %q, ...) failed: %v", inaming.Network, ep, lastErr)
 			}
@@ -269,15 +323,24 @@
 		return verror.New(verror.ErrBadArg, s.ctx, verror.New(errNoListeners, s.ctx, lastErr))
 	}
 
+	s.addressChooser = listenSpec.AddressChooser
 	roaming := false
+	chosenEps := make(map[string]*inaming.Endpoint)
+	protoEps := make(map[string]*inaming.Endpoint)
 	for _, ep := range leps {
 		eps, _, eproaming, eperr := s.createEndpoints(ep, listenSpec.AddressChooser)
-		s.chosenEndpoints = append(s.chosenEndpoints, eps...)
+		for _, cep := range eps {
+			chosenEps[cep.String()] = cep
+		}
 		if eproaming && eperr == nil {
-			s.protoEndpoints = append(s.protoEndpoints, ep.(*inaming.Endpoint))
+			protoEps[ep.String()] = ep.(*inaming.Endpoint)
 			roaming = true
 		}
 	}
+	s.mu.Lock()
+	s.chosenEndpoints = chosenEps
+	s.protoEndpoints = protoEps
+	s.mu.Unlock()
 
 	if roaming && s.dhcpState == nil && s.settingsPublisher != nil {
 		// TODO(mattr): Support roaming.
diff --git a/services/xproxyd/proxy_test.go b/services/xproxyd/proxy_test.go
index 960e533..34130c5 100644
--- a/services/xproxyd/proxy_test.go
+++ b/services/xproxyd/proxy_test.go
@@ -7,6 +7,7 @@
 import (
 	"bufio"
 	"fmt"
+	"os"
 	"strings"
 	"sync"
 	"testing"
@@ -29,6 +30,85 @@
 	pollTime     = 50 * time.Millisecond
 )
 
+type testService struct{}
+
+func (t *testService) Echo(ctx *context.T, call rpc.ServerCall, arg string) (string, error) {
+	return "response:" + arg, nil
+}
+
+func TestProxyRPC(t *testing.T) {
+	if os.Getenv("V23_RPC_TRANSITION_STATE") != "xservers" {
+		t.Skip("Test only runs under 'V23_RPC_TRANSITION_STATE==xservers'")
+	}
+	defer goroutines.NoLeaks(t, leakWaitTime)()
+	ctx, shutdown := v23.Init()
+	defer shutdown()
+
+	// Start the proxy.
+	pep := startProxy(t, ctx, address{"tcp", "127.0.0.1:0"})
+
+	// Start the server listening through the proxy.
+	ctx = v23.WithListenSpec(ctx, rpc.ListenSpec{Proxy: pep.Name()})
+	_, s, err := v23.WithNewServer(ctx, "", &testService{}, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Wait for the server to finish listening through the proxy.
+	eps := s.Status().Endpoints
+	for ; len(eps) < 2 || eps[1].Addr().Network() == ""; eps = s.Status().Endpoints {
+		time.Sleep(pollTime)
+	}
+
+	var got string
+	if err := v23.GetClient(ctx).Call(ctx, eps[1].Name(), "Echo", []interface{}{"hello"}, []interface{}{&got}); err != nil {
+		t.Fatal(err)
+	}
+	if want := "response:hello"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+}
+
+func TestMultipleProxyRPC(t *testing.T) {
+	if os.Getenv("V23_RPC_TRANSITION_STATE") != "xservers" {
+		t.Skip("Test only runs under 'V23_RPC_TRANSITION_STATE==xservers'")
+	}
+	defer goroutines.NoLeaks(t, leakWaitTime)()
+	kp := newKillProtocol()
+	flow.RegisterProtocol("kill", kp)
+	ctx, shutdown := v23.Init()
+	defer shutdown()
+
+	// Start the proxies.
+	pep := startProxy(t, ctx, address{"kill", "127.0.0.1:0"})
+	p2ep := startProxy(t, ctx, address{"v23", pep.String()}, address{"kill", "127.0.0.1:0"})
+
+	// Start the server listening through the proxy.
+	ctx = v23.WithListenSpec(ctx, rpc.ListenSpec{Proxy: p2ep.Name()})
+	_, s, err := v23.WithNewServer(ctx, "", &testService{}, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Create a new flow manager for the client.
+	cctx, _, err := v23.ExperimentalWithNewFlowManager(ctx)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Wait for the server to finish listening through the proxy.
+	eps := s.Status().Endpoints
+	for ; len(eps) == 0 || eps[0].Addr().Network() == ""; eps = s.Status().Endpoints {
+		time.Sleep(pollTime)
+	}
+
+	var got string
+	if err := v23.GetClient(cctx).Call(ctx, eps[0].Name(), "Echo", []interface{}{"hello"}, []interface{}{&got}); err != nil {
+		t.Fatal(err)
+	}
+	if want := "response:hello"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+}
+
+// TODO(suharshs): Remove the below tests when the transition is complete.
 func TestSingleProxy(t *testing.T) {
 	defer goroutines.NoLeaks(t, leakWaitTime)()
 	kp := newKillProtocol()
@@ -46,15 +126,20 @@
 
 	pep := startProxy(t, pctx, address{"kill", "127.0.0.1:0"})
 
-	if err := am.Listen(actx, "v23", pep.String()); err != nil {
+	done := make(chan struct{})
+	update := func(eps []naming.Endpoint) {
+		if len(eps) > 0 {
+			if err := testEndToEndConnection(t, dctx, actx, dm, am, eps[0]); err != nil {
+				t.Error(err)
+			}
+			close(done)
+		}
+	}
+
+	if err := am.ProxyListen(actx, pep, update); err != nil {
 		t.Fatal(err)
 	}
-
-	for am.ListeningEndpoints()[0].Addr().Network() == "" {
-		time.Sleep(pollTime)
-	}
-
-	testEndToEndConnections(t, dctx, actx, dm, am, kp)
+	<-done
 }
 
 func TestMultipleProxies(t *testing.T) {
@@ -78,34 +163,34 @@
 
 	p3ep := startProxy(t, pctx, address{"v23", p2ep.String()}, address{"kill", "127.0.0.1:0"})
 
-	if err := am.Listen(actx, "v23", p3ep.String()); err != nil {
+	ch := make(chan struct{})
+	var allEps []naming.Endpoint
+	idx := 0
+	update := func(eps []naming.Endpoint) {
+		// TODO(suharshs): Fix this test once we have the proxy send update messages to the
+		// server when it reconnects to a proxy.
+		if len(eps) == 3 {
+			allEps = eps
+		}
+		if len(eps) > 0 {
+			if err := testEndToEndConnection(t, dctx, actx, dm, am, allEps[idx]); err != nil {
+				t.Error(err)
+			}
+			idx++
+			ch <- struct{}{}
+		}
+	}
+
+	if err := am.ProxyListen(actx, p3ep, update); err != nil {
 		t.Fatal(err)
 	}
 
-	// Wait for am.Listen to get 3 endpoints.
-	for len(am.ListeningEndpoints()) != 3 {
-		time.Sleep(pollTime)
-	}
-
-	testEndToEndConnections(t, dctx, actx, dm, am, kp)
-}
-
-func testEndToEndConnections(t *testing.T, dctx, actx *context.T, dm, am flow.Manager, kp *killProtocol) {
-	aeps := am.ListeningEndpoints()
-	if len(aeps) == 0 {
-		t.Fatal("acceptor not listening on any endpoints")
-	}
-	for _, aep := range aeps {
-		// Kill the connections, connections should still eventually succeed.
+	<-ch
+	// Test the other two endpoints.
+	for i := 0; i < 2; i++ {
+		// Kill the connections to test reconnection.
 		kp.KillConnections()
-		for {
-			if err := testEndToEndConnection(t, dctx, actx, dm, am, aep); err != nil {
-				t.Log(err)
-				time.Sleep(pollTime)
-				continue
-			}
-			break
-		}
+		<-ch
 	}
 }