blob: ab97dd3a5324c71333c076a58688e5fec61fc0e3 [file] [log] [blame]
package testwriter
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sync"
"time"
"veyron.io/veyron/veyron/services/wsprd/lib"
)
type TestHarness interface {
Errorf(fmt string, a ...interface{})
}
type Response struct {
Type lib.ResponseType
Message interface{}
}
type Writer struct {
sync.Mutex
Stream []Response
err error
// If this channel is set then a message will be sent
// to this channel after recieving a call to FinishMessage()
notifier chan bool
}
func (w *Writer) Send(responseType lib.ResponseType, msg interface{}) error {
// We serialize and deserialize the reponse so that we can do deep equal with
// messages that contain non-exported structs.
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(Response{Type: responseType, Message: msg}); err != nil {
return err
}
var r Response
if err := json.NewDecoder(&buf).Decode(&r); err != nil {
return err
}
w.Lock()
defer w.Unlock()
w.Stream = append(w.Stream, r)
if w.notifier != nil {
w.notifier <- true
}
return nil
}
func (w *Writer) Error(err error) {
w.err = err
}
func (w *Writer) streamLength() int {
w.Lock()
defer w.Unlock()
return len(w.Stream)
}
// Waits until there is at least n messages in w.Stream. Returns an error if we timeout
// waiting for the given number of messages.
func (w *Writer) WaitForMessage(n int) error {
if w.streamLength() >= n {
return nil
}
w.Lock()
w.notifier = make(chan bool, 1)
w.Unlock()
for w.streamLength() < n {
select {
case <-w.notifier:
continue
case <-time.After(time.Second):
return fmt.Errorf("timed out")
}
}
w.Lock()
fmt.Printf("Looking for %d, got %d: %v", n, len(w.Stream), w.Stream)
w.notifier = nil
w.Unlock()
return nil
}
func CheckResponses(w *Writer, expectedStream []Response, err error, t TestHarness) {
if !reflect.DeepEqual(expectedStream, w.Stream) {
t.Errorf("streams don't match: expected %v, got %v", expectedStream, w.Stream)
}
if !reflect.DeepEqual(err, w.err) {
t.Errorf("unexpected error, got: %v, expected: %v", w.err, err)
}
}