veyron/services/store/memstore/query: Fix race conditions.
This change makes evalIterator thread-safe. The problem with the
old code was that Next() might not see something on the error
channel since select is non-deterministic. it.err might not be
set even though there is a pending error on errc.
Change-Id: I3e1651a5859a35f5d1c6b683cfd123cc60342fd2
diff --git a/services/store/memstore/query/eval.go b/services/store/memstore/query/eval.go
index ab71a4e..7061a67 100644
--- a/services/store/memstore/query/eval.go
+++ b/services/store/memstore/query/eval.go
@@ -34,8 +34,7 @@
// if err := it.Err(); err != nil {
// ...
// }
-// Iterator is thread-compatible. In particular, Abort must be called from
-// the same thread that is iterating.
+// Iterator is thread-safe.
type Iterator interface {
// Next advances the iterator. It must be called before calling Get.
// Returns true if there is a value to be retrieved with Get. Returns
@@ -49,27 +48,32 @@
// idempotent.
Err() error
- // Abort stops query evaluation early. The client must call Abort unless
- // iteration goes to completion (i.e. Next returns false). It is not
- // idempotent and must be called from the same thread doing the iteration.
+ // Abort stops query evaluation early. MThe client must call Abort unless
+ // iteration goes to completion (i.e. Next returns false). It is
+ // idempotent and can be called from any thread.
Abort()
}
// evalIterator implements Iterator.
type evalIterator struct {
+ // mu guards 'result', 'err', and the closing of 'abort'.
+ mu sync.Mutex
// result is what Get will return. It will be nil if there are no more
- // query results.
+ // query results. Guarded by mu.
result *store.QueryResult
+ // err is the first error encountered during query evaluation.
+ // Guarded by mu.
+ err error
+ // abort is used as the signal to query evaluation to terminate early.
+ // evaluator implementations will test for abort closing. The close()
+ // call is guarded by mu.
+ abort chan bool
+
// results is the output of the top-level evaluator for the query.
results <-chan *store.QueryResult
- // abort is used as the signal to query evaluation to terminate early.
- // evaluator implementations will test for abort closing.
- abort chan bool
// errc is the path that evaluator implementations use to pass errors
// to evalIterator. Any error will abort query evaluation.
errc chan error
- // err is the first error encountered during query evaluation.
- err error
// cleanup is used for testing to ensure that no goroutines are leaked.
cleanup sync.WaitGroup
}
@@ -80,41 +84,61 @@
return false
}
select {
- case it.err = <-it.errc:
- it.Abort()
- return false
case result, ok := <-it.results:
if !ok {
return false
}
+ it.mu.Lock()
+ defer it.mu.Unlock()
// TODO(kash): Need to watch out for fields of type channel and pull them
// out of line.
it.result = result
return true
- case _, ok := <-it.abort:
- if !ok {
- return false
- }
- panic("Unexpected value received from it.abort")
+ case <-it.abort:
+ return false
}
}
// Get implements the Iterator method.
func (it *evalIterator) Get() *store.QueryResult {
+ it.mu.Lock()
+ defer it.mu.Unlock()
return it.result
}
// Abort implements the Iterator method.
func (it *evalIterator) Abort() {
- close(it.abort)
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ select {
+ case <-it.abort:
+ // Already closed.
+ default:
+ close(it.abort)
+ }
it.result = nil
}
// Err implements the Iterator method.
func (it *evalIterator) Err() error {
+ it.mu.Lock()
+ defer it.mu.Unlock()
return it.err
}
+// handleErrors watches for errors on it.errc, calling it.Abort when it finds
+// one. It should run in a goroutine.
+func (it *evalIterator) handleErrors() {
+ select {
+ case <-it.abort:
+ case err := <-it.errc:
+ it.mu.Lock()
+ it.err = err
+ it.mu.Unlock()
+ it.Abort()
+ }
+}
+
// wait blocks until all children goroutines are finished. This is useful in
// tests to ensure that an abort cleans up correctly.
func (it *evalIterator) wait() {
@@ -158,6 +182,7 @@
abort: make(chan bool),
errc: make(chan error),
}
+ go it.handleErrors()
it.cleanup.Add(1)
go evaluator.eval(&context{
sn: sn,