Merge "xproxy: Add reconnection to proxy logic."
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index f6432bc..c951916 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -34,18 +34,20 @@
 
 	mu              *sync.Mutex
 	listenEndpoints []naming.Endpoint
+	proxyEndpoints  map[string][]naming.Endpoint // keyed by proxy address
 	listeners       []flow.Listener
 	wg              sync.WaitGroup
 }
 
 func New(ctx *context.T, rid naming.RoutingID) flow.Manager {
 	m := &manager{
-		rid:       rid,
-		closed:    make(chan struct{}),
-		q:         upcqueue.New(),
-		cache:     NewConnCache(),
-		mu:        &sync.Mutex{},
-		listeners: []flow.Listener{},
+		rid:            rid,
+		closed:         make(chan struct{}),
+		q:              upcqueue.New(),
+		cache:          NewConnCache(),
+		mu:             &sync.Mutex{},
+		proxyEndpoints: make(map[string][]naming.Endpoint),
+		listeners:      []flow.Listener{},
 	}
 	go func() {
 		select {
@@ -72,28 +74,16 @@
 // The flow.Manager associated with ctx must be the receiver of the method,
 // otherwise an error is returned.
 func (m *manager) Listen(ctx *context.T, protocol, address string) error {
-	var (
-		eps []naming.Endpoint
-		err error
-	)
 	if protocol == inaming.Network {
-		eps, err = m.proxyListen(ctx, address)
-	} else {
-		eps, err = m.listen(ctx, protocol, address)
+		return m.proxyListen(ctx, address)
 	}
-	if err != nil {
-		return err
-	}
-	m.mu.Lock()
-	m.listenEndpoints = append(m.listenEndpoints, eps...)
-	m.mu.Unlock()
-	return nil
+	return m.listen(ctx, protocol, address)
 }
 
-func (m *manager) listen(ctx *context.T, protocol, address string) ([]naming.Endpoint, error) {
+func (m *manager) listen(ctx *context.T, protocol, address string) error {
 	ln, err := listen(ctx, protocol, address)
 	if err != nil {
-		return nil, flow.NewErrNetwork(ctx, err)
+		return flow.NewErrNetwork(ctx, err)
 	}
 	local := &inaming.Endpoint{
 		Protocol: protocol,
@@ -102,33 +92,70 @@
 	}
 	m.mu.Lock()
 	if m.listeners == nil {
-		return nil, flow.NewErrBadState(ctx, NewErrManagerClosed(ctx))
+		return flow.NewErrBadState(ctx, NewErrManagerClosed(ctx))
 	}
 	m.listeners = append(m.listeners, ln)
 	m.mu.Unlock()
 	m.wg.Add(1)
 	go m.lnAcceptLoop(ctx, ln, local)
-	return []naming.Endpoint{local}, nil
+	m.mu.Lock()
+	m.listenEndpoints = append(m.listenEndpoints, local)
+	m.mu.Unlock()
+	return nil
 }
 
-func (m *manager) proxyListen(ctx *context.T, address string) ([]naming.Endpoint, error) {
+func (m *manager) proxyListen(ctx *context.T, address string) error {
 	ep, err := inaming.NewEndpoint(address)
 	if err != nil {
-		return nil, flow.NewErrBadArg(ctx, err)
+		return flow.NewErrBadArg(ctx, err)
 	}
-	f, err := m.internalDial(ctx, ep, proxyBlessingsForPeer{}.run, &proxyFlowHandler{ctx: ctx, m: m})
-	if err != nil {
-		return nil, flow.NewErrNetwork(ctx, err)
-	}
-	w, err := message.Append(ctx, &message.ProxyServerRequest{}, nil)
-	if err != nil {
-		return nil, flow.NewErrBadArg(ctx, err)
-	}
-	if _, err := f.WriteMsg(w); err != nil {
-		return nil, flow.NewErrBadArg(ctx, err)
-	}
+	m.wg.Add(1)
+	go m.connectToProxy(ctx, address, ep)
+	return nil
+}
 
-	return m.readProxyResponse(ctx, f)
+func (m *manager) connectToProxy(ctx *context.T, address string, ep naming.Endpoint) {
+	defer m.wg.Done()
+	reconnectDelay := 50 * time.Millisecond
+	for delay := reconnectDelay; ; delay *= 2 {
+		time.Sleep(delay - reconnectDelay)
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+		f, err := m.internalDial(ctx, ep, proxyBlessingsForPeer{}.run, &proxyFlowHandler{ctx: ctx, m: m})
+		if err != nil {
+			ctx.Error(err)
+			continue
+		}
+		w, err := message.Append(ctx, &message.ProxyServerRequest{}, nil)
+		if err != nil {
+			ctx.Error(err)
+			continue
+		}
+		if _, err = f.WriteMsg(w); err != nil {
+			ctx.Error(err)
+			continue
+		}
+		eps, err := m.readProxyResponse(ctx, f)
+		if err != nil {
+			ctx.Error(err)
+			continue
+		}
+		m.mu.Lock()
+		m.proxyEndpoints[address] = eps
+		m.mu.Unlock()
+		select {
+		case <-ctx.Done():
+			return
+		case <-f.Closed():
+			m.mu.Lock()
+			delete(m.proxyEndpoints, address)
+			m.mu.Unlock()
+			delay = reconnectDelay
+		}
+	}
 }
 
 func (m *manager) readProxyResponse(ctx *context.T, f flow.Flow) ([]naming.Endpoint, error) {
@@ -236,6 +263,9 @@
 	m.mu.Lock()
 	ret := make([]naming.Endpoint, len(m.listenEndpoints))
 	copy(ret, m.listenEndpoints)
+	for _, peps := range m.proxyEndpoints {
+		ret = append(ret, peps...)
+	}
 	m.mu.Unlock()
 	if len(ret) == 0 {
 		ret = append(ret, &inaming.Endpoint{RID: m.rid})
@@ -334,6 +364,7 @@
 			fh,
 		)
 		if err != nil {
+			flowConn.Close()
 			if verror.ErrorID(err) == message.ErrWrongProtocol.ID {
 				return nil, err
 			}
@@ -350,16 +381,17 @@
 
 	// If we are dialing out to a Proxy, we need to dial a conn on this flow, and
 	// return a flow on that corresponding conn.
-	if remote.RoutingID() != c.RemoteEndpoint().RoutingID() {
+	if proxyConn := c; remote.RoutingID() != proxyConn.RemoteEndpoint().RoutingID() {
 		c, err = conn.NewDialed(
 			ctx,
 			f,
-			c.LocalEndpoint(),
+			proxyConn.LocalEndpoint(),
 			remote,
 			version.Supported,
 			fh,
 		)
 		if err != nil {
+			proxyConn.Close(ctx, err)
 			if verror.ErrorID(err) == message.ErrWrongProtocol.ID {
 				return nil, err
 			}
@@ -370,6 +402,7 @@
 		}
 		f, err = c.Dial(ctx, fn)
 		if err != nil {
+			proxyConn.Close(ctx, err)
 			return nil, flow.NewErrDialFailed(ctx, err)
 		}
 	}
diff --git a/services/xproxyd/proxy_test.go b/services/xproxyd/proxy_test.go
index ebeb262..bb85a2b 100644
--- a/services/xproxyd/proxy_test.go
+++ b/services/xproxyd/proxy_test.go
@@ -6,7 +6,9 @@
 
 import (
 	"bufio"
+	"fmt"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -22,10 +24,15 @@
 	"v.io/v23/security"
 )
 
-const leakWaitTime = 100 * time.Millisecond
+const (
+	leakWaitTime = 100 * time.Millisecond
+	pollTime     = 50 * time.Millisecond
+)
 
-func TestProxiedConnection(t *testing.T) {
+func TestSingleProxy(t *testing.T) {
 	defer goroutines.NoLeaks(t, leakWaitTime)()
+	kp := newKillProtocol()
+	flow.RegisterProtocol("kill", kp)
 	pctx, shutdown := v23.Init()
 	defer shutdown()
 	actx, am, err := v23.ExperimentalWithNewFlowManager(pctx)
@@ -37,16 +44,23 @@
 		t.Fatal(err)
 	}
 
-	pep := startProxy(t, pctx, address{"tcp", "127.0.0.1:0"})
+	pep := startProxy(t, pctx, address{"kill", "127.0.0.1:0"})
 
 	if err := am.Listen(actx, "v23", pep.String()); err != nil {
 		t.Fatal(err)
 	}
-	testEndToEndConnections(t, dctx, actx, dm, am)
+
+	for am.ListeningEndpoints()[0].Addr().Network() == "" {
+		time.Sleep(pollTime)
+	}
+
+	testEndToEndConnections(t, dctx, actx, dm, am, kp)
 }
 
 func TestMultipleProxies(t *testing.T) {
 	defer goroutines.NoLeaks(t, leakWaitTime)()
+	kp := newKillProtocol()
+	flow.RegisterProtocol("kill", kp)
 	pctx, shutdown := v23.Init()
 	defer shutdown()
 	actx, am, err := v23.ExperimentalWithNewFlowManager(pctx)
@@ -58,59 +72,79 @@
 		t.Fatal(err)
 	}
 
-	pep := startProxy(t, pctx, address{"tcp", "127.0.0.1:0"})
+	pep := startProxy(t, pctx, address{"kill", "127.0.0.1:0"})
 
-	p2ep := startProxy(t, pctx, address{"v23", pep.String()}, address{"tcp", "127.0.0.1:0"})
+	p2ep := startProxy(t, pctx, address{"v23", pep.String()}, address{"kill", "127.0.0.1:0"})
 
-	p3ep := startProxy(t, pctx, address{"v23", p2ep.String()}, address{"tcp", "127.0.0.1:0"})
+	p3ep := startProxy(t, pctx, address{"v23", p2ep.String()}, address{"kill", "127.0.0.1:0"})
 
 	if err := am.Listen(actx, "v23", p3ep.String()); err != nil {
 		t.Fatal(err)
 	}
-	testEndToEndConnections(t, dctx, actx, dm, am)
+
+	// Wait for am.Listen to get 3 endpoints.
+	for len(am.ListeningEndpoints()) != 3 {
+		time.Sleep(pollTime)
+	}
+
+	testEndToEndConnections(t, dctx, actx, dm, am, kp)
 }
 
-func testEndToEndConnections(t *testing.T, dctx, actx *context.T, dm, am flow.Manager) {
+func testEndToEndConnections(t *testing.T, dctx, actx *context.T, dm, am flow.Manager, kp *killProtocol) {
 	aeps := am.ListeningEndpoints()
 	if len(aeps) == 0 {
 		t.Fatal("acceptor not listening on any endpoints")
 	}
 	for _, aep := range aeps {
-		testEndToEndConnection(t, dctx, actx, dm, am, aep)
+		// Kill the connections, connections should still eventually succeed.
+		kp.KillConnections()
+		for {
+			if err := testEndToEndConnection(t, dctx, actx, dm, am, aep); err != nil {
+				t.Log(err)
+				time.Sleep(pollTime)
+				continue
+			}
+			break
+		}
 	}
 }
 
-func testEndToEndConnection(t *testing.T, dctx, actx *context.T, dm, am flow.Manager, aep naming.Endpoint) {
+func testEndToEndConnection(t *testing.T, dctx, actx *context.T, dm, am flow.Manager, aep naming.Endpoint) error {
 	// The dialing flow.Manager dials a flow to the accepting flow.Manager.
 	want := "Do you read me?"
 	df, err := dm.Dial(dctx, aep, bfp)
 	if err != nil {
-		t.Fatal(err)
+		return err
 	}
 	// We write before accepting to ensure that the openFlow message is sent.
-	writeLine(df, want)
+	if err := writeLine(df, want); err != nil {
+		return err
+	}
 	af, err := am.Accept(actx)
 	if err != nil {
-		t.Fatal(err)
+		return err
 	}
 	got, err := readLine(af)
 	if err != nil {
-		t.Fatal(err)
+		return err
 	}
 	if got != want {
-		t.Errorf("got %v, want %v", got, want)
+		return fmt.Errorf("got %v, want %v", got, want)
 	}
 
 	// Writes in the opposite direction should work as well.
 	want = "I read you loud and clear."
-	writeLine(af, want)
+	if err := writeLine(af, want); err != nil {
+		return err
+	}
 	got, err = readLine(df)
 	if err != nil {
-		t.Fatal(err)
+		return err
 	}
 	if got != want {
-		t.Errorf("got %v, want %v", got, want)
+		return fmt.Errorf("got %v, want %v", got, want)
 	}
+	return nil
 }
 
 // TODO(suharshs): Add test for bidirectional RPC.
@@ -141,20 +175,69 @@
 
 func startProxy(t *testing.T, ctx *context.T, addrs ...address) naming.Endpoint {
 	var ls rpc.ListenSpec
+	hasProxies := false
 	for _, addr := range addrs {
 		ls.Addrs = append(ls.Addrs, addr)
+		if addr.Protocol == "v23" {
+			hasProxies = true
+		}
 	}
 	ctx = v23.WithListenSpec(ctx, ls)
 	proxy, _, err := xproxyd.New(ctx)
 	if err != nil {
 		t.Fatal(err)
 	}
+	// Wait for the proxy to connect to its proxies.
+	if hasProxies {
+		for len(proxy.MultipleProxyEndpoints()) == 0 {
+			time.Sleep(pollTime)
+		}
+	}
 	peps := proxy.ListeningEndpoints()
 	for _, pep := range peps {
-		if pep.Addr().Network() == "tcp" {
+		if pep.Addr().Network() == "tcp" || pep.Addr().Network() == "kill" {
 			return pep
 		}
 	}
 	t.Fatal("Proxy not listening on network address.")
 	return nil
 }
+
+type killProtocol struct {
+	protocol flow.Protocol
+	mu       sync.Mutex
+	conns    []flow.Conn
+}
+
+func newKillProtocol() *killProtocol {
+	p, _ := flow.RegisteredProtocol("tcp")
+	return &killProtocol{protocol: p}
+}
+
+func (p *killProtocol) KillConnections() {
+	p.mu.Lock()
+	for _, c := range p.conns {
+		c.Close()
+	}
+	p.conns = nil
+	p.mu.Unlock()
+}
+
+func (p *killProtocol) Dial(ctx *context.T, protocol, address string, timeout time.Duration) (flow.Conn, error) {
+	c, err := p.protocol.Dial(ctx, "tcp", address, timeout)
+	if err != nil {
+		return nil, err
+	}
+	p.mu.Lock()
+	p.conns = append(p.conns, c)
+	p.mu.Unlock()
+	return c, nil
+}
+
+func (p *killProtocol) Listen(ctx *context.T, protocol, address string) (flow.Listener, error) {
+	return p.protocol.Listen(ctx, "tcp", address)
+}
+
+func (p *killProtocol) Resolve(ctx *context.T, protocol, address string) (string, string, error) {
+	return p.protocol.Resolve(ctx, "tcp", address)
+}
diff --git a/services/xproxyd/proxyd.go b/services/xproxyd/proxyd.go
index d05a561..83225d2 100644
--- a/services/xproxyd/proxyd.go
+++ b/services/xproxyd/proxyd.go
@@ -5,9 +5,9 @@
 package xproxyd
 
 import (
-	"fmt"
 	"io"
 	"sync"
+	"time"
 
 	"v.io/v23"
 	"v.io/v23/context"
@@ -21,7 +21,7 @@
 type proxy struct {
 	m              flow.Manager
 	mu             sync.Mutex
-	proxyEndpoints []naming.Endpoint
+	proxyEndpoints map[string][]naming.Endpoint // keyed by proxy address
 }
 
 func New(ctx *context.T) (*proxy, *context.T, error) {
@@ -30,7 +30,8 @@
 		return nil, nil, err
 	}
 	p := &proxy{
-		m: mgr,
+		m:              mgr,
+		proxyEndpoints: make(map[string][]naming.Endpoint),
 	}
 	for _, addr := range v23.GetListenSpec(ctx).Addrs {
 		if addr.Protocol == "v23" {
@@ -38,25 +39,7 @@
 			if err != nil {
 				return nil, nil, err
 			}
-			f, err := p.m.Dial(ctx, ep, proxyBlessingsForPeer{}.run)
-			if err != nil {
-				return nil, nil, err
-			}
-			// Send a byte telling the acceptor that we are a proxy.
-			if err := writeMessage(ctx, &message.MultiProxyRequest{}, f); err != nil {
-				return nil, nil, err
-			}
-			msg, err := readMessage(ctx, f)
-			if err != nil {
-				return nil, nil, err
-			}
-			m, ok := msg.(*message.ProxyResponse)
-			if !ok {
-				return nil, nil, NewErrUnexpectedMessage(ctx, fmt.Sprintf("%t", m))
-			}
-			p.mu.Lock()
-			p.proxyEndpoints = append(p.proxyEndpoints, m.Endpoints...)
-			p.mu.Unlock()
+			go p.connectToProxy(ctx, addr.Address, ep)
 		} else if err := p.m.Listen(ctx, addr.Protocol, addr.Address); err != nil {
 			return nil, nil, err
 		}
@@ -69,6 +52,16 @@
 	return p.m.ListeningEndpoints()
 }
 
+func (p *proxy) MultipleProxyEndpoints() []naming.Endpoint {
+	var eps []naming.Endpoint
+	p.mu.Lock()
+	for _, v := range p.proxyEndpoints {
+		eps = append(eps, v...)
+	}
+	p.mu.Unlock()
+	return eps
+}
+
 func (p *proxy) listenLoop(ctx *context.T) {
 	for {
 		f, err := p.m.Accept(ctx)
@@ -99,6 +92,7 @@
 func (p *proxy) startRouting(ctx *context.T, f flow.Flow, m *message.Setup) error {
 	fout, err := p.dialNextHop(ctx, f, m)
 	if err != nil {
+		f.Close()
 		return err
 	}
 	go p.forwardLoop(ctx, f, fout)
@@ -108,10 +102,9 @@
 
 func (p *proxy) forwardLoop(ctx *context.T, fin, fout flow.Flow) {
 	for {
-		_, err := io.Copy(fin, fout)
-		if err == io.EOF {
-			return
-		} else if err != nil {
+		if _, err := io.Copy(fin, fout); err != nil {
+			fin.Close()
+			fout.Close()
 			ctx.Errorf("f.Read failed: %v", err)
 			return
 		}
@@ -124,7 +117,10 @@
 		ep  naming.Endpoint
 		err error
 	)
-	if routes := m.PeerRemoteEndpoint.Routes(); len(routes) > 0 {
+	if ep, err = removeNetworkAddress(m.PeerRemoteEndpoint); err != nil {
+		return nil, err
+	}
+	if routes := ep.Routes(); len(routes) > 0 {
 		if err := rid.FromString(routes[0]); err != nil {
 			return nil, err
 		}
@@ -133,15 +129,13 @@
 		// TODO(suharshs): Make sure that the routingID from the route belongs to a
 		// connection that is stored in the manager's cache. (i.e. a Server has connected
 		// with the routingID before)
-		if ep, err = setEndpointRoutingID(m.PeerRemoteEndpoint, rid); err != nil {
+		if ep, err = setEndpointRoutingID(ep, rid); err != nil {
 			return nil, err
 		}
 		// Remove the read route from the setup message endpoint.
 		if m.PeerRemoteEndpoint, err = setEndpointRoutes(m.PeerRemoteEndpoint, routes[1:]); err != nil {
 			return nil, err
 		}
-	} else {
-		ep = m.PeerRemoteEndpoint
 	}
 	fout, err := p.m.Dial(ctx, ep, proxyBlessingsForPeer{}.run)
 	if err != nil {
@@ -175,7 +169,10 @@
 
 func (p *proxy) returnEndpoints(ctx *context.T, rid naming.RoutingID, route string) ([]naming.Endpoint, error) {
 	p.mu.Lock()
-	eps := append(p.m.ListeningEndpoints(), p.proxyEndpoints...)
+	eps := p.m.ListeningEndpoints()
+	for _, peps := range p.proxyEndpoints {
+		eps = append(eps, peps...)
+	}
 	p.mu.Unlock()
 	if len(eps) == 0 {
 		return nil, NewErrNotListening(ctx)
@@ -201,3 +198,42 @@
 	}
 	return eps, nil
 }
+
+func (p *proxy) connectToProxy(ctx *context.T, address string, ep naming.Endpoint) {
+	reconnectDelay := 50 * time.Millisecond
+	for delay := reconnectDelay; ; delay *= 2 {
+		time.Sleep(delay - reconnectDelay)
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+		f, err := p.m.Dial(ctx, ep, proxyBlessingsForPeer{}.run)
+		if err != nil {
+			ctx.Error(err)
+			continue
+		}
+		// Send a byte telling the acceptor that we are a proxy.
+		if err := writeMessage(ctx, &message.MultiProxyRequest{}, f); err != nil {
+			ctx.Error(err)
+			continue
+		}
+		eps, err := readProxyResponse(ctx, f)
+		if err != nil {
+			ctx.Error(err)
+			continue
+		}
+		p.mu.Lock()
+		p.proxyEndpoints[address] = eps
+		p.mu.Unlock()
+		select {
+		case <-ctx.Done():
+			return
+		case <-f.Closed():
+			p.mu.Lock()
+			delete(p.proxyEndpoints, address)
+			p.mu.Unlock()
+			delay = reconnectDelay
+		}
+	}
+}
diff --git a/services/xproxyd/util.go b/services/xproxyd/util.go
index 4ccad31..e14e23e 100644
--- a/services/xproxyd/util.go
+++ b/services/xproxyd/util.go
@@ -5,6 +5,8 @@
 package xproxyd
 
 import (
+	"fmt"
+
 	"v.io/v23"
 	"v.io/v23/context"
 	"v.io/v23/flow"
@@ -13,6 +15,16 @@
 	"v.io/v23/security"
 )
 
+func removeNetworkAddress(ep naming.Endpoint) (naming.Endpoint, error) {
+	_, _, routes, rid, bnames, mountable := getEndpointParts(ep)
+	opts := routes
+	opts = append(opts, bnames...)
+	opts = append(opts, rid)
+	opts = append(opts, mountable)
+	epString := naming.FormatEndpoint("", "", opts...)
+	return v23.NewEndpoint(epString)
+}
+
 // setEndpointRoutingID returns a copy of ep with RoutingId changed to rid.
 func setEndpointRoutingID(ep naming.Endpoint, rid naming.RoutingID) (naming.Endpoint, error) {
 	network, address, routes, _, bnames, mountable := getEndpointParts(ep)
@@ -87,3 +99,15 @@
 	}
 	return message.Read(ctx, b)
 }
+
+func readProxyResponse(ctx *context.T, f flow.Flow) ([]naming.Endpoint, error) {
+	msg, err := readMessage(ctx, f)
+	if err != nil {
+		return nil, err
+	}
+	res, ok := msg.(*message.ProxyResponse)
+	if !ok {
+		return nil, NewErrUnexpectedMessage(ctx, fmt.Sprintf("%t", msg))
+	}
+	return res.Endpoints, nil
+}