store/leveldb: handle concurrent transactions

Change-Id: I7d060a64ec1ca6d47dae7d250ed9938feb0d8e6a
diff --git a/services/syncbase/store/leveldb/db.go b/services/syncbase/store/leveldb/db.go
index e5e32cb..e163833 100644
--- a/services/syncbase/store/leveldb/db.go
+++ b/services/syncbase/store/leveldb/db.go
@@ -11,6 +11,8 @@
 // #include "syncbase_leveldb.h"
 import "C"
 import (
+	"container/list"
+	"fmt"
 	"sync"
 	"unsafe"
 
@@ -28,9 +30,20 @@
 	readOptions  *C.leveldb_readoptions_t
 	writeOptions *C.leveldb_writeoptions_t
 	err          error
-	// Used to prevent concurrent transactions.
-	// TODO(rogulenko): improve concurrency.
+
+	// TODO(rogulenko): decide whether we need to make a defensive copy of
+	// keys/values used by transactions.
+	// txmu protects transaction-related variables below. It is also held during
+	// commit.
+	// txmu must always be acquired before mu.
 	txmu sync.Mutex
+	// txEvents is a queue of create/commit transaction events.
+	txEvents         *list.List
+	txSequenceNumber uint64
+	// txTable is a set of keys written by recent transactions. This set
+	// includes all write sets of transactions committed after the oldest living
+	// transaction.
+	txTable *trie
 }
 
 var _ store.Store = (*db)(nil)
@@ -58,6 +71,8 @@
 		cDb:          cDb,
 		readOptions:  readOptions,
 		writeOptions: C.leveldb_writeoptions_create(),
+		txEvents:     list.New(),
+		txTable:      newTrie(),
 	}, nil
 }
 
@@ -107,28 +122,30 @@
 
 // Put implements the store.StoreWriter interface.
 func (d *db) Put(key, value []byte) error {
-	// TODO(rogulenko): improve performance.
-	return store.RunInTransaction(d, func(st store.StoreReadWriter) error {
-		return st.Put(key, value)
-	})
+	write := writeOp{
+		t:     putOp,
+		key:   key,
+		value: value,
+	}
+	return d.write([]writeOp{write}, d.writeOptions)
 }
 
 // Delete implements the store.StoreWriter interface.
 func (d *db) Delete(key []byte) error {
-	// TODO(rogulenko): improve performance.
-	return store.RunInTransaction(d, func(st store.StoreReadWriter) error {
-		return st.Delete(key)
-	})
+	write := writeOp{
+		t:   deleteOp,
+		key: key,
+	}
+	return d.write([]writeOp{write}, d.writeOptions)
 }
 
 // NewTransaction implements the store.Store interface.
 func (d *db) NewTransaction() store.Transaction {
-	// txmu is held until the transaction is successfully committed or aborted.
 	d.txmu.Lock()
+	defer d.txmu.Unlock()
 	d.mu.RLock()
 	defer d.mu.RUnlock()
 	if d.err != nil {
-		d.txmu.Unlock()
 		return &store.InvalidTransaction{d.err}
 	}
 	return newTransaction(d, d.node)
@@ -144,6 +161,65 @@
 	return newSnapshot(d, d.node)
 }
 
+// write writes a batch and adds all written keys to  txTable.
+// TODO(rogulenko): remove this method.
+func (d *db) write(batch []writeOp, cOpts *C.leveldb_writeoptions_t) error {
+	d.txmu.Lock()
+	defer d.txmu.Unlock()
+	return d.writeLocked(batch, cOpts)
+}
+
+// writeLocked is like write(), but it assumes txmu is held.
+func (d *db) writeLocked(batch []writeOp, cOpts *C.leveldb_writeoptions_t) error {
+	d.mu.Lock()
+	defer d.mu.Unlock()
+	if d.err != nil {
+		return d.err
+	}
+	cBatch := C.leveldb_writebatch_create()
+	defer C.leveldb_writebatch_destroy(cBatch)
+	for _, write := range batch {
+		switch write.t {
+		case putOp:
+			cKey, cKeyLen := cSlice(write.key)
+			cVal, cValLen := cSlice(write.value)
+			C.leveldb_writebatch_put(cBatch, cKey, cKeyLen, cVal, cValLen)
+		case deleteOp:
+			cKey, cKeyLen := cSlice(write.key)
+			C.leveldb_writebatch_delete(cBatch, cKey, cKeyLen)
+		default:
+			panic(fmt.Sprintf("unknown write operation type: %v", write.t))
+		}
+	}
+	var cError *C.char
+	C.leveldb_write(d.cDb, cOpts, cBatch, &cError)
+	if err := goError(cError); err != nil {
+		return err
+	}
+	if d.txEvents.Len() == 0 {
+		return nil
+	}
+	d.trackBatch(batch)
+	return nil
+}
+
+// trackBatch writes the batch to txTable, adds a commit event to txEvents.
+func (d *db) trackBatch(batch []writeOp) {
+	// TODO(rogulenko): do GC.
+	d.txSequenceNumber++
+	seq := d.txSequenceNumber
+	var keys [][]byte
+	for _, write := range batch {
+		d.txTable.add(write.key, seq)
+		keys = append(keys, write.key)
+	}
+	tx := &commitedTransaction{
+		seq:   seq,
+		batch: keys,
+	}
+	d.txEvents.PushBack(tx)
+}
+
 // getWithOpts returns the value for the given key.
 // cOpts may contain a pointer to a snapshot.
 func (d *db) getWithOpts(key, valbuf []byte, cOpts *C.leveldb_readoptions_t) ([]byte, error) {
diff --git a/services/syncbase/store/leveldb/db_test.go b/services/syncbase/store/leveldb/db_test.go
index 65594f5..cf91b0d 100644
--- a/services/syncbase/store/leveldb/db_test.go
+++ b/services/syncbase/store/leveldb/db_test.go
@@ -42,6 +42,10 @@
 	runTest(t, test.RunReadWriteRandomTest)
 }
 
+func TestConcurrentTransactions(t *testing.T) {
+	runTest(t, test.RunConcurrentTransactionsTest)
+}
+
 func TestTransactionState(t *testing.T) {
 	runTest(t, test.RunTransactionStateTest)
 }
diff --git a/services/syncbase/store/leveldb/transaction.go b/services/syncbase/store/leveldb/transaction.go
index 6dfa250..e7a1b8a 100644
--- a/services/syncbase/store/leveldb/transaction.go
+++ b/services/syncbase/store/leveldb/transaction.go
@@ -7,12 +7,41 @@
 // #include "leveldb/c.h"
 import "C"
 import (
+	"container/list"
 	"sync"
 
 	"v.io/syncbase/x/ref/services/syncbase/store"
 	"v.io/v23/verror"
 )
 
+type scanRange struct {
+	start, limit []byte
+}
+
+type readSet struct {
+	keys   [][]byte
+	ranges []scanRange
+}
+
+type writeType int
+
+const (
+	putOp writeType = iota
+	deleteOp
+)
+
+type writeOp struct {
+	t     writeType
+	key   []byte
+	value []byte
+}
+
+// commitedTransaction is only used as an element of db.txEvents.
+type commitedTransaction struct {
+	seq   uint64
+	batch [][]byte
+}
+
 // transaction is a wrapper around LevelDB WriteBatch that implements
 // the store.Transaction interface.
 type transaction struct {
@@ -20,8 +49,11 @@
 	mu       sync.Mutex
 	node     *store.ResourceNode
 	d        *db
+	seq      uint64
+	event    *list.Element // pointer to element of db.txEvents
 	snapshot store.Snapshot
-	batch    *C.leveldb_writebatch_t
+	reads    readSet
+	writes   []writeOp
 	cOpts    *C.leveldb_writeoptions_t
 	err      error
 }
@@ -35,9 +67,10 @@
 		node:     node,
 		d:        d,
 		snapshot: snapshot,
-		batch:    C.leveldb_writebatch_create(),
+		seq:      d.txSequenceNumber,
 		cOpts:    d.writeOptions,
 	}
+	tx.event = d.txEvents.PushFront(tx)
 	parent.AddChild(tx.node, func() {
 		tx.Abort()
 	})
@@ -47,16 +80,28 @@
 // close frees allocated C objects and releases acquired locks.
 // Assumes mu is held.
 func (tx *transaction) close() {
-	tx.d.txmu.Unlock()
+	tx.removeEvent()
 	tx.node.Close()
-	C.leveldb_writebatch_destroy(tx.batch)
-	tx.batch = nil
 	if tx.cOpts != tx.d.writeOptions {
 		C.leveldb_writeoptions_destroy(tx.cOpts)
 	}
 	tx.cOpts = nil
 }
 
+// removeEvent removes this transaction from the db.txEvents queue.
+// Assumes mu is held.
+func (tx *transaction) removeEvent() {
+	// This can happen if the transaction was committed, since Commit()
+	// explicitly calls removeEvent().
+	if tx.event == nil {
+		return
+	}
+	tx.d.txmu.Lock()
+	tx.d.txEvents.Remove(tx.event)
+	tx.d.txmu.Unlock()
+	tx.event = nil
+}
+
 // Get implements the store.StoreReader interface.
 func (tx *transaction) Get(key, valbuf []byte) ([]byte, error) {
 	tx.mu.Lock()
@@ -64,6 +109,7 @@
 	if tx.err != nil {
 		return valbuf, store.WrapError(tx.err)
 	}
+	tx.reads.keys = append(tx.reads.keys, key)
 	return tx.snapshot.Get(key, valbuf)
 }
 
@@ -74,6 +120,10 @@
 	if tx.err != nil {
 		return &store.InvalidStream{tx.err}
 	}
+	tx.reads.ranges = append(tx.reads.ranges, scanRange{
+		start: start,
+		limit: limit,
+	})
 	return tx.snapshot.Scan(start, limit)
 }
 
@@ -84,9 +134,11 @@
 	if tx.err != nil {
 		return store.WrapError(tx.err)
 	}
-	cKey, cKeyLen := cSlice(key)
-	cVal, cValLen := cSlice(value)
-	C.leveldb_writebatch_put(tx.batch, cKey, cKeyLen, cVal, cValLen)
+	tx.writes = append(tx.writes, writeOp{
+		t:     putOp,
+		key:   key,
+		value: value,
+	})
 	return nil
 }
 
@@ -97,11 +149,31 @@
 	if tx.err != nil {
 		return store.WrapError(tx.err)
 	}
-	cKey, cKeyLen := cSlice(key)
-	C.leveldb_writebatch_delete(tx.batch, cKey, cKeyLen)
+	tx.writes = append(tx.writes, writeOp{
+		t:   deleteOp,
+		key: key,
+	})
 	return nil
 }
 
+// validateReadSet returns true iff the read set of this transaction has not
+// been invalidated by other transactions.
+// Assumes tx.d.txmu is held.
+func (tx *transaction) validateReadSet() bool {
+	for _, key := range tx.reads.keys {
+		if tx.d.txTable.get(key) > tx.seq {
+			return false
+		}
+	}
+	for _, r := range tx.reads.ranges {
+		if tx.d.txTable.rangeMax(r.start, r.limit) > tx.seq {
+			return false
+		}
+
+	}
+	return true
+}
+
 // Commit implements the store.Transaction interface.
 func (tx *transaction) Commit() error {
 	tx.mu.Lock()
@@ -109,19 +181,23 @@
 	if tx.err != nil {
 		return store.WrapError(tx.err)
 	}
-	tx.d.mu.Lock()
-	defer tx.d.mu.Unlock()
-	var cError *C.char
-	C.leveldb_write(tx.d.cDb, tx.cOpts, tx.batch, &cError)
-	if err := goError(cError); err != nil {
-		// Once Commit() has failed with store.ErrConcurrentTransaction, subsequent
-		// ops on the transaction will fail with the following error.
+	// Explicitly remove this transaction from the event queue. If this was the
+	// only active transaction, the event queue becomes empty and writeLocked will
+	// not add the write set of this transaction to the txTable.
+	tx.removeEvent()
+	defer tx.close()
+	tx.d.txmu.Lock()
+	defer tx.d.txmu.Unlock()
+	if !tx.validateReadSet() {
+		return store.NewErrConcurrentTransaction(nil)
+	}
+	if err := tx.d.writeLocked(tx.writes, tx.cOpts); err != nil {
+		// Once Commit() has failed, subsequent ops on the transaction will fail
+		// with the following error.
 		tx.err = verror.New(verror.ErrBadState, nil, "already attempted to commit transaction")
-		tx.close()
 		return err
 	}
 	tx.err = verror.New(verror.ErrBadState, nil, "committed transaction")
-	tx.close()
 	return nil
 }
 
diff --git a/services/syncbase/store/leveldb/trie.go b/services/syncbase/store/leveldb/trie.go
new file mode 100644
index 0000000..78a6982
--- /dev/null
+++ b/services/syncbase/store/leveldb/trie.go
@@ -0,0 +1,75 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package leveldb
+
+import (
+	"fmt"
+)
+
+// trie is an in-memory data structure that keeps track of recently-written
+// keys, and exposes an interface for asking when a key or key range was most
+// recently written to. It is used to check whether the read set of a
+// transaction pending commit is still valid. The transaction can be committed
+// iff its read set is valid.
+// TODO(rogulenko): replace this dummy implementation with an actual trie.
+type trie struct {
+	seqs map[string]uint64
+}
+
+func newTrie() *trie {
+	return &trie{
+		seqs: make(map[string]uint64),
+	}
+}
+
+// add updates the given key to the given seq, which must be greater than the
+// current seq (if one exists). Seqs of subsequent calls must be in
+// ascending order.
+func (t *trie) add(key []byte, seq uint64) {
+	keystr := string(key)
+	if oldSeq, ok := t.seqs[keystr]; ok && seq < oldSeq {
+		panic(fmt.Sprintf("seq for key %q should be at least %d, but got %d", key, oldSeq, seq))
+	}
+	t.seqs[keystr] = seq
+}
+
+// remove reverts effect of add(key, seq).
+// Seqs of subsequent calls must be in ascending order.
+func (t *trie) remove(key []byte, seq uint64) {
+	keystr := string(key)
+	oldSeq, ok := t.seqs[keystr]
+	if !ok {
+		panic(fmt.Sprintf("key %q was not found", key))
+	}
+	if oldSeq > seq {
+		return
+	} else if oldSeq == seq {
+		delete(t.seqs, keystr)
+	} else {
+		panic(fmt.Sprintf("seq for key %q is too big: got %v, want %v", keystr, seq, oldSeq))
+	}
+}
+
+// get returns the seq associated with the given key.
+func (t *trie) get(key []byte) uint64 {
+	keystr := string(key)
+	if seq, ok := t.seqs[keystr]; ok {
+		return seq
+	}
+	return 0
+}
+
+// rangeMax returns the max seq associated with keys in range
+// [start, limit). Empty limit means no limit.
+func (t *trie) rangeMax(start, limit []byte) uint64 {
+	var result uint64 = 0
+	s, e := string(start), string(limit)
+	for key, seq := range t.seqs {
+		if key >= s && (e == "" || key < e) && seq > result {
+			result = seq
+		}
+	}
+	return result
+}
diff --git a/services/syncbase/store/memstore/store_test.go b/services/syncbase/store/memstore/store_test.go
index 1b11ac3..fef07b1 100644
--- a/services/syncbase/store/memstore/store_test.go
+++ b/services/syncbase/store/memstore/store_test.go
@@ -45,9 +45,7 @@
 }
 
 func TestTransactionsWithGet(t *testing.T) {
-	// TODO(sadovsky): Enable this test once we've added a retry loop to
-	// RunInTransaction. Without that, concurrency makes the test fail.
-	// runTest(t, test.RunTransactionsWithGetTest)
+	runTest(t, test.RunTransactionsWithGetTest)
 }
 
 func runTest(t *testing.T, f func(t *testing.T, st store.Store)) {
diff --git a/services/syncbase/store/test/store.go b/services/syncbase/store/test/store.go
index 0a081fd..784c545 100644
--- a/services/syncbase/store/test/store.go
+++ b/services/syncbase/store/test/store.go
@@ -213,11 +213,10 @@
 	var streams []store.Stream
 	var snapshots []store.Snapshot
 	var transactions []store.Transaction
-	// TODO(rogulenko): make multiple transactions.
-	tx := st.NewTransaction()
 	for i := 0; i < 10; i++ {
 		streams = append(streams, st.Scan([]byte("a"), []byte("z")))
 		snapshot := st.NewSnapshot()
+		tx := st.NewTransaction()
 		for j := 0; j < 10; j++ {
 			streams = append(streams, snapshot.Scan([]byte("a"), []byte("z")))
 			streams = append(streams, tx.Scan([]byte("a"), []byte("z")))
diff --git a/services/syncbase/store/test/transaction.go b/services/syncbase/store/test/transaction.go
index 70dac8f..6aa5861 100644
--- a/services/syncbase/store/test/transaction.go
+++ b/services/syncbase/store/test/transaction.go
@@ -5,6 +5,7 @@
 package test
 
 import (
+	"bytes"
 	"fmt"
 	"math/rand"
 	"strconv"
@@ -61,6 +62,61 @@
 	}
 }
 
+// RunConcurrentTransactionsTest verifies that concurrent transactions
+// invalidate each other as expected.
+func RunConcurrentTransactionsTest(t *testing.T, st store.Store) {
+	st.Put([]byte("a"), []byte("0"))
+	st.Put([]byte("b"), []byte("0"))
+	st.Put([]byte("c"), []byte("0"))
+	// Test Get fails.
+	txA := st.NewTransaction()
+	txB := st.NewTransaction()
+	txA.Get([]byte("a"), nil)
+	txB.Get([]byte("a"), nil)
+	txA.Put([]byte("a"), []byte("a"))
+	txB.Put([]byte("a"), []byte("b"))
+	if err := txA.Commit(); err != nil {
+		t.Fatalf("can't commit the transaction: %v", err)
+	}
+	if err := txB.Commit(); verror.ErrorID(err) != store.ErrConcurrentTransaction.ID {
+		t.Fatalf("unexpected commit error: %v", err)
+	}
+	if value, _ := st.Get([]byte("a"), nil); !bytes.Equal(value, []byte("a")) {
+		t.Fatalf("unexpected value: got %q, want %q", value, "a")
+	}
+	// Test Scan fails.
+	txA = st.NewTransaction()
+	txB = st.NewTransaction()
+	txA.Scan([]byte("a"), []byte("z"))
+	txB.Scan([]byte("a"), []byte("z"))
+	txA.Put([]byte("aa"), []byte("a"))
+	txB.Put([]byte("bb"), []byte("b"))
+	if err := txA.Commit(); err != nil {
+		t.Fatalf("can't commit the transaction: %v", err)
+	}
+	if err := txB.Commit(); verror.ErrorID(err) != store.ErrConcurrentTransaction.ID {
+		t.Fatalf("unexpected commit error: %v", err)
+	}
+	if value, _ := st.Get([]byte("aa"), nil); !bytes.Equal(value, []byte("a")) {
+		t.Fatalf("unexpected value: got %q, want %q", value, "a")
+	}
+	// Test Get and Scan OK.
+	txA = st.NewTransaction()
+	txB = st.NewTransaction()
+	txA.Scan([]byte("a"), []byte("b"))
+	txB.Scan([]byte("b"), []byte("c"))
+	txA.Get([]byte("c"), nil)
+	txB.Get([]byte("c"), nil)
+	txA.Put([]byte("a"), []byte("a"))
+	txB.Put([]byte("b"), []byte("b"))
+	if err := txA.Commit(); err != nil {
+		t.Fatalf("can't commit the transaction: %v", err)
+	}
+	if err := txB.Commit(); err != nil {
+		t.Fatalf("can't commit the transaction: %v", err)
+	}
+}
+
 // RunTransactionsWithGetTest tests transactions that use Put and Get
 // operations.
 // NOTE: consider setting GOMAXPROCS to something greater than 1.
@@ -78,8 +134,6 @@
 	}
 	var wg sync.WaitGroup
 	wg.Add(k)
-	// TODO(sadovsky): This configuration creates huge resource contention.
-	// Perhaps we should add some random sleep's to reduce the contention.
 	for i := 0; i < k; i++ {
 		go func() {
 			rnd := rand.New(rand.NewSource(239017 * int64(i)))
diff --git a/services/syncbase/store/util.go b/services/syncbase/store/util.go
index 447c305..3d0867a 100644
--- a/services/syncbase/store/util.go
+++ b/services/syncbase/store/util.go
@@ -8,22 +8,25 @@
 	"v.io/v23/verror"
 )
 
-// TODO(sadovsky): Add retry loop.
 func RunInTransaction(st Store, fn func(st StoreReadWriter) error) error {
-	tx := st.NewTransaction()
-	if err := fn(tx); err != nil {
-		tx.Abort()
+	// TODO(rogulenko): We should eventually give up with
+	// ErrConcurrentTransaction.
+	// TODO(rogulenko): Fail on RPC errors.
+	for {
+		tx := st.NewTransaction()
+		if err := fn(tx); err != nil {
+			tx.Abort()
+			return err
+		}
+		err := tx.Commit()
+		if err == nil {
+			return nil
+		}
+		if verror.ErrorID(err) == ErrConcurrentTransaction.ID {
+			continue
+		}
 		return err
 	}
-	if err := tx.Commit(); err != nil {
-		// TODO(sadovsky): Commit() can fail for a number of reasons, e.g. RPC
-		// failure or ErrConcurrentTransaction. Depending on the cause of failure,
-		// it may be desirable to retry the Commit() and/or to call Abort(). For
-		// now, we always abort on a failed commit.
-		tx.Abort()
-		return err
-	}
-	return nil
 }
 
 // CopyBytes copies elements from a source slice into a destination slice.