veyron/services/store/memstore/query: Add sample() function.
The sample() function randomly samples the stream of results using reservoir
sampling. sample() takes a single integer argument specifying the number
of samples to return.
I plan to use this function in the Store codelab to randomly select
a fortune.
Change-Id: I574a0fef77cff4fe602c1b300c80c5be33318fef
diff --git a/services/store/memstore/query/eval.go b/services/store/memstore/query/eval.go
index 783fbd8..1b2b9a2 100644
--- a/services/store/memstore/query/eval.go
+++ b/services/store/memstore/query/eval.go
@@ -1,6 +1,7 @@
package query
import (
+ "crypto/rand"
"fmt"
"math/big"
"reflect"
@@ -8,7 +9,6 @@
"sort"
"strings"
"sync"
- "veyron2/vlog"
vsync "veyron/runtimes/google/lib/sync"
"veyron/services/store/memstore/state"
@@ -21,6 +21,7 @@
"veyron2/services/store"
"veyron2/storage"
"veyron2/vdl/vdlutil"
+ "veyron2/vlog"
)
// maxChannelSize is the maximum size of the channels used for concurrent
@@ -622,22 +623,24 @@
switch p.FuncName {
case "sort":
if src.singleResult() {
- panic(fmt.Errorf("found aggregate function at %v, sort expects multiple results"))
+ panic(fmt.Errorf("%v: sort expects multiple inputs not a single input", p.Pos))
}
return &funcSortEvaluator{
src: convertPipeline(p.Src),
args: args,
pos: p.Pos,
}
+ case "sample":
+ return convertSampleFunc(src, args, p.Pos)
default:
panic(fmt.Errorf("unknown function %s at Pos %v", p.FuncName, p.Pos))
}
}
type funcSortEvaluator struct {
- // src produces intermediate results that will be transformed by func.
+ // src produces intermediate results that will be sorted.
src evaluator
- // args is the list of arguments passed to the function.
+ // args is the list of arguments passed to sort().
args []expr
// pos specifies where in the query string this component started.
pos parse.Pos
@@ -715,6 +718,77 @@
return a.results[i].Name < a.results[j].Name
}
+// funcSampleEvaluator is an evaluator that uses reservior sampling to
+// filter results to a desired number.
+type funcSampleEvaluator struct {
+ // src produces intermediate results that will be transformed by func.
+ src evaluator
+ // numSamples is the number of samples to send to the output.
+ numSamples int64
+ // pos specifies where in the query string this component started.
+ pos parse.Pos
+}
+
+func convertSampleFunc(src evaluator, args []expr, pos parse.Pos) evaluator {
+ if src.singleResult() {
+ panic(fmt.Errorf("%v: sample expects multiple inputs not a single input", pos))
+ }
+ if len(args) != 1 {
+ panic(fmt.Errorf("%v: sample expects exactly one integer argument specifying the number of results to include in the sample", pos))
+ }
+ n, ok := args[0].(*exprInt)
+ if !ok {
+ panic(fmt.Errorf("%v: sample expects exactly one integer argument specifying the number of results to include in the sample", pos))
+ }
+ return &funcSampleEvaluator{src, n.i.Int64(), pos}
+}
+
+// eval implements the evaluator method.
+func (e *funcSampleEvaluator) eval(c *context) {
+ defer c.cleanup.Done()
+ defer close(c.out)
+ srcOut := startSource(c, e.src)
+
+ reservoir := make([]*store.QueryResult, e.numSamples)
+ i := int64(0)
+ for result := range srcOut {
+ if i < e.numSamples {
+ // Fill the reservoir.
+ reservoir[i] = result
+ } else {
+ // Sample with decreasing probability.
+ bigJ, err := rand.Int(rand.Reader, big.NewInt(i+1))
+ if err != nil {
+ c.evalIt.setErrorf("error while sampling: %v", err)
+ return
+ }
+ j := bigJ.Int64()
+ if j < e.numSamples {
+ reservoir[j] = result
+ }
+ }
+ i++
+ }
+ for _, result := range reservoir {
+ if !c.emit(result) {
+ return
+ }
+ }
+}
+
+// singleResult implements the evaluator method.
+func (e *funcSampleEvaluator) singleResult() bool {
+ // During construction, we tested that e.src is not singleResult.
+ return false
+}
+
+// name implements the evaluator method.
+func (e *funcSampleEvaluator) name() string {
+ // A sampled resultset is still the same as the original resultset, so it
+ // should have the same name.
+ return e.src.name()
+}
+
// predicate determines whether an intermediate query result should be
// filtered out.
type predicate interface {
diff --git a/services/store/memstore/query/eval_test.go b/services/store/memstore/query/eval_test.go
index 6c28c11..e74de10 100644
--- a/services/store/memstore/query/eval_test.go
+++ b/services/store/memstore/query/eval_test.go
@@ -148,6 +148,9 @@
names := map[string]bool{}
for it.Next() {
result := it.Get()
+ if _, ok := names[result.Name]; ok {
+ t.Errorf("query: %s, duplicate results for %s", test.query, result.Name)
+ }
names[result.Name] = true
}
if it.Err() != nil {
@@ -168,6 +171,54 @@
}
}
+func TestSample(t *testing.T) {
+ st := populate(t)
+
+ type testCase struct {
+ query string
+ expectedNumNames int
+ }
+
+ tests := []testCase{
+ {"teams/* | type team | sample(1)", 1},
+ {"teams/* | type team | sample(2)", 2},
+ {"teams/* | type team | sample(3)", 3},
+ {"teams/* | type team | sample(4)", 3}, // Can't sample more values than exist.
+ }
+
+ for _, test := range tests {
+ it := Eval(st.Snapshot(), rootPublicID, storage.ParsePath(""), query.Query{test.query})
+ names := make(map[string]struct{})
+ for it.Next() {
+ result := it.Get()
+ if _, ok := names[result.Name]; ok {
+ t.Errorf("query: %s, duplicate results for %s", test.query, result.Name)
+ }
+ names[result.Name] = struct{}{}
+ }
+ if it.Err() != nil {
+ t.Errorf("query: %s, Error during eval: %v", test.query, it.Err())
+ continue
+ }
+ if len(names) != test.expectedNumNames {
+ t.Errorf("query: %s, Wrong number of names. got %v, wanted %v", test.query, names, test.expectedNumNames)
+ continue
+ }
+ possibleNames := map[string]struct{}{
+ "teams/cardinals": struct{}{},
+ "teams/sharks": struct{}{},
+ "teams/bears": struct{}{},
+ }
+ for name, _ := range names {
+ if _, ok := possibleNames[name]; !ok {
+ t.Errorf("Did not find '%s' in %v", name, possibleNames)
+ }
+ }
+ // Ensure that all the goroutines are cleaned up.
+ it.(*evalIterator).wait()
+ }
+}
+
func TestSorting(t *testing.T) {
st := populate(t)
sn := st.MutableSnapshot()
@@ -524,6 +575,9 @@
{"teams/* | ?Name > 'foo'", "could not look up name 'Name' relative to 'teams': not found"},
{"'teams/cardinals' | {myname: Name, myloc: Location} | ? Name == 'foo'", "name 'Name' was not selected from 'teams/cardinals', found: [myloc, myname]"},
{"teams/* | type team | sort(Name) | ?-Name > 'foo'", "cannot negate value of type string for teams/bears"},
+ {"teams/* | type team | sample(2, 3)", "1:21: sample expects exactly one integer argument specifying the number of results to include in the sample"},
+ {"teams/* | type team | sample(2.0)", "1:21: sample expects exactly one integer argument specifying the number of results to include in the sample"},
+ {"teams/* | type team | sample(-1)", "1:21: sample expects exactly one integer argument specifying the number of results to include in the sample"},
// TODO(kash): Selection with conflicting names.
// TODO(kash): Trying to sort an aggregate. "... | avg | sort()"