veyron/security/agent/server: Improve agent shutdown.
Make sure NotifyWhenChanged shuts down when the agent's context is cancelled.
Also exit read loops when the agent's context is cancelled.

Change-Id: I19197ef7d7d6d4c1f07d26c31f63e3b8e9216608
diff --git a/security/agent/agent_test.go b/security/agent/agent_test.go
index 8eb1f3b..9237f50 100644
--- a/security/agent/agent_test.go
+++ b/security/agent/agent_test.go
@@ -1,12 +1,17 @@
 package agent_test
 
 import (
+	"fmt"
+	"io"
+	"io/ioutil"
 	"os"
 	"reflect"
 	"syscall"
 	"testing"
 	"time"
 
+	"v.io/core/veyron/lib/expect"
+	"v.io/core/veyron/lib/modules"
 	"v.io/core/veyron/lib/testutil"
 	tsecurity "v.io/core/veyron/lib/testutil/security"
 	_ "v.io/core/veyron/profiles"
@@ -18,6 +23,24 @@
 	"v.io/core/veyron2/security"
 )
 
+func init() {
+	modules.RegisterChild("hang", "", getPrincipalAndHang)
+}
+
+func TestHelperProcess(t *testing.T) {
+	modules.DispatchInTest()
+}
+
+func getPrincipalAndHang(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+	ctx, shutdown := testutil.InitForTest()
+	defer shutdown()
+
+	p := veyron2.GetPrincipal(ctx)
+	fmt.Fprintf(stdout, "DEFAULT_BLESSING=%s\n", p.BlessingStore().Default())
+	ioutil.ReadAll(stdin)
+	return nil
+}
+
 func newAgent(ctx *context.T, sock *os.File, cached bool) (security.Principal, error) {
 	fd, err := syscall.Dup(int(sock.Fd()))
 	if err != nil {
@@ -121,6 +144,37 @@
 	}
 }
 
+func TestAgentShutdown(t *testing.T) {
+	ctx, shutdown := testutil.InitForTest()
+
+	// This starts an agent
+	sh, err := modules.NewShell(ctx, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// The child process will connect to the agent
+	h, err := sh.Start("hang", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	s := expect.NewSession(t, h.Stdout(), time.Minute)
+	fmt.Fprintf(os.Stderr, "reading var...\n")
+	s.ExpectVar("DEFAULT_BLESSING")
+	fmt.Fprintf(os.Stderr, "read\n")
+	if err := s.Error(); err != nil {
+		t.Fatalf("failed to read input: %s", err)
+	}
+	fmt.Fprintf(os.Stderr, "shutting down...\n")
+	// This shouldn't not hang
+	shutdown()
+	fmt.Fprintf(os.Stderr, "shut down\n")
+
+	fmt.Fprintf(os.Stderr, "cleanup...\n")
+	sh.Cleanup(os.Stdout, os.Stderr)
+	fmt.Fprintf(os.Stderr, "cleanup done\n")
+}
+
 var message = []byte("bugs bunny")
 
 func runSignBenchmark(b *testing.B, p security.Principal) {
diff --git a/security/agent/server/server.go b/security/agent/server/server.go
index 28baa1e..b76b845 100644
--- a/security/agent/server/server.go
+++ b/security/agent/server/server.go
@@ -44,6 +44,7 @@
 	id        int
 	w         *watchers
 	principal security.Principal
+	ctx       *context.T
 }
 
 type keyData struct {
@@ -103,6 +104,14 @@
 }
 
 func (a keymgr) readDMConns(conn *net.UnixConn) {
+	donech := a.ctx.Done()
+	if donech != nil {
+		go func() {
+			// Shut down our read loop if the context is cancelled
+			<-donech
+			conn.Close()
+		}()
+	}
 	defer conn.Close()
 	var buf keyHandle
 	for {
@@ -110,8 +119,14 @@
 		if err == io.EOF {
 			return
 		} else if err != nil {
-			vlog.Infof("Error accepting connection: %v", err)
-			continue
+			// We ignore read errors, unless the context is cancelled.
+			select {
+			case <-donech:
+				return
+			default:
+				vlog.Infof("Error accepting connection: %v", err)
+				continue
+			}
 		}
 		ack()
 		var data *keyData
@@ -200,6 +215,14 @@
 }
 
 func startAgent(ctx *context.T, conn *net.UnixConn, w *watchers, principal security.Principal) error {
+	donech := ctx.Done()
+	if donech != nil {
+		go func() {
+			// Interrupt the read loop if the context is cancelled.
+			<-donech
+			conn.Close()
+		}()
+	}
 	go func() {
 		buf := make([]byte, 1)
 		for {
@@ -207,6 +230,15 @@
 			if err == io.EOF {
 				conn.Close()
 				return
+			} else if err != nil {
+				// We ignore read errors, unless the context is cancelled.
+				select {
+				case <-donech:
+					return
+				default:
+					vlog.Infof("Error accepting connection: %v", err)
+					continue
+				}
 			}
 			if clientAddr != nil {
 				// VCSecurityNone is safe since we're using anonymous unix sockets.
@@ -226,7 +258,7 @@
 				}
 				spec := ipc.ListenSpec{Addrs: ipc.ListenAddrs(a)}
 				if _, err = s.Listen(spec); err == nil {
-					agent := &agentd{w.newID(), w, principal}
+					agent := &agentd{w.newID(), w, principal, ctx}
 					serverAgent := AgentServer(agent)
 					err = s.Serve("", serverAgent, nil)
 				}
@@ -442,6 +474,8 @@
 	defer a.w.unregister(a.id, ch)
 	for {
 		select {
+		case <-a.ctx.Done():
+			return nil
 		case <-ctx.Context().Done():
 			return nil
 		case _, ok := <-ch:
diff --git a/security/agent/v23_test.go b/security/agent/v23_test.go
index ba81e42..2464c23 100644
--- a/security/agent/v23_test.go
+++ b/security/agent/v23_test.go
@@ -2,14 +2,27 @@
 // DO NOT UPDATE MANUALLY
 package agent_test
 
+import "fmt"
 import "testing"
 import "os"
 
+import "v.io/core/veyron/lib/modules"
 import "v.io/core/veyron/lib/testutil"
 import "v.io/core/veyron/lib/testutil/v23tests"
 
+func init() {
+	modules.RegisterChild("getPrincipalAndHang", ``, getPrincipalAndHang)
+}
+
 func TestMain(m *testing.M) {
 	testutil.Init()
+	if modules.IsModulesProcess() {
+		if err := modules.Dispatch(); err != nil {
+			fmt.Fprintf(os.Stderr, "modules.Dispatch failed: %v\n", err)
+			os.Exit(1)
+		}
+		return
+	}
 	cleanup := v23tests.UseSharedBinDir()
 	r := m.Run()
 	cleanup()