TBR: gosh: revert Shyam's roll-back of Todd's WriteTo/ReadFrom change:
Revert "lib: Revert https://vanadium-review.googlesource.com/#/c/19347/"
A dependent cl will fix the issue with the implementation of WriteTo
that was causing lib/signals_test to hang.
Change-Id: I94a96e266cd46125b031e095a00fd8a0b1e82286
diff --git a/gosh/buffered_pipe.go b/gosh/buffered_pipe.go
index d9e05fd..6bc9d30 100644
--- a/gosh/buffered_pipe.go
+++ b/gosh/buffered_pipe.go
@@ -16,6 +16,12 @@
closed bool
}
+var (
+ // Make sure the signatures are right, so that io.Copy can be faster.
+ _ io.WriterTo = (*bufferedPipe)(nil)
+ _ io.ReaderFrom = (*bufferedPipe)(nil)
+)
+
// newBufferedPipe returns a new thread-safe pipe backed by an unbounded
// in-memory buffer. Writes on the pipe never block; reads on the pipe block
// until data is available.
@@ -24,7 +30,7 @@
}
// Read reads from the pipe.
-func (p *bufferedPipe) Read(d []byte) (n int, err error) {
+func (p *bufferedPipe) Read(d []byte) (int, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for {
@@ -39,8 +45,25 @@
}
}
+// WriteTo implements the io.WriterTo method; it is the fast version of Read
+// used by io.Copy.
+func (p *bufferedPipe) WriteTo(w io.Writer) (int64, error) {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+ for {
+ // Read any remaining data before checking whether the pipe is closed.
+ if p.buf.Len() > 0 {
+ return p.buf.WriteTo(w)
+ }
+ if p.closed {
+ return 0, io.EOF
+ }
+ p.cond.Wait()
+ }
+}
+
// Write writes to the pipe.
-func (p *bufferedPipe) Write(d []byte) (n int, err error) {
+func (p *bufferedPipe) Write(d []byte) (int, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
if p.closed {
@@ -50,6 +73,18 @@
return p.buf.Write(d)
}
+// ReadFrom implements the io.ReaderFrom method; it is the fast version of Write
+// used by io.Copy.
+func (p *bufferedPipe) ReadFrom(r io.Reader) (int64, error) {
+ p.cond.L.Lock()
+ defer p.cond.L.Unlock()
+ if p.closed {
+ return 0, io.ErrClosedPipe
+ }
+ defer p.cond.Signal()
+ return p.buf.ReadFrom(r)
+}
+
// Close closes the pipe.
func (p *bufferedPipe) Close() error {
p.cond.L.Lock()
diff --git a/gosh/buffered_pipe_test.go b/gosh/buffered_pipe_test.go
index 1dcca26..a4fef76 100644
--- a/gosh/buffered_pipe_test.go
+++ b/gosh/buffered_pipe_test.go
@@ -5,23 +5,57 @@
package gosh
import (
+ "bytes"
+ "io"
"io/ioutil"
+ "strings"
"testing"
)
-func TestReadAfterClose(t *testing.T) {
+func TestReadWriteAfterClose(t *testing.T) {
p := newBufferedPipe()
- if _, err := p.Write([]byte("foo")); err != nil {
- t.Errorf("write failed: %v", err)
+ if n, err := p.Write([]byte("foo")); n != 3 || err != nil {
+ t.Errorf("write got (%v,%v) want (3,nil)", n, err)
+ }
+ if n, err := p.Write([]byte("barbaz")); n != 6 || err != nil {
+ t.Errorf("write got (%v,%v) want (6,nil)", n, err)
}
if err := p.Close(); err != nil {
t.Errorf("close failed: %v", err)
}
- b, err := ioutil.ReadAll(p)
- if err != nil {
- t.Errorf("read failed: %v", err)
+ // Read after close returns all data terminated by EOF.
+ if b, err := ioutil.ReadAll(p); string(b) != "foobarbaz" || err != nil {
+ t.Errorf("read got (%s,%v) want (foobarbaz,nil)", b, err)
}
- if got, want := string(b), "foo"; got != want {
- t.Errorf("got %s, want %s", got, want)
+ // Write after close fails.
+ n, err := p.Write([]byte("already closed"))
+ if got, want := n, 0; got != want {
+ t.Errorf("write after close got n %v, want %v", got, want)
+ }
+ if got, want := err, io.ErrClosedPipe; got != want {
+ t.Errorf("write after close got error %v, want %v", got, want)
+ }
+}
+
+func TestReadFromWriteTo(t *testing.T) {
+ p, buf := newBufferedPipe(), new(bytes.Buffer)
+ if n, err := p.(io.ReaderFrom).ReadFrom(strings.NewReader("foobarbaz")); n != 9 || err != nil {
+ t.Errorf("readfrom got (%v,%v) want (9,nil)", n, err)
+ }
+ if n, err := p.(io.WriterTo).WriteTo(buf); n != 9 || err != nil {
+ t.Errorf("writeto got (%v,%v) want (9,nil)", n, err)
+ }
+ if got, want := buf.String(), "foobarbaz"; got != want {
+ t.Errorf("writeto got %v want %v", got, want)
+ }
+ buf.Reset()
+ if n, err := p.(io.ReaderFrom).ReadFrom(strings.NewReader("foobarbaz")); n != 9 || err != nil {
+ t.Errorf("readfrom got (%v,%v) want (9,nil)", n, err)
+ }
+ if err := p.Close(); err != nil {
+ t.Errorf("close failed: %v", err)
+ }
+ if n, err := p.(io.WriterTo).WriteTo(buf); n != 9 || err != nil {
+ t.Errorf("writeto got (%v,%v) want (9,nil)", n, err)
}
}