Merge "discovery: change Advertise() to return done channel"
diff --git a/lib/discovery/advertise.go b/lib/discovery/advertise.go
index 1d8a97e..c53d3c4 100644
--- a/lib/discovery/advertise.go
+++ b/lib/discovery/advertise.go
@@ -12,22 +12,23 @@
)
var (
- errNoInterfaceName = verror.Register(pkgPath+".errNoInterfaceName", verror.NoRetry, "{1:}{2:} interface name not provided")
- errNotPackableAttributes = verror.Register(pkgPath+".errNotPackableAttributes", verror.NoRetry, "{1:}{2:} attribute not packable")
- errNoAddresses = verror.Register(pkgPath+".errNoAddress", verror.NoRetry, "{1:}{2:} address not provided")
- errNotPackableAddresses = verror.Register(pkgPath+".errNotPackableAddresses", verror.NoRetry, "{1:}{2:} address not packable")
+ errAlreadyBeingAdvertised = verror.Register(pkgPath+".errAlreadyBeingAdvertised", verror.NoRetry, "{1:}{2:} already being advertised")
+ errNoInterfaceName = verror.Register(pkgPath+".errNoInterfaceName", verror.NoRetry, "{1:}{2:} interface name not provided")
+ errNotPackableAttributes = verror.Register(pkgPath+".errNotPackableAttributes", verror.NoRetry, "{1:}{2:} attribute not packable")
+ errNoAddresses = verror.Register(pkgPath+".errNoAddress", verror.NoRetry, "{1:}{2:} address not provided")
+ errNotPackableAddresses = verror.Register(pkgPath+".errNotPackableAddresses", verror.NoRetry, "{1:}{2:} address not packable")
)
// Advertise implements discovery.Advertiser.
-func (ds *ds) Advertise(ctx *context.T, service discovery.Service, visibility []security.BlessingPattern) error {
+func (ds *ds) Advertise(ctx *context.T, service discovery.Service, visibility []security.BlessingPattern) (<-chan struct{}, error) {
if len(service.InterfaceName) == 0 {
- return verror.New(errNoInterfaceName, ctx)
+ return nil, verror.New(errNoInterfaceName, ctx)
}
if len(service.Addrs) == 0 {
- return verror.New(errNoAddresses, ctx)
+ return nil, verror.New(errNoAddresses, ctx)
}
if err := validateAttributes(service.Attrs); err != nil {
- return err
+ return nil, err
}
if len(service.InstanceUuid) == 0 {
@@ -39,21 +40,49 @@
Service: service,
}
if err := encrypt(&ad, visibility); err != nil {
- return err
+ return nil, err
}
- adId := string(ad.Service.InstanceUuid)
- ctx, cancel, err := ds.addTask(ctx, adId)
+ ctx, cancel, err := ds.addTask(ctx)
if err != nil {
- return err
+ return nil, err
}
- barrier := NewBarrier(func() { ds.removeTask(ctx, adId) })
+ id := string(ad.Service.InstanceUuid)
+ if !ds.addAd(id) {
+ cancel()
+ ds.removeTask(ctx)
+ return nil, verror.New(errAlreadyBeingAdvertised, ctx)
+ }
+
+ done := make(chan struct{})
+ barrier := NewBarrier(func() {
+ ds.removeAd(id)
+ ds.removeTask(ctx)
+ close(done)
+ })
for _, plugin := range ds.plugins {
if err := plugin.Advertise(ctx, ad, barrier.Add()); err != nil {
cancel()
- return err
+ return nil, err
}
}
- return nil
+ return done, nil
+}
+
+func (ds *ds) addAd(id string) bool {
+ ds.mu.Lock()
+ if _, exist := ds.ads[id]; exist {
+ ds.mu.Unlock()
+ return false
+ }
+ ds.ads[id] = struct{}{}
+ ds.mu.Unlock()
+ return true
+}
+
+func (ds *ds) removeAd(id string) {
+ ds.mu.Lock()
+ delete(ds.ads, id)
+ ds.mu.Unlock()
}
diff --git a/lib/discovery/discovery.go b/lib/discovery/discovery.go
index cee3bb7..78597b7 100644
--- a/lib/discovery/discovery.go
+++ b/lib/discovery/discovery.go
@@ -16,20 +16,19 @@
var (
- errClosed = verror.Register(pkgPath+".errClosed", verror.NoRetry, "{1:}{2:} closed")
- errAlreadyBeingAdvertised = verror.Register(pkgPath+".errAlreadyBeingAdvertised", verror.NoRetry, "{1:}{2:} already being advertised")
+ errDiscoveryClosed = verror.Register(pkgPath+".errDiscoveryClosed", verror.NoRetry, "{1:}{2:} discovery closed")
)
// ds is an implementation of discovery.T.
type ds struct {
plugins []Plugin
- mu sync.Mutex
- closed bool // GUARDED_BY(mu)
- tasks map[*context.T]func() // GUARDED_BY(mu)
- advertising map[string]struct{} // GUARDED_BY(mu)
+ mu sync.Mutex
+ closed bool // GUARDED_BY(mu)
+ tasks map[*context.T]func() // GUARDED_BY(mu)
+ wg sync.WaitGroup
- wg sync.WaitGroup
+ ads map[string]struct{} // GUARDED_BY(mu)
}
func (ds *ds) Close() {
@@ -46,18 +45,11 @@
ds.wg.Wait()
}
-func (ds *ds) addTask(ctx *context.T, adId string) (*context.T, func(), error) {
+func (ds *ds) addTask(ctx *context.T) (*context.T, func(), error) {
ds.mu.Lock()
if ds.closed {
ds.mu.Unlock()
- return nil, nil, verror.New(errClosed, ctx)
- }
- if len(adId) > 0 {
- if _, exist := ds.advertising[adId]; exist {
- ds.mu.Unlock()
- return nil, nil, verror.New(errAlreadyBeingAdvertised, ctx)
- }
- ds.advertising[adId] = struct{}{}
+ return nil, nil, verror.New(errDiscoveryClosed, ctx)
}
ctx, cancel := context.WithCancel(ctx)
ds.tasks[ctx] = cancel
@@ -66,27 +58,26 @@
return ctx, cancel, nil
}
-func (ds *ds) removeTask(ctx *context.T, adId string) {
+func (ds *ds) removeTask(ctx *context.T) {
ds.mu.Lock()
- if len(adId) > 0 {
- delete(ds.advertising, adId)
- }
- _, exist := ds.tasks[ctx]
- delete(ds.tasks, ctx)
- ds.mu.Unlock()
- if exist {
+ if _, exist := ds.tasks[ctx]; exist {
+ delete(ds.tasks, ctx)
ds.wg.Done()
}
+ ds.mu.Unlock()
}
// New returns a new Discovery instance initialized with the given plugins.
//
// Mostly for internal use. Consider to use factory.New.
func NewWithPlugins(plugins []Plugin) discovery.T {
+ if len(plugins) == 0 {
+ panic("no plugins")
+ }
ds := &ds{
- plugins: make([]Plugin, len(plugins)),
- tasks: make(map[*context.T]func()),
- advertising: make(map[string]struct{}),
+ plugins: make([]Plugin, len(plugins)),
+ tasks: make(map[*context.T]func()),
+ ads: make(map[string]struct{}),
}
copy(ds.plugins, plugins)
return ds
diff --git a/lib/discovery/discovery_test.go b/lib/discovery/discovery_test.go
index 34b079f..28984cd 100644
--- a/lib/discovery/discovery_test.go
+++ b/lib/discovery/discovery_test.go
@@ -9,6 +9,7 @@
"fmt"
"reflect"
"runtime"
+ "sync"
"testing"
"time"
@@ -25,11 +26,21 @@
)
func advertise(ctx *context.T, ds discovery.Advertiser, perms []security.BlessingPattern, services ...discovery.Service) (func(), error) {
- ctx, stop := context.WithCancel(ctx)
+ var wg sync.WaitGroup
+ tr := idiscovery.NewTrigger()
+ ctx, cancel := context.WithCancel(ctx)
for _, service := range services {
- if err := ds.Advertise(ctx, service, perms); err != nil {
+ wg.Add(1)
+ done, err := ds.Advertise(ctx, service, perms)
+ if err != nil {
+ cancel()
return nil, fmt.Errorf("Advertise failed: %v", err)
}
+ tr.Add(wg.Done, done)
+ }
+ stop := func() {
+ cancel()
+ wg.Wait()
}
return stop, nil
}
diff --git a/lib/discovery/factory/lazy.go b/lib/discovery/factory/lazy.go
index 05af440..5622562 100644
--- a/lib/discovery/factory/lazy.go
+++ b/lib/discovery/factory/lazy.go
@@ -27,10 +27,10 @@
derr error
}
-func (l *lazyFactory) Advertise(ctx *context.T, service discovery.Service, visibility []security.BlessingPattern) error {
+func (l *lazyFactory) Advertise(ctx *context.T, service discovery.Service, visibility []security.BlessingPattern) (<-chan struct{}, error) {
l.once.Do(l.init)
if l.derr != nil {
- return l.derr
+ return nil, l.derr
}
return l.d.Advertise(ctx, service, visibility)
}
diff --git a/lib/discovery/factory/lazy_test.go b/lib/discovery/factory/lazy_test.go
index f97cb9f..003ec45 100644
--- a/lib/discovery/factory/lazy_test.go
+++ b/lib/discovery/factory/lazy_test.go
@@ -27,9 +27,9 @@
return m, nil
}
-func (m *mock) Advertise(_ *context.T, _ discovery.Service, _ []security.BlessingPattern) error {
+func (m *mock) Advertise(_ *context.T, _ discovery.Service, _ []security.BlessingPattern) (<-chan struct{}, error) {
m.numAdvertises++
- return nil
+ return nil, nil
}
func (m *mock) Scan(_ *context.T, _ string) (<-chan discovery.Update, error) {
@@ -91,7 +91,7 @@
}
// Closed already; Shouldn't initialize it again.
- if err := d.Advertise(nil, discovery.Service{}, nil); err != errClosed {
+ if _, err := d.Advertise(nil, discovery.Service{}, nil); err != errClosed {
t.Errorf("expected an error %v, but got %v", errClosed, err)
}
if err := m.check(0, 0, 0, 0); err != nil {
@@ -111,7 +111,7 @@
m := mock{initErr: errInit}
d := newLazyFactory(m.init)
- if err := d.Advertise(nil, discovery.Service{}, nil); err != errInit {
+ if _, err := d.Advertise(nil, discovery.Service{}, nil); err != errInit {
t.Errorf("expected an error %v, but got %v", errInit, err)
}
if err := m.check(1, 0, 0, 0); err != nil {
diff --git a/lib/discovery/scan.go b/lib/discovery/scan.go
index 729f223..1f5953f 100644
--- a/lib/discovery/scan.go
+++ b/lib/discovery/scan.go
@@ -19,7 +19,7 @@
serviceUuid = NewServiceUUID(query)
}
- ctx, cancel, err := ds.addTask(ctx, "")
+ ctx, cancel, err := ds.addTask(ctx)
if err != nil {
return nil, err
}
@@ -28,7 +28,7 @@
scanCh := make(chan Advertisement, 10)
barrier := NewBarrier(func() {
close(scanCh)
- ds.removeTask(ctx, "")
+ ds.removeTask(ctx)
})
for _, plugin := range ds.plugins {
if err := plugin.Scan(ctx, serviceUuid, scanCh, barrier.Add()); err != nil {
diff --git a/lib/discovery/trigger_test.go b/lib/discovery/trigger_test.go
index da2db00..dc5a2b0 100644
--- a/lib/discovery/trigger_test.go
+++ b/lib/discovery/trigger_test.go
@@ -42,4 +42,11 @@
if got, want := <-done, 0; got != want {
t.Errorf("Trigger failed; got %v, but wanted %v", got, want)
}
+
+ // Make sure the callback is triggered even when it is added with a closed channel.
+ close(c0)
+ tr.Add(f0, c0)
+ if got, want := <-done, 0; got != want {
+ t.Errorf("Trigger failed; got %v, but wanted %v", got, want)
+ }
}
diff --git a/services/discovery/service.go b/services/discovery/service.go
index 5e75c94..079e4e6 100644
--- a/services/discovery/service.go
+++ b/services/discovery/service.go
@@ -37,7 +37,8 @@
func (s *impl) RegisterService(ctx *context.T, call rpc.ServerCall, service discovery.Service, visibility []security.BlessingPattern) (sdiscovery.ServiceHandle, error) {
ctx, cancel := context.WithCancel(s.ctx)
- if err := s.d.Advertise(ctx, service, visibility); err != nil {
+ done, err := s.d.Advertise(ctx, service, visibility)
+ if err != nil {
cancel()
return 0, err
}
@@ -57,7 +58,10 @@
break
}
}
- s.handles[handle] = cancel
+ s.handles[handle] = func() {
+ cancel()
+ <-done
+ }
s.lastHandle = handle
s.mu.Unlock()
return handle, nil
@@ -65,11 +69,11 @@
func (s *impl) UnregisterService(ctx *context.T, call rpc.ServerCall, handle sdiscovery.ServiceHandle) error {
s.mu.Lock()
- cancel := s.handles[handle]
+ stop := s.handles[handle]
delete(s.handles, handle)
s.mu.Unlock()
- if cancel != nil {
- cancel()
+ if stop != nil {
+ stop()
}
return nil
}
diff --git a/services/syncbase/vsync/sync.go b/services/syncbase/vsync/sync.go
index c2ccf69..e203fdd 100644
--- a/services/syncbase/vsync/sync.go
+++ b/services/syncbase/vsync/sync.go
@@ -352,7 +352,7 @@
}
// Duplicate calls to advertise will return an error.
- err := advertiser.Advertise(ctx, sbService, nil)
+ _, err := advertiser.Advertise(ctx, sbService, nil)
if err == nil {
s.advCancel = stop
}