runtime/internal/flow/conn: Shutdown Conns.
In the process of testing this change I fixed several bugs with reading
and writing on flows.
I still need to deal with cancellation and non-trivial flow closing
scenarios.
Change-Id: Ied6f5e9f4d80fa89bc1d638c7e8d1c9ca23d8b5c
diff --git a/runtime/factories/fake/naming.go b/runtime/factories/fake/naming.go
index b4a14ed..e7bb344 100644
--- a/runtime/factories/fake/naming.go
+++ b/runtime/factories/fake/naming.go
@@ -9,11 +9,12 @@
"v.io/v23/namespace"
"v.io/v23/naming"
"v.io/x/ref/lib/apilog"
+ inaming "v.io/x/ref/runtime/internal/naming"
)
func (r *Runtime) NewEndpoint(ep string) (naming.Endpoint, error) {
defer apilog.LogCallf(nil, "ep=%.10s...", ep)(nil, "") // gologcop: DO NOT EDIT, MUST BE FIRST STATEMENT
- panic("unimplemented")
+ return inaming.NewEndpoint(ep)
}
func (r *Runtime) WithNewNamespace(ctx *context.T, roots ...string) (*context.T, namespace.T, error) {
defer apilog.LogCallf(ctx, "roots...=%v", roots)(ctx, "") // gologcop: DO NOT EDIT, MUST BE FIRST STATEMENT
diff --git a/runtime/internal/flow/conn/close_test.go b/runtime/internal/flow/conn/close_test.go
new file mode 100644
index 0000000..7bc4af0
--- /dev/null
+++ b/runtime/internal/flow/conn/close_test.go
@@ -0,0 +1,96 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "v.io/v23"
+ _ "v.io/x/ref/runtime/factories/fake"
+)
+
+func TestRemoteDialerClose(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ d, a, w := setupConns(t, ctx, nil, nil)
+ d.Close(ctx, fmt.Errorf("Closing randomly."))
+ <-d.Closed()
+ <-a.Closed()
+ if !w.isClosed() {
+ t.Errorf("The connection should be closed")
+ }
+}
+
+func TestRemoteAcceptorClose(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ d, a, w := setupConns(t, ctx, nil, nil)
+ a.Close(ctx, fmt.Errorf("Closing randomly."))
+ <-a.Closed()
+ <-d.Closed()
+ if !w.isClosed() {
+ t.Errorf("The connection should be closed")
+ }
+}
+
+func TestUnderlyingConnectionClosed(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ d, a, w := setupConns(t, ctx, nil, nil)
+ w.close()
+ <-a.Closed()
+ <-d.Closed()
+}
+
+func TestDialAfterConnClose(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ d, a, _ := setupConns(t, ctx, nil, nil)
+
+ d.Close(ctx, fmt.Errorf("Closing randomly."))
+ <-d.Closed()
+ <-a.Closed()
+ if _, err := d.Dial(ctx); err == nil {
+ t.Errorf("Nil error dialing on dialer")
+ }
+ if _, err := a.Dial(ctx); err == nil {
+ t.Errorf("Nil error dialing on acceptor")
+ }
+}
+
+func TestReadWriteAfterConnClose(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ for _, dialerDials := range []bool{true, false} {
+ df, flows := setupFlow(t, ctx, dialerDials)
+ if _, err := df.WriteMsg([]byte("hello")); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ af := <-flows
+ if got, err := af.ReadMsg(); err != nil {
+ t.Fatalf("read failed: %v", err)
+ } else if !bytes.Equal(got, []byte("hello")) {
+ t.Errorf("got %s want %s", string(got), "hello")
+ }
+ if _, err := df.WriteMsg([]byte("there")); err != nil {
+ t.Fatalf("second write failed: %v", err)
+ }
+ df.(*flw).conn.Close(ctx, nil)
+ <-af.Conn().Closed()
+ if got, err := af.ReadMsg(); err != nil {
+ t.Fatalf("read failed: %v", err)
+ } else if !bytes.Equal(got, []byte("there")) {
+ t.Errorf("got %s want %s", string(got), "there")
+ }
+ if _, err := df.WriteMsg([]byte("fail")); err == nil {
+ t.Errorf("nil error for write after close.")
+ }
+ if _, err := af.ReadMsg(); err == nil {
+ t.Fatalf("nil error for read after close.")
+ }
+ }
+}
diff --git a/runtime/internal/flow/conn/conn.go b/runtime/internal/flow/conn/conn.go
index 17a5a0a..4e9079c 100644
--- a/runtime/internal/flow/conn/conn.go
+++ b/runtime/internal/flow/conn/conn.go
@@ -5,6 +5,7 @@
package conn
import (
+ "reflect"
"sync"
"v.io/v23"
@@ -31,6 +32,11 @@
tearDownPriority
)
+type MsgReadWriteCloser interface {
+ flow.MsgReadWriter
+ Close() error
+}
+
// FlowHandlers process accepted flows.
type FlowHandler interface {
// HandleFlow processes an accepted flow.
@@ -59,7 +65,7 @@
// NewDialed dials a new Conn on the given conn.
func NewDialed(
ctx *context.T,
- conn flow.MsgReadWriter,
+ conn MsgReadWriteCloser,
local, remote naming.Endpoint,
versions version.RPCVersionRange,
handler FlowHandler,
@@ -73,6 +79,7 @@
dialerPublicKey: principal.PublicKey(),
local: local,
remote: remote,
+ closed: make(chan struct{}),
nextFid: reservedFlows,
flows: map[flowID]*flw{},
}
@@ -83,7 +90,7 @@
// NewAccepted accepts a new Conn on the given conn.
func NewAccepted(
ctx *context.T,
- conn flow.MsgReadWriter,
+ conn MsgReadWriteCloser,
local naming.Endpoint,
lBlessings security.Blessings,
versions version.RPCVersionRange,
@@ -95,6 +102,8 @@
versions: versions,
acceptorBlessings: lBlessings,
local: local,
+ remote: local, // TODO(mattr): Get the real remote endpoint.
+ closed: make(chan struct{}),
nextFid: reservedFlows + 1,
flows: map[flowID]*flw{},
}
@@ -106,22 +115,14 @@
func (c *Conn) Dial(ctx *context.T) (flow.Flow, error) {
defer c.mu.Unlock()
c.mu.Lock()
-
+ if c.flows == nil {
+ return nil, NewErrConnectionClosed(ctx)
+ }
id := c.nextFid
c.nextFid++
-
return c.newFlowLocked(ctx, id), nil
}
-// Closed returns a channel that will be closed after the Conn is shutdown.
-// After this channel is closed it is guaranteed that all Dial calls will fail
-// with an error and no more flows will be sent to the FlowHandler.
-func (c *Conn) Closed() <-chan struct{} { return c.closed }
-
-// Close marks the Conn as closed. All Dial calls will fail with an error and
-// no more flows will be sent to the FlowHandler.
-func (c *Conn) Close() { close(c.closed) }
-
// LocalEndpoint returns the local vanadium Endpoint
func (c *Conn) LocalEndpoint() naming.Endpoint { return c.local }
@@ -138,34 +139,82 @@
// Discharges are organized in a map keyed by the discharge-identifier.
func (c *Conn) AcceptorDischarges() map[string]security.Discharge { return nil }
+// Closed returns a channel that will be closed after the Conn is shutdown.
+// After this channel is closed it is guaranteed that all Dial calls will fail
+// with an error and no more flows will be sent to the FlowHandler.
+func (c *Conn) Closed() <-chan struct{} { return c.closed }
+
+// Close shuts down a conn. This will cause the read loop
+// to exit.
+func (c *Conn) Close(ctx *context.T, err error) {
+ c.mu.Lock()
+ var flows map[flowID]*flw
+ flows, c.flows = c.flows, nil
+ c.mu.Unlock()
+ if flows == nil {
+ // We've already torn this conn down.
+ return
+ }
+ for _, f := range flows {
+ f.close(err)
+ }
+ err = c.fc.Run(ctx, expressPriority, func(_ int) (int, bool, error) {
+ return 0, true, c.mp.writeMsg(ctx, &tearDown{Err: err})
+ })
+ if err != nil {
+ ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, err)
+ }
+ if err = c.mp.close(); err != nil {
+ ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, err)
+ }
+ close(c.closed)
+}
+
+func (c *Conn) release(ctx *context.T) {
+ counts := map[flowID]uint64{}
+ c.mu.Lock()
+ for fid, f := range c.flows {
+ if release := f.q.release(); release > 0 {
+ counts[fid] = uint64(release)
+ }
+ }
+ c.mu.Unlock()
+ if len(counts) == 0 {
+ return
+ }
+
+ err := c.fc.Run(ctx, expressPriority, func(_ int) (int, bool, error) {
+ err := c.mp.writeMsg(ctx, &addRecieveBuffers{
+ counters: counts,
+ })
+ return 0, true, err
+ })
+ if err != nil {
+ c.Close(ctx, NewErrSend(ctx, "addRecieveBuffers", c.remote.String(), err))
+ }
+}
+
func (c *Conn) readLoop(ctx *context.T) {
+ var terr error
+ defer c.Close(ctx, terr)
+
for {
x, err := c.mp.readMsg(ctx)
if err != nil {
- ctx.Errorf("Error reading from connection to %s: %v", c.remote, err)
- // TODO(mattr): tear down the conn.
+ c.Close(ctx, NewErrRecv(ctx, c.remote.String(), err))
+ return
}
switch msg := x.(type) {
case *tearDown:
- // TODO(mattr): tear down the conn.
+ terr = msg.Err
+ return
case *openFlow:
c.mu.Lock()
f := c.newFlowLocked(ctx, msg.id)
c.mu.Unlock()
-
c.handler.HandleFlow(f)
- err := c.fc.Run(ctx, expressPriority, func(_ int) (int, bool, error) {
- err := c.mp.writeMsg(ctx, &addRecieveBuffers{
- counters: map[flowID]uint64{msg.id: defaultBufferSize},
- })
- return 0, true, err
- })
- if err != nil {
- // TODO(mattr): Maybe in this case we should close the conn.
- ctx.Errorf("Error sending counters on connection to %s: %v", c.remote, err)
- }
case *addRecieveBuffers:
release := make([]flowcontrol.Release, 0, len(msg.counters))
@@ -179,8 +228,8 @@
}
}
c.mu.Unlock()
- if err := c.fc.Release(release); err != nil {
- ctx.Errorf("Error releasing counters from connection to %s: %v", c.remote, err)
+ if terr = c.fc.Release(ctx, release); terr != nil {
+ return
}
case *data:
@@ -188,31 +237,34 @@
f := c.flows[msg.id]
c.mu.Unlock()
if f == nil {
- ctx.Errorf("Ignoring data message for unknown flow on connection to %s: %d", c.remote, msg.id)
+ ctx.Infof("Ignoring data message for unknown flow on connection to %s: %d", c.remote, msg.id)
continue
}
- if err := f.q.Put(msg.payload); err != nil {
- ctx.Errorf("Ignoring data message for closed flow on connection to %s: %d", c.remote, msg.id)
+ if terr = f.q.put(ctx, msg.payload); terr != nil {
+ return
}
- // TODO(mattr): perhaps close the flow.
- // TODO(mattr): check if the q is full.
+ if msg.flags&closeFlag != 0 {
+ f.close(nil)
+ }
case *unencryptedData:
c.mu.Lock()
f := c.flows[msg.id]
c.mu.Unlock()
if f == nil {
- ctx.Errorf("Ignoring data message for unknown flow: %d", msg.id)
+ ctx.Infof("Ignoring data message for unknown flow: %d", msg.id)
continue
}
- if err := f.q.Put(msg.payload); err != nil {
- ctx.Errorf("Ignoring data message for closed flow: %d", msg.id)
+ if terr = f.q.put(ctx, msg.payload); terr != nil {
+ return
}
- // TODO(mattr): perhaps close the flow.
- // TODO(mattr): check if the q is full.
+ if msg.flags&closeFlag != 0 {
+ f.close(nil)
+ }
default:
- // TODO(mattr): tearDown the conn.
+ terr = NewErrUnexpectedMsg(ctx, reflect.TypeOf(msg).Name())
+ return
}
}
}
diff --git a/runtime/internal/flow/conn/conn_test.go b/runtime/internal/flow/conn/conn_test.go
index 4fb158b..2c3ca40 100644
--- a/runtime/internal/flow/conn/conn_test.go
+++ b/runtime/internal/flow/conn/conn_test.go
@@ -5,70 +5,74 @@
package conn
import (
+ "bytes"
+ "crypto/rand"
+ "io"
"testing"
"v.io/v23"
"v.io/v23/context"
"v.io/v23/flow"
- "v.io/v23/rpc/version"
- "v.io/v23/security"
_ "v.io/x/ref/runtime/factories/fake"
"v.io/x/ref/test"
)
+var randData []byte
+
func init() {
test.Init()
+
+ randData = make([]byte, 2*defaultBufferSize)
+ if _, err := rand.Read(randData); err != nil {
+ panic("Could not read random data.")
+ }
}
-func setupConns(t *testing.T, ctx *context.T, dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn) {
- dmrw, amrw := newMRWPair(ctx)
- versions := version.RPCVersionRange{Min: 3, Max: 5}
- d, err := NewDialed(ctx, dmrw, nil, nil, versions, fh(dflows), nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- a, err := NewAccepted(ctx, amrw, nil, security.Blessings{}, versions, fh(aflows))
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- return d, a
-}
+func testWrite(t *testing.T, ctx *context.T, want []byte, df flow.Flow, flows <-chan flow.Flow) {
+ finished := make(chan struct{})
+ go func(x []byte) {
+ mid := len(x) / 2
+ wrote, err := df.WriteMsgAndClose(x[:mid], x[mid:])
+ if err != nil {
+ t.Fatalf("Unexpected error for write: %v", err)
+ }
+ if wrote != len(x) {
+ t.Errorf("got %d want %d", wrote, len(x))
+ }
+ close(finished)
+ }(want)
-func testWrite(t *testing.T, ctx *context.T, dialer *Conn, flows <-chan flow.Flow) {
- df, err := dialer.Dial(ctx)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- want := "hello world"
- df.WriteMsgAndClose([]byte(want[:5]), []byte(want[5:]))
af := <-flows
- msg, err := af.ReadMsg()
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
+ for len(want) > 0 {
+ got, err := af.ReadMsg()
+ if err != nil && err != io.EOF {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ if !bytes.Equal(got, want[:len(got)]) {
+ t.Fatalf("Got: %s want %s", got, want)
+ }
+ want = want[len(got):]
}
- if got := string(msg); got != want {
- t.Errorf("Got: %s want %s", got, want)
+ if len(want) != 0 {
+ t.Errorf("got %d leftover bytes, expected 0.", len(want))
+ }
+ <-finished
+ <-df.Closed()
+ <-af.Closed()
+}
+
+func TestDial(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ for _, dialerDials := range []bool{true, false} {
+ df, flows := setupFlow(t, ctx, dialerDials)
+ testWrite(t, ctx, []byte("hello world"), df, flows)
}
}
-func TestDailerDialsFlow(t *testing.T) {
+func TestLargeWrite(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- aflows := make(chan flow.Flow, 1)
- d, _ := setupConns(t, ctx, nil, aflows)
- testWrite(t, ctx, d, aflows)
+ df, flows := setupFlow(t, ctx, true)
+ testWrite(t, ctx, randData, df, flows)
}
-
-func TestAcceptorDialsFlow(t *testing.T) {
- ctx, shutdown := v23.Init()
- defer shutdown()
- dflows := make(chan flow.Flow, 1)
- _, a := setupConns(t, ctx, dflows, nil)
- testWrite(t, ctx, a, dflows)
-}
-
-// TODO(mattr): List of tests to write
-// 1. multiple writes
-// 2. interleave writemsg and write
-// 3. interleave read and readmsg
-// 4. multiple reads
diff --git a/runtime/internal/flow/conn/conncache.go b/runtime/internal/flow/conn/conncache.go
index 34b42c1..b435d9c 100644
--- a/runtime/internal/flow/conn/conncache.go
+++ b/runtime/internal/flow/conn/conncache.go
@@ -8,6 +8,7 @@
"strings"
"sync"
+ "v.io/v23/context"
"v.io/v23/naming"
)
@@ -112,13 +113,13 @@
}
// Close marks the ConnCache as closed and closes all Conns in the cache.
-func (c *ConnCache) Close() {
+func (c *ConnCache) Close(ctx *context.T) {
defer c.mu.Unlock()
c.mu.Lock()
c.addrCache, c.ridCache, c.started = nil, nil, nil
d := c.head.next
for d != c.head {
- d.conn.Close()
+ d.conn.Close(ctx, nil)
d = d.next
}
c.head = nil
@@ -129,7 +130,7 @@
// If num is greater than the number of connections in the cache, all cached
// connections will be closed and removed.
// KillConnections returns an error iff the cache is closed.
-func (c *ConnCache) KillConnections(num int) error {
+func (c *ConnCache) KillConnections(ctx *context.T, num int) error {
defer c.mu.Unlock()
c.mu.Lock()
if c.addrCache == nil {
@@ -140,7 +141,7 @@
if d == c.head {
break
}
- d.conn.Close()
+ d.conn.Close(ctx, nil)
delete(c.addrCache, d.addrKey)
delete(c.ridCache, d.rid)
prev := d.prev
diff --git a/runtime/internal/flow/conn/conncache_test.go b/runtime/internal/flow/conn/conncache_test.go
index 0f9bcaf..a3c8b80 100644
--- a/runtime/internal/flow/conn/conncache_test.go
+++ b/runtime/internal/flow/conn/conncache_test.go
@@ -8,12 +8,18 @@
"strconv"
"testing"
+ "v.io/v23"
+ "v.io/v23/context"
"v.io/v23/naming"
+ "v.io/v23/rpc/version"
inaming "v.io/x/ref/runtime/internal/naming"
)
func TestCache(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+
c := NewConnCache()
remote := &inaming.Endpoint{
Protocol: "tcp",
@@ -21,10 +27,7 @@
RID: naming.FixedRoutingID(0x5555),
Blessings: []string{"A", "B", "C"},
}
- conn := &Conn{
- remote: remote,
- closed: make(chan struct{}),
- }
+ conn := makeConn(t, ctx, remote)
if err := c.Insert(conn); err != nil {
t.Fatal(err)
}
@@ -63,10 +66,8 @@
Address: "other",
Blessings: []string{"other"},
}
- otherConn := &Conn{
- remote: otherEP,
- closed: make(chan struct{}),
- }
+ otherConn := makeConn(t, ctx, otherEP)
+
// Looking up a not yet inserted endpoint should fail.
if got, err := c.ReservedFind(otherEP.Protocol, otherEP.Address, otherEP.Blessings); err != nil || got != nil {
t.Errorf("got %v, want <nil>, err: %v", got, err)
@@ -99,26 +100,25 @@
if isClosed(otherConn) {
t.Fatalf("wanted otherConn to not be closed")
}
- c.Close()
+ c.Close(ctx)
// Now the connections should be closed.
- if !isClosed(conn) {
- t.Errorf("wanted conn to be closed")
- }
- if !isClosed(otherConn) {
- t.Errorf("wanted otherConn to be closed")
- }
+ <-conn.Closed()
+ <-otherConn.Closed()
}
func TestLRU(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+
// Ensure that the least recently inserted conns are killed by KillConnections.
c := NewConnCache()
- conns := nConns(10)
+ conns := nConns(t, ctx, 10)
for _, conn := range conns {
if err := c.Insert(conn); err != nil {
t.Fatal(err)
}
}
- if err := c.KillConnections(3); err != nil {
+ if err := c.KillConnections(ctx, 3); err != nil {
t.Fatal(err)
}
if !cacheSizeMatches(c) {
@@ -135,9 +135,7 @@
}
}
for _, conn := range conns[:3] {
- if !isClosed(conn) {
- t.Errorf("conn %v should have been closed", conn)
- }
+ <-conn.Closed()
if isInCache(t, c, conn) {
t.Errorf("conn %v should not be in cache", conn)
}
@@ -145,7 +143,7 @@
// Ensure that ReservedFind marks conns as more recently used.
c = NewConnCache()
- conns = nConns(10)
+ conns = nConns(t, ctx, 10)
for _, conn := range conns {
if err := c.Insert(conn); err != nil {
t.Fatal(err)
@@ -157,7 +155,7 @@
}
c.Unreserve(conn.remote.Addr().Network(), conn.remote.Addr().String(), conn.remote.BlessingNames())
}
- if err := c.KillConnections(3); err != nil {
+ if err := c.KillConnections(ctx, 3); err != nil {
t.Fatal(err)
}
if !cacheSizeMatches(c) {
@@ -174,9 +172,7 @@
}
}
for _, conn := range conns[7:] {
- if !isClosed(conn) {
- t.Errorf("conn %v should have been closed", conn)
- }
+ <-conn.Closed()
if isInCache(t, c, conn) {
t.Errorf("conn %v should not be in cache", conn)
}
@@ -184,7 +180,7 @@
// Ensure that FindWithRoutingID marks conns as more recently used.
c = NewConnCache()
- conns = nConns(10)
+ conns = nConns(t, ctx, 10)
for _, conn := range conns {
if err := c.Insert(conn); err != nil {
t.Fatal(err)
@@ -195,7 +191,7 @@
t.Errorf("got %v, want %v, err: %v", got, conn, err)
}
}
- if err := c.KillConnections(3); err != nil {
+ if err := c.KillConnections(ctx, 3); err != nil {
t.Fatal(err)
}
if !cacheSizeMatches(c) {
@@ -212,9 +208,7 @@
}
}
for _, conn := range conns[7:] {
- if !isClosed(conn) {
- t.Errorf("conn %v should have been closed", conn)
- }
+ <-conn.Closed()
if isInCache(t, c, conn) {
t.Errorf("conn %v should not be in cache", conn)
}
@@ -249,16 +243,22 @@
return size
}
-func nConns(n int) []*Conn {
+func nConns(t *testing.T, ctx *context.T, n int) []*Conn {
conns := make([]*Conn, n)
for i := 0; i < n; i++ {
- conns[i] = &Conn{
- remote: &inaming.Endpoint{
- Protocol: strconv.Itoa(i),
- RID: naming.FixedRoutingID(uint64(i)),
- },
- closed: make(chan struct{}),
- }
+ conns[i] = makeConn(t, ctx, &inaming.Endpoint{
+ Protocol: strconv.Itoa(i),
+ RID: naming.FixedRoutingID(uint64(i)),
+ })
}
return conns
}
+
+func makeConn(t *testing.T, ctx *context.T, ep naming.Endpoint) *Conn {
+ d, _, _ := newMRWPair(ctx)
+ c, err := NewDialed(ctx, d, ep, ep, version.RPCVersionRange{Min: 1, Max: 5}, nil, nil)
+ if err != nil {
+ t.Fatalf("Could not create conn: %v", err)
+ }
+ return c
+}
diff --git a/runtime/internal/flow/conn/errors.vdl b/runtime/internal/flow/conn/errors.vdl
index eb0d216..98938d4 100644
--- a/runtime/internal/flow/conn/errors.vdl
+++ b/runtime/internal/flow/conn/errors.vdl
@@ -11,10 +11,16 @@
// TODO(suharshs,toddw): Allow skipping of {1}{2} in vdl generated errors.
error (
InvalidMsg(typ byte, size, field int64) {
- "en":"message of type{:typ} and size{:size} failed decoding at field{:field}."}
+ "en": "message of type{:typ} and size{:size} failed decoding at field{:field}."}
InvalidControlMsg(cmd byte, size, field int64) {
- "en":"control message of cmd{:cmd} and size{:size} failed decoding at field{:field}."}
+ "en": "control message of cmd{:cmd} and size{:size} failed decoding at field{:field}."}
UnknownMsg(typ byte) {"en":"unknown message type{:typ}."}
- UnknownControlMsg(cmd byte) {"en":"unknown control command{:cmd}."}
+ UnknownControlMsg(cmd byte) {"en": "unknown control command{:cmd}."}
+
+ UnexpectedMsg(typ string) {"en": "unexpected message type{:typ}."}
+ ConnectionClosed() {"en": "connection closed."}
+ Send(typ, dest string, err error) {"en": "failure sending {typ} message to {dest}{:err}."}
+ Recv(src string, err error) {"en": "error reading from {src}{:err}"}
CacheClosed() {"en":"cache is closed"}
+ CounterOverflow() {"en": "A remote process has sent more data than allowed."}
)
diff --git a/runtime/internal/flow/conn/errors.vdl.go b/runtime/internal/flow/conn/errors.vdl.go
index de2eda9..e26dc9c 100644
--- a/runtime/internal/flow/conn/errors.vdl.go
+++ b/runtime/internal/flow/conn/errors.vdl.go
@@ -19,7 +19,12 @@
ErrInvalidControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidControlMsg", verror.NoRetry, "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
ErrUnknownMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownMsg", verror.NoRetry, "{1:}{2:} unknown message type{:3}.")
ErrUnknownControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownControlMsg", verror.NoRetry, "{1:}{2:} unknown control command{:3}.")
+ ErrUnexpectedMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnexpectedMsg", verror.NoRetry, "{1:}{2:} unexpected message type{:3}.")
+ ErrConnectionClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.ConnectionClosed", verror.NoRetry, "{1:}{2:} connection closed.")
+ ErrSend = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Send", verror.NoRetry, "{1:}{2:} failure sending {3} message to {4}{:5}.")
+ ErrRecv = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Recv", verror.NoRetry, "{1:}{2:} error reading from {3}{:4}")
ErrCacheClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CacheClosed", verror.NoRetry, "{1:}{2:} cache is closed")
+ ErrCounterOverflow = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CounterOverflow", verror.NoRetry, "{1:}{2:} A remote process has sent more data than allowed.")
)
func init() {
@@ -27,7 +32,12 @@
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidControlMsg.ID), "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownMsg.ID), "{1:}{2:} unknown message type{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownControlMsg.ID), "{1:}{2:} unknown control command{:3}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnexpectedMsg.ID), "{1:}{2:} unexpected message type{:3}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrConnectionClosed.ID), "{1:}{2:} connection closed.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrSend.ID), "{1:}{2:} failure sending {3} message to {4}{:5}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrRecv.ID), "{1:}{2:} error reading from {3}{:4}")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrCacheClosed.ID), "{1:}{2:} cache is closed")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrCounterOverflow.ID), "{1:}{2:} A remote process has sent more data than allowed.")
}
// NewErrInvalidMsg returns an error with the ErrInvalidMsg ID.
@@ -50,7 +60,32 @@
return verror.New(ErrUnknownControlMsg, ctx, cmd)
}
+// NewErrUnexpectedMsg returns an error with the ErrUnexpectedMsg ID.
+func NewErrUnexpectedMsg(ctx *context.T, typ string) error {
+ return verror.New(ErrUnexpectedMsg, ctx, typ)
+}
+
+// NewErrConnectionClosed returns an error with the ErrConnectionClosed ID.
+func NewErrConnectionClosed(ctx *context.T) error {
+ return verror.New(ErrConnectionClosed, ctx)
+}
+
+// NewErrSend returns an error with the ErrSend ID.
+func NewErrSend(ctx *context.T, typ string, dest string, err error) error {
+ return verror.New(ErrSend, ctx, typ, dest, err)
+}
+
+// NewErrRecv returns an error with the ErrRecv ID.
+func NewErrRecv(ctx *context.T, src string, err error) error {
+ return verror.New(ErrRecv, ctx, src, err)
+}
+
// NewErrCacheClosed returns an error with the ErrCacheClosed ID.
func NewErrCacheClosed(ctx *context.T) error {
return verror.New(ErrCacheClosed, ctx)
}
+
+// NewErrCounterOverflow returns an error with the ErrCounterOverflow ID.
+func NewErrCounterOverflow(ctx *context.T) error {
+ return verror.New(ErrCounterOverflow, ctx)
+}
diff --git a/runtime/internal/flow/conn/flow.go b/runtime/internal/flow/conn/flow.go
index 6ea0246..804f472 100644
--- a/runtime/internal/flow/conn/flow.go
+++ b/runtime/internal/flow/conn/flow.go
@@ -10,18 +10,16 @@
"v.io/v23/security"
"v.io/x/ref/runtime/internal/flow/flowcontrol"
- "v.io/x/ref/runtime/internal/lib/upcqueue"
)
type flw struct {
id flowID
ctx *context.T
+ cancel context.CancelFunc
conn *Conn
- closed chan struct{}
worker *flowcontrol.Worker
opened bool
- q *upcqueue.T
- readBufs [][]byte
+ q *readq
dialerBlessings security.Blessings
dialerDischarges map[string]security.Discharge
}
@@ -31,12 +29,11 @@
func (c *Conn) newFlowLocked(ctx *context.T, id flowID) *flw {
f := &flw{
id: id,
- ctx: ctx,
conn: c,
- closed: make(chan struct{}),
worker: c.fc.NewWorker(flowPriority),
- q: upcqueue.New(),
+ q: newReadQ(),
}
+ f.SetContext(ctx)
c.flows[id] = f
return f
}
@@ -49,19 +46,10 @@
// Read and ReadMsg should not be called concurrently with themselves
// or each other.
func (f *flw) Read(p []byte) (n int, err error) {
- for {
- for len(f.readBufs) > 0 && len(f.readBufs[0]) == 0 {
- f.readBufs = f.readBufs[1:]
- }
- if len(f.readBufs) > 0 {
- break
- }
- var msg interface{}
- msg, err = f.q.Get(f.ctx.Done())
- f.readBufs = msg.([][]byte)
+ var release bool
+ if n, release, err = f.q.read(f.ctx, p); release {
+ f.conn.release(f.ctx)
}
- n = copy(p, f.readBufs[0])
- f.readBufs[0] = f.readBufs[0][n:]
return
}
@@ -70,19 +58,14 @@
// Read and ReadMsg should not be called concurrently with themselves
// or each other.
func (f *flw) ReadMsg() (buf []byte, err error) {
- for {
- for len(f.readBufs) > 0 {
- buf, f.readBufs = f.readBufs[0], f.readBufs[1:]
- if len(buf) > 0 {
- return buf, nil
- }
- }
- bufs, err := f.q.Get(f.ctx.Done())
- if err != nil {
- return nil, err
- }
- f.readBufs = bufs.([][]byte)
+ var release bool
+ // TODO(mattr): Currently we only ever release counters when some flow
+ // reads. We may need to do it more or less often. Currently
+ // we'll send counters whenever a new flow is opened.
+ if buf, release, err = f.q.get(f.ctx); release {
+ f.conn.release(f.ctx)
}
+ return
}
// Implement io.Writer.
@@ -95,10 +78,7 @@
func (f *flw) writeMsg(alsoClose bool, parts ...[]byte) (int, error) {
sent := 0
var left []byte
-
- f.ctx.VI(3).Infof("trying to write: %d.", f.id)
err := f.worker.Run(f.ctx, func(tokens int) (int, bool, error) {
- f.ctx.VI(3).Infof("writing: %d.", f.id)
if !f.opened {
// TODO(mattr): we should be able to send multiple messages
// in a single writeMsg call.
@@ -116,6 +96,7 @@
if len(left) > 0 {
size += len(left)
bufs = append(bufs, left)
+ left = nil
}
for size <= tokens && len(parts) > 0 {
bufs = append(bufs, parts[0])
@@ -128,6 +109,7 @@
take := len(last) - (size - tokens)
bufs[lidx] = last[:take]
left = last[take:]
+ size = tokens
}
d := &data{
id: f.id,
@@ -140,6 +122,9 @@
sent += size
return size, done, f.conn.mp.writeMsg(f.ctx, d)
})
+ if alsoClose || err != nil {
+ f.close(err)
+ }
return sent, err
}
@@ -163,6 +148,7 @@
// SetContext sets the context associated with the flow. Typically this is
// used to set state that is only available after the flow is connected, such
// as a more restricted flow timeout, or the language of the request.
+// Calling SetContext may invalidate values previously returned from Closed.
//
// The flow.Manager associated with ctx must be the same flow.Manager that the
// flow was dialed or accepted from, otherwise an error is returned.
@@ -171,7 +157,10 @@
// TODO(mattr): update v23/flow documentation.
// SetContext may not be called concurrently with other methods.
func (f *flw) SetContext(ctx *context.T) error {
- f.ctx = ctx
+ if f.cancel != nil {
+ f.cancel()
+ }
+ f.ctx, f.cancel = context.WithCancel(ctx)
return nil
}
@@ -220,8 +209,21 @@
return f.conn
}
-// Closed returns a channel that remains open until the flow has been closed or
-// the ctx to the Dial or Accept call used to create the flow has been cancelled.
+// Closed returns a channel that remains open until the flow has been closed remotely
+// or the context attached to the flow has been canceled.
+//
+// Note that after the returned channel is closed starting new writes will result
+// in an error, but reads of previously queued data are still possible. No
+// new data will be queued.
+// TODO(mattr): update v23/flow docs.
func (f *flw) Closed() <-chan struct{} {
- return f.closed
+ return f.ctx.Done()
+}
+
+func (f *flw) close(err error) {
+ f.q.close(f.ctx)
+ f.cancel()
+
+ // TODO(mattr): maybe send a final close data message.
+ // TODO(mattr): save the error to hand out later.
}
diff --git a/runtime/internal/flow/conn/message.go b/runtime/internal/flow/conn/message.go
index c4f40eb..d30c52e 100644
--- a/runtime/internal/flow/conn/message.go
+++ b/runtime/internal/flow/conn/message.go
@@ -8,7 +8,6 @@
"errors"
"v.io/v23/context"
- "v.io/v23/flow"
"v.io/v23/naming"
"v.io/v23/rpc/version"
)
@@ -93,10 +92,16 @@
}
func (m *tearDown) write(ctx *context.T, p *messagePipe) error {
- return p.write([][]byte{{controlType}}, [][]byte{{tearDownCmd}, []byte(m.Err.Error())})
+ var errBytes []byte
+ if m.Err != nil {
+ errBytes = []byte(m.Err.Error())
+ }
+ return p.write([][]byte{{controlType}}, [][]byte{{tearDownCmd}, errBytes})
}
func (m *tearDown) read(ctx *context.T, data []byte) error {
- m.Err = errors.New(string(data))
+ if len(data) > 0 {
+ m.Err = errors.New(string(data))
+ }
return nil
}
@@ -148,6 +153,9 @@
fid, val uint64
n int64
)
+ if len(data) == 0 {
+ return nil
+ }
m.counters = map[flowID]uint64{}
for len(data) > 0 {
if fid, data, valid = readVarUint64(ctx, data); !valid {
@@ -188,7 +196,9 @@
if m.flags, data, valid = readVarUint64(ctx, data); !valid {
return NewErrInvalidMsg(ctx, dataType, int64(len(orig)), 1)
}
- m.payload = [][]byte{data}
+ if len(data) > 0 {
+ m.payload = [][]byte{data}
+ }
return nil
}
@@ -224,7 +234,9 @@
if plen > len(data) {
return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 1)
}
- m.payload, data = [][]byte{data[:plen]}, data[plen:]
+ if plen > 0 {
+ m.payload, data = [][]byte{data[:plen]}, data[plen:]
+ }
if v, data, valid = readVarUint64(ctx, data); !valid {
return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 2)
}
@@ -236,13 +248,13 @@
}
type messagePipe struct {
- rw flow.MsgReadWriter
+ rw MsgReadWriteCloser
controlBuf []byte
dataBuf []byte
outBuf [][]byte
}
-func newMessagePipe(rw flow.MsgReadWriter) *messagePipe {
+func newMessagePipe(rw MsgReadWriteCloser) *messagePipe {
return &messagePipe{
rw: rw,
controlBuf: make([]byte, 256),
@@ -251,6 +263,10 @@
}
}
+func (p *messagePipe) close() error {
+ return p.rw.Close()
+}
+
func (p *messagePipe) write(unencrypted [][]byte, encrypted [][]byte) error {
p.outBuf = append(p.outBuf[:0], unencrypted...)
p.outBuf = append(p.outBuf, encrypted...)
diff --git a/runtime/internal/flow/conn/message_test.go b/runtime/internal/flow/conn/message_test.go
index caf0d7e..416ba3d 100644
--- a/runtime/internal/flow/conn/message_test.go
+++ b/runtime/internal/flow/conn/message_test.go
@@ -52,7 +52,7 @@
func testMessages(t *testing.T, cases []message) {
ctx, shutdown := v23.Init()
defer shutdown()
- w, r := newMRWPair(ctx)
+ w, r, _ := newMRWPair(ctx)
wp, rp := newMessagePipe(w), newMessagePipe(r)
for _, want := range cases {
ch := make(chan struct{})
@@ -76,24 +76,27 @@
func TestSetup(t *testing.T) {
testMessages(t, []message{
&setup{versions: version.RPCVersionRange{Min: 3, Max: 5}},
+ &setup{},
})
}
func TestTearDown(t *testing.T) {
testMessages(t, []message{
&tearDown{Err: errors.New("foobar")},
+ &tearDown{},
})
}
func TestOpenFlow(t *testing.T) {
testMessages(t, []message{
&openFlow{id: 23, initialCounters: 1 << 20},
+ &openFlow{},
})
}
func TestAddReceiveBuffers(t *testing.T) {
testMessages(t, []message{
- &addRecieveBuffers{counters: map[flowID]uint64{}},
+ &addRecieveBuffers{},
&addRecieveBuffers{counters: map[flowID]uint64{
4: 233,
9: 423242,
@@ -104,11 +107,13 @@
func TestData(t *testing.T) {
testMessages(t, []message{
&data{id: 1123, flags: 232, payload: [][]byte{[]byte("fake payload")}},
+ &data{},
})
}
func TestUnencryptedData(t *testing.T) {
testMessages(t, []message{
&unencryptedData{id: 1123, flags: 232, payload: [][]byte{[]byte("fake payload")}},
+ &unencryptedData{},
})
}
diff --git a/runtime/internal/flow/conn/readq.go b/runtime/internal/flow/conn/readq.go
new file mode 100644
index 0000000..a2c09e6
--- /dev/null
+++ b/runtime/internal/flow/conn/readq.go
@@ -0,0 +1,152 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+ "io"
+ "sync"
+
+ "v.io/v23/context"
+)
+
+type readq struct {
+ mu sync.Mutex
+ bufs [][]byte
+ b, e int
+
+ size int
+ toRelease int
+ notify chan struct{}
+}
+
+const initialReadqBufferSize = 10
+
+func newReadQ() *readq {
+ return &readq{
+ bufs: make([][]byte, initialReadqBufferSize),
+ notify: make(chan struct{}, 1),
+ toRelease: defaultBufferSize,
+ }
+}
+
+func (r *readq) put(ctx *context.T, bufs [][]byte) error {
+ if len(bufs) == 0 {
+ return nil
+ }
+ l := 0
+ for _, b := range bufs {
+ l += len(b)
+ }
+ if l == 0 {
+ return nil
+ }
+
+ defer r.mu.Unlock()
+ r.mu.Lock()
+ if r.e == -1 {
+ // The flow has already closed. Simply drop the data.
+ return nil
+ }
+ newSize := l + r.size
+ if newSize > defaultBufferSize {
+ return NewErrCounterOverflow(ctx)
+ }
+ if r.e == r.b {
+ r.reserveLocked(len(bufs))
+ }
+ for _, b := range bufs {
+ r.bufs[r.e] = b
+ r.e = (r.e + 1) % len(r.bufs)
+ }
+ if r.size == 0 {
+ select {
+ case r.notify <- struct{}{}:
+ default:
+ }
+ }
+ r.size = newSize
+ return nil
+}
+
+func (r *readq) read(ctx *context.T, data []byte) (n int, release bool, err error) {
+ defer r.mu.Unlock()
+ r.mu.Lock()
+ if err := r.waitLocked(ctx); err != nil {
+ return 0, false, err
+ }
+ buf := r.bufs[r.b]
+ n = copy(data, buf)
+ buf = buf[n:]
+ if len(buf) > 0 {
+ r.bufs[r.b] = buf
+ } else {
+ r.b = (r.b + 1) % len(r.bufs)
+ }
+ r.size -= n
+ r.toRelease += n
+ return n, r.toRelease > defaultBufferSize/2, nil
+}
+
+func (r *readq) get(ctx *context.T) (out []byte, release bool, err error) {
+ defer r.mu.Unlock()
+ r.mu.Lock()
+ if err := r.waitLocked(ctx); err != nil {
+ return nil, false, err
+ }
+ out = r.bufs[r.b]
+ r.b = (r.b + 1) % len(r.bufs)
+ r.size -= len(out)
+ r.toRelease += len(out)
+ return out, r.toRelease > defaultBufferSize/2, nil
+}
+
+func (r *readq) waitLocked(ctx *context.T) (err error) {
+ for r.size == 0 && err == nil {
+ r.mu.Unlock()
+ select {
+ case _, ok := <-r.notify:
+ if !ok {
+ err = io.EOF
+ }
+ case <-ctx.Done():
+ if r.size == 0 {
+ err = ctx.Err()
+ }
+ }
+ r.mu.Lock()
+ }
+ return
+}
+
+func (r *readq) close(ctx *context.T) {
+ r.mu.Lock()
+ if r.e != -1 {
+ r.e = -1
+ r.toRelease = 0
+ close(r.notify)
+ }
+ r.mu.Unlock()
+}
+
+func (r *readq) reserveLocked(n int) {
+ needed := n + r.e - r.b
+ if r.e < r.b {
+ needed += len(r.bufs)
+ }
+ if needed < len(r.bufs) {
+ return
+ }
+ nb := make([][]byte, 2*needed)
+ copied := copy(nb, r.bufs[r.b:])
+ copied += copy(nb[n:], r.bufs[:r.e])
+ r.bufs, r.b, r.e = nb, 0, copied
+}
+
+func (r *readq) release() (out int) {
+ r.mu.Lock()
+ out, r.toRelease = r.toRelease, 0
+ r.mu.Unlock()
+ return out
+}
diff --git a/runtime/internal/flow/conn/readq_test.go b/runtime/internal/flow/conn/readq_test.go
new file mode 100644
index 0000000..865b246
--- /dev/null
+++ b/runtime/internal/flow/conn/readq_test.go
@@ -0,0 +1,109 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+ "io"
+ "testing"
+
+ "v.io/v23"
+)
+
+func mkBufs(in ...string) [][]byte {
+ out := make([][]byte, len(in))
+ for i, s := range in {
+ out[i] = []byte(s)
+ }
+ return out
+}
+
+func TestReadqRead(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+
+ r := newReadQ()
+ r.put(ctx, mkBufs("one", "two"))
+ r.put(ctx, mkBufs("thre", "reallong"))
+ r.close(ctx)
+
+ read := make([]byte, 4)
+ want := []string{"one", "two", "thre", "real", "long"}
+ for _, w := range want {
+ n, _, err := r.read(ctx, read)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got := string(read[:n]); got != w {
+ t.Errorf("got: %s, want %s", got, w)
+ }
+ }
+ if _, _, err := r.read(ctx, read); err != io.EOF {
+ t.Errorf("expected EOF got %v", err)
+ }
+}
+
+func TestReadqGet(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+
+ r := newReadQ()
+ r.put(ctx, mkBufs("one", "two"))
+ r.put(ctx, mkBufs("thre", "reallong"))
+ r.close(ctx)
+
+ want := []string{"one", "two", "thre", "reallong"}
+ for _, w := range want {
+ out, _, err := r.get(ctx)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got := string(out); got != w {
+ t.Errorf("got: %s, want %s", got, w)
+ }
+ }
+ if _, _, err := r.get(ctx); err != io.EOF {
+ t.Errorf("expected EOF got %v", err)
+ }
+}
+
+func TestReadqMixed(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+
+ r := newReadQ()
+ r.put(ctx, mkBufs("one", "two"))
+ r.put(ctx, mkBufs("thre", "reallong"))
+ r.close(ctx)
+
+ want := []string{"one", "two", "thre", "real", "long"}
+ for i, w := range want {
+ var (
+ err error
+ got string
+ n int
+ out []byte
+ read = make([]byte, 4)
+ )
+ if i%2 == 0 {
+ out, _, err = r.get(ctx)
+ got = string(out)
+ } else {
+ n, _, err = r.read(ctx, read)
+ got = string(read[:n])
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != w {
+ t.Errorf("got: %s, want %s", got, w)
+ }
+ }
+ if _, _, err := r.get(ctx); err != io.EOF {
+ t.Errorf("expected EOF got %v", err)
+ }
+ if _, _, err := r.read(ctx, nil); err != io.EOF {
+ t.Errorf("expected EOF got %v", err)
+ }
+}
diff --git a/runtime/internal/flow/conn/util_test.go b/runtime/internal/flow/conn/util_test.go
index 5123e3e..18cc267 100644
--- a/runtime/internal/flow/conn/util_test.go
+++ b/runtime/internal/flow/conn/util_test.go
@@ -5,21 +5,50 @@
package conn
import (
+ "io"
+ "sync"
+ "testing"
+
+ "v.io/v23"
"v.io/v23/context"
"v.io/v23/flow"
+ "v.io/v23/rpc/version"
+ "v.io/v23/security"
)
-type mRW struct {
- recieve <-chan []byte
- send chan<- []byte
- ctx *context.T
+type wire struct {
+ ctx *context.T
+ mu sync.Mutex
+ c *sync.Cond
+ closed bool
}
-func newMRWPair(ctx *context.T) (flow.MsgReadWriter, flow.MsgReadWriter) {
- ac, bc := make(chan []byte), make(chan []byte)
- a := &mRW{recieve: ac, send: bc, ctx: ctx}
- b := &mRW{recieve: bc, send: ac, ctx: ctx}
- return a, b
+func (w *wire) close() {
+ w.mu.Lock()
+ w.closed = true
+ w.c.Broadcast()
+ w.mu.Unlock()
+}
+
+func (w *wire) isClosed() bool {
+ w.mu.Lock()
+ c := w.closed
+ w.mu.Unlock()
+ return c
+}
+
+type mRW struct {
+ wire *wire
+ in []byte
+ peer *mRW
+}
+
+func newMRWPair(ctx *context.T) (MsgReadWriteCloser, MsgReadWriteCloser, *wire) {
+ w := &wire{ctx: ctx}
+ w.c = sync.NewCond(&w.mu)
+ a, b := &mRW{wire: w}, &mRW{wire: w}
+ a.peer, b.peer = b, a
+ return a, b, w
}
func (f *mRW) WriteMsg(data ...[]byte) (int, error) {
@@ -27,15 +56,35 @@
for _, d := range data {
buf = append(buf, d...)
}
- f.send <- buf
- f.ctx.VI(5).Infof("Wrote: %v", buf)
+ defer f.wire.mu.Unlock()
+ f.wire.mu.Lock()
+ for f.peer.in != nil && !f.wire.closed {
+ f.wire.c.Wait()
+ }
+ if f.wire.closed {
+ return 0, io.EOF
+ }
+ f.peer.in = buf
+ f.wire.c.Broadcast()
return len(buf), nil
}
func (f *mRW) ReadMsg() (buf []byte, err error) {
- buf = <-f.recieve
- f.ctx.VI(5).Infof("Read: %v", buf)
+ defer f.wire.mu.Unlock()
+ f.wire.mu.Lock()
+ for f.in == nil && !f.wire.closed {
+ f.wire.c.Wait()
+ }
+ if f.wire.closed {
+ return nil, io.EOF
+ }
+ buf, f.in = f.in, nil
+ f.wire.c.Broadcast()
return buf, nil
}
+func (f *mRW) Close() error {
+ f.wire.close()
+ return nil
+}
type fh chan<- flow.Flow
@@ -46,3 +95,35 @@
fh <- f
return nil
}
+
+func setupConns(t *testing.T, ctx *context.T, dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn, _ *wire) {
+ dmrw, amrw, w := newMRWPair(ctx)
+ versions := version.RPCVersionRange{Min: 3, Max: 5}
+ ep, err := v23.NewEndpoint("localhost:80")
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ d, err := NewDialed(ctx, dmrw, ep, ep, versions, fh(dflows), nil)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ a, err := NewAccepted(ctx, amrw, ep, security.Blessings{}, versions, fh(aflows))
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ return d, a, w
+}
+
+func setupFlow(t *testing.T, ctx *context.T, dialFromDialer bool) (dialed flow.Flow, accepted <-chan flow.Flow) {
+ dflows, aflows := make(chan flow.Flow, 1), make(chan flow.Flow, 1)
+ d, a, _ := setupConns(t, ctx, dflows, aflows)
+ if !dialFromDialer {
+ d, a = a, d
+ aflows, dflows = dflows, aflows
+ }
+ df, err := d.Dial(ctx)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ return df, aflows
+}
diff --git a/runtime/internal/flow/flowcontrol/flowcontrol.go b/runtime/internal/flow/flowcontrol/flowcontrol.go
index 1a47ada..f744641 100644
--- a/runtime/internal/flow/flowcontrol/flowcontrol.go
+++ b/runtime/internal/flow/flowcontrol/flowcontrol.go
@@ -144,7 +144,7 @@
return err
}
-func (w *Worker) releaseLocked(tokens int) {
+func (w *Worker) releaseLocked(ctx *context.T, tokens int) {
if w.counters == nil {
return
}
@@ -164,9 +164,9 @@
// Release releases tokens to this worker.
// Workers will first repay any debts to the flow controllers shared pool
// and use any surplus in subsequent calls to Run.
-func (w *Worker) Release(tokens int) {
+func (w *Worker) Release(ctx *context.T, tokens int) {
w.fc.mu.Lock()
- w.releaseLocked(tokens)
+ w.releaseLocked(ctx, tokens)
next := w.fc.nextWorkerLocked()
w.fc.mu.Unlock()
if next != nil {
@@ -223,13 +223,13 @@
// Release releases to many Workers atomically. It is conceptually
// the same as calling release on each worker indepedently.
-func (fc *FlowController) Release(to []Release) error {
+func (fc *FlowController) Release(ctx *context.T, to []Release) error {
fc.mu.Lock()
for _, t := range to {
if t.Worker.fc != fc {
- return verror.New(ErrWrongFlowController, nil)
+ return verror.New(ErrWrongFlowController, ctx)
}
- t.Worker.releaseLocked(t.Tokens)
+ t.Worker.releaseLocked(ctx, t.Tokens)
}
next := fc.nextWorkerLocked()
fc.mu.Unlock()
diff --git a/runtime/internal/flow/flowcontrol/flowcontrol_test.go b/runtime/internal/flow/flowcontrol/flowcontrol_test.go
index 7388b8f..ce42db5 100644
--- a/runtime/internal/flow/flowcontrol/flowcontrol_test.go
+++ b/runtime/internal/flow/flowcontrol/flowcontrol_test.go
@@ -46,7 +46,7 @@
for i := 0; i < workers; i++ {
go func(idx int) {
el := fc.NewWorker(0)
- go el.Release(messages * 5) // Try to make races happen
+ go el.Release(ctx, messages*5) // Try to make races happen
j := 0
el.Run(ctx, func(tokens int) (used int, done bool, err error) {
msgs[idx] = append(msgs[idx], []byte(fmt.Sprintf("%d-%d,", idx, j))...)
@@ -91,7 +91,7 @@
work <- w
return t, false, nil
})
- w.Release(mtu)
+ w.Release(ctx, mtu)
<-work
return w
}
@@ -104,7 +104,7 @@
// Release to all the flows at once and ensure the writes
// happen in the correct order.
- fc.Release([]Release{{w0, 2 * mtu}, {w1a, 2 * mtu}, {w1b, 3 * mtu}, {w1c, 0}, {w2, mtu}})
+ fc.Release(ctx, []Release{{w0, 2 * mtu}, {w1a, 2 * mtu}, {w1b, 3 * mtu}, {w1c, 0}, {w2, mtu}})
expect(t, work, w0, w0, w1a, w1b, w1a, w1b, w1b, w2)
}
@@ -137,11 +137,11 @@
w1 := worker(1)
// Now Release to w0 which shouldn't allow it to run since it's just repaying, but
// should allow w1 to run on the returned shared counters.
- w0.Release(2 * mtu)
+ w0.Release(ctx, 2*mtu)
expect(t, work, w1, w1)
// Releasing again will allow w0 to run.
- w0.Release(mtu)
+ w0.Release(ctx, mtu)
expect(t, work, w0)
}
@@ -250,7 +250,7 @@
for i := 0; i < workers; i++ {
go func(idx int) {
w := fc.NewWorker(0)
- w.Release(len(testdata))
+ w.Release(ctx, len(testdata))
t := testdata
err := w.Run(ctx, func(tokens int) (used int, done bool, err error) {
towrite := min(tokens, len(t))
diff --git a/runtime/internal/flow/manager/framer.go b/runtime/internal/flow/manager/framer.go
index 6e2a991..4b08264 100644
--- a/runtime/internal/flow/manager/framer.go
+++ b/runtime/internal/flow/manager/framer.go
@@ -13,7 +13,7 @@
// framer is a wrapper of io.ReadWriter that adds framing to a net.Conn
// and implements flow.MsgReadWriter.
type framer struct {
- io.ReadWriter
+ io.ReadWriteCloser
buf []byte
}
diff --git a/runtime/internal/flow/manager/framer_test.go b/runtime/internal/flow/manager/framer_test.go
index 7710f7d..b863fe0 100644
--- a/runtime/internal/flow/manager/framer_test.go
+++ b/runtime/internal/flow/manager/framer_test.go
@@ -9,9 +9,16 @@
"testing"
)
+type readWriteCloser struct {
+ bytes.Buffer
+}
+
+func (*readWriteCloser) Close() error {
+ return nil
+}
+
func TestFramer(t *testing.T) {
- b := &bytes.Buffer{}
- f := &framer{ReadWriter: b}
+ f := &framer{ReadWriteCloser: &readWriteCloser{}}
bufs := [][]byte{[]byte("read "), []byte("this "), []byte("please.")}
want := []byte("read this please.")
l := len(want)
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index d5873d0..7e395ef 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -72,7 +72,7 @@
netConn, err := netLn.Accept()
for tokill := 1; isTemporaryError(err); tokill *= 2 {
if isTooManyOpenFiles(err) {
- if err := m.cache.KillConnections(tokill); err != nil {
+ if err := m.cache.KillConnections(ctx, tokill); err != nil {
ctx.VI(2).Infof("failed to kill connections: %v", err)
continue
}
@@ -88,7 +88,7 @@
}
_, err = conn.NewAccepted(
ctx,
- &framer{ReadWriter: netConn},
+ &framer{ReadWriteCloser: netConn},
local,
v23.GetPrincipal(ctx).BlessingStore().Default(),
version.Supported,
@@ -196,7 +196,7 @@
}
c, err = conn.NewDialed(
ctx,
- &framer{ReadWriter: netConn}, // TODO(suharshs): Don't frame if the net.Conn already has framing in its protocol.
+ &framer{ReadWriteCloser: netConn}, // TODO(suharshs): Don't frame if the net.Conn already has framing in its protocol.
localEndpoint(netConn, m.rid),
remote,
version.Supported,