ref/lib/security: Add a method to PrepareDischarges.
Change-Id: I4475b039c2252526e00f0ba406254d0950822da6
diff --git a/lib/security/prepare_discharges.go b/lib/security/prepare_discharges.go
new file mode 100644
index 0000000..8cdca3e
--- /dev/null
+++ b/lib/security/prepare_discharges.go
@@ -0,0 +1,171 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security
+
+import (
+ "sync"
+ "time"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/security"
+ "v.io/v23/vdl"
+ "v.io/v23/vtrace"
+)
+
+// If this is attached to the context, we will not fetch discharges.
+// We use this to prevent ourselves from fetching discharges in the
+// process of fetching discharges, thus creating an infinite loop.
+type skipDischargesKey struct{}
+
+// PrepareDischarges retrieves the caveat discharges required for using blessings
+// at server. The discharges are either found in the dischargeCache, in the call
+// options, or requested from the discharge issuer indicated on the caveat.
+// Note that requesting a discharge is an rpc call, so one copy of this
+// function must be able to successfully terminate while another is blocked.
+func PrepareDischarges(
+ ctx *context.T,
+ blessings security.Blessings,
+ impetus security.DischargeImpetus,
+ expiryBuffer time.Duration) map[string]security.Discharge {
+ tpCavs := blessings.ThirdPartyCaveats()
+ if skip, _ := ctx.Value(skipDischargesKey{}).(bool); skip || len(tpCavs) == 0 {
+ return nil
+ }
+ ctx = context.WithValue(ctx, skipDischargesKey{}, true)
+
+ // Make a copy since this copy will be mutated.
+ var caveats []security.Caveat
+ var filteredImpetuses []security.DischargeImpetus
+ for _, cav := range tpCavs {
+ // It shouldn't happen, but in case there are non-third-party
+ // caveats, drop them.
+ if tp := cav.ThirdPartyDetails(); tp != nil {
+ caveats = append(caveats, cav)
+ filteredImpetuses = append(filteredImpetuses, filteredImpetus(tp.Requirements(), impetus))
+ }
+ }
+ bstore := v23.GetPrincipal(ctx).BlessingStore()
+ // Gather discharges from cache.
+ discharges, rem := discharges(bstore, caveats, impetus)
+ if rem > 0 {
+ // Fetch discharges for caveats for which no discharges were
+ // found in the cache.
+ if ctx != nil {
+ var span vtrace.Span
+ ctx, span = vtrace.WithNewSpan(ctx, "Fetching Discharges")
+ defer span.Finish()
+ }
+ fetchDischarges(ctx, caveats, filteredImpetuses, discharges, expiryBuffer)
+ }
+ ret := make(map[string]security.Discharge, len(discharges))
+ for _, d := range discharges {
+ if d.ID() != "" {
+ ret[d.ID()] = d
+ }
+ }
+ return ret
+}
+
+func discharges(bs security.BlessingStore, caveats []security.Caveat, imp security.DischargeImpetus) (out []security.Discharge, rem int) {
+ out = make([]security.Discharge, len(caveats))
+ for i := range caveats {
+ out[i] = bs.Discharge(caveats[i], imp)
+ if out[i].ID() == "" {
+ rem++
+ }
+ }
+ return
+}
+
+// fetchDischarges fills out by fetching discharges for caveats from the
+// appropriate discharge service. Since there may be dependencies in the
+// caveats, fetchDischarges keeps retrying until either all discharges can be
+// fetched or no new discharges are fetched.
+// REQUIRES: len(caveats) == len(out)
+// REQUIRES: caveats[i].ThirdPartyDetails() != nil for 0 <= i < len(caveats)
+func fetchDischarges(
+ ctx *context.T,
+ caveats []security.Caveat,
+ impetuses []security.DischargeImpetus,
+ out []security.Discharge,
+ expiryBuffer time.Duration) {
+ bstore := v23.GetPrincipal(ctx).BlessingStore()
+ var wg sync.WaitGroup
+ for {
+ type fetched struct {
+ idx int
+ discharge security.Discharge
+ caveat security.Caveat
+ impetus security.DischargeImpetus
+ }
+ discharges := make(chan fetched, len(caveats))
+ want := 0
+ for i := range caveats {
+ if !shouldFetchDischarge(out[i], expiryBuffer) {
+ continue
+ }
+ want++
+ wg.Add(1)
+ go func(i int, ctx *context.T, cav security.Caveat) {
+ defer wg.Done()
+ tp := cav.ThirdPartyDetails()
+ var dis security.Discharge
+ ctx.VI(3).Infof("Fetching discharge for %v", tp)
+ if err := v23.GetClient(ctx).Call(ctx, tp.Location(), "Discharge",
+ []interface{}{cav, impetuses[i]}, []interface{}{&dis}); err != nil {
+ ctx.VI(3).Infof("Discharge fetch for %v failed: %v", tp, err)
+ return
+ }
+ discharges <- fetched{i, dis, caveats[i], impetuses[i]}
+ }(i, ctx, caveats[i])
+ }
+ wg.Wait()
+ close(discharges)
+ var got int
+ for fetched := range discharges {
+ bstore.CacheDischarge(fetched.discharge, fetched.caveat, fetched.impetus)
+ out[fetched.idx] = fetched.discharge
+ got++
+ }
+ if want > 0 {
+ ctx.VI(3).Infof("fetchDischarges: got %d of %d discharge(s) (total %d caveats)", got, want, len(caveats))
+ }
+ if got == 0 || got == want {
+ return
+ }
+ }
+}
+
+// filteredImpetus returns a copy of 'before' after removing any values that are not required as per 'r'.
+func filteredImpetus(r security.ThirdPartyRequirements, before security.DischargeImpetus) (after security.DischargeImpetus) {
+ if r.ReportServer && len(before.Server) > 0 {
+ after.Server = make([]security.BlessingPattern, len(before.Server))
+ for i := range before.Server {
+ after.Server[i] = before.Server[i]
+ }
+ }
+ if r.ReportMethod {
+ after.Method = before.Method
+ }
+ if r.ReportArguments && len(before.Arguments) > 0 {
+ after.Arguments = make([]*vdl.Value, len(before.Arguments))
+ for i := range before.Arguments {
+ after.Arguments[i] = vdl.CopyValue(before.Arguments[i])
+ }
+ }
+ return
+}
+
+func shouldFetchDischarge(dis security.Discharge, expiryBuffer time.Duration) bool {
+ if dis.ID() == "" {
+ return true
+ }
+ expiry := dis.Expiry()
+ if expiry.IsZero() {
+ return false
+ }
+ return expiry.Before(time.Now().Add(expiryBuffer))
+}
diff --git a/lib/security/prepare_discharges_test.go b/lib/security/prepare_discharges_test.go
new file mode 100644
index 0000000..292c58e
--- /dev/null
+++ b/lib/security/prepare_discharges_test.go
@@ -0,0 +1,137 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/rpc"
+ "v.io/v23/security"
+ securitylib "v.io/x/ref/lib/security"
+ "v.io/x/ref/lib/xrpc"
+ _ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/test"
+ "v.io/x/ref/test/testutil"
+)
+
+func init() {
+ test.Init()
+}
+
+type expiryDischarger struct {
+ called bool
+}
+
+func (ed *expiryDischarger) Discharge(ctx *context.T, call rpc.StreamServerCall, cav security.Caveat, _ security.DischargeImpetus) (security.Discharge, error) {
+ tp := cav.ThirdPartyDetails()
+ if tp == nil {
+ return security.Discharge{}, fmt.Errorf("discharger: %v does not represent a third-party caveat", cav)
+ }
+ if err := tp.Dischargeable(ctx, call.Security()); err != nil {
+ return security.Discharge{}, fmt.Errorf("third-party caveat %v cannot be discharged for this context: %v", cav, err)
+ }
+ expDur := 10 * time.Millisecond
+ if ed.called {
+ expDur = time.Second
+ }
+ expiry, err := security.NewExpiryCaveat(time.Now().Add(expDur))
+ if err != nil {
+ return security.Discharge{}, fmt.Errorf("failed to create an expiration on the discharge: %v", err)
+ }
+ d, err := call.Security().LocalPrincipal().MintDischarge(cav, expiry)
+ if err != nil {
+ return security.Discharge{}, err
+ }
+ ctx.Infof("got discharge on sever %#v", d)
+ ed.called = true
+ return d, nil
+}
+
+func TestPrepareDischarges(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
+
+ pclient := testutil.NewPrincipal("client")
+ cctx, err := v23.WithPrincipal(ctx, pclient)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pdischarger := testutil.NewPrincipal("discharger")
+ dctx, err := v23.WithPrincipal(ctx, pdischarger)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pclient.AddToRoots(pdischarger.BlessingStore().Default())
+ pclient.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+ pdischarger.AddToRoots(pclient.BlessingStore().Default())
+ pdischarger.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+
+ expcav, err := security.NewExpiryCaveat(time.Now().Add(time.Hour))
+ if err != nil {
+ t.Fatal(err)
+ }
+ tpcav, err := security.NewPublicKeyCaveat(
+ pdischarger.PublicKey(),
+ "discharger",
+ security.ThirdPartyRequirements{},
+ expcav)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cbless, err := pclient.BlessSelf("clientcaveats", tpcav)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tpid := tpcav.ThirdPartyDetails().ID()
+
+ v23.GetPrincipal(dctx)
+ _, err = xrpc.NewServer(dctx,
+ "discharger",
+ &expiryDischarger{},
+ security.AllowEveryone())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Fetch discharges for tpcav.
+ buffer := 100 * time.Millisecond
+ discharges := securitylib.PrepareDischarges(cctx, cbless,
+ security.DischargeImpetus{}, buffer)
+ if len(discharges) != 1 {
+ t.Errorf("Got %d discharges, expected 1.", len(discharges))
+ }
+ dis, has := discharges[tpid]
+ if !has {
+ t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+ }
+ // Check that the discharges is not yet expired, but is expired after 100 milliseconds.
+ expiry := dis.Expiry()
+ // The discharge should expire.
+ select {
+ case <-time.After(time.Now().Sub(expiry)):
+ break
+ case <-time.After(time.Second):
+ t.Fatalf("discharge didn't expire within a second")
+ }
+
+ // Preparing Discharges again to get fresh discharges.
+ discharges = securitylib.PrepareDischarges(cctx, cbless,
+ security.DischargeImpetus{}, buffer)
+ if len(discharges) != 1 {
+ t.Errorf("Got %d discharges, expected 1.", len(discharges))
+ }
+ dis, has = discharges[tpid]
+ if !has {
+ t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+ }
+ now := time.Now()
+ if expiry = dis.Expiry(); expiry.Before(now) {
+ t.Fatalf("discharge has expired %v, but should be fresh", dis)
+ }
+}