ref/lib/security: Add a method to PrepareDischarges.

Change-Id: I4475b039c2252526e00f0ba406254d0950822da6
diff --git a/lib/security/prepare_discharges.go b/lib/security/prepare_discharges.go
new file mode 100644
index 0000000..8cdca3e
--- /dev/null
+++ b/lib/security/prepare_discharges.go
@@ -0,0 +1,171 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security
+
+import (
+	"sync"
+	"time"
+
+	"v.io/v23"
+	"v.io/v23/context"
+	"v.io/v23/security"
+	"v.io/v23/vdl"
+	"v.io/v23/vtrace"
+)
+
+// If this is attached to the context, we will not fetch discharges.
+// We use this to prevent ourselves from fetching discharges in the
+// process of fetching discharges, thus creating an infinite loop.
+type skipDischargesKey struct{}
+
+// PrepareDischarges retrieves the caveat discharges required for using blessings
+// at server. The discharges are either found in the dischargeCache, in the call
+// options, or requested from the discharge issuer indicated on the caveat.
+// Note that requesting a discharge is an rpc call, so one copy of this
+// function must be able to successfully terminate while another is blocked.
+func PrepareDischarges(
+	ctx *context.T,
+	blessings security.Blessings,
+	impetus security.DischargeImpetus,
+	expiryBuffer time.Duration) map[string]security.Discharge {
+	tpCavs := blessings.ThirdPartyCaveats()
+	if skip, _ := ctx.Value(skipDischargesKey{}).(bool); skip || len(tpCavs) == 0 {
+		return nil
+	}
+	ctx = context.WithValue(ctx, skipDischargesKey{}, true)
+
+	// Make a copy since this copy will be mutated.
+	var caveats []security.Caveat
+	var filteredImpetuses []security.DischargeImpetus
+	for _, cav := range tpCavs {
+		// It shouldn't happen, but in case there are non-third-party
+		// caveats, drop them.
+		if tp := cav.ThirdPartyDetails(); tp != nil {
+			caveats = append(caveats, cav)
+			filteredImpetuses = append(filteredImpetuses, filteredImpetus(tp.Requirements(), impetus))
+		}
+	}
+	bstore := v23.GetPrincipal(ctx).BlessingStore()
+	// Gather discharges from cache.
+	discharges, rem := discharges(bstore, caveats, impetus)
+	if rem > 0 {
+		// Fetch discharges for caveats for which no discharges were
+		// found in the cache.
+		if ctx != nil {
+			var span vtrace.Span
+			ctx, span = vtrace.WithNewSpan(ctx, "Fetching Discharges")
+			defer span.Finish()
+		}
+		fetchDischarges(ctx, caveats, filteredImpetuses, discharges, expiryBuffer)
+	}
+	ret := make(map[string]security.Discharge, len(discharges))
+	for _, d := range discharges {
+		if d.ID() != "" {
+			ret[d.ID()] = d
+		}
+	}
+	return ret
+}
+
+func discharges(bs security.BlessingStore, caveats []security.Caveat, imp security.DischargeImpetus) (out []security.Discharge, rem int) {
+	out = make([]security.Discharge, len(caveats))
+	for i := range caveats {
+		out[i] = bs.Discharge(caveats[i], imp)
+		if out[i].ID() == "" {
+			rem++
+		}
+	}
+	return
+}
+
+// fetchDischarges fills out by fetching discharges for caveats from the
+// appropriate discharge service. Since there may be dependencies in the
+// caveats, fetchDischarges keeps retrying until either all discharges can be
+// fetched or no new discharges are fetched.
+// REQUIRES: len(caveats) == len(out)
+// REQUIRES: caveats[i].ThirdPartyDetails() != nil for 0 <= i < len(caveats)
+func fetchDischarges(
+	ctx *context.T,
+	caveats []security.Caveat,
+	impetuses []security.DischargeImpetus,
+	out []security.Discharge,
+	expiryBuffer time.Duration) {
+	bstore := v23.GetPrincipal(ctx).BlessingStore()
+	var wg sync.WaitGroup
+	for {
+		type fetched struct {
+			idx       int
+			discharge security.Discharge
+			caveat    security.Caveat
+			impetus   security.DischargeImpetus
+		}
+		discharges := make(chan fetched, len(caveats))
+		want := 0
+		for i := range caveats {
+			if !shouldFetchDischarge(out[i], expiryBuffer) {
+				continue
+			}
+			want++
+			wg.Add(1)
+			go func(i int, ctx *context.T, cav security.Caveat) {
+				defer wg.Done()
+				tp := cav.ThirdPartyDetails()
+				var dis security.Discharge
+				ctx.VI(3).Infof("Fetching discharge for %v", tp)
+				if err := v23.GetClient(ctx).Call(ctx, tp.Location(), "Discharge",
+					[]interface{}{cav, impetuses[i]}, []interface{}{&dis}); err != nil {
+					ctx.VI(3).Infof("Discharge fetch for %v failed: %v", tp, err)
+					return
+				}
+				discharges <- fetched{i, dis, caveats[i], impetuses[i]}
+			}(i, ctx, caveats[i])
+		}
+		wg.Wait()
+		close(discharges)
+		var got int
+		for fetched := range discharges {
+			bstore.CacheDischarge(fetched.discharge, fetched.caveat, fetched.impetus)
+			out[fetched.idx] = fetched.discharge
+			got++
+		}
+		if want > 0 {
+			ctx.VI(3).Infof("fetchDischarges: got %d of %d discharge(s) (total %d caveats)", got, want, len(caveats))
+		}
+		if got == 0 || got == want {
+			return
+		}
+	}
+}
+
+// filteredImpetus returns a copy of 'before' after removing any values that are not required as per 'r'.
+func filteredImpetus(r security.ThirdPartyRequirements, before security.DischargeImpetus) (after security.DischargeImpetus) {
+	if r.ReportServer && len(before.Server) > 0 {
+		after.Server = make([]security.BlessingPattern, len(before.Server))
+		for i := range before.Server {
+			after.Server[i] = before.Server[i]
+		}
+	}
+	if r.ReportMethod {
+		after.Method = before.Method
+	}
+	if r.ReportArguments && len(before.Arguments) > 0 {
+		after.Arguments = make([]*vdl.Value, len(before.Arguments))
+		for i := range before.Arguments {
+			after.Arguments[i] = vdl.CopyValue(before.Arguments[i])
+		}
+	}
+	return
+}
+
+func shouldFetchDischarge(dis security.Discharge, expiryBuffer time.Duration) bool {
+	if dis.ID() == "" {
+		return true
+	}
+	expiry := dis.Expiry()
+	if expiry.IsZero() {
+		return false
+	}
+	return expiry.Before(time.Now().Add(expiryBuffer))
+}
diff --git a/lib/security/prepare_discharges_test.go b/lib/security/prepare_discharges_test.go
new file mode 100644
index 0000000..292c58e
--- /dev/null
+++ b/lib/security/prepare_discharges_test.go
@@ -0,0 +1,137 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security_test
+
+import (
+	"fmt"
+	"testing"
+	"time"
+
+	"v.io/v23"
+	"v.io/v23/context"
+	"v.io/v23/rpc"
+	"v.io/v23/security"
+	securitylib "v.io/x/ref/lib/security"
+	"v.io/x/ref/lib/xrpc"
+	_ "v.io/x/ref/runtime/factories/generic"
+	"v.io/x/ref/test"
+	"v.io/x/ref/test/testutil"
+)
+
+func init() {
+	test.Init()
+}
+
+type expiryDischarger struct {
+	called bool
+}
+
+func (ed *expiryDischarger) Discharge(ctx *context.T, call rpc.StreamServerCall, cav security.Caveat, _ security.DischargeImpetus) (security.Discharge, error) {
+	tp := cav.ThirdPartyDetails()
+	if tp == nil {
+		return security.Discharge{}, fmt.Errorf("discharger: %v does not represent a third-party caveat", cav)
+	}
+	if err := tp.Dischargeable(ctx, call.Security()); err != nil {
+		return security.Discharge{}, fmt.Errorf("third-party caveat %v cannot be discharged for this context: %v", cav, err)
+	}
+	expDur := 10 * time.Millisecond
+	if ed.called {
+		expDur = time.Second
+	}
+	expiry, err := security.NewExpiryCaveat(time.Now().Add(expDur))
+	if err != nil {
+		return security.Discharge{}, fmt.Errorf("failed to create an expiration on the discharge: %v", err)
+	}
+	d, err := call.Security().LocalPrincipal().MintDischarge(cav, expiry)
+	if err != nil {
+		return security.Discharge{}, err
+	}
+	ctx.Infof("got discharge on sever %#v", d)
+	ed.called = true
+	return d, nil
+}
+
+func TestPrepareDischarges(t *testing.T) {
+	ctx, shutdown := test.V23Init()
+	defer shutdown()
+
+	pclient := testutil.NewPrincipal("client")
+	cctx, err := v23.WithPrincipal(ctx, pclient)
+	if err != nil {
+		t.Fatal(err)
+	}
+	pdischarger := testutil.NewPrincipal("discharger")
+	dctx, err := v23.WithPrincipal(ctx, pdischarger)
+	if err != nil {
+		t.Fatal(err)
+	}
+	pclient.AddToRoots(pdischarger.BlessingStore().Default())
+	pclient.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+	pdischarger.AddToRoots(pclient.BlessingStore().Default())
+	pdischarger.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+
+	expcav, err := security.NewExpiryCaveat(time.Now().Add(time.Hour))
+	if err != nil {
+		t.Fatal(err)
+	}
+	tpcav, err := security.NewPublicKeyCaveat(
+		pdischarger.PublicKey(),
+		"discharger",
+		security.ThirdPartyRequirements{},
+		expcav)
+	if err != nil {
+		t.Fatal(err)
+	}
+	cbless, err := pclient.BlessSelf("clientcaveats", tpcav)
+	if err != nil {
+		t.Fatal(err)
+	}
+	tpid := tpcav.ThirdPartyDetails().ID()
+
+	v23.GetPrincipal(dctx)
+	_, err = xrpc.NewServer(dctx,
+		"discharger",
+		&expiryDischarger{},
+		security.AllowEveryone())
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Fetch discharges for tpcav.
+	buffer := 100 * time.Millisecond
+	discharges := securitylib.PrepareDischarges(cctx, cbless,
+		security.DischargeImpetus{}, buffer)
+	if len(discharges) != 1 {
+		t.Errorf("Got %d discharges, expected 1.", len(discharges))
+	}
+	dis, has := discharges[tpid]
+	if !has {
+		t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+	}
+	// Check that the discharges is not yet expired, but is expired after 100 milliseconds.
+	expiry := dis.Expiry()
+	// The discharge should expire.
+	select {
+	case <-time.After(time.Now().Sub(expiry)):
+		break
+	case <-time.After(time.Second):
+		t.Fatalf("discharge didn't expire within a second")
+	}
+
+	// Preparing Discharges again to get fresh discharges.
+	discharges = securitylib.PrepareDischarges(cctx, cbless,
+		security.DischargeImpetus{}, buffer)
+	if len(discharges) != 1 {
+		t.Errorf("Got %d discharges, expected 1.", len(discharges))
+	}
+	dis, has = discharges[tpid]
+	if !has {
+		t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+	}
+	now := time.Now()
+	if expiry = dis.Expiry(); expiry.Before(now) {
+		t.Fatalf("discharge has expired %v, but should be fresh", dis)
+	}
+}