gosh: log head and tail of failed cmd stdout and stderr

addresses https://v.io/i/1127

Change-Id: I3395e04ab1a69d38b1c5e5c09e80a0c4acc4a74e
diff --git a/gosh/buffered_pipe_test.go b/gosh/buffered_pipe_test.go
index 0ba942f..26da848 100644
--- a/gosh/buffered_pipe_test.go
+++ b/gosh/buffered_pipe_test.go
@@ -12,7 +12,7 @@
 	"testing"
 )
 
-func TestReadWriteAfterClose(t *testing.T) {
+func TestBufferedPipeReadWriteAfterClose(t *testing.T) {
 	p := newBufferedPipe()
 	if n, err := p.Write([]byte("foo")); n != 3 || err != nil {
 		t.Errorf("write got (%v, %v), want (3, <nil>)", n, err)
@@ -37,7 +37,7 @@
 	}
 }
 
-func TestReadFromWriteTo(t *testing.T) {
+func TestBufferedPipeReadFromWriteTo(t *testing.T) {
 	p, buf := newBufferedPipe(), new(bytes.Buffer)
 	if n, err := p.(io.ReaderFrom).ReadFrom(strings.NewReader("foobarbaz")); n != 9 || err != nil {
 		t.Errorf("ReadFrom got (%v, %v), want (9, <nil>)", n, err)
@@ -62,7 +62,7 @@
 	}
 }
 
-func TestWriteToMany(t *testing.T) {
+func TestBufferedPipeWriteToMany(t *testing.T) {
 	p := newBufferedPipe()
 	pR, pW := io.Pipe()
 	nCh, errCh := make(chan int64, 1), make(chan error, 1)
diff --git a/gosh/cmd.go b/gosh/cmd.go
index 23347ed..25ea76a 100644
--- a/gosh/cmd.go
+++ b/gosh/cmd.go
@@ -13,6 +13,7 @@
 	"os"
 	"os/exec"
 	"path/filepath"
+	"strings"
 	"sync"
 	"syscall"
 	"time"
@@ -77,6 +78,8 @@
 	stdinDoneChan     chan error
 	started           bool // protected by sh.cleanupMu
 	exited            bool // protected by cond.L
+	stdoutHeadTail    *headTail
+	stderrHeadTail    *headTail
 	stdoutWriters     []io.Writer
 	stderrWriters     []io.Writer
 	afterStartClosers []io.Closer
@@ -235,16 +238,20 @@
 ////////////////////////////////////////
 // Internals
 
+const headTailCapacity = 1 << 15
+
 func newCmdInternal(sh *Shell, vars map[string]string, path string, args []string) (*Cmd, error) {
 	c := &Cmd{
-		Path:     path,
-		Vars:     vars,
-		Args:     append([]string{path}, args...),
-		sh:       sh,
-		c:        &exec.Cmd{},
-		cond:     sync.NewCond(&sync.Mutex{}),
-		waitChan: make(chan error, 1),
-		recvVars: map[string]string{},
+		Path:           path,
+		Vars:           vars,
+		Args:           append([]string{path}, args...),
+		sh:             sh,
+		c:              &exec.Cmd{},
+		cond:           sync.NewCond(&sync.Mutex{}),
+		waitChan:       make(chan error, 1),
+		stdoutHeadTail: newHeadTail(headTailCapacity),
+		stderrHeadTail: newHeadTail(headTailCapacity),
+		recvVars:       map[string]string{},
 	}
 	// Protect against concurrent signal-triggered Shell.cleanup().
 	sh.cleanupMu.Lock()
@@ -268,13 +275,13 @@
 	return newCmdInternal(sh, vars, name, args)
 }
 
+func isExitError(err error) bool {
+	_, ok := err.(*exec.ExitError)
+	return ok
+}
+
 func (c *Cmd) errorIsOk(err error) bool {
-	if c.ExitErrorIsOk {
-		if _, ok := err.(*exec.ExitError); ok {
-			return true
-		}
-	}
-	return err == nil
+	return err == nil || c.ExitErrorIsOk && isExitError(err)
 }
 
 // An explanation of closed pipe errors. Consider the pipeline "yes | head -1",
@@ -302,6 +309,8 @@
 // might also want to add code to InitChildMain to exit the program with 141 if
 // it receives SIGPIPE.
 
+var sep = strings.Repeat("-", 40)
+
 func (c *Cmd) handleError(err error) {
 	if c.IgnoreClosedPipeError && isClosedPipeError(err) {
 		err = nil
@@ -310,6 +319,11 @@
 	if c.errorIsOk(err) {
 		err = nil
 	}
+	if isExitError(err) && !c.sh.ContinueOnError {
+		c.sh.tb.Logf("gosh: command failed: %s\n", strings.Join(c.Args, " "))
+		c.sh.tb.Logf("\nSTDOUT\n%s\n%s\n", sep, c.stdoutHeadTail.String())
+		c.sh.tb.Logf("\nSTDERR\n%s\n%s\n", sep, c.stderrHeadTail.String())
+	}
 	c.sh.HandleErrorWithSkip(err, 3)
 }
 
@@ -371,6 +385,8 @@
 
 func (c *Cmd) makeStdoutStderr() (io.Writer, io.Writer, error) {
 	c.stderrWriters = append(c.stderrWriters, &recvWriter{c: c})
+	c.stdoutWriters = append(c.stdoutWriters, c.stdoutHeadTail)
+	c.stderrWriters = append(c.stderrWriters, c.stderrHeadTail)
 	if c.PropagateOutput {
 		c.stdoutWriters = append(c.stdoutWriters, os.Stdout)
 		c.stderrWriters = append(c.stderrWriters, os.Stderr)
@@ -758,3 +774,55 @@
 	err := c.run()
 	return output.String(), err
 }
+
+////////////////////////////////////////
+// Head-and-tail buffer
+
+// headTail stores the first and last 'capacity' written bytes.
+type headTail struct {
+	head     []byte
+	tail     *ringBuffer
+	nWritten int // number of bytes written
+}
+
+func newHeadTail(capacity int) *headTail {
+	return &headTail{head: make([]byte, capacity)}
+}
+
+// Write writes to the buffer.
+func (b *headTail) Write(p []byte) (int, error) {
+	nHead := len(b.head) - b.nWritten // number of bytes to write to head
+	if nHead > len(p) {
+		nHead = len(p)
+	} else if nHead < 0 {
+		nHead = 0
+	}
+	if nHead > 0 {
+		copy(b.head[b.nWritten:], p[:nHead])
+	}
+	// Write any remaining bytes to tail.
+	if len(p) > nHead {
+		if b.tail == nil {
+			b.tail = newRingBuffer(len(b.head))
+		}
+		b.tail.Append(p[nHead:])
+	}
+	b.nWritten += len(p)
+	return len(p), nil
+}
+
+// String returns the buffer as a string.
+func (b *headTail) String() string {
+	if b.nWritten == 0 {
+		return "[ empty ]"
+	}
+	if b.tail == nil {
+		return string(b.head[:b.nWritten])
+	}
+	tail := b.tail.String()
+	skipped := b.nWritten - 2*len(b.head)
+	if skipped <= 0 {
+		return fmt.Sprintf("%s%s", b.head, tail)
+	}
+	return fmt.Sprintf("%s\n[ ... skipping %d bytes ... ]\n%s", b.head, skipped, tail)
+}
diff --git a/gosh/ring_buffer.go b/gosh/ring_buffer.go
new file mode 100644
index 0000000..d678e05
--- /dev/null
+++ b/gosh/ring_buffer.go
@@ -0,0 +1,51 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gosh
+
+type ringBuffer struct {
+	buf   []byte
+	start int
+	len   int
+}
+
+// newRingBuffer returns a new fixed-size buffer that holds the last 'capacity'
+// bytes written.
+func newRingBuffer(capacity int) *ringBuffer {
+	return &ringBuffer{buf: make([]byte, capacity)}
+}
+
+// Append writes to the buffer.
+func (b *ringBuffer) Append(p []byte) {
+	if len(b.buf) == 0 {
+		return
+	}
+	if len(p) >= len(b.buf) {
+		copy(b.buf, p[len(p)-len(b.buf):])
+		b.start = 0
+		b.len = len(b.buf)
+		return
+	}
+	// Copy p into b.buf.
+	end := (b.start + b.len) % len(b.buf)
+	n := copy(b.buf[end:], p)
+	if n < len(p) {
+		copy(b.buf, p[n:])
+	}
+	// Update b.start and b.len.
+	b.len += len(p)
+	if b.len > len(b.buf) {
+		b.start = (b.start + b.len) % len(b.buf)
+		b.len = len(b.buf)
+	}
+}
+
+// String returns the buffer as a string.
+func (b *ringBuffer) String() string {
+	if b.start == 0 {
+		return string(b.buf[:b.len])
+	}
+	// INVARIANT: If b.start > 0, b.len == len(b.buf).
+	return string(b.buf[b.start:]) + string(b.buf[:b.start])
+}
diff --git a/gosh/ring_buffer_test.go b/gosh/ring_buffer_test.go
new file mode 100644
index 0000000..c9aa6a8
--- /dev/null
+++ b/gosh/ring_buffer_test.go
@@ -0,0 +1,127 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gosh
+
+import (
+	"testing"
+)
+
+func TestRingBufferBasic(t *testing.T) {
+	b := newRingBuffer(5)
+	if got, want := b.String(), ""; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("foo"))
+	if got, want := b.String(), "foo"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("bar"))
+	if got, want := b.String(), "oobar"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	// Append an empty string.
+	b.Append([]byte(""))
+	if got, want := b.String(), "oobar"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// This time, appending a string that puts us right at the cap.
+	b = newRingBuffer(3)
+	b.Append([]byte("foo"))
+	if got, want := b.String(), "foo"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("bar"))
+	if got, want := b.String(), "bar"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// This time, appending a string that's much bigger than the buffer.
+	b = newRingBuffer(2)
+	b.Append([]byte("012345678"))
+	if got, want := b.String(), "78"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("0123456789"))
+	if got, want := b.String(), "89"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("0"))
+	if got, want := b.String(), "90"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// This time, with a size-1 buffer.
+	b = newRingBuffer(1)
+	b.Append([]byte("f"))
+	if got, want := b.String(), "f"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("o"))
+	if got, want := b.String(), "o"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append([]byte("bar"))
+	if got, want := b.String(), "r"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// This time, with a size-0 buffer.
+	b = newRingBuffer(0)
+	b.Append([]byte("f"))
+	if got, want := b.String(), ""; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+}
+
+func TestRingBufferCopiesBytes(t *testing.T) {
+	foo, bar := []byte("foo"), []byte("bar")
+	b := newRingBuffer(5)
+	b.Append(foo)
+	foo[2] = 'z'
+	if got, want := b.String(), "foo"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append(bar)
+	bar[2] = 'z'
+	if got, want := b.String(), "oobar"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// This time, appending a string that puts us right at the cap.
+	foo, bar = []byte("foo"), []byte("bar")
+	b = newRingBuffer(3)
+	b.Append(foo)
+	foo[2] = 'z'
+	if got, want := b.String(), "foo"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+	b.Append(bar)
+	bar[2] = 'z'
+	if got, want := b.String(), "bar"; got != want {
+		t.Errorf("got %v, want %v", got, want)
+	}
+}
+
+func TestRingBufferStress(t *testing.T) {
+	const s = "0123456789"
+	for strLen := 0; strLen <= len(s); strLen++ {
+		for bufCap := 0; bufCap <= 2*len(s); bufCap++ {
+			b := newRingBuffer(bufCap)
+			all := ""
+			for i := 0; i < 2*len(s); i++ {
+				b.Append([]byte(s[:strLen]))
+				all += s[:strLen]
+			}
+			start := len(all) - bufCap
+			if start < 0 {
+				start = 0
+			}
+			if got, want := b.String(), all[start:]; got != want {
+				t.Errorf("got %v, want %v", got, want)
+			}
+		}
+	}
+}
diff --git a/gosh/shell.go b/gosh/shell.go
index 3bebf8c..3c9df0b 100644
--- a/gosh/shell.go
+++ b/gosh/shell.go
@@ -324,8 +324,9 @@
 }
 
 func (sh *Shell) wait() error {
-	// Note: It is illegal to call newCmdInternal concurrently with Shell.wait, so
-	// we need not hold cleanupMu when accessing sh.cmds below.
+	// Note: It is illegal to call newCmdInternal (which mutates sh.cmds)
+	// concurrently with Shell.wait, so we need not hold cleanupMu when accessing
+	// sh.cmds below.
 	var res error
 	for _, c := range sh.cmds {
 		if !c.started || c.calledWait {
diff --git a/gosh/shell_test.go b/gosh/shell_test.go
index 5771c99..cdfd8da 100644
--- a/gosh/shell_test.go
+++ b/gosh/shell_test.go
@@ -137,6 +137,13 @@
 	buf           *bytes.Buffer
 }
 
+func (tb *customTB) Reset() {
+	tb.calledFailNow = false
+	if tb.buf != nil {
+		tb.buf.Reset()
+	}
+}
+
 func (tb *customTB) FailNow() {
 	tb.calledFailNow = true
 }
@@ -895,7 +902,7 @@
 	defer sh.Cleanup()
 
 	// Call HandleError, then check that the stack trace and error got logged.
-	tb.buf.Reset()
+	tb.Reset()
 	sh.HandleError(fakeError)
 	_, file, line, _ := runtime.Caller(0)
 	got, wantSuffix := tb.buf.String(), fmt.Sprintf("%s:%d: %v\n", filepath.Base(file), line-1, fakeError)
@@ -910,7 +917,7 @@
 	// Same as above, but with ContinueOnError set to true. Only the error should
 	// get logged.
 	sh.ContinueOnError = true
-	tb.buf.Reset()
+	tb.Reset()
 	sh.HandleError(fakeError)
 	_, file, line, _ = runtime.Caller(0)
 	got, want := tb.buf.String(), fmt.Sprintf("%s:%d: %v\n", filepath.Base(file), line-1, fakeError)
@@ -918,7 +925,7 @@
 	sh.Err = nil
 
 	// Same as above, but calling HandleErrorWithSkip, with skip set to 1.
-	tb.buf.Reset()
+	tb.Reset()
 	sh.HandleErrorWithSkip(fakeError, 1)
 	_, file, line, _ = runtime.Caller(0)
 	got, want = tb.buf.String(), fmt.Sprintf("%s:%d: %v\n", filepath.Base(file), line-1, fakeError)
@@ -926,7 +933,7 @@
 	sh.Err = nil
 
 	// Same as above, but with skip set to 2.
-	tb.buf.Reset()
+	tb.Reset()
 	sh.HandleErrorWithSkip(fakeError, 2)
 	_, file, line, _ = runtime.Caller(1)
 	got, want = tb.buf.String(), fmt.Sprintf("%s:%d: %v\n", filepath.Base(file), line, fakeError)
@@ -934,6 +941,94 @@
 	sh.Err = nil
 }
 
+var cmdFailureFunc = gosh.RegisterFunc("cmdFailureFunc", func(nStdout, nStderr int) error {
+	if _, err := os.Stdout.Write([]byte(strings.Repeat("A", nStdout))); err != nil {
+		return err
+	}
+	if _, err := os.Stderr.Write([]byte(strings.Repeat("B", nStderr))); err != nil {
+		return err
+	}
+	time.Sleep(time.Second)
+	return fakeError
+})
+
+// Tests that when a command fails, we log the head and tail of its stdout and
+// stderr.
+func TestCmdFailureLoggingEnabled(t *testing.T) {
+	tb := &customTB{t: t, buf: &bytes.Buffer{}}
+	sh := gosh.NewShell(tb)
+	defer sh.Cleanup()
+
+	const k = 1 << 15
+
+	// Note: When a FuncCmd fails, InitMain calls log.Fatal(err), which writes err
+	// to stderr. In several places below, our expected stderr must accommodate
+	// this logged fakeError string.
+	cmdFailureLoggingTestCases := []struct {
+		nStdout    int
+		nStderr    int
+		wantStdout string
+		wantStderr string
+	}{
+		{0, 0, "[ empty ]", ""},
+		{1, 1, "A", "B"},
+		{k, k, strings.Repeat("A", k), strings.Repeat("B", k)},
+		{k + 1, k + 1, strings.Repeat("A", k+1), strings.Repeat("B", k+1)},
+		// Stderr includes fakeError.
+		{2 * k, 2 * k, strings.Repeat("A", 2*k), strings.Repeat("B", k) + "\n[ ... skipping "},
+		// Stderr includes fakeError.
+		{2*k + 1, 2*k + 1, strings.Repeat("A", k) + "\n[ ... skipping 1 bytes ... ]\n" + strings.Repeat("A", k), strings.Repeat("B", k) + "\n[ ... skipping "},
+	}
+
+	sep := strings.Repeat("-", 40)
+	for _, tc := range cmdFailureLoggingTestCases {
+		tb.Reset()
+		sh.FuncCmd(cmdFailureFunc, tc.nStdout, tc.nStderr).Run()
+		got := tb.buf.String()
+		wantStdout := fmt.Sprintf("\nSTDOUT\n%s\n%s\n", sep, tc.wantStdout)
+		if !strings.Contains(got, wantStdout) {
+			t.Fatalf("got %v, want substring %v", got, wantStdout)
+		}
+		// Stderr includes fakeError.
+		wantStderr := fmt.Sprintf("\nSTDERR\n%s\n%s", sep, tc.wantStderr)
+		if !strings.Contains(got, wantStderr) {
+			t.Fatalf("got %v, want substring %v", got, wantStderr)
+		}
+		sh.Err = nil
+	}
+}
+
+// Tests that we don't log command failures when ExitErrorIsOk or
+// ContinueOnError is set.
+func TestCmdFailureLoggingDisabled(t *testing.T) {
+	tb := &customTB{t: t, buf: &bytes.Buffer{}}
+	sh := gosh.NewShell(tb)
+	defer sh.Cleanup()
+
+	// If ExitErrorIsOk is set and the command fails, we shouldn't log anything.
+	tb.Reset()
+	c := sh.FuncCmd(exitFunc, 1)
+	c.ExitErrorIsOk = true
+	c.Run()
+	eq(t, tb.calledFailNow, false)
+	eq(t, tb.buf.String(), "")
+
+	// If ContinueOnError is set and the command fails, we should log the exit
+	// status but not the command stderr.
+	tb.Reset()
+	c = sh.FuncCmd(exitFunc, 1)
+	sh.ContinueOnError = true
+	c.Run()
+	eq(t, tb.calledFailNow, false)
+	got := tb.buf.String()
+	if !strings.Contains(got, "exit status 1") {
+		t.Fatalf("missing error: %s", got)
+	}
+	if strings.Contains(got, "STDERR") {
+		t.Fatalf("should not log stderr: %s", got)
+	}
+}
+
 func TestMain(m *testing.M) {
 	gosh.InitMain()
 	os.Exit(m.Run())