Merge "mocknet: fix a race in V23CloseAtMessage"
diff --git a/runtime/internal/rpc/stream/message/message.go b/runtime/internal/rpc/stream/message/message.go
index 199ba03..5e4b790 100644
--- a/runtime/internal/rpc/stream/message/message.go
+++ b/runtime/internal/rpc/stream/message/message.go
@@ -99,12 +99,12 @@
 	// level errors and hence {1}{2} is omitted from their format
 	// strings to avoid repeating these n-times in the final error
 	// message visible to the user.
-	errEmptyMessage            = reg(".errEmptyMessage", "message is empty")
-	errCorruptedMessage        = reg(".errCorruptedMessage", "corrupted message")
-	errInvalidMessageType      = reg("errInvalidMessageType", "invalid message type {3}")
-	errUnrecognizedMessageType = reg("errUrecognizedMessageType", "unrecognized message type {3}")
-	errFailedToReadVCHeader    = reg(".errFailedToReadVCHeader", "failed to read VC header{:3}")
-	errFailedToReadPayload     = reg(".errFailedToReadPayload", "failed to read payload of {3} bytes for type {4}{:5}")
+	errEmptyMessage              = reg(".errEmptyMessage", "message is empty")
+	errCorruptedMessage          = reg(".errCorruptedMessage", "corrupted message")
+	errInvalidMessageType        = reg("errInvalidMessageType", "invalid message type {3}")
+	errUnrecognizedMessageType   = reg("errUrecognizedMessageType", "unrecognized message type {3}")
+	errFailedToReadMessageHeader = reg(".errFailedToReadMessageHeader", "failed to read message header{:3}")
+	errFailedToReadPayload       = reg(".errFailedToReadPayload", "failed to read payload of {3} bytes for type {4}{:5}")
 )
 
 // T is the interface implemented by all messages communicated over a VIF.
@@ -127,7 +127,7 @@
 func ReadFrom(r *iobuf.Reader, c crypto.ControlCipher) (T, error) {
 	header, err := r.Read(commonHeaderSizeBytes)
 	if err != nil {
-		return nil, verror.New(errFailedToReadVCHeader, nil, err)
+		return nil, verror.New(errFailedToReadMessageHeader, nil, err)
 	}
 	c.Decrypt(header.Contents)
 	msgType := header.Contents[0]
diff --git a/runtime/internal/testing/mocks/mocknet/mocknet.go b/runtime/internal/testing/mocks/mocknet/mocknet.go
index 1132f51..83f301e 100644
--- a/runtime/internal/testing/mocks/mocknet/mocknet.go
+++ b/runtime/internal/testing/mocks/mocknet/mocknet.go
@@ -12,6 +12,7 @@
 	"io"
 	"net"
 	"sync"
+	"testing/iotest"
 	"time"
 
 	"v.io/v23"
@@ -123,7 +124,7 @@
 		return &v23Conn{
 			conn:   c,
 			opts:   opts,
-			cipher: crypto.NewDisabledControlCipher(&crypto.NullControlCipher{}),
+			cipher: &crypto.NullControlCipher{},
 			pool:   iobuf.NewPool(1024),
 		}
 	}
@@ -272,24 +273,40 @@
 }
 
 func (c *v23Conn) Read(b []byte) (n int, err error) {
-	n, err = c.conn.Read(b)
-	buf := iobuf.NewReader(c.pool, bytes.NewBuffer(b[:n]))
-	msg, err := message.ReadFrom(buf, c.cipher)
+	rb := bytes.NewBuffer(b[:0])
+	r := iobuf.NewReader(c.pool, io.TeeReader(iotest.OneByteReader(io.LimitReader(c.conn, int64(len(b)))), rb))
+	msg, err := message.ReadFrom(r, c.cipher)
 	if err == nil && c.opts.V23MessageMatcher(true, msg) {
 		c.conn.Close()
 		return 0, io.EOF
 	}
-	return n, err
+	return rb.Len(), err
 }
 
 func (c *v23Conn) Write(b []byte) (n int, err error) {
-	buf := iobuf.NewReader(c.pool, bytes.NewBuffer(b))
-	msg, err := message.ReadFrom(buf, c.cipher)
-	if err == nil && c.opts.V23MessageMatcher(false, msg) {
-		c.conn.Close()
-		return 0, io.EOF
+	rb := bytes.NewBuffer(b)
+	r := iobuf.NewReader(c.pool, iotest.OneByteReader(rb))
+	for rb.Len() > 0 {
+		msg, err := message.ReadFrom(r, c.cipher)
+		if err != nil {
+			return n, err
+		}
+		if c.opts.V23MessageMatcher(false, msg) {
+			c.conn.Close()
+			return n, io.EOF
+		}
+		var wb bytes.Buffer
+		err = message.WriteTo(&wb, msg, c.cipher)
+		if err != nil {
+			return n, err
+		}
+		tx, err := c.conn.Write(wb.Bytes())
+		n += tx
+		if err != nil {
+			return n, err
+		}
 	}
-	return c.conn.Write(b)
+	return n, nil
 }
 
 func (c *v23Conn) Close() error {
diff --git a/runtime/internal/testing/mocks/mocknet/mocknet_test.go b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
index 018e3ce..ccf1cf9 100644
--- a/runtime/internal/testing/mocks/mocknet/mocknet_test.go
+++ b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
@@ -5,6 +5,7 @@
 package mocknet_test
 
 import (
+	"bytes"
 	"errors"
 	"io"
 	"net"
@@ -21,6 +22,7 @@
 	"v.io/v23/verror"
 
 	_ "v.io/x/ref/runtime/factories/generic"
+	"v.io/x/ref/runtime/internal/rpc/stream/crypto"
 	"v.io/x/ref/runtime/internal/rpc/stream/message"
 	"v.io/x/ref/runtime/internal/testing/mocks/mocknet"
 	"v.io/x/ref/test"
@@ -46,11 +48,10 @@
 	defer ln.Close()
 
 	var rxconn net.Conn
-	var rxerr error
 	var wg sync.WaitGroup
 	wg.Add(1)
 	go func() {
-		rxconn, rxerr = ln.Accept()
+		rxconn, _ = ln.Accept()
 		wg.Done()
 	}()
 
@@ -119,11 +120,10 @@
 		defer ln.Close()
 
 		var rxconn net.Conn
-		var rxerr error
 		var wg sync.WaitGroup
 		wg.Add(1)
 		go func() {
-			rxconn, rxerr = ln.Accept()
+			rxconn, _ = ln.Accept()
 			wg.Done()
 		}()
 
@@ -190,11 +190,10 @@
 		defer ln.Close()
 
 		var rxconn net.Conn
-		var rxerr error
 		var wg sync.WaitGroup
 		wg.Add(1)
 		go func() {
-			rxconn, rxerr = ln.Accept()
+			rxconn, _ = ln.Accept()
 			wg.Done()
 		}()
 
@@ -223,6 +222,108 @@
 	}
 }
 
+func TestV23Drop(t *testing.T) {
+	cases := []struct {
+		numMsgs, txClose, rxClose int
+	}{
+		{5, 0, 0},
+		{5, 2, 0},
+		{5, 0, 2},
+		{5, 3, 2},
+		{5, 2, 3},
+	}
+
+	for ci, c := range cases {
+		var txed, rxed int
+		matcher := func(read bool, msg message.T) bool {
+			if read {
+				rxed++
+				return rxed == c.rxClose
+			} else {
+				txed++
+				return txed == c.txClose
+			}
+		}
+		opts := mocknet.Opts{
+			Mode:              mocknet.V23CloseAtMessage,
+			V23MessageMatcher: matcher,
+		}
+
+		ln := newListener(t, opts)
+		defer ln.Close()
+
+		var rxconn net.Conn
+		var wg sync.WaitGroup
+		wg.Add(1)
+		go func() {
+			rxconn, _ = ln.Accept()
+			wg.Done()
+		}()
+
+		txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
+		if err != nil {
+			t.Fatal(err)
+		}
+		wg.Wait()
+
+		var msgBuf bytes.Buffer
+		for i := 0; i < c.numMsgs; i++ {
+			err = message.WriteTo(&msgBuf, &message.Data{}, crypto.NullControlCipher{})
+			if err != nil {
+				t.Fatal(err)
+			}
+		}
+		perMsgBytes := msgBuf.Len() / c.numMsgs
+
+		n, err := txconn.Write(msgBuf.Bytes())
+		txMsgs := n / perMsgBytes
+		switch {
+		case c.txClose > 0:
+			if got, want := txMsgs, c.txClose-1; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+			if got, want := err, io.EOF; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+		default:
+			if got, want := txMsgs, c.numMsgs; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+			if err != nil {
+				t.Fatalf("%d: %v\n", ci, err)
+			}
+		}
+
+		var rxMsgs int
+		for ; rxMsgs < txMsgs; rxMsgs++ {
+			var n int
+			n, err = rxconn.Read(make([]byte, perMsgBytes*2))
+			if err != nil {
+				break
+			}
+			if got, want := n, perMsgBytes; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+		}
+		switch {
+		case c.rxClose > 0 && (c.txClose == 0 || c.txClose > c.rxClose):
+			if got, want := rxMsgs, c.rxClose-1; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+			if got, want := err, io.EOF; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+		default:
+			if got, want := rxMsgs, txMsgs; got != want {
+				t.Fatalf("%d: got %v, want %v", ci, got, want)
+			}
+			if err != nil {
+				t.Fatalf("%d: %v\n", ci, err)
+			}
+		}
+	}
+}
+
 func newCtx() (*context.T, v23.Shutdown) {
 	ctx, shutdown := test.InitForTest()
 	v23.GetNamespace(ctx).CacheCtl(naming.DisableCache(true))
@@ -287,7 +388,7 @@
 		t.Fatal(err)
 	}
 
-	_, err = v23.GetClient(ctx).StartCall(ctx, dropServer.Name(), "Ping", nil, options.SecurityNone)
+	_, err = v23.GetClient(ctx).StartCall(ctx, dropServer.Name(), "Ping", nil, options.SecurityNone, options.NoRetry{})
 	if verror.ErrorID(err) != verror.ErrBadProtocol.ID {
 		t.Fatal(err)
 	}