Merge "veyron/lib/modules: allow Shutdown to return an error status and avoid races."
diff --git a/lib/modules/core/core_test.go b/lib/modules/core/core_test.go
index 8a6614b..fc104a9 100644
--- a/lib/modules/core/core_test.go
+++ b/lib/modules/core/core_test.go
@@ -40,7 +40,7 @@
s := expect.NewSession(t, root.Stdout(), time.Second)
s.ExpectVar("MT_NAME")
s.ExpectVar("PID")
- root.Stdin().Close()
+ root.CloseStdin()
s.Expect("done")
}
diff --git a/lib/modules/exec.go b/lib/modules/exec.go
index 279e6dc..5da6499 100644
--- a/lib/modules/exec.go
+++ b/lib/modules/exec.go
@@ -24,6 +24,7 @@
cmd *exec.Cmd
entryPoint string
handle *vexec.ParentHandle
+ sh *Shell
stderr *os.File
stdout *bufio.Reader
stdin io.WriteCloser
@@ -79,12 +80,18 @@
return eh.stderr
}
-func (eh *execHandle) Stdin() io.WriteCloser {
+func (eh *execHandle) Stdin() io.Writer {
eh.mu.Lock()
defer eh.mu.Unlock()
return eh.stdin
}
+func (eh *execHandle) CloseStdin() {
+ eh.mu.Lock()
+ eh.stdin.Close()
+ eh.mu.Unlock()
+}
+
// mergeOSEnv returns a slice contained the merged set of environment
// variables from the OS environment and those in this Shell, preferring
// values in the Shell environment over those found in the OS environment.
@@ -126,6 +133,7 @@
func (eh *execHandle) start(sh *Shell, args ...string) (Handle, error) {
eh.mu.Lock()
defer eh.mu.Unlock()
+ eh.sh = sh
newargs := append(testFlags(), args...)
cmd := exec.Command(os.Args[0], newargs...)
cmd.Env = append(sh.mergeOSEnvSlice(), eh.entryPoint)
@@ -156,26 +164,28 @@
return eh, err
}
-func (eh *execHandle) Shutdown(output io.Writer) {
+func (eh *execHandle) Shutdown(output io.Writer) error {
eh.mu.Lock()
defer eh.mu.Unlock()
eh.stdin.Close()
+ defer eh.sh.forget(eh)
if eh.stderr != nil {
defer func() {
eh.stderr.Close()
os.Remove(eh.stderr.Name())
}()
if output == nil {
- return
+ return eh.cmd.Wait()
}
if _, err := eh.stderr.Seek(0, 0); err != nil {
- return
+ return eh.cmd.Wait()
}
scanner := bufio.NewScanner(eh.stderr)
for scanner.Scan() {
fmt.Fprintf(output, "%s\n", scanner.Text())
}
}
+ return eh.cmd.Wait()
}
const shellEntryPoint = "VEYRON_SHELL_HELPER_PROCESS_ENTRY_POINT"
diff --git a/lib/modules/func.go b/lib/modules/func.go
index e5fd73b..5a7c149 100644
--- a/lib/modules/func.go
+++ b/lib/modules/func.go
@@ -17,6 +17,9 @@
main Main
stdin, stderr, stdout pipe
bufferedStdout *bufio.Reader
+ err error
+ sh *Shell
+ wg sync.WaitGroup
}
func newFunctionHandle(main Main) command {
@@ -35,15 +38,23 @@
return fh.stderr.r
}
-func (fh *functionHandle) Stdin() io.WriteCloser {
+func (fh *functionHandle) Stdin() io.Writer {
fh.mu.Lock()
defer fh.mu.Unlock()
return fh.stdin.w
}
+func (fh *functionHandle) CloseStdin() {
+ fh.mu.Lock()
+ fd := fh.stdin.w.Fd()
+ fh.mu.Unlock()
+ syscall.Close(int(fd))
+}
+
func (fh *functionHandle) start(sh *Shell, args ...string) (Handle, error) {
fh.mu.Lock()
defer fh.mu.Unlock()
+ fh.sh = sh
for _, p := range []*pipe{&fh.stdin, &fh.stdout, &fh.stderr} {
var err error
if p.r, p.w, err = os.Pipe(); err != nil {
@@ -51,30 +62,43 @@
}
}
fh.bufferedStdout = bufio.NewReader(fh.stdout.r)
+ fh.wg.Add(1)
+
go func() {
err := fh.main(fh.stdin.r, fh.stdout.w, fh.stderr.w, sh.mergeOSEnv(), args...)
if err != nil {
fmt.Fprintf(fh.stderr.w, "%s\n", err)
}
- // See the comment below in Shutdown.
- syscall.Close(int(fh.stdin.r.Fd()))
- syscall.Close(int(fh.stdout.w.Fd()))
- syscall.Close(int(fh.stderr.w.Fd()))
+ fh.mu.Lock()
+ // We close these files using the Close system call since there
+ // may be an oustanding read on them that would otherwise trigger
+ // a test failure with go test -race
+ syscall.Close(int(fh.stdin.w.Fd()))
+ syscall.Close(int(fh.stdout.r.Fd()))
+ syscall.Close(int(fh.stderr.r.Fd()))
+ fh.err = err
+ fh.mu.Unlock()
+ fh.wg.Done()
}()
return fh, nil
}
-func (fh *functionHandle) Shutdown(output io.Writer) {
+func (fh *functionHandle) Shutdown(output io.Writer) error {
fh.mu.Lock()
- defer fh.mu.Unlock()
- scanner := bufio.NewScanner(fh.stderr.r)
- for scanner.Scan() {
- fmt.Fprintf(output, "%s\n", scanner.Text())
- }
- // We close these files using the Close system call since there
- // may be an oustanding read on them that would otherwise trigger
- // a test failure with go test -race
syscall.Close(int(fh.stdin.w.Fd()))
- syscall.Close(int(fh.stdout.r.Fd()))
- syscall.Close(int(fh.stderr.r.Fd()))
+ if output != nil {
+ scanner := bufio.NewScanner(fh.stderr.r)
+ for scanner.Scan() {
+ fmt.Fprintf(output, "%s\n", scanner.Text())
+ }
+ }
+ fh.mu.Unlock()
+
+ fh.wg.Wait()
+
+ fh.mu.Lock()
+ err := fh.err
+ fh.sh.forget(fh)
+ fh.mu.Unlock()
+ return err
}
diff --git a/lib/modules/modules_internal_test.go b/lib/modules/modules_internal_test.go
new file mode 100644
index 0000000..349869b
--- /dev/null
+++ b/lib/modules/modules_internal_test.go
@@ -0,0 +1,55 @@
+package modules
+
+import (
+ "fmt"
+ "io"
+ "path/filepath"
+ "runtime"
+ "testing"
+)
+
+func init() {
+ RegisterChild("echos", Echo)
+}
+
+func Echo(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+ if len(args) == 0 {
+ return fmt.Errorf("no args")
+ }
+ for _, a := range args {
+ fmt.Println(a)
+ }
+ return nil
+}
+
+func assertNumHandles(t *testing.T, sh *Shell, n int) {
+ if got, want := len(sh.handles), n; got != want {
+ _, file, line, _ := runtime.Caller(2)
+ t.Errorf("%s:%d: got %d, want %d", filepath.Base(file), line, got, want)
+ }
+}
+
+func TestState(t *testing.T) {
+ sh := NewShell()
+ sh.AddSubprocess("echonotregistered", "[args]*")
+ sh.AddSubprocess("echos", "[args]*")
+ sh.AddFunction("echof", Echo, "[args]*")
+
+ assertNumHandles(t, sh, 0)
+ _, _ = sh.Start("echonotregistered") // won't start.
+ hs, _ := sh.Start("echos", "a")
+ hf, _ := sh.Start("echof", "b")
+
+ assertNumHandles(t, sh, 2)
+ for i, h := range []Handle{hs, hf} {
+ if got := h.Shutdown(nil); got != nil {
+ t.Errorf("%d: got %q, want %q", i, got, nil)
+ }
+ }
+ assertNumHandles(t, sh, 0)
+ hs, _ = sh.Start("echos", "a", "b")
+ hf, _ = sh.Start("echof", "c")
+ assertNumHandles(t, sh, 2)
+ sh.Cleanup(nil)
+ assertNumHandles(t, sh, 0)
+}
diff --git a/lib/modules/modules_test.go b/lib/modules/modules_test.go
index 7414358..808a413 100644
--- a/lib/modules/modules_test.go
+++ b/lib/modules/modules_test.go
@@ -13,6 +13,7 @@
func init() {
modules.RegisterChild("envtest", PrintEnv)
+ modules.RegisterChild("errortest", ErrorMain)
}
func PrintEnv(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
@@ -28,6 +29,10 @@
return nil
}
+func ErrorMain(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+ return fmt.Errorf("an error")
+}
+
func waitForInput(scanner *bufio.Scanner) bool {
ch := make(chan struct{})
go func(ch chan<- struct{}) {
@@ -58,7 +63,7 @@
if got, want := scanner.Text(), key+"="+val; got != want {
t.Errorf("got %q, want %q", got, want)
}
- h.Stdin().Close()
+ h.CloseStdin()
if !waitForInput(scanner) {
t.Errorf("timeout")
return
@@ -66,6 +71,9 @@
if got, want := scanner.Text(), "done"; got != want {
t.Errorf("got %q, want %q", got, want)
}
+ if err := h.Shutdown(nil); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
}
func TestChild(t *testing.T) {
@@ -84,6 +92,30 @@
testCommand(t, sh, "envtest", key, val)
}
+func TestErrorChild(t *testing.T) {
+ sh := modules.NewShell()
+ sh.AddSubprocess("errortest", "")
+ h, err := sh.Start("errortest")
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if got, want := h.Shutdown(nil), "exit status 1"; got == nil || got.Error() != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
+func TestErrorFunc(t *testing.T) {
+ sh := modules.NewShell()
+ sh.AddFunction("errortest", ErrorMain, "")
+ h, err := sh.Start("errortest")
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if got, want := h.Shutdown(nil), "an error"; got != nil && got.Error() != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
func TestHelperProcess(t *testing.T) {
if !modules.IsTestHelperProcess() {
return
diff --git a/lib/modules/shell.go b/lib/modules/shell.go
index 3105b4b..d6ee1bc 100644
--- a/lib/modules/shell.go
+++ b/lib/modules/shell.go
@@ -27,7 +27,8 @@
// should wait for their stdin stream to be closed before exiting. The
// caller can then coordinate with any command by writing to that stdin
// stream and reading responses from the stdout stream, and it can close
-// stdin when it's ready for the command to exit.
+// stdin when it's ready for the command to exit using the CloseStdin method
+// on the command's handle.
//
// The signature of the function that implements the command is the
// same for both types of command and is defined by the Main function type.
@@ -122,9 +123,7 @@
// Start starts the specified command, it returns a Handle which can be used
// for interacting with that command. The Shell tracks all of the Handles
-// that it creates so that it can shut them down when asked to. If any
-// application calls Shutdown on a handle directly, it must call the Forget
-// method on the Shell instance hosting that Handle to avoid storage leaks.
+// that it creates so that it can shut them down when asked to.
func (sh *Shell) Start(command string, args ...string) (Handle, error) {
sh.mu.Lock()
cmd := sh.cmds[command]
@@ -144,8 +143,8 @@
return h, nil
}
-// Forget tells the Shell to stop tracking the supplied Handle.
-func (sh *Shell) Forget(h Handle) {
+// forget tells the Shell to stop tracking the supplied Handle.
+func (sh *Shell) forget(h Handle) {
sh.mu.Lock()
delete(sh.handles, h)
sh.mu.Unlock()
@@ -200,11 +199,15 @@
// then any such output is lost.
func (sh *Shell) Cleanup(output io.Writer) {
sh.mu.Lock()
- defer sh.mu.Unlock()
- for k, _ := range sh.handles {
- k.Shutdown(output)
+ handles := make(map[Handle]struct{})
+ for k, v := range sh.handles {
+ handles[k] = v
}
sh.handles = make(map[Handle]struct{})
+ sh.mu.Unlock()
+ for k, _ := range handles {
+ k.Shutdown(output)
+ }
}
// Handle represents a running command.
@@ -220,15 +223,17 @@
// convention is for commands to wait for stdin to be closed before
// they exit, thus the caller should close stdin when it wants the
// command to exit cleanly.
- Stdin() io.WriteCloser
+ Stdin() io.Writer
- // Shutdown closes the Stdin for the command. It is primarily intended
- // for being called by the Shell, if other application code calls it
- // then it should use the Shell's Forget method to have the Shell stop
- // tracking the handle. Any buffered stderr output from the command will
- // be written to the supplied io.Writer. If the io.Writer is nil then
- // any such output is lost.
- Shutdown(io.Writer)
+ // CloseStdin closes stdin in a manner that avoids a data race
+ // between any current readers on it.
+ CloseStdin()
+
+ // Shutdown closes the Stdin for the command and then reads output
+ // from the command's stdout until it encounters EOF and writes that
+ // output to the supplied io.Writer. It returns any error returned by
+ // the command.
+ Shutdown(io.Writer) error
}
// command is used to abstract the implementations of inprocess and subprocess