services/mounttable/btmtd: Keep track of mounted servers

Add a counter for the number of mounted servers, per user. Each server
is charged to the creator of the node where it is mounted.

Since we're keeping track of the mounted servers, we can't use
bigtable's garbage collection to delete expired servers. Instead, server
expiration is now done in gc().

Change-Id: Ia1b1b59ff68d1bf933828c975bc6ca7053c92bea
diff --git a/services/mounttable/btmtd/internal/bt.go b/services/mounttable/btmtd/internal/bt.go
index de89f98..64ec5e7 100644
--- a/services/mounttable/btmtd/internal/bt.go
+++ b/services/mounttable/btmtd/internal/bt.go
@@ -5,8 +5,6 @@
 package internal
 
 import (
-	"bytes"
-	"encoding/binary"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -146,7 +144,7 @@
 	}{
 		{b.counterTableName(), metadataFamily, bigtable.MaxVersionsPolicy(1)},
 		{b.nodeTableName(), metadataFamily, bigtable.MaxVersionsPolicy(1)},
-		{b.nodeTableName(), serversFamily, bigtable.UnionPolicy(bigtable.MaxVersionsPolicy(1), bigtable.MaxAgePolicy(time.Second))},
+		{b.nodeTableName(), serversFamily, bigtable.MaxVersionsPolicy(1)},
 		{b.nodeTableName(), childrenFamily, bigtable.MaxVersionsPolicy(1)},
 	}
 	for _, f := range families {
@@ -255,6 +253,28 @@
 	return count, nil
 }
 
+func (b *BigTable) Counters(ctx *context.T) (map[string]int64, error) {
+	bctx, cancel := btctx(ctx)
+	defer cancel()
+
+	counters := make(map[string]int64)
+	if err := b.counterTbl.ReadRows(bctx, bigtable.InfiniteRange(""),
+		func(row bigtable.Row) bool {
+			c, err := decodeCounterValue(ctx, row)
+			if err != nil {
+				ctx.Errorf("decodeCounterValue: %v", err)
+				return false
+			}
+			counters[row.Key()] = c
+			return true
+		},
+		bigtable.RowFilter(bigtable.LatestNFilter(1)),
+	); err != nil {
+		return nil, err
+	}
+	return counters, nil
+}
+
 func getTokenSource(ctx netcontext.Context, scope, keyFile string) (oauth2.TokenSource, error) {
 	if len(keyFile) == 0 {
 		return google.DefaultTokenSource(ctx, scope)
@@ -324,27 +344,5 @@
 	if err := b.apply(ctx, rowKey(name), mut); err != nil {
 		return err
 	}
-	return b.incrementCreatorNodeCount(ctx, creator, 1)
-}
-
-func (b *BigTable) incrementCreatorNodeCount(ctx *context.T, creator string, delta int64) error {
-	bctx, cancel := btctx(ctx)
-	defer cancel()
-
-	key := "num-nodes-per-user:" + creator
-	m := bigtable.NewReadModifyWrite()
-	m.Increment(metadataFamily, "c", delta)
-	row, err := b.counterTbl.ApplyReadModifyWrite(bctx, key, m)
-	if err != nil {
-		return err
-	}
-	if len(row[metadataFamily]) == 1 {
-		var c int64
-		b := row[metadataFamily][0].Value
-		if err := binary.Read(bytes.NewReader(b), binary.BigEndian, &c); err != nil {
-			return err
-		}
-		ctx.Infof("Counter %s = %d", key, c)
-	}
-	return nil
+	return incrementCreatorNodeCount(ctx, b, creator, 1)
 }
diff --git a/services/mounttable/btmtd/internal/counters.go b/services/mounttable/btmtd/internal/counters.go
new file mode 100644
index 0000000..0bff4c1
--- /dev/null
+++ b/services/mounttable/btmtd/internal/counters.go
@@ -0,0 +1,47 @@
+// Copyright 2016 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package internal
+
+import (
+	"bytes"
+	"encoding/binary"
+
+	"google.golang.org/cloud/bigtable"
+
+	"v.io/v23/context"
+	"v.io/v23/verror"
+)
+
+func incrementCounter(ctx *context.T, bt *BigTable, name string, delta int64) (int64, error) {
+	bctx, cancel := btctx(ctx)
+	defer cancel()
+
+	m := bigtable.NewReadModifyWrite()
+	m.Increment(metadataFamily, "c", delta)
+	row, err := bt.counterTbl.ApplyReadModifyWrite(bctx, name, m)
+	if err != nil {
+		return 0, err
+	}
+	return decodeCounterValue(ctx, row)
+}
+
+func decodeCounterValue(ctx *context.T, row bigtable.Row) (c int64, err error) {
+	if len(row[metadataFamily]) != 1 {
+		return 0, verror.NewErrInternal(ctx)
+	}
+	b := row[metadataFamily][0].Value
+	err = binary.Read(bytes.NewReader(b), binary.BigEndian, &c)
+	return
+}
+
+func incrementCreatorNodeCount(ctx *context.T, bt *BigTable, creator string, delta int64) error {
+	_, err := incrementCounter(ctx, bt, "num-nodes-per-user:"+creator, delta)
+	return err
+}
+
+func incrementCreatorServerCount(ctx *context.T, bt *BigTable, creator string, delta int64) error {
+	_, err := incrementCounter(ctx, bt, "num-servers-per-user:"+creator, delta)
+	return err
+}
diff --git a/services/mounttable/btmtd/internal/mounttable_test.go b/services/mounttable/btmtd/internal/mounttable_test.go
index e591686..aed1b76 100644
--- a/services/mounttable/btmtd/internal/mounttable_test.go
+++ b/services/mounttable/btmtd/internal/mounttable_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"reflect"
 	"runtime/debug"
@@ -713,6 +714,124 @@
 	checkMatch(t, []string{"a1/b1"}, doGlob(t, rootCtx, estr, "", "*/b1/..."))
 }
 
+func TestNodeCounters(t *testing.T) {
+	rootCtx, shutdown := test.V23InitWithMounttable()
+	defer shutdown()
+
+	stop, estr, bt, clock := newMT(t, "", rootCtx)
+	defer stop()
+
+	nodeCount := func() int64 {
+		counters, err := bt.Counters(rootCtx)
+		if err != nil {
+			t.Fatalf("bt.Counters failed: %v", err)
+		}
+		return counters["num-nodes-per-user:test-blessing"]
+	}
+	serverCount := func() int64 {
+		counters, err := bt.Counters(rootCtx)
+		if err != nil {
+			t.Fatalf("bt.Counters failed: %v", err)
+		}
+		return counters["num-servers-per-user:test-blessing"]
+	}
+
+	// Test flat tree
+	for i := 1; i <= 10; i++ {
+		name := fmt.Sprintf("node%d", i)
+		addr := naming.JoinAddressName(estr, name)
+		doMount(t, rootCtx, estr, name, addr, true)
+		if expected, got := int64(i), nodeCount(); got != expected {
+			t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+		}
+		if expected, got := int64(i), serverCount(); got != expected {
+			t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+		}
+	}
+	for i := 1; i <= 10; i++ {
+		name := fmt.Sprintf("node%d", i)
+		if i%2 == 0 {
+			doUnmount(t, rootCtx, estr, name, "", true)
+		} else {
+			doDeleteSubtree(t, rootCtx, estr, name, true)
+		}
+		if expected, got := int64(10-i), nodeCount(); got != expected {
+			t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+		}
+		if expected, got := int64(10-i), serverCount(); got != expected {
+			t.Errorf("Unexpected number of server. Got %d, expected %d", got, expected)
+		}
+	}
+
+	// Test deep tree
+	doMount(t, rootCtx, estr, "1/2/3/4/5/6/7/8/9a/10", naming.JoinAddressName(estr, ""), true)
+	doMount(t, rootCtx, estr, "1/2/3/4/5/6/7/8/9b/11", naming.JoinAddressName(estr, ""), true)
+	if expected, got := int64(12), nodeCount(); got != expected {
+		t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+	}
+	if expected, got := int64(2), serverCount(); got != expected {
+		t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+	}
+	doDeleteSubtree(t, rootCtx, estr, "1/2/3/4/5", true)
+	if expected, got := int64(0), nodeCount(); got != expected {
+		t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+	}
+	if expected, got := int64(0), serverCount(); got != expected {
+		t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+	}
+
+	// Test multiple servers per node
+	for i := 1; i <= 5; i++ {
+		server := naming.JoinAddressName(estr, fmt.Sprintf("addr%d", i))
+		doMount(t, rootCtx, estr, "node1", server, true)
+		doMount(t, rootCtx, estr, "node2", server, true)
+		if expected, got := int64(2), nodeCount(); got != expected {
+			t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+		}
+		if expected, got := int64(2*i), serverCount(); got != expected {
+			t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+		}
+	}
+	doUnmount(t, rootCtx, estr, "node1", "", true)
+	if expected, got := int64(1), nodeCount(); got != expected {
+		t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+	}
+	if expected, got := int64(5), serverCount(); got != expected {
+		t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+	}
+	for i := 1; i <= 5; i++ {
+		server := naming.JoinAddressName(estr, fmt.Sprintf("addr%d", i))
+		doUnmount(t, rootCtx, estr, "node2", server, true)
+		expectedNodes := int64(1)
+		if i == 5 {
+			expectedNodes = 0
+		}
+		if expected, got := expectedNodes, nodeCount(); got != expected {
+			t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+		}
+		if expected, got := int64(5-i), serverCount(); got != expected {
+			t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+		}
+	}
+
+	// Test expired mounts
+	doMount(t, rootCtx, estr, "1/2/3/4/5", naming.JoinAddressName(estr, ""), true)
+	if expected, got := int64(5), nodeCount(); got != expected {
+		t.Errorf("Unexpected number of nodes. Got %d, expected %d", got, expected)
+	}
+	if expected, got := int64(1), serverCount(); got != expected {
+		t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+	}
+
+	clock.AdvanceTime(time.Duration(ttlSecs+4) * time.Second)
+	if _, err := resolve(rootCtx, naming.JoinAddressName(estr, "1/2/3/4/5")); err == nil {
+		t.Errorf("Expected failure. Got success")
+	}
+	if expected, got := int64(0), serverCount(); got != expected {
+		t.Errorf("Unexpected number of servers. Got %d, expected %d", got, expected)
+	}
+}
+
 func TestIntermediateNodesCreatedFromConfig(t *testing.T) {
 	rootCtx, _, _, shutdown := initTest()
 	defer shutdown()
diff --git a/services/mounttable/btmtd/internal/node.go b/services/mounttable/btmtd/internal/node.go
index f07fd9e..3c98e38 100644
--- a/services/mounttable/btmtd/internal/node.go
+++ b/services/mounttable/btmtd/internal/node.go
@@ -28,16 +28,17 @@
 )
 
 type mtNode struct {
-	bt           *BigTable
-	name         string
-	sticky       bool
-	creationTime bigtable.Timestamp
-	permissions  access.Permissions
-	version      string
-	creator      string
-	mountFlags   mtFlags
-	servers      []naming.MountedServer
-	children     []string
+	bt             *BigTable
+	name           string
+	sticky         bool
+	creationTime   bigtable.Timestamp
+	permissions    access.Permissions
+	version        string
+	creator        string
+	mountFlags     mtFlags
+	servers        []naming.MountedServer
+	expiredServers []string
+	children       []string
 }
 
 type mtFlags struct {
@@ -98,7 +99,9 @@
 	n.servers = make([]naming.MountedServer, 0, len(row[serversFamily]))
 	for _, i := range row[serversFamily] {
 		deadline := i.Timestamp.Time()
+		server := i.Column[2:]
 		if deadline.Before(clock.Now()) {
+			n.expiredServers = append(n.expiredServers, server)
 			continue
 		}
 		if err := json.Unmarshal(i.Value, &n.mountFlags); err != nil {
@@ -106,7 +109,7 @@
 			return nil
 		}
 		n.servers = append(n.servers, naming.MountedServer{
-			Server:   i.Column[2:],
+			Server:   server,
 			Deadline: vdltime.Deadline{deadline},
 		})
 	}
@@ -150,9 +153,16 @@
 }
 
 func (n *mtNode) mount(ctx *context.T, server string, deadline time.Time, flags naming.MountFlag) error {
+	delta := int64(1)
 	mut := bigtable.NewMutation()
-	if flags&naming.Replace != 0 {
-		mut.DeleteCellsInFamily(serversFamily)
+	for _, s := range n.servers {
+		// Mount replaces an already mounted server with the same name,
+		// or all servers if the Replace flag is set.
+		if s.Server != server && flags&naming.Replace == 0 {
+			continue
+		}
+		delta--
+		mut.DeleteCellsInColumn(serversFamily, s.Server)
 	}
 	f := mtFlags{
 		MT:   flags&naming.MT != 0,
@@ -166,26 +176,23 @@
 	if err := n.mutate(ctx, mut, false); err != nil {
 		return err
 	}
-	if flags&naming.Replace != 0 {
-		n.servers = nil
-	}
-	return nil
+	return incrementCreatorServerCount(ctx, n.bt, n.creator, delta)
 }
 
 func (n *mtNode) unmount(ctx *context.T, server string) error {
+	delta := int64(0)
 	mut := bigtable.NewMutation()
-	if server == "" {
-		// HACK ALERT
-		// The bttest server doesn't support DeleteCellsInFamily
-		if !n.bt.testMode {
-			mut.DeleteCellsInFamily(serversFamily)
-		} else {
-			for _, s := range n.servers {
-				mut.DeleteCellsInColumn(serversFamily, s.Server)
-			}
+	for _, s := range n.servers {
+		// Unmount removes the specified server, or all servers if
+		// server == "".
+		if server != "" && s.Server != server {
+			continue
 		}
-	} else {
-		mut.DeleteCellsInColumn(serversFamily, server)
+		delta--
+		mut.DeleteCellsInColumn(serversFamily, s.Server)
+	}
+	if delta == 0 {
+		return nil
 	}
 	if err := n.mutate(ctx, mut, false); err != nil {
 		return err
@@ -193,16 +200,42 @@
 	if n, err := getNode(ctx, n.bt, n.name); err == nil {
 		n.gc(ctx)
 	}
-	return nil
+	return incrementCreatorServerCount(ctx, n.bt, n.creator, delta)
 }
 
-func (n *mtNode) gc(ctx *context.T) (deletedAtLeastOne bool, err error) {
-	for n != nil && n.name != "" && !n.sticky && len(n.children) == 0 && len(n.servers) == 0 {
+func (n *mtNode) gc(ctx *context.T) (deletedSomething bool, err error) {
+	for n != nil && n.name != "" {
+		if len(n.expiredServers) > 0 {
+			mut := bigtable.NewMutation()
+			for _, s := range n.expiredServers {
+				mut.DeleteCellsInColumn(serversFamily, s)
+			}
+			if err = n.mutate(ctx, mut, false); err != nil {
+				break
+			}
+			delta := -int64(len(n.expiredServers))
+			if err = incrementCreatorServerCount(ctx, n.bt, n.creator, delta); err != nil {
+				// TODO(rthellend): Since counters are stored in different rows,
+				// there is no way to update them atomically, e.g. if the server
+				// dies here, or if incrementCreatorServerCount returns an error,
+				// the server counter will be off.
+				// The same thing could happen everywhere the counters are updated.
+				// If/when we start using these counters for quota enforcement, we
+				// should also come up with a way to make sure the counters aren't
+				// too far off.
+				break
+			}
+			deletedSomething = true
+			break
+		}
+		if n.sticky || len(n.children) > 0 || len(n.servers) > 0 {
+			break
+		}
 		if err = n.delete(ctx, false); err != nil {
 			break
 		}
 		ctx.Infof("Deleted empty node %q", n.name)
-		deletedAtLeastOne = true
+		deletedSomething = true
 		parent := path.Dir(n.name)
 		if parent == "." {
 			break
@@ -278,7 +311,10 @@
 	if err := n.bt.apply(longCtx, rowKey(parent), mut); err != nil {
 		return err
 	}
-	return n.bt.incrementCreatorNodeCount(ctx, n.creator, -1)
+	if err := incrementCreatorServerCount(ctx, n.bt, n.creator, -int64(len(n.servers))); err != nil {
+		return err
+	}
+	return incrementCreatorNodeCount(ctx, n.bt, n.creator, -1)
 }
 
 func (n *mtNode) setPermissions(ctx *context.T, perms access.Permissions) error {