veyron/examples/tunnel: Fix race in IO

When the shell command exits, stdout and stderr are closed, which means
that both stdout2outchan() and stderr2outchan() will report an error.
The problem happens when there is still data to be read from one file
descriptor when the error is detected on the other, which would drop the
unread data.

This change makes sure that all the queues are properly drained when an
error is reported, and adds new tests for common patterns.

Change-Id: Iaa636e768a6c890155cbdc9ed44e3a925cbe676c
diff --git a/examples/tunnel/tunnel.vdl b/examples/tunnel/tunnel.vdl
index 88eb2bf..3098e0e 100644
--- a/examples/tunnel/tunnel.vdl
+++ b/examples/tunnel/tunnel.vdl
@@ -20,21 +20,23 @@
 }
 
 type ShellOpts struct {
-	UsePty       bool      // Whether to open a pseudo-terminal
-	Environment  []string  // Environment variables to pass to the remote shell.
-	Rows, Cols   uint32    // Window size.
+  UsePty      bool      // Whether to open a pseudo-terminal
+  Environment []string  // Environment variables to pass to the remote shell.
+  Rows, Cols  uint32    // Window size.
 }
 
 type ClientShellPacket struct {
-        // Bytes going to the shell's stdin.
-        Stdin       []byte
-        // A dynamic update of the window size. The default value of 0 means no-change.
-        Rows, Cols  uint32
+  // Bytes going to the shell's stdin.
+  Stdin      []byte
+  // Indicates that stdin should be closed.
+  EOF        bool
+  // A dynamic update of the window size. The default value of 0 means no-change.
+  Rows, Cols uint32
 }
 
 type ServerShellPacket struct {
-        // Bytes coming from the shell's stdout.
-        Stdout      []byte
-        // Bytes coming from the shell's stderr.
-        Stderr      []byte
+  // Bytes coming from the shell's stdout.
+  Stdout []byte
+  // Bytes coming from the shell's stderr.
+  Stderr []byte
 }
diff --git a/examples/tunnel/tunnel.vdl.go b/examples/tunnel/tunnel.vdl.go
index 2d7f6ff..b8b33c6 100644
--- a/examples/tunnel/tunnel.vdl.go
+++ b/examples/tunnel/tunnel.vdl.go
@@ -26,6 +26,8 @@
 type ClientShellPacket struct {
 	// Bytes going to the shell's stdin.
 	Stdin []byte
+	// Indicates that stdin should be closed.
+	EOF bool
 	// A dynamic update of the window size. The default value of 0 means no-change.
 	Rows uint32
 	Cols uint32
@@ -674,6 +676,7 @@
 		_gen_wiretype.StructType{
 			[]_gen_wiretype.FieldType{
 				_gen_wiretype.FieldType{Type: 0x43, Name: "Stdin"},
+				_gen_wiretype.FieldType{Type: 0x2, Name: "EOF"},
 				_gen_wiretype.FieldType{Type: 0x34, Name: "Rows"},
 				_gen_wiretype.FieldType{Type: 0x34, Name: "Cols"},
 			},
diff --git a/examples/tunnel/tunneld/impl/iomanager.go b/examples/tunnel/tunneld/impl/iomanager.go
index 265cc3b..a8c6409 100644
--- a/examples/tunnel/tunneld/impl/iomanager.go
+++ b/examples/tunnel/tunneld/impl/iomanager.go
@@ -3,12 +3,13 @@
 import (
 	"fmt"
 	"io"
+	"sync"
 
 	"veyron/examples/tunnel"
 	"veyron2/vlog"
 )
 
-func runIOManager(stdin io.Writer, stdout, stderr io.Reader, ptyFd uintptr, stream tunnel.TunnelServiceShellStream) error {
+func runIOManager(stdin io.WriteCloser, stdout, stderr io.Reader, ptyFd uintptr, stream tunnel.TunnelServiceShellStream) error {
 	m := ioManager{stdin: stdin, stdout: stdout, stderr: stderr, ptyFd: ptyFd, stream: stream}
 	return m.run()
 }
@@ -16,105 +17,149 @@
 // ioManager manages the forwarding of all the data between the shell and the
 // stream.
 type ioManager struct {
-	stdin          io.Writer
+	stdin          io.WriteCloser
 	stdout, stderr io.Reader
 	ptyFd          uintptr
 	stream         tunnel.TunnelServiceShellStream
 
-	// done receives any error from chan2stream, user2stream, or
-	// stream2user.
-	done chan error
-	// outchan is used to serialize the output to the stream. This is
-	// needed because stream.Send is not thread-safe.
-	outchan chan tunnel.ServerShellPacket
-	// closed is closed when run() exits.
-	closed chan struct{}
+	// streamError receives errors coming from stream operations.
+	streamError chan error
+	// stdioError receives errors coming from stdio operations.
+	stdioError chan error
 }
 
 func (m *ioManager) run() error {
-	// done receives any error from chan2stream, stdout2stream, or
-	// stream2stdin.
-	m.done = make(chan error, 3)
-	// outchan is used to serialize the output to the stream.
-	// chan2stream() receives data sent by stdout2outchan() and
-	// stderr2outchan() and sends it to the stream.
-	m.outchan = make(chan tunnel.ServerShellPacket)
-	m.closed = make(chan struct{})
-	defer close(m.closed)
-	go m.chan2stream()
+	m.streamError = make(chan error, 1)
+	m.stdioError = make(chan error, 1)
+
+	var pendingShellOutput sync.WaitGroup
+	pendingShellOutput.Add(1)
+	var pendingStreamInput sync.WaitGroup
+	pendingStreamInput.Add(1)
 
 	// Forward data between the shell's stdio and the stream.
-	go m.stdout2outchan()
-	if m.stderr != nil {
-		go m.stderr2outchan()
-	}
-	go m.stream2stdin()
+	go func() {
+		defer pendingShellOutput.Done()
+		// outchan is used to serialize the output to the stream.
+		// chan2stream() receives data sent by stdout2outchan() and
+		// stderr2outchan() and sends it to the stream.
+		outchan := make(chan tunnel.ServerShellPacket)
+		var wgStream sync.WaitGroup
+		wgStream.Add(1)
+		go m.chan2stream(outchan, &wgStream)
+		var wgStdio sync.WaitGroup
+		wgStdio.Add(1)
+		go m.stdout2outchan(outchan, &wgStdio)
+		if m.stderr != nil {
+			wgStdio.Add(1)
+			go m.stderr2outchan(outchan, &wgStdio)
+		}
+		// When both stdout2outchan and stderr2outchan are done, close
+		// outchan to signal chan2stream to exit.
+		wgStdio.Wait()
+		close(outchan)
+		wgStream.Wait()
+	}()
+	go m.stream2stdin(&pendingStreamInput)
 
 	// Block until something reports an error.
-	return <-m.done
+	//
+	// If there is any stream error, we assume that both ends of the stream
+	// have an error, e.g. if stream.Reader.Advance fails then
+	// stream.Sender.Send will fail. We process any remaining input from
+	// the stream and then return.
+	//
+	// If there is any stdio error, we assume all 3 io channels will fail
+	// (if stdout.Read fails then stdin.Write and stderr.Read will also
+	// fail). We process is remaining output from the shell and then
+	// return.
+	select {
+	case err := <-m.streamError:
+		// Process remaining input from the stream before exiting.
+		vlog.VI(2).Infof("run stream error: %v", err)
+		pendingStreamInput.Wait()
+		return err
+	case err := <-m.stdioError:
+		// Process remaining output from the shell before exiting.
+		vlog.VI(2).Infof("run stdio error: %v", err)
+		pendingShellOutput.Wait()
+		return err
+	}
+}
+
+func (m *ioManager) sendStreamError(err error) {
+	select {
+	case m.streamError <- err:
+	default:
+	}
+}
+
+func (m *ioManager) sendStdioError(err error) {
+	select {
+	case m.stdioError <- err:
+	default:
+	}
 }
 
 // chan2stream receives ServerShellPacket from outchan and sends it to stream.
-func (m *ioManager) chan2stream() {
+func (m *ioManager) chan2stream(outchan <-chan tunnel.ServerShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	sender := m.stream.SendStream()
-	for packet := range m.outchan {
+	for packet := range outchan {
+		vlog.VI(3).Infof("chan2stream packet: %+v", packet)
 		if err := sender.Send(packet); err != nil {
-			m.done <- err
-			return
+			vlog.VI(2).Infof("chan2stream: %v", err)
+			m.sendStreamError(err)
 		}
 	}
-	m.done <- io.EOF
-}
-
-func (m *ioManager) sendOnOutchan(p tunnel.ServerShellPacket) bool {
-	select {
-	case m.outchan <- p:
-		return true
-	case <-m.closed:
-		return false
-	}
 }
 
 // stdout2stream reads data from the shell's stdout and sends it to the outchan.
-func (m *ioManager) stdout2outchan() {
+func (m *ioManager) stdout2outchan(outchan chan<- tunnel.ServerShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	for {
 		buf := make([]byte, 2048)
 		n, err := m.stdout.Read(buf[:])
 		if err != nil {
 			vlog.VI(2).Infof("stdout2outchan: %v", err)
-			m.done <- err
+			m.sendStdioError(err)
 			return
 		}
-		if !m.sendOnOutchan(tunnel.ServerShellPacket{Stdout: buf[:n]}) {
-			return
-		}
+		outchan <- tunnel.ServerShellPacket{Stdout: buf[:n]}
 	}
 }
 
 // stderr2stream reads data from the shell's stderr and sends it to the outchan.
-func (m *ioManager) stderr2outchan() {
+func (m *ioManager) stderr2outchan(outchan chan<- tunnel.ServerShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	for {
 		buf := make([]byte, 2048)
 		n, err := m.stderr.Read(buf[:])
 		if err != nil {
 			vlog.VI(2).Infof("stderr2outchan: %v", err)
-			m.done <- err
+			m.sendStdioError(err)
 			return
 		}
-		if !m.sendOnOutchan(tunnel.ServerShellPacket{Stderr: buf[:n]}) {
-			return
-		}
+		outchan <- tunnel.ServerShellPacket{Stderr: buf[:n]}
 	}
 }
 
 // stream2stdin reads data from the stream and sends it to the shell's stdin.
-func (m *ioManager) stream2stdin() {
+func (m *ioManager) stream2stdin(wg *sync.WaitGroup) {
+	defer wg.Done()
 	rStream := m.stream.RecvStream()
 	for rStream.Advance() {
 		packet := rStream.Value()
+		vlog.VI(3).Infof("stream2stdin packet: %+v", packet)
 		if len(packet.Stdin) > 0 {
 			if n, err := m.stdin.Write(packet.Stdin); n != len(packet.Stdin) || err != nil {
-				m.done <- fmt.Errorf("stdin.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stdin))
+				m.sendStdioError(fmt.Errorf("stdin.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stdin)))
+				return
+			}
+		}
+		if packet.EOF {
+			if err := m.stdin.Close(); err != nil {
+				m.sendStdioError(fmt.Errorf("stdin.Close: %v", err))
 				return
 			}
 		}
@@ -129,5 +174,8 @@
 	}
 
 	vlog.VI(2).Infof("stream2stdin: %v", err)
-	m.done <- err
+	m.sendStreamError(err)
+	if err := m.stdin.Close(); err != nil {
+		m.sendStdioError(fmt.Errorf("stdin.Close: %v", err))
+	}
 }
diff --git a/examples/tunnel/tunneld/test.sh b/examples/tunnel/tunneld/test.sh
index 7b225b5..3e36ae0 100755
--- a/examples/tunnel/tunneld/test.sh
+++ b/examples/tunnel/tunneld/test.sh
@@ -101,6 +101,24 @@
     fail "line ${LINENO}: unexpected output. Got ${got}, want ${want}"
   fi
 
+  # Send input to remote command.
+  echo "HELLO SERVER" | ./vsh --logtostderr --v=1 "${ep}" "cat > ${workdir}/hello.txt" > "${vshlog}" 2>&1
+  got=$(cat "${workdir}/hello.txt")
+  want="HELLO SERVER"
+
+  if [[ "${got}" != "${want}" ]]; then
+    dumplogs "${vshlog}" "${tunlog}" "${mtlog}"
+    fail "line ${LINENO}: unexpected output. Got ${got}, want ${want}"
+  fi
+
+  got=$(echo "ECHO" | ./vsh --logtostderr --v=1 "${ep}" cat 2>"${vshlog}")
+  want="ECHO"
+
+  if [[ "${got}" != "${want}" ]]; then
+    dumplogs "${vshlog}" "${tunlog}" "${mtlog}"
+    fail "line ${LINENO}: unexpected output. Got ${got}, want ${want}"
+  fi
+
   # Verify that all the published names are there.
   got=$(./mounttable glob "${NAMESPACE_ROOT}" 'tunnel/*/*' |    \
         sed -e 's/TTL .m..s/TTL XmXXs/'                     \
diff --git a/examples/tunnel/vsh/iomanager.go b/examples/tunnel/vsh/iomanager.go
index 3389809..bc778d6 100644
--- a/examples/tunnel/vsh/iomanager.go
+++ b/examples/tunnel/vsh/iomanager.go
@@ -5,6 +5,7 @@
 	"io"
 	"os"
 	"os/signal"
+	"sync"
 	"syscall"
 
 	"veyron/examples/tunnel"
@@ -24,103 +25,148 @@
 	stdout, stderr io.Writer
 	stream         tunnel.TunnelShellCall
 
-	// done receives any error from chan2stream, user2outchan, or
-	// stream2user.
-	done chan error
-	// outchan is used to serialize the output to the stream. This is
-	// needed because stream.Send is not thread-safe.
-	outchan chan tunnel.ClientShellPacket
-	// closed is closed when run() exits
-	closed chan struct{}
+	// streamError receives errors coming from stream operations.
+	streamError chan error
+	// stdioError receives errors coming from stdio operations.
+	stdioError chan error
 }
 
 func (m *ioManager) run() error {
-	m.done = make(chan error, 3)
-	// outchan is used to serialize the output to the stream.
-	// chan2stream() receives data sent by handleWindowResize() and
-	// user2outchan() and sends it to the stream.
-	m.outchan = make(chan tunnel.ClientShellPacket)
-	m.closed = make(chan struct{})
-	defer close(m.closed)
-	go m.chan2stream()
-	// When the terminal window is resized, we receive a SIGWINCH. Then we
-	// send the new window size to the server.
-	winch := make(chan os.Signal, 1)
-	signal.Notify(winch, syscall.SIGWINCH)
-	defer signal.Stop(winch)
-	go m.handleWindowResize(winch)
+	m.streamError = make(chan error, 1)
+	m.stdioError = make(chan error, 1)
+
+	var pendingUserInput sync.WaitGroup
+	pendingUserInput.Add(1)
+	var pendingStreamOutput sync.WaitGroup
+	pendingStreamOutput.Add(1)
+
 	// Forward data between the user and the remote shell.
-	go m.user2outchan()
-	go m.stream2user()
+	go func() {
+		defer pendingUserInput.Done()
+		// outchan is used to serialize the output to the stream.
+		// chan2stream() receives data sent by handleWindowResize() and
+		// user2outchan() and sends it to the stream.
+		outchan := make(chan tunnel.ClientShellPacket)
+		var wgStream sync.WaitGroup
+		wgStream.Add(1)
+		go m.chan2stream(outchan, &wgStream)
+
+		// When the terminal window is resized, we receive a SIGWINCH. Then we
+		// send the new window size to the server.
+		winch := make(chan os.Signal, 1)
+		signal.Notify(winch, syscall.SIGWINCH)
+
+		var wgUser sync.WaitGroup
+		wgUser.Add(2)
+		go func() {
+			m.user2outchan(outchan, &wgUser)
+			signal.Stop(winch)
+			close(winch)
+		}()
+		go m.handleWindowResize(winch, outchan, &wgUser)
+		// When both user2outchan and handleWindowResize are done,
+		// close outchan to signal chan2stream to exit.
+		wgUser.Wait()
+		close(outchan)
+		wgStream.Wait()
+	}()
+	go m.stream2user(&pendingStreamOutput)
 	// Block until something reports an error.
-	return <-m.done
+	select {
+	case err := <-m.streamError:
+		// When we receive an error from the stream, wait for any
+		// remaining stream output to be sent to the user before
+		// exiting.
+		vlog.VI(2).Infof("run stream error: %v", err)
+		pendingStreamOutput.Wait()
+		return err
+	case err := <-m.stdioError:
+		// When we receive an error from the user, wait for any
+		// remaining input from the user to be sent to the stream
+		// before exiting.
+		vlog.VI(2).Infof("run stdio error: %v", err)
+		pendingUserInput.Wait()
+		return err
+	}
+}
+
+func (m *ioManager) sendStreamError(err error) {
+	select {
+	case m.streamError <- err:
+	default:
+	}
+}
+
+func (m *ioManager) sendStdioError(err error) {
+	select {
+	case m.stdioError <- err:
+	default:
+	}
 }
 
 // chan2stream receives ClientShellPacket from outchan and sends it to stream.
-func (m *ioManager) chan2stream() {
+func (m *ioManager) chan2stream(outchan <-chan tunnel.ClientShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	sender := m.stream.SendStream()
-	for packet := range m.outchan {
+	for packet := range outchan {
+		vlog.VI(3).Infof("chan2stream packet: %+v", packet)
 		if err := sender.Send(packet); err != nil {
-			m.done <- err
-			return
+			vlog.VI(2).Infof("chan2stream: %v", err)
+			m.sendStreamError(err)
 		}
 	}
-	m.done <- io.EOF
+	m.sendStreamError(io.EOF)
 }
 
-func (m *ioManager) sendOnOutchan(p tunnel.ClientShellPacket) bool {
-	select {
-	case m.outchan <- p:
-		return true
-	case <-m.closed:
-		return false
-	}
-}
-
-func (m *ioManager) handleWindowResize(winch chan os.Signal) {
+func (m *ioManager) handleWindowResize(winch <-chan os.Signal, outchan chan<- tunnel.ClientShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	for _ = range winch {
 		ws, err := lib.GetWindowSize()
 		if err != nil {
 			vlog.Infof("GetWindowSize failed: %v", err)
 			continue
 		}
-		if !m.sendOnOutchan(tunnel.ClientShellPacket{Rows: uint32(ws.Row), Cols: uint32(ws.Col)}) {
-			return
-		}
+		outchan <- tunnel.ClientShellPacket{Rows: uint32(ws.Row), Cols: uint32(ws.Col)}
 	}
 }
 
 // user2stream reads input from stdin and sends it to the outchan.
-func (m *ioManager) user2outchan() {
+func (m *ioManager) user2outchan(outchan chan<- tunnel.ClientShellPacket, wg *sync.WaitGroup) {
+	defer wg.Done()
 	for {
 		buf := make([]byte, 2048)
 		n, err := m.stdin.Read(buf[:])
+		if err == io.EOF {
+			vlog.VI(2).Infof("user2outchan: EOF, closing stdin")
+			outchan <- tunnel.ClientShellPacket{EOF: true}
+			return
+		}
 		if err != nil {
-			vlog.VI(2).Infof("user2stream: %v", err)
-			m.done <- err
+			vlog.VI(2).Infof("user2outchan: %v", err)
+			m.sendStdioError(err)
 			return
 		}
-		if !m.sendOnOutchan(tunnel.ClientShellPacket{Stdin: buf[:n]}) {
-			return
-		}
+		outchan <- tunnel.ClientShellPacket{Stdin: buf[:n]}
 	}
 }
 
 // stream2user reads data from the stream and sends it to either stdout or stderr.
-func (m *ioManager) stream2user() {
+func (m *ioManager) stream2user(wg *sync.WaitGroup) {
+	defer wg.Done()
 	rStream := m.stream.RecvStream()
 	for rStream.Advance() {
 		packet := rStream.Value()
+		vlog.VI(3).Infof("stream2user packet: %+v", packet)
 
 		if len(packet.Stdout) > 0 {
 			if n, err := m.stdout.Write(packet.Stdout); n != len(packet.Stdout) || err != nil {
-				m.done <- fmt.Errorf("stdout.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stdout))
+				m.sendStdioError(fmt.Errorf("stdout.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stdout)))
 				return
 			}
 		}
 		if len(packet.Stderr) > 0 {
 			if n, err := m.stderr.Write(packet.Stderr); n != len(packet.Stderr) || err != nil {
-				m.done <- fmt.Errorf("stderr.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stderr))
+				m.sendStdioError(fmt.Errorf("stderr.Write returned (%d, %v) want (%d, nil)", n, err, len(packet.Stderr)))
 				return
 			}
 		}
@@ -130,5 +176,5 @@
 		err = io.EOF
 	}
 	vlog.VI(2).Infof("stream2user: %v", err)
-	m.done <- err
+	m.sendStreamError(err)
 }
diff --git a/examples/tunnel/vsh/main.go b/examples/tunnel/vsh/main.go
index a460012..43f0586 100644
--- a/examples/tunnel/vsh/main.go
+++ b/examples/tunnel/vsh/main.go
@@ -100,8 +100,10 @@
 		fmt.Fprintf(os.Stderr, "Error: %v\n", err)
 		return 1
 	}
-	saved := lib.EnterRawTerminalMode()
-	defer lib.RestoreTerminalSettings(saved)
+	if opts.UsePty {
+		saved := lib.EnterRawTerminalMode()
+		defer lib.RestoreTerminalSettings(saved)
+	}
 	runIOManager(os.Stdin, os.Stdout, os.Stderr, stream)
 
 	exitMsg := fmt.Sprintf("Connection to %s closed.", oname)