vine: Setting not reachable behavior causes existing connections
to be killed.

Change-Id: I6e2aa8f7384e726b8d1b4ddd41a519dd99945147
diff --git a/runtime/protocols/vine/vine.go b/runtime/protocols/vine/vine.go
index a729e1d..736ad20 100644
--- a/runtime/protocols/vine/vine.go
+++ b/runtime/protocols/vine/vine.go
@@ -24,7 +24,10 @@
 )
 
 func init() {
-	v := &vine{behaviors: make(map[ConnKey]ConnBehavior)}
+	v := &vine{
+		behaviors: make(map[ConnKey]ConnBehavior),
+		conns:     make(map[ConnKey]map[*conn]bool),
+	}
 	flow.RegisterProtocol("vine", v)
 }
 
@@ -58,6 +61,9 @@
 	// If a ConnKey isn't in the map, the connection will be created under normal
 	// network characteristics.
 	behaviors map[ConnKey]ConnBehavior
+	// conns stores all the vine connections. Sets of *conns are keyed by their
+	// corresponding ConnKey
+	conns map[ConnKey]map[*conn]bool
 }
 
 // SetBehaviors sets the policy that the accepting vine service's process
@@ -67,9 +73,20 @@
 //   client.SetBehaviors(map[ConnKey]ConnBehavior{ConnKey{"foo", "bar"}, ConnBehavior{Reachable: false}})
 // will cause all vine protocol dial calls from "foo" to "bar" to fail.
 func (v *vine) SetBehaviors(ctx *context.T, call rpc.ServerCall, behaviors map[ConnKey]ConnBehavior) error {
+	var toKill []flow.Conn
 	v.mu.Lock()
 	v.behaviors = behaviors
+	for key, behavior := range behaviors {
+		if !behavior.Reachable {
+			for conn := range v.conns[key] {
+				toKill = append(toKill, conn)
+			}
+		}
+	}
 	v.mu.Unlock()
+	for _, conn := range toKill {
+		conn.Close()
+	}
 	return nil
 }
 
@@ -102,10 +119,14 @@
 	if err := sendLocalTag(ctx, c); err != nil {
 		return nil, err
 	}
-	return &conn{
+	conn := &conn{
 		base: c,
 		addr: addr(createDialingAddress(laddr.Network(), laddr.String(), localTag)),
-	}, nil
+		key:  key,
+		vine: v,
+	}
+	v.insertConn(conn)
+	return conn, nil
 }
 
 // Resolve returns the resolved protocol and addresses. For example,
@@ -151,9 +172,34 @@
 	}, nil
 }
 
+func (v *vine) insertConn(c *conn) {
+	key := c.key
+	v.mu.Lock()
+	if m, ok := v.conns[key]; !ok {
+		v.conns[key] = make(map[*conn]bool)
+		v.conns[key][c] = true
+	} else {
+		m[c] = true
+	}
+	v.mu.Unlock()
+}
+
+func (v *vine) removeConn(c *conn) {
+	key := c.key
+	v.mu.Lock()
+	if m, ok := v.conns[key]; ok {
+		if _, ok := m[c]; ok {
+			delete(m, c)
+		}
+	}
+	v.mu.Unlock()
+}
+
 type conn struct {
 	base flow.Conn
 	addr addr
+	key  ConnKey
+	vine *vine
 }
 
 // WriteMsg wraps the base flow.Conn's WriteMsg method to allow injection of
@@ -169,6 +215,7 @@
 }
 
 func (c *conn) Close() error {
+	c.vine.removeConn(c)
 	return c.base.Close()
 }
 
@@ -199,7 +246,14 @@
 	if ok && !behavior.Reachable {
 		return nil, NewErrCantAcceptFromTag(ctx, remoteTag)
 	}
-	return &conn{base: c, addr: l.addr}, nil
+	conn := &conn{
+		base: c,
+		addr: l.addr,
+		key:  key,
+		vine: l.vine,
+	}
+	l.vine.insertConn(conn)
+	return conn, nil
 }
 
 func (l *listener) Addr() net.Addr {
diff --git a/runtime/protocols/vine/vine_test.go b/runtime/protocols/vine/vine_test.go
index 122bb4d..0132075 100644
--- a/runtime/protocols/vine/vine_test.go
+++ b/runtime/protocols/vine/vine_test.go
@@ -49,11 +49,6 @@
 	if err := client.Call(ctx, "reachable", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
-	// We create a new client to avoid using cached connections.
-	ctx, client, err = v23.WithNewClient(ctx)
-	if err != nil {
-		t.Error(err)
-	}
 	if err := client.Call(ctx, "unreachable", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
@@ -67,24 +62,23 @@
 	}); err != nil {
 		t.Error(err)
 	}
-	// We create a new client to avoid using cached connections.
-	ctx, client, err = v23.WithNewClient(ctx)
-	if err != nil {
-		t.Error(err)
-	}
-	// The call to reachable should succeed
+	// The call to reachable should succeed since the cached connection still exists.
 	if err := client.Call(ctx, "reachable", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
-	// We create a new client to avoid using cached connections.
-	ctx, client, err = v23.WithNewClient(ctx)
-	if err != nil {
-		t.Error(err)
-	}
-	// but the call to unreachable should fail.
+	// the call to unreachable should fail, since the cached connection should be closed
+	// and the new attempt to create a connection fails as well.
 	if err := client.Call(ctx, "unreachable", "Foo", nil, nil, options.NoRetry{}); err == nil {
 		t.Errorf("wanted call to fail")
 	}
+	// Create new clients to avoid using cached connections.
+	if ctx, _, err = v23.WithNewClient(ctx); err != nil {
+		t.Error(err)
+	}
+	// Now, a call to reachable should still work even without a cached connection.
+	if err := client.Call(ctx, "reachable", "Foo", nil, nil); err != nil {
+		t.Error(err)
+	}
 }
 
 func TestIncomingReachable(t *testing.T) {
@@ -96,6 +90,9 @@
 		t.Fatal(err)
 	}
 	denyCtx := vine.WithLocalTag(ctx, "denyClient")
+	if denyCtx, _, err = v23.WithNewClient(denyCtx); err != nil {
+		t.Fatal(err)
+	}
 
 	sctx := vine.WithLocalTag(ctx, "server")
 	sctx, cancel := context.WithCancel(sctx)
@@ -113,15 +110,7 @@
 	if err := v23.GetClient(ctx).Call(ctx, "server", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
-	if err := v23.GetClient(denyCtx).Call(ctx, "server", "Foo", nil, nil); err != nil {
-		t.Error(err)
-	}
-
-	// Create new clients to avoid using cached connections.
-	if ctx, _, err = v23.WithNewClient(ctx); err != nil {
-		t.Error(err)
-	}
-	if denyCtx, _, err = v23.WithNewClient(denyCtx); err != nil {
+	if err := v23.GetClient(denyCtx).Call(denyCtx, "server", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
 
@@ -135,14 +124,23 @@
 		t.Error(err)
 	}
 
-	// Now, the call from client to server should work.
+	// Now, the call from client to server should work, since the connection is still cached.
 	if err := v23.GetClient(ctx).Call(ctx, "server", "Foo", nil, nil); err != nil {
 		t.Error(err)
 	}
-	// but the call from denyclient to server should fail.
+	// but the call from denyclient to server should fail, since the cached connection
+	// should be closed and the new call should also fail.
 	if err := v23.GetClient(denyCtx).Call(denyCtx, "server", "Foo", nil, nil, options.NoRetry{}); err == nil {
 		t.Errorf("wanted call to fail")
 	}
+	// Create new clients to avoid using cached connections.
+	if ctx, _, err = v23.WithNewClient(ctx); err != nil {
+		t.Error(err)
+	}
+	// Now, a call with "client" should still work even without a cached connection.
+	if err := v23.GetClient(ctx).Call(ctx, "server", "Foo", nil, nil); err != nil {
+		t.Error(err)
+	}
 }
 
 type testService struct{}