runtime/internal/flow/conn: Implement flow cancelation.
Change-Id: Icbcaf895dadc913785258221d393b9a78440276b
diff --git a/runtime/internal/flow/conn/close_test.go b/runtime/internal/flow/conn/close_test.go
index 7bc4af0..6dc0987 100644
--- a/runtime/internal/flow/conn/close_test.go
+++ b/runtime/internal/flow/conn/close_test.go
@@ -10,13 +10,14 @@
"testing"
"v.io/v23"
+ "v.io/v23/context"
_ "v.io/x/ref/runtime/factories/fake"
)
func TestRemoteDialerClose(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- d, a, w := setupConns(t, ctx, nil, nil)
+ d, a, w := setupConns(t, ctx, ctx, nil, nil)
d.Close(ctx, fmt.Errorf("Closing randomly."))
<-d.Closed()
<-a.Closed()
@@ -28,7 +29,7 @@
func TestRemoteAcceptorClose(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- d, a, w := setupConns(t, ctx, nil, nil)
+ d, a, w := setupConns(t, ctx, ctx, nil, nil)
a.Close(ctx, fmt.Errorf("Closing randomly."))
<-a.Closed()
<-d.Closed()
@@ -40,7 +41,7 @@
func TestUnderlyingConnectionClosed(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- d, a, w := setupConns(t, ctx, nil, nil)
+ d, a, w := setupConns(t, ctx, ctx, nil, nil)
w.close()
<-a.Closed()
<-d.Closed()
@@ -49,7 +50,7 @@
func TestDialAfterConnClose(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- d, a, _ := setupConns(t, ctx, nil, nil)
+ d, a, _ := setupConns(t, ctx, ctx, nil, nil)
d.Close(ctx, fmt.Errorf("Closing randomly."))
<-d.Closed()
@@ -66,7 +67,7 @@
ctx, shutdown := v23.Init()
defer shutdown()
for _, dialerDials := range []bool{true, false} {
- df, flows := setupFlow(t, ctx, dialerDials)
+ df, flows := setupFlow(t, ctx, ctx, dialerDials)
if _, err := df.WriteMsg([]byte("hello")); err != nil {
t.Fatalf("write failed: %v", err)
}
@@ -79,7 +80,7 @@
if _, err := df.WriteMsg([]byte("there")); err != nil {
t.Fatalf("second write failed: %v", err)
}
- df.(*flw).conn.Close(ctx, nil)
+ df.(*flw).conn.Close(ctx, fmt.Errorf("Closing randomly."))
<-af.Conn().Closed()
if got, err := af.ReadMsg(); err != nil {
t.Fatalf("read failed: %v", err)
@@ -94,3 +95,49 @@
}
}
}
+
+func TestFlowCancelOnWrite(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ dctx, cancel := context.WithCancel(ctx)
+ df, accept := setupFlow(t, dctx, ctx, true)
+ done := make(chan struct{})
+ go func() {
+ if _, err := df.WriteMsg([]byte("hello")); err != nil {
+ t.Fatalf("could not write flow: %v", err)
+ }
+ for {
+ if _, err := df.WriteMsg([]byte("hello")); err == context.Canceled {
+ break
+ } else if err != nil {
+ t.Fatalf("unexpected error waiting for cancel: %v", err)
+ }
+ }
+ close(done)
+ }()
+ af := <-accept
+ cancel()
+ <-done
+ <-af.Closed()
+}
+
+func TestFlowCancelOnRead(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ dctx, cancel := context.WithCancel(ctx)
+ df, accept := setupFlow(t, dctx, ctx, true)
+ done := make(chan struct{})
+ go func() {
+ if _, err := df.WriteMsg([]byte("hello")); err != nil {
+ t.Fatalf("could not write flow: %v", err)
+ }
+ if _, err := df.ReadMsg(); err != context.Canceled {
+ t.Fatalf("unexpected error waiting for cancel: %v", err)
+ }
+ close(done)
+ }()
+ af := <-accept
+ cancel()
+ <-done
+ <-af.Closed()
+}
diff --git a/runtime/internal/flow/conn/conn.go b/runtime/internal/flow/conn/conn.go
index ebe4535..8cd4b6a 100644
--- a/runtime/internal/flow/conn/conn.go
+++ b/runtime/internal/flow/conn/conn.go
@@ -14,6 +14,7 @@
"v.io/v23/naming"
"v.io/v23/rpc/version"
"v.io/v23/security"
+ "v.io/v23/verror"
"v.io/x/ref/runtime/internal/flow/flowcontrol"
)
@@ -155,18 +156,29 @@
// We've already torn this conn down.
return
}
+ ferr := err
+ if verror.ErrorID(err) == ErrConnClosedRemotely.ID {
+ ferr = NewErrFlowClosedRemotely(ctx)
+ } else {
+ message := ""
+ if err != nil {
+ message = err.Error()
+ }
+ cerr := c.fc.Run(ctx, expressPriority, func(_ int) (int, bool, error) {
+ return 0, true, c.mp.writeMsg(ctx, &tearDown{Message: message})
+ })
+ if cerr != nil {
+ ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, cerr)
+ }
+ }
for _, f := range flows {
- f.close(err)
+ f.close(ctx, ferr)
}
- err = c.fc.Run(ctx, expressPriority, func(_ int) (int, bool, error) {
- return 0, true, c.mp.writeMsg(ctx, &tearDown{Err: err})
- })
- if err != nil {
- ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, err)
+ if cerr := c.mp.close(); cerr != nil {
+ ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, cerr)
}
- if err = c.mp.close(); err != nil {
- ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, err)
- }
+
+ // TODO(mattr): ensure the readLoop is finished before closing this.
close(c.closed)
}
@@ -207,7 +219,7 @@
switch msg := x.(type) {
case *tearDown:
- terr = msg.Err
+ terr = NewErrConnClosedRemotely(ctx, msg.Message)
return
case *openFlow:
@@ -244,7 +256,7 @@
return
}
if msg.flags&closeFlag != 0 {
- f.close(nil)
+ f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
}
case *unencryptedData:
@@ -259,7 +271,7 @@
return
}
if msg.flags&closeFlag != 0 {
- f.close(nil)
+ f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
}
default:
diff --git a/runtime/internal/flow/conn/conn_test.go b/runtime/internal/flow/conn/conn_test.go
index 3ff2767..7fefba6 100644
--- a/runtime/internal/flow/conn/conn_test.go
+++ b/runtime/internal/flow/conn/conn_test.go
@@ -72,7 +72,7 @@
ctx, shutdown := v23.Init()
defer shutdown()
for _, dialerDials := range []bool{true, false} {
- df, flows := setupFlow(t, ctx, dialerDials)
+ df, flows := setupFlow(t, ctx, ctx, dialerDials)
testWrite(t, ctx, []byte("hello world"), df, flows)
}
}
@@ -80,6 +80,6 @@
func TestLargeWrite(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
- df, flows := setupFlow(t, ctx, true)
+ df, flows := setupFlow(t, ctx, ctx, true)
testWrite(t, ctx, randData, df, flows)
}
diff --git a/runtime/internal/flow/conn/conncache.go b/runtime/internal/flow/conn/conncache.go
index b435d9c..ae77f97 100644
--- a/runtime/internal/flow/conn/conncache.go
+++ b/runtime/internal/flow/conn/conncache.go
@@ -119,7 +119,7 @@
c.addrCache, c.ridCache, c.started = nil, nil, nil
d := c.head.next
for d != c.head {
- d.conn.Close(ctx, nil)
+ d.conn.Close(ctx, NewErrCacheClosed(ctx))
d = d.next
}
c.head = nil
@@ -134,14 +134,15 @@
defer c.mu.Unlock()
c.mu.Lock()
if c.addrCache == nil {
- return NewErrCacheClosed(nil)
+ return NewErrCacheClosed(ctx)
}
d := c.head.prev
+ err := NewErrConnKilledToFreeResources(ctx)
for i := 0; i < num; i++ {
if d == c.head {
break
}
- d.conn.Close(ctx, nil)
+ d.conn.Close(ctx, err)
delete(c.addrCache, d.addrKey)
delete(c.ridCache, d.rid)
prev := d.prev
diff --git a/runtime/internal/flow/conn/errors.vdl b/runtime/internal/flow/conn/errors.vdl
index 98938d4..0e0e821 100644
--- a/runtime/internal/flow/conn/errors.vdl
+++ b/runtime/internal/flow/conn/errors.vdl
@@ -19,6 +19,9 @@
UnexpectedMsg(typ string) {"en": "unexpected message type{:typ}."}
ConnectionClosed() {"en": "connection closed."}
+ ConnKilledToFreeResources() {"en": "Connection killed to free resources."}
+ ConnClosedRemotely(msg string) {"en": "connection closed remotely{:msg}."}
+ FlowClosedRemotely() {"en": "flow closed remotely."}
Send(typ, dest string, err error) {"en": "failure sending {typ} message to {dest}{:err}."}
Recv(src string, err error) {"en": "error reading from {src}{:err}"}
CacheClosed() {"en":"cache is closed"}
diff --git a/runtime/internal/flow/conn/errors.vdl.go b/runtime/internal/flow/conn/errors.vdl.go
index e26dc9c..2db4372 100644
--- a/runtime/internal/flow/conn/errors.vdl.go
+++ b/runtime/internal/flow/conn/errors.vdl.go
@@ -15,16 +15,19 @@
)
var (
- ErrInvalidMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidMsg", verror.NoRetry, "{1:}{2:} message of type{:3} and size{:4} failed decoding at field{:5}.")
- ErrInvalidControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidControlMsg", verror.NoRetry, "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
- ErrUnknownMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownMsg", verror.NoRetry, "{1:}{2:} unknown message type{:3}.")
- ErrUnknownControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownControlMsg", verror.NoRetry, "{1:}{2:} unknown control command{:3}.")
- ErrUnexpectedMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnexpectedMsg", verror.NoRetry, "{1:}{2:} unexpected message type{:3}.")
- ErrConnectionClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.ConnectionClosed", verror.NoRetry, "{1:}{2:} connection closed.")
- ErrSend = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Send", verror.NoRetry, "{1:}{2:} failure sending {3} message to {4}{:5}.")
- ErrRecv = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Recv", verror.NoRetry, "{1:}{2:} error reading from {3}{:4}")
- ErrCacheClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CacheClosed", verror.NoRetry, "{1:}{2:} cache is closed")
- ErrCounterOverflow = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CounterOverflow", verror.NoRetry, "{1:}{2:} A remote process has sent more data than allowed.")
+ ErrInvalidMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidMsg", verror.NoRetry, "{1:}{2:} message of type{:3} and size{:4} failed decoding at field{:5}.")
+ ErrInvalidControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidControlMsg", verror.NoRetry, "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
+ ErrUnknownMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownMsg", verror.NoRetry, "{1:}{2:} unknown message type{:3}.")
+ ErrUnknownControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownControlMsg", verror.NoRetry, "{1:}{2:} unknown control command{:3}.")
+ ErrUnexpectedMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnexpectedMsg", verror.NoRetry, "{1:}{2:} unexpected message type{:3}.")
+ ErrConnectionClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.ConnectionClosed", verror.NoRetry, "{1:}{2:} connection closed.")
+ ErrConnKilledToFreeResources = verror.Register("v.io/x/ref/runtime/internal/flow/conn.ConnKilledToFreeResources", verror.NoRetry, "{1:}{2:} Connection killed to free resources.")
+ ErrConnClosedRemotely = verror.Register("v.io/x/ref/runtime/internal/flow/conn.ConnClosedRemotely", verror.NoRetry, "{1:}{2:} connection closed remotely{:3}.")
+ ErrFlowClosedRemotely = verror.Register("v.io/x/ref/runtime/internal/flow/conn.FlowClosedRemotely", verror.NoRetry, "{1:}{2:} flow closed remotely.")
+ ErrSend = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Send", verror.NoRetry, "{1:}{2:} failure sending {3} message to {4}{:5}.")
+ ErrRecv = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Recv", verror.NoRetry, "{1:}{2:} error reading from {3}{:4}")
+ ErrCacheClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CacheClosed", verror.NoRetry, "{1:}{2:} cache is closed")
+ ErrCounterOverflow = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CounterOverflow", verror.NoRetry, "{1:}{2:} A remote process has sent more data than allowed.")
)
func init() {
@@ -34,6 +37,9 @@
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownControlMsg.ID), "{1:}{2:} unknown control command{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnexpectedMsg.ID), "{1:}{2:} unexpected message type{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrConnectionClosed.ID), "{1:}{2:} connection closed.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrConnKilledToFreeResources.ID), "{1:}{2:} Connection killed to free resources.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrConnClosedRemotely.ID), "{1:}{2:} connection closed remotely{:3}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrFlowClosedRemotely.ID), "{1:}{2:} flow closed remotely.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrSend.ID), "{1:}{2:} failure sending {3} message to {4}{:5}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrRecv.ID), "{1:}{2:} error reading from {3}{:4}")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrCacheClosed.ID), "{1:}{2:} cache is closed")
@@ -70,6 +76,21 @@
return verror.New(ErrConnectionClosed, ctx)
}
+// NewErrConnKilledToFreeResources returns an error with the ErrConnKilledToFreeResources ID.
+func NewErrConnKilledToFreeResources(ctx *context.T) error {
+ return verror.New(ErrConnKilledToFreeResources, ctx)
+}
+
+// NewErrConnClosedRemotely returns an error with the ErrConnClosedRemotely ID.
+func NewErrConnClosedRemotely(ctx *context.T, msg string) error {
+ return verror.New(ErrConnClosedRemotely, ctx, msg)
+}
+
+// NewErrFlowClosedRemotely returns an error with the ErrFlowClosedRemotely ID.
+func NewErrFlowClosedRemotely(ctx *context.T) error {
+ return verror.New(ErrFlowClosedRemotely, ctx)
+}
+
// NewErrSend returns an error with the ErrSend ID.
func NewErrSend(ctx *context.T, typ string, dest string, err error) error {
return verror.New(ErrSend, ctx, typ, dest, err)
diff --git a/runtime/internal/flow/conn/flow.go b/runtime/internal/flow/conn/flow.go
index 804f472..f75e10f 100644
--- a/runtime/internal/flow/conn/flow.go
+++ b/runtime/internal/flow/conn/flow.go
@@ -8,7 +8,7 @@
"v.io/v23/context"
"v.io/v23/flow"
"v.io/v23/security"
-
+ "v.io/v23/verror"
"v.io/x/ref/runtime/internal/flow/flowcontrol"
)
@@ -50,6 +50,9 @@
if n, release, err = f.q.read(f.ctx, p); release {
f.conn.release(f.ctx)
}
+ if err != nil {
+ f.close(f.ctx, err)
+ }
return
}
@@ -65,6 +68,9 @@
if buf, release, err = f.q.get(f.ctx); release {
f.conn.release(f.ctx)
}
+ if err != nil {
+ f.close(f.ctx, err)
+ }
return
}
@@ -123,7 +129,7 @@
return size, done, f.conn.mp.writeMsg(f.ctx, d)
})
if alsoClose || err != nil {
- f.close(err)
+ f.close(f.ctx, err)
}
return sent, err
}
@@ -220,10 +226,18 @@
return f.ctx.Done()
}
-func (f *flw) close(err error) {
- f.q.close(f.ctx)
+func (f *flw) close(ctx *context.T, err error) {
+ f.q.close(ctx)
f.cancel()
-
- // TODO(mattr): maybe send a final close data message.
- // TODO(mattr): save the error to hand out later.
+ if verror.ErrorID(err) != ErrFlowClosedRemotely.ID {
+ // We want to try to send this message even if ctx is already canceled.
+ ctx, cancel := context.WithRootCancel(ctx)
+ err := f.worker.Run(ctx, func(tokens int) (int, bool, error) {
+ return 0, true, f.conn.mp.writeMsg(ctx, &data{id: f.id, flags: closeFlag})
+ })
+ if err != nil {
+ ctx.Errorf("Could not send close flow message: %v", err)
+ }
+ cancel()
+ }
}
diff --git a/runtime/internal/flow/conn/message.go b/runtime/internal/flow/conn/message.go
index e1bc1f3..6cea426 100644
--- a/runtime/internal/flow/conn/message.go
+++ b/runtime/internal/flow/conn/message.go
@@ -5,8 +5,6 @@
package conn
import (
- "errors"
-
"v.io/v23/context"
"v.io/v23/naming"
"v.io/v23/rpc/version"
@@ -88,19 +86,15 @@
// tearDown is sent over the wire before a connection is closed.
type tearDown struct {
- Err error
+ Message string
}
func (m *tearDown) write(ctx *context.T, p *messagePipe) error {
- var errBytes []byte
- if m.Err != nil {
- errBytes = []byte(m.Err.Error())
- }
- return p.write([][]byte{{controlType}}, [][]byte{{tearDownCmd}, errBytes})
+ return p.write([][]byte{{controlType}}, [][]byte{{tearDownCmd}, []byte(m.Message)})
}
func (m *tearDown) read(ctx *context.T, data []byte) error {
if len(data) > 0 {
- m.Err = errors.New(string(data))
+ m.Message = string(data)
}
return nil
}
diff --git a/runtime/internal/flow/conn/message_test.go b/runtime/internal/flow/conn/message_test.go
index ab715ae..2414c1c 100644
--- a/runtime/internal/flow/conn/message_test.go
+++ b/runtime/internal/flow/conn/message_test.go
@@ -5,7 +5,6 @@
package conn
import (
- "errors"
"reflect"
"testing"
@@ -82,7 +81,7 @@
func TestTearDown(t *testing.T) {
testMessages(t, []message{
- &tearDown{Err: errors.New("foobar")},
+ &tearDown{Message: "foobar"},
&tearDown{},
})
}
diff --git a/runtime/internal/flow/conn/util_test.go b/runtime/internal/flow/conn/util_test.go
index 18cc267..5694830 100644
--- a/runtime/internal/flow/conn/util_test.go
+++ b/runtime/internal/flow/conn/util_test.go
@@ -96,34 +96,42 @@
return nil
}
-func setupConns(t *testing.T, ctx *context.T, dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn, _ *wire) {
- dmrw, amrw, w := newMRWPair(ctx)
+func setupConns(t *testing.T, dctx, actx *context.T, dflows, aflows chan<- flow.Flow) (dialed, accepted *Conn, _ *wire) {
+ dmrw, amrw, w := newMRWPair(dctx)
versions := version.RPCVersionRange{Min: 3, Max: 5}
ep, err := v23.NewEndpoint("localhost:80")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
- d, err := NewDialed(ctx, dmrw, ep, ep, versions, fh(dflows), nil)
+ d, err := NewDialed(dctx, dmrw, ep, ep, versions, fh(dflows), nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
- a, err := NewAccepted(ctx, amrw, ep, security.Blessings{}, versions, fh(aflows))
+ a, err := NewAccepted(actx, amrw, ep, security.Blessings{}, versions, fh(aflows))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
return d, a, w
}
-func setupFlow(t *testing.T, ctx *context.T, dialFromDialer bool) (dialed flow.Flow, accepted <-chan flow.Flow) {
- dflows, aflows := make(chan flow.Flow, 1), make(chan flow.Flow, 1)
- d, a, _ := setupConns(t, ctx, dflows, aflows)
+func setupFlow(t *testing.T, dctx, actx *context.T, dialFromDialer bool) (dialed flow.Flow, accepted <-chan flow.Flow) {
+ d, accepted := setupFlows(t, dctx, actx, dialFromDialer, 1)
+ return d[0], accepted
+}
+
+func setupFlows(t *testing.T, dctx, actx *context.T, dialFromDialer bool, n int) (dialed []flow.Flow, accepted <-chan flow.Flow) {
+ dflows, aflows := make(chan flow.Flow, n), make(chan flow.Flow, n)
+ d, a, _ := setupConns(t, dctx, actx, dflows, aflows)
if !dialFromDialer {
d, a = a, d
aflows, dflows = dflows, aflows
}
- df, err := d.Dial(ctx)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
+ dialed = make([]flow.Flow, n)
+ for i := 0; i < n; i++ {
+ var err error
+ if dialed[i], err = d.Dial(dctx); err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
}
- return df, aflows
+ return dialed, aflows
}
diff --git a/runtime/internal/flow/flowcontrol/flowcontrol.go b/runtime/internal/flow/flowcontrol/flowcontrol.go
index 2b6e64b..8221895 100644
--- a/runtime/internal/flow/flowcontrol/flowcontrol.go
+++ b/runtime/internal/flow/flowcontrol/flowcontrol.go
@@ -70,6 +70,15 @@
for {
next := w.fc.nextWorkerLocked()
+ if w.fc.writing == w {
+ // We're already schedule to write, but we should bail
+ // out if we're canceled.
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ default:
+ }
+ }
for w.fc.writing != w && err == nil {
w.fc.mu.Unlock()
if next != nil {