v.io/x/lib/simplemr: add the ability to cancel a running MR.
Change-Id: I97a46102d467a4a5f2f661332d48ab66d61b88ae
diff --git a/simplemr/.api b/simplemr/.api
index 16b6f34..a3ce5c5 100644
--- a/simplemr/.api
+++ b/simplemr/.api
@@ -1,6 +1,9 @@
pkg simplemr, method (*Identity) Map(*MR, string, interface{}) error
pkg simplemr, method (*Identity) Reduce(*MR, string, []interface{}) error
+pkg simplemr, method (*MR) Cancel()
+pkg simplemr, method (*MR) CancelCh() <-chan struct{}
pkg simplemr, method (*MR) Error() error
+pkg simplemr, method (*MR) IsCancelled() bool
pkg simplemr, method (*MR) MapOut(string, ...interface{})
pkg simplemr, method (*MR) ReduceOut(string, ...interface{})
pkg simplemr, method (*MR) Run(<-chan *Record, chan<- *Record, Mapper, Reducer) error
@@ -15,3 +18,4 @@
pkg simplemr, type Record struct, Values []interface{}
pkg simplemr, type Reducer interface { Reduce }
pkg simplemr, type Reducer interface, Reduce(*MR, string, []interface{}) error
+pkg simplemr, var ErrMRCancelled error
diff --git a/simplemr/mr.go b/simplemr/mr.go
index 09d0347..586a241 100644
--- a/simplemr/mr.go
+++ b/simplemr/mr.go
@@ -13,6 +13,7 @@
package simplemr
import (
+ "errors"
"fmt"
"runtime"
"sort"
@@ -20,6 +21,8 @@
"time"
)
+var ErrMRCancelled = errors.New("MR cancelled")
+
// Mapper is in the interface that must be implemented by all mappers.
type Mapper interface {
// Map is called by the framework for every key, value pair read
@@ -74,10 +77,14 @@
// MR represents the Map Reduction.
type MR struct {
- input <-chan *Record
- output chan<- *Record
- err error
- data *store
+ input <-chan *Record
+ output chan<- *Record
+ cancel chan struct{}
+ cancelled bool
+ cancelled_mu sync.RWMutex // guards cancelled
+ err error
+ data *store
+
// The number of conccurent mappers to use. A value of 0 instructs
// the implementation to use an appropriate number, such as the number
// of available CPUs.
@@ -106,6 +113,33 @@
mr.output <- &Record{key, values}
}
+// CancelCh returns a channel that will be closed when the Cancel
+// method is called. It should only be called by a mapper or reducer.
+func (mr *MR) CancelCh() <-chan struct{} {
+ return mr.cancel
+}
+
+// Cancel closes the channel intended to be used for monitoring
+// cancellation requests. If Cancel is called before any reducers
+// have been run then no reducers will be run. It can only be called
+// after mr.Run has been called, generally by a mapper or a reducer.
+func (mr *MR) Cancel() {
+ mr.cancelled_mu.Lock()
+ defer mr.cancelled_mu.Unlock()
+ if mr.cancelled {
+ return
+ }
+ close(mr.cancel)
+ mr.cancelled = true
+}
+
+// IsCancelled returns true if this MR has been cancelled.
+func (mr *MR) IsCancelled() bool {
+ mr.cancelled_mu.RLock()
+ defer mr.cancelled_mu.RUnlock()
+ return mr.cancelled
+}
+
func (mr *MR) runMapper(ch chan error, mapper Mapper) {
for {
rec := <-mr.input
@@ -139,6 +173,8 @@
if done == mr.NumMappers {
return nil
}
+ case <-mr.cancel:
+ return ErrMRCancelled
case <-timeout:
return fmt.Errorf("timed out mappers after %s", mr.Timeout)
}
@@ -172,6 +208,7 @@
// Run may only be called once per MR receiver.
func (mr *MR) Run(input <-chan *Record, output chan<- *Record, mapper Mapper, reducer Reducer) error {
mr.input, mr.output, mr.data = input, output, newStore()
+ mr.cancel = make(chan struct{})
if mr.NumMappers == 0 {
// TODO(cnicolaou,toddw): consider using a new goroutine
// for every input record rather than fixing concurrency like
@@ -187,6 +224,12 @@
if mr.err = mr.runMappers(mapper, timeout); mr.err != nil {
return mr.err
}
+ if mr.IsCancelled() {
+ return ErrMRCancelled
+ }
mr.err = mr.runReducers(reducer, timeout)
+ if mr.IsCancelled() {
+ return ErrMRCancelled
+ }
return mr.err
}
diff --git a/simplemr/mr_test.go b/simplemr/mr_test.go
index 00ed1f8..baf5077 100644
--- a/simplemr/mr_test.go
+++ b/simplemr/mr_test.go
@@ -5,7 +5,9 @@
package simplemr_test
import (
+ "errors"
"fmt"
+ "math/rand"
"runtime"
"strings"
"testing"
@@ -203,3 +205,62 @@
t.Fatal(err)
}
}
+
+type cancelMR struct{ cancelMapper bool }
+
+var (
+ errMapperCancelled = errors.New("mapper cancelled")
+ errReducerCancelled = errors.New("reducer cancelled")
+)
+
+func cancelEg(mr *simplemr.MR) error {
+ delay := rand.Int63n(1000) * int64(time.Millisecond)
+ select {
+ case <-mr.CancelCh():
+ return nil
+ case <-time.After(time.Duration(delay)):
+ mr.Cancel()
+ return nil
+ case <-time.After(time.Hour):
+ }
+ return fmt.Errorf("timeout")
+}
+
+func (c *cancelMR) Map(mr *simplemr.MR, key string, val interface{}) error {
+ if c.cancelMapper {
+ return cancelEg(mr)
+ }
+ mr.MapOut(key, val)
+ return nil
+}
+
+func (c *cancelMR) Reduce(mr *simplemr.MR, key string, values []interface{}) error {
+ if !c.cancelMapper {
+ return cancelEg(mr)
+ }
+ panic("should never get here")
+ return nil
+}
+
+func testCancel(t *testing.T, mapper bool) {
+ mrt := &simplemr.MR{}
+ in, out := newChans(10)
+ cancel := &cancelMR{true}
+ genInput := func() {
+ in <- &simplemr.Record{"d1", []interface{}{d1, d2, d3}}
+ in <- &simplemr.Record{"d2", []interface{}{d1, d2, d3}}
+ close(in)
+ }
+ go genInput()
+ if got, want := mrt.Run(in, out, cancel, cancel), simplemr.ErrMRCancelled; got != want {
+ t.Fatalf("got %v, want %v", got, want)
+ }
+}
+
+func TestCancelMappers(t *testing.T) {
+ testCancel(t, true)
+}
+
+func TestCancelReducers(t *testing.T) {
+ testCancel(t, false)
+}