Merge "veyron/lib/modules: allow Shutdown to return an error status and avoid races."
diff --git a/lib/modules/core/core_test.go b/lib/modules/core/core_test.go
index 8a6614b..fc104a9 100644
--- a/lib/modules/core/core_test.go
+++ b/lib/modules/core/core_test.go
@@ -40,7 +40,7 @@
 	s := expect.NewSession(t, root.Stdout(), time.Second)
 	s.ExpectVar("MT_NAME")
 	s.ExpectVar("PID")
-	root.Stdin().Close()
+	root.CloseStdin()
 	s.Expect("done")
 }
 
diff --git a/lib/modules/exec.go b/lib/modules/exec.go
index 279e6dc..5da6499 100644
--- a/lib/modules/exec.go
+++ b/lib/modules/exec.go
@@ -24,6 +24,7 @@
 	cmd        *exec.Cmd
 	entryPoint string
 	handle     *vexec.ParentHandle
+	sh         *Shell
 	stderr     *os.File
 	stdout     *bufio.Reader
 	stdin      io.WriteCloser
@@ -79,12 +80,18 @@
 	return eh.stderr
 }
 
-func (eh *execHandle) Stdin() io.WriteCloser {
+func (eh *execHandle) Stdin() io.Writer {
 	eh.mu.Lock()
 	defer eh.mu.Unlock()
 	return eh.stdin
 }
 
+func (eh *execHandle) CloseStdin() {
+	eh.mu.Lock()
+	eh.stdin.Close()
+	eh.mu.Unlock()
+}
+
 // mergeOSEnv returns a slice contained the merged set of environment
 // variables from the OS environment and those in this Shell, preferring
 // values in the Shell environment over those found in the OS environment.
@@ -126,6 +133,7 @@
 func (eh *execHandle) start(sh *Shell, args ...string) (Handle, error) {
 	eh.mu.Lock()
 	defer eh.mu.Unlock()
+	eh.sh = sh
 	newargs := append(testFlags(), args...)
 	cmd := exec.Command(os.Args[0], newargs...)
 	cmd.Env = append(sh.mergeOSEnvSlice(), eh.entryPoint)
@@ -156,26 +164,28 @@
 	return eh, err
 }
 
-func (eh *execHandle) Shutdown(output io.Writer) {
+func (eh *execHandle) Shutdown(output io.Writer) error {
 	eh.mu.Lock()
 	defer eh.mu.Unlock()
 	eh.stdin.Close()
+	defer eh.sh.forget(eh)
 	if eh.stderr != nil {
 		defer func() {
 			eh.stderr.Close()
 			os.Remove(eh.stderr.Name())
 		}()
 		if output == nil {
-			return
+			return eh.cmd.Wait()
 		}
 		if _, err := eh.stderr.Seek(0, 0); err != nil {
-			return
+			return eh.cmd.Wait()
 		}
 		scanner := bufio.NewScanner(eh.stderr)
 		for scanner.Scan() {
 			fmt.Fprintf(output, "%s\n", scanner.Text())
 		}
 	}
+	return eh.cmd.Wait()
 }
 
 const shellEntryPoint = "VEYRON_SHELL_HELPER_PROCESS_ENTRY_POINT"
diff --git a/lib/modules/func.go b/lib/modules/func.go
index e5fd73b..5a7c149 100644
--- a/lib/modules/func.go
+++ b/lib/modules/func.go
@@ -17,6 +17,9 @@
 	main                  Main
 	stdin, stderr, stdout pipe
 	bufferedStdout        *bufio.Reader
+	err                   error
+	sh                    *Shell
+	wg                    sync.WaitGroup
 }
 
 func newFunctionHandle(main Main) command {
@@ -35,15 +38,23 @@
 	return fh.stderr.r
 }
 
-func (fh *functionHandle) Stdin() io.WriteCloser {
+func (fh *functionHandle) Stdin() io.Writer {
 	fh.mu.Lock()
 	defer fh.mu.Unlock()
 	return fh.stdin.w
 }
 
+func (fh *functionHandle) CloseStdin() {
+	fh.mu.Lock()
+	fd := fh.stdin.w.Fd()
+	fh.mu.Unlock()
+	syscall.Close(int(fd))
+}
+
 func (fh *functionHandle) start(sh *Shell, args ...string) (Handle, error) {
 	fh.mu.Lock()
 	defer fh.mu.Unlock()
+	fh.sh = sh
 	for _, p := range []*pipe{&fh.stdin, &fh.stdout, &fh.stderr} {
 		var err error
 		if p.r, p.w, err = os.Pipe(); err != nil {
@@ -51,30 +62,43 @@
 		}
 	}
 	fh.bufferedStdout = bufio.NewReader(fh.stdout.r)
+	fh.wg.Add(1)
+
 	go func() {
 		err := fh.main(fh.stdin.r, fh.stdout.w, fh.stderr.w, sh.mergeOSEnv(), args...)
 		if err != nil {
 			fmt.Fprintf(fh.stderr.w, "%s\n", err)
 		}
-		// See the comment below in Shutdown.
-		syscall.Close(int(fh.stdin.r.Fd()))
-		syscall.Close(int(fh.stdout.w.Fd()))
-		syscall.Close(int(fh.stderr.w.Fd()))
+		fh.mu.Lock()
+		// We close these files using the Close system call since there
+		// may be an oustanding read on them that would otherwise trigger
+		// a test failure with go test -race
+		syscall.Close(int(fh.stdin.w.Fd()))
+		syscall.Close(int(fh.stdout.r.Fd()))
+		syscall.Close(int(fh.stderr.r.Fd()))
+		fh.err = err
+		fh.mu.Unlock()
+		fh.wg.Done()
 	}()
 	return fh, nil
 }
 
-func (fh *functionHandle) Shutdown(output io.Writer) {
+func (fh *functionHandle) Shutdown(output io.Writer) error {
 	fh.mu.Lock()
-	defer fh.mu.Unlock()
-	scanner := bufio.NewScanner(fh.stderr.r)
-	for scanner.Scan() {
-		fmt.Fprintf(output, "%s\n", scanner.Text())
-	}
-	// We close these files using the Close system call since there
-	// may be an oustanding read on them that would otherwise trigger
-	// a test failure with go test -race
 	syscall.Close(int(fh.stdin.w.Fd()))
-	syscall.Close(int(fh.stdout.r.Fd()))
-	syscall.Close(int(fh.stderr.r.Fd()))
+	if output != nil {
+		scanner := bufio.NewScanner(fh.stderr.r)
+		for scanner.Scan() {
+			fmt.Fprintf(output, "%s\n", scanner.Text())
+		}
+	}
+	fh.mu.Unlock()
+
+	fh.wg.Wait()
+
+	fh.mu.Lock()
+	err := fh.err
+	fh.sh.forget(fh)
+	fh.mu.Unlock()
+	return err
 }
diff --git a/lib/modules/modules_internal_test.go b/lib/modules/modules_internal_test.go
new file mode 100644
index 0000000..349869b
--- /dev/null
+++ b/lib/modules/modules_internal_test.go
@@ -0,0 +1,55 @@
+package modules
+
+import (
+	"fmt"
+	"io"
+	"path/filepath"
+	"runtime"
+	"testing"
+)
+
+func init() {
+	RegisterChild("echos", Echo)
+}
+
+func Echo(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+	if len(args) == 0 {
+		return fmt.Errorf("no args")
+	}
+	for _, a := range args {
+		fmt.Println(a)
+	}
+	return nil
+}
+
+func assertNumHandles(t *testing.T, sh *Shell, n int) {
+	if got, want := len(sh.handles), n; got != want {
+		_, file, line, _ := runtime.Caller(2)
+		t.Errorf("%s:%d: got %d, want %d", filepath.Base(file), line, got, want)
+	}
+}
+
+func TestState(t *testing.T) {
+	sh := NewShell()
+	sh.AddSubprocess("echonotregistered", "[args]*")
+	sh.AddSubprocess("echos", "[args]*")
+	sh.AddFunction("echof", Echo, "[args]*")
+
+	assertNumHandles(t, sh, 0)
+	_, _ = sh.Start("echonotregistered") // won't start.
+	hs, _ := sh.Start("echos", "a")
+	hf, _ := sh.Start("echof", "b")
+
+	assertNumHandles(t, sh, 2)
+	for i, h := range []Handle{hs, hf} {
+		if got := h.Shutdown(nil); got != nil {
+			t.Errorf("%d: got %q, want %q", i, got, nil)
+		}
+	}
+	assertNumHandles(t, sh, 0)
+	hs, _ = sh.Start("echos", "a", "b")
+	hf, _ = sh.Start("echof", "c")
+	assertNumHandles(t, sh, 2)
+	sh.Cleanup(nil)
+	assertNumHandles(t, sh, 0)
+}
diff --git a/lib/modules/modules_test.go b/lib/modules/modules_test.go
index 7414358..808a413 100644
--- a/lib/modules/modules_test.go
+++ b/lib/modules/modules_test.go
@@ -13,6 +13,7 @@
 
 func init() {
 	modules.RegisterChild("envtest", PrintEnv)
+	modules.RegisterChild("errortest", ErrorMain)
 }
 
 func PrintEnv(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
@@ -28,6 +29,10 @@
 	return nil
 }
 
+func ErrorMain(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+	return fmt.Errorf("an error")
+}
+
 func waitForInput(scanner *bufio.Scanner) bool {
 	ch := make(chan struct{})
 	go func(ch chan<- struct{}) {
@@ -58,7 +63,7 @@
 	if got, want := scanner.Text(), key+"="+val; got != want {
 		t.Errorf("got %q, want %q", got, want)
 	}
-	h.Stdin().Close()
+	h.CloseStdin()
 	if !waitForInput(scanner) {
 		t.Errorf("timeout")
 		return
@@ -66,6 +71,9 @@
 	if got, want := scanner.Text(), "done"; got != want {
 		t.Errorf("got %q, want %q", got, want)
 	}
+	if err := h.Shutdown(nil); err != nil {
+		t.Fatalf("unexpected error: %s", err)
+	}
 }
 
 func TestChild(t *testing.T) {
@@ -84,6 +92,30 @@
 	testCommand(t, sh, "envtest", key, val)
 }
 
+func TestErrorChild(t *testing.T) {
+	sh := modules.NewShell()
+	sh.AddSubprocess("errortest", "")
+	h, err := sh.Start("errortest")
+	if err != nil {
+		t.Fatalf("unexpected error: %s", err)
+	}
+	if got, want := h.Shutdown(nil), "exit status 1"; got == nil || got.Error() != want {
+		t.Errorf("got %q, want %q", got, want)
+	}
+}
+
+func TestErrorFunc(t *testing.T) {
+	sh := modules.NewShell()
+	sh.AddFunction("errortest", ErrorMain, "")
+	h, err := sh.Start("errortest")
+	if err != nil {
+		t.Fatalf("unexpected error: %s", err)
+	}
+	if got, want := h.Shutdown(nil), "an error"; got != nil && got.Error() != want {
+		t.Errorf("got %q, want %q", got, want)
+	}
+}
+
 func TestHelperProcess(t *testing.T) {
 	if !modules.IsTestHelperProcess() {
 		return
diff --git a/lib/modules/shell.go b/lib/modules/shell.go
index 3105b4b..d6ee1bc 100644
--- a/lib/modules/shell.go
+++ b/lib/modules/shell.go
@@ -27,7 +27,8 @@
 // should wait for their stdin stream to be closed before exiting. The
 // caller can then coordinate with any command by writing to that stdin
 // stream and reading responses from the stdout stream, and it can close
-// stdin when it's ready for the command to exit.
+// stdin when it's ready for the command to exit using the CloseStdin method
+// on the command's handle.
 //
 // The signature of the function that implements the command is the
 // same for both types of command and is defined by the Main function type.
@@ -122,9 +123,7 @@
 
 // Start starts the specified command, it returns a Handle which can be used
 // for interacting with that command. The Shell tracks all of the Handles
-// that it creates so that it can shut them down when asked to. If any
-// application calls Shutdown on a handle directly, it must call the Forget
-// method on the Shell instance hosting that Handle to avoid storage leaks.
+// that it creates so that it can shut them down when asked to.
 func (sh *Shell) Start(command string, args ...string) (Handle, error) {
 	sh.mu.Lock()
 	cmd := sh.cmds[command]
@@ -144,8 +143,8 @@
 	return h, nil
 }
 
-// Forget tells the Shell to stop tracking the supplied Handle.
-func (sh *Shell) Forget(h Handle) {
+// forget tells the Shell to stop tracking the supplied Handle.
+func (sh *Shell) forget(h Handle) {
 	sh.mu.Lock()
 	delete(sh.handles, h)
 	sh.mu.Unlock()
@@ -200,11 +199,15 @@
 // then any such output is lost.
 func (sh *Shell) Cleanup(output io.Writer) {
 	sh.mu.Lock()
-	defer sh.mu.Unlock()
-	for k, _ := range sh.handles {
-		k.Shutdown(output)
+	handles := make(map[Handle]struct{})
+	for k, v := range sh.handles {
+		handles[k] = v
 	}
 	sh.handles = make(map[Handle]struct{})
+	sh.mu.Unlock()
+	for k, _ := range handles {
+		k.Shutdown(output)
+	}
 }
 
 // Handle represents a running command.
@@ -220,15 +223,17 @@
 	// convention is for commands to wait for stdin to be closed before
 	// they exit, thus the caller should close stdin when it wants the
 	// command to exit cleanly.
-	Stdin() io.WriteCloser
+	Stdin() io.Writer
 
-	// Shutdown closes the Stdin for the command. It is primarily intended
-	// for being called by the Shell, if other application code calls it
-	// then it should use the Shell's Forget method to have the Shell stop
-	// tracking the handle. Any buffered stderr output from the command will
-	// be written to the supplied io.Writer. If the io.Writer is nil then
-	// any such output is lost.
-	Shutdown(io.Writer)
+	// CloseStdin closes stdin in a manner that avoids a data race
+	// between any current readers on it.
+	CloseStdin()
+
+	// Shutdown closes the Stdin for the command and then reads output
+	// from the command's stdout until it encounters EOF and writes that
+	// output to the supplied io.Writer. It returns any error returned by
+	// the command.
+	Shutdown(io.Writer) error
 }
 
 // command is used to abstract the implementations of inprocess and subprocess