Merge "services/mounttable: Resolve vanadium/issues#583"
diff --git a/services/mounttable/mounttablelib/mounttable.go b/services/mounttable/mounttablelib/mounttable.go
index 1cc9175..666a355 100644
--- a/services/mounttable/mounttablelib/mounttable.go
+++ b/services/mounttable/mounttablelib/mounttable.go
@@ -646,6 +646,11 @@
// globStep is called with n and n.parent locked. Returns with both unlocked.
func (mt *mountTable) globStep(ctx *context.T, call security.Call, n *node, name string, pattern *glob.Glob, ch chan<- naming.GlobReply) {
+ if shouldAbort(ctx) {
+ n.parent.Unlock()
+ n.Unlock()
+ return
+ }
ctx.VI(2).Infof("globStep(%s, %s)", name, pattern)
// Globing is the lowest priority so we give up the cpu often.
@@ -704,6 +709,10 @@
// Recurse through the children.
matcher, suffix := pattern.Head(), pattern.Tail()
for k, c := range children {
+ if shouldAbort(ctx) {
+ n.Unlock()
+ return
+ }
// At this point, n lock is held.
if matcher.Match(k) {
c.Lock()
@@ -880,3 +889,12 @@
}
return creator, nil
}
+
+func shouldAbort(ctx *context.T) bool {
+ select {
+ case <-ctx.Done():
+ return true
+ default:
+ return false
+ }
+}
diff --git a/services/mounttable/mounttablelib/mounttable_test.go b/services/mounttable/mounttablelib/mounttable_test.go
index e54e977..9c5b9bf 100644
--- a/services/mounttable/mounttablelib/mounttable_test.go
+++ b/services/mounttable/mounttablelib/mounttable_test.go
@@ -23,6 +23,7 @@
"v.io/v23/rpc"
"v.io/v23/security"
"v.io/v23/security/access"
+ "v.io/v23/services/mounttable"
"v.io/v23/services/stats"
"v.io/v23/vdl"
@@ -435,6 +436,68 @@
}
}
+type fakeServerCall struct{}
+
+func (fakeServerCall) Security() security.Call { return security.NewCall(&security.CallParams{}) }
+func (fakeServerCall) Suffix() string { return "" }
+func (fakeServerCall) LocalEndpoint() naming.Endpoint { return nil }
+func (fakeServerCall) RemoteEndpoint() naming.Endpoint { return nil }
+func (fakeServerCall) GrantedBlessings() security.Blessings { return security.Blessings{} }
+func (fakeServerCall) Server() rpc.Server { return nil }
+
+func TestGlobAborts(t *testing.T) {
+ ctx, shutdown := test.V23Init()
+ defer shutdown()
+
+ mt, err := mounttablelib.NewMountTableDispatcher(ctx, "", "", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mount := func(name string) error {
+ invoker, _, _ := mt.Lookup(name)
+ server := naming.FormatEndpoint("tcp", name)
+ return invoker.(mounttable.MountTableServerStub).Mount(ctx, fakeServerCall{}, server, 0, 0)
+ }
+ // Mount 125 entries: 5 "directories" with 25 entries each.
+ for i := 0; i < 5; i++ {
+ for j := 0; j < 25; j++ {
+ if err := mount(fmt.Sprintf("%d/%d", i, j)); err != nil {
+ t.Fatalf("%v (%d, %d)", err, i, j)
+ }
+ }
+ }
+
+ glob := func(ctx *context.T) (int, error) {
+ root, _, _ := mt.Lookup("")
+ ch, err := root.(rpc.Globber).Globber().AllGlobber.Glob__(ctx, fakeServerCall{}, "...")
+ if err != nil {
+ return 0, err
+ }
+ num := 0
+ for range ch {
+ num++
+ }
+ return num, nil
+ }
+
+ got, err := glob(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := 5 + 125 + 1; got != want { // 5 "directories", 125 entries, 1 root entry
+ t.Errorf("Got %d want %d", got, want)
+ }
+ canceled, cancel := context.WithCancel(ctx)
+ cancel()
+ if got, err = glob(canceled); err != nil {
+ t.Fatal(err)
+ }
+ if got != 0 {
+ t.Errorf("Glob returned entries even though the context was cancelled first (returned %d)", got)
+ }
+}
+
func TestAccessListTemplate(t *testing.T) {
rootCtx, aliceCtx, bobCtx, shutdown := initTest()
defer shutdown()