ref: Add lameducking support into the flow.Conn.

MultiPart: 2/2

Change-Id: I86a9c885b706e33d9fbf1ae2fd547907c904f033
diff --git a/runtime/internal/flow/conn/conn.go b/runtime/internal/flow/conn/conn.go
index bbe5862..5a6e248 100644
--- a/runtime/internal/flow/conn/conn.go
+++ b/runtime/internal/flow/conn/conn.go
@@ -43,6 +43,24 @@
 	HandleFlow(flow.Flow) error
 }
 
+// Status describes the current state of the Conn.
+type Status struct {
+	Closed bool
+	// LocalLameDuck signifies that we have received acknowledgment from the
+	// remote host of our lame duck mode and therefore should not expect new
+	// flows to arrive on this Conn.
+	LocalLameDuck bool
+	// RemoteLameDuck signifies that we have received a lameduck notification
+	// from the remote host and therefore should not open new flows on this Conn.
+	RemoteLameDuck bool
+}
+
+// Events that can be emitted through the events channel.
+type StatusUpdate struct {
+	Conn   *Conn
+	Status Status
+}
+
 // Conns are a multiplexing encrypted channels that can host Flows.
 type Conn struct {
 	fc                     *flowcontrol.FlowController
@@ -53,6 +71,8 @@
 	closed                 chan struct{}
 	blessingsFlow          *blessingsFlow
 	loopWG                 sync.WaitGroup
+	unopenedFlows          sync.WaitGroup
+	events                 chan<- StatusUpdate
 
 	mu             sync.Mutex
 	handler        FlowHandler
@@ -62,6 +82,7 @@
 	lastUsedTime   time.Time
 	toRelease      map[uint64]uint64
 	borrowing      map[uint64]bool
+	status         Status
 }
 
 // Ensure that *Conn implements flow.ManagedConn.
@@ -73,7 +94,8 @@
 	conn flow.MsgReadWriteCloser,
 	local, remote naming.Endpoint,
 	versions version.RPCVersionRange,
-	handler FlowHandler) (*Conn, error) {
+	handler FlowHandler,
+	events chan<- StatusUpdate) (*Conn, error) {
 	c := &Conn{
 		fc:           flowcontrol.New(defaultBufferSize, mtu),
 		mp:           newMessagePipe(conn),
@@ -87,6 +109,7 @@
 		lastUsedTime: time.Now(),
 		toRelease:    map[uint64]uint64{},
 		borrowing:    map[uint64]bool{},
+		events:       events,
 	}
 	if err := c.dialHandshake(ctx, versions); err != nil {
 		c.Close(ctx, err)
@@ -103,7 +126,8 @@
 	conn flow.MsgReadWriteCloser,
 	local naming.Endpoint,
 	versions version.RPCVersionRange,
-	handler FlowHandler) (*Conn, error) {
+	handler FlowHandler,
+	events chan<- StatusUpdate) (*Conn, error) {
 	c := &Conn{
 		fc:           flowcontrol.New(defaultBufferSize, mtu),
 		mp:           newMessagePipe(conn),
@@ -116,6 +140,7 @@
 		lastUsedTime: time.Now(),
 		toRelease:    map[uint64]uint64{},
 		borrowing:    map[uint64]bool{},
+		events:       events,
 	}
 	if err := c.acceptHandshake(ctx, versions); err != nil {
 		c.Close(ctx, err)
@@ -126,6 +151,16 @@
 	return c, nil
 }
 
+// Enter LameDuck mode.
+func (c *Conn) EnterLameDuck(ctx *context.T) {
+	err := c.fc.Run(ctx, "enterlameduck", expressPriority, func(_ int) (int, bool, error) {
+		return 0, true, c.mp.writeMsg(ctx, &message.EnterLameDuck{})
+	})
+	if err != nil {
+		c.Close(ctx, NewErrSend(ctx, "release", c.remote.String(), err))
+	}
+}
+
 // Dial dials a new flow on the Conn.
 func (c *Conn) Dial(ctx *context.T, fn flow.BlessingsForPeer) (flow.Flow, error) {
 	if c.rBlessings.IsZero() {
@@ -145,7 +180,7 @@
 	}
 	defer c.mu.Unlock()
 	c.mu.Lock()
-	if c.flows == nil {
+	if c.status.RemoteLameDuck || c.flows == nil {
 		return nil, NewErrConnectionClosed(ctx)
 	}
 	id := c.nextFid
@@ -180,47 +215,63 @@
 // with an error and no more flows will be sent to the FlowHandler.
 func (c *Conn) Closed() <-chan struct{} { return c.closed }
 
+func (c *Conn) Status() Status {
+	c.mu.Lock()
+	status := c.status
+	c.mu.Unlock()
+	return status
+}
+
 // Close shuts down a conn.
 func (c *Conn) Close(ctx *context.T, err error) {
 	c.mu.Lock()
+	c.internalCloseLocked(ctx, err)
+	c.mu.Unlock()
+	<-c.closed
+}
+
+func (c *Conn) internalCloseLocked(ctx *context.T, err error) {
+	ctx.VI(2).Infof("Closing connection: %v", err)
+
 	var flows map[uint64]*flw
 	flows, c.flows = c.flows, nil
 	if c.dischargeTimer != nil {
 		c.dischargeTimer.Stop()
 		c.dischargeTimer = nil
 	}
-	c.mu.Unlock()
-
 	if flows == nil {
 		// This conn is already being torn down.
-		<-c.closed
 		return
 	}
-	c.internalClose(ctx, err, flows)
-}
-
-func (c *Conn) internalClose(ctx *context.T, err error, flows map[uint64]*flw) {
-	ctx.VI(2).Infof("Closing connection: %v", err)
-	if verror.ErrorID(err) != ErrConnClosedRemotely.ID {
-		msg := ""
-		if err != nil {
-			msg = err.Error()
+	go func() {
+		if verror.ErrorID(err) != ErrConnClosedRemotely.ID {
+			msg := ""
+			if err != nil {
+				msg = err.Error()
+			}
+			cerr := c.fc.Run(ctx, "close", expressPriority, func(_ int) (int, bool, error) {
+				return 0, true, c.mp.writeMsg(ctx, &message.TearDown{Message: msg})
+			})
+			if cerr != nil {
+				ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, cerr)
+			}
 		}
-		cerr := c.fc.Run(ctx, "close", expressPriority, func(_ int) (int, bool, error) {
-			return 0, true, c.mp.writeMsg(ctx, &message.TearDown{Message: msg})
-		})
-		if cerr != nil {
-			ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, cerr)
+		for _, f := range flows {
+			f.close(ctx, NewErrConnectionClosed(ctx))
 		}
-	}
-	for _, f := range flows {
-		f.close(ctx, NewErrConnectionClosed(ctx))
-	}
-	if cerr := c.mp.close(); cerr != nil {
-		ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, cerr)
-	}
-	c.loopWG.Wait()
-	close(c.closed)
+		if cerr := c.mp.close(); cerr != nil {
+			ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, cerr)
+		}
+		c.loopWG.Wait()
+		c.mu.Lock()
+		c.status.Closed = true
+		status := c.status
+		c.mu.Unlock()
+		if c.events != nil {
+			c.events <- StatusUpdate{c, status}
+		}
+		close(c.closed)
+	}()
 }
 
 func (c *Conn) release(ctx *context.T, fid, count uint64) {
@@ -260,12 +311,48 @@
 func (c *Conn) handleMessage(ctx *context.T, m message.Message) error {
 	switch msg := m.(type) {
 	case *message.TearDown:
-		return NewErrConnClosedRemotely(ctx, msg.Message)
+		c.mu.Lock()
+		c.internalCloseLocked(ctx, NewErrConnClosedRemotely(ctx, msg.Message))
+		c.mu.Unlock()
+		return nil
+
+	case *message.EnterLameDuck:
+		c.mu.Lock()
+		c.status.RemoteLameDuck = true
+		status := c.status
+		c.mu.Unlock()
+		if c.events != nil {
+			c.events <- StatusUpdate{c, status}
+		}
+		go func() {
+			// We only want to send the lame duck acknowledgment after all outstanding
+			// OpenFlows are sent.
+			c.unopenedFlows.Wait()
+			err := c.fc.Run(ctx, "lameduck", expressPriority, func(_ int) (int, bool, error) {
+				return 0, true, c.mp.writeMsg(ctx, &message.AckLameDuck{})
+			})
+			if err != nil {
+				c.Close(ctx, NewErrSend(ctx, "release", c.remote.String(), err))
+			}
+		}()
+
+	case *message.AckLameDuck:
+		c.mu.Lock()
+		c.status.LocalLameDuck = true
+		status := c.status
+		c.mu.Unlock()
+		if c.events != nil {
+			c.events <- StatusUpdate{c, status}
+		}
 
 	case *message.OpenFlow:
 		c.mu.Lock()
 		if c.handler == nil {
+			c.mu.Unlock()
 			return NewErrUnexpectedMsg(ctx, "openFlow")
+		} else if c.flows == nil {
+			c.mu.Unlock()
+			return nil // Conn is already being closed.
 		}
 		handler := c.handler
 		f := c.newFlowLocked(ctx, msg.ID, msg.BlessingsKey, msg.DischargeKey, false, true)
@@ -273,6 +360,7 @@
 		c.toRelease[msg.ID] = defaultBufferSize
 		c.borrowing[msg.ID] = true
 		c.mu.Unlock()
+
 		handler.HandleFlow(f)
 		if err := f.q.put(ctx, msg.Payload); err != nil {
 			return err
@@ -299,6 +387,10 @@
 
 	case *message.Data:
 		c.mu.Lock()
+		if c.flows == nil {
+			c.mu.Unlock()
+			return nil // Conn is already being shut down.
+		}
 		f := c.flows[msg.ID]
 		c.mu.Unlock()
 		if f == nil {
@@ -319,6 +411,7 @@
 }
 
 func (c *Conn) readLoop(ctx *context.T) {
+	defer c.loopWG.Done()
 	var err error
 	for {
 		msg, rerr := c.mp.readMsg(ctx)
@@ -330,16 +423,9 @@
 			break
 		}
 	}
-
 	c.mu.Lock()
-	var flows map[uint64]*flw
-	flows, c.flows = c.flows, nil
+	c.internalCloseLocked(ctx, err)
 	c.mu.Unlock()
-
-	c.loopWG.Done()
-	if flows != nil {
-		c.internalClose(ctx, err, flows)
-	}
 }
 
 func (c *Conn) markUsed() {
diff --git a/runtime/internal/flow/conn/conn_test.go b/runtime/internal/flow/conn/conn_test.go
index 9c5ce48..6837486 100644
--- a/runtime/internal/flow/conn/conn_test.go
+++ b/runtime/internal/flow/conn/conn_test.go
@@ -100,14 +100,14 @@
 	q1, q2 := make(chan flow.Flow, 1), make(chan flow.Flow, 1)
 	fh1, fh2 := fh(q1), fh(q2)
 	go func() {
-		d, err := NewDialed(ctx, dmrw, ep, ep, versions, nil)
+		d, err := NewDialed(ctx, dmrw, ep, ep, versions, nil, nil)
 		if err != nil {
 			panic(err)
 		}
 		dch <- d
 	}()
 	go func() {
-		a, err := NewAccepted(ctx, amrw, ep, versions, fh1)
+		a, err := NewAccepted(ctx, amrw, ep, versions, fh1, nil)
 		if err != nil {
 			panic(err)
 		}
diff --git a/runtime/internal/flow/conn/flow.go b/runtime/internal/flow/conn/flow.go
index 3bfd000..04fdfcf 100644
--- a/runtime/internal/flow/conn/flow.go
+++ b/runtime/internal/flow/conn/flow.go
@@ -43,6 +43,9 @@
 		opened: preopen,
 	}
 	f.SetContext(ctx)
+	if !f.opened {
+		c.unopenedFlows.Add(1)
+	}
 	c.flows[id] = f
 	return f
 }
@@ -136,6 +139,7 @@
 				Payload:         d.Payload,
 			})
 			f.opened = true
+			f.conn.unopenedFlows.Done()
 		}
 		return size, done, err
 	})
@@ -262,21 +266,34 @@
 func (f *flw) close(ctx *context.T, err error) {
 	f.q.close(ctx)
 	f.cancel()
-	if eid := verror.ErrorID(err); eid != ErrFlowClosedRemotely.ID &&
-		eid != ErrConnectionClosed.ID {
-		// We want to try to send this message even if ctx is already canceled.
-		ctx, cancel := context.WithRootCancel(ctx)
-		err := f.worker.Run(ctx, func(tokens int) (int, bool, error) {
-			return 0, true, f.conn.mp.writeMsg(ctx, &message.Data{
-				ID:    f.id,
-				Flags: message.CloseFlag,
-			})
-		})
-		if err != nil {
-			ctx.Errorf("Could not send close flow message: %v", err)
+	// We want to try to send this message even if ctx is already canceled.
+	ctx, cancel := context.WithRootCancel(ctx)
+	serr := f.worker.Run(ctx, func(tokens int) (int, bool, error) {
+		f.conn.mu.Lock()
+		delete(f.conn.flows, f.id)
+		connClosed := f.conn.flows == nil
+		f.conn.mu.Unlock()
+
+		if !f.opened {
+			// Closing a flow that was never opened.
+			f.conn.unopenedFlows.Done()
+			return 0, true, nil
+		} else if eid := verror.ErrorID(err); eid == ErrFlowClosedRemotely.ID || connClosed {
+			// Note: If the conn is closed there is no point in trying to send
+			// the flow close message as it will fail.  This is racy with the connection
+			// closing, but there are no ill-effects other than spamming the logs a little
+			// so it's OK.
+			return 0, true, nil
 		}
-		cancel()
+		return 0, true, f.conn.mp.writeMsg(ctx, &message.Data{
+			ID:    f.id,
+			Flags: message.CloseFlag,
+		})
+	})
+	if serr != nil {
+		ctx.Errorf("Could not send close flow message: %v", err)
 	}
+	cancel()
 }
 
 // Close marks the flow as closed. After Close is called, new data cannot be
diff --git a/runtime/internal/flow/conn/lameduck_test.go b/runtime/internal/flow/conn/lameduck_test.go
new file mode 100644
index 0000000..b2f1def
--- /dev/null
+++ b/runtime/internal/flow/conn/lameduck_test.go
@@ -0,0 +1,110 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+	"bytes"
+	"fmt"
+	"testing"
+	"time"
+
+	"v.io/v23"
+	"v.io/v23/flow"
+	"v.io/x/ref/test/goroutines"
+)
+
+func TestLameDuck(t *testing.T) {
+	defer goroutines.NoLeaks(t, leakWaitTime)()
+
+	ctx, shutdown := v23.Init()
+	defer shutdown()
+
+	events := make(chan StatusUpdate, 2)
+	dflows, aflows := make(chan flow.Flow, 3), make(chan flow.Flow, 3)
+	dc, ac, _ := setupConnsWithEvents(t, ctx, ctx, dflows, aflows, events)
+
+	go func() {
+		for {
+			select {
+			case f := <-aflows:
+				if got, err := f.ReadMsg(); err != nil {
+					panic(fmt.Sprintf("got %v wanted nil", err))
+				} else if !bytes.Equal(got, []byte("hello")) {
+					panic(fmt.Sprintf("got %q, wanted 'hello'", string(got)))
+				}
+			case <-ac.Closed():
+				return
+			}
+		}
+	}()
+
+	// Dial a flow and write it (which causes it to open).
+	f1, err := dc.Dial(ctx, testBFP)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := f1.WriteMsg([]byte("hello")); err != nil {
+		t.Fatal(err)
+	}
+	// Dial more flows, but don't write to them yet.
+	f2, err := dc.Dial(ctx, testBFP)
+	if err != nil {
+		t.Fatal(err)
+	}
+	f3, err := dc.Dial(ctx, testBFP)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Now put the accepted conn into lame duck mode.
+	ac.EnterLameDuck(ctx)
+	if e := <-events; e.Conn != dc || e.Status != (Status{false, false, true}) {
+		t.Errorf("Expected RemoteLameDuck on dialer, got %#v (a %p, d %p)", e, ac, dc)
+	}
+
+	// Now we shouldn't be able to dial from dc because it's in lame duck mode.
+	if _, err := dc.Dial(ctx, testBFP); err == nil {
+		t.Fatalf("expected an error, got nil")
+	}
+
+	// I can't think of a non-flaky way to test for it, but it should
+	// be the case that we don't send the AckLameDuck message until
+	// we write to or close the other flows.  This should catch it sometimes.
+	select {
+	case e := <-events:
+		t.Errorf("Didn't expect any additional events yet, got %#v", e)
+	case <-time.After(time.Millisecond * 100):
+	}
+
+	// Now write or close the other flows.
+	if _, err := f2.WriteMsg([]byte("hello")); err != nil {
+		t.Fatal(err)
+	}
+	f3.Close()
+
+	if e := <-events; e.Conn != ac || e.Status != (Status{false, true, false}) {
+		t.Errorf("Expected LocalLameDuck on acceptor, got %#v (a %p, d %p)", e, ac, dc)
+	}
+
+	// Now put the dialer side into lame duck.
+	dc.EnterLameDuck(ctx)
+	if e := <-events; e.Conn != ac || e.Status != (Status{false, true, true}) {
+		t.Errorf("Expected RemoteLameDuck on acceptor, got %#v (a %p, d %p)", e, ac, dc)
+	}
+	if e := <-events; e.Conn != dc || e.Status != (Status{false, true, true}) {
+		t.Errorf("Expected LocalLameDuck on dialer, got %#v (a %p, d %p)", e, ac, dc)
+	}
+
+	// Now close the accept side.
+	ac.Close(ctx, nil)
+	if e := <-events; e.Status != (Status{true, true, true}) {
+		t.Errorf("Expected Closed got %#v (a %p, d %p)", e, ac, dc)
+	}
+	if e := <-events; e.Status != (Status{true, true, true}) {
+		t.Errorf("Expected Closed got %#v (a %p, d %p)", e, ac, dc)
+	}
+	<-dc.Closed()
+	<-ac.Closed()
+}
diff --git a/runtime/internal/flow/conn/util_test.go b/runtime/internal/flow/conn/util_test.go
index dbb3537..b68247b 100644
--- a/runtime/internal/flow/conn/util_test.go
+++ b/runtime/internal/flow/conn/util_test.go
@@ -28,7 +28,16 @@
 	return nil
 }
 
-func setupConns(t *testing.T, dctx, actx *context.T, dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn, _ *flowtest.Wire) {
+func setupConns(t *testing.T,
+	dctx, actx *context.T,
+	dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn, _ *flowtest.Wire) {
+	return setupConnsWithEvents(t, dctx, actx, dflows, aflows, nil)
+}
+
+func setupConnsWithEvents(t *testing.T,
+	dctx, actx *context.T,
+	dflows, aflows chan<- flow.Flow,
+	events chan<- StatusUpdate) (dialed, accepted *Conn, _ *flowtest.Wire) {
 	dmrw, amrw, w := flowtest.NewMRWPair(dctx)
 	versions := version.RPCVersionRange{Min: 3, Max: 5}
 	ep, err := v23.NewEndpoint("localhost:80")
@@ -42,7 +51,7 @@
 		if dflows != nil {
 			handler = fh(dflows)
 		}
-		d, err := NewDialed(dctx, dmrw, ep, ep, versions, handler)
+		d, err := NewDialed(dctx, dmrw, ep, ep, versions, handler, events)
 		if err != nil {
 			panic(err)
 		}
@@ -53,7 +62,7 @@
 		if aflows != nil {
 			handler = fh(aflows)
 		}
-		a, err := NewAccepted(actx, amrw, ep, versions, handler)
+		a, err := NewAccepted(actx, amrw, ep, versions, handler, events)
 		if err != nil {
 			panic(err)
 		}
diff --git a/runtime/internal/flow/manager/conncache_test.go b/runtime/internal/flow/manager/conncache_test.go
index ace0f89..8d02e9a 100644
--- a/runtime/internal/flow/manager/conncache_test.go
+++ b/runtime/internal/flow/manager/conncache_test.go
@@ -286,7 +286,7 @@
 	ach := make(chan *connpackage.Conn)
 	go func() {
 		d, err := connpackage.NewDialed(ctx, dmrw, ep, ep,
-			version.RPCVersionRange{Min: 1, Max: 5}, nil)
+			version.RPCVersionRange{Min: 1, Max: 5}, nil, nil)
 		if err != nil {
 			t.Fatalf("Unexpected error: %v", err)
 		}
@@ -295,7 +295,7 @@
 	fh := fh{t, make(chan struct{})}
 	go func() {
 		a, err := connpackage.NewAccepted(ctx, amrw, ep,
-			version.RPCVersionRange{Min: 1, Max: 5}, fh)
+			version.RPCVersionRange{Min: 1, Max: 5}, fh, nil)
 		if err != nil {
 			t.Fatalf("Unexpected error: %v", err)
 		}
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index bc15eb4..099aaac 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -221,7 +221,7 @@
 			local,
 			version.Supported,
 			&flowHandler{q: m.q, cached: cached},
-		)
+			nil)
 		if err != nil {
 			close(cached)
 			flowConn.Close()
@@ -260,7 +260,8 @@
 			f,
 			f.Conn().LocalEndpoint(),
 			version.Supported,
-			&flowHandler{q: h.m.q})
+			&flowHandler{q: h.m.q},
+			nil)
 		if err != nil {
 			h.ctx.Errorf("failed to create accepted conn: %v", err)
 			return
@@ -387,6 +388,7 @@
 			remote,
 			version.Supported,
 			fh,
+			nil,
 		)
 		if err != nil {
 			flowConn.Close()
@@ -414,6 +416,7 @@
 			remote,
 			version.Supported,
 			fh,
+			nil,
 		)
 		if err != nil {
 			proxyConn.Close(ctx, err)