veyron/runtimes/google/ipc: Fix race in choosing an error to return when the flow closes.
Change-Id: I068e59ff2b4d873db6705cca75e8f6f2c3bca463
diff --git a/runtimes/google/ipc/full_test.go b/runtimes/google/ipc/full_test.go
index 76d940e..1174d02 100644
--- a/runtimes/google/ipc/full_test.go
+++ b/runtimes/google/ipc/full_test.go
@@ -810,29 +810,26 @@
type cancelTestServer struct {
started chan struct{}
cancelled chan struct{}
+ t *testing.T
}
-func newCancelTestServer() *cancelTestServer {
+func newCancelTestServer(t *testing.T) *cancelTestServer {
return &cancelTestServer{
started: make(chan struct{}),
cancelled: make(chan struct{}),
+ t: t,
}
}
func (s *cancelTestServer) CancelStreamReader(call ipc.ServerCall) error {
close(s.started)
- for {
- var b []byte
- if err := call.Recv(&b); err != nil && err != io.EOF {
- return err
- }
- select {
- case <-call.Done():
- close(s.cancelled)
- return nil
- default:
- }
+ var b []byte
+ if err := call.Recv(&b); err != io.EOF {
+ s.t.Errorf("Got error %v, want io.EOF", err)
}
+ <-call.Done()
+ close(s.cancelled)
+ return nil
}
// CancelStreamIgnorer doesn't read from it's input stream so all it's
@@ -840,15 +837,9 @@
// even when the stream is stalled.
func (s *cancelTestServer) CancelStreamIgnorer(call ipc.ServerCall) error {
close(s.started)
- for {
- time.Sleep(time.Millisecond)
- select {
- case <-call.Done():
- close(s.cancelled)
- return nil
- default:
- }
- }
+ <-call.Done()
+ close(s.cancelled)
+ return nil
}
func waitForCancel(t *testing.T, ts *cancelTestServer, call ipc.Call) {
@@ -859,7 +850,7 @@
// TestCancel tests cancellation while the server is reading from a stream.
func TestCancel(t *testing.T) {
- ts := newCancelTestServer()
+ ts := newCancelTestServer(t)
b := createBundle(t, clientID, serverID, ts)
defer b.cleanup(t)
@@ -867,19 +858,13 @@
if err != nil {
t.Fatalf("Start call failed: %v", err)
}
- for i := 0; i <= 10; i++ {
- b := []byte{1, 2, 3}
- if err := call.Send(b); err != nil {
- t.Errorf("clientCall.Send error %q", err)
- }
- }
waitForCancel(t, ts, call)
}
// TestCancelWithFullBuffers tests that even if the writer has filled the buffers and
// the server is not reading that the cancel message gets through.
func TestCancelWithFullBuffers(t *testing.T) {
- ts := newCancelTestServer()
+ ts := newCancelTestServer(t)
b := createBundle(t, clientID, serverID, ts)
defer b.cleanup(t)
diff --git a/runtimes/google/ipc/server.go b/runtimes/google/ipc/server.go
index 57d90dc..3188aeb 100644
--- a/runtimes/google/ipc/server.go
+++ b/runtimes/google/ipc/server.go
@@ -630,11 +630,6 @@
func (fs *flowServer) serve() error {
defer fs.flow.Close()
- // Here we remove the contexts channel as a deadline to the flow.
- // We do this to ensure clients get a consistent error when they read/write
- // after the flow is closed. Otherwise there is a race between the
- // context cancellation and the flow being closed.
- defer fs.flow.SetDeadline(nil)
results, err := fs.processRequest()
@@ -719,8 +714,16 @@
// Ensure that the context gets cancelled if the flow is closed
// due to a network error, or client cancellation.
go func() {
- <-fs.flow.Closed()
- cancel()
+ select {
+ case <-fs.flow.Closed():
+ // Here we remove the contexts channel as a deadline to the flow.
+ // We do this to ensure clients get a consistent error when they read/write
+ // after the flow is closed. Since the flow is already closed, it doesn't
+ // matter that the context is also cancelled.
+ fs.flow.SetDeadline(nil)
+ cancel()
+ case <-fs.Done():
+ }
}()
// If additional credentials are provided, make them available in the context