gosh: Add Cmd.CombinedOutput, and simplify stdout/stderr ordering

The Cmd.CombinedOutput method works like the exec.CombinedOutput
method; it returns the combined stdout/stderr, in the order they
were written.

The simplification to stdout/stderr ordering is to always make
writes to stdout and stderr writers synchronous, regardless of
whether we have the same underlying writers capturing both of
them.  This also has the nice property of ensuring that multiple
writers that capture both stdout and stderr will always see the
results in the same order.

Change-Id: Ib2b15092d3bd9d1566d98487965336c1d1c67aca
diff --git a/gosh/.api b/gosh/.api
index 2b75989..e99f9da 100644
--- a/gosh/.api
+++ b/gosh/.api
@@ -14,6 +14,7 @@
 pkg gosh, method (*Cmd) AwaitReady()
 pkg gosh, method (*Cmd) AwaitVars(...string) map[string]string
 pkg gosh, method (*Cmd) Clone() *Cmd
+pkg gosh, method (*Cmd) CombinedOutput() string
 pkg gosh, method (*Cmd) Kill()
 pkg gosh, method (*Cmd) Pid() int
 pkg gosh, method (*Cmd) Run()
diff --git a/gosh/cmd.go b/gosh/cmd.go
index effbac9..e80eff0 100644
--- a/gosh/cmd.go
+++ b/gosh/cmd.go
@@ -204,6 +204,15 @@
 	return stdout, stderr
 }
 
+// CombinedOutput calls Start followed by Wait, then returns the command's
+// combined stdout and stderr.
+func (c *Cmd) CombinedOutput() string {
+	c.sh.Ok()
+	res, err := c.combinedOutput()
+	c.handleError(err)
+	return res
+}
+
 // Pid returns the command's PID, or -1 if the command has not been started.
 func (c *Cmd) Pid() int {
 	if !c.started {
@@ -264,10 +273,6 @@
 	}
 }
 
-func (c *Cmd) addWriter(writers *[]io.Writer, w io.Writer) {
-	*writers = append(*writers, w)
-}
-
 func (c *Cmd) closeClosers() {
 	// If the same WriteCloser was passed to both AddStdoutWriter and
 	// AddStderrWriter, we should only close it once.
@@ -337,31 +342,58 @@
 	return len(p), nil
 }
 
-func (c *Cmd) makeMultiWriter(stdout bool, t string) (io.Writer, error) {
-	std, writers := os.Stderr, &c.stderrWriters
-	if stdout {
-		std, writers = os.Stdout, &c.stdoutWriters
-		c.addWriter(writers, &recvWriter{c: c})
-	}
+type lockedWriter struct {
+	mu *sync.Mutex
+	w  io.Writer
+}
+
+func (w lockedWriter) Write(p []byte) (int, error) {
+	w.mu.Lock()
+	n, err := w.w.Write(p)
+	w.mu.Unlock()
+	return n, err
+}
+
+func (c *Cmd) makeStdoutStderr() (io.Writer, io.Writer, error) {
+	c.stdoutWriters = append(c.stdoutWriters, &recvWriter{c: c})
 	if c.PropagateOutput {
-		c.addWriter(writers, std)
-		// Don't add std to c.closers, since we don't want to close os.Stdout or
-		// os.Stderr for the entire address space when c exits.
+		c.stdoutWriters = append(c.stdoutWriters, os.Stdout)
+		c.stderrWriters = append(c.stderrWriters, os.Stderr)
 	}
 	if c.OutputDir != "" {
-		suffix := "stderr"
-		if stdout {
-			suffix = "stdout"
+		t := time.Now().Format("20060102.150405.000000")
+		name := filepath.Join(c.OutputDir, filepath.Base(c.Path)+"."+t)
+		const flags = os.O_WRONLY | os.O_CREATE | os.O_EXCL
+		switch file, err := os.OpenFile(name+".stdout", flags, 0600); {
+		case err != nil:
+			return nil, nil, err
+		default:
+			c.stdoutWriters = append(c.stdoutWriters, file)
+			c.closers = append(c.closers, file)
 		}
-		name := filepath.Join(c.OutputDir, filepath.Base(c.Path)+"."+t+"."+suffix)
-		file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
-		if err != nil {
-			return nil, err
+		switch file, err := os.OpenFile(name+".stderr", flags, 0600); {
+		case err != nil:
+			return nil, nil, err
+		default:
+			c.stderrWriters = append(c.stderrWriters, file)
+			c.closers = append(c.closers, file)
 		}
-		c.addWriter(writers, file)
-		c.closers = append(c.closers, file)
 	}
-	return io.MultiWriter(*writers...), nil
+	switch hasOut, hasErr := len(c.stdoutWriters) > 0, len(c.stderrWriters) > 0; {
+	case hasOut && hasErr:
+		// Make writes synchronous between stdout and stderr. This ensures all
+		// writers that capture both will see the same ordering, and don't need to
+		// worry about concurrent writes.
+		sharedMu := &sync.Mutex{}
+		stdout := lockedWriter{sharedMu, io.MultiWriter(c.stdoutWriters...)}
+		stderr := lockedWriter{sharedMu, io.MultiWriter(c.stderrWriters...)}
+		return stdout, stderr, nil
+	case hasOut:
+		return io.MultiWriter(c.stdoutWriters...), nil, nil
+	case hasErr:
+		return nil, io.MultiWriter(c.stderrWriters...), nil
+	}
+	return nil, nil, nil
 }
 
 func (c *Cmd) clone() (*Cmd, error) {
@@ -399,7 +431,7 @@
 		return nil, errAlreadyCalledStart
 	}
 	p := NewBufferedPipe()
-	c.addWriter(&c.stdoutWriters, p)
+	c.stdoutWriters = append(c.stdoutWriters, p)
 	c.closers = append(c.closers, p)
 	return p, nil
 }
@@ -409,7 +441,7 @@
 		return nil, errAlreadyCalledStart
 	}
 	p := NewBufferedPipe()
-	c.addWriter(&c.stderrWriters, p)
+	c.stderrWriters = append(c.stderrWriters, p)
 	c.closers = append(c.closers, p)
 	return p, nil
 }
@@ -423,7 +455,7 @@
 	case wc == os.Stderr:
 		return errCloseStderr
 	}
-	c.addWriter(&c.stdoutWriters, wc)
+	c.stdoutWriters = append(c.stdoutWriters, wc)
 	c.closers = append(c.closers, wc)
 	return nil
 }
@@ -437,7 +469,7 @@
 	case wc == os.Stderr:
 		return errCloseStderr
 	}
-	c.addWriter(&c.stderrWriters, wc)
+	c.stderrWriters = append(c.stderrWriters, wc)
 	c.closers = append(c.closers, wc)
 	return nil
 }
@@ -445,17 +477,6 @@
 // TODO(sadovsky): Maybe wrap every child process with a "supervisor" process
 // that calls WatchParent().
 
-type threadSafeWriter struct {
-	mu sync.Mutex
-	w  io.Writer
-}
-
-func (w *threadSafeWriter) Write(p []byte) (int, error) {
-	w.mu.Lock()
-	defer w.mu.Unlock()
-	return w.w.Write(p)
-}
-
 func (c *Cmd) start() error {
 	if c.calledStart {
 		return errAlreadyCalledStart
@@ -468,29 +489,6 @@
 	if c.sh.calledCleanup {
 		return errAlreadyCalledCleanup
 	}
-	// Wrap Writers in threadSafeWriters as needed so that if the same WriteCloser
-	// was passed to both AddStdoutWriter and AddStderrWriter, the writes are
-	// serialized.
-	isStdoutWriter := map[io.Writer]bool{}
-	for _, w := range c.stdoutWriters {
-		isStdoutWriter[w] = true
-	}
-	safe := map[io.Writer]*threadSafeWriter{}
-	for i, w := range c.stderrWriters {
-		if isStdoutWriter[w] {
-			if safe[w] == nil {
-				safe[w] = &threadSafeWriter{w: w}
-			}
-			c.stderrWriters[i] = safe[w]
-		}
-	}
-	if len(safe) > 0 {
-		for i, w := range c.stdoutWriters {
-			if s := safe[w]; s != nil {
-				c.stdoutWriters[i] = s
-			}
-		}
-	}
 	// Configure the command.
 	c.c.Path = c.Path
 	c.c.Env = mapToSlice(c.Vars)
@@ -501,12 +499,8 @@
 		}
 		c.c.Stdin = strings.NewReader(c.Stdin)
 	}
-	t := time.Now().Format("20060102.150405.000000")
 	var err error
-	if c.c.Stdout, err = c.makeMultiWriter(true, t); err != nil {
-		return err
-	}
-	if c.c.Stderr, err = c.makeMultiWriter(false, t); err != nil {
+	if c.c.Stdout, c.c.Stderr, err = c.makeStdoutStderr(); err != nil {
 		return err
 	}
 	// Start the command.
@@ -652,7 +646,7 @@
 		return "", errAlreadyCalledStart
 	}
 	var stdout bytes.Buffer
-	c.addWriter(&c.stdoutWriters, &stdout)
+	c.stdoutWriters = append(c.stdoutWriters, &stdout)
 	err := c.run()
 	return stdout.String(), err
 }
@@ -662,8 +656,19 @@
 		return "", "", errAlreadyCalledStart
 	}
 	var stdout, stderr bytes.Buffer
-	c.addWriter(&c.stdoutWriters, &stdout)
-	c.addWriter(&c.stderrWriters, &stderr)
+	c.stdoutWriters = append(c.stdoutWriters, &stdout)
+	c.stderrWriters = append(c.stderrWriters, &stderr)
 	err := c.run()
 	return stdout.String(), stderr.String(), err
 }
+
+func (c *Cmd) combinedOutput() (string, error) {
+	if c.calledStart {
+		return "", errAlreadyCalledStart
+	}
+	var output bytes.Buffer
+	c.stdoutWriters = append(c.stdoutWriters, &output)
+	c.stderrWriters = append(c.stderrWriters, &output)
+	err := c.run()
+	return output.String(), err
+}
diff --git a/gosh/shell_test.go b/gosh/shell_test.go
index b1c9c12..73e5da0 100644
--- a/gosh/shell_test.go
+++ b/gosh/shell_test.go
@@ -474,23 +474,45 @@
 	sh.Err = nil
 }
 
-// Demonstrates how to capture combined stdout and stderr, a la
-// exec.Cmd.CombinedOutput.
-func TestCombinedStdoutStderr(t *testing.T) {
+func TestCombinedOutput(t *testing.T) {
 	sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
 	defer sh.Cleanup()
 
 	c := sh.Fn(writeFn, true, true)
 	buf := &bytes.Buffer{}
-	// Note, we must share a single NopWriteCloser so that Cmd detects that we
-	// passed the same WriteCloser to both AddStdoutWriter and AddStderrWriter.
-	wc := gosh.NopWriteCloser(buf)
-	c.AddStdoutWriter(wc)
-	c.AddStderrWriter(wc)
-	c.Run()
+	c.AddStdoutWriter(gosh.NopWriteCloser(buf))
+	c.AddStderrWriter(gosh.NopWriteCloser(buf))
+	output := c.CombinedOutput()
 	// Note, we can't assume any particular ordering of stdout and stderr, so we
 	// simply check the length of the combined output.
-	eq(t, len(buf.String()), 4)
+	eq(t, len(output), 4)
+	// The ordering must be the same, regardless of how we captured the combined
+	// output.
+	eq(t, output, buf.String())
+}
+
+func TestOutputDir(t *testing.T) {
+	sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
+	defer sh.Cleanup()
+
+	dir := sh.MakeTempDir()
+	c := sh.Fn(writeFn, true, true)
+	c.OutputDir = dir
+	c.Run()
+
+	matches, err := filepath.Glob(filepath.Join(dir, "*.stdout"))
+	ok(t, err)
+	eq(t, len(matches), 1)
+	stdout, err := ioutil.ReadFile(matches[0])
+	ok(t, err)
+	eq(t, string(stdout), "AA")
+
+	matches, err = filepath.Glob(filepath.Join(dir, "*.stderr"))
+	ok(t, err)
+	eq(t, len(matches), 1)
+	stderr, err := ioutil.ReadFile(matches[0])
+	ok(t, err)
+	eq(t, string(stderr), "BB")
 }
 
 type countingWriteCloser struct {