veyron/runtimes/google/ipc: Add NoDischarges option to prevent discharge client
infinite recursion of fetching discharges.
More described here: https://docs.google.com/document/d/1vKGrbLWZ6QyTbJHvlm0N0vVvSo8MMwlcf3WvAmFY4fA/edit
Change-Id: Ic69112c412de483b6f071734fc75d4456148b087
diff --git a/runtimes/google/ipc/client.go b/runtimes/google/ipc/client.go
index 41405ce..e1cc0ae 100644
--- a/runtimes/google/ipc/client.go
+++ b/runtimes/google/ipc/client.go
@@ -129,7 +129,7 @@
return c, nil
}
-func (c *client) createFlow(ctx context.T, ep naming.Endpoint) (stream.Flow, verror.E) {
+func (c *client) createFlow(ctx context.T, ep naming.Endpoint, noDischarges bool) (stream.Flow, verror.E) {
c.vcMapMu.Lock()
defer c.vcMapMu.Unlock()
if c.vcMap == nil {
@@ -151,6 +151,9 @@
vcOpts := make([]stream.VCOpt, len(c.vcOpts))
copy(vcOpts, c.vcOpts)
c.vcMapMu.Unlock()
+ if noDischarges {
+ vcOpts = append(vcOpts, vc.NoDischarges{})
+ }
vc, err := sm.Dial(ep, vcOpts...)
c.vcMapMu.Lock()
if err != nil {
@@ -183,7 +186,7 @@
// a flow to the endpoint, returning the parsed suffix.
// The server name passed in should be a rooted name, of the form "/ep/suffix" or
// "/ep//suffix", or just "/ep".
-func (c *client) connectFlow(ctx context.T, server string) (stream.Flow, string, verror.E) {
+func (c *client) connectFlow(ctx context.T, server string, noDischarges bool) (stream.Flow, string, verror.E) {
address, suffix := naming.SplitAddressName(server)
if len(address) == 0 {
return nil, "", verror.Make(errNonRootedName, ctx, server)
@@ -195,7 +198,7 @@
if err = version.CheckCompatibility(ep); err != nil {
return nil, "", verror.Make(errIncompatibleEndpoint, ctx, ep)
}
- flow, verr := c.createFlow(ctx, ep)
+ flow, verr := c.createFlow(ctx, ep, noDischarges)
if verr != nil {
return nil, "", verr
}
@@ -247,6 +250,15 @@
return false
}
+func shouldNotFetchDischarges(opts []ipc.CallOpt) bool {
+ for _, o := range opts {
+ if _, ok := o.(vc.NoDischarges); ok {
+ return true
+ }
+ }
+ return false
+}
+
func mkDischargeImpetus(serverBlessings []string, method string, args []interface{}) security.DischargeImpetus {
var impetus security.DischargeImpetus
if len(serverBlessings) > 0 {
@@ -312,10 +324,10 @@
}
// TODO(cnicolaou): implement real, configurable load balancing.
-func (c *client) tryServer(ctx context.T, index int, server string, ch chan<- *serverStatus) {
+func (c *client) tryServer(ctx context.T, index int, server string, ch chan<- *serverStatus, noDischarges bool) {
status := &serverStatus{index: index}
var err verror.E
- if status.flow, status.suffix, err = c.connectFlow(ctx, server); err != nil {
+ if status.flow, status.suffix, err = c.connectFlow(ctx, server, noDischarges); err != nil {
vlog.VI(2).Infof("ipc: err: %s", err)
status.err = err
status.flow = nil
@@ -326,6 +338,7 @@
// tryCall makes a single attempt at a call, against possibly multiple servers.
func (c *client) tryCall(ctx context.T, name, method string, args []interface{}, opts []ipc.CallOpt) (ipc.Call, verror.E) {
mtPattern, serverPattern, name := splitObjectName(name)
+ noDischarges := shouldNotFetchDischarges(opts)
// Resolve name unless told not to.
var servers []string
if getNoResolveOpt(opts) {
@@ -366,7 +379,7 @@
responses := make([]*serverStatus, attempts)
ch := make(chan *serverStatus, attempts)
for i, server := range servers {
- go c.tryServer(ctx, i, server, ch)
+ go c.tryServer(ctx, i, server, ch, noDischarges)
}
delay := time.Duration(ipc.NoTimeout)
@@ -451,6 +464,9 @@
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
timeout = deadline.Sub(time.Now())
}
+ if noDischarges {
+ fc.dc = nil
+ }
if verr := fc.start(r.suffix, method, args, timeout, grantedB); verr != nil {
return nil, verr
}
diff --git a/runtimes/google/ipc/discharges.go b/runtimes/google/ipc/discharges.go
index 40459fd..53ce00a 100644
--- a/runtimes/google/ipc/discharges.go
+++ b/runtimes/google/ipc/discharges.go
@@ -92,7 +92,7 @@
go func(i int, cav security.ThirdPartyCaveat) {
defer wg.Done()
vlog.VI(3).Infof("Fetching discharge for %v", cav)
- call, err := d.c.StartCall(ctx, cav.Location(), "Discharge", []interface{}{cav, filteredImpetus(cav.Requirements(), impetus)})
+ call, err := d.c.StartCall(ctx, cav.Location(), "Discharge", []interface{}{cav, filteredImpetus(cav.Requirements(), impetus)}, vc.NoDischarges{})
if err != nil {
vlog.VI(3).Infof("Discharge fetch for %v failed: %v", cav, err)
return
diff --git a/runtimes/google/ipc/full_test.go b/runtimes/google/ipc/full_test.go
index d58264a..61163f4 100644
--- a/runtimes/google/ipc/full_test.go
+++ b/runtimes/google/ipc/full_test.go
@@ -1228,7 +1228,7 @@
t.Fatalf("InternalNewClient failed: %v", err)
}
// When using VCSecurityNone, all authorization checks should be skipped, so
- // unauthorized methods shoudl be callable.
+ // unauthorized methods should be callable.
call, err := client.StartCall(testContext(), "mp/server", "Unauthorized", nil)
if err != nil {
t.Fatalf("client.StartCall failed: %v", err)
@@ -1372,6 +1372,172 @@
}
}
+type mockDischarger struct {
+ called bool
+}
+
+func (m *mockDischarger) Discharge(ctx ipc.ServerContext, caveatAny vdlutil.Any, _ security.DischargeImpetus) (vdlutil.Any, error) {
+ m.called = true
+ caveat, ok := caveatAny.(security.ThirdPartyCaveat)
+ if !ok {
+ return nil, fmt.Errorf("type %T does not implement security.ThirdPartyCaveat", caveatAny)
+ }
+ return ctx.LocalPrincipal().MintDischarge(caveat, security.UnconstrainedUse())
+}
+
+func TestNoDischargesOpt(t *testing.T) {
+ var (
+ pdischarger = tsecurity.NewPrincipal("discharger")
+ pserver = tsecurity.NewPrincipal("server")
+ pclient = tsecurity.NewPrincipal("client")
+ )
+ // Make the client recognize all server blessings
+ if err := pclient.AddToRoots(pserver.BlessingStore().Default()); err != nil {
+ t.Fatal(err)
+ }
+ if err := pclient.AddToRoots(pdischarger.BlessingStore().Default()); err != nil {
+ t.Fatal(err)
+ }
+
+ // Bless the client with a ThirdPartyCaveat.
+ tpcav := mkThirdPartyCaveat(pdischarger.PublicKey(), "mountpoint/discharger", mkCaveat(security.ExpiryCaveat(time.Now().Add(time.Hour))))
+ blessings, err := pserver.Bless(pclient.PublicKey(), pserver.BlessingStore().Default(), "tpcav", tpcav)
+ if err != nil {
+ t.Fatalf("failed to create Blessings: %v", err)
+ }
+ if _, err = pclient.BlessingStore().Set(blessings, "server"); err != nil {
+ t.Fatalf("failed to set blessings: %v", err)
+ }
+
+ ns := tnaming.NewSimpleNamespace()
+ runServer := func(name string, obj interface{}, principal security.Principal) stream.Manager {
+ rid, err := naming.NewRoutingID()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sm := imanager.InternalNew(rid)
+ server, err := InternalNewServer(testContext(), sm, ns, nil, vc.LocalPrincipal{principal})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := server.Listen(listenSpec); err != nil {
+ t.Fatal(err)
+ }
+ if err := server.Serve(name, obj, acceptAllAuthorizer{}); err != nil {
+ t.Fatal(err)
+ }
+ return sm
+ }
+
+ // Setup the disharger and test server.
+ discharger := &mockDischarger{}
+ defer runServer("mountpoint/discharger", discharger, pdischarger).Shutdown()
+ defer runServer("mountpoint/testServer", &testServer{}, pserver).Shutdown()
+
+ runClient := func(noDischarges bool) {
+ rid, err := naming.NewRoutingID()
+ if err != nil {
+ t.Fatal(err)
+ }
+ smc := imanager.InternalNew(rid)
+ defer smc.Shutdown()
+ dc, err := InternalNewDischargeClient(smc, ns, testContext())
+ if err != nil {
+ t.Fatal(err)
+ }
+ client, err := InternalNewClient(smc, ns, vc.LocalPrincipal{pclient}, dc)
+ if err != nil {
+ t.Fatalf("failed to create client: %v", err)
+ }
+ defer client.Close()
+ var opts []ipc.CallOpt
+ if noDischarges {
+ opts = append(opts, vc.NoDischarges{})
+ }
+ if _, err = client.StartCall(testContext(), "mountpoint/testServer", "Closure", nil, opts...); err != nil {
+ t.Fatalf("failed to StartCall: %v", err)
+ }
+ }
+
+ // Test that when the NoDischarges option is set, mockDischarger does not get called.
+ if runClient(true); discharger.called {
+ t.Errorf("did not expect discharger to be called")
+ }
+ discharger.called = false
+ // Test that when the Nodischarges option is not set, mockDischarger does get called.
+ if runClient(false); !discharger.called {
+ t.Errorf("expected discharger to be called")
+ }
+}
+
+func TestNoImplicitDischargeFetching(t *testing.T) {
+ // This test ensures that discharge clients only fetch discharges for the specified tp caveats and not its own.
+ var (
+ pdischarger1 = tsecurity.NewPrincipal("discharger1")
+ pdischarger2 = tsecurity.NewPrincipal("discharger2")
+ pdischargeClient = tsecurity.NewPrincipal("dischargeClient")
+ )
+
+ // Bless the client with a ThirdPartyCaveat from discharger1.
+ tpcav1 := mkThirdPartyCaveat(pdischarger1.PublicKey(), "mountpoint/discharger1", mkCaveat(security.ExpiryCaveat(time.Now().Add(time.Hour))))
+ blessings, err := pdischarger1.Bless(pdischargeClient.PublicKey(), pdischarger1.BlessingStore().Default(), "tpcav1", tpcav1)
+ if err != nil {
+ t.Fatalf("failed to create Blessings: %v", err)
+ }
+ if err = pdischargeClient.BlessingStore().SetDefault(blessings); err != nil {
+ t.Fatalf("failed to set blessings: %v", err)
+ }
+
+ ns := tnaming.NewSimpleNamespace()
+ runServer := func(name string, obj interface{}, principal security.Principal) stream.Manager {
+ rid, err := naming.NewRoutingID()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sm := imanager.InternalNew(rid)
+ server, err := InternalNewServer(testContext(), sm, ns, nil, vc.LocalPrincipal{principal})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := server.Listen(listenSpec); err != nil {
+ t.Fatal(err)
+ }
+ if err := server.Serve(name, obj, acceptAllAuthorizer{}); err != nil {
+ t.Fatal(err)
+ }
+ return sm
+ }
+
+ // Setup the disharger and test server.
+ discharger1 := &mockDischarger{}
+ discharger2 := &mockDischarger{}
+ defer runServer("mountpoint/discharger1", discharger1, pdischarger1).Shutdown()
+ defer runServer("mountpoint/discharger2", discharger2, pdischarger2).Shutdown()
+
+ rid, err := naming.NewRoutingID()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sm := imanager.InternalNew(rid)
+ dc, err := InternalNewDischargeClient(sm, ns, testContext(), vc.LocalPrincipal{pdischargeClient})
+ if err != nil {
+ t.Fatal(err)
+ }
+ tpcav2, err := security.NewPublicKeyCaveat(pdischarger2.PublicKey(), "mountpoint/discharger2", security.ThirdPartyRequirements{}, mkCaveat(security.ExpiryCaveat(time.Now().Add(time.Hour))))
+ if err != nil {
+ t.Error(err)
+ }
+ _ = dc.PrepareDischarges([]security.ThirdPartyCaveat{tpcav2}, security.DischargeImpetus{})
+
+ // Ensure that discharger1 was not called and discharger2 was called.
+ if discharger1.called {
+ t.Errorf("discharge for caveat on discharge client should not have been fetched.")
+ }
+ if !discharger2.called {
+ t.Errorf("discharge for caveat passed to PrepareDischarges should have been fetched.")
+ }
+}
+
func init() {
testutil.Init()
vom.Register(fakeTimeCaveat(0))
diff --git a/runtimes/google/ipc/stream/vc/vc.go b/runtimes/google/ipc/stream/vc/vc.go
index f5783cb..3a91f8b 100644
--- a/runtimes/google/ipc/stream/vc/vc.go
+++ b/runtimes/google/ipc/stream/vc/vc.go
@@ -62,6 +62,12 @@
version version.IPCVersion
}
+// NoDischarges specifies that the RPC call should not fetch discharges.
+type NoDischarges struct{}
+
+func (NoDischarges) IPCCallOpt() {}
+func (NoDischarges) IPCStreamVCOpt() {}
+
var _ stream.VC = (*VC)(nil)
// Helper is the interface for functionality required by the stream.VC
@@ -361,6 +367,7 @@
tlsSessionCache crypto.TLSClientSessionCache
securityLevel options.VCSecurityLevel
dischargeClient DischargeClient
+ noDischarges bool
)
for _, o := range opts {
switch v := o.(type) {
@@ -372,8 +379,14 @@
securityLevel = v
case crypto.TLSClientSessionCache:
tlsSessionCache = v
+ case NoDischarges:
+ noDischarges = true
}
}
+ // If noDischarge is provided, disable the dischargeClient.
+ if noDischarges {
+ dischargeClient = nil
+ }
switch securityLevel {
case options.VCSecurityConfidential:
if principal == nil {
diff --git a/runtimes/google/ipc/stream/vif/auth.go b/runtimes/google/ipc/stream/vif/auth.go
index 6a05496..5d3adc6 100644
--- a/runtimes/google/ipc/stream/vif/auth.go
+++ b/runtimes/google/ipc/stream/vif/auth.go
@@ -231,6 +231,7 @@
// list.
func clientAuthOptions(lopts []stream.VCOpt) (principal security.Principal, dischargeClient vc.DischargeClient, err error) {
var securityLevel options.VCSecurityLevel
+ var noDischarges bool
for _, o := range lopts {
switch v := o.(type) {
case vc.DischargeClient:
@@ -239,8 +240,13 @@
principal = v.Principal
case options.VCSecurityLevel:
securityLevel = v
+ case vc.NoDischarges:
+ noDischarges = true
}
}
+ if noDischarges {
+ dischargeClient = nil
+ }
switch securityLevel {
case options.VCSecurityConfidential:
if principal == nil {