services/mounttable: Count expired mounts
With this change, the num-mounted-servers counter is properly
decremented when expired mounts are removed.
Closes https://github.com/veyron/release-issues/issues/1896
Change-Id: Ia7364118eb0fd3d220831a36e53cd794f369ae02
diff --git a/services/mounttable/mounttablelib/mounttable.go b/services/mounttable/mounttablelib/mounttable.go
index c61f093..7ba7f4e 100644
--- a/services/mounttable/mounttablelib/mounttable.go
+++ b/services/mounttable/mounttablelib/mounttable.go
@@ -173,11 +173,15 @@
}
// isActive returns true if a mount has unexpired servers attached.
-func (m *mount) isActive() bool {
+func (m *mount) isActive(mt *mountTable) bool {
if m == nil {
return false
}
- return m.servers.removeExpired() > 0
+ numLeft, numRemoved := m.servers.removeExpired()
+ if numRemoved > 0 {
+ mt.serverCounter.Incr(int64(-numRemoved))
+ }
+ return numLeft > 0
}
// satisfies returns no error if the ctx + n.vPerms satisfies the associated one of the required Tags.
@@ -283,7 +287,7 @@
}
}
// If we hit another mount table, we're done.
- if cur.mount.isActive() {
+ if cur.mount.isActive(mt) {
return cur, elems[i:], nil
}
// Walk the children looking for a match.
@@ -373,7 +377,7 @@
n.Unlock()
return nil, nil, err
}
- if !n.mount.isActive() {
+ if !n.mount.isActive(mt) {
removed := n.removeUseless(mt)
n.parent.Unlock()
n.Unlock()
@@ -515,7 +519,7 @@
//
// We assume both n and n.parent are locked.
func (n *node) removeUseless(mt *mountTable) bool {
- if len(n.children) > 0 || n.mount.isActive() || n.explicitPermissions {
+ if len(n.children) > 0 || n.mount.isActive(mt) || n.explicitPermissions {
return false
}
for k, c := range n.parent.children {
diff --git a/services/mounttable/mounttablelib/mounttable_test.go b/services/mounttable/mounttablelib/mounttable_test.go
index 06d12f2..5704ef1 100644
--- a/services/mounttable/mounttablelib/mounttable_test.go
+++ b/services/mounttable/mounttablelib/mounttable_test.go
@@ -629,6 +629,9 @@
rootCtx, shutdown := test.InitForTest()
defer shutdown()
+ ft := NewFakeTimeClock()
+ setServerListClock(ft)
+
server, estr := newMT(t, "", "", rootCtx)
defer server.Stop()
@@ -723,6 +726,16 @@
if expected, got := int64(1), serverCount(t, rootCtx, estr); got != expected {
t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
}
+
+ // Test expired mounts
+ // "1/2/3/4/5" is still mounted from earlier.
+ ft.advance(time.Duration(ttlSecs+4) * time.Second)
+ if _, err := resolve(rootCtx, naming.JoinAddressName(estr, "1/2/3/4/5")); err == nil {
+ t.Errorf("Expected failure. Got success")
+ }
+ if expected, got := int64(0), serverCount(t, rootCtx, estr); got != expected {
+ t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+ }
}
func initTest() (rootCtx *context.T, aliceCtx *context.T, bobCtx *context.T, shutdown v23.Shutdown) {
diff --git a/services/mounttable/mounttablelib/serverlist.go b/services/mounttable/mounttablelib/serverlist.go
index 0738a94..123dc9d 100644
--- a/services/mounttable/mounttablelib/serverlist.go
+++ b/services/mounttable/mounttablelib/serverlist.go
@@ -98,20 +98,22 @@
}
// removeExpired removes any expired servers.
-func (sl *serverList) removeExpired() int {
+func (sl *serverList) removeExpired() (int, int) {
sl.Lock()
defer sl.Unlock()
now := slc.now()
var next *list.Element
+ removed := 0
for e := sl.l.Front(); e != nil; e = next {
s := e.Value.(*server)
next = e.Next()
if now.After(s.expires) {
sl.l.Remove(e)
+ removed++
}
}
- return sl.l.Len()
+ return sl.l.Len(), removed
}
// copyToSlice returns the contents of the list as a slice of MountedServer.
diff --git a/services/mounttable/mounttablelib/serverlist_test.go b/services/mounttable/mounttablelib/serverlist_test.go
index b17ffd0..9b5a484 100644
--- a/services/mounttable/mounttablelib/serverlist_test.go
+++ b/services/mounttable/mounttablelib/serverlist_test.go
@@ -50,7 +50,7 @@
// Test timing out entries.
ft.advance(6 * time.Second)
- if sl.removeExpired() != len(eps)-2 {
+ if numLeft, _ := sl.removeExpired(); numLeft != len(eps)-2 {
t.Fatalf("got %d, want %d", sl.len(), len(eps)-2)
}