veyron/lib/expect: simple library for making assertions about expected input.
Change-Id: I44fbfb12aec74705f103329000fc79cb51c4900e
diff --git a/lib/expect/expect.go b/lib/expect/expect.go
new file mode 100644
index 0000000..a748a9d
--- /dev/null
+++ b/lib/expect/expect.go
@@ -0,0 +1,231 @@
+// Package expect provides support for testing the contents from a buffered
+// input stream. It supports literal and pattern based matching. It is
+// line oriented; all of the methods (expect ReadAll) strip trailing newlines
+// from their return values. It places a timeout on all its operations.
+// It will generally be used to read from the stdout stream of subprocesses
+// in tests and other situations and to make 'assertions'
+// about what is to be read.
+//
+// A Session type is used to store state, in particular error state, across
+// consecutive invocations of its method set. If a particular method call
+// encounters an error then subsequent calls on that Session will have no
+// effect. This allows for a series of assertions to be made, one per line,
+// and for errors to be checked at the end. In addition Session is designed
+// to be easily used with the testing package; passing a testing.T instance
+// to NewSession allows it to set errors directly and hence tests will pass or
+// fail according to whether the expect assertions are met or not.
+//
+// Care is taken to ensure that the file and line number of the first
+// failed assertion in the session are recorded in the error stored in
+// the Session.
+//
+// Examples
+//
+// func TestSomething(t *testing.T) {
+// buf := []byte{}
+// buffer := bytes.NewBuffer(buf)
+// buffer.WriteString("foo\n")
+// buffer.WriteString("bar\n")
+// buffer.WriteString("baz\n")
+// s := expect.New(t, bufio.NewReader(buffer), time.Second)
+// s.Expect("foo")
+// s.Expect("bars)
+// if got, want := s.ReadLine(), "baz"; got != want {
+// t.Errorf("got %v, want %v", got, want)
+// }
+// }
+//
+package expect
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "time"
+)
+
+var (
+ Timeout = errors.New("timeout")
+)
+
+// Session represents the state of an expect session.
+type Session struct {
+ input *bufio.Reader
+ timeout time.Duration
+ t Testing
+ err error
+}
+
+type Testing interface {
+ Error(args ...interface{})
+}
+
+// NewSession creates a new Session. The parameter t may be safely be nil.
+func NewSession(t Testing, input *bufio.Reader, timeout time.Duration) *Session {
+ return &Session{t: t, timeout: timeout, input: input}
+}
+
+// Failed returns true if an error has been encountered by a prior call.
+func (s *Session) Failed() bool {
+ return s.err != nil
+}
+
+// Error returns the error code (possibly nil) currently stored in the Session.
+func (s *Session) Error() error {
+ return s.err
+}
+
+// ReportError calls Testing.Error to report any error currently stored
+// in the Session.
+func (s *Session) ReportError() {
+ if s.err != nil && s.t != nil {
+ s.t.Error(s.err)
+ }
+}
+
+// error must always be called from a public function that is called
+// directly by an external user, otherwise the file:line info will
+// be incorrect.
+func (s *Session) error(err error) error {
+ _, file, line, _ := runtime.Caller(2)
+ s.err = fmt.Errorf("%s:%d: %s", filepath.Base(file), line, err)
+ s.ReportError()
+ return s.err
+}
+
+type reader func(r *bufio.Reader) (string, error)
+
+func readAll(r *bufio.Reader) (string, error) {
+ all := ""
+ for {
+ buf := make([]byte, 4096*4)
+ n, err := r.Read(buf)
+ all += string(buf[:n])
+ if err != nil {
+ if err == io.EOF {
+ return all, nil
+ }
+ return all, err
+ }
+ }
+}
+
+func readLine(r *bufio.Reader) (string, error) {
+ return r.ReadString('\n')
+}
+
+func (s *Session) read(f reader) (string, error) {
+ ch := make(chan string, 1)
+ ech := make(chan error, 1)
+ go func(fn reader, io *bufio.Reader) {
+ s, err := fn(io)
+ if err != nil {
+ ech <- err
+ return
+ }
+ ch <- s
+ }(f, s.input)
+ select {
+ case err := <-ech:
+ return "", err
+ case m := <-ch:
+ return m, nil
+ case <-time.After(s.timeout):
+ return "", Timeout
+ }
+}
+
+// Expect asserts that the next line in the input matches the supplied string.
+func (s *Session) Expect(expected string) {
+ if s.Failed() {
+ return
+ }
+ line, err := s.read(readLine)
+ if err != nil {
+ s.error(err)
+ return
+ }
+ line = strings.TrimRight(line, "\n")
+ if line != expected {
+ s.error(fmt.Errorf("got %q, want %q", line, expected))
+ }
+ return
+}
+
+func (s *Session) expectRE(pattern string, n int) (string, [][]string, error) {
+ if s.Failed() {
+ return "", nil, s.err
+ }
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ return "", nil, err
+ }
+ line, err := s.read(readLine)
+ if err != nil {
+ return "", nil, err
+ }
+ line = strings.TrimRight(line, "\n")
+ return line, re.FindAllStringSubmatch(line, n), err
+}
+
+// ExpectRE asserts that the next line in the input matches the pattern using
+// regexp.MustCompile(pattern,n).FindAllStringSubmatch.
+func (s *Session) ExpectRE(pattern string, n int) [][]string {
+ if s.Failed() {
+ return [][]string{}
+ }
+ l, m, err := s.expectRE(pattern, n)
+ if err != nil {
+ s.error(err)
+ return [][]string{}
+ }
+ if len(m) == 0 {
+ s.error(fmt.Errorf("%q found no match in %q", pattern, l))
+ }
+ return m
+}
+
+// ExpectVar asserts that the next line in the input matches the pattern
+// <name>=<value> and returns <value>.
+func (s *Session) ExpectVar(name string) string {
+ if s.Failed() {
+ return ""
+ }
+ l, m, err := s.expectRE(name+"=(.*)", 1)
+ if err != nil {
+ s.error(err)
+ return ""
+ }
+ if len(m) != 1 || len(m[0]) != 2 {
+ s.error(fmt.Errorf("failed to find value for %q in %q", name, l))
+ return ""
+ }
+ return m[0][1]
+}
+
+// ReadLine reads the next line, if any, from the input stream. It will set
+// the error state to io.EOF if it has read past the end of the stream.
+func (s *Session) ReadLine() string {
+ if s.Failed() {
+ return ""
+ }
+ l, err := s.read(readLine)
+ if err != nil {
+ s.error(err)
+ }
+ return strings.TrimRight(l, "\n")
+}
+
+// ReadAll reads all remaining input on the stream. Unlike all of the other
+// methods it does not strip newlines from the input.
+func (s *Session) ReadAll() (string, error) {
+ if s.Failed() {
+ return "", s.err
+ }
+ return s.read(readAll)
+}
diff --git a/lib/expect/expect_test.go b/lib/expect/expect_test.go
new file mode 100644
index 0000000..1d4906b
--- /dev/null
+++ b/lib/expect/expect_test.go
@@ -0,0 +1,90 @@
+package expect_test
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "veyron/lib/expect"
+)
+
+func TestSimple(t *testing.T) {
+ buf := []byte{}
+ buffer := bytes.NewBuffer(buf)
+ buffer.WriteString("bar\n")
+ buffer.WriteString("baz\n")
+ buffer.WriteString("oops\n")
+ s := expect.NewSession(nil, bufio.NewReader(buffer), time.Minute)
+ s.Expect("bar")
+ s.Expect("baz")
+ if err := s.Error(); err != nil {
+ t.Error(err)
+ }
+ // This will fail the test.
+ s.Expect("not oops")
+ if err := s.Error(); err == nil {
+ t.Error("unexpected success")
+ } else {
+ t.Log(s.Error())
+ }
+}
+
+func TestExpectRE(t *testing.T) {
+ buf := []byte{}
+ buffer := bytes.NewBuffer(buf)
+ buffer.WriteString("bar=baz\n")
+ buffer.WriteString("aaa\n")
+ buffer.WriteString("bbb\n")
+ s := expect.NewSession(nil, bufio.NewReader(buffer), time.Minute)
+ if got, want := s.ExpectVar("bar"), "baz"; got != want {
+ t.Errorf("got %v, want %v", got, want)
+ }
+ s.ExpectRE("zzz|aaa", -1)
+ if err := s.Error(); err != nil {
+ t.Error(err)
+ }
+ if got, want := s.ExpectRE("(.*)", -1), [][]string{{"bbb", "bbb"}}; !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+ if got, want := s.ExpectRE("(.*", -1), [][]string{{"bbb", "bbb"}}; !reflect.DeepEqual(got, want) {
+ // this will have failed the test also.
+ if err := s.Error(); err == nil || !strings.Contains(err.Error(), "error parsing regexp") {
+ t.Errorf("missing or wrong error: %v", s.Error())
+ }
+ }
+}
+
+func TestRead(t *testing.T) {
+ buf := []byte{}
+ buffer := bytes.NewBuffer(buf)
+ lines := []string{"some words", "bar=baz", "more words"}
+ for _, l := range lines {
+ buffer.WriteString(l + "\n")
+ }
+ s := expect.NewSession(nil, bufio.NewReader(buffer), time.Minute)
+ for _, l := range lines {
+ if got, want := s.ReadLine(), l; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ }
+ if s.Failed() {
+ t.Errorf("unexpected error: %s", s.Error())
+ }
+ want := ""
+ for i := 0; i < 100; i++ {
+ m := fmt.Sprintf("%d\n", i)
+ buffer.WriteString(m)
+ want += m
+ }
+ got, err := s.ReadAll()
+ if err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}