Merge "mocknet: fix a race in V23CloseAtMessage"
diff --git a/runtime/internal/rpc/stream/message/message.go b/runtime/internal/rpc/stream/message/message.go
index 199ba03..5e4b790 100644
--- a/runtime/internal/rpc/stream/message/message.go
+++ b/runtime/internal/rpc/stream/message/message.go
@@ -99,12 +99,12 @@
// level errors and hence {1}{2} is omitted from their format
// strings to avoid repeating these n-times in the final error
// message visible to the user.
- errEmptyMessage = reg(".errEmptyMessage", "message is empty")
- errCorruptedMessage = reg(".errCorruptedMessage", "corrupted message")
- errInvalidMessageType = reg("errInvalidMessageType", "invalid message type {3}")
- errUnrecognizedMessageType = reg("errUrecognizedMessageType", "unrecognized message type {3}")
- errFailedToReadVCHeader = reg(".errFailedToReadVCHeader", "failed to read VC header{:3}")
- errFailedToReadPayload = reg(".errFailedToReadPayload", "failed to read payload of {3} bytes for type {4}{:5}")
+ errEmptyMessage = reg(".errEmptyMessage", "message is empty")
+ errCorruptedMessage = reg(".errCorruptedMessage", "corrupted message")
+ errInvalidMessageType = reg("errInvalidMessageType", "invalid message type {3}")
+ errUnrecognizedMessageType = reg("errUrecognizedMessageType", "unrecognized message type {3}")
+ errFailedToReadMessageHeader = reg(".errFailedToReadMessageHeader", "failed to read message header{:3}")
+ errFailedToReadPayload = reg(".errFailedToReadPayload", "failed to read payload of {3} bytes for type {4}{:5}")
)
// T is the interface implemented by all messages communicated over a VIF.
@@ -127,7 +127,7 @@
func ReadFrom(r *iobuf.Reader, c crypto.ControlCipher) (T, error) {
header, err := r.Read(commonHeaderSizeBytes)
if err != nil {
- return nil, verror.New(errFailedToReadVCHeader, nil, err)
+ return nil, verror.New(errFailedToReadMessageHeader, nil, err)
}
c.Decrypt(header.Contents)
msgType := header.Contents[0]
diff --git a/runtime/internal/testing/mocks/mocknet/mocknet.go b/runtime/internal/testing/mocks/mocknet/mocknet.go
index 1132f51..83f301e 100644
--- a/runtime/internal/testing/mocks/mocknet/mocknet.go
+++ b/runtime/internal/testing/mocks/mocknet/mocknet.go
@@ -12,6 +12,7 @@
"io"
"net"
"sync"
+ "testing/iotest"
"time"
"v.io/v23"
@@ -123,7 +124,7 @@
return &v23Conn{
conn: c,
opts: opts,
- cipher: crypto.NewDisabledControlCipher(&crypto.NullControlCipher{}),
+ cipher: &crypto.NullControlCipher{},
pool: iobuf.NewPool(1024),
}
}
@@ -272,24 +273,40 @@
}
func (c *v23Conn) Read(b []byte) (n int, err error) {
- n, err = c.conn.Read(b)
- buf := iobuf.NewReader(c.pool, bytes.NewBuffer(b[:n]))
- msg, err := message.ReadFrom(buf, c.cipher)
+ rb := bytes.NewBuffer(b[:0])
+ r := iobuf.NewReader(c.pool, io.TeeReader(iotest.OneByteReader(io.LimitReader(c.conn, int64(len(b)))), rb))
+ msg, err := message.ReadFrom(r, c.cipher)
if err == nil && c.opts.V23MessageMatcher(true, msg) {
c.conn.Close()
return 0, io.EOF
}
- return n, err
+ return rb.Len(), err
}
func (c *v23Conn) Write(b []byte) (n int, err error) {
- buf := iobuf.NewReader(c.pool, bytes.NewBuffer(b))
- msg, err := message.ReadFrom(buf, c.cipher)
- if err == nil && c.opts.V23MessageMatcher(false, msg) {
- c.conn.Close()
- return 0, io.EOF
+ rb := bytes.NewBuffer(b)
+ r := iobuf.NewReader(c.pool, iotest.OneByteReader(rb))
+ for rb.Len() > 0 {
+ msg, err := message.ReadFrom(r, c.cipher)
+ if err != nil {
+ return n, err
+ }
+ if c.opts.V23MessageMatcher(false, msg) {
+ c.conn.Close()
+ return n, io.EOF
+ }
+ var wb bytes.Buffer
+ err = message.WriteTo(&wb, msg, c.cipher)
+ if err != nil {
+ return n, err
+ }
+ tx, err := c.conn.Write(wb.Bytes())
+ n += tx
+ if err != nil {
+ return n, err
+ }
}
- return c.conn.Write(b)
+ return n, nil
}
func (c *v23Conn) Close() error {
diff --git a/runtime/internal/testing/mocks/mocknet/mocknet_test.go b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
index 018e3ce..ccf1cf9 100644
--- a/runtime/internal/testing/mocks/mocknet/mocknet_test.go
+++ b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
@@ -5,6 +5,7 @@
package mocknet_test
import (
+ "bytes"
"errors"
"io"
"net"
@@ -21,6 +22,7 @@
"v.io/v23/verror"
_ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/runtime/internal/rpc/stream/crypto"
"v.io/x/ref/runtime/internal/rpc/stream/message"
"v.io/x/ref/runtime/internal/testing/mocks/mocknet"
"v.io/x/ref/test"
@@ -46,11 +48,10 @@
defer ln.Close()
var rxconn net.Conn
- var rxerr error
var wg sync.WaitGroup
wg.Add(1)
go func() {
- rxconn, rxerr = ln.Accept()
+ rxconn, _ = ln.Accept()
wg.Done()
}()
@@ -119,11 +120,10 @@
defer ln.Close()
var rxconn net.Conn
- var rxerr error
var wg sync.WaitGroup
wg.Add(1)
go func() {
- rxconn, rxerr = ln.Accept()
+ rxconn, _ = ln.Accept()
wg.Done()
}()
@@ -190,11 +190,10 @@
defer ln.Close()
var rxconn net.Conn
- var rxerr error
var wg sync.WaitGroup
wg.Add(1)
go func() {
- rxconn, rxerr = ln.Accept()
+ rxconn, _ = ln.Accept()
wg.Done()
}()
@@ -223,6 +222,108 @@
}
}
+func TestV23Drop(t *testing.T) {
+ cases := []struct {
+ numMsgs, txClose, rxClose int
+ }{
+ {5, 0, 0},
+ {5, 2, 0},
+ {5, 0, 2},
+ {5, 3, 2},
+ {5, 2, 3},
+ }
+
+ for ci, c := range cases {
+ var txed, rxed int
+ matcher := func(read bool, msg message.T) bool {
+ if read {
+ rxed++
+ return rxed == c.rxClose
+ } else {
+ txed++
+ return txed == c.txClose
+ }
+ }
+ opts := mocknet.Opts{
+ Mode: mocknet.V23CloseAtMessage,
+ V23MessageMatcher: matcher,
+ }
+
+ ln := newListener(t, opts)
+ defer ln.Close()
+
+ var rxconn net.Conn
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ rxconn, _ = ln.Accept()
+ wg.Done()
+ }()
+
+ txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wg.Wait()
+
+ var msgBuf bytes.Buffer
+ for i := 0; i < c.numMsgs; i++ {
+ err = message.WriteTo(&msgBuf, &message.Data{}, crypto.NullControlCipher{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ perMsgBytes := msgBuf.Len() / c.numMsgs
+
+ n, err := txconn.Write(msgBuf.Bytes())
+ txMsgs := n / perMsgBytes
+ switch {
+ case c.txClose > 0:
+ if got, want := txMsgs, c.txClose-1; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ if got, want := err, io.EOF; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ default:
+ if got, want := txMsgs, c.numMsgs; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ if err != nil {
+ t.Fatalf("%d: %v\n", ci, err)
+ }
+ }
+
+ var rxMsgs int
+ for ; rxMsgs < txMsgs; rxMsgs++ {
+ var n int
+ n, err = rxconn.Read(make([]byte, perMsgBytes*2))
+ if err != nil {
+ break
+ }
+ if got, want := n, perMsgBytes; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ }
+ switch {
+ case c.rxClose > 0 && (c.txClose == 0 || c.txClose > c.rxClose):
+ if got, want := rxMsgs, c.rxClose-1; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ if got, want := err, io.EOF; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ default:
+ if got, want := rxMsgs, txMsgs; got != want {
+ t.Fatalf("%d: got %v, want %v", ci, got, want)
+ }
+ if err != nil {
+ t.Fatalf("%d: %v\n", ci, err)
+ }
+ }
+ }
+}
+
func newCtx() (*context.T, v23.Shutdown) {
ctx, shutdown := test.InitForTest()
v23.GetNamespace(ctx).CacheCtl(naming.DisableCache(true))
@@ -287,7 +388,7 @@
t.Fatal(err)
}
- _, err = v23.GetClient(ctx).StartCall(ctx, dropServer.Name(), "Ping", nil, options.SecurityNone)
+ _, err = v23.GetClient(ctx).StartCall(ctx, dropServer.Name(), "Ping", nil, options.SecurityNone, options.NoRetry{})
if verror.ErrorID(err) != verror.ErrBadProtocol.ID {
t.Fatal(err)
}