Merge "syncbase: sb51: remove "Demo" prefix from tables in demo database."
diff --git a/lib/discovery/advertise.go b/lib/discovery/advertise.go
index 0cf4018..3930f63 100644
--- a/lib/discovery/advertise.go
+++ b/lib/discovery/advertise.go
@@ -7,33 +7,45 @@
import (
"v.io/v23/context"
"v.io/v23/discovery"
- "v.io/v23/security/access"
+ "v.io/v23/security"
"v.io/v23/verror"
)
var (
- errNoInterfaceName = verror.Register(pkgPath+".errNoInterfaceName", verror.NoRetry, "{1:}{2:} interface name not provided")
- errNoAddresses = verror.Register(pkgPath+".errNoAddress", verror.NoRetry, "{1:}{2:} address not provided")
+ errNoInterfaceName = verror.Register(pkgPath+".errNoInterfaceName", verror.NoRetry, "{1:}{2:} interface name not provided")
+ errNotPackableAttributes = verror.Register(pkgPath+".errNotPackableAttributes", verror.NoRetry, "{1:}{2:} attribute not packable")
+ errNoAddresses = verror.Register(pkgPath+".errNoAddress", verror.NoRetry, "{1:}{2:} address not provided")
+ errNotPackableAddresses = verror.Register(pkgPath+".errNotPackableAddresses", verror.NoRetry, "{1:}{2:} address not packable")
)
// Advertise implements discovery.Advertiser.
//
// TODO(jhahn): Handle ACL.
-func (ds *ds) Advertise(ctx *context.T, service discovery.Service, perms access.Permissions) error {
+func (ds *ds) Advertise(ctx *context.T, service discovery.Service, perms []security.BlessingPattern) error {
if len(service.InterfaceName) == 0 {
return verror.New(errNoInterfaceName, ctx)
}
+ if !IsAttributePackable(service.Attrs) {
+ return verror.New(errNotPackableAttributes, ctx)
+ }
if len(service.Addrs) == 0 {
return verror.New(errNoAddresses, ctx)
}
-
+ if !IsAddressPackable(service.Addrs) {
+ return verror.New(errNotPackableAddresses, ctx)
+ }
if len(service.InstanceUuid) == 0 {
service.InstanceUuid = NewInstanceUUID()
}
- ad := &Advertisement{
+
+ ad := Advertisement{
ServiceUuid: NewServiceUUID(service.InterfaceName),
Service: service,
}
+ if err := encrypt(&ad, perms); err != nil {
+ return err
+ }
+
ctx, cancel := context.WithCancel(ctx)
for _, plugin := range ds.plugins {
err := plugin.Advertise(ctx, ad)
diff --git a/lib/discovery/cipher.go b/lib/discovery/cipher.go
new file mode 100644
index 0000000..715c558
--- /dev/null
+++ b/lib/discovery/cipher.go
@@ -0,0 +1,152 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package discovery
+
+import (
+ "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "strings"
+
+ "golang.org/x/crypto/nacl/secretbox"
+
+ "v.io/v23/security"
+)
+
+var (
+ // errNoPermission is the error returned by decrypt when there is no permission
+ // to decrypt an advertisement.
+ errNoPermission = errors.New("no permission")
+)
+
+// encrypt identity-based encrypts the service so that only users who match with one of
+// the given blessing patterns can decrypt it. Nil patterns means no encryption.
+func encrypt(ad *Advertisement, patterns []security.BlessingPattern) error {
+ if len(patterns) == 0 {
+ ad.EncryptionAlgorithm = NoEncryption
+ return nil
+ }
+
+ sharedKey, keys, err := newSharedKey(patterns)
+ if err != nil {
+ return err
+ }
+ ad.EncryptionAlgorithm = TestEncryption
+ ad.EncryptionKeys = keys
+
+ // We only encrypt addresses for now.
+ //
+ // TODO(jhahn): Revisit the scope of encryption.
+ encrypted := make([]string, len(ad.Addrs))
+ for i, addr := range ad.Addrs {
+ var n [24]byte
+ binary.LittleEndian.PutUint64(n[:], uint64(i))
+ encrypted[i] = string(secretbox.Seal(nil, []byte(addr), &n, sharedKey))
+ }
+ ad.Addrs = encrypted
+ return nil
+}
+
+// decrypt decrypts the service with the given blessing names.
+func decrypt(ad *Advertisement, names []string) error {
+ if ad.EncryptionAlgorithm == NoEncryption {
+ // Not encrypted.
+ return nil
+ }
+ if len(names) == 0 {
+ // No identifiers.
+ return errNoPermission
+ }
+
+ if ad.EncryptionAlgorithm != TestEncryption {
+ return fmt.Errorf("not supported encryption algorithm %v\n", ad.EncryptionAlgorithm)
+ }
+ sharedKey, err := decryptSharedKey(ad.EncryptionKeys, names)
+ if err != nil {
+ return err
+ }
+ if sharedKey == nil {
+ return errNoPermission
+ }
+
+ // We only encrypt addresses for now.
+ //
+ // Note that we should not modify the slice element directly here since the
+ // underlying plugins may cache services and the next plugin.Scan() may return
+ // the already decrypted addresses.
+ decrypted := make([]string, len(ad.Addrs))
+ for i, encrypted := range ad.Addrs {
+ var n [24]byte
+ binary.LittleEndian.PutUint64(n[:], uint64(i))
+ addr, ok := secretbox.Open(nil, []byte(encrypted), &n, sharedKey)
+ if !ok {
+ return errors.New("decryption error")
+ }
+ decrypted[i] = string(addr)
+ }
+ ad.Addrs = decrypted
+ return nil
+}
+
+// newSharedKey creates a new shared encryption key and identity-based encrypts
+// the shared key with the given blessing patterns.
+func newSharedKey(patterns []security.BlessingPattern) (*[32]byte, []EncryptionKey, error) {
+ var sharedKey [32]byte
+ if _, err := rand.Read(sharedKey[:]); err != nil {
+ return nil, nil, err
+ }
+
+ keys := make([]EncryptionKey, len(patterns))
+ // TODO(jhahn): Replace this fake with the real IBE.
+ for i, pattern := range patterns {
+ var k [32]byte
+ copy(k[:], pattern)
+ keys[i] = secretbox.Seal(nil, sharedKey[:], &[24]byte{}, &k)
+ }
+ return &sharedKey, keys, nil
+}
+
+// decryptSharedKey decrypts the identity-based encrypted shared key with the
+// given blessing names.
+func decryptSharedKey(keys []EncryptionKey, names []string) (*[32]byte, error) {
+ // TODO(jhahn): Replace this fake with the real IBE.
+ for _, name := range names {
+ for _, pattern := range prefixPatterns(name) {
+ var k [32]byte
+ copy(k[:], pattern)
+ for _, key := range keys {
+ decrypted, ok := secretbox.Open(nil, key, &[24]byte{}, &k)
+ if !ok {
+ continue
+ }
+ if len(decrypted) != 32 {
+ return nil, errors.New("shared key decryption error")
+ }
+ var sharedKey [32]byte
+ copy(sharedKey[:], decrypted)
+ return &sharedKey, nil
+ }
+ }
+ }
+ return nil, nil
+}
+
+// prefixPatterns returns blessing patterns that can be matched by the given name.
+func prefixPatterns(name string) []string {
+ patterns := []string{
+ name,
+ name + security.ChainSeparator + string(security.NoExtension),
+ }
+ for {
+ i := strings.LastIndex(name, security.ChainSeparator)
+ if i < 0 {
+ break
+ }
+ name = name[:i]
+ patterns = append(patterns, name)
+ }
+ return patterns
+}
diff --git a/lib/discovery/discovery.go b/lib/discovery/discovery.go
index 10b7377..4f2210b 100644
--- a/lib/discovery/discovery.go
+++ b/lib/discovery/discovery.go
@@ -24,11 +24,27 @@
// The service UUID to advertise.
ServiceUuid uuid.UUID
+ // Type of encryption applied to the advertisement so that it can
+ // only be decoded by authorized principals.
+ EncryptionAlgorithm EncryptionAlgorithm
+ // If the advertisement is encrypted, then the data required to
+ // decrypt it. The format of this data is a function of the algorithm.
+ EncryptionKeys []EncryptionKey
+
// TODO(jhahn): Add proximity.
// TODO(jhahn): Use proximity for Lost.
Lost bool
}
+type EncryptionAlgorithm byte
+type EncryptionKey []byte
+
+const (
+ NoEncryption EncryptionAlgorithm = 0
+ TestEncryption EncryptionAlgorithm = 1
+ IbeEncryption EncryptionAlgorithm = 2
+)
+
// TODO(jhahn): Need a better API.
func New(plugins []Plugin) discovery.T {
ds := &ds{plugins: make([]Plugin, len(plugins))}
diff --git a/lib/discovery/discovery_test.go b/lib/discovery/discovery_test.go
index df10892..3875fa2 100644
--- a/lib/discovery/discovery_test.go
+++ b/lib/discovery/discovery_test.go
@@ -11,93 +11,30 @@
"testing"
"time"
+ "v.io/v23"
"v.io/v23/context"
"v.io/v23/discovery"
+ "v.io/v23/security"
ldiscovery "v.io/x/ref/lib/discovery"
"v.io/x/ref/lib/discovery/plugins/mock"
+ _ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/test"
+ "v.io/x/ref/test/testutil"
)
-func TestBasic(t *testing.T) {
- ds := ldiscovery.New([]ldiscovery.Plugin{mock.New()})
- services := []discovery.Service{
- {
- InstanceUuid: ldiscovery.NewInstanceUUID(),
- InterfaceName: "v.io/v23/a",
- Addrs: []string{"/h1:123/x", "/h2:123/y"},
- },
- {
- InstanceUuid: ldiscovery.NewInstanceUUID(),
- InterfaceName: "v.io/v23/b",
- Addrs: []string{"/h1:123/x", "/h2:123/z"},
- },
- }
- var stops []func()
+func advertise(ctx *context.T, ds discovery.Advertiser, perms []security.BlessingPattern, services ...discovery.Service) (func(), error) {
+ ctx, stop := context.WithCancel(ctx)
for _, service := range services {
- stop, err := advertise(ds, service)
- if err != nil {
- t.Fatal(err)
- }
- stops = append(stops, stop)
- }
-
- // Make sure all advertisements are discovered.
- if err := scanAndMatch(ds, "v.io/v23/a", services[0]); err != nil {
- t.Error(err)
- }
- if err := scanAndMatch(ds, "v.io/v23/b", services[1]); err != nil {
- t.Error(err)
- }
- if err := scanAndMatch(ds, "", services...); err != nil {
- t.Error(err)
- }
- if err := scanAndMatch(ds, "v.io/v23/c"); err != nil {
- t.Error(err)
- }
-
- // Open a new scan channel and consume expected advertisements first.
- scan, scanStop, err := startScan(ds, "v.io/v23/a")
- if err != nil {
- t.Error(err)
- }
- defer scanStop()
- update := <-scan
- if !matchFound([]discovery.Update{update}, services[0]) {
- t.Errorf("Unexpected scan: %v", update)
- }
-
- // Make sure scan returns the lost advertisement when advertising is stopped.
- stops[0]()
-
- update = <-scan
- if !matchLost([]discovery.Update{update}, services[0]) {
- t.Errorf("Unexpected scan: %v", update)
- }
-
- // Also it shouldn't affect the other.
- if err := scanAndMatch(ds, "v.io/v23/b", services[1]); err != nil {
- t.Error(err)
- }
-
- // Stop advertising the remaining one; Shouldn't discover any service.
- stops[1]()
- if err := scanAndMatch(ds, ""); err != nil {
- t.Error(err)
- }
-}
-
-func advertise(ds discovery.Advertiser, services ...discovery.Service) (func(), error) {
- ctx, cancel := context.RootContext()
- for _, service := range services {
- if err := ds.Advertise(ctx, service, nil); err != nil {
+ if err := ds.Advertise(ctx, service, perms); err != nil {
return nil, fmt.Errorf("Advertise failed: %v", err)
}
}
- return cancel, nil
+ return stop, nil
}
-func startScan(ds discovery.Scanner, query string) (<-chan discovery.Update, func(), error) {
- ctx, stop := context.RootContext()
+func startScan(ctx *context.T, ds discovery.Scanner, query string) (<-chan discovery.Update, func(), error) {
+ ctx, stop := context.WithCancel(ctx)
scan, err := ds.Scan(ctx, query)
if err != nil {
return nil, nil, fmt.Errorf("Scan failed: %v", err)
@@ -105,8 +42,8 @@
return scan, stop, err
}
-func scan(ds discovery.Scanner, query string) ([]discovery.Update, error) {
- scan, stop, err := startScan(ds, query)
+func scan(ctx *context.T, ds discovery.Scanner, query string) ([]discovery.Update, error) {
+ scan, stop, err := startScan(ctx, ds, query)
if err != nil {
return nil, err
}
@@ -123,6 +60,25 @@
}
}
+func scanAndMatch(ctx *context.T, ds discovery.Scanner, query string, wants ...discovery.Service) error {
+ const timeout = 3 * time.Second
+
+ var updates []discovery.Update
+ for now := time.Now(); time.Since(now) < timeout; {
+ runtime.Gosched()
+
+ var err error
+ updates, err = scan(ctx, ds, query)
+ if err != nil {
+ return err
+ }
+ if matchFound(updates, wants...) {
+ return nil
+ }
+ }
+ return fmt.Errorf("Match failed; got %v, but wanted %v", updates, wants)
+}
+
func match(updates []discovery.Update, lost bool, wants ...discovery.Service) bool {
for _, want := range wants {
matched := false
@@ -159,21 +115,125 @@
return match(updates, true, wants...)
}
-func scanAndMatch(ds discovery.Scanner, query string, wants ...discovery.Service) error {
- const timeout = 3 * time.Second
+func TestBasic(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
- var updates []discovery.Update
- for now := time.Now(); time.Since(now) < timeout; {
- runtime.Gosched()
-
- var err error
- updates, err = scan(ds, query)
- if err != nil {
- return err
- }
- if matchFound(updates, wants...) {
- return nil
- }
+ ds := ldiscovery.New([]ldiscovery.Plugin{mock.New()})
+ services := []discovery.Service{
+ {
+ InstanceUuid: ldiscovery.NewInstanceUUID(),
+ InterfaceName: "v.io/v23/a",
+ Attrs: discovery.Attributes{"a1": "v1"},
+ Addrs: []string{"/h1:123/x", "/h2:123/y"},
+ },
+ {
+ InstanceUuid: ldiscovery.NewInstanceUUID(),
+ InterfaceName: "v.io/v23/b",
+ Attrs: discovery.Attributes{"b1": "v1"},
+ Addrs: []string{"/h1:123/x", "/h2:123/z"},
+ },
}
- return fmt.Errorf("Match failed; got %v, but wanted %v", updates, wants)
+ var stops []func()
+ for _, service := range services {
+ stop, err := advertise(ctx, ds, nil, service)
+ if err != nil {
+ t.Fatal(err)
+ }
+ stops = append(stops, stop)
+ }
+
+ // Make sure all advertisements are discovered.
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a", services[0]); err != nil {
+ t.Error(err)
+ }
+ if err := scanAndMatch(ctx, ds, "v.io/v23/b", services[1]); err != nil {
+ t.Error(err)
+ }
+ if err := scanAndMatch(ctx, ds, "", services...); err != nil {
+ t.Error(err)
+ }
+ if err := scanAndMatch(ctx, ds, "v.io/v23/c"); err != nil {
+ t.Error(err)
+ }
+
+ // Open a new scan channel and consume expected advertisements first.
+ scan, scanStop, err := startScan(ctx, ds, "v.io/v23/a")
+ if err != nil {
+ t.Error(err)
+ }
+ defer scanStop()
+ update := <-scan
+ if !matchFound([]discovery.Update{update}, services[0]) {
+ t.Errorf("Unexpected scan: %v", update)
+ }
+
+ // Make sure scan returns the lost advertisement when advertising is stopped.
+ stops[0]()
+
+ update = <-scan
+ if !matchLost([]discovery.Update{update}, services[0]) {
+ t.Errorf("Unexpected scan: %v", update)
+ }
+
+ // Also it shouldn't affect the other.
+ if err := scanAndMatch(ctx, ds, "v.io/v23/b", services[1]); err != nil {
+ t.Error(err)
+ }
+
+ // Stop advertising the remaining one; Shouldn't discover any service.
+ stops[1]()
+ if err := scanAndMatch(ctx, ds, ""); err != nil {
+ t.Error(err)
+ }
+}
+
+// TODO(jhahn): Add a low level test that ensures the advertisement is unusable
+// by the listener, if encrypted rather than replying on a higher level API.
+func TestPermission(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
+
+ ds := ldiscovery.New([]ldiscovery.Plugin{mock.New()})
+ service := discovery.Service{
+ InstanceUuid: ldiscovery.NewInstanceUUID(),
+ InterfaceName: "v.io/v23/a",
+ Attrs: discovery.Attributes{"a1": "v1", "a2": "v2"},
+ Addrs: []string{"/h1:123/x", "/h2:123/y"},
+ }
+ perms := []security.BlessingPattern{
+ security.BlessingPattern("v.io/bob"),
+ security.BlessingPattern("v.io/alice").MakeNonExtendable(),
+ }
+ stop, err := advertise(ctx, ds, perms, service)
+ defer stop()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Bob and his friend should discover the advertisement.
+ ctx, _ = v23.WithPrincipal(ctx, testutil.NewPrincipal("v.io/bob"))
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a", service); err != nil {
+ t.Error(err)
+ }
+ ctx, _ = v23.WithPrincipal(ctx, testutil.NewPrincipal("v.io/bob/friend"))
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a", service); err != nil {
+ t.Error(err)
+ }
+
+ // Alice should discover the advertisement, but her friend shouldn't.
+ ctx, _ = v23.WithPrincipal(ctx, testutil.NewPrincipal("v.io/alice"))
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a", service); err != nil {
+ t.Error(err)
+ }
+ ctx, _ = v23.WithPrincipal(ctx, testutil.NewPrincipal("v.io/alice/friend"))
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a"); err != nil {
+ t.Error(err)
+ }
+
+ // Other people shouldn't discover the advertisement.
+ ctx, _ = v23.WithPrincipal(ctx, testutil.NewPrincipal("v.io/carol"))
+ if err := scanAndMatch(ctx, ds, "v.io/v23/a"); err != nil {
+ t.Error(err)
+ }
}
diff --git a/lib/discovery/encoding.go b/lib/discovery/encoding.go
new file mode 100644
index 0000000..0f27651
--- /dev/null
+++ b/lib/discovery/encoding.go
@@ -0,0 +1,96 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package discovery
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+
+ "v.io/v23/discovery"
+)
+
+// TODO(jhahn): Figure out how to overcome the size limit.
+
+// isAttributePackage returns false if the provided attributes cannot be serialized safely.
+func IsAttributePackable(attrs discovery.Attributes) bool {
+ for k, v := range attrs {
+ if strings.HasPrefix(k, "_") || strings.Contains(k, "=") {
+ return false
+ }
+ if len(k)+len(v) > 254 {
+ return false
+ }
+ }
+ return true
+}
+
+// IsAddressPackable returns false if any address is larger than 250 bytes.
+//
+// go-mdns-sd package limits the size of each txt record to 255 bytes. We use
+// 5 bytes for tag, so we limit the address to 250 bytes.
+func IsAddressPackable(addrs []string) bool {
+ for _, a := range addrs {
+ if len(a) > 250 {
+ return false
+ }
+ }
+ return true
+}
+
+// PackAddresses packs addresses into a byte slice. If any address exceeds
+// 255 bytes, it will panic.
+func PackAddresses(addrs []string) []byte {
+ var b bytes.Buffer
+ for _, a := range addrs {
+ n := len(a)
+ if n > 255 {
+ panic(fmt.Sprintf("too large address %d: %s", n, a))
+ }
+ b.WriteByte(byte(n))
+ b.WriteString(a)
+ }
+ return b.Bytes()
+}
+
+// UnpackAddresses unpacks addresses from a byte slice.
+func UnpackAddresses(data []byte) []string {
+ addrs := []string{}
+ for off := 0; off < len(data); {
+ n := int(data[off])
+ off++
+ addrs = append(addrs, string(data[off:off+n]))
+ off += n
+ }
+ return addrs
+}
+
+// PackEncryptionKeys packs keys into a byte slice.
+func PackEncryptionKeys(algo EncryptionAlgorithm, keys []EncryptionKey) []byte {
+ var b bytes.Buffer
+ b.WriteByte(byte(algo))
+ for _, k := range keys {
+ n := len(k)
+ if n > 255 {
+ panic(fmt.Sprintf("too large key %d", n))
+ }
+ b.WriteByte(byte(n))
+ b.Write(k)
+ }
+ return b.Bytes()
+}
+
+// UnpackEncryptionKeys unpacks keys from a byte slice.
+func UnpackEncryptionKeys(data []byte) (EncryptionAlgorithm, []EncryptionKey) {
+ algo := EncryptionAlgorithm(data[0])
+ keys := []EncryptionKey{}
+ for off := 1; off < len(data); {
+ n := int(data[off])
+ off++
+ keys = append(keys, EncryptionKey(data[off:off+n]))
+ off += n
+ }
+ return algo, keys
+}
diff --git a/lib/discovery/encoding_test.go b/lib/discovery/encoding_test.go
new file mode 100644
index 0000000..ca74b57
--- /dev/null
+++ b/lib/discovery/encoding_test.go
@@ -0,0 +1,81 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package discovery
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ "v.io/v23/discovery"
+)
+
+func TestAttributePackable(t *testing.T) {
+ tests := []struct {
+ addrs discovery.Attributes
+ want bool
+ }{
+ {discovery.Attributes{"k": "v"}, true},
+ {discovery.Attributes{"_k": "v"}, false},
+ {discovery.Attributes{"k=": "v"}, false},
+ {discovery.Attributes{strings.Repeat("k", 100): strings.Repeat("v", 154)}, true},
+ {discovery.Attributes{strings.Repeat("k", 100): strings.Repeat("v", 155)}, false},
+ }
+ for i, test := range tests {
+ if got := IsAttributePackable(test.addrs); got != test.want {
+ t.Errorf("[%d]: packable %v, but want %v", i, got, test.want)
+ }
+ }
+}
+
+func TestAddressPackable(t *testing.T) {
+ tests := []struct {
+ addrs []string
+ want bool
+ }{
+ {[]string{strings.Repeat("a", 250)}, true},
+ {[]string{strings.Repeat("a", 10), strings.Repeat("a", 251)}, false},
+ }
+ for i, test := range tests {
+ if got := IsAddressPackable(test.addrs); got != test.want {
+ t.Errorf("[%d]: packable %v, but want %v", i, got, test.want)
+ }
+ }
+}
+
+func TestPackAddresses(t *testing.T) {
+ tests := [][]string{
+ []string{"a12345"},
+ []string{"a1234", "b5678", "c9012"},
+ []string{},
+ }
+
+ for _, test := range tests {
+ pack := PackAddresses(test)
+ unpack := UnpackAddresses(pack)
+ if !reflect.DeepEqual(test, unpack) {
+ t.Errorf("unpacked to %v, but want %v", unpack, test)
+ }
+ }
+}
+
+func TestPackEncryptionKeys(t *testing.T) {
+ tests := []struct {
+ algo EncryptionAlgorithm
+ keys []EncryptionKey
+ }{
+ {TestEncryption, []EncryptionKey{EncryptionKey("0123456789")}},
+ {IbeEncryption, []EncryptionKey{EncryptionKey("012345"), EncryptionKey("123456"), EncryptionKey("234567")}},
+ {NoEncryption, []EncryptionKey{}},
+ }
+
+ for _, test := range tests {
+ pack := PackEncryptionKeys(test.algo, test.keys)
+ algo, keys := UnpackEncryptionKeys(pack)
+ if algo != test.algo || !reflect.DeepEqual(keys, test.keys) {
+ t.Errorf("unpacked to (%d, %v), but want (%d, %v)", algo, keys, test.algo, test.keys)
+ }
+ }
+}
diff --git a/lib/discovery/plugin.go b/lib/discovery/plugin.go
index 57a74c1..fc420c1 100644
--- a/lib/discovery/plugin.go
+++ b/lib/discovery/plugin.go
@@ -15,7 +15,7 @@
type Plugin interface {
// Advertise advertises the advertisement. Advertising will continue until
// the context is canceled or exceeds its deadline.
- Advertise(ctx *context.T, ad *Advertisement) error
+ Advertise(ctx *context.T, ad Advertisement) error
// Scan scans services that match the service uuid and returns scanned
// advertisements to the channel. A zero-value service uuid means any service.
@@ -23,5 +23,5 @@
// deadline.
//
// TODO(jhahn): Pass a filter on service attributes.
- Scan(ctx *context.T, serviceUuid uuid.UUID, scanCh chan<- *Advertisement) error
+ Scan(ctx *context.T, serviceUuid uuid.UUID, ch chan<- Advertisement) error
}
diff --git a/lib/discovery/plugins/ble/advertisement.go b/lib/discovery/plugins/ble/advertisement.go
index 04c87b9..b64ade1 100644
--- a/lib/discovery/plugins/ble/advertisement.go
+++ b/lib/discovery/plugins/ble/advertisement.go
@@ -5,15 +5,14 @@
package ble
import (
+ "fmt"
"strings"
- "fmt"
- "net/url"
+ "github.com/pborman/uuid"
vdiscovery "v.io/v23/discovery"
- "v.io/x/ref/lib/discovery"
- "github.com/pborman/uuid"
+ "v.io/x/ref/lib/discovery"
)
type bleAdv struct {
@@ -22,28 +21,21 @@
attrs map[string][]byte
}
-var (
- // This uuids are v4 uuid generated out of band. These constants need
+const (
+ // This uuids are v5 uuid generated out of band. These constants need
// to be accessible in all the languages that have a ble implementation
-
- // The attribute uuid for the unique service id
- instanceUUID = "f6445c7f-73fd-4b8d-98d0-c4e02b087844"
-
- // The attribute uuid for the interface name
- interfaceNameUUID = "d4789810-4db0-40d8-9658-92f8e304d578"
-
- addrUUID = "f123fb0e-770f-4e46-b8ad-aee4185ab5a1"
+ instanceUUID = "12db9a9c-1c7c-5560-bc6b-73a115c93413" // NewAttributeUUID("_instanceuuid")
+ interfaceNameUUID = "b2cadfd4-d003-576c-acad-58b8e3a9cbc8" // NewAttributeUUID("_interfacename")
+ addrsUUID = "ad2566b7-59d8-50ae-8885-222f43f65fdc" // NewAttributeUUID("_addrs")
+ encryptionUUID = "6286d80a-adaa-519a-8a06-281a4645a607" // NewAttributeUUID("_encryption")
)
func newAdvertisment(adv discovery.Advertisement) bleAdv {
- cleanAddrs := make([]string, len(adv.Addrs))
- for i, v := range adv.Addrs {
- cleanAddrs[i] = url.QueryEscape(v)
- }
attrs := map[string][]byte{
instanceUUID: adv.InstanceUuid,
interfaceNameUUID: []byte(adv.InterfaceName),
- addrUUID: []byte(strings.Join(cleanAddrs, "&")),
+ addrsUUID: discovery.PackAddresses(adv.Addrs),
+ encryptionUUID: discovery.PackEncryptionKeys(adv.EncryptionAlgorithm, adv.EncryptionKeys),
}
for k, v := range adv.Attrs {
@@ -58,34 +50,31 @@
}
func (a *bleAdv) toDiscoveryAdvertisement() (*discovery.Advertisement, error) {
- out := &discovery.Advertisement{
+ adv := &discovery.Advertisement{
Service: vdiscovery.Service{
- Attrs: vdiscovery.Attributes{},
- InterfaceName: string(a.attrs[interfaceNameUUID]),
- InstanceUuid: a.instanceID,
+ InstanceUuid: a.instanceID,
+ Attrs: make(vdiscovery.Attributes),
},
ServiceUuid: a.serviceUUID,
}
- out.Addrs = strings.Split(string(a.attrs[addrUUID]), "&")
- var err error
- for i, v := range out.Addrs {
- out.Addrs[i], err = url.QueryUnescape(v)
- if err != nil {
- return nil, err
- }
- }
for k, v := range a.attrs {
- if k == instanceUUID || k == interfaceNameUUID || k == addrUUID {
- continue
+ switch k {
+ case instanceUUID:
+ adv.InstanceUuid = v
+ case interfaceNameUUID:
+ adv.InterfaceName = string(v)
+ case addrsUUID:
+ adv.Addrs = discovery.UnpackAddresses(v)
+ case encryptionUUID:
+ adv.EncryptionAlgorithm, adv.EncryptionKeys = discovery.UnpackEncryptionKeys(v)
+ default:
+ parts := strings.SplitN(string(v), "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("incorrectly formatted value, %s", v)
+ }
+ adv.Attrs[parts[0]] = parts[1]
}
- parts := strings.SplitN(string(v), "=", 2)
- if len(parts) != 2 {
- return nil, fmt.Errorf("incorrectly formatted value, %s", v)
- }
- out.Attrs[parts[0]] = parts[1]
-
}
-
- return out, nil
+ return adv, nil
}
diff --git a/lib/discovery/plugins/ble/advertisement_test.go b/lib/discovery/plugins/ble/advertisement_test.go
index 5b4025c..3fe95db 100644
--- a/lib/discovery/plugins/ble/advertisement_test.go
+++ b/lib/discovery/plugins/ble/advertisement_test.go
@@ -5,24 +5,29 @@
package ble
import (
- "github.com/pborman/uuid"
"reflect"
"testing"
+
+ "github.com/pborman/uuid"
+
vdiscovery "v.io/v23/discovery"
+
"v.io/x/ref/lib/discovery"
)
func TestConvertingBackAndForth(t *testing.T) {
v23Adv := discovery.Advertisement{
Service: vdiscovery.Service{
- Addrs: []string{"localhost:1000", "example.com:540"},
- InstanceUuid: []byte(uuid.NewUUID()),
- Attrs: map[string]string{
+ InstanceUuid: []byte(discovery.NewInstanceUUID()),
+ Attrs: vdiscovery.Attributes{
"key1": "value1",
"key2": "value2",
},
+ Addrs: []string{"localhost:1000", "example.com:540"},
},
- ServiceUuid: uuid.NewUUID(),
+ ServiceUuid: uuid.NewUUID(),
+ EncryptionAlgorithm: discovery.TestEncryption,
+ EncryptionKeys: []discovery.EncryptionKey{discovery.EncryptionKey("k1"), discovery.EncryptionKey("k2")},
}
adv := newAdvertisment(v23Adv)
diff --git a/lib/discovery/plugins/ble/plugin.go b/lib/discovery/plugins/ble/plugin.go
index 6c7d36b..4b94a18 100644
--- a/lib/discovery/plugins/ble/plugin.go
+++ b/lib/discovery/plugins/ble/plugin.go
@@ -19,15 +19,15 @@
trigger *discovery.Trigger
}
-func (b *blePlugin) Advertise(ctx *context.T, ad *discovery.Advertisement) error {
- b.b.addAdvertisement(newAdvertisment(*ad))
+func (b *blePlugin) Advertise(ctx *context.T, ad discovery.Advertisement) error {
+ b.b.addAdvertisement(newAdvertisment(ad))
b.trigger.Add(func() {
b.b.removeService(ad.InstanceUuid)
}, ctx.Done())
return nil
}
-func (b *blePlugin) Scan(ctx *context.T, serviceUuid uuid.UUID, scan chan<- *discovery.Advertisement) error {
+func (b *blePlugin) Scan(ctx *context.T, serviceUuid uuid.UUID, scan chan<- discovery.Advertisement) error {
ch, id := b.b.addScanner(serviceUuid)
drain := func() {
for range ch {
@@ -44,7 +44,7 @@
case <-ctx.Done():
break L
case a := <-ch:
- scan <- a
+ scan <- *a
}
}
}()
diff --git a/lib/discovery/plugins/mdns/mdns.go b/lib/discovery/plugins/mdns/mdns.go
index baa7339..b036716 100644
--- a/lib/discovery/plugins/mdns/mdns.go
+++ b/lib/discovery/plugins/mdns/mdns.go
@@ -12,18 +12,12 @@
//
// v23._tcp.local.
// _<printer_service_uuid>._sub._v23._tcp.local.
-//
-// Even though an instance is advertised as two services, both PTR records refer
-// to the same name.
-//
-// _v23._tcp.local. PTR <instance_uuid>.<printer_service_uuid>._v23._tcp.local.
-// _<printer_service_uuid>._sub._v23._tcp.local.
-// PTR <instance_uuid>.<printer_service_uuid>._v23._tcp.local.
package mdns
import (
"encoding/hex"
"fmt"
+ "strconv"
"strings"
"sync"
"time"
@@ -40,14 +34,15 @@
const (
v23ServiceName = "v23"
serviceNameSuffix = "._sub._" + v23ServiceName
- // The host name is in the form of '<instance uuid>.<service uuid>._v23._tcp.local.'.
- // The double dots at the end are for bypassing the host name composition in
- // go-mdns-sd package so that we can use the same host name both in the (subtype)
- // service and v23 service announcements.
- hostNameSuffix = "._v23._tcp.local.."
- attrInterface = "__intf"
- attrAddr = "__addr"
+ // The attribute names should not exceed 4 bytes due to the txt record
+ // size limit.
+ attrServiceUuid = "_srv"
+ attrInterface = "_itf"
+ attrAddr = "_adr"
+ // TODO(jhahn): Remove attrEncryptionAlgorithm.
+ attrEncryptionAlgorithm = "_xxx"
+ attrEncryptionKeys = "_key"
)
type plugin struct {
@@ -65,10 +60,12 @@
lastSubscription time.Time
}
-func (p *plugin) Advertise(ctx *context.T, ad *ldiscovery.Advertisement) error {
+func (p *plugin) Advertise(ctx *context.T, ad ldiscovery.Advertisement) error {
serviceName := ad.ServiceUuid.String() + serviceNameSuffix
- hostName := fmt.Sprintf("%x.%s%s", ad.InstanceUuid, ad.ServiceUuid.String(), hostNameSuffix)
- txt, err := createTXTRecords(ad)
+ // We use the instance uuid as the host name so that we can get the instance uuid
+ // from the lost service instance, which has no txt records at all.
+ hostName := hex.EncodeToString(ad.InstanceUuid)
+ txt, err := createTXTRecords(&ad)
if err != nil {
return err
}
@@ -85,14 +82,14 @@
return err
}
stop := func() {
- p.mdns.RemoveService(serviceName, hostName, 0)
- p.mdns.RemoveService(v23ServiceName, hostName, 0)
+ p.mdns.RemoveService(serviceName, hostName, 0, txt...)
+ p.mdns.RemoveService(v23ServiceName, hostName, 0, txt...)
}
p.adStopper.Add(stop, ctx.Done())
return nil
}
-func (p *plugin) Scan(ctx *context.T, serviceUuid uuid.UUID, scanCh chan<- *ldiscovery.Advertisement) error {
+func (p *plugin) Scan(ctx *context.T, serviceUuid uuid.UUID, ch chan<- ldiscovery.Advertisement) error {
var serviceName string
if len(serviceUuid) == 0 {
serviceName = v23ServiceName
@@ -143,7 +140,7 @@
continue
}
select {
- case scanCh <- ad:
+ case ch <- ad:
case <-ctx.Done():
return
}
@@ -155,38 +152,37 @@
func createTXTRecords(ad *ldiscovery.Advertisement) ([]string, error) {
// Prepare a TXT record with attributes and addresses to announce.
//
- // TODO(jhahn): Currently, the record size is limited to 2000 bytes in
- // go-mdns-sd package. Think about how to handle a large TXT record size
- // exceeds the limit.
- txt := make([]string, 0, len(ad.Attrs)+len(ad.Addrs)+1)
+ // TODO(jhahn): Currently, the packet size is limited to 2000 bytes in
+ // go-mdns-sd package. Think about how to handle a large number of TXT
+ // records.
+ txt := make([]string, 0, len(ad.Attrs)+4)
+ txt = append(txt, fmt.Sprintf("%s=%s", attrServiceUuid, ad.ServiceUuid))
txt = append(txt, fmt.Sprintf("%s=%s", attrInterface, ad.InterfaceName))
for k, v := range ad.Attrs {
txt = append(txt, fmt.Sprintf("%s=%s", k, v))
}
- for _, addr := range ad.Addrs {
- txt = append(txt, fmt.Sprintf("%s=%s", attrAddr, addr))
+ for _, a := range ad.Addrs {
+ txt = append(txt, fmt.Sprintf("%s=%s", attrAddr, a))
+ }
+ txt = append(txt, fmt.Sprintf("%s=%d", attrEncryptionAlgorithm, ad.EncryptionAlgorithm))
+ for _, k := range ad.EncryptionKeys {
+ txt = append(txt, fmt.Sprintf("%s=%s", attrEncryptionKeys, k))
}
return txt, nil
}
-func decodeAdvertisement(service mdns.ServiceInstance) (*ldiscovery.Advertisement, error) {
- // Note that service.Name would be '<instance uuid>.<service uuid>._v23._tcp.local.' for
- // subtype service discovery and ''<instance uuid>.<service uuid>' for v23 service discovery.
- p := strings.SplitN(service.Name, ".", 3)
- if len(p) < 2 {
- return nil, fmt.Errorf("invalid host name: %s", service.Name)
+func decodeAdvertisement(service mdns.ServiceInstance) (ldiscovery.Advertisement, error) {
+ // Note that service.Name starts with a host name, which is the instance uuid.
+ p := strings.SplitN(service.Name, ".", 2)
+ if len(p) < 1 {
+ return ldiscovery.Advertisement{}, fmt.Errorf("invalid host name: %s", service.Name)
}
instanceUuid, err := hex.DecodeString(p[0])
if err != nil {
- return nil, fmt.Errorf("invalid instance uuid in host name: %s", p[0])
- }
- serviceUuid := uuid.Parse(p[1])
- if len(serviceUuid) == 0 {
- return nil, fmt.Errorf("invalid service uuid in host name: %s", p[1])
+ return ldiscovery.Advertisement{}, fmt.Errorf("invalid host name: %v", err)
}
ad := ldiscovery.Advertisement{
- ServiceUuid: serviceUuid,
Service: discovery.Service{
InstanceUuid: instanceUuid,
Attrs: make(discovery.Attributes),
@@ -198,19 +194,26 @@
for _, txt := range rr.Txt {
kv := strings.SplitN(txt, "=", 2)
if len(kv) != 2 {
- return nil, fmt.Errorf("invalid txt record: %s", txt)
+ return ldiscovery.Advertisement{}, fmt.Errorf("invalid txt record: %s", txt)
}
switch k, v := kv[0], kv[1]; k {
+ case attrServiceUuid:
+ ad.ServiceUuid = uuid.Parse(v)
case attrInterface:
ad.InterfaceName = v
case attrAddr:
ad.Addrs = append(ad.Addrs, v)
+ case attrEncryptionAlgorithm:
+ a, _ := strconv.Atoi(v)
+ ad.EncryptionAlgorithm = ldiscovery.EncryptionAlgorithm(a)
+ case attrEncryptionKeys:
+ ad.EncryptionKeys = append(ad.EncryptionKeys, ldiscovery.EncryptionKey(v))
default:
ad.Attrs[k] = v
}
}
}
- return &ad, nil
+ return ad, nil
}
func New(host string) (ldiscovery.Plugin, error) {
@@ -219,8 +222,8 @@
func newWithLoopback(host string, loopback bool) (ldiscovery.Plugin, error) {
if len(host) == 0 {
- // go-mdns-sd reannounce the services periodically only when the host name
- // is set. Use a default one if not given.
+ // go-mdns-sd doesn't answer when the host name is not set.
+ // Assign a default one if not given.
host = "v23()"
}
var v4addr, v6addr string
diff --git a/lib/discovery/plugins/mdns/mdns_test.go b/lib/discovery/plugins/mdns/mdns_test.go
index b6beeae..5b6a7c3 100644
--- a/lib/discovery/plugins/mdns/mdns_test.go
+++ b/lib/discovery/plugins/mdns/mdns_test.go
@@ -17,9 +17,114 @@
"v.io/v23/discovery"
ldiscovery "v.io/x/ref/lib/discovery"
+ _ "v.io/x/ref/runtime/factories/generic"
+ "v.io/x/ref/test"
)
+func encryptionKeys(key []byte) []ldiscovery.EncryptionKey {
+ return []ldiscovery.EncryptionKey{ldiscovery.EncryptionKey(fmt.Sprintf("key:%x", key))}
+}
+
+func advertise(ctx *context.T, p ldiscovery.Plugin, service discovery.Service) (func(), error) {
+ ctx, stop := context.WithCancel(ctx)
+ ad := ldiscovery.Advertisement{
+ ServiceUuid: ldiscovery.NewServiceUUID(service.InterfaceName),
+ Service: service,
+ EncryptionAlgorithm: ldiscovery.TestEncryption,
+ EncryptionKeys: encryptionKeys(service.InstanceUuid),
+ }
+ if err := p.Advertise(ctx, ad); err != nil {
+ return nil, fmt.Errorf("Advertise failed: %v", err)
+ }
+ return stop, nil
+}
+
+func startScan(ctx *context.T, p ldiscovery.Plugin, interfaceName string) (<-chan ldiscovery.Advertisement, func(), error) {
+ ctx, stop := context.WithCancel(ctx)
+ scan := make(chan ldiscovery.Advertisement)
+ var serviceUuid uuid.UUID
+ if len(interfaceName) > 0 {
+ serviceUuid = ldiscovery.NewServiceUUID(interfaceName)
+ }
+ if err := p.Scan(ctx, serviceUuid, scan); err != nil {
+ return nil, nil, fmt.Errorf("Scan failed: %v", err)
+ }
+ return scan, stop, nil
+}
+
+func scan(ctx *context.T, p ldiscovery.Plugin, interfaceName string) ([]ldiscovery.Advertisement, error) {
+ scan, stop, err := startScan(ctx, p, interfaceName)
+ if err != nil {
+ return nil, err
+ }
+ defer stop()
+
+ var ads []ldiscovery.Advertisement
+ for {
+ select {
+ case ad := <-scan:
+ ads = append(ads, ad)
+ case <-time.After(10 * time.Millisecond):
+ return ads, nil
+ }
+ }
+}
+
+func match(ads []ldiscovery.Advertisement, lost bool, wants ...discovery.Service) bool {
+ for _, want := range wants {
+ matched := false
+ for i, ad := range ads {
+ if !uuid.Equal(ad.InstanceUuid, want.InstanceUuid) {
+ continue
+ }
+ if lost {
+ matched = ad.Lost
+ } else {
+ matched = !ad.Lost && reflect.DeepEqual(ad.Service, want) && ad.EncryptionAlgorithm == ldiscovery.TestEncryption && reflect.DeepEqual(ad.EncryptionKeys, encryptionKeys(want.InstanceUuid))
+ }
+ if matched {
+ ads = append(ads[:i], ads[i+1:]...)
+ break
+ }
+ }
+ if !matched {
+ return false
+ }
+ }
+ return len(ads) == 0
+}
+
+func matchFound(ads []ldiscovery.Advertisement, wants ...discovery.Service) bool {
+ return match(ads, false, wants...)
+}
+
+func matchLost(ads []ldiscovery.Advertisement, wants ...discovery.Service) bool {
+ return match(ads, true, wants...)
+}
+
+func scanAndMatch(ctx *context.T, p ldiscovery.Plugin, interfaceName string, wants ...discovery.Service) error {
+ const timeout = 3 * time.Second
+
+ var ads []ldiscovery.Advertisement
+ for now := time.Now(); time.Since(now) < timeout; {
+ runtime.Gosched()
+
+ var err error
+ ads, err = scan(ctx, p, interfaceName)
+ if err != nil {
+ return err
+ }
+ if matchFound(ads, wants...) {
+ return nil
+ }
+ }
+ return fmt.Errorf("Match failed; got %v, but wanted %v", ads, wants)
+}
+
func TestBasic(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
+
services := []discovery.Service{
{
InstanceUuid: ldiscovery.NewInstanceUUID(),
@@ -61,159 +166,66 @@
if err != nil {
t.Fatalf("New() failed: %v", err)
}
- p2, err := newWithLoopback("m2", true)
- if err != nil {
- t.Fatalf("New() failed: %v", err)
- }
var stops []func()
for _, service := range services {
- stop, err := advertise(p1, service)
+ stop, err := advertise(ctx, p1, service)
if err != nil {
t.Fatal(err)
}
stops = append(stops, stop)
}
+ p2, err := newWithLoopback("m2", true)
+ if err != nil {
+ t.Fatalf("New() failed: %v", err)
+ }
+
// Make sure all advertisements are discovered.
- if err := scanAndMatch(p2, "v.io/x", services[0], services[1]); err != nil {
+ if err := scanAndMatch(ctx, p2, "v.io/x", services[0], services[1]); err != nil {
t.Error(err)
}
- if err := scanAndMatch(p2, "v.io/y", services[2]); err != nil {
+ if err := scanAndMatch(ctx, p2, "v.io/y", services[2]); err != nil {
t.Error(err)
}
- if err := scanAndMatch(p2, "", services...); err != nil {
+ if err := scanAndMatch(ctx, p2, "", services...); err != nil {
t.Error(err)
}
- if err := scanAndMatch(p2, "v.io/z"); err != nil {
+ if err := scanAndMatch(ctx, p2, "v.io/z"); err != nil {
t.Error(err)
}
// Make sure it is not discovered when advertising is stopped.
stops[0]()
- if err := scanAndMatch(p2, "v.io/x", services[1]); err != nil {
+ if err := scanAndMatch(ctx, p2, "v.io/x", services[1]); err != nil {
t.Error(err)
}
- if err := scanAndMatch(p2, "", services[1], services[2]); err != nil {
+ if err := scanAndMatch(ctx, p2, "", services[1], services[2]); err != nil {
t.Error(err)
}
// Open a new scan channel and consume expected advertisements first.
- scan, scanStop, err := startScan(p2, "v.io/y")
+ scan, scanStop, err := startScan(ctx, p2, "v.io/y")
if err != nil {
t.Error(err)
}
defer scanStop()
- ad := *<-scan
+ ad := <-scan
if !matchFound([]ldiscovery.Advertisement{ad}, services[2]) {
- t.Errorf("Unexpected scan: %v", ad)
+ t.Errorf("Unexpected scan: %v, but want %v", ad, services[2])
}
// Make sure scan returns the lost advertisement when advertising is stopped.
stops[2]()
- ad = *<-scan
+ ad = <-scan
if !matchLost([]ldiscovery.Advertisement{ad}, services[2]) {
- t.Errorf("Unexpected scan: %v", ad)
+ t.Errorf("Unexpected scan: %v, but want %v as lost", ad, services[2])
}
// Stop advertising the remaining one; Shouldn't discover anything.
stops[1]()
- if err := scanAndMatch(p2, ""); err != nil {
+ if err := scanAndMatch(ctx, p2, ""); err != nil {
t.Error(err)
}
}
-
-func advertise(p ldiscovery.Plugin, service discovery.Service) (func(), error) {
- ctx, cancel := context.RootContext()
- ad := ldiscovery.Advertisement{
- ServiceUuid: ldiscovery.NewServiceUUID(service.InterfaceName),
- Service: service,
- }
- if err := p.Advertise(ctx, &ad); err != nil {
- return nil, fmt.Errorf("Advertise failed: %v", err)
- }
- return cancel, nil
-}
-
-func startScan(p ldiscovery.Plugin, interfaceName string) (<-chan *ldiscovery.Advertisement, func(), error) {
- ctx, stop := context.RootContext()
- scan := make(chan *ldiscovery.Advertisement)
- var serviceUuid uuid.UUID
- if len(interfaceName) > 0 {
- serviceUuid = ldiscovery.NewServiceUUID(interfaceName)
- }
- if err := p.Scan(ctx, serviceUuid, scan); err != nil {
- return nil, nil, fmt.Errorf("Scan failed: %v", err)
- }
- return scan, stop, nil
-}
-
-func scan(p ldiscovery.Plugin, interfaceName string) ([]ldiscovery.Advertisement, error) {
- scan, stop, err := startScan(p, interfaceName)
- if err != nil {
- return nil, err
- }
- defer stop()
-
- var ads []ldiscovery.Advertisement
- for {
- select {
- case ad := <-scan:
- ads = append(ads, *ad)
- case <-time.After(10 * time.Millisecond):
- return ads, nil
- }
- }
-}
-
-func match(ads []ldiscovery.Advertisement, lost bool, wants ...discovery.Service) bool {
- for _, want := range wants {
- matched := false
- for i, ad := range ads {
- if !uuid.Equal(ad.ServiceUuid, ldiscovery.NewServiceUUID(want.InterfaceName)) {
- continue
- }
- if lost {
- matched = ad.Lost
- } else {
- matched = lost || reflect.DeepEqual(ad.Service, want)
- }
- if matched {
- ads = append(ads[:i], ads[i+1:]...)
- break
- }
- }
- if !matched {
- return false
- }
- }
- return len(ads) == 0
-}
-
-func matchFound(ads []ldiscovery.Advertisement, wants ...discovery.Service) bool {
- return match(ads, false, wants...)
-}
-
-func matchLost(ads []ldiscovery.Advertisement, wants ...discovery.Service) bool {
- return match(ads, true, wants...)
-}
-
-func scanAndMatch(p ldiscovery.Plugin, interfaceName string, wants ...discovery.Service) error {
- const timeout = 1 * time.Second
-
- var ads []ldiscovery.Advertisement
- for now := time.Now(); time.Since(now) < timeout; {
- runtime.Gosched()
-
- var err error
- ads, err = scan(p, interfaceName)
- if err != nil {
- return err
- }
- if matchFound(ads, wants...) {
- return nil
- }
- }
- return fmt.Errorf("Match failed; got %v, but wanted %v", ads, wants)
-}
diff --git a/lib/discovery/plugins/mock/mock.go b/lib/discovery/plugins/mock/mock.go
index a15fc8a..29e55eb 100644
--- a/lib/discovery/plugins/mock/mock.go
+++ b/lib/discovery/plugins/mock/mock.go
@@ -17,12 +17,12 @@
type plugin struct {
mu sync.Mutex
- services map[string][]*discovery.Advertisement // GUARDED_BY(mu)
+ services map[string][]discovery.Advertisement // GUARDED_BY(mu)
updated *sync.Cond
}
-func (p *plugin) Advertise(ctx *context.T, ad *discovery.Advertisement) error {
+func (p *plugin) Advertise(ctx *context.T, ad discovery.Advertisement) error {
p.mu.Lock()
key := string(ad.ServiceUuid)
ads := p.services[key]
@@ -52,7 +52,7 @@
return nil
}
-func (p *plugin) Scan(ctx *context.T, serviceUuid uuid.UUID, scanCh chan<- *discovery.Advertisement) error {
+func (p *plugin) Scan(ctx *context.T, serviceUuid uuid.UUID, ch chan<- discovery.Advertisement) error {
rescan := make(chan struct{})
go func() {
for {
@@ -68,10 +68,10 @@
}()
go func() {
- scanned := make(map[string]*discovery.Advertisement)
+ scanned := make(map[string]discovery.Advertisement)
for {
- current := make(map[string]*discovery.Advertisement)
+ current := make(map[string]discovery.Advertisement)
p.mu.Lock()
for key, ads := range p.services {
if len(serviceUuid) > 0 && key != string(serviceUuid) {
@@ -83,7 +83,7 @@
}
p.mu.Unlock()
- changed := make([]*discovery.Advertisement, 0, len(current))
+ changed := make([]discovery.Advertisement, 0, len(current))
for key, ad := range current {
old, ok := scanned[key]
if !ok || !reflect.DeepEqual(old, ad) {
@@ -100,7 +100,7 @@
// Push new changes.
for _, ad := range changed {
select {
- case scanCh <- ad:
+ case ch <- ad:
case <-ctx.Done():
return
}
@@ -121,7 +121,7 @@
func New() discovery.Plugin {
return &plugin{
- services: make(map[string][]*discovery.Advertisement),
+ services: make(map[string][]discovery.Advertisement),
updated: sync.NewCond(&sync.Mutex{}),
}
}
diff --git a/lib/discovery/scan.go b/lib/discovery/scan.go
index 634e8df..94da490 100644
--- a/lib/discovery/scan.go
+++ b/lib/discovery/scan.go
@@ -7,6 +7,7 @@
import (
"github.com/pborman/uuid"
+ "v.io/v23"
"v.io/v23/context"
"v.io/v23/discovery"
)
@@ -19,7 +20,7 @@
serviceUuid = NewServiceUUID(query)
}
// TODO(jhahn): Revisit the buffer size.
- scanCh := make(chan *Advertisement, 10)
+ scanCh := make(chan Advertisement, 10)
ctx, cancel := context.WithCancel(ctx)
for _, plugin := range ds.plugins {
err := plugin.Scan(ctx, serviceUuid, scanCh)
@@ -34,19 +35,48 @@
return updateCh, nil
}
-func doScan(ctx *context.T, scanCh <-chan *Advertisement, updateCh chan<- discovery.Update) {
+func doScan(ctx *context.T, scanCh <-chan Advertisement, updateCh chan<- discovery.Update) {
defer close(updateCh)
+
+ // Get the blessing names belong to the principal.
+ //
+ // TODO(jhahn): It isn't clear that we will always have the blessing required to decrypt
+ // the advertisement as their "default" blessing - indeed it may not even be in the store.
+ // Revisit this issue.
+ principal := v23.GetPrincipal(ctx)
+ var names []string
+ if principal != nil {
+ blessings := principal.BlessingStore().Default()
+ for n, _ := range principal.BlessingsInfo(blessings) {
+ names = append(names, n)
+ }
+ }
+
+ // A plugin may returns a Lost event with clearing all attributes including encryption
+ // keys. Thus, we have to keep what we've found so far so that we can ignore the Lost
+ // events for instances that we ignored due to permission.
+ found := make(map[string]struct{})
for {
select {
case ad := <-scanCh:
- // TODO(jhahn): Merge scanData based on InstanceUuid.
- var update discovery.Update
- if ad.Lost {
- update = discovery.UpdateLost{discovery.Lost{Service: ad.Service}}
- } else {
- update = discovery.UpdateFound{discovery.Found{Service: ad.Service}}
+ if err := decrypt(&ad, names); err != nil {
+ // Couldn't decrypt it. Ignore it.
+ if err != errNoPermission {
+ ctx.Error(err)
+ }
+ continue
}
- updateCh <- update
+ id := string(ad.InstanceUuid)
+ // TODO(jhahn): Merge scanData based on InstanceUuid.
+ if ad.Lost {
+ if _, ok := found[id]; ok {
+ delete(found, id)
+ updateCh <- discovery.UpdateLost{discovery.Lost{Service: ad.Service}}
+ }
+ } else {
+ found[id] = struct{}{}
+ updateCh <- discovery.UpdateFound{discovery.Found{Service: ad.Service}}
+ }
case <-ctx.Done():
return
}
diff --git a/runtime/internal/flow/conn/auth.go b/runtime/internal/flow/conn/auth.go
index dc3ecd4..02bcc8c 100644
--- a/runtime/internal/flow/conn/auth.go
+++ b/runtime/internal/flow/conn/auth.go
@@ -150,7 +150,6 @@
return err
}
}
- var rPublicKey security.PublicKey
if rauth.BlessingsKey != 0 {
var err error
// TODO(mattr): Make sure we cancel out of this at some point.
@@ -158,14 +157,14 @@
if err != nil {
return err
}
- rPublicKey = c.rBlessings.PublicKey()
+ c.rPublicKey = c.rBlessings.PublicKey()
} else {
- rPublicKey = rauth.PublicKey
+ c.rPublicKey = rauth.PublicKey
}
- if rPublicKey == nil {
+ if c.rPublicKey == nil {
return NewErrNoPublicKey(ctx)
}
- if !rauth.ChannelBinding.Verify(rPublicKey, binding) {
+ if !rauth.ChannelBinding.Verify(c.rPublicKey, binding) {
return NewErrInvalidChannelBinding(ctx)
}
return nil
diff --git a/runtime/internal/flow/conn/conn.go b/runtime/internal/flow/conn/conn.go
index 5a6e248..17475a5 100644
--- a/runtime/internal/flow/conn/conn.go
+++ b/runtime/internal/flow/conn/conn.go
@@ -67,6 +67,7 @@
mp *messagePipe
version version.RPCVersion
lBlessings, rBlessings security.Blessings
+ rPublicKey security.PublicKey
local, remote naming.Endpoint
closed chan struct{}
blessingsFlow *blessingsFlow
@@ -361,6 +362,14 @@
c.borrowing[msg.ID] = true
c.mu.Unlock()
+ rBlessings, _, err := c.blessingsFlow.get(ctx, msg.BlessingsKey, msg.DischargeKey)
+ if err != nil {
+ return err
+ }
+ if !reflect.DeepEqual(rBlessings.PublicKey(), c.rPublicKey) {
+ return NewErrBlessingsNotBound(ctx)
+ }
+
handler.HandleFlow(f)
if err := f.q.put(ctx, msg.Payload); err != nil {
return err
diff --git a/runtime/internal/flow/conn/errors.vdl b/runtime/internal/flow/conn/errors.vdl
index 42d092f..42b56bf 100644
--- a/runtime/internal/flow/conn/errors.vdl
+++ b/runtime/internal/flow/conn/errors.vdl
@@ -25,4 +25,5 @@
DialingNonServer() {"en": "You are attempting to dial on a connection with no remote server."}
AcceptorBlessingsMissing() {"en": "The acceptor did not send blessings."}
UpdatingNilFlowHandler() {"en": "nil flowHandler cannot be updated to non-nil value."}
+ BlessingsNotBound() {"en": "blessings not bound to connection remote public key"}
)
diff --git a/runtime/internal/flow/conn/errors.vdl.go b/runtime/internal/flow/conn/errors.vdl.go
index ea7a30e..506f283 100644
--- a/runtime/internal/flow/conn/errors.vdl.go
+++ b/runtime/internal/flow/conn/errors.vdl.go
@@ -29,6 +29,7 @@
ErrDialingNonServer = verror.Register("v.io/x/ref/runtime/internal/flow/conn.DialingNonServer", verror.NoRetry, "{1:}{2:} You are attempting to dial on a connection with no remote server.")
ErrAcceptorBlessingsMissing = verror.Register("v.io/x/ref/runtime/internal/flow/conn.AcceptorBlessingsMissing", verror.NoRetry, "{1:}{2:} The acceptor did not send blessings.")
ErrUpdatingNilFlowHandler = verror.Register("v.io/x/ref/runtime/internal/flow/conn.UpdatingNilFlowHandler", verror.NoRetry, "{1:}{2:} nil flowHandler cannot be updated to non-nil value.")
+ ErrBlessingsNotBound = verror.Register("v.io/x/ref/runtime/internal/flow/conn.BlessingsNotBound", verror.NoRetry, "{1:}{2:} blessings not bound to connection remote public key")
)
func init() {
@@ -46,6 +47,7 @@
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrDialingNonServer.ID), "{1:}{2:} You are attempting to dial on a connection with no remote server.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrAcceptorBlessingsMissing.ID), "{1:}{2:} The acceptor did not send blessings.")
i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrUpdatingNilFlowHandler.ID), "{1:}{2:} nil flowHandler cannot be updated to non-nil value.")
+ i18n.Cat().SetWithBase(i18n.LangID("en"), i18n.MsgID(ErrBlessingsNotBound.ID), "{1:}{2:} blessings not bound to connection remote public key")
}
// NewErrMissingSetupOption returns an error with the ErrMissingSetupOption ID.
@@ -117,3 +119,8 @@
func NewErrUpdatingNilFlowHandler(ctx *context.T) error {
return verror.New(ErrUpdatingNilFlowHandler, ctx)
}
+
+// NewErrBlessingsNotBound returns an error with the ErrBlessingsNotBound ID.
+func NewErrBlessingsNotBound(ctx *context.T) error {
+ return verror.New(ErrBlessingsNotBound, ctx)
+}
diff --git a/runtime/internal/flow/manager/conncache.go b/runtime/internal/flow/manager/conncache.go
index b3975e2..257b681 100644
--- a/runtime/internal/flow/manager/conncache.go
+++ b/runtime/internal/flow/manager/conncache.go
@@ -19,11 +19,12 @@
// Multiple goroutines can invoke methods on the ConnCache simultaneously.
// TODO(suharshs): We should periodically look for closed connections and remove them.
type ConnCache struct {
- mu *sync.Mutex
- addrCache map[string]*connEntry // keyed by (protocol, address, blessingNames)
- ridCache map[naming.RoutingID]*connEntry // keyed by naming.RoutingID
- started map[string]bool // keyed by (protocol, address, blessingNames)
- cond *sync.Cond
+ mu *sync.Mutex
+ cond *sync.Cond
+ addrCache map[string]*connEntry // keyed by (protocol, address, blessingNames)
+ ridCache map[naming.RoutingID]*connEntry // keyed by naming.RoutingID
+ started map[string]bool // keyed by (protocol, address, blessingNames)
+ unmappedConns map[*connEntry]bool // list of connEntries replaced by other entries
}
type connEntry struct {
@@ -36,11 +37,12 @@
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
return &ConnCache{
- mu: mu,
- addrCache: make(map[string]*connEntry),
- ridCache: make(map[naming.RoutingID]*connEntry),
- started: make(map[string]bool),
- cond: cond,
+ mu: mu,
+ cond: cond,
+ addrCache: make(map[string]*connEntry),
+ ridCache: make(map[naming.RoutingID]*connEntry),
+ started: make(map[string]bool),
+ unmappedConns: make(map[*connEntry]bool),
}
}
@@ -101,6 +103,9 @@
rid: ep.RoutingID(),
addrKey: k,
}
+ if old := c.ridCache[entry.rid]; old != nil {
+ c.unmappedConns[old] = true
+ }
c.addrCache[k] = entry
c.ridCache[entry.rid] = entry
return nil
@@ -117,6 +122,9 @@
conn: conn,
rid: conn.RemoteEndpoint().RoutingID(),
}
+ if old := c.ridCache[entry.rid]; old != nil {
+ c.unmappedConns[old] = true
+ }
c.ridCache[entry.rid] = entry
return nil
}
@@ -126,8 +134,12 @@
defer c.mu.Unlock()
c.mu.Lock()
c.addrCache, c.started = nil, nil
+ err := NewErrCacheClosed(ctx)
for _, d := range c.ridCache {
- d.conn.Close(ctx, NewErrCacheClosed(ctx))
+ d.conn.Close(ctx, err)
+ }
+ for d := range c.unmappedConns {
+ d.conn.Close(ctx, err)
}
}
@@ -157,12 +169,23 @@
}
pq = append(pq, e)
}
+ for d := range c.unmappedConns {
+ if isClosed(d.conn) {
+ delete(c.unmappedConns, d)
+ continue
+ }
+ if d.conn.IsEncapsulated() {
+ continue
+ }
+ pq = append(pq, d)
+ }
sort.Sort(pq)
for i := 0; i < num; i++ {
d := pq[i]
d.conn.Close(ctx, err)
delete(c.addrCache, d.addrKey)
delete(c.ridCache, d.rid)
+ delete(c.unmappedConns, d)
}
return nil
}
diff --git a/runtime/internal/flow/manager/conncache_test.go b/runtime/internal/flow/manager/conncache_test.go
index 8d02e9a..c66afa1 100644
--- a/runtime/internal/flow/manager/conncache_test.go
+++ b/runtime/internal/flow/manager/conncache_test.go
@@ -114,17 +114,27 @@
t.Errorf("got %v, want %v", cachedConn, otherConn)
}
+ // Insert a duplicate conn to ensure that replaced conns still get closed.
+ dupConn := makeConnAndFlow(t, ctx, remote).c
+ if err := c.Insert(dupConn); err != nil {
+ t.Fatal(err)
+ }
+
// Closing the cache should close all the connections in the cache.
// Ensure that the conns are not closed yet.
if isClosed(conn) {
- t.Fatalf("wanted conn to not be closed")
+ t.Fatal("wanted conn to not be closed")
+ }
+ if isClosed(dupConn) {
+ t.Fatal("wanted dupConn to not be closed")
}
if isClosed(otherConn) {
- t.Fatalf("wanted otherConn to not be closed")
+ t.Fatal("wanted otherConn to not be closed")
}
c.Close(ctx)
// Now the connections should be closed.
<-conn.Closed()
+ <-dupConn.Closed()
<-otherConn.Closed()
}