Merge "mounttablelib: restrict name element lengths to 512 characters."
diff --git a/lib/security/prepare_discharges.go b/lib/security/prepare_discharges.go
new file mode 100644
index 0000000..8cdca3e
--- /dev/null
+++ b/lib/security/prepare_discharges.go
@@ -0,0 +1,171 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security
+
+import (
+ "sync"
+ "time"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/security"
+ "v.io/v23/vdl"
+ "v.io/v23/vtrace"
+)
+
+// If this is attached to the context, we will not fetch discharges.
+// We use this to prevent ourselves from fetching discharges in the
+// process of fetching discharges, thus creating an infinite loop.
+type skipDischargesKey struct{}
+
+// PrepareDischarges retrieves the caveat discharges required for using blessings
+// at server. The discharges are either found in the dischargeCache, in the call
+// options, or requested from the discharge issuer indicated on the caveat.
+// Note that requesting a discharge is an rpc call, so one copy of this
+// function must be able to successfully terminate while another is blocked.
+func PrepareDischarges(
+ ctx *context.T,
+ blessings security.Blessings,
+ impetus security.DischargeImpetus,
+ expiryBuffer time.Duration) map[string]security.Discharge {
+ tpCavs := blessings.ThirdPartyCaveats()
+ if skip, _ := ctx.Value(skipDischargesKey{}).(bool); skip || len(tpCavs) == 0 {
+ return nil
+ }
+ ctx = context.WithValue(ctx, skipDischargesKey{}, true)
+
+ // Make a copy since this copy will be mutated.
+ var caveats []security.Caveat
+ var filteredImpetuses []security.DischargeImpetus
+ for _, cav := range tpCavs {
+ // It shouldn't happen, but in case there are non-third-party
+ // caveats, drop them.
+ if tp := cav.ThirdPartyDetails(); tp != nil {
+ caveats = append(caveats, cav)
+ filteredImpetuses = append(filteredImpetuses, filteredImpetus(tp.Requirements(), impetus))
+ }
+ }
+ bstore := v23.GetPrincipal(ctx).BlessingStore()
+ // Gather discharges from cache.
+ discharges, rem := discharges(bstore, caveats, impetus)
+ if rem > 0 {
+ // Fetch discharges for caveats for which no discharges were
+ // found in the cache.
+ if ctx != nil {
+ var span vtrace.Span
+ ctx, span = vtrace.WithNewSpan(ctx, "Fetching Discharges")
+ defer span.Finish()
+ }
+ fetchDischarges(ctx, caveats, filteredImpetuses, discharges, expiryBuffer)
+ }
+ ret := make(map[string]security.Discharge, len(discharges))
+ for _, d := range discharges {
+ if d.ID() != "" {
+ ret[d.ID()] = d
+ }
+ }
+ return ret
+}
+
+func discharges(bs security.BlessingStore, caveats []security.Caveat, imp security.DischargeImpetus) (out []security.Discharge, rem int) {
+ out = make([]security.Discharge, len(caveats))
+ for i := range caveats {
+ out[i] = bs.Discharge(caveats[i], imp)
+ if out[i].ID() == "" {
+ rem++
+ }
+ }
+ return
+}
+
+// fetchDischarges fills out by fetching discharges for caveats from the
+// appropriate discharge service. Since there may be dependencies in the
+// caveats, fetchDischarges keeps retrying until either all discharges can be
+// fetched or no new discharges are fetched.
+// REQUIRES: len(caveats) == len(out)
+// REQUIRES: caveats[i].ThirdPartyDetails() != nil for 0 <= i < len(caveats)
+func fetchDischarges(
+ ctx *context.T,
+ caveats []security.Caveat,
+ impetuses []security.DischargeImpetus,
+ out []security.Discharge,
+ expiryBuffer time.Duration) {
+ bstore := v23.GetPrincipal(ctx).BlessingStore()
+ var wg sync.WaitGroup
+ for {
+ type fetched struct {
+ idx int
+ discharge security.Discharge
+ caveat security.Caveat
+ impetus security.DischargeImpetus
+ }
+ discharges := make(chan fetched, len(caveats))
+ want := 0
+ for i := range caveats {
+ if !shouldFetchDischarge(out[i], expiryBuffer) {
+ continue
+ }
+ want++
+ wg.Add(1)
+ go func(i int, ctx *context.T, cav security.Caveat) {
+ defer wg.Done()
+ tp := cav.ThirdPartyDetails()
+ var dis security.Discharge
+ ctx.VI(3).Infof("Fetching discharge for %v", tp)
+ if err := v23.GetClient(ctx).Call(ctx, tp.Location(), "Discharge",
+ []interface{}{cav, impetuses[i]}, []interface{}{&dis}); err != nil {
+ ctx.VI(3).Infof("Discharge fetch for %v failed: %v", tp, err)
+ return
+ }
+ discharges <- fetched{i, dis, caveats[i], impetuses[i]}
+ }(i, ctx, caveats[i])
+ }
+ wg.Wait()
+ close(discharges)
+ var got int
+ for fetched := range discharges {
+ bstore.CacheDischarge(fetched.discharge, fetched.caveat, fetched.impetus)
+ out[fetched.idx] = fetched.discharge
+ got++
+ }
+ if want > 0 {
+ ctx.VI(3).Infof("fetchDischarges: got %d of %d discharge(s) (total %d caveats)", got, want, len(caveats))
+ }
+ if got == 0 || got == want {
+ return
+ }
+ }
+}
+
+// filteredImpetus returns a copy of 'before' after removing any values that are not required as per 'r'.
+func filteredImpetus(r security.ThirdPartyRequirements, before security.DischargeImpetus) (after security.DischargeImpetus) {
+ if r.ReportServer && len(before.Server) > 0 {
+ after.Server = make([]security.BlessingPattern, len(before.Server))
+ for i := range before.Server {
+ after.Server[i] = before.Server[i]
+ }
+ }
+ if r.ReportMethod {
+ after.Method = before.Method
+ }
+ if r.ReportArguments && len(before.Arguments) > 0 {
+ after.Arguments = make([]*vdl.Value, len(before.Arguments))
+ for i := range before.Arguments {
+ after.Arguments[i] = vdl.CopyValue(before.Arguments[i])
+ }
+ }
+ return
+}
+
+func shouldFetchDischarge(dis security.Discharge, expiryBuffer time.Duration) bool {
+ if dis.ID() == "" {
+ return true
+ }
+ expiry := dis.Expiry()
+ if expiry.IsZero() {
+ return false
+ }
+ return expiry.Before(time.Now().Add(expiryBuffer))
+}
diff --git a/lib/security/prepare_discharges_test.go b/lib/security/prepare_discharges_test.go
new file mode 100644
index 0000000..292c58e
--- /dev/null
+++ b/lib/security/prepare_discharges_test.go
@@ -0,0 +1,137 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package security_test
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/rpc"
+ "v.io/v23/security"
+ securitylib "v.io/x/ref/lib/security"
+ "v.io/x/ref/lib/xrpc"
+ _ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/test"
+ "v.io/x/ref/test/testutil"
+)
+
+func init() {
+ test.Init()
+}
+
+type expiryDischarger struct {
+ called bool
+}
+
+func (ed *expiryDischarger) Discharge(ctx *context.T, call rpc.StreamServerCall, cav security.Caveat, _ security.DischargeImpetus) (security.Discharge, error) {
+ tp := cav.ThirdPartyDetails()
+ if tp == nil {
+ return security.Discharge{}, fmt.Errorf("discharger: %v does not represent a third-party caveat", cav)
+ }
+ if err := tp.Dischargeable(ctx, call.Security()); err != nil {
+ return security.Discharge{}, fmt.Errorf("third-party caveat %v cannot be discharged for this context: %v", cav, err)
+ }
+ expDur := 10 * time.Millisecond
+ if ed.called {
+ expDur = time.Second
+ }
+ expiry, err := security.NewExpiryCaveat(time.Now().Add(expDur))
+ if err != nil {
+ return security.Discharge{}, fmt.Errorf("failed to create an expiration on the discharge: %v", err)
+ }
+ d, err := call.Security().LocalPrincipal().MintDischarge(cav, expiry)
+ if err != nil {
+ return security.Discharge{}, err
+ }
+ ctx.Infof("got discharge on sever %#v", d)
+ ed.called = true
+ return d, nil
+}
+
+func TestPrepareDischarges(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
+
+ pclient := testutil.NewPrincipal("client")
+ cctx, err := v23.WithPrincipal(ctx, pclient)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pdischarger := testutil.NewPrincipal("discharger")
+ dctx, err := v23.WithPrincipal(ctx, pdischarger)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pclient.AddToRoots(pdischarger.BlessingStore().Default())
+ pclient.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+ pdischarger.AddToRoots(pclient.BlessingStore().Default())
+ pdischarger.AddToRoots(v23.GetPrincipal(ctx).BlessingStore().Default())
+
+ expcav, err := security.NewExpiryCaveat(time.Now().Add(time.Hour))
+ if err != nil {
+ t.Fatal(err)
+ }
+ tpcav, err := security.NewPublicKeyCaveat(
+ pdischarger.PublicKey(),
+ "discharger",
+ security.ThirdPartyRequirements{},
+ expcav)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cbless, err := pclient.BlessSelf("clientcaveats", tpcav)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tpid := tpcav.ThirdPartyDetails().ID()
+
+ v23.GetPrincipal(dctx)
+ _, err = xrpc.NewServer(dctx,
+ "discharger",
+ &expiryDischarger{},
+ security.AllowEveryone())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Fetch discharges for tpcav.
+ buffer := 100 * time.Millisecond
+ discharges := securitylib.PrepareDischarges(cctx, cbless,
+ security.DischargeImpetus{}, buffer)
+ if len(discharges) != 1 {
+ t.Errorf("Got %d discharges, expected 1.", len(discharges))
+ }
+ dis, has := discharges[tpid]
+ if !has {
+ t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+ }
+ // Check that the discharges is not yet expired, but is expired after 100 milliseconds.
+ expiry := dis.Expiry()
+ // The discharge should expire.
+ select {
+ case <-time.After(time.Now().Sub(expiry)):
+ break
+ case <-time.After(time.Second):
+ t.Fatalf("discharge didn't expire within a second")
+ }
+
+ // Preparing Discharges again to get fresh discharges.
+ discharges = securitylib.PrepareDischarges(cctx, cbless,
+ security.DischargeImpetus{}, buffer)
+ if len(discharges) != 1 {
+ t.Errorf("Got %d discharges, expected 1.", len(discharges))
+ }
+ dis, has = discharges[tpid]
+ if !has {
+ t.Errorf("Got %#v, Expected discharge for %s", discharges, tpid)
+ }
+ now := time.Now()
+ if expiry = dis.Expiry(); expiry.Before(now) {
+ t.Fatalf("discharge has expired %v, but should be fresh", dis)
+ }
+}
diff --git a/runtime/factories/fake/runtime.go b/runtime/factories/fake/runtime.go
index f31bd9a..1abdae7 100644
--- a/runtime/factories/fake/runtime.go
+++ b/runtime/factories/fake/runtime.go
@@ -12,8 +12,8 @@
"v.io/v23/security"
"v.io/x/ref/internal/logger"
"v.io/x/ref/lib/apilog"
- vsecurity "v.io/x/ref/lib/security"
tnaming "v.io/x/ref/runtime/internal/testing/mocks/naming"
+ "v.io/x/ref/test/testutil"
)
type contextKey int
@@ -30,11 +30,7 @@
}
func new(ctx *context.T) (*Runtime, *context.T, v23.Shutdown, error) {
- p, err := vsecurity.NewPrincipal()
- if err != nil {
- return nil, nil, func() {}, err
- }
- ctx = context.WithValue(ctx, principalKey, p)
+ ctx = context.WithValue(ctx, principalKey, testutil.NewPrincipal("fake"))
ctx = context.WithLogger(ctx, logger.Global())
return &Runtime{ns: tnaming.NewSimpleNamespace()}, ctx, func() {}, nil
}
diff --git a/runtime/internal/flow/conn/auth.go b/runtime/internal/flow/conn/auth.go
new file mode 100644
index 0000000..eff876d
--- /dev/null
+++ b/runtime/internal/flow/conn/auth.go
@@ -0,0 +1,277 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+ "crypto/rand"
+ "reflect"
+ "sync"
+
+ "golang.org/x/crypto/nacl/box"
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/flow"
+ "v.io/v23/rpc/version"
+ "v.io/v23/security"
+ "v.io/v23/vom"
+)
+
+func (c *Conn) dialHandshake(ctx *context.T, versions version.RPCVersionRange) error {
+ binding, err := c.setup(ctx, versions)
+ if err != nil {
+ return err
+ }
+ c.blessingsFlow = newBlessingsFlow(ctx, c.newFlowLocked(ctx, blessingsFlowID, 0, 0, true, true), true)
+ if err = c.readRemoteAuth(ctx, binding); err != nil {
+ return err
+ }
+ if c.rBlessings.IsZero() {
+ return NewErrAcceptorBlessingsMissing(ctx)
+ }
+ signedBinding, err := v23.GetPrincipal(ctx).Sign(binding)
+ if err != nil {
+ return err
+ }
+ lAuth := &auth{
+ channelBinding: signedBinding,
+ }
+ // We only send our blessings if we are a server in addition to being a client.
+ // If we are a pure client, we only send our public key.
+ if c.handler != nil {
+ bkey, dkey, err := c.blessingsFlow.put(ctx, c.lBlessings, c.lDischarges)
+ if err != nil {
+ return err
+ }
+ lAuth.bkey, lAuth.dkey = bkey, dkey
+ } else {
+ lAuth.publicKey = c.lBlessings.PublicKey()
+ }
+ if err = c.mp.writeMsg(ctx, lAuth); err != nil {
+ return err
+ }
+ return err
+}
+
+func (c *Conn) acceptHandshake(ctx *context.T, versions version.RPCVersionRange) error {
+ binding, err := c.setup(ctx, versions)
+ if err != nil {
+ return err
+ }
+ c.blessingsFlow = newBlessingsFlow(ctx, c.newFlowLocked(ctx, blessingsFlowID, 0, 0, true, true), false)
+ signedBinding, err := v23.GetPrincipal(ctx).Sign(binding)
+ if err != nil {
+ return err
+ }
+ bkey, dkey, err := c.blessingsFlow.put(ctx, c.lBlessings, c.lDischarges)
+ if err != nil {
+ return err
+ }
+ err = c.mp.writeMsg(ctx, &auth{
+ bkey: bkey,
+ dkey: dkey,
+ channelBinding: signedBinding,
+ })
+ if err != nil {
+ return err
+ }
+ return c.readRemoteAuth(ctx, binding)
+}
+
+func (c *Conn) setup(ctx *context.T, versions version.RPCVersionRange) ([]byte, error) {
+ pk, sk, err := box.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+ lSetup := &setup{
+ versions: versions,
+ peerLocalEndpoint: c.local,
+ peerNaClPublicKey: pk,
+ }
+ if c.remote != nil {
+ lSetup.peerRemoteEndpoint = c.remote
+ }
+ ch := make(chan error)
+ go func() {
+ ch <- c.mp.writeMsg(ctx, lSetup)
+ }()
+ msg, err := c.mp.readMsg(ctx)
+ if err != nil {
+ return nil, NewErrRecv(ctx, "unknown", err)
+ }
+ rSetup, valid := msg.(*setup)
+ if !valid {
+ return nil, NewErrUnexpectedMsg(ctx, reflect.TypeOf(msg).String())
+ }
+ if err := <-ch; err != nil {
+ return nil, NewErrSend(ctx, "setup", c.remote.String(), err)
+ }
+ if c.version, err = version.CommonVersion(ctx, lSetup.versions, rSetup.versions); err != nil {
+ return nil, err
+ }
+ // TODO(mattr): Decide which endpoints to actually keep, the ones we know locally
+ // or what the remote side thinks.
+ if rSetup.peerRemoteEndpoint != nil {
+ c.local = rSetup.peerRemoteEndpoint
+ }
+ if rSetup.peerLocalEndpoint != nil {
+ c.remote = rSetup.peerLocalEndpoint
+ }
+ if rSetup.peerNaClPublicKey == nil {
+ return nil, NewErrMissingSetupOption(ctx, peerNaClPublicKeyOption)
+ }
+ return c.mp.setupEncryption(ctx, pk, sk, rSetup.peerNaClPublicKey), nil
+}
+
+func (c *Conn) readRemoteAuth(ctx *context.T, binding []byte) error {
+ var rauth *auth
+ for {
+ msg, err := c.mp.readMsg(ctx)
+ if err != nil {
+ return NewErrRecv(ctx, c.remote.String(), err)
+ }
+ if rauth, _ = msg.(*auth); rauth != nil {
+ break
+ }
+ if err = c.handleMessage(ctx, msg); err != nil {
+ return err
+ }
+ }
+ var rPublicKey security.PublicKey
+ if rauth.bkey != 0 {
+ var err error
+ // TODO(mattr): Make sure we cancel out of this at some point.
+ c.rBlessings, c.rDischarges, err = c.blessingsFlow.get(ctx, rauth.bkey, rauth.dkey)
+ if err != nil {
+ return err
+ }
+ rPublicKey = c.rBlessings.PublicKey()
+ } else {
+ rPublicKey = rauth.publicKey
+ }
+ if rPublicKey == nil {
+ return NewErrNoPublicKey(ctx)
+ }
+ if !rauth.channelBinding.Verify(rPublicKey, binding) {
+ return NewErrInvalidChannelBinding(ctx)
+ }
+ return nil
+}
+
+type blessingsFlow struct {
+ enc *vom.Encoder
+ dec *vom.Decoder
+
+ mu sync.Mutex
+ cond *sync.Cond
+ closed bool
+ nextKey uint64
+ byUID map[string]*Blessings
+ byBKey map[uint64]*Blessings
+}
+
+func newBlessingsFlow(ctx *context.T, f flow.Flow, dialed bool) *blessingsFlow {
+ b := &blessingsFlow{
+ enc: vom.NewEncoder(f),
+ dec: vom.NewDecoder(f),
+ nextKey: 1,
+ byUID: make(map[string]*Blessings),
+ byBKey: make(map[uint64]*Blessings),
+ }
+ b.cond = sync.NewCond(&b.mu)
+ if !dialed {
+ b.nextKey++
+ }
+ go b.readLoop(ctx)
+ return b
+}
+
+func (b *blessingsFlow) put(ctx *context.T, blessings security.Blessings, discharges map[string]security.Discharge) (bkey, dkey uint64, err error) {
+ defer b.mu.Unlock()
+ b.mu.Lock()
+ buid := string(blessings.UniqueID())
+ element, has := b.byUID[buid]
+ if has && equalDischarges(discharges, element.Discharges) {
+ return element.BKey, element.DKey, nil
+ }
+ defer b.cond.Broadcast()
+ if has {
+ element.Discharges = dischargeList(discharges)
+ element.DKey = b.nextKey
+ b.nextKey += 2
+ return element.BKey, element.DKey, b.enc.Encode(Blessings{
+ Discharges: element.Discharges,
+ DKey: element.DKey,
+ })
+ }
+ element = &Blessings{
+ Blessings: blessings,
+ Discharges: dischargeList(discharges),
+ BKey: b.nextKey,
+ }
+ b.nextKey += 2
+ if len(discharges) > 0 {
+ element.DKey = b.nextKey
+ b.nextKey += 2
+ }
+ b.byUID[buid] = element
+ b.byBKey[element.BKey] = element
+ return element.BKey, element.DKey, b.enc.Encode(element)
+}
+
+func (b *blessingsFlow) get(ctx *context.T, bkey, dkey uint64) (security.Blessings, map[string]security.Discharge, error) {
+ defer b.mu.Unlock()
+ b.mu.Lock()
+ for !b.closed {
+ element, has := b.byBKey[bkey]
+ if has && element.DKey == dkey {
+ return element.Blessings, dischargeMap(element.Discharges), nil
+ }
+ b.cond.Wait()
+ }
+ return security.Blessings{}, nil, NewErrBlessingsFlowClosed(ctx)
+}
+
+func (b *blessingsFlow) readLoop(ctx *context.T) {
+ for {
+ var received Blessings
+ err := b.dec.Decode(&received)
+ b.mu.Lock()
+ if err != nil {
+ b.closed = true
+ b.mu.Unlock()
+ return
+ }
+ b.byUID[string(received.Blessings.UniqueID())] = &received
+ b.byBKey[received.BKey] = &received
+ b.cond.Broadcast()
+ b.mu.Unlock()
+ }
+}
+
+func dischargeList(in map[string]security.Discharge) []security.Discharge {
+ out := make([]security.Discharge, 0, len(in))
+ for _, d := range in {
+ out = append(out, d)
+ }
+ return out
+}
+func dischargeMap(in []security.Discharge) map[string]security.Discharge {
+ out := make(map[string]security.Discharge, len(in))
+ for _, d := range in {
+ out[d.ID()] = d
+ }
+ return out
+}
+func equalDischarges(m map[string]security.Discharge, s []security.Discharge) bool {
+ if len(m) != len(s) {
+ return false
+ }
+ for _, d := range s {
+ if !d.Equivalent(m[d.ID()]) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/runtime/internal/flow/conn/auth_test.go b/runtime/internal/flow/conn/auth_test.go
new file mode 100644
index 0000000..df8fee6
--- /dev/null
+++ b/runtime/internal/flow/conn/auth_test.go
@@ -0,0 +1,124 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import (
+ "testing"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/flow"
+ "v.io/v23/security"
+ "v.io/v23/verror"
+ _ "v.io/x/ref/runtime/factories/fake"
+ "v.io/x/ref/test/testutil"
+)
+
+func checkBlessings(t *testing.T, df, af flow.Flow, db, ab security.Blessings) {
+ msg, err := af.ReadMsg()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(msg) != "hello" {
+ t.Fatalf("Got %s, wanted hello", string(msg))
+ }
+ if !af.LocalBlessings().Equivalent(ab) {
+ t.Errorf("got: %#v wanted %#v", af.LocalBlessings(), ab)
+ }
+ if !af.RemoteBlessings().Equivalent(db) {
+ t.Errorf("got: %#v wanted %#v", af.RemoteBlessings(), db)
+ }
+ if !df.LocalBlessings().Equivalent(db) {
+ t.Errorf("got: %#v wanted %#v", df.LocalBlessings(), db)
+ }
+ if !df.RemoteBlessings().Equivalent(ab) {
+ t.Errorf("got: %#v wanted %#v", df.RemoteBlessings(), ab)
+ }
+}
+func dialFlow(t *testing.T, ctx *context.T, dc *Conn, b security.Blessings) flow.Flow {
+ df, err := dc.Dial(ctx, makeBFP(b))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = df.WriteMsg([]byte("hello")); err != nil {
+ t.Fatal(err)
+ }
+ return df
+}
+
+func TestUnidirectional(t *testing.T) {
+ dctx, shutdown := v23.Init()
+ defer shutdown()
+ actx, err := v23.WithPrincipal(dctx, testutil.NewPrincipal("acceptor"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ aflows := make(chan flow.Flow, 2)
+ dc, ac, _ := setupConns(t, dctx, actx, nil, aflows)
+
+ df1 := dialFlow(t, dctx, dc, v23.GetPrincipal(dctx).BlessingStore().Default())
+ af1 := <-aflows
+ checkBlessings(t, df1, af1,
+ v23.GetPrincipal(dctx).BlessingStore().Default(),
+ v23.GetPrincipal(actx).BlessingStore().Default())
+
+ db2, err := v23.GetPrincipal(dctx).BlessSelf("other")
+ if err != nil {
+ t.Fatal(err)
+ }
+ df2 := dialFlow(t, dctx, dc, db2)
+ af2 := <-aflows
+ checkBlessings(t, df2, af2, db2,
+ v23.GetPrincipal(actx).BlessingStore().Default())
+
+ // We should not be able to dial in the other direction, because that flow
+ // manager is not willing to accept flows.
+ _, err = ac.Dial(actx, testBFP)
+ if verror.ErrorID(err) != ErrDialingNonServer.ID {
+ t.Errorf("got %v, wanted ErrDialingNonServer", err)
+ }
+}
+
+func TestBidirectional(t *testing.T) {
+ dctx, shutdown := v23.Init()
+ defer shutdown()
+ actx, err := v23.WithPrincipal(dctx, testutil.NewPrincipal("acceptor"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ dflows := make(chan flow.Flow, 2)
+ aflows := make(chan flow.Flow, 2)
+ dc, ac, _ := setupConns(t, dctx, actx, dflows, aflows)
+
+ df1 := dialFlow(t, dctx, dc, v23.GetPrincipal(dctx).BlessingStore().Default())
+ af1 := <-aflows
+ checkBlessings(t, df1, af1,
+ v23.GetPrincipal(dctx).BlessingStore().Default(),
+ v23.GetPrincipal(actx).BlessingStore().Default())
+
+ db2, err := v23.GetPrincipal(dctx).BlessSelf("other")
+ if err != nil {
+ t.Fatal(err)
+ }
+ df2 := dialFlow(t, dctx, dc, db2)
+ af2 := <-aflows
+ checkBlessings(t, df2, af2, db2,
+ v23.GetPrincipal(actx).BlessingStore().Default())
+
+ af3 := dialFlow(t, actx, ac, v23.GetPrincipal(actx).BlessingStore().Default())
+ df3 := <-dflows
+ checkBlessings(t, af3, df3,
+ v23.GetPrincipal(actx).BlessingStore().Default(),
+ v23.GetPrincipal(dctx).BlessingStore().Default())
+
+ ab2, err := v23.GetPrincipal(actx).BlessSelf("aother")
+ if err != nil {
+ t.Fatal(err)
+ }
+ af4 := dialFlow(t, actx, ac, ab2)
+ df4 := <-dflows
+ checkBlessings(t, af4, df4, ab2,
+ v23.GetPrincipal(dctx).BlessingStore().Default())
+}
diff --git a/runtime/internal/flow/conn/close_test.go b/runtime/internal/flow/conn/close_test.go
index 6dc0987..fd73c50 100644
--- a/runtime/internal/flow/conn/close_test.go
+++ b/runtime/internal/flow/conn/close_test.go
@@ -55,10 +55,10 @@
d.Close(ctx, fmt.Errorf("Closing randomly."))
<-d.Closed()
<-a.Closed()
- if _, err := d.Dial(ctx); err == nil {
+ if _, err := d.Dial(ctx, testBFP); err == nil {
t.Errorf("Nil error dialing on dialer")
}
- if _, err := a.Dial(ctx); err == nil {
+ if _, err := a.Dial(ctx, testBFP); err == nil {
t.Errorf("Nil error dialing on acceptor")
}
}
diff --git a/runtime/internal/flow/conn/conn.go b/runtime/internal/flow/conn/conn.go
index 8cd4b6a..321abde 100644
--- a/runtime/internal/flow/conn/conn.go
+++ b/runtime/internal/flow/conn/conn.go
@@ -15,7 +15,6 @@
"v.io/v23/rpc/version"
"v.io/v23/security"
"v.io/v23/verror"
-
"v.io/x/ref/runtime/internal/flow/flowcontrol"
)
@@ -23,9 +22,14 @@
// Each flow on a given conn will have a unique number.
type flowID uint64
+const (
+ invalidFlowID = flowID(iota)
+ blessingsFlowID
+ reservedFlows = 10
+)
+
const mtu = 1 << 16
const defaultBufferSize = 1 << 20
-const reservedFlows = 10
const (
expressPriority = iota
@@ -45,22 +49,24 @@
}
// Conns are a multiplexing encrypted channels that can host Flows.
+// TODO(mattr): track and clean up all spawned goroutines.
type Conn struct {
- fc *flowcontrol.FlowController
- mp *messagePipe
- handler FlowHandler
- versions version.RPCVersionRange
- acceptorBlessings security.Blessings
- dialerPublicKey security.PublicKey
- local, remote naming.Endpoint
- closed chan struct{}
+ fc *flowcontrol.FlowController
+ mp *messagePipe
+ handler FlowHandler
+ version version.RPCVersion
+ lBlessings, rBlessings security.Blessings
+ rDischarges, lDischarges map[string]security.Discharge
+ local, remote naming.Endpoint
+ closed chan struct{}
+ blessingsFlow *blessingsFlow
mu sync.Mutex
nextFid flowID
flows map[flowID]*flw
}
-// Ensure that *Conn implements flow.Conn
+// Ensure that *Conn implements flow.Conn.
var _ flow.Conn = &Conn{}
// NewDialed dials a new Conn on the given conn.
@@ -69,20 +75,21 @@
conn MsgReadWriteCloser,
local, remote naming.Endpoint,
versions version.RPCVersionRange,
- handler FlowHandler,
- fn flow.BlessingsForPeer) (*Conn, error) {
- principal := v23.GetPrincipal(ctx)
+ handler FlowHandler) (*Conn, error) {
c := &Conn{
- fc: flowcontrol.New(defaultBufferSize, mtu),
- mp: newMessagePipe(conn),
- handler: handler,
- versions: versions,
- dialerPublicKey: principal.PublicKey(),
- local: local,
- remote: remote,
- closed: make(chan struct{}),
- nextFid: reservedFlows,
- flows: map[flowID]*flw{},
+ fc: flowcontrol.New(defaultBufferSize, mtu),
+ mp: newMessagePipe(conn),
+ handler: handler,
+ lBlessings: v23.GetPrincipal(ctx).BlessingStore().Default(),
+ local: local,
+ remote: remote,
+ closed: make(chan struct{}),
+ nextFid: reservedFlows,
+ flows: map[flowID]*flw{},
+ }
+ if err := c.dialHandshake(ctx, versions); err != nil {
+ c.Close(ctx, err)
+ return nil, err
}
go c.readLoop(ctx)
return c, nil
@@ -93,27 +100,39 @@
ctx *context.T,
conn MsgReadWriteCloser,
local naming.Endpoint,
- lBlessings security.Blessings,
versions version.RPCVersionRange,
handler FlowHandler) (*Conn, error) {
c := &Conn{
- fc: flowcontrol.New(defaultBufferSize, mtu),
- mp: newMessagePipe(conn),
- handler: handler,
- versions: versions,
- acceptorBlessings: lBlessings,
- local: local,
- remote: local, // TODO(mattr): Get the real remote endpoint.
- closed: make(chan struct{}),
- nextFid: reservedFlows + 1,
- flows: map[flowID]*flw{},
+ fc: flowcontrol.New(defaultBufferSize, mtu),
+ mp: newMessagePipe(conn),
+ handler: handler,
+ lBlessings: v23.GetPrincipal(ctx).BlessingStore().Default(),
+ local: local,
+ closed: make(chan struct{}),
+ nextFid: reservedFlows + 1,
+ flows: map[flowID]*flw{},
+ }
+ if err := c.acceptHandshake(ctx, versions); err != nil {
+ c.Close(ctx, err)
+ return nil, err
}
go c.readLoop(ctx)
return c, nil
}
// Dial dials a new flow on the Conn.
-func (c *Conn) Dial(ctx *context.T) (flow.Flow, error) {
+func (c *Conn) Dial(ctx *context.T, fn flow.BlessingsForPeer) (flow.Flow, error) {
+ if c.rBlessings.IsZero() {
+ return nil, NewErrDialingNonServer(ctx)
+ }
+ blessings, err := fn(ctx, c.local, c.remote, c.rBlessings, c.rDischarges)
+ if err != nil {
+ return nil, err
+ }
+ bkey, dkey, err := c.blessingsFlow.put(ctx, blessings, nil)
+ if err != nil {
+ return nil, err
+ }
defer c.mu.Unlock()
c.mu.Lock()
if c.flows == nil {
@@ -121,7 +140,7 @@
}
id := c.nextFid
c.nextFid++
- return c.newFlowLocked(ctx, id), nil
+ return c.newFlowLocked(ctx, id, bkey, dkey, true, false), nil
}
// LocalEndpoint returns the local vanadium Endpoint
@@ -130,16 +149,6 @@
// RemoteEndpoint returns the remote vanadium Endpoint
func (c *Conn) RemoteEndpoint() naming.Endpoint { return c.remote }
-// DialerPublicKey returns the public key presented by the dialer during authentication.
-func (c *Conn) DialerPublicKey() security.PublicKey { return c.dialerPublicKey }
-
-// AcceptorBlessings returns the blessings presented by the acceptor during authentication.
-func (c *Conn) AcceptorBlessings() security.Blessings { return c.acceptorBlessings }
-
-// AcceptorDischarges returns the discharges presented by the acceptor during authentication.
-// Discharges are organized in a map keyed by the discharge-identifier.
-func (c *Conn) AcceptorDischarges() map[string]security.Discharge { return nil }
-
// Closed returns a channel that will be closed after the Conn is shutdown.
// After this channel is closed it is guaranteed that all Dial calls will fail
// with an error and no more flows will be sent to the FlowHandler.
@@ -156,10 +165,7 @@
// We've already torn this conn down.
return
}
- ferr := err
- if verror.ErrorID(err) == ErrConnClosedRemotely.ID {
- ferr = NewErrFlowClosedRemotely(ctx)
- } else {
+ if verror.ErrorID(err) != ErrConnClosedRemotely.ID {
message := ""
if err != nil {
message = err.Error()
@@ -172,7 +178,7 @@
}
}
for _, f := range flows {
- f.close(ctx, ferr)
+ f.close(ctx, NewErrConnectionClosed(ctx))
}
if cerr := c.mp.close(); cerr != nil {
ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, cerr)
@@ -206,77 +212,83 @@
}
}
-func (c *Conn) readLoop(ctx *context.T) {
- var terr error
- defer c.Close(ctx, terr)
+func (c *Conn) handleMessage(ctx *context.T, x message) error {
+ switch msg := x.(type) {
+ case *tearDown:
+ return NewErrConnClosedRemotely(ctx, msg.Message)
- for {
- x, err := c.mp.readMsg(ctx)
- if err != nil {
- c.Close(ctx, NewErrRecv(ctx, c.remote.String(), err))
- return
+ case *openFlow:
+ if c.handler == nil {
+ return NewErrUnexpectedMsg(ctx, "openFlow")
+ }
+ c.mu.Lock()
+ f := c.newFlowLocked(ctx, msg.id, msg.bkey, msg.dkey, false, false)
+ c.mu.Unlock()
+ c.handler.HandleFlow(f)
+
+ case *release:
+ release := make([]flowcontrol.Release, 0, len(msg.counters))
+ c.mu.Lock()
+ for fid, val := range msg.counters {
+ if f := c.flows[fid]; f != nil {
+ release = append(release, flowcontrol.Release{
+ Worker: f.worker,
+ Tokens: int(val),
+ })
+ }
+ }
+ c.mu.Unlock()
+ if err := c.fc.Release(ctx, release); err != nil {
+ return err
}
- switch msg := x.(type) {
- case *tearDown:
- terr = NewErrConnClosedRemotely(ctx, msg.Message)
- return
+ case *data:
+ c.mu.Lock()
+ f := c.flows[msg.id]
+ c.mu.Unlock()
+ if f == nil {
+ ctx.Infof("Ignoring data message for unknown flow on connection to %s: %d", c.remote, msg.id)
+ return nil
+ }
+ if err := f.q.put(ctx, msg.payload); err != nil {
+ return err
+ }
+ if msg.flags&closeFlag != 0 {
+ f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
+ }
- case *openFlow:
- c.mu.Lock()
- f := c.newFlowLocked(ctx, msg.id)
- c.mu.Unlock()
- c.handler.HandleFlow(f)
+ case *unencryptedData:
+ c.mu.Lock()
+ f := c.flows[msg.id]
+ c.mu.Unlock()
+ if f == nil {
+ ctx.Infof("Ignoring data message for unknown flow: %d", msg.id)
+ return nil
+ }
+ if err := f.q.put(ctx, msg.payload); err != nil {
+ return err
+ }
+ if msg.flags&closeFlag != 0 {
+ f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
+ }
- case *release:
- release := make([]flowcontrol.Release, 0, len(msg.counters))
- c.mu.Lock()
- for fid, val := range msg.counters {
- if f := c.flows[fid]; f != nil {
- release = append(release, flowcontrol.Release{
- Worker: f.worker,
- Tokens: int(val),
- })
- }
- }
- c.mu.Unlock()
- if terr = c.fc.Release(ctx, release); terr != nil {
- return
- }
+ default:
+ return NewErrUnexpectedMsg(ctx, reflect.TypeOf(msg).String())
+ }
+ return nil
+}
- case *data:
- c.mu.Lock()
- f := c.flows[msg.id]
- c.mu.Unlock()
- if f == nil {
- ctx.Infof("Ignoring data message for unknown flow on connection to %s: %d", c.remote, msg.id)
- continue
- }
- if terr = f.q.put(ctx, msg.payload); terr != nil {
- return
- }
- if msg.flags&closeFlag != 0 {
- f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
- }
-
- case *unencryptedData:
- c.mu.Lock()
- f := c.flows[msg.id]
- c.mu.Unlock()
- if f == nil {
- ctx.Infof("Ignoring data message for unknown flow: %d", msg.id)
- continue
- }
- if terr = f.q.put(ctx, msg.payload); terr != nil {
- return
- }
- if msg.flags&closeFlag != 0 {
- f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
- }
-
- default:
- terr = NewErrUnexpectedMsg(ctx, reflect.TypeOf(msg).Name())
- return
+func (c *Conn) readLoop(ctx *context.T) {
+ var err error
+ for {
+ msg, rerr := c.mp.readMsg(ctx)
+ if rerr != nil {
+ err = NewErrRecv(ctx, c.remote.String(), rerr)
+ break
+ }
+ if err = c.handleMessage(ctx, msg); err != nil {
+ break
}
}
+ c.Close(ctx, err)
}
diff --git a/runtime/internal/flow/conn/conn_test.go b/runtime/internal/flow/conn/conn_test.go
index 7fefba6..49bd59e 100644
--- a/runtime/internal/flow/conn/conn_test.go
+++ b/runtime/internal/flow/conn/conn_test.go
@@ -21,7 +21,6 @@
func init() {
test.Init()
-
randData = make([]byte, 2*defaultBufferSize)
if _, err := rand.Read(randData); err != nil {
panic("Could not read random data.")
@@ -68,15 +67,6 @@
<-af.Closed()
}
-func TestDial(t *testing.T) {
- ctx, shutdown := v23.Init()
- defer shutdown()
- for _, dialerDials := range []bool{true, false} {
- df, flows := setupFlow(t, ctx, ctx, dialerDials)
- testWrite(t, ctx, []byte("hello world"), df, flows)
- }
-}
-
func TestLargeWrite(t *testing.T) {
ctx, shutdown := v23.Init()
defer shutdown()
diff --git a/runtime/internal/flow/conn/conncache_test.go b/runtime/internal/flow/conn/conncache_test.go
index a3c8b80..2686b0b 100644
--- a/runtime/internal/flow/conn/conncache_test.go
+++ b/runtime/internal/flow/conn/conncache_test.go
@@ -255,10 +255,23 @@
}
func makeConn(t *testing.T, ctx *context.T, ep naming.Endpoint) *Conn {
- d, _, _ := newMRWPair(ctx)
- c, err := NewDialed(ctx, d, ep, ep, version.RPCVersionRange{Min: 1, Max: 5}, nil, nil)
- if err != nil {
- t.Fatalf("Could not create conn: %v", err)
- }
- return c
+ dmrw, amrw, _ := newMRWPair(ctx)
+ dch := make(chan *Conn)
+ ach := make(chan *Conn)
+ go func() {
+ d, err := NewDialed(ctx, dmrw, ep, ep, version.RPCVersionRange{1, 5}, nil)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ dch <- d
+ }()
+ go func() {
+ a, err := NewAccepted(ctx, amrw, ep, version.RPCVersionRange{1, 5}, nil)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ ach <- a
+ }()
+ <-dch
+ return <-ach
}
diff --git a/runtime/internal/flow/conn/errors.vdl b/runtime/internal/flow/conn/errors.vdl
index 0e0e821..40e5c0f 100644
--- a/runtime/internal/flow/conn/errors.vdl
+++ b/runtime/internal/flow/conn/errors.vdl
@@ -10,13 +10,15 @@
// since all of their errors are intended to be used as arguments to higher level errors.
// TODO(suharshs,toddw): Allow skipping of {1}{2} in vdl generated errors.
error (
- InvalidMsg(typ byte, size, field int64) {
- "en": "message of type{:typ} and size{:size} failed decoding at field{:field}."}
- InvalidControlMsg(cmd byte, size, field int64) {
- "en": "control message of cmd{:cmd} and size{:size} failed decoding at field{:field}."}
+ InvalidMsg(typ byte, size, field uint64) {"en": "message of type{:typ} and size{:size} failed decoding at field{:field}."}
+ InvalidControlMsg(cmd byte, size, field uint64, err error) {"en": "control message of cmd {cmd} and size {size} failed decoding at field {field}{:err}."}
+ InvalidSetupOption(option, field uint64) {
+ "en": "setup option{:option} failed decoding at field{:field}."}
+ MissingSetupOption(option uint64) {
+ "en": "missing required setup option{:option}."}
+ UnknownSetupOption(option uint64) {"en": "unknown setup option{:option}."}
UnknownMsg(typ byte) {"en":"unknown message type{:typ}."}
UnknownControlMsg(cmd byte) {"en": "unknown control command{:cmd}."}
-
UnexpectedMsg(typ string) {"en": "unexpected message type{:typ}."}
ConnectionClosed() {"en": "connection closed."}
ConnKilledToFreeResources() {"en": "Connection killed to free resources."}
@@ -26,4 +28,9 @@
Recv(src string, err error) {"en": "error reading from {src}{:err}"}
CacheClosed() {"en":"cache is closed"}
CounterOverflow() {"en": "A remote process has sent more data than allowed."}
+ BlessingsFlowClosed() {"en": "The blessings flow was closed."}
+ InvalidChannelBinding() {"en": "The channel binding was invalid."}
+ NoPublicKey() {"en": "No public key was received by the remote end."}
+ DialingNonServer() {"en": "You are attempting to dial on a connection with no remote server."}
+ AcceptorBlessingsMissing() {"en": "The acceptor did not send blessings."}
)
diff --git a/runtime/internal/flow/conn/errors.vdl.go b/runtime/internal/flow/conn/errors.vdl.go
index 2db4372..ed75b03 100644
--- a/runtime/internal/flow/conn/errors.vdl.go
+++ b/runtime/internal/flow/conn/errors.vdl.go
@@ -16,7 +16,10 @@
var (
ErrInvalidMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidMsg", verror.NoRetry, "{1:}{2:} message of type{:3} and size{:4} failed decoding at field{:5}.")
- ErrInvalidControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidControlMsg", verror.NoRetry, "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
+ ErrInvalidControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidControlMsg", verror.NoRetry, "{1:}{2:} control message of cmd {3} and size {4} failed decoding at field {5}{:6}.")
+ ErrInvalidSetupOption = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidSetupOption", verror.NoRetry, "{1:}{2:} setup option{:3} failed decoding at field{:4}.")
+ ErrMissingSetupOption = verror.Register("v.io/x/ref/runtime/internal/flow/conn.MissingSetupOption", verror.NoRetry, "{1:}{2:} missing required setup option{:3}.")
+ ErrUnknownSetupOption = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownSetupOption", verror.NoRetry, "{1:}{2:} unknown setup option{:3}.")
ErrUnknownMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownMsg", verror.NoRetry, "{1:}{2:} unknown message type{:3}.")
ErrUnknownControlMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnknownControlMsg", verror.NoRetry, "{1:}{2:} unknown control command{:3}.")
ErrUnexpectedMsg = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UnexpectedMsg", verror.NoRetry, "{1:}{2:} unexpected message type{:3}.")
@@ -28,11 +31,19 @@
ErrRecv = verror.Register("v.io/x/ref/runtime/internal/flow/conn.Recv", verror.NoRetry, "{1:}{2:} error reading from {3}{:4}")
ErrCacheClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CacheClosed", verror.NoRetry, "{1:}{2:} cache is closed")
ErrCounterOverflow = verror.Register("v.io/x/ref/runtime/internal/flow/conn.CounterOverflow", verror.NoRetry, "{1:}{2:} A remote process has sent more data than allowed.")
+ ErrBlessingsFlowClosed = verror.Register("v.io/x/ref/runtime/internal/flow/conn.BlessingsFlowClosed", verror.NoRetry, "{1:}{2:} The blessings flow was closed.")
+ ErrInvalidChannelBinding = verror.Register("v.io/x/ref/runtime/internal/flow/conn.InvalidChannelBinding", verror.NoRetry, "{1:}{2:} The channel binding was invalid.")
+ ErrNoPublicKey = verror.Register("v.io/x/ref/runtime/internal/flow/conn.NoPublicKey", verror.NoRetry, "{1:}{2:} No public key was received by the remote end.")
+ ErrDialingNonServer = verror.Register("v.io/x/ref/runtime/internal/flow/conn.DialingNonServer", verror.NoRetry, "{1:}{2:} You are attempting to dial on a connection with no remote server.")
+ ErrAcceptorBlessingsMissing = verror.Register("v.io/x/ref/runtime/internal/flow/conn.AcceptorBlessingsMissing", verror.NoRetry, "{1:}{2:} The acceptor did not send blessings.")
)
func init() {
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidMsg.ID), "{1:}{2:} message of type{:3} and size{:4} failed decoding at field{:5}.")
- i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidControlMsg.ID), "{1:}{2:} control message of cmd{:3} and size{:4} failed decoding at field{:5}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidControlMsg.ID), "{1:}{2:} control message of cmd {3} and size {4} failed decoding at field {5}{:6}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidSetupOption.ID), "{1:}{2:} setup option{:3} failed decoding at field{:4}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrMissingSetupOption.ID), "{1:}{2:} missing required setup option{:3}.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownSetupOption.ID), "{1:}{2:} unknown setup option{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownMsg.ID), "{1:}{2:} unknown message type{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnknownControlMsg.ID), "{1:}{2:} unknown control command{:3}.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUnexpectedMsg.ID), "{1:}{2:} unexpected message type{:3}.")
@@ -44,16 +55,36 @@
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrRecv.ID), "{1:}{2:} error reading from {3}{:4}")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrCacheClosed.ID), "{1:}{2:} cache is closed")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrCounterOverflow.ID), "{1:}{2:} A remote process has sent more data than allowed.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrBlessingsFlowClosed.ID), "{1:}{2:} The blessings flow was closed.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrInvalidChannelBinding.ID), "{1:}{2:} The channel binding was invalid.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrNoPublicKey.ID), "{1:}{2:} No public key was received by the remote end.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrDialingNonServer.ID), "{1:}{2:} You are attempting to dial on a connection with no remote server.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrAcceptorBlessingsMissing.ID), "{1:}{2:} The acceptor did not send blessings.")
}
// NewErrInvalidMsg returns an error with the ErrInvalidMsg ID.
-func NewErrInvalidMsg(ctx *context.T, typ byte, size int64, field int64) error {
+func NewErrInvalidMsg(ctx *context.T, typ byte, size uint64, field uint64) error {
return verror.New(ErrInvalidMsg, ctx, typ, size, field)
}
// NewErrInvalidControlMsg returns an error with the ErrInvalidControlMsg ID.
-func NewErrInvalidControlMsg(ctx *context.T, cmd byte, size int64, field int64) error {
- return verror.New(ErrInvalidControlMsg, ctx, cmd, size, field)
+func NewErrInvalidControlMsg(ctx *context.T, cmd byte, size uint64, field uint64, err error) error {
+ return verror.New(ErrInvalidControlMsg, ctx, cmd, size, field, err)
+}
+
+// NewErrInvalidSetupOption returns an error with the ErrInvalidSetupOption ID.
+func NewErrInvalidSetupOption(ctx *context.T, option uint64, field uint64) error {
+ return verror.New(ErrInvalidSetupOption, ctx, option, field)
+}
+
+// NewErrMissingSetupOption returns an error with the ErrMissingSetupOption ID.
+func NewErrMissingSetupOption(ctx *context.T, option uint64) error {
+ return verror.New(ErrMissingSetupOption, ctx, option)
+}
+
+// NewErrUnknownSetupOption returns an error with the ErrUnknownSetupOption ID.
+func NewErrUnknownSetupOption(ctx *context.T, option uint64) error {
+ return verror.New(ErrUnknownSetupOption, ctx, option)
}
// NewErrUnknownMsg returns an error with the ErrUnknownMsg ID.
@@ -110,3 +141,28 @@
func NewErrCounterOverflow(ctx *context.T) error {
return verror.New(ErrCounterOverflow, ctx)
}
+
+// NewErrBlessingsFlowClosed returns an error with the ErrBlessingsFlowClosed ID.
+func NewErrBlessingsFlowClosed(ctx *context.T) error {
+ return verror.New(ErrBlessingsFlowClosed, ctx)
+}
+
+// NewErrInvalidChannelBinding returns an error with the ErrInvalidChannelBinding ID.
+func NewErrInvalidChannelBinding(ctx *context.T) error {
+ return verror.New(ErrInvalidChannelBinding, ctx)
+}
+
+// NewErrNoPublicKey returns an error with the ErrNoPublicKey ID.
+func NewErrNoPublicKey(ctx *context.T) error {
+ return verror.New(ErrNoPublicKey, ctx)
+}
+
+// NewErrDialingNonServer returns an error with the ErrDialingNonServer ID.
+func NewErrDialingNonServer(ctx *context.T) error {
+ return verror.New(ErrDialingNonServer, ctx)
+}
+
+// NewErrAcceptorBlessingsMissing returns an error with the ErrAcceptorBlessingsMissing ID.
+func NewErrAcceptorBlessingsMissing(ctx *context.T) error {
+ return verror.New(ErrAcceptorBlessingsMissing, ctx)
+}
diff --git a/runtime/internal/flow/conn/flow.go b/runtime/internal/flow/conn/flow.go
index f75e10f..a8f832f 100644
--- a/runtime/internal/flow/conn/flow.go
+++ b/runtime/internal/flow/conn/flow.go
@@ -13,35 +13,36 @@
)
type flw struct {
- id flowID
- ctx *context.T
- cancel context.CancelFunc
- conn *Conn
- worker *flowcontrol.Worker
- opened bool
- q *readq
- dialerBlessings security.Blessings
- dialerDischarges map[string]security.Discharge
+ id flowID
+ dialed bool
+ ctx *context.T
+ cancel context.CancelFunc
+ conn *Conn
+ worker *flowcontrol.Worker
+ opened bool
+ q *readq
+ bkey, dkey uint64
}
+// Ensure that *flw implements flow.Flow.
var _ flow.Flow = &flw{}
-func (c *Conn) newFlowLocked(ctx *context.T, id flowID) *flw {
+func (c *Conn) newFlowLocked(ctx *context.T, id flowID, bkey, dkey uint64, dialed, preopen bool) *flw {
f := &flw{
id: id,
+ dialed: dialed,
conn: c,
worker: c.fc.NewWorker(flowPriority),
q: newReadQ(),
+ bkey: bkey,
+ dkey: dkey,
+ opened: preopen,
}
f.SetContext(ctx)
c.flows[id] = f
return f
}
-func (f *flw) dialed() bool {
- return f.id%2 == 0
-}
-
// Implement io.Reader.
// Read and ReadMsg should not be called concurrently with themselves
// or each other.
@@ -91,6 +92,8 @@
err := f.conn.mp.writeMsg(f.ctx, &openFlow{
id: f.id,
initialCounters: defaultBufferSize,
+ bkey: f.bkey,
+ dkey: f.dkey,
})
if err != nil {
return 0, false, err
@@ -173,19 +176,27 @@
// LocalBlessings returns the blessings presented by the local end of the flow
// during authentication.
func (f *flw) LocalBlessings() security.Blessings {
- if f.dialed() {
- return f.dialerBlessings
+ if f.dialed {
+ blessings, _, err := f.conn.blessingsFlow.get(f.ctx, f.bkey, f.dkey)
+ if err != nil {
+ f.conn.Close(f.ctx, err)
+ }
+ return blessings
}
- return f.conn.AcceptorBlessings()
+ return f.conn.lBlessings
}
// RemoteBlessings returns the blessings presented by the remote end of the
// flow during authentication.
func (f *flw) RemoteBlessings() security.Blessings {
- if f.dialed() {
- return f.conn.AcceptorBlessings()
+ if !f.dialed {
+ blessings, _, err := f.conn.blessingsFlow.get(f.ctx, f.bkey, f.dkey)
+ if err != nil {
+ f.conn.Close(f.ctx, err)
+ }
+ return blessings
}
- return f.dialerBlessings
+ return f.conn.rBlessings
}
// LocalDischarges returns the discharges presented by the local end of the
@@ -193,10 +204,14 @@
//
// Discharges are organized in a map keyed by the discharge-identifier.
func (f *flw) LocalDischarges() map[string]security.Discharge {
- if f.dialed() {
- return f.dialerDischarges
+ if f.dialed {
+ _, discharges, err := f.conn.blessingsFlow.get(f.ctx, f.bkey, f.dkey)
+ if err != nil {
+ f.conn.Close(f.ctx, err)
+ }
+ return discharges
}
- return f.conn.AcceptorDischarges()
+ return f.conn.lDischarges
}
// RemoteDischarges returns the discharges presented by the remote end of the
@@ -204,10 +219,14 @@
//
// Discharges are organized in a map keyed by the discharge-identifier.
func (f *flw) RemoteDischarges() map[string]security.Discharge {
- if f.dialed() {
- return f.conn.AcceptorDischarges()
+ if !f.dialed {
+ _, discharges, err := f.conn.blessingsFlow.get(f.ctx, f.bkey, f.dkey)
+ if err != nil {
+ f.conn.Close(f.ctx, err)
+ }
+ return discharges
}
- return f.dialerDischarges
+ return f.conn.rDischarges
}
// Conn returns the connection the flow is multiplexed on.
@@ -229,7 +248,8 @@
func (f *flw) close(ctx *context.T, err error) {
f.q.close(ctx)
f.cancel()
- if verror.ErrorID(err) != ErrFlowClosedRemotely.ID {
+ if eid := verror.ErrorID(err); eid != ErrFlowClosedRemotely.ID &&
+ eid != ErrConnectionClosed.ID {
// We want to try to send this message even if ctx is already canceled.
ctx, cancel := context.WithRootCancel(ctx)
err := f.worker.Run(ctx, func(tokens int) (int, bool, error) {
diff --git a/runtime/internal/flow/conn/message.go b/runtime/internal/flow/conn/message.go
index 6cea426..7202223 100644
--- a/runtime/internal/flow/conn/message.go
+++ b/runtime/internal/flow/conn/message.go
@@ -5,9 +5,12 @@
package conn
import (
+ "v.io/v23"
"v.io/v23/context"
"v.io/v23/naming"
"v.io/v23/rpc/version"
+ "v.io/v23/security"
+ "v.io/x/ref/runtime/internal/rpc/stream/crypto"
)
// TODO(mattr): Link to protocol doc.
@@ -30,6 +33,7 @@
invalidCmd = iota
setupCmd
tearDownCmd
+ authCmd
openFlowCmd
releaseCmd
)
@@ -45,7 +49,6 @@
// data flags.
const (
closeFlag = 1 << iota
- metadataFlag
)
// random consts.
@@ -57,15 +60,48 @@
// and encryption options for connection.
type setup struct {
versions version.RPCVersionRange
- PeerNaClPublicKey *[32]byte
- PeerRemoteEndpoint naming.Endpoint
- PeerLocalEndpoint naming.Endpoint
+ peerNaClPublicKey *[32]byte
+ peerRemoteEndpoint naming.Endpoint
+ peerLocalEndpoint naming.Endpoint
+}
+
+func writeSetupOption(option uint64, payload, buf []byte) []byte {
+ buf = writeVarUint64(option, buf)
+ buf = writeVarUint64(uint64(len(payload)), buf)
+ return append(buf, payload...)
+}
+func readSetupOption(ctx *context.T, orig []byte) (
+ option uint64, payload, data []byte, err error) {
+ var valid bool
+ if option, data, valid = readVarUint64(ctx, orig); !valid {
+ err = NewErrInvalidSetupOption(ctx, invalidOption, 0)
+ return
+ }
+ var size uint64
+ if size, data, valid = readVarUint64(ctx, data); !valid || uint64(len(data)) < size {
+ err = NewErrInvalidSetupOption(ctx, option, 1)
+ return
+ }
+ payload, data = data[:size], data[size:]
+ return
}
func (m *setup) write(ctx *context.T, p *messagePipe) error {
p.controlBuf = writeVarUint64(uint64(m.versions.Min), p.controlBuf[:0])
p.controlBuf = writeVarUint64(uint64(m.versions.Max), p.controlBuf)
- return p.write([][]byte{{controlType}}, [][]byte{{setupCmd}, p.controlBuf})
+ if m.peerNaClPublicKey != nil {
+ p.controlBuf = writeSetupOption(peerNaClPublicKeyOption,
+ m.peerNaClPublicKey[:], p.controlBuf)
+ }
+ if m.peerRemoteEndpoint != nil {
+ p.controlBuf = writeSetupOption(peerRemoteEndpointOption,
+ []byte(m.peerRemoteEndpoint.String()), p.controlBuf)
+ }
+ if m.peerLocalEndpoint != nil {
+ p.controlBuf = writeSetupOption(peerLocalEndpointOption,
+ []byte(m.peerLocalEndpoint.String()), p.controlBuf)
+ }
+ return p.write(ctx, [][]byte{{controlType, setupCmd}, p.controlBuf})
}
func (m *setup) read(ctx *context.T, orig []byte) error {
var (
@@ -74,13 +110,37 @@
v uint64
)
if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, setupCmd, int64(len(orig)), 0)
+ return NewErrInvalidControlMsg(ctx, setupCmd, uint64(len(orig)), 0, nil)
}
m.versions.Min = version.RPCVersion(v)
if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, setupCmd, int64(len(orig)), 1)
+ return NewErrInvalidControlMsg(ctx, setupCmd, uint64(len(orig)), 1, nil)
}
m.versions.Max = version.RPCVersion(v)
+ for field := uint64(2); len(data) > 0; field++ {
+ var (
+ payload []byte
+ option uint64
+ err error
+ )
+ if option, payload, data, err = readSetupOption(ctx, data); err != nil {
+ return NewErrInvalidControlMsg(ctx, setupCmd, uint64(len(orig)), field, err)
+ }
+ switch option {
+ case peerNaClPublicKeyOption:
+ m.peerNaClPublicKey = new([32]byte)
+ copy(m.peerNaClPublicKey[:], payload)
+ case peerRemoteEndpointOption:
+ m.peerRemoteEndpoint, err = v23.NewEndpoint(string(payload))
+ case peerLocalEndpointOption:
+ m.peerLocalEndpoint, err = v23.NewEndpoint(string(payload))
+ default:
+ err = NewErrUnknownSetupOption(ctx, option)
+ }
+ if err != nil {
+ return NewErrInvalidControlMsg(ctx, setupCmd, uint64(len(orig)), field, err)
+ }
+ }
return nil
}
@@ -90,7 +150,7 @@
}
func (m *tearDown) write(ctx *context.T, p *messagePipe) error {
- return p.write([][]byte{{controlType}}, [][]byte{{tearDownCmd}, []byte(m.Message)})
+ return p.write(ctx, [][]byte{{controlType, tearDownCmd}, []byte(m.Message)})
}
func (m *tearDown) read(ctx *context.T, data []byte) error {
if len(data) > 0 {
@@ -99,16 +159,89 @@
return nil
}
+// auth is used to complete the auth handshake.
+type auth struct {
+ bkey, dkey uint64
+ channelBinding security.Signature
+ publicKey security.PublicKey
+}
+
+func (m *auth) write(ctx *context.T, p *messagePipe) error {
+ p.controlBuf = writeVarUint64(m.bkey, p.controlBuf[:0])
+ p.controlBuf = writeVarUint64(m.dkey, p.controlBuf)
+ s := m.channelBinding
+ p.controlBuf = writeVarUint64(uint64(len(s.Purpose)), p.controlBuf)
+ p.controlBuf = append(p.controlBuf, s.Purpose...)
+ p.controlBuf = writeVarUint64(uint64(len(s.Hash)), p.controlBuf)
+ p.controlBuf = append(p.controlBuf, []byte(s.Hash)...)
+ p.controlBuf = writeVarUint64(uint64(len(s.R)), p.controlBuf)
+ p.controlBuf = append(p.controlBuf, s.R...)
+ p.controlBuf = writeVarUint64(uint64(len(s.S)), p.controlBuf)
+ p.controlBuf = append(p.controlBuf, s.S...)
+ if m.publicKey != nil {
+ pk, err := m.publicKey.MarshalBinary()
+ if err != nil {
+ return err
+ }
+ p.controlBuf = append(p.controlBuf, pk...)
+ }
+ return p.write(ctx, [][]byte{{controlType, authCmd}, p.controlBuf})
+}
+func (m *auth) read(ctx *context.T, orig []byte) error {
+ var data []byte
+ var valid bool
+ var l uint64
+ if m.bkey, data, valid = readVarUint64(ctx, orig); !valid {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 0, nil)
+ }
+ if m.dkey, data, valid = readVarUint64(ctx, data); !valid {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 1, nil)
+ }
+ if l, data, valid = readVarUint64(ctx, data); !valid || uint64(len(data)) < l {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 2, nil)
+ }
+ if l > 0 {
+ m.channelBinding.Purpose, data = data[:l], data[l:]
+ }
+ if l, data, valid = readVarUint64(ctx, data); !valid || uint64(len(data)) < l {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 3, nil)
+ }
+ if l > 0 {
+ m.channelBinding.Hash, data = security.Hash(data[:l]), data[l:]
+ }
+ if l, data, valid = readVarUint64(ctx, data); !valid || uint64(len(data)) < l {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 4, nil)
+ }
+ if l > 0 {
+ m.channelBinding.R, data = data[:l], data[l:]
+ }
+ if l, data, valid = readVarUint64(ctx, data); !valid || uint64(len(data)) < l {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 5, nil)
+ }
+ if l > 0 {
+ m.channelBinding.S, data = data[:l], data[l:]
+ }
+ if len(data) > 0 {
+ var err error
+ m.publicKey, err = security.UnmarshalPublicKey(data)
+ return err
+ }
+ return nil
+}
+
// openFlow is sent at the beginning of every new flow.
type openFlow struct {
id flowID
initialCounters uint64
+ bkey, dkey uint64
}
func (m *openFlow) write(ctx *context.T, p *messagePipe) error {
p.controlBuf = writeVarUint64(uint64(m.id), p.controlBuf[:0])
p.controlBuf = writeVarUint64(m.initialCounters, p.controlBuf)
- return p.write([][]byte{{controlType}}, [][]byte{{openFlowCmd}, p.controlBuf})
+ p.controlBuf = writeVarUint64(m.bkey, p.controlBuf)
+ p.controlBuf = writeVarUint64(m.dkey, p.controlBuf)
+ return p.write(ctx, [][]byte{{controlType, openFlowCmd}, p.controlBuf})
}
func (m *openFlow) read(ctx *context.T, orig []byte) error {
var (
@@ -117,11 +250,17 @@
v uint64
)
if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, openFlowCmd, int64(len(orig)), 0)
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 0, nil)
}
m.id = flowID(v)
if m.initialCounters, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, openFlowCmd, int64(len(orig)), 1)
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 1, nil)
+ }
+ if m.bkey, data, valid = readVarUint64(ctx, data); !valid {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 2, nil)
+ }
+ if m.dkey, data, valid = readVarUint64(ctx, data); !valid {
+ return NewErrInvalidControlMsg(ctx, openFlowCmd, uint64(len(orig)), 3, nil)
}
return nil
}
@@ -138,14 +277,14 @@
p.controlBuf = writeVarUint64(uint64(fid), p.controlBuf)
p.controlBuf = writeVarUint64(val, p.controlBuf)
}
- return p.write([][]byte{{controlType}}, [][]byte{{releaseCmd}, p.controlBuf})
+ return p.write(ctx, [][]byte{{controlType, releaseCmd}, p.controlBuf})
}
func (m *release) read(ctx *context.T, orig []byte) error {
var (
data = orig
valid bool
fid, val uint64
- n int64
+ n uint64
)
if len(data) == 0 {
return nil
@@ -153,10 +292,10 @@
m.counters = map[flowID]uint64{}
for len(data) > 0 {
if fid, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, releaseCmd, int64(len(orig)), n)
+ return NewErrInvalidControlMsg(ctx, releaseCmd, uint64(len(orig)), n, nil)
}
if val, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidControlMsg(ctx, releaseCmd, int64(len(orig)), n+1)
+ return NewErrInvalidControlMsg(ctx, releaseCmd, uint64(len(orig)), n+1, nil)
}
m.counters[flowID(fid)] = val
n += 2
@@ -172,10 +311,9 @@
}
func (m *data) write(ctx *context.T, p *messagePipe) error {
- p.dataBuf = writeVarUint64(uint64(m.id), p.dataBuf[:0])
- p.dataBuf = writeVarUint64(m.flags, p.dataBuf)
- encrypted := append([][]byte{p.dataBuf}, m.payload...)
- return p.write([][]byte{{dataType}}, encrypted)
+ p.controlBuf = writeVarUint64(uint64(m.id), p.controlBuf[:0])
+ p.controlBuf = writeVarUint64(m.flags, p.controlBuf)
+ return p.write(ctx, append([][]byte{{dataType}, p.controlBuf}, m.payload...))
}
func (m *data) read(ctx *context.T, orig []byte) error {
var (
@@ -184,11 +322,11 @@
v uint64
)
if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidMsg(ctx, dataType, int64(len(orig)), 0)
+ return NewErrInvalidMsg(ctx, dataType, uint64(len(orig)), 0)
}
m.id = flowID(v)
if m.flags, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidMsg(ctx, dataType, int64(len(orig)), 1)
+ return NewErrInvalidMsg(ctx, dataType, uint64(len(orig)), 1)
}
if len(data) > 0 {
m.payload = [][]byte{data}
@@ -204,16 +342,13 @@
}
func (m *unencryptedData) write(ctx *context.T, p *messagePipe) error {
- p.dataBuf = writeVarUint64(uint64(m.id), p.dataBuf[:0])
- p.dataBuf = writeVarUint64(m.flags, p.dataBuf)
- // re-use the controlBuf for the data size.
- size := uint64(0)
- for _, b := range m.payload {
- size += uint64(len(b))
+ p.controlBuf = writeVarUint64(uint64(m.id), p.controlBuf[:0])
+ p.controlBuf = writeVarUint64(m.flags, p.controlBuf)
+ if err := p.write(ctx, [][]byte{{unencryptedDataType}, p.controlBuf}); err != nil {
+ return err
}
- p.controlBuf = writeVarUint64(size, p.controlBuf[:0])
- unencrypted := append([][]byte{[]byte{unencryptedDataType}, p.controlBuf}, m.payload...)
- return p.write(unencrypted, [][]byte{p.dataBuf})
+ _, err := p.rw.WriteMsg(m.payload...)
+ return err
}
func (m *unencryptedData) read(ctx *context.T, orig []byte) error {
var (
@@ -222,49 +357,67 @@
v uint64
)
if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 0)
- }
- plen := int(v)
- if plen > len(data) {
- return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 1)
- }
- if plen > 0 {
- m.payload, data = [][]byte{data[:plen]}, data[plen:]
- }
- if v, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 2)
+ return NewErrInvalidMsg(ctx, unencryptedDataType, uint64(len(orig)), 2)
}
m.id = flowID(v)
if m.flags, data, valid = readVarUint64(ctx, data); !valid {
- return NewErrInvalidMsg(ctx, unencryptedDataType, int64(len(orig)), 3)
+ return NewErrInvalidMsg(ctx, unencryptedDataType, uint64(len(orig)), 3)
}
return nil
}
+// TODO(mattr): Consider cleaning up the ControlCipher library to
+// eliminate extraneous functionality and reduce copying.
type messagePipe struct {
rw MsgReadWriteCloser
+ cipher crypto.ControlCipher
controlBuf []byte
- dataBuf []byte
- outBuf [][]byte
+ encBuf []byte
}
func newMessagePipe(rw MsgReadWriteCloser) *messagePipe {
return &messagePipe{
rw: rw,
controlBuf: make([]byte, 256),
- dataBuf: make([]byte, 2*maxVarUint64Size),
- outBuf: make([][]byte, 5),
+ encBuf: make([]byte, mtu),
+ cipher: &crypto.NullControlCipher{},
}
}
+func (p *messagePipe) setupEncryption(ctx *context.T, pk, sk, opk *[32]byte) []byte {
+ p.cipher = crypto.NewControlCipherRPC11(
+ (*crypto.BoxKey)(pk),
+ (*crypto.BoxKey)(sk),
+ (*crypto.BoxKey)(opk))
+ return p.cipher.ChannelBinding()
+}
+
func (p *messagePipe) close() error {
return p.rw.Close()
}
-func (p *messagePipe) write(unencrypted [][]byte, encrypted [][]byte) error {
- p.outBuf = append(p.outBuf[:0], unencrypted...)
- p.outBuf = append(p.outBuf, encrypted...)
- _, err := p.rw.WriteMsg(p.outBuf...)
+func (p *messagePipe) write(ctx *context.T, encrypted [][]byte) error {
+ // TODO(mattr): Because of the API of the underlying crypto library,
+ // an enormous amount of copying happens here.
+ // TODO(mattr): We allocate many buffers here to hold potentially
+ // many copies of the data. The maximum memory usage per Conn is probably
+ // quite high. We should try to reduce it.
+ needed := p.cipher.MACSize()
+ for _, b := range encrypted {
+ needed += len(b)
+ }
+ if cap(p.encBuf) < needed {
+ p.encBuf = make([]byte, needed)
+ }
+ p.encBuf = p.encBuf[:0]
+ for _, b := range encrypted {
+ p.encBuf = append(p.encBuf, b...)
+ }
+ p.encBuf = p.encBuf[:needed]
+ if err := p.cipher.Seal(p.encBuf); err != nil {
+ return err
+ }
+ _, err := p.rw.WriteMsg(p.encBuf)
return err
}
@@ -274,39 +427,56 @@
func (p *messagePipe) readMsg(ctx *context.T) (message, error) {
msg, err := p.rw.ReadMsg()
- if len(msg) == 0 {
- return nil, NewErrInvalidMsg(ctx, invalidType, 0, 0)
- }
if err != nil {
return nil, err
}
+ minSize := 2 + p.cipher.MACSize()
+ if len(msg) < minSize || !p.cipher.Open(msg) {
+ return nil, NewErrInvalidMsg(ctx, invalidType, 0, 0)
+ }
+ logmsg := msg
+ if len(msg) > 128 {
+ logmsg = logmsg[:128]
+ }
+ msgType, msg := msg[0], msg[1:len(msg)-p.cipher.MACSize()]
var m message
- switch msg[0] {
+ switch msgType {
case controlType:
- if len(msg) == 1 {
- return nil, NewErrInvalidControlMsg(ctx, invalidCmd, 0, 1)
- }
- msg = msg[1:]
- switch msg[0] {
+ var msgCmd byte
+ msgCmd, msg = msg[0], msg[1:]
+ switch msgCmd {
case setupCmd:
m = &setup{}
case tearDownCmd:
m = &tearDown{}
+ case authCmd:
+ m = &auth{}
case openFlowCmd:
m = &openFlow{}
case releaseCmd:
m = &release{}
default:
- return nil, NewErrUnknownControlMsg(ctx, msg[0])
+ return nil, NewErrUnknownControlMsg(ctx, msgCmd)
}
case dataType:
m = &data{}
case unencryptedDataType:
- m = &unencryptedData{}
+ ud := &unencryptedData{}
+ payload, err := p.rw.ReadMsg()
+ if err != nil {
+ return nil, err
+ }
+ if len(payload) > 0 {
+ ud.payload = [][]byte{payload}
+ }
+ m = ud
default:
- return nil, NewErrUnknownMsg(ctx, msg[0])
+ return nil, NewErrUnknownMsg(ctx, msgType)
}
- return m, m.read(ctx, msg[1:])
+ if err = m.read(ctx, msg); err == nil {
+ ctx.VI(2).Infof("Read low-level message: %#v", m)
+ }
+ return m, err
}
func readVarUint64(ctx *context.T, data []byte) (uint64, []byte, bool) {
diff --git a/runtime/internal/flow/conn/message_test.go b/runtime/internal/flow/conn/message_test.go
index 2414c1c..81ceb64 100644
--- a/runtime/internal/flow/conn/message_test.go
+++ b/runtime/internal/flow/conn/message_test.go
@@ -48,9 +48,7 @@
}
}
-func testMessages(t *testing.T, cases []message) {
- ctx, shutdown := v23.Init()
- defer shutdown()
+func testMessages(t *testing.T, ctx *context.T, cases []message) {
w, r, _ := newMRWPair(ctx)
wp, rp := newMessagePipe(w), newMessagePipe(r)
for _, want := range cases {
@@ -73,28 +71,69 @@
}
func TestSetup(t *testing.T) {
- testMessages(t, []message{
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ ep1, err := v23.NewEndpoint(
+ "@5@tcp@foo.com:1234@00112233445566778899aabbccddeeff@m@v.io/foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ep2, err := v23.NewEndpoint(
+ "@5@tcp@bar.com:1234@00112233445566778899aabbccddeeff@m@v.io/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ testMessages(t, ctx, []message{
&setup{versions: version.RPCVersionRange{Min: 3, Max: 5}},
+ &setup{
+ versions: version.RPCVersionRange{Min: 3, Max: 5},
+ peerNaClPublicKey: &[32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
+ peerRemoteEndpoint: ep1,
+ peerLocalEndpoint: ep2,
+ },
&setup{},
})
}
func TestTearDown(t *testing.T) {
- testMessages(t, []message{
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ testMessages(t, ctx, []message{
&tearDown{Message: "foobar"},
&tearDown{},
})
}
+func TestAuth(t *testing.T) {
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ p := v23.GetPrincipal(ctx)
+ sig, err := p.Sign([]byte("message"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ testMessages(t, ctx, []message{
+ &auth{bkey: 1, dkey: 5, channelBinding: sig, publicKey: p.PublicKey()},
+ &auth{bkey: 1, dkey: 5, channelBinding: sig},
+ &auth{channelBinding: sig, publicKey: p.PublicKey()},
+ &auth{},
+ })
+}
+
func TestOpenFlow(t *testing.T) {
- testMessages(t, []message{
- &openFlow{id: 23, initialCounters: 1 << 20},
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ testMessages(t, ctx, []message{
+ &openFlow{id: 23, initialCounters: 1 << 20, bkey: 42, dkey: 55},
&openFlow{},
})
}
func TestAddReceiveBuffers(t *testing.T) {
- testMessages(t, []message{
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ testMessages(t, ctx, []message{
&release{},
&release{counters: map[flowID]uint64{
4: 233,
@@ -104,14 +143,18 @@
}
func TestData(t *testing.T) {
- testMessages(t, []message{
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ testMessages(t, ctx, []message{
&data{id: 1123, flags: 232, payload: [][]byte{[]byte("fake payload")}},
&data{},
})
}
func TestUnencryptedData(t *testing.T) {
- testMessages(t, []message{
+ ctx, shutdown := v23.Init()
+ defer shutdown()
+ testMessages(t, ctx, []message{
&unencryptedData{id: 1123, flags: 232, payload: [][]byte{[]byte("fake payload")}},
&unencryptedData{},
})
diff --git a/runtime/internal/flow/conn/types.vdl b/runtime/internal/flow/conn/types.vdl
new file mode 100644
index 0000000..de776fa
--- /dev/null
+++ b/runtime/internal/flow/conn/types.vdl
@@ -0,0 +1,18 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package conn
+
+import "v.io/v23/security"
+
+// Blessings is used to transport blessings and their discharges
+// between the two ends of a Conn. Since these objects can be large
+// we try not to send them more than once, therefore whenever we send
+// new blessings or discharges we associate them with an integer
+// key (BKey and DKey). Thereafter we refer to them by their key.
+type Blessings struct {
+ Blessings security.WireBlessings
+ Discharges []security.WireDischarge
+ BKey, DKey uint64
+}
\ No newline at end of file
diff --git a/runtime/internal/flow/conn/types.vdl.go b/runtime/internal/flow/conn/types.vdl.go
new file mode 100644
index 0000000..dc4bcd9
--- /dev/null
+++ b/runtime/internal/flow/conn/types.vdl.go
@@ -0,0 +1,37 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file was auto-generated by the vanadium vdl tool.
+// Source: types.vdl
+
+package conn
+
+import (
+ // VDL system imports
+ "v.io/v23/vdl"
+
+ // VDL user imports
+ "v.io/v23/security"
+)
+
+// Blessings is used to transport blessings and their discharges
+// between the two ends of a Conn. Since these objects can be large
+// we try not to send them more than once, therefore whenever we send
+// new blessings or discharges we associate them with an integer
+// key (BKey and DKey). Thereafter we refer to them by their key.
+type Blessings struct {
+ Blessings security.Blessings
+ Discharges []security.Discharge
+ BKey uint64
+ DKey uint64
+}
+
+func (Blessings) __VDLReflect(struct {
+ Name string `vdl:"v.io/x/ref/runtime/internal/flow/conn.Blessings"`
+}) {
+}
+
+func init() {
+ vdl.Register((*Blessings)(nil))
+}
diff --git a/runtime/internal/flow/conn/util_test.go b/runtime/internal/flow/conn/util_test.go
index 5694830..140515a 100644
--- a/runtime/internal/flow/conn/util_test.go
+++ b/runtime/internal/flow/conn/util_test.go
@@ -12,8 +12,10 @@
"v.io/v23"
"v.io/v23/context"
"v.io/v23/flow"
+ "v.io/v23/naming"
"v.io/v23/rpc/version"
"v.io/v23/security"
+ "v.io/x/ref/internal/logger"
)
type wire struct {
@@ -56,6 +58,11 @@
for _, d := range data {
buf = append(buf, d...)
}
+ logbuf := buf
+ if len(buf) > 128 {
+ logbuf = buf[:128]
+ }
+ logger.Global().VI(2).Infof("Writing %d bytes to the wire: %#v", len(buf), logbuf)
defer f.wire.mu.Unlock()
f.wire.mu.Lock()
for f.peer.in != nil && !f.wire.closed {
@@ -79,6 +86,11 @@
}
buf, f.in = f.in, nil
f.wire.c.Broadcast()
+ logbuf := buf
+ if len(buf) > 128 {
+ logbuf = buf[:128]
+ }
+ logger.Global().VI(2).Infof("Reading %d bytes from the wire: %#v", len(buf), logbuf)
return buf, nil
}
func (f *mRW) Close() error {
@@ -103,15 +115,31 @@
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
- d, err := NewDialed(dctx, dmrw, ep, ep, versions, fh(dflows), nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- a, err := NewAccepted(actx, amrw, ep, security.Blessings{}, versions, fh(aflows))
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- return d, a, w
+ dch := make(chan *Conn)
+ ach := make(chan *Conn)
+ go func() {
+ var handler FlowHandler
+ if dflows != nil {
+ handler = fh(dflows)
+ }
+ d, err := NewDialed(dctx, dmrw, ep, ep, versions, handler)
+ if err != nil {
+ panic(err)
+ }
+ dch <- d
+ }()
+ go func() {
+ var handler FlowHandler
+ if aflows != nil {
+ handler = fh(aflows)
+ }
+ a, err := NewAccepted(actx, amrw, ep, versions, handler)
+ if err != nil {
+ panic(err)
+ }
+ ach <- a
+ }()
+ return <-dch, <-ach, w
}
func setupFlow(t *testing.T, dctx, actx *context.T, dialFromDialer bool) (dialed flow.Flow, accepted <-chan flow.Flow) {
@@ -129,9 +157,29 @@
dialed = make([]flow.Flow, n)
for i := 0; i < n; i++ {
var err error
- if dialed[i], err = d.Dial(dctx); err != nil {
+ if dialed[i], err = d.Dial(dctx, testBFP); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
return dialed, aflows
}
+
+func testBFP(
+ ctx *context.T,
+ localEndpoint, remoteEndpoint naming.Endpoint,
+ remoteBlessings security.Blessings,
+ remoteDischarges map[string]security.Discharge,
+) (security.Blessings, error) {
+ return v23.GetPrincipal(ctx).BlessingStore().Default(), nil
+}
+
+func makeBFP(in security.Blessings) flow.BlessingsForPeer {
+ return func(
+ ctx *context.T,
+ localEndpoint, remoteEndpoint naming.Endpoint,
+ remoteBlessings security.Blessings,
+ remoteDischarges map[string]security.Discharge,
+ ) (security.Blessings, error) {
+ return in, nil
+ }
+}
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index 4bdca655..ac26aa1 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -11,7 +11,6 @@
"syscall"
"time"
- "v.io/v23"
"v.io/v23/context"
"v.io/v23/flow"
"v.io/v23/naming"
@@ -90,7 +89,6 @@
ctx,
&framer{ReadWriteCloser: netConn},
local,
- v23.GetPrincipal(ctx).BlessingStore().Default(),
version.Supported,
&flowHandler{q: m.q, closed: m.closed},
)
@@ -194,6 +192,9 @@
if err != nil {
return nil, flow.NewErrDialFailed(ctx, err)
}
+ // TODO(mattr): We should only pass a flowHandler to NewDialed if there
+ // is a server attached to this flow manager. Perhaps we can signal
+ // "serving flow manager" by passing a 0 RID to non-serving flow managers?
c, err = conn.NewDialed(
ctx,
&framer{ReadWriteCloser: netConn}, // TODO(suharshs): Don't frame if the net.Conn already has framing in its protocol.
@@ -201,7 +202,6 @@
remote,
version.Supported,
&flowHandler{q: m.q, closed: m.closed},
- fn,
)
if err != nil {
return nil, flow.NewErrDialFailed(ctx, err)
@@ -210,7 +210,7 @@
return nil, flow.NewErrBadState(ctx, err)
}
}
- f, err := c.Dial(ctx)
+ f, err := c.Dial(ctx, fn)
if err != nil {
return nil, flow.NewErrDialFailed(ctx, err)
}
diff --git a/runtime/internal/flow/manager/manager_test.go b/runtime/internal/flow/manager/manager_test.go
index 9fe2d50..77d8c57 100644
--- a/runtime/internal/flow/manager/manager_test.go
+++ b/runtime/internal/flow/manager/manager_test.go
@@ -36,7 +36,14 @@
t.Fatal(err)
}
- bFn := func(*context.T, security.Call) (security.Blessings, error) { return p.BlessingStore().Default(), nil }
+ bFn := func(
+ ctx *context.T,
+ localEndpoint, remoteEndpoint naming.Endpoint,
+ remoteBlessings security.Blessings,
+ remoteDischarges map[string]security.Discharge,
+ ) (security.Blessings, error) {
+ return p.BlessingStore().Default(), nil
+ }
eps := m.ListeningEndpoints()
if len(eps) == 0 {
t.Fatalf("no endpoints listened on")
@@ -70,7 +77,14 @@
t.Fatal(err)
}
- bFn := func(*context.T, security.Call) (security.Blessings, error) { return p.BlessingStore().Default(), nil }
+ bFn := func(
+ ctx *context.T,
+ localEndpoint, remoteEndpoint naming.Endpoint,
+ remoteBlessings security.Blessings,
+ remoteDischarges map[string]security.Discharge,
+ ) (security.Blessings, error) {
+ return p.BlessingStore().Default(), nil
+ }
eps := am.ListeningEndpoints()
if len(eps) == 0 {
t.Fatalf("no endpoints listened on")
diff --git a/runtime/internal/naming/endpoint.go b/runtime/internal/naming/endpoint.go
index f354743..9f32a92 100644
--- a/runtime/internal/naming/endpoint.go
+++ b/runtime/internal/naming/endpoint.go
@@ -20,6 +20,7 @@
separator = "@"
suffix = "@@"
blessingsSeparator = ","
+ routeSeparator = ","
)
var (
@@ -27,6 +28,8 @@
hostportEP = regexp.MustCompile("^(?:\\((.*)\\)@)?([^@]+)$")
)
+// TODO(suharshs): Remove endpoint version 5 after the transition to 6 is complete.
+
// Network is the string returned by naming.Endpoint.Network implementations
// defined in this package.
const Network = "v23"
@@ -36,6 +39,7 @@
Protocol string
Address string
RID naming.RoutingID
+ RouteList []string
Blessings []string
IsMountTable bool
IsLeaf bool
@@ -67,6 +71,8 @@
}
switch version {
+ case 6:
+ err = ep.parseV6(parts)
case 5:
err = ep.parseV5(parts)
default:
@@ -141,10 +147,58 @@
return nil
}
+func (ep *Endpoint) parseV6(parts []string) error {
+ if len(parts) < 6 {
+ return errInvalidEndpointString
+ }
+
+ ep.Protocol = parts[1]
+ if len(ep.Protocol) == 0 {
+ ep.Protocol = naming.UnknownProtocol
+ }
+
+ var ok bool
+ if ep.Address, ok = naming.Unescape(parts[2]); !ok {
+ return fmt.Errorf("invalid address: bad escape %s", parts[2])
+ }
+ if len(ep.Address) == 0 {
+ ep.Address = net.JoinHostPort("", "0")
+ }
+
+ if len(parts[3]) > 0 {
+ ep.RouteList = strings.Split(parts[3], routeSeparator)
+ for i := range ep.RouteList {
+ if ep.RouteList[i], ok = naming.Unescape(ep.RouteList[i]); !ok {
+ return fmt.Errorf("invalid route: bad escape %s", ep.RouteList[i])
+ }
+ }
+ }
+
+ if err := ep.RID.FromString(parts[4]); err != nil {
+ return fmt.Errorf("invalid routing id: %v", err)
+ }
+
+ var err error
+ if ep.IsMountTable, ep.IsLeaf, err = parseMountTableFlag(parts[5]); err != nil {
+ return fmt.Errorf("invalid mount table flag: %v", err)
+ }
+ // Join the remaining and re-split.
+ if str := strings.Join(parts[6:], separator); len(str) > 0 {
+ ep.Blessings = strings.Split(str, blessingsSeparator)
+ }
+ return nil
+}
+
func (ep *Endpoint) RoutingID() naming.RoutingID {
//nologcall
return ep.RID
}
+
+func (ep *Endpoint) Routes() []string {
+ //nologcall
+ return ep.RouteList
+}
+
func (ep *Endpoint) Network() string {
//nologcall
return Network
@@ -159,8 +213,6 @@
func (ep *Endpoint) VersionedString(version int) string {
// nologcall
switch version {
- default:
- return ep.VersionedString(defaultVersion)
case 5:
mt := "s"
switch {
@@ -172,6 +224,24 @@
blessings := strings.Join(ep.Blessings, blessingsSeparator)
return fmt.Sprintf("@5@%s@%s@%s@%s@%s@@",
ep.Protocol, naming.Escape(ep.Address, "@"), ep.RID, mt, blessings)
+ case 6:
+ mt := "s"
+ switch {
+ case ep.IsLeaf:
+ mt = "l"
+ case ep.IsMountTable:
+ mt = "m"
+ }
+ blessings := strings.Join(ep.Blessings, blessingsSeparator)
+ escaped := make([]string, len(ep.RouteList))
+ for i := range ep.RouteList {
+ escaped[i] = naming.Escape(ep.RouteList[i], routeSeparator)
+ }
+ routes := strings.Join(escaped, routeSeparator)
+ return fmt.Sprintf("@6@%s@%s@%s@%s@%s@%s@@",
+ ep.Protocol, naming.Escape(ep.Address, "@"), routes, ep.RID, mt, blessings)
+ default:
+ return ep.VersionedString(defaultVersion)
}
}
diff --git a/runtime/internal/naming/endpoint_test.go b/runtime/internal/naming/endpoint_test.go
index f1d6d56..f699440 100644
--- a/runtime/internal/naming/endpoint_test.go
+++ b/runtime/internal/naming/endpoint_test.go
@@ -12,7 +12,7 @@
"v.io/v23/naming"
)
-func TestEndpoint(t *testing.T) {
+func TestEndpointV5(t *testing.T) {
defver := defaultVersion
defer func() {
defaultVersion = defver
@@ -56,7 +56,6 @@
// Blessings that look similar to other parts of the endpoint.
Blessings: []string{"@@", "@s", "@m"},
}
-
testcasesA := []struct {
endpoint naming.Endpoint
address string
@@ -71,7 +70,6 @@
t.Errorf("unexpected address %q, not %q", addr.String(), test.address)
}
}
-
// Test v5 endpoints.
testcasesC := []struct {
Endpoint naming.Endpoint
@@ -85,6 +83,106 @@
{v5e, "@5@tcp@batman.com:2345@0000000000000000000000000000ba77@m@dev.v.io/foo@bar.com,dev.v.io/bar@bar.com/delegate@@", 5},
{v5f, "@5@tcp@batman.com:2345@0000000000000000000000000000ba77@m@@@,@s,@m@@", 5},
}
+ for i, test := range testcasesC {
+ if got, want := test.Endpoint.VersionedString(test.Version), test.String; got != want {
+ t.Errorf("Test %d: Got %q want %q for endpoint (v%d): %#v", i, got, want, test.Version, test.Endpoint)
+ }
+ ep, err := NewEndpoint(test.String)
+ if err != nil {
+ t.Errorf("Test %d: NewEndpoint(%q) failed with %v", i, test.String, err)
+ continue
+ }
+ if !reflect.DeepEqual(ep, test.Endpoint) {
+ t.Errorf("Test %d: Got endpoint %#v, want %#v for string %q", i, ep, test.Endpoint, test.String)
+ }
+ }
+}
+
+func TestEndpoint(t *testing.T) {
+ defver := defaultVersion
+ defer func() {
+ defaultVersion = defver
+ }()
+ defaultVersion = 6
+ v6a := &Endpoint{
+ Protocol: naming.UnknownProtocol,
+ Address: "batman.com:1234",
+ RID: naming.FixedRoutingID(0xdabbad00),
+ IsMountTable: true,
+ }
+ v6b := &Endpoint{
+ Protocol: naming.UnknownProtocol,
+ Address: "batman.com:2345",
+ RID: naming.FixedRoutingID(0xdabbad00),
+ IsMountTable: false,
+ }
+ v6c := &Endpoint{
+ Protocol: "tcp",
+ Address: "batman.com:2345",
+ RID: naming.FixedRoutingID(0x0),
+ IsMountTable: false,
+ }
+ v6d := &Endpoint{
+ Protocol: "ws6",
+ Address: "batman.com:2345",
+ RID: naming.FixedRoutingID(0x0),
+ IsMountTable: false,
+ }
+ v6e := &Endpoint{
+ Protocol: "tcp",
+ Address: "batman.com:2345",
+ RID: naming.FixedRoutingID(0xba77),
+ RouteList: []string{"1"},
+ IsMountTable: true,
+ Blessings: []string{"dev.v.io/foo@bar.com", "dev.v.io/bar@bar.com/delegate"},
+ }
+ v6f := &Endpoint{
+ Protocol: "tcp",
+ Address: "batman.com:2345",
+ RouteList: []string{"1", "2", "3"},
+ RID: naming.FixedRoutingID(0xba77),
+ IsMountTable: true,
+ // Blessings that look similar to other parts of the endpoint.
+ Blessings: []string{"@@", "@s", "@m"},
+ }
+ v6g := &Endpoint{
+ Protocol: "tcp",
+ Address: "batman.com:2345",
+ // Routes that have commas should be escaped correctly
+ RouteList: []string{"a,b", ",ab", "ab,"},
+ RID: naming.FixedRoutingID(0xba77),
+ IsMountTable: true,
+ }
+
+ testcasesA := []struct {
+ endpoint naming.Endpoint
+ address string
+ }{
+ {v6a, "batman.com:1234"},
+ {v6b, "batman.com:2345"},
+ {v6c, "batman.com:2345"},
+ }
+ for _, test := range testcasesA {
+ addr := test.endpoint.Addr()
+ if addr.String() != test.address {
+ t.Errorf("unexpected address %q, not %q", addr.String(), test.address)
+ }
+ }
+
+ // Test v6 endpoints.
+ testcasesC := []struct {
+ Endpoint naming.Endpoint
+ String string
+ Version int
+ }{
+ {v6a, "@6@@batman.com:1234@@000000000000000000000000dabbad00@m@@@", 6},
+ {v6b, "@6@@batman.com:2345@@000000000000000000000000dabbad00@s@@@", 6},
+ {v6c, "@6@tcp@batman.com:2345@@00000000000000000000000000000000@s@@@", 6},
+ {v6d, "@6@ws6@batman.com:2345@@00000000000000000000000000000000@s@@@", 6},
+ {v6e, "@6@tcp@batman.com:2345@1@0000000000000000000000000000ba77@m@dev.v.io/foo@bar.com,dev.v.io/bar@bar.com/delegate@@", 6},
+ {v6f, "@6@tcp@batman.com:2345@1,2,3@0000000000000000000000000000ba77@m@@@,@s,@m@@", 6},
+ {v6g, "@6@tcp@batman.com:2345@a%2Cb,%2Cab,ab%2C@0000000000000000000000000000ba77@m@@@", 6},
+ }
for i, test := range testcasesC {
if got, want := test.Endpoint.VersionedString(test.Version), test.String; got != want {
@@ -131,13 +229,14 @@
defer func() {
defaultVersion = defver
}()
+ defaultVersion = 6
testcases := []endpointTest{
- {"localhost:10", "@5@@localhost:10@00000000000000000000000000000000@m@@@", nil},
- {"localhost:", "@5@@localhost:@00000000000000000000000000000000@m@@@", nil},
+ {"localhost:10", "@6@@localhost:10@@00000000000000000000000000000000@m@@@", nil},
+ {"localhost:", "@6@@localhost:@@00000000000000000000000000000000@m@@@", nil},
{"localhost", "", errInvalidEndpointString},
- {"(dev.v.io/service/mounttabled)@ns.dev.v.io:8101", "@5@@ns.dev.v.io:8101@00000000000000000000000000000000@m@dev.v.io/service/mounttabled@@", nil},
- {"(dev.v.io/users/foo@bar.com)@ns.dev.v.io:8101", "@5@@ns.dev.v.io:8101@00000000000000000000000000000000@m@dev.v.io/users/foo@bar.com@@", nil},
- {"(@1@tcp)@ns.dev.v.io:8101", "@5@@ns.dev.v.io:8101@00000000000000000000000000000000@m@@1@tcp@@", nil},
+ {"(dev.v.io/service/mounttabled)@ns.dev.v.io:8101", "@6@@ns.dev.v.io:8101@@00000000000000000000000000000000@m@dev.v.io/service/mounttabled@@", nil},
+ {"(dev.v.io/users/foo@bar.com)@ns.dev.v.io:8101", "@6@@ns.dev.v.io:8101@@00000000000000000000000000000000@m@dev.v.io/users/foo@bar.com@@", nil},
+ {"(@1@tcp)@ns.dev.v.io:8101", "@6@@ns.dev.v.io:8101@@00000000000000000000000000000000@m@@1@tcp@@", nil},
}
runEndpointTests(t, testcases)
}
diff --git a/runtime/internal/rpc/reserved.go b/runtime/internal/rpc/reserved.go
index 35adc61..e47f13b 100644
--- a/runtime/internal/rpc/reserved.go
+++ b/runtime/internal/rpc/reserved.go
@@ -268,7 +268,7 @@
continue
}
gs := invoker.Globber()
- if gs == nil || (gs.AllGlobber == nil && gs.ChildrenGlobber == nil && gs.AllGlobberX == nil && gs.ChildrenGlobberX == nil) {
+ if gs == nil || (gs.AllGlobber == nil && gs.ChildrenGlobber == nil) {
if state.glob.Len() == 0 {
subcall.Send(naming.GlobReplyEntry{
Value: naming.MountEntry{Name: state.name, IsLeaf: true},
@@ -280,12 +280,6 @@
}
continue
}
- if gs.AllGlobberX != nil {
- gs.AllGlobber = gs.AllGlobberX
- }
- if gs.ChildrenGlobberX != nil {
- gs.ChildrenGlobber = gs.ChildrenGlobberX
- }
if gs.AllGlobber != nil {
ctx.VI(3).Infof("rpc Glob: %q implements AllGlobber", suffix)
send := func(reply naming.GlobReply) error {
diff --git a/runtime/internal/rpc/stream/vc/vc_test.go b/runtime/internal/rpc/stream/vc/vc_test.go
index 7e12aa6..452801a 100644
--- a/runtime/internal/rpc/stream/vc/vc_test.go
+++ b/runtime/internal/rpc/stream/vc/vc_test.go
@@ -716,6 +716,7 @@
func (e endpoint) String() string { return naming.RoutingID(e).String() }
func (e endpoint) Name() string { return naming.JoinAddressName(e.String(), "") }
func (e endpoint) RoutingID() naming.RoutingID { return naming.RoutingID(e) }
+func (e endpoint) Routes() []string { return nil }
func (e endpoint) Addr() net.Addr { return nil }
func (e endpoint) ServesMountTable() bool { return false }
func (e endpoint) ServesLeaf() bool { return false }
diff --git a/services/device/claimable/claimable_v23_test.go b/services/device/claimable/claimable_v23_test.go
new file mode 100644
index 0000000..ce8da3a
--- /dev/null
+++ b/services/device/claimable/claimable_v23_test.go
@@ -0,0 +1,116 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main_test
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ "v.io/v23/security"
+ lsecurity "v.io/x/ref/lib/security"
+ "v.io/x/ref/test/modules"
+ "v.io/x/ref/test/v23tests"
+)
+
+//go:generate v23 test generate
+
+func V23TestClaimableServer(t *v23tests.T) {
+ workdir, err := ioutil.TempDir("", "claimable-test-")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir failed: %v", err)
+ }
+ defer os.RemoveAll(workdir)
+
+ permsDir := filepath.Join(workdir, "perms")
+
+ serverCreds, err := detachedCredentials(t, "server")
+ if err != nil {
+ t.Fatalf("Failed to create server credentials: %v", err)
+ }
+ legitClientCreds, err := t.Shell().NewChildCredentials("legit")
+ if err != nil {
+ t.Fatalf("Failed to create legit credentials: %v", err)
+ }
+ badClientCreds1, err := t.Shell().NewCustomCredentials()
+ if err != nil {
+ t.Fatalf("Failed to create bad credentials: %v", err)
+ }
+ badClientCreds2, err := t.Shell().NewChildCredentials("other-guy")
+ if err != nil {
+ t.Fatalf("Failed to create bad credentials: %v", err)
+ }
+
+ serverBin := t.BuildV23Pkg("v.io/x/ref/services/device/claimable")
+ serverBin = serverBin.WithStartOpts(serverBin.StartOpts().WithCustomCredentials(serverCreds))
+
+ server := serverBin.Start(
+ "--v23.tcp.address=127.0.0.1:0",
+ "--perms-dir="+permsDir,
+ "--blessing-root="+blessingRoots(t, legitClientCreds.Principal()),
+ "--v23.permissions.literal={\"Admin\":{\"In\":[\"root/legit\"]}}",
+ )
+ addr := server.ExpectVar("NAME")
+
+ clientBin := t.BuildV23Pkg("v.io/x/ref/services/device/device")
+
+ testcases := []struct {
+ creds *modules.CustomCredentials
+ success bool
+ permsExist bool
+ }{
+ {badClientCreds1, false, false},
+ {badClientCreds2, false, false},
+ {legitClientCreds, true, true},
+ }
+
+ for _, tc := range testcases {
+ clientBin = clientBin.WithStartOpts(clientBin.StartOpts().WithCustomCredentials(tc.creds))
+ client := clientBin.Start("claim", addr, "my-device")
+ if err := client.Wait(nil, nil); (err == nil) != tc.success {
+ t.Errorf("Unexpected exit value. Expected success=%v, got err=%v", tc.success, err)
+ }
+ if _, err := os.Stat(permsDir); (err == nil) != tc.permsExist {
+ t.Errorf("Unexpected permsDir state. Got %v, expected %v", err == nil, tc.permsExist)
+ }
+ }
+ // Server should exit cleanly after the successful Claim.
+ if err := server.ExpectEOF(); err != nil {
+ t.Errorf("Expected server to exit cleanly, got %v", err)
+ }
+}
+
+func detachedCredentials(t *v23tests.T, name string) (*modules.CustomCredentials, error) {
+ creds, err := t.Shell().NewCustomCredentials()
+ if err != nil {
+ return nil, err
+ }
+ return creds, lsecurity.InitDefaultBlessings(creds.Principal(), name)
+}
+
+func blessingRoots(t *v23tests.T, p security.Principal) string {
+ pk, ok := p.Roots().Dump()["root"]
+ if !ok || len(pk) == 0 {
+ t.Fatalf("Failed to find root blessing")
+ }
+ der, err := pk[0].MarshalBinary()
+ if err != nil {
+ t.Fatalf("MarshalPublicKey failed: %v", err)
+ }
+ rootInfo := struct {
+ Names []string `json:"names"`
+ PublicKey string `json:"publicKey"`
+ }{
+ Names: []string{"root"},
+ PublicKey: base64.URLEncoding.EncodeToString(der),
+ }
+ out, err := json.Marshal(rootInfo)
+ if err != nil {
+ t.Fatalf("json.Marshal failed: %v", err)
+ }
+ return string(out)
+}
diff --git a/services/device/claimable/doc.go b/services/device/claimable/doc.go
new file mode 100644
index 0000000..dcb6397
--- /dev/null
+++ b/services/device/claimable/doc.go
@@ -0,0 +1,74 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file was auto-generated via go generate.
+// DO NOT UPDATE MANUALLY
+
+/*
+Claimable is a server that implements the Claimable interface from
+v.io/v23/services/device. It exits immediately if the device is already claimed.
+Otherwise, it keeps running until a successful Claim() request is received.
+
+It uses -v23.permissions.* to authorize the Claim request.
+
+Usage:
+ claimable [flags]
+
+The claimable flags are:
+ -blessing-root=
+ The blessing root to trust, JSON-encoded, e.g. from
+ https://v.io/auth/blessing-root
+ -perms-dir=
+ The directory where permissions will be stored.
+
+The global flags are:
+ -alsologtostderr=true
+ log to standard error as well as files
+ -log_backtrace_at=:0
+ when logging hits line file:N, emit a stack trace
+ -log_dir=
+ if non-empty, write log files to this directory
+ -logtostderr=false
+ log to standard error instead of files
+ -max_stack_buf_size=4292608
+ max size in bytes of the buffer to use for logging stack traces
+ -metadata=<just specify -metadata to activate>
+ Displays metadata for the program and exits.
+ -stderrthreshold=2
+ logs at or above this threshold go to stderr
+ -v=0
+ log level for V logs
+ -v23.credentials=
+ directory to use for storing security credentials
+ -v23.i18n-catalogue=
+ 18n catalogue files to load, comma separated
+ -v23.namespace.root=[/(dev.v.io/role/vprod/service/mounttabled)@ns.dev.v.io:8101]
+ local namespace root; can be repeated to provided multiple roots
+ -v23.permissions.file=map[]
+ specify a perms file as <name>:<permsfile>
+ -v23.permissions.literal=
+ explicitly specify the runtime perms as a JSON-encoded access.Permissions.
+ Overrides all --v23.permissions.file flags.
+ -v23.proxy=
+ object name of proxy service to use to export services across network
+ boundaries
+ -v23.tcp.address=
+ address to listen on
+ -v23.tcp.protocol=wsh
+ protocol to listen with
+ -v23.vtrace.cache-size=1024
+ The number of vtrace traces to store in memory.
+ -v23.vtrace.collect-regexp=
+ Spans and annotations that match this regular expression will trigger trace
+ collection.
+ -v23.vtrace.dump-on-shutdown=true
+ If true, dump all stored traces on runtime shutdown.
+ -v23.vtrace.sample-rate=0
+ Rate (from 0.0 to 1.0) to sample vtrace traces.
+ -vmodule=
+ comma-separated list of pattern=N settings for filename-filtered logging
+ -vpath=
+ comma-separated list of pattern=N settings for file pathname-filtered logging
+*/
+package main
diff --git a/services/device/claimable/main.go b/services/device/claimable/main.go
new file mode 100644
index 0000000..383988c
--- /dev/null
+++ b/services/device/claimable/main.go
@@ -0,0 +1,106 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The following enables go generate to generate the doc.go file.
+//go:generate go run $V23_ROOT/release/go/src/v.io/x/lib/cmdline/testdata/gendoc.go . -help
+
+package main
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "v.io/v23"
+ "v.io/v23/context"
+ "v.io/v23/security"
+ "v.io/x/lib/cmdline"
+ "v.io/x/ref/lib/security/securityflag"
+ "v.io/x/ref/lib/signals"
+ "v.io/x/ref/lib/v23cmd"
+ _ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/services/device/internal/claim"
+ "v.io/x/ref/services/identity"
+)
+
+var (
+ permsDir string
+ blessingRoot string
+)
+
+func runServer(ctx *context.T, _ *cmdline.Env, _ []string) error {
+ if blessingRoot != "" {
+ addRoot(ctx, blessingRoot)
+ }
+
+ auth := securityflag.NewAuthorizerOrDie()
+ claimable, claimed := claim.NewClaimableDispatcher(ctx, permsDir, "", auth)
+ if claimable == nil {
+ return errors.New("device is already claimed")
+ }
+
+ server, err := v23.NewServer(ctx)
+ if err != nil {
+ return err
+ }
+ if _, err := server.Listen(v23.GetListenSpec(ctx)); err != nil {
+ return err
+ }
+ if err := server.ServeDispatcher("", claimable); err != nil {
+ return err
+ }
+
+ status := server.Status()
+ ctx.Infof("Listening on: %v", status.Endpoints)
+ if len(status.Endpoints) > 0 {
+ fmt.Printf("NAME=%s\n", status.Endpoints[0].Name())
+ }
+ select {
+ case <-claimed:
+ return nil
+ case s := <-signals.ShutdownOnSignals(ctx):
+ return fmt.Errorf("received signal %v", s)
+ }
+}
+
+func addRoot(ctx *context.T, jRoot string) {
+ var bRoot identity.BlessingRootResponse
+ if err := json.Unmarshal([]byte(jRoot), &bRoot); err != nil {
+ ctx.Fatalf("unable to unmarshal the json blessing root: %v", err)
+ }
+ decodedKey, err := base64.URLEncoding.DecodeString(bRoot.PublicKey)
+ if err != nil {
+ ctx.Fatalf("unable to decode public key: %v", err)
+ }
+ key, err := security.UnmarshalPublicKey(decodedKey)
+ if err != nil {
+ ctx.Fatalf("unable to unmarshal the public key: %v", err)
+ }
+ roots := v23.GetPrincipal(ctx).Roots()
+ for _, name := range bRoot.Names {
+ if err := roots.Add(key, security.BlessingPattern(name)); err != nil {
+ ctx.Fatalf("unable to add root: %v", err)
+ }
+ }
+}
+
+func main() {
+ rootCmd := &cmdline.Command{
+ Name: "claimable",
+ Short: "Run claimable server",
+ Long: `
+Claimable is a server that implements the Claimable interface from
+v.io/v23/services/device. It exits immediately if the device is already
+claimed. Otherwise, it keeps running until a successful Claim() request
+is received.
+
+It uses -v23.permissions.* to authorize the Claim request.
+`,
+ Runner: v23cmd.RunnerFunc(runServer),
+ }
+ rootCmd.Flags.StringVar(&permsDir, "perms-dir", "", "The directory where permissions will be stored.")
+ rootCmd.Flags.StringVar(&blessingRoot, "blessing-root", "", "The blessing root to trust, JSON-encoded, e.g. from https://v.io/auth/blessing-root")
+ cmdline.Main(rootCmd)
+}
diff --git a/services/device/claimable/v23_test.go b/services/device/claimable/v23_test.go
new file mode 100644
index 0000000..ae373a9
--- /dev/null
+++ b/services/device/claimable/v23_test.go
@@ -0,0 +1,30 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file was auto-generated via go generate.
+// DO NOT UPDATE MANUALLY
+
+package main_test
+
+import (
+ "os"
+ "testing"
+
+ "v.io/x/ref/test"
+ "v.io/x/ref/test/modules"
+ "v.io/x/ref/test/v23tests"
+)
+
+func TestMain(m *testing.M) {
+ test.Init()
+ modules.DispatchAndExitIfChild()
+ cleanup := v23tests.UseSharedBinDir()
+ r := m.Run()
+ cleanup()
+ os.Exit(r)
+}
+
+func TestV23ClaimableServer(t *testing.T) {
+ v23tests.RunTest(t, V23TestClaimableServer)
+}
diff --git a/services/device/deviced/internal/impl/dispatcher.go b/services/device/deviced/internal/impl/dispatcher.go
index d27dbf2..e54838c 100644
--- a/services/device/deviced/internal/impl/dispatcher.go
+++ b/services/device/deviced/internal/impl/dispatcher.go
@@ -84,26 +84,6 @@
errNewAgentFailed = verror.Register(pkgPath+".errNewAgentFailed", verror.NoRetry, "{1:}{2:} NewAgent() failed{:_}")
)
-// NewClaimableDispatcher returns an rpc.Dispatcher that allows the device to
-// be Claimed if it hasn't been already and a channel that will be closed once
-// the device has been claimed.
-//
-// It returns (nil, nil) if the device is no longer claimable.
-func NewClaimableDispatcher(ctx *context.T, config *config.State, pairingToken string) (rpc.Dispatcher, <-chan struct{}) {
- var (
- permsDir = PermsDir(config)
- permsStore = pathperms.NewPathStore(ctx)
- )
- if _, _, err := permsStore.Get(permsDir); !os.IsNotExist(err) {
- return nil, nil
- }
- // The device is claimable only if Claim hasn't been called before. The
- // existence of the Permissions file is an indication of a successful prior
- // call to Claim.
- notify := make(chan struct{})
- return &claimable{token: pairingToken, permsStore: permsStore, permsDir: permsDir, notify: notify}, notify
-}
-
// NewDispatcher is the device manager dispatcher factory. It returns a new
// dispatcher as well as a shutdown function, to be called when the dispatcher
// is no longer needed.
diff --git a/services/device/deviced/internal/starter/starter.go b/services/device/deviced/internal/starter/starter.go
index cf9967f..afd6e1e 100644
--- a/services/device/deviced/internal/starter/starter.go
+++ b/services/device/deviced/internal/starter/starter.go
@@ -25,6 +25,7 @@
"v.io/x/ref/services/debug/debuglib"
"v.io/x/ref/services/device/deviced/internal/impl"
"v.io/x/ref/services/device/deviced/internal/versioning"
+ "v.io/x/ref/services/device/internal/claim"
"v.io/x/ref/services/device/internal/config"
"v.io/x/ref/services/internal/pathperms"
"v.io/x/ref/services/mounttable/mounttablelib"
@@ -114,7 +115,7 @@
// claimable service and wait for it to be claimed.
// Once a device is claimed, close any previously running servers and
// start a new mounttable and device service.
- claimable, claimed := impl.NewClaimableDispatcher(ctx, args.Device.ConfigState, args.Device.PairingToken)
+ claimable, claimed := claim.NewClaimableDispatcher(ctx, impl.PermsDir(args.Device.ConfigState), args.Device.PairingToken, security.AllowEveryone())
if claimable == nil {
// Device has already been claimed, bypass claimable service
// stage.
diff --git a/services/device/deviced/internal/impl/claim.go b/services/device/internal/claim/claim.go
similarity index 77%
rename from services/device/deviced/internal/impl/claim.go
rename to services/device/internal/claim/claim.go
index 0959409..55202d8 100644
--- a/services/device/deviced/internal/impl/claim.go
+++ b/services/device/internal/claim/claim.go
@@ -2,10 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package impl
+package claim
import (
"crypto/subtle"
+ "os"
"sync"
"v.io/v23"
@@ -13,11 +14,29 @@
"v.io/v23/rpc"
"v.io/v23/security"
"v.io/v23/security/access"
+ "v.io/v23/services/device"
"v.io/v23/verror"
"v.io/x/ref/services/device/internal/errors"
"v.io/x/ref/services/internal/pathperms"
)
+// NewClaimableDispatcher returns an rpc.Dispatcher that allows the device to
+// be Claimed if it hasn't been already and a channel that will be closed once
+// the device has been claimed.
+//
+// It returns (nil, nil) if the device is no longer claimable.
+func NewClaimableDispatcher(ctx *context.T, permsDir, pairingToken string, auth security.Authorizer) (rpc.Dispatcher, <-chan struct{}) {
+ permsStore := pathperms.NewPathStore(ctx)
+ if _, _, err := permsStore.Get(permsDir); !os.IsNotExist(err) {
+ // The device is claimable only if Claim hasn't been called before. The
+ // existence of the Permissions file is an indication of a successful prior
+ // call to Claim.
+ return nil, nil
+ }
+ notify := make(chan struct{})
+ return &claimable{token: pairingToken, permsStore: permsStore, permsDir: permsDir, notify: notify, auth: auth}, notify
+}
+
// claimable implements the device.Claimable RPC interface and the
// rpc.Dispatcher and security.Authorizer to serve it.
//
@@ -27,6 +46,7 @@
permsStore *pathperms.PathStore
permsDir string
notify chan struct{} // GUARDED_BY(mu)
+ auth security.Authorizer
// Lock used to ensure that a successful claim can happen at most once.
// This is done by allowing only a single goroutine to execute the
@@ -101,11 +121,5 @@
if suffix != "" && suffix != "device" {
return nil, nil, verror.New(errors.ErrUnclaimedDevice, nil)
}
- return c, c, nil
-}
-
-func (c *claimable) Authorize(*context.T, security.Call) error {
- // Claim is open to all. The Claim method implementation
- // allows at most one successful call.
- return nil
+ return device.ClaimableServer(c), c.auth, nil
}