lib: Change gosh.Cmd.AddStdoutWriter to take an io.WriteCloser
The rationale behind this change is detailed here:
https://github.com/vanadium/issues/issues/1031
The basic idea is that the previous gosh behavior wrt
Cmd.AddStdoutWriter (and AddStderrWriter) was a bit weird. We
took an io.Writer argument w, and if w happened to implement
io.Closer, we'd auto-close w when the process finished. The
semantics of Close is largely implementation-dependent, which
made the gosh usage a bit scary. In addition we special-cased
os.Stdout and os.Stderr, to prevent closing those when a single
cmd finished.
This change makes things more explicit. We always take an
io.WriteCloser as an argument, which we will auto-close when the
process finishes. We also remove the os.Stdout and os.Stderr
special-cases, and add gosh.NopWriteCloser instead.
MultiPart: 1/2
Change-Id: I77d04a1bc90f1b07fe4d0f8815a963f4fb73739e
diff --git a/gosh/.api b/gosh/.api
index f20a7a5..2b75989 100644
--- a/gosh/.api
+++ b/gosh/.api
@@ -3,13 +3,14 @@
pkg gosh, func MaybeWatchParent()
pkg gosh, func NewBufferedPipe() io.ReadWriteCloser
pkg gosh, func NewShell(Opts) *Shell
+pkg gosh, func NopWriteCloser(io.Writer) io.WriteCloser
pkg gosh, func Register(string, interface{}) *Fn
pkg gosh, func Run(func() int) int
pkg gosh, func SendReady()
pkg gosh, func SendVars(map[string]string)
pkg gosh, func WatchParent()
-pkg gosh, method (*Cmd) AddStderrWriter(io.Writer)
-pkg gosh, method (*Cmd) AddStdoutWriter(io.Writer)
+pkg gosh, method (*Cmd) AddStderrWriter(io.WriteCloser)
+pkg gosh, method (*Cmd) AddStdoutWriter(io.WriteCloser)
pkg gosh, method (*Cmd) AwaitReady()
pkg gosh, method (*Cmd) AwaitVars(...string) map[string]string
pkg gosh, method (*Cmd) Clone() *Cmd
diff --git a/gosh/cmd.go b/gosh/cmd.go
index 38ca2a0..f113530 100644
--- a/gosh/cmd.go
+++ b/gosh/cmd.go
@@ -21,6 +21,8 @@
var (
errAlreadyCalledStart = errors.New("gosh: already called Cmd.Start")
errAlreadyCalledWait = errors.New("gosh: already called Cmd.Wait")
+ errCloseStdout = errors.New("gosh: use NopWriteCloser(os.Stdout) to prevent stdout from being closed")
+ errCloseStderr = errors.New("gosh: use NopWriteCloser(os.Stderr) to prevent stderr from being closed")
errDidNotCallStart = errors.New("gosh: did not call Cmd.Start")
errProcessExited = errors.New("gosh: process exited")
)
@@ -104,19 +106,25 @@
}
// AddStdoutWriter configures this Cmd to tee the child's stdout to the given
-// Writer. If this Writer is a Closer and is not os.Stdout or os.Stderr, it will
-// be closed when the process exits.
-func (c *Cmd) AddStdoutWriter(w io.Writer) {
+// wc, which will be closed when the process exits.
+//
+// Use NopWriteCloser to extend an io.Writer to io.WriteCloser, or to prevent an
+// existing io.WriteCloser from being closed. It is an error to pass in
+// os.Stdout or os.Stderr, since they shouldn't be closed.
+func (c *Cmd) AddStdoutWriter(wc io.WriteCloser) {
c.sh.Ok()
- c.handleError(c.addStdoutWriter(w))
+ c.handleError(c.addStdoutWriter(wc))
}
// AddStderrWriter configures this Cmd to tee the child's stderr to the given
-// Writer. If this Writer is a Closer and is not os.Stdout or os.Stderr, it will
-// be closed when the process exits.
-func (c *Cmd) AddStderrWriter(w io.Writer) {
+// wc, which will be closed when the process exits.
+//
+// Use NopWriteCloser to extend an io.Writer to io.WriteCloser, or to prevent an
+// existing io.WriteCloser from being closed. It is an error to pass in
+// os.Stdout or os.Stderr, since they shouldn't be closed.
+func (c *Cmd) AddStderrWriter(wc io.WriteCloser) {
c.sh.Ok()
- c.handleError(c.addStderrWriter(w))
+ c.handleError(c.addStderrWriter(wc))
}
// Start starts the command.
@@ -250,15 +258,6 @@
func (c *Cmd) addWriter(writers *[]io.Writer, w io.Writer) {
*writers = append(*writers, w)
- // Check for os.Stdout and os.Stderr so that we don't close these when a
- // single command exits. This technique isn't foolproof (since, for example,
- // os.Stdout may be wrapped in another WriteCloser), but in practice it should
- // be adequate.
- if w != os.Stdout && w != os.Stderr {
- if wc, ok := w.(io.Closer); ok {
- c.closers = append(c.closers, wc)
- }
- }
}
func (c *Cmd) closeClosers() {
@@ -324,20 +323,20 @@
return len(p), nil
}
-func (c *Cmd) initMultiWriter(f *os.File, t string) (io.Writer, error) {
- var writers *[]io.Writer
- if f == os.Stdout {
- writers = &c.stdoutWriters
+func (c *Cmd) makeMultiWriter(stdout bool, t string) (io.Writer, error) {
+ std, writers := os.Stderr, &c.stderrWriters
+ if stdout {
+ std, writers = os.Stdout, &c.stdoutWriters
c.addWriter(writers, &recvWriter{c: c})
- } else {
- writers = &c.stderrWriters
}
if c.PropagateOutput {
- c.addWriter(writers, f)
+ c.addWriter(writers, std)
+ // Don't add std to c.closers, since we don't want to close os.Stdout or
+ // os.Stderr for the entire address space when c finishes.
}
if c.OutputDir != "" {
suffix := "stderr"
- if f == os.Stdout {
+ if stdout {
suffix = "stdout"
}
name := filepath.Join(c.OutputDir, filepath.Base(c.Path)+"."+t+"."+suffix)
@@ -346,6 +345,7 @@
return nil, err
}
c.addWriter(writers, file)
+ c.closers = append(c.closers, file)
}
return io.MultiWriter(*writers...), nil
}
@@ -386,6 +386,7 @@
}
p := NewBufferedPipe()
c.addWriter(&c.stdoutWriters, p)
+ c.closers = append(c.closers, p)
return p, nil
}
@@ -395,22 +396,35 @@
}
p := NewBufferedPipe()
c.addWriter(&c.stderrWriters, p)
+ c.closers = append(c.closers, p)
return p, nil
}
-func (c *Cmd) addStdoutWriter(w io.Writer) error {
- if c.calledStart {
+func (c *Cmd) addStdoutWriter(wc io.WriteCloser) error {
+ switch {
+ case c.calledStart:
return errAlreadyCalledStart
+ case wc == os.Stdout:
+ return errCloseStdout
+ case wc == os.Stderr:
+ return errCloseStderr
}
- c.addWriter(&c.stdoutWriters, w)
+ c.addWriter(&c.stdoutWriters, wc)
+ c.closers = append(c.closers, wc)
return nil
}
-func (c *Cmd) addStderrWriter(w io.Writer) error {
- if c.calledStart {
+func (c *Cmd) addStderrWriter(wc io.WriteCloser) error {
+ switch {
+ case c.calledStart:
return errAlreadyCalledStart
+ case wc == os.Stdout:
+ return errCloseStdout
+ case wc == os.Stderr:
+ return errCloseStderr
}
- c.addWriter(&c.stderrWriters, w)
+ c.addWriter(&c.stderrWriters, wc)
+ c.closers = append(c.closers, wc)
return nil
}
@@ -441,10 +455,10 @@
}
t := time.Now().Format("20060102.150405.000000")
var err error
- if c.c.Stdout, err = c.initMultiWriter(os.Stdout, t); err != nil {
+ if c.c.Stdout, err = c.makeMultiWriter(true, t); err != nil {
return err
}
- if c.c.Stderr, err = c.initMultiWriter(os.Stderr, t); err != nil {
+ if c.c.Stderr, err = c.makeMultiWriter(false, t); err != nil {
return err
}
// Start the command.
diff --git a/gosh/shell.go b/gosh/shell.go
index 79992cd..192b795 100644
--- a/gosh/shell.go
+++ b/gosh/shell.go
@@ -15,6 +15,7 @@
import (
"errors"
"fmt"
+ "io"
"io/ioutil"
"log"
"math/rand"
@@ -565,3 +566,15 @@
MaybeRunFnAndExit()
return run()
}
+
+// NopWriteCloser returns a WriteCloser with a no-op Close method wrapping the
+// provided Writer w.
+func NopWriteCloser(w io.Writer) io.WriteCloser {
+ return nopWriteCloser{w}
+}
+
+type nopWriteCloser struct {
+ io.Writer
+}
+
+func (nopWriteCloser) Close() error { return nil }
diff --git a/gosh/shell_test.go b/gosh/shell_test.go
index fc78858..c96b406 100644
--- a/gosh/shell_test.go
+++ b/gosh/shell_test.go
@@ -434,15 +434,15 @@
defer sh.Cleanup()
c := sh.Fn(writeFn, true, true)
- c.AddStdoutWriter(os.Stdout)
- c.AddStderrWriter(os.Stderr)
+ c.AddStdoutWriter(gosh.NopWriteCloser(os.Stdout))
+ c.AddStderrWriter(gosh.NopWriteCloser(os.Stderr))
c.Run()
fmt.Fprint(os.Stdout, " stdout done")
fmt.Fprint(os.Stderr, " stderr done")
})
-// Tests that it's safe to add os.Stdout and os.Stderr as writers.
+// Tests that it's safe to add wrapped os.Stdout and os.Stderr as writers.
func TestAddWriters(t *testing.T) {
sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
defer sh.Cleanup()
@@ -452,6 +452,23 @@
eq(t, stderr, "BB stderr done")
}
+// Tests that adding non-wrapped os.Stdout or os.Stderr fails.
+func TestAddWritersUnwrappedStdoutStderr(t *testing.T) {
+ sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
+ defer sh.Cleanup()
+
+ for _, addFn := range []func(*gosh.Cmd, io.WriteCloser){(*gosh.Cmd).AddStdoutWriter, (*gosh.Cmd).AddStderrWriter} {
+ for _, std := range []io.WriteCloser{os.Stdout, os.Stderr} {
+ c := sh.Fn(writeMoreFn)
+ sh.Opts.Fatalf = nil
+ addFn(c, std)
+ nok(t, sh.Err)
+ sh.Err = nil
+ sh.Opts.Fatalf = makeFatalf(t)
+ }
+ }
+}
+
// Tests piping from one Cmd's stdout/stderr to another's stdin. It should be
// possible to wait on just the last Cmd.
func TestPiping(t *testing.T) {
diff --git a/textutil/.api b/textutil/.api
index fbf83f1..b57fbcb 100644
--- a/textutil/.api
+++ b/textutil/.api
@@ -5,7 +5,7 @@
pkg textutil, func FlushRuneChunk(RuneChunkDecoder, func(rune) error) error
pkg textutil, func NewLineWriter(io.Writer, int, RuneChunkDecoder, RuneEncoder) *LineWriter
pkg textutil, func NewUTF8LineWriter(io.Writer, int) *LineWriter
-pkg textutil, func PrefixLineWriter(io.Writer, string) WriteFlushCloser
+pkg textutil, func PrefixLineWriter(io.Writer, string) WriteFlusher
pkg textutil, func PrefixWriter(io.Writer, string) io.Writer
pkg textutil, func TerminalSize() (int, int, error)
pkg textutil, func WriteRuneChunk(RuneChunkDecoder, func(rune) error, []byte) (int, error)
@@ -27,10 +27,6 @@
pkg textutil, type RuneEncoder interface, Encode(rune, *bytes.Buffer)
pkg textutil, type UTF8ChunkDecoder struct
pkg textutil, type UTF8Encoder struct
-pkg textutil, type WriteFlushCloser interface { Close, Flush, Write }
-pkg textutil, type WriteFlushCloser interface, Close() error
-pkg textutil, type WriteFlushCloser interface, Flush() error
-pkg textutil, type WriteFlushCloser interface, Write([]byte) (int, error)
pkg textutil, type WriteFlusher interface { Flush, Write }
pkg textutil, type WriteFlusher interface, Flush() error
pkg textutil, type WriteFlusher interface, Write([]byte) (int, error)
diff --git a/textutil/writer.go b/textutil/writer.go
index 8b7e930..72b5386 100644
--- a/textutil/writer.go
+++ b/textutil/writer.go
@@ -16,20 +16,11 @@
// immediately outputs the buffered data. Flush must be called after the last
// call to Write, and may be called an arbitrary number of times before the last
// Write.
-//
-// If the type is also a Closer, Close implies a Flush call.
type WriteFlusher interface {
io.Writer
Flush() error
}
-// WriteFlushCloser is the interface that groups the basic Write, Flush and
-// Close methods.
-type WriteFlushCloser interface {
- WriteFlusher
- io.Closer
-}
-
// PrefixWriter returns an io.Writer that wraps w, where the prefix is written
// out immediately before the first non-empty Write call.
func PrefixWriter(w io.Writer, prefix string) io.Writer {
@@ -49,12 +40,17 @@
return w.w.Write(data)
}
-// PrefixLineWriter returns a WriteFlushCloser that wraps w. Any occurrence of
-// EOL (\f, \n, \r, \v, LineSeparator or ParagraphSeparator) causes the
-// preceeding line to be written to w, with the given prefix. Data without EOL
-// is buffered until the next EOL, or Flush or Close call. A single Write call
-// may result in zero or more Write calls on the underlying writer.
-func PrefixLineWriter(w io.Writer, prefix string) WriteFlushCloser {
+// PrefixLineWriter returns a WriteFlusher that wraps w. Any occurrence of EOL
+// (\f, \n, \r, \v, LineSeparator or ParagraphSeparator) causes the preceeding
+// line to be written to w, with the given prefix. Data without EOL is buffered
+// until the next EOL or Flush call.
+//
+// A single Write call on the returned WriteFlusher may result in zero or more
+// Write calls on the underlying w.
+//
+// If w implements WriteFlusher, each Flush call on the returned WriteFlusher
+// results in exactly one Flush call on the underlying w.
+func PrefixLineWriter(w io.Writer, prefix string) WriteFlusher {
return &prefixLineWriter{w, []byte(prefix), nil}
}
@@ -94,7 +90,14 @@
return totalLen, nil
}
-func (w *prefixLineWriter) Flush() error {
+func (w *prefixLineWriter) Flush() (e error) {
+ defer func() {
+ if f, ok := w.w.(WriteFlusher); ok {
+ if err := f.Flush(); err != nil && e == nil {
+ e = err
+ }
+ }
+ }()
if len(w.buf) > 0 {
if _, err := w.w.Write(w.prefix); err != nil {
return err
@@ -104,22 +107,9 @@
}
w.buf = w.buf[:0]
}
- if f, ok := w.w.(WriteFlusher); ok {
- return f.Flush()
- }
return nil
}
-func (w *prefixLineWriter) Close() error {
- firstErr := w.Flush()
- if c, ok := w.w.(io.Closer); ok {
- if err := c.Close(); firstErr == nil {
- firstErr = err
- }
- }
- return firstErr
-}
-
// ByteReplaceWriter returns an io.Writer that wraps w, where all occurrences of
// the old byte are replaced with the new string on Write calls.
func ByteReplaceWriter(w io.Writer, old byte, new string) io.Writer {
diff --git a/textutil/writer_test.go b/textutil/writer_test.go
index cb8f705..8aba8ab 100644
--- a/textutil/writer_test.go
+++ b/textutil/writer_test.go
@@ -6,6 +6,7 @@
import (
"bytes"
+ "errors"
"fmt"
"strings"
"testing"
@@ -141,20 +142,115 @@
}
}
-type fakeWriteFlushCloser struct{ flushed, closed bool }
+var (
+ err1 = errors.New("error 1")
+ err2 = errors.New("error 2")
+)
-func (f *fakeWriteFlushCloser) Write(p []byte) (int, error) { return len(p), nil }
-func (f *fakeWriteFlushCloser) Flush() error { f.flushed = true; return nil }
-func (f *fakeWriteFlushCloser) Close() error { f.closed = true; return nil }
+type fakeWriteFlusher struct {
+ writeErr error
+ flushErr error
+ flushed bool
+}
-func TestPrefixLineWriterCloseFlush(t *testing.T) {
- var fake fakeWriteFlushCloser
- w := PrefixLineWriter(&fake, "")
- if w.Flush(); !fake.flushed {
+func (f *fakeWriteFlusher) Write(p []byte) (int, error) {
+ return len(p), f.writeErr
+}
+
+func (f *fakeWriteFlusher) Flush() error {
+ f.flushed = true
+ return f.flushErr
+}
+
+func TestPrefixLineWriter_Flush(t *testing.T) {
+ fake := &fakeWriteFlusher{}
+ w := PrefixLineWriter(fake, "prefix")
+ if err := w.Flush(); err != nil {
+ t.Errorf("Flush got error %v, want nil", err)
+ }
+ if !fake.flushed {
t.Errorf("Flush not propagated")
}
- if w.Close(); !fake.closed {
- t.Errorf("Close not propagated")
+}
+
+func TestPrefixLineWriter_FlushError(t *testing.T) {
+ fake := &fakeWriteFlusher{flushErr: err1}
+ w := PrefixLineWriter(fake, "prefix")
+ if err := w.Flush(); err != err1 {
+ t.Errorf("Flush got error %v, want %v", err, err1)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
+ }
+}
+
+func TestPrefixLineWriter_WriteFlush(t *testing.T) {
+ fake := &fakeWriteFlusher{}
+ w := PrefixLineWriter(fake, "prefix")
+ if n, err := w.Write([]byte("abc")); n != 3 || err != nil {
+ t.Errorf("Write got (%v,%v), want (3,nil)", n, err)
+ }
+ if err := w.Flush(); err != nil {
+ t.Errorf("Flush got error %v, want nil", err)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
+ }
+}
+
+func TestPrefixLineWriter_WriteFlushError(t *testing.T) {
+ fake := &fakeWriteFlusher{flushErr: err1}
+ w := PrefixLineWriter(fake, "prefix")
+ if n, err := w.Write([]byte("abc")); n != 3 || err != nil {
+ t.Errorf("Write got (%v,%v), want (3,nil)", n, err)
+ }
+ if err := w.Flush(); err != err1 {
+ t.Errorf("Flush got error %v, want %v", err, err1)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
+ }
+}
+
+func TestPrefixLineWriter_WriteErrorFlush(t *testing.T) {
+ fake := &fakeWriteFlusher{writeErr: err1}
+ w := PrefixLineWriter(fake, "prefix")
+ if n, err := w.Write([]byte("abc")); n != 3 || err != nil {
+ t.Errorf("Write got (%v,%v), want (3,nil)", n, err)
+ }
+ if err := w.Flush(); err != err1 {
+ t.Errorf("Flush got error %v, want %v", err, err1)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
+ }
+}
+
+func TestPrefixLineWriter_WriteErrorFlushError(t *testing.T) {
+ fake := &fakeWriteFlusher{writeErr: err1, flushErr: err2}
+ w := PrefixLineWriter(fake, "prefix")
+ if n, err := w.Write([]byte("abc")); n != 3 || err != nil {
+ t.Errorf("Write got (%v,%v), want (3,nil)", n, err)
+ }
+ if err := w.Flush(); err != err1 {
+ t.Errorf("Flush got error %v, want %v", err, err1)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
+ }
+}
+
+func TestPrefixLineWriter_EOLWriteErrorFlushError(t *testing.T) {
+ fake := &fakeWriteFlusher{writeErr: err1, flushErr: err2}
+ w := PrefixLineWriter(fake, "prefix")
+ if n, err := w.Write([]byte("ab\n")); n != 0 || err != err1 {
+ t.Errorf("Write got (%v,%v), want (0,%v)", n, err, err1)
+ }
+ if err := w.Flush(); err != err2 {
+ t.Errorf("Flush got error %v, want %v", err, err2)
+ }
+ if !fake.flushed {
+ t.Errorf("Flush not propagated")
}
}