blob: ccf1cf96e5e09bd094054c48c5551655b7602a00 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mocknet_test
import (
"bytes"
"errors"
"io"
"net"
"reflect"
"sync"
"testing"
"time"
"v.io/v23"
"v.io/v23/context"
"v.io/v23/naming"
"v.io/v23/options"
"v.io/v23/rpc"
"v.io/v23/verror"
_ "v.io/x/ref/runtime/factories/generic"
"v.io/x/ref/runtime/internal/rpc/stream/crypto"
"v.io/x/ref/runtime/internal/rpc/stream/message"
"v.io/x/ref/runtime/internal/testing/mocks/mocknet"
"v.io/x/ref/test"
)
//go:generate v23 test generate
func newListener(t *testing.T, opts mocknet.Opts) net.Listener {
ln, err := mocknet.ListenerWithOpts(opts, "test", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
return ln
}
func TestTrace(t *testing.T) {
opts := mocknet.Opts{
Mode: mocknet.Trace,
Tx: make(chan int, 100),
Rx: make(chan int, 100),
}
ln := newListener(t, opts)
defer ln.Close()
var rxconn net.Conn
var wg sync.WaitGroup
wg.Add(1)
go func() {
rxconn, _ = ln.Accept()
wg.Done()
}()
txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
if err != nil {
t.Fatal(err)
}
wg.Wait()
rw := func(s string) {
b := make([]byte, len(s))
txconn.Write([]byte(s))
rxconn.Read(b[:])
if got, want := string(b), s; got != want {
t.Fatalf("got %v, want %v", got, want)
}
}
sizes := []int{}
for _, s := range []string{"hello", " ", "world"} {
rw(s)
sizes = append(sizes, len(s))
}
rxconn.Close()
close(opts.Tx)
close(opts.Rx)
sizes = append(sizes, -1)
drain := func(ch chan int) []int {
r := []int{}
for v := range ch {
r = append(r, v)
}
return r
}
if got, want := drain(opts.Rx), sizes; !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
if got, want := drain(opts.Tx), sizes; !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
}
func TestClose(t *testing.T) {
cases := []struct {
txClose, rxClose int
tx []string
rx []string
err error
}{
{6, 10, []string{"hello", "world"}, []string{"hello", "w"}, io.EOF},
{5, 10, []string{"hello", "world"}, []string{"hello", ""}, io.EOF},
{8, 6, []string{"hello", "world"}, []string{"hello", "w"}, io.EOF},
{8, 5, []string{"hello", "world"}, []string{"hello", ""}, errors.New("use of closed network connection")},
}
for ci, c := range cases {
opts := mocknet.Opts{
Mode: mocknet.Close,
TxCloseAt: c.txClose,
RxCloseAt: c.rxClose,
}
ln := newListener(t, opts)
defer ln.Close()
var rxconn net.Conn
var wg sync.WaitGroup
wg.Add(1)
go func() {
rxconn, _ = ln.Accept()
wg.Done()
}()
txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
if err != nil {
t.Fatal(err)
}
wg.Wait()
rw := func(s string) (int, int, string, error) {
b := make([]byte, len(s))
tx, _ := txconn.Write([]byte(s))
rx, err := rxconn.Read(b[:])
return tx, rx, string(b[0:rx]), err
}
txBytes := 0
rxBytes := 0
for i, m := range c.tx {
tx, rx, rxed, err := rw(m)
if got, want := rxed, c.rx[i]; got != want {
t.Fatalf("%d: got %q, want %q", ci, got, want)
}
txBytes += tx
rxBytes += rx
if err != nil {
if got, want := err.Error(), c.err.Error(); got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
}
}
if got, want := txBytes, c.txClose; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
rxWant := c.rxClose
if rxWant > c.txClose {
rxWant = c.txClose
}
if got, want := rxBytes, rxWant; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
}
}
func TestDrop(t *testing.T) {
cases := []struct {
txDropAfter int
tx []string
rx []string
}{
{6, []string{"hello", "world"}, []string{"hello", "w"}},
{2, []string{"hello", "world"}, []string{"he", "wo"}},
{0, []string{"hello", "world"}, []string{"", ""}},
}
for ci, c := range cases {
opts := mocknet.Opts{
Mode: mocknet.Drop,
TxDropAfter: func() int { return c.txDropAfter },
}
ln := newListener(t, opts)
defer ln.Close()
var rxconn net.Conn
var wg sync.WaitGroup
wg.Add(1)
go func() {
rxconn, _ = ln.Accept()
wg.Done()
}()
txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
if err != nil {
t.Fatal(err)
}
wg.Wait()
rw := func(s string, l int) (int, int, string, error) {
b := make([]byte, l)
tx, _ := txconn.Write([]byte(s))
rx, err := rxconn.Read(b[:])
return tx, rx, string(b[0:rx]), err
}
for i, m := range c.tx {
tx, rx, rxed, _ := rw(m, len(c.rx[i]))
if got, want := rxed, c.rx[i]; got != want {
t.Fatalf("%d: got %q, want %q", ci, got, want)
}
if tx != rx {
t.Fatalf("%d: tx %d, rx %d", ci, tx, rx)
}
}
}
}
func TestV23Drop(t *testing.T) {
cases := []struct {
numMsgs, txClose, rxClose int
}{
{5, 0, 0},
{5, 2, 0},
{5, 0, 2},
{5, 3, 2},
{5, 2, 3},
}
for ci, c := range cases {
var txed, rxed int
matcher := func(read bool, msg message.T) bool {
if read {
rxed++
return rxed == c.rxClose
} else {
txed++
return txed == c.txClose
}
}
opts := mocknet.Opts{
Mode: mocknet.V23CloseAtMessage,
V23MessageMatcher: matcher,
}
ln := newListener(t, opts)
defer ln.Close()
var rxconn net.Conn
var wg sync.WaitGroup
wg.Add(1)
go func() {
rxconn, _ = ln.Accept()
wg.Done()
}()
txconn, err := mocknet.DialerWithOpts(opts, "test", ln.Addr().String(), time.Minute)
if err != nil {
t.Fatal(err)
}
wg.Wait()
var msgBuf bytes.Buffer
for i := 0; i < c.numMsgs; i++ {
err = message.WriteTo(&msgBuf, &message.Data{}, crypto.NullControlCipher{})
if err != nil {
t.Fatal(err)
}
}
perMsgBytes := msgBuf.Len() / c.numMsgs
n, err := txconn.Write(msgBuf.Bytes())
txMsgs := n / perMsgBytes
switch {
case c.txClose > 0:
if got, want := txMsgs, c.txClose-1; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
if got, want := err, io.EOF; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
default:
if got, want := txMsgs, c.numMsgs; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
if err != nil {
t.Fatalf("%d: %v\n", ci, err)
}
}
var rxMsgs int
for ; rxMsgs < txMsgs; rxMsgs++ {
var n int
n, err = rxconn.Read(make([]byte, perMsgBytes*2))
if err != nil {
break
}
if got, want := n, perMsgBytes; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
}
switch {
case c.rxClose > 0 && (c.txClose == 0 || c.txClose > c.rxClose):
if got, want := rxMsgs, c.rxClose-1; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
if got, want := err, io.EOF; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
default:
if got, want := rxMsgs, txMsgs; got != want {
t.Fatalf("%d: got %v, want %v", ci, got, want)
}
if err != nil {
t.Fatalf("%d: %v\n", ci, err)
}
}
}
}
func newCtx() (*context.T, v23.Shutdown) {
ctx, shutdown := test.InitForTest()
v23.GetNamespace(ctx).CacheCtl(naming.DisableCache(true))
return ctx, shutdown
}
type simple struct{}
func (s *simple) Ping(call rpc.ServerCall) (string, error) {
return "pong", nil
}
func initServer(t *testing.T, ctx *context.T) (string, func()) {
server, err := v23.NewServer(ctx, options.SecurityNone)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
done := make(chan struct{})
deferFn := func() { close(done); server.Stop() }
eps, err := server.Listen(v23.GetListenSpec(ctx))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
server.Serve("", &simple{}, nil)
return eps[0].Name(), deferFn
}
func TestV23Control(t *testing.T) {
ctx, shutdown := newCtx()
defer shutdown()
matcher := func(_ bool, msg message.T) bool {
switch msg.(type) {
case *message.Data:
return false
}
// drop first control message
return true
}
dropControlDialer := func(network, address string, timeout time.Duration) (net.Conn, error) {
opts := mocknet.Opts{
Mode: mocknet.V23CloseAtMessage,
V23MessageMatcher: matcher,
}
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
simpleResolver := func(network, address string) (string, string, error) {
return network, address, nil
}
rpc.RegisterProtocol("dropControl", dropControlDialer, simpleResolver, net.Listen)
server, fn := initServer(t, ctx)
defer fn()
addr, _ := naming.SplitAddressName(server)
dropServer, err := mocknet.RewriteEndpointProtocol(addr, "dropControl")
if err != nil {
t.Fatal(err)
}
_, err = v23.GetClient(ctx).StartCall(ctx, dropServer.Name(), "Ping", nil, options.SecurityNone, options.NoRetry{})
if verror.ErrorID(err) != verror.ErrBadProtocol.ID {
t.Fatal(err)
}
}