veyron/security/agent/server: Improve agent shutdown.
Make sure NotifyWhenChanged shuts down when the agent's context is cancelled.
Also exit read loops when the agent's context is cancelled.
Change-Id: I19197ef7d7d6d4c1f07d26c31f63e3b8e9216608
diff --git a/security/agent/agent_test.go b/security/agent/agent_test.go
index 8eb1f3b..9237f50 100644
--- a/security/agent/agent_test.go
+++ b/security/agent/agent_test.go
@@ -1,12 +1,17 @@
package agent_test
import (
+ "fmt"
+ "io"
+ "io/ioutil"
"os"
"reflect"
"syscall"
"testing"
"time"
+ "v.io/core/veyron/lib/expect"
+ "v.io/core/veyron/lib/modules"
"v.io/core/veyron/lib/testutil"
tsecurity "v.io/core/veyron/lib/testutil/security"
_ "v.io/core/veyron/profiles"
@@ -18,6 +23,24 @@
"v.io/core/veyron2/security"
)
+func init() {
+ modules.RegisterChild("hang", "", getPrincipalAndHang)
+}
+
+func TestHelperProcess(t *testing.T) {
+ modules.DispatchInTest()
+}
+
+func getPrincipalAndHang(stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
+ ctx, shutdown := testutil.InitForTest()
+ defer shutdown()
+
+ p := veyron2.GetPrincipal(ctx)
+ fmt.Fprintf(stdout, "DEFAULT_BLESSING=%s\n", p.BlessingStore().Default())
+ ioutil.ReadAll(stdin)
+ return nil
+}
+
func newAgent(ctx *context.T, sock *os.File, cached bool) (security.Principal, error) {
fd, err := syscall.Dup(int(sock.Fd()))
if err != nil {
@@ -121,6 +144,37 @@
}
}
+func TestAgentShutdown(t *testing.T) {
+ ctx, shutdown := testutil.InitForTest()
+
+ // This starts an agent
+ sh, err := modules.NewShell(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // The child process will connect to the agent
+ h, err := sh.Start("hang", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ s := expect.NewSession(t, h.Stdout(), time.Minute)
+ fmt.Fprintf(os.Stderr, "reading var...\n")
+ s.ExpectVar("DEFAULT_BLESSING")
+ fmt.Fprintf(os.Stderr, "read\n")
+ if err := s.Error(); err != nil {
+ t.Fatalf("failed to read input: %s", err)
+ }
+ fmt.Fprintf(os.Stderr, "shutting down...\n")
+ // This shouldn't not hang
+ shutdown()
+ fmt.Fprintf(os.Stderr, "shut down\n")
+
+ fmt.Fprintf(os.Stderr, "cleanup...\n")
+ sh.Cleanup(os.Stdout, os.Stderr)
+ fmt.Fprintf(os.Stderr, "cleanup done\n")
+}
+
var message = []byte("bugs bunny")
func runSignBenchmark(b *testing.B, p security.Principal) {
diff --git a/security/agent/server/server.go b/security/agent/server/server.go
index 28baa1e..b76b845 100644
--- a/security/agent/server/server.go
+++ b/security/agent/server/server.go
@@ -44,6 +44,7 @@
id int
w *watchers
principal security.Principal
+ ctx *context.T
}
type keyData struct {
@@ -103,6 +104,14 @@
}
func (a keymgr) readDMConns(conn *net.UnixConn) {
+ donech := a.ctx.Done()
+ if donech != nil {
+ go func() {
+ // Shut down our read loop if the context is cancelled
+ <-donech
+ conn.Close()
+ }()
+ }
defer conn.Close()
var buf keyHandle
for {
@@ -110,8 +119,14 @@
if err == io.EOF {
return
} else if err != nil {
- vlog.Infof("Error accepting connection: %v", err)
- continue
+ // We ignore read errors, unless the context is cancelled.
+ select {
+ case <-donech:
+ return
+ default:
+ vlog.Infof("Error accepting connection: %v", err)
+ continue
+ }
}
ack()
var data *keyData
@@ -200,6 +215,14 @@
}
func startAgent(ctx *context.T, conn *net.UnixConn, w *watchers, principal security.Principal) error {
+ donech := ctx.Done()
+ if donech != nil {
+ go func() {
+ // Interrupt the read loop if the context is cancelled.
+ <-donech
+ conn.Close()
+ }()
+ }
go func() {
buf := make([]byte, 1)
for {
@@ -207,6 +230,15 @@
if err == io.EOF {
conn.Close()
return
+ } else if err != nil {
+ // We ignore read errors, unless the context is cancelled.
+ select {
+ case <-donech:
+ return
+ default:
+ vlog.Infof("Error accepting connection: %v", err)
+ continue
+ }
}
if clientAddr != nil {
// VCSecurityNone is safe since we're using anonymous unix sockets.
@@ -226,7 +258,7 @@
}
spec := ipc.ListenSpec{Addrs: ipc.ListenAddrs(a)}
if _, err = s.Listen(spec); err == nil {
- agent := &agentd{w.newID(), w, principal}
+ agent := &agentd{w.newID(), w, principal, ctx}
serverAgent := AgentServer(agent)
err = s.Serve("", serverAgent, nil)
}
@@ -442,6 +474,8 @@
defer a.w.unregister(a.id, ch)
for {
select {
+ case <-a.ctx.Done():
+ return nil
case <-ctx.Context().Done():
return nil
case _, ok := <-ch:
diff --git a/security/agent/v23_test.go b/security/agent/v23_test.go
index ba81e42..2464c23 100644
--- a/security/agent/v23_test.go
+++ b/security/agent/v23_test.go
@@ -2,14 +2,27 @@
// DO NOT UPDATE MANUALLY
package agent_test
+import "fmt"
import "testing"
import "os"
+import "v.io/core/veyron/lib/modules"
import "v.io/core/veyron/lib/testutil"
import "v.io/core/veyron/lib/testutil/v23tests"
+func init() {
+ modules.RegisterChild("getPrincipalAndHang", ``, getPrincipalAndHang)
+}
+
func TestMain(m *testing.M) {
testutil.Init()
+ if modules.IsModulesProcess() {
+ if err := modules.Dispatch(); err != nil {
+ fmt.Fprintf(os.Stderr, "modules.Dispatch failed: %v\n", err)
+ os.Exit(1)
+ }
+ return
+ }
cleanup := v23tests.UseSharedBinDir()
r := m.Run()
cleanup()