gosh: make "await" methods return immediately on process exit

Change-Id: If2ae3b119e2d0abdceac89ad9a225c911988e2c2
diff --git a/gosh/cmd.go b/gosh/cmd.go
index 8678d24..38ca2a0 100644
--- a/gosh/cmd.go
+++ b/gosh/cmd.go
@@ -22,6 +22,7 @@
 	errAlreadyCalledStart = errors.New("gosh: already called Cmd.Start")
 	errAlreadyCalledWait  = errors.New("gosh: already called Cmd.Wait")
 	errDidNotCallStart    = errors.New("gosh: did not call Cmd.Start")
+	errProcessExited      = errors.New("gosh: process exited")
 )
 
 // Cmd represents a command. Not thread-safe.
@@ -50,17 +51,15 @@
 	stdinWriteCloser io.WriteCloser // from exec.Cmd.StdinPipe
 	calledStart      bool
 	calledWait       bool
+	cond             *sync.Cond
 	waitChan         chan error
 	started          bool // protected by sh.cleanupMu
-	exitedMu         sync.Mutex
-	exited           bool // protected by exitedMu
+	exited           bool // protected by cond.L
 	stdoutWriters    []io.Writer
 	stderrWriters    []io.Writer
 	closers          []io.Closer
-	condReady        *sync.Cond
-	recvReady        bool // protected by condReady.L
-	condVars         *sync.Cond
-	recvVars         map[string]string // protected by condVars.L
+	recvReady        bool              // protected by cond.L
+	recvVars         map[string]string // protected by cond.L
 }
 
 // Clone returns a new Cmd with a copy of this Cmd's configuration.
@@ -72,9 +71,11 @@
 }
 
 // StdinPipe returns a thread-safe WriteCloser backed by a buffered pipe for the
-// command's stdin. The returned WriteCloser will be closed when the process
-// exits. Must be called before Start. It is safe to call StdinPipe multiple
-// times; calls after the first return the pipe created by the first call.
+// command's stdin. The returned pipe will be closed when the process exits, but
+// may also be closed earlier by the caller, e.g. if the command does not exit
+// until its stdin is closed. Must be called before Start. It is safe to call
+// StdinPipe multiple times; calls after the first return the pipe created by
+// the first call.
 func (c *Cmd) StdinPipe() io.WriteCloser {
 	c.sh.Ok()
 	res, err := c.stdinPipe()
@@ -200,15 +201,14 @@
 
 func newCmdInternal(sh *Shell, vars map[string]string, path string, args []string) (*Cmd, error) {
 	c := &Cmd{
-		Path:      path,
-		Vars:      vars,
-		Args:      args,
-		sh:        sh,
-		c:         &exec.Cmd{},
-		waitChan:  make(chan error, 1),
-		condReady: sync.NewCond(&sync.Mutex{}),
-		condVars:  sync.NewCond(&sync.Mutex{}),
-		recvVars:  map[string]string{},
+		Path:     path,
+		Vars:     vars,
+		Args:     args,
+		sh:       sh,
+		c:        &exec.Cmd{},
+		cond:     sync.NewCond(&sync.Mutex{}),
+		waitChan: make(chan error, 1),
+		recvVars: map[string]string{},
 	}
 	// Protect against concurrent signal-triggered Shell.cleanup().
 	sh.cleanupMu.Lock()
@@ -271,8 +271,8 @@
 	if !c.started {
 		return false
 	}
-	c.exitedMu.Lock()
-	defer c.exitedMu.Unlock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 	return !c.exited
 }
 
@@ -294,15 +294,15 @@
 				}
 				switch m.Type {
 				case typeReady:
-					w.c.condReady.L.Lock()
+					w.c.cond.L.Lock()
 					w.c.recvReady = true
-					w.c.condReady.Signal()
-					w.c.condReady.L.Unlock()
+					w.c.cond.Signal()
+					w.c.cond.L.Unlock()
 				case typeVars:
-					w.c.condVars.L.Lock()
+					w.c.cond.L.Lock()
 					w.c.recvVars = mergeMaps(w.c.recvVars, m.Vars)
-					w.c.condVars.Signal()
-					w.c.condVars.L.Unlock()
+					w.c.cond.Signal()
+					w.c.cond.L.Unlock()
 				default:
 					return 0, fmt.Errorf("unknown message type: %q", m.Type)
 				}
@@ -448,13 +448,16 @@
 		return err
 	}
 	// Start the command.
-	err = c.c.Start()
-	if err != nil {
-		c.exitedMu.Lock()
+	onExit := func(err error) {
+		c.cond.L.Lock()
 		c.exited = true
-		c.exitedMu.Unlock()
+		c.cond.Signal()
+		c.cond.L.Unlock()
 		c.closeClosers()
-		c.waitChan <- errors.New("gosh: start failed")
+		c.waitChan <- err
+	}
+	if err = c.c.Start(); err != nil {
+		onExit(errors.New("gosh: start failed"))
 		return err
 	}
 	c.started = true
@@ -463,18 +466,12 @@
 	// ensures that the child process is reaped once it exits. Note, gosh.Cmd.wait
 	// blocks on waitChan.
 	go func() {
-		err := c.c.Wait()
-		c.exitedMu.Lock()
-		c.exited = true
-		c.exitedMu.Unlock()
-		c.closeClosers()
-		c.waitChan <- err
+		onExit(c.c.Wait())
 	}()
 	return nil
 }
 
-// TODO(sadovsky): Make it so Cmd.{awaitReady,awaitVars} return an error if/when
-// we detect that the process has exited. Also, maybe add optional timeouts for
+// TODO(sadovsky): Maybe add optional timeouts for
 // Cmd.{awaitReady,awaitVars,wait}.
 
 func (c *Cmd) awaitReady() error {
@@ -483,12 +480,15 @@
 	} else if c.calledWait {
 		return errAlreadyCalledWait
 	}
-	// http://golang.org/pkg/sync/#Cond.Wait
-	c.condReady.L.Lock()
-	for !c.recvReady {
-		c.condReady.Wait()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
+	for !c.exited && !c.recvReady {
+		c.cond.Wait()
 	}
-	c.condReady.L.Unlock()
+	// Return nil error if both conditions triggered simultaneously.
+	if !c.recvReady {
+		return errProcessExited
+	}
 	return nil
 }
 
@@ -510,14 +510,17 @@
 			}
 		}
 	}
-	// http://golang.org/pkg/sync/#Cond.Wait
-	c.condVars.L.Lock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 	updateRes()
-	for len(res) < len(wantKeys) {
-		c.condVars.Wait()
+	for !c.exited && len(res) < len(wantKeys) {
+		c.cond.Wait()
 		updateRes()
 	}
-	c.condVars.L.Unlock()
+	// Return nil error if both conditions triggered simultaneously.
+	if len(res) < len(wantKeys) {
+		return nil, errProcessExited
+	}
 	return res, nil
 }
 
diff --git a/gosh/internal/gosh_example/main.go b/gosh/internal/gosh_example/main.go
index 67fa15c..bde8f86 100644
--- a/gosh/internal/gosh_example/main.go
+++ b/gosh/internal/gosh_example/main.go
@@ -30,8 +30,8 @@
 }
 
 var (
-	getFn   = gosh.Register("get", lib.Get)
-	serveFn = gosh.Register("serve", lib.Serve)
+	getFn   = gosh.Register("getFn", lib.Get)
+	serveFn = gosh.Register("serveFn", lib.Serve)
 )
 
 func ExampleFns() {
diff --git a/gosh/shell_test.go b/gosh/shell_test.go
index 8c25904..fc78858 100644
--- a/gosh/shell_test.go
+++ b/gosh/shell_test.go
@@ -82,30 +82,30 @@
 
 // Simplified versions of various Unix commands.
 var (
-	catFn = gosh.Register("cat", func() {
+	catFn = gosh.Register("catFn", func() {
 		io.Copy(os.Stdout, os.Stdin)
 	})
-	echoFn = gosh.Register("echo", func() {
+	echoFn = gosh.Register("echoFn", func() {
 		fmt.Println(os.Args[1])
 	})
-	readFn = gosh.Register("read", func() {
+	readFn = gosh.Register("readFn", func() {
 		bufio.NewReader(os.Stdin).ReadString('\n')
 	})
 )
 
 // Functions with parameters.
 var (
-	exitFn = gosh.Register("exit", func(code int) {
+	exitFn = gosh.Register("exitFn", func(code int) {
 		os.Exit(code)
 	})
-	sleepFn = gosh.Register("sleep", func(d time.Duration, code int) {
+	sleepFn = gosh.Register("sleepFn", func(d time.Duration, code int) {
 		time.Sleep(d)
 		os.Exit(code)
 	})
-	printFn = gosh.Register("print", func(v ...interface{}) {
+	printFn = gosh.Register("printFn", func(v ...interface{}) {
 		fmt.Print(v...)
 	})
-	printfFn = gosh.Register("printf", func(format string, v ...interface{}) {
+	printfFn = gosh.Register("printfFn", func(format string, v ...interface{}) {
 		fmt.Printf(format, v...)
 	})
 )
@@ -200,8 +200,8 @@
 }
 
 var (
-	getFn   = gosh.Register("get", lib.Get)
-	serveFn = gosh.Register("serve", lib.Serve)
+	getFn   = gosh.Register("getFn", lib.Get)
+	serveFn = gosh.Register("serveFn", lib.Serve)
 )
 
 func TestFns(t *testing.T) {
@@ -230,14 +230,14 @@
 
 // Functions designed for TestRegistry.
 var (
-	printIntsFn = gosh.Register("printInts", func(v ...int) {
+	printIntsFn = gosh.Register("printIntsFn", func(v ...int) {
 		var vi []interface{}
 		for _, x := range v {
 			vi = append(vi, x)
 		}
 		fmt.Print(vi...)
 	})
-	printfIntsFn = gosh.Register("printfInts", func(format string, v ...int) {
+	printfIntsFn = gosh.Register("printfIntsFn", func(format string, v ...int) {
 		var vi []interface{}
 		for _, x := range v {
 			vi = append(vi, x)
@@ -246,6 +246,28 @@
 	})
 )
 
+// Tests that Await{Ready,Vars} return immediately when the process exits.
+func TestAwaitProcessExit(t *testing.T) {
+	sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
+	defer sh.Cleanup()
+
+	c := sh.Fn(exitFn, 0)
+	c.Start()
+	sh.Opts.Fatalf = nil
+	c.AwaitReady()
+	nok(t, sh.Err)
+	sh.Err = nil
+	sh.Opts.Fatalf = makeFatalf(t)
+
+	c = sh.Fn(exitFn, 0)
+	c.Start()
+	sh.Opts.Fatalf = nil
+	c.AwaitVars("foo")
+	nok(t, sh.Err)
+	sh.Err = nil
+	sh.Opts.Fatalf = makeFatalf(t)
+}
+
 // Tests function signature-checking and execution.
 func TestRegistry(t *testing.T) {
 	sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
@@ -351,7 +373,7 @@
 	nok(t, sh.Err)
 }
 
-var writeFn = gosh.Register("write", func(stdout, stderr bool) error {
+var writeFn = gosh.Register("writeFn", func(stdout, stderr bool) error {
 	if stdout {
 		if _, err := os.Stdout.Write([]byte("A")); err != nil {
 			return err
@@ -407,7 +429,7 @@
 	eq(t, toString(t, stderrPipe), "BB")
 }
 
-var writeMoreFn = gosh.Register("writeMore", func() {
+var writeMoreFn = gosh.Register("writeMoreFn", func() {
 	sh := gosh.NewShell(gosh.Opts{})
 	defer sh.Cleanup()