veyron/services/store: A transaction may be used in only the session that created it.
For now, we identify a session by the local and remote principals.
Change-Id: I5d58bc3c10a1cf918bbdda88c095f34e3d0981c3
diff --git a/services/store/server/object.go b/services/store/server/object.go
index 377a57f..96c099f 100644
--- a/services/store/server/object.go
+++ b/services/store/server/object.go
@@ -87,9 +87,9 @@
// Exists returns true iff the Entry has a value.
func (o *object) Exists(ctx ipc.ServerContext, tid store.TransactionID) (bool, error) {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return false, errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return false, err
}
return o.obj.Exists(ctx.RemoteID(), t)
}
@@ -98,9 +98,9 @@
// most recent mutation of the entry in the Transaction, or from the
// Transaction's snapshot if there is no mutation.
func (o *object) Get(ctx ipc.ServerContext, tid store.TransactionID) (store.Entry, error) {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return nullEntry, errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return nullEntry, err
}
entry, err := o.obj.Get(ctx.RemoteID(), t)
if err != nil {
@@ -111,9 +111,9 @@
// Put modifies the value of the Object.
func (o *object) Put(ctx ipc.ServerContext, tid store.TransactionID, val vdl.Any) (store.Stat, error) {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return nullStat, errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return nullStat, err
}
stat, err := o.obj.Put(ctx.RemoteID(), t, interface{}(val))
if err != nil {
@@ -124,9 +124,9 @@
// Remove removes the Object.
func (o *object) Remove(ctx ipc.ServerContext, tid store.TransactionID) error {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return err
}
return o.obj.Remove(ctx.RemoteID(), t)
}
@@ -135,9 +135,9 @@
// replication groups. Attributes are associated with the value, not the
// path.
func (o *object) SetAttr(ctx ipc.ServerContext, tid store.TransactionID, attrs []vdl.Any) error {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return err
}
typedAttrs, err := attrsFromAnyData(attrs)
if err != nil {
@@ -148,9 +148,9 @@
// Stat returns entry info.
func (o *object) Stat(ctx ipc.ServerContext, tid store.TransactionID) (store.Stat, error) {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return nullStat, errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return nullStat, err
}
stat, err := o.obj.Stat(ctx.RemoteID(), t)
if err != nil {
@@ -161,9 +161,9 @@
// Query returns a sequence of objects that match the given query.
func (o *object) Query(ctx ipc.ServerContext, tid store.TransactionID, q query.Query, stream store.ObjectServiceQueryStream) error {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return err
}
it, err := o.obj.Query(ctx.RemoteID(), t, q)
if err != nil {
@@ -195,9 +195,9 @@
// Glob streams a series of names that match the given pattern.
func (o *object) GlobT(ctx ipc.ServerContext, tid store.TransactionID, pattern string, stream store.ObjectServiceGlobTStream) error {
- t, ok := o.server.findTransaction(tid)
- if !ok {
- return errTransactionDoesNotExist
+ t, err := o.server.findTransaction(ctx, tid)
+ if err != nil {
+ return err
}
it, err := o.obj.Glob(ctx.RemoteID(), t, pattern)
if err != nil {
diff --git a/services/store/server/server.go b/services/store/server/server.go
index 504d560..ae6bfd8 100644
--- a/services/store/server/server.go
+++ b/services/store/server/server.go
@@ -3,6 +3,7 @@
import (
"errors"
+ "reflect"
"sync"
"time"
@@ -34,6 +35,8 @@
errTransactionAlreadyExists = errors.New("transaction already exists")
errTransactionDoesNotExist = errors.New("transaction does not exist")
+ // Transaction exists, but may not be used by the caller.
+ errPermissionDenied = errors.New("permission denied")
)
// Server stores the dictionary of all media items. It has a scanner.Scanner
@@ -57,9 +60,23 @@
watcher service.Watcher
}
+// transactionContext defines the context in which a transaction is used. A
+// transaction may be used only in the context that created it.
+// transactionContext weakly identifies a session by the local and remote
+// principals involved in the RPC.
+// TODO(tilaks): Use the local and remote addresses to identify the session.
+// Does a session with a mobile device break if the remote address changes?
+type transactionContext interface {
+ // LocalID returns the PublicID of the principal at the local end of the request.
+ LocalID() security.PublicID
+ // RemoteID returns the PublicID of the principal at the remote end of the request.
+ RemoteID() security.PublicID
+}
+
type transaction struct {
- trans service.Transaction
- expires time.Time
+ trans service.Transaction
+ expires time.Time
+ creatorCtx transactionContext
}
// ServerConfig provides the parameters needed to construct a Server.
@@ -111,21 +128,47 @@
}
// findTransaction returns the transaction for the TransactionID.
-func (s *Server) findTransaction(id store.TransactionID) (service.Transaction, bool) {
+func (s *Server) findTransaction(ctx transactionContext, id store.TransactionID) (service.Transaction, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
- return s.findTransactionLocked(id)
+ return s.findTransactionLocked(ctx, id)
}
-func (s *Server) findTransactionLocked(id store.TransactionID) (service.Transaction, bool) {
+func (s *Server) findTransactionLocked(ctx transactionContext, id store.TransactionID) (service.Transaction, error) {
if id == nullTransactionID {
- return nil, true
+ return nil, nil
}
info, ok := s.transactions[id]
if !ok {
- return nil, false
+ return nil, errTransactionDoesNotExist
}
- return info.trans, true
+ // A transaction may be used only by the session (and therefore client)
+ // that created it.
+ if !info.matchesContext(ctx) {
+ return nil, errPermissionDenied
+ }
+ return info.trans, nil
+}
+
+func (t *transaction) matchesContext(ctx transactionContext) bool {
+ creatorCtx := t.creatorCtx
+ return membersEqual(creatorCtx.LocalID().Names(), ctx.LocalID().Names()) &&
+ membersEqual(creatorCtx.RemoteID().Names(), ctx.RemoteID().Names())
+}
+
+// membersEquals checks whether two slices of strings have the same set of
+// members, regardless of order.
+func membersEqual(slice1, slice2 []string) bool {
+ set1 := make(map[string]bool, len(slice1))
+ for _, s := range slice1 {
+ set1[s] = true
+ }
+ set2 := make(map[string]bool, len(slice2))
+ for _, s := range slice2 {
+ set2[s] = true
+ }
+ // DeepEqual tests keys for == equality, which is sufficient for strings.
+ return reflect.DeepEqual(set1, set2)
}
// gcLoop drops transactions that have expired.
@@ -151,7 +194,7 @@
}
// CreateTransaction creates a transaction.
-func (s *Server) CreateTransaction(_ ipc.ServerContext, id store.TransactionID, opts []vdl.Any) error {
+func (s *Server) CreateTransaction(ctx ipc.ServerContext, id store.TransactionID, opts []vdl.Any) error {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -160,8 +203,9 @@
return errTransactionAlreadyExists
}
info = &transaction{
- trans: memstore.NewTransaction(),
- expires: time.Now().Add(transactionMaxLifetime),
+ trans: memstore.NewTransaction(),
+ expires: time.Now().Add(transactionMaxLifetime),
+ creatorCtx: ctx,
}
s.transactions[id] = info
return nil
@@ -170,31 +214,37 @@
// Commit commits the changes in the transaction to the store. The
// operation is atomic, so all mutations are performed, or none. Returns an
// error if the transaction aborted.
-func (s *Server) Commit(_ ipc.ServerContext, id store.TransactionID) error {
+func (s *Server) Commit(ctx ipc.ServerContext, id store.TransactionID) error {
s.mutex.Lock()
defer s.mutex.Unlock()
- t, ok := s.findTransactionLocked(id)
- if !ok {
+ t, err := s.findTransactionLocked(ctx, id)
+ if err != nil {
+ return err
+ }
+ if t == nil {
return errTransactionDoesNotExist
}
- err := t.Commit()
+ err = t.Commit()
delete(s.transactions, id)
return err
}
// Abort discards a transaction.
-func (s *Server) Abort(_ ipc.ServerContext, id store.TransactionID) error {
+func (s *Server) Abort(ctx ipc.ServerContext, id store.TransactionID) error {
s.mutex.Lock()
defer s.mutex.Unlock()
- t, ok := s.transactions[id]
- if !ok {
+ t, err := s.findTransactionLocked(ctx, id)
+ if err != nil {
+ return err
+ }
+ if t == nil {
return errTransactionDoesNotExist
}
- t.trans.Abort()
+ err = t.Abort()
delete(s.transactions, id)
- return nil
+ return err
}
// Watch returns a stream of changes.
diff --git a/services/store/server/server_test.go b/services/store/server/server_test.go
index 5e56f73..fa87c83 100644
--- a/services/store/server/server_test.go
+++ b/services/store/server/server_test.go
@@ -20,65 +20,69 @@
)
var (
- rootPublicID security.PublicID = security.FakePublicID("root")
- rootName = fmt.Sprintf("%s", rootPublicID)
+ rootPublicID security.PublicID = security.FakePublicID("root")
+ rootName = fmt.Sprintf("%s", rootPublicID)
+ blessedPublicId security.PublicID = security.FakePublicID("root/blessed")
nextTransactionID store.TransactionID = 1
- rootCtx ipc.ServerContext = &rootContext{}
+ rootCtx ipc.ServerContext = &testContext{rootPublicID}
+ blessedCtx ipc.ServerContext = &testContext{blessedPublicId}
)
-type rootContext struct{}
+type testContext struct {
+ id security.PublicID
+}
-func (*rootContext) Server() ipc.Server {
+func (*testContext) Server() ipc.Server {
return nil
}
-func (*rootContext) Method() string {
+func (*testContext) Method() string {
return ""
}
-func (*rootContext) Name() string {
+func (*testContext) Name() string {
return ""
}
-func (*rootContext) Suffix() string {
+func (*testContext) Suffix() string {
return ""
}
-func (*rootContext) Label() (l security.Label) {
+func (*testContext) Label() (l security.Label) {
return
}
-func (*rootContext) CaveatDischarges() security.CaveatDischargeMap {
+func (*testContext) CaveatDischarges() security.CaveatDischargeMap {
return nil
}
-func (*rootContext) LocalID() security.PublicID {
- return rootPublicID
+func (ctx *testContext) LocalID() security.PublicID {
+ return ctx.id
}
-func (*rootContext) RemoteID() security.PublicID {
- return rootPublicID
+func (ctx *testContext) RemoteID() security.PublicID {
+ return ctx.id
}
-func (*rootContext) LocalAddr() net.Addr {
+func (*testContext) LocalAddr() net.Addr {
return nil
}
-func (*rootContext) RemoteAddr() net.Addr {
+func (*testContext) RemoteAddr() net.Addr {
return nil
}
-func (*rootContext) Deadline() (t time.Time) {
+func (*testContext) Deadline() (t time.Time) {
return
}
-func (rootContext) IsClosed() bool {
+func (testContext) IsClosed() bool {
return false
}
-func (rootContext) Closed() <-chan struct{} {
+func (testContext) Closed() <-chan struct{} {
return nil
}
@@ -197,18 +201,18 @@
o := s.lookupObject("/")
value := newValue()
tr1 := newTransaction()
- if err := s.CreateTransaction(nil, tr1, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr1, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if _, err := o.Put(rootCtx, tr1, value); err != nil {
t.Errorf("Unexpected error: %s", err)
}
- if err := s.Commit(nil, tr1); err != nil {
+ if err := s.Commit(rootCtx, tr1); err != nil {
t.Errorf("Unexpected error: %s", err)
}
tr2 := newTransaction()
- if err := s.CreateTransaction(nil, tr2, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr2, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr2); !ok || err != nil {
@@ -217,7 +221,7 @@
if _, err := o.Get(rootCtx, tr2); err != nil {
t.Errorf("Object should exist: %s", err)
}
- if err := s.Abort(nil, tr2); err != nil {
+ if err := s.Abort(rootCtx, tr2); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -231,7 +235,7 @@
{
// Check that the object does not exist.
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr); ok || err != nil {
@@ -245,7 +249,7 @@
{
// Add the object.
tr1 := newTransaction()
- if err := s.CreateTransaction(nil, tr1, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr1, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if _, err := o.Put(rootCtx, tr1, value); err != nil {
@@ -260,7 +264,7 @@
// Transactions are isolated.
tr2 := newTransaction()
- if err := s.CreateTransaction(nil, tr2, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr2, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr2); ok || err != nil {
@@ -271,7 +275,7 @@
}
// Apply tr1.
- if err := s.Commit(nil, tr1); err != nil {
+ if err := s.Commit(rootCtx, tr1); err != nil {
t.Errorf("Unexpected error: %s", err)
}
@@ -285,7 +289,7 @@
// tr3 observes the commit.
tr3 := newTransaction()
- if err := s.CreateTransaction(nil, tr3, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr3, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr3); !ok || err != nil {
@@ -299,7 +303,7 @@
{
// Remove the object.
tr1 := newTransaction()
- if err := s.CreateTransaction(nil, tr1, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr1, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if err := o.Remove(rootCtx, tr1); err != nil {
@@ -314,7 +318,7 @@
// The removal is isolated.
tr2 := newTransaction()
- if err := s.CreateTransaction(nil, tr2, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr2, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr2); !ok || err != nil {
@@ -325,7 +329,7 @@
}
// Apply tr1.
- if err := s.Commit(nil, tr1); err != nil {
+ if err := s.Commit(rootCtx, tr1); err != nil {
t.Errorf("Unexpected error: %s", err)
}
@@ -341,7 +345,7 @@
{
// Check that the object does not exist.
tr1 := newTransaction()
- if err := s.CreateTransaction(nil, tr1, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr1, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
if ok, err := o.Exists(rootCtx, tr1); ok || err != nil {
@@ -353,6 +357,19 @@
}
}
+func TestNilTransaction(t *testing.T) {
+ s, c := newServer()
+ defer c()
+
+ if err := s.Commit(rootCtx, nullTransactionID); err != errTransactionDoesNotExist {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ if err := s.Abort(rootCtx, nullTransactionID); err != errTransactionDoesNotExist {
+ t.Errorf("Unexpected error: %v", err)
+ }
+}
+
func TestWatch(t *testing.T) {
s, c := newServer()
defer c()
@@ -364,7 +381,7 @@
// Before the watch request has been made, commit a transaction that puts /.
{
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
o := s.lookupObject(path1)
@@ -373,7 +390,7 @@
t.Errorf("Unexpected error: %s", err)
}
id1 = st.ID
- if err := s.Commit(nil, tr); err != nil {
+ if err := s.Commit(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -401,7 +418,7 @@
// Commit a second transaction that puts /a.
{
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
o := s.lookupObject(path2)
@@ -410,7 +427,7 @@
t.Errorf("Unexpected error: %s", err)
}
id2 = st.ID
- if err := s.Commit(nil, tr); err != nil {
+ if err := s.Commit(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -455,7 +472,7 @@
// Before the watch request has been made, commit a transaction that puts /.
{
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
o := s.lookupObject(path1)
@@ -464,7 +481,7 @@
t.Errorf("Unexpected error: %s", err)
}
id1 = st.ID
- if err := s.Commit(nil, tr); err != nil {
+ if err := s.Commit(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -492,7 +509,7 @@
// Commit a second transaction that puts /a.
{
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
o := s.lookupObject(path2)
@@ -501,7 +518,7 @@
t.Errorf("Unexpected error: %s", err)
}
id2 = st.ID
- if err := s.Commit(nil, tr); err != nil {
+ if err := s.Commit(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -532,14 +549,14 @@
// Commit a third transaction that removes /a.
{
tr := newTransaction()
- if err := s.CreateTransaction(nil, tr, nil); err != nil {
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
t.Errorf("Unexpected error: %s", err)
}
o := s.lookupObject("/a")
if err := o.Remove(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
- if err := s.Commit(nil, tr); err != nil {
+ if err := s.Commit(rootCtx, tr); err != nil {
t.Errorf("Unexpected error: %s", err)
}
}
@@ -572,3 +589,68 @@
expectDoesNotExist(t, changes, id2)
}
}
+
+func TestTransactionSecurity(t *testing.T) {
+ s, c := newServer()
+ defer c()
+
+ // Create a root.
+ o := s.lookupObject("/")
+ value := newValue()
+
+ // Create a transaction in the root's session.
+ tr := newTransaction()
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != nil {
+ t.Errorf("Unexpected error: %s", err)
+ }
+ // Check that the transaction cannot be created or accessed by the blessee.
+ if err := s.CreateTransaction(blessedCtx, tr, nil); err != errTransactionAlreadyExists {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Exists(blessedCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Get(blessedCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Put(blessedCtx, tr, value); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := o.Remove(blessedCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := s.Abort(blessedCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := s.Commit(blessedCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ // Create a transaction in the blessee's session.
+ tr = newTransaction()
+ if err := s.CreateTransaction(blessedCtx, tr, nil); err != nil {
+ t.Errorf("Unexpected error: %s", err)
+ }
+ // Check that the transaction cannot be created or accessed by the root.
+ if err := s.CreateTransaction(rootCtx, tr, nil); err != errTransactionAlreadyExists {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Exists(rootCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Get(rootCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if _, err := o.Put(rootCtx, tr, value); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := o.Remove(rootCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := s.Abort(rootCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if err := s.Commit(rootCtx, tr); err != errPermissionDenied {
+ t.Errorf("Unexpected error: %v", err)
+ }
+}