gosh: make "await" methods return immediately on process exit
Change-Id: If2ae3b119e2d0abdceac89ad9a225c911988e2c2
diff --git a/gosh/cmd.go b/gosh/cmd.go
index 8678d24..38ca2a0 100644
--- a/gosh/cmd.go
+++ b/gosh/cmd.go
@@ -22,6 +22,7 @@
errAlreadyCalledStart = errors.New("gosh: already called Cmd.Start")
errAlreadyCalledWait = errors.New("gosh: already called Cmd.Wait")
errDidNotCallStart = errors.New("gosh: did not call Cmd.Start")
+ errProcessExited = errors.New("gosh: process exited")
)
// Cmd represents a command. Not thread-safe.
@@ -50,17 +51,15 @@
stdinWriteCloser io.WriteCloser // from exec.Cmd.StdinPipe
calledStart bool
calledWait bool
+ cond *sync.Cond
waitChan chan error
started bool // protected by sh.cleanupMu
- exitedMu sync.Mutex
- exited bool // protected by exitedMu
+ exited bool // protected by cond.L
stdoutWriters []io.Writer
stderrWriters []io.Writer
closers []io.Closer
- condReady *sync.Cond
- recvReady bool // protected by condReady.L
- condVars *sync.Cond
- recvVars map[string]string // protected by condVars.L
+ recvReady bool // protected by cond.L
+ recvVars map[string]string // protected by cond.L
}
// Clone returns a new Cmd with a copy of this Cmd's configuration.
@@ -72,9 +71,11 @@
}
// StdinPipe returns a thread-safe WriteCloser backed by a buffered pipe for the
-// command's stdin. The returned WriteCloser will be closed when the process
-// exits. Must be called before Start. It is safe to call StdinPipe multiple
-// times; calls after the first return the pipe created by the first call.
+// command's stdin. The returned pipe will be closed when the process exits, but
+// may also be closed earlier by the caller, e.g. if the command does not exit
+// until its stdin is closed. Must be called before Start. It is safe to call
+// StdinPipe multiple times; calls after the first return the pipe created by
+// the first call.
func (c *Cmd) StdinPipe() io.WriteCloser {
c.sh.Ok()
res, err := c.stdinPipe()
@@ -200,15 +201,14 @@
func newCmdInternal(sh *Shell, vars map[string]string, path string, args []string) (*Cmd, error) {
c := &Cmd{
- Path: path,
- Vars: vars,
- Args: args,
- sh: sh,
- c: &exec.Cmd{},
- waitChan: make(chan error, 1),
- condReady: sync.NewCond(&sync.Mutex{}),
- condVars: sync.NewCond(&sync.Mutex{}),
- recvVars: map[string]string{},
+ Path: path,
+ Vars: vars,
+ Args: args,
+ sh: sh,
+ c: &exec.Cmd{},
+ cond: sync.NewCond(&sync.Mutex{}),
+ waitChan: make(chan error, 1),
+ recvVars: map[string]string{},
}
// Protect against concurrent signal-triggered Shell.cleanup().
sh.cleanupMu.Lock()
@@ -271,8 +271,8 @@
if !c.started {
return false
}
- c.exitedMu.Lock()
- defer c.exitedMu.Unlock()
+ c.cond.L.Lock()
+ defer c.cond.L.Unlock()
return !c.exited
}
@@ -294,15 +294,15 @@
}
switch m.Type {
case typeReady:
- w.c.condReady.L.Lock()
+ w.c.cond.L.Lock()
w.c.recvReady = true
- w.c.condReady.Signal()
- w.c.condReady.L.Unlock()
+ w.c.cond.Signal()
+ w.c.cond.L.Unlock()
case typeVars:
- w.c.condVars.L.Lock()
+ w.c.cond.L.Lock()
w.c.recvVars = mergeMaps(w.c.recvVars, m.Vars)
- w.c.condVars.Signal()
- w.c.condVars.L.Unlock()
+ w.c.cond.Signal()
+ w.c.cond.L.Unlock()
default:
return 0, fmt.Errorf("unknown message type: %q", m.Type)
}
@@ -448,13 +448,16 @@
return err
}
// Start the command.
- err = c.c.Start()
- if err != nil {
- c.exitedMu.Lock()
+ onExit := func(err error) {
+ c.cond.L.Lock()
c.exited = true
- c.exitedMu.Unlock()
+ c.cond.Signal()
+ c.cond.L.Unlock()
c.closeClosers()
- c.waitChan <- errors.New("gosh: start failed")
+ c.waitChan <- err
+ }
+ if err = c.c.Start(); err != nil {
+ onExit(errors.New("gosh: start failed"))
return err
}
c.started = true
@@ -463,18 +466,12 @@
// ensures that the child process is reaped once it exits. Note, gosh.Cmd.wait
// blocks on waitChan.
go func() {
- err := c.c.Wait()
- c.exitedMu.Lock()
- c.exited = true
- c.exitedMu.Unlock()
- c.closeClosers()
- c.waitChan <- err
+ onExit(c.c.Wait())
}()
return nil
}
-// TODO(sadovsky): Make it so Cmd.{awaitReady,awaitVars} return an error if/when
-// we detect that the process has exited. Also, maybe add optional timeouts for
+// TODO(sadovsky): Maybe add optional timeouts for
// Cmd.{awaitReady,awaitVars,wait}.
func (c *Cmd) awaitReady() error {
@@ -483,12 +480,15 @@
} else if c.calledWait {
return errAlreadyCalledWait
}
- // http://golang.org/pkg/sync/#Cond.Wait
- c.condReady.L.Lock()
- for !c.recvReady {
- c.condReady.Wait()
+ c.cond.L.Lock()
+ defer c.cond.L.Unlock()
+ for !c.exited && !c.recvReady {
+ c.cond.Wait()
}
- c.condReady.L.Unlock()
+ // Return nil error if both conditions triggered simultaneously.
+ if !c.recvReady {
+ return errProcessExited
+ }
return nil
}
@@ -510,14 +510,17 @@
}
}
}
- // http://golang.org/pkg/sync/#Cond.Wait
- c.condVars.L.Lock()
+ c.cond.L.Lock()
+ defer c.cond.L.Unlock()
updateRes()
- for len(res) < len(wantKeys) {
- c.condVars.Wait()
+ for !c.exited && len(res) < len(wantKeys) {
+ c.cond.Wait()
updateRes()
}
- c.condVars.L.Unlock()
+ // Return nil error if both conditions triggered simultaneously.
+ if len(res) < len(wantKeys) {
+ return nil, errProcessExited
+ }
return res, nil
}
diff --git a/gosh/internal/gosh_example/main.go b/gosh/internal/gosh_example/main.go
index 67fa15c..bde8f86 100644
--- a/gosh/internal/gosh_example/main.go
+++ b/gosh/internal/gosh_example/main.go
@@ -30,8 +30,8 @@
}
var (
- getFn = gosh.Register("get", lib.Get)
- serveFn = gosh.Register("serve", lib.Serve)
+ getFn = gosh.Register("getFn", lib.Get)
+ serveFn = gosh.Register("serveFn", lib.Serve)
)
func ExampleFns() {
diff --git a/gosh/shell_test.go b/gosh/shell_test.go
index 8c25904..fc78858 100644
--- a/gosh/shell_test.go
+++ b/gosh/shell_test.go
@@ -82,30 +82,30 @@
// Simplified versions of various Unix commands.
var (
- catFn = gosh.Register("cat", func() {
+ catFn = gosh.Register("catFn", func() {
io.Copy(os.Stdout, os.Stdin)
})
- echoFn = gosh.Register("echo", func() {
+ echoFn = gosh.Register("echoFn", func() {
fmt.Println(os.Args[1])
})
- readFn = gosh.Register("read", func() {
+ readFn = gosh.Register("readFn", func() {
bufio.NewReader(os.Stdin).ReadString('\n')
})
)
// Functions with parameters.
var (
- exitFn = gosh.Register("exit", func(code int) {
+ exitFn = gosh.Register("exitFn", func(code int) {
os.Exit(code)
})
- sleepFn = gosh.Register("sleep", func(d time.Duration, code int) {
+ sleepFn = gosh.Register("sleepFn", func(d time.Duration, code int) {
time.Sleep(d)
os.Exit(code)
})
- printFn = gosh.Register("print", func(v ...interface{}) {
+ printFn = gosh.Register("printFn", func(v ...interface{}) {
fmt.Print(v...)
})
- printfFn = gosh.Register("printf", func(format string, v ...interface{}) {
+ printfFn = gosh.Register("printfFn", func(format string, v ...interface{}) {
fmt.Printf(format, v...)
})
)
@@ -200,8 +200,8 @@
}
var (
- getFn = gosh.Register("get", lib.Get)
- serveFn = gosh.Register("serve", lib.Serve)
+ getFn = gosh.Register("getFn", lib.Get)
+ serveFn = gosh.Register("serveFn", lib.Serve)
)
func TestFns(t *testing.T) {
@@ -230,14 +230,14 @@
// Functions designed for TestRegistry.
var (
- printIntsFn = gosh.Register("printInts", func(v ...int) {
+ printIntsFn = gosh.Register("printIntsFn", func(v ...int) {
var vi []interface{}
for _, x := range v {
vi = append(vi, x)
}
fmt.Print(vi...)
})
- printfIntsFn = gosh.Register("printfInts", func(format string, v ...int) {
+ printfIntsFn = gosh.Register("printfIntsFn", func(format string, v ...int) {
var vi []interface{}
for _, x := range v {
vi = append(vi, x)
@@ -246,6 +246,28 @@
})
)
+// Tests that Await{Ready,Vars} return immediately when the process exits.
+func TestAwaitProcessExit(t *testing.T) {
+ sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
+ defer sh.Cleanup()
+
+ c := sh.Fn(exitFn, 0)
+ c.Start()
+ sh.Opts.Fatalf = nil
+ c.AwaitReady()
+ nok(t, sh.Err)
+ sh.Err = nil
+ sh.Opts.Fatalf = makeFatalf(t)
+
+ c = sh.Fn(exitFn, 0)
+ c.Start()
+ sh.Opts.Fatalf = nil
+ c.AwaitVars("foo")
+ nok(t, sh.Err)
+ sh.Err = nil
+ sh.Opts.Fatalf = makeFatalf(t)
+}
+
// Tests function signature-checking and execution.
func TestRegistry(t *testing.T) {
sh := gosh.NewShell(gosh.Opts{Fatalf: makeFatalf(t), Logf: t.Logf})
@@ -351,7 +373,7 @@
nok(t, sh.Err)
}
-var writeFn = gosh.Register("write", func(stdout, stderr bool) error {
+var writeFn = gosh.Register("writeFn", func(stdout, stderr bool) error {
if stdout {
if _, err := os.Stdout.Write([]byte("A")); err != nil {
return err
@@ -407,7 +429,7 @@
eq(t, toString(t, stderrPipe), "BB")
}
-var writeMoreFn = gosh.Register("writeMore", func() {
+var writeMoreFn = gosh.Register("writeMoreFn", func() {
sh := gosh.NewShell(gosh.Opts{})
defer sh.Cleanup()