blob: 9df1c4f24b5cb24d72a1db33ef0bb4328d4e4aad [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"encoding/gob"
"fmt"
"hash/crc64"
"io"
"math/rand"
"net"
"sync"
"testing"
"time"
"v.io/v23/context"
"v.io/v23/rpc"
)
//go:generate jiri test generate
var crcTable *crc64.Table
func init() {
crcTable = crc64.MakeTable(crc64.ISO)
}
func newSender(t *testing.T, dialer rpc.DialerFunc, protocol, address string) net.Conn {
ctx, _ := context.RootContext()
conn, err := dialer(ctx, protocol, address, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %s", err)
return nil
}
return conn
}
func checkProtocols(conn net.Conn, tx string) error {
expectedProtocol := map[string]string{
"ws": "ws", "wsh": "tcp", "tcp": "tcp",
}
if got, want := conn.LocalAddr().Network(), expectedProtocol[tx]; got != want {
return fmt.Errorf("wrong local protocol: got %q, want %q", got, want)
}
// Can't tell that the remote protocol is really 'wsh'
if got, want := conn.RemoteAddr().Network(), expectedProtocol[tx]; got != want {
return fmt.Errorf("wrong remote protocol: got %q, want %q", got, want)
}
return nil
}
type packet struct {
Data []byte
Size int
CRC64 uint64
}
func createPacket() *packet {
p := &packet{}
p.Size = rand.Intn(4 * 1024)
p.Data = make([]byte, p.Size)
for i := 0; i < p.Size; i++ {
p.Data[i] = byte(rand.Int() & 0xff)
}
p.CRC64 = crc64.Checksum([]byte(p.Data), crcTable)
return p
}
func checkPacket(p *packet) error {
if got, want := len(p.Data), p.Size; got != want {
return fmt.Errorf("wrong sizes: got %d, want %d", got, want)
}
crc := crc64.Checksum(p.Data, crcTable)
if got, want := crc, p.CRC64; got != want {
return fmt.Errorf("wrong crc: got %d, want %d", got, want)
}
return nil
}
type backChannel struct {
crcChan chan uint64
byteChan chan []byte
errChan chan error
}
type bcTable struct {
ready *sync.Cond
sync.Mutex
bc map[string]*backChannel
}
var globalBCTable bcTable
func init() {
globalBCTable.ready = sync.NewCond(&globalBCTable)
globalBCTable.bc = make(map[string]*backChannel)
}
func (bt *bcTable) waitfor(key string) *backChannel {
bt.Lock()
defer bt.Unlock()
for {
bc := bt.bc[key]
if bc != nil {
delete(bt.bc, key)
return bc
}
bt.ready.Wait()
}
}
func (bt *bcTable) add(key string, bc *backChannel) {
bt.Lock()
bt.bc[key] = bc
bt.Unlock()
bt.ready.Broadcast()
}
func packetReceiver(t *testing.T, ln net.Listener, bc *backChannel) {
conn, err := ln.Accept()
if err != nil {
close(bc.crcChan)
close(bc.errChan)
return
}
globalBCTable.add(conn.RemoteAddr().String(), bc)
defer conn.Close()
dec := gob.NewDecoder(conn)
rxed := 0
for {
var p packet
err := dec.Decode(&p)
if err != nil {
if err != io.EOF {
bc.errChan <- fmt.Errorf("unexpected error: %s", err)
}
close(bc.crcChan)
close(bc.errChan)
return
}
if err := checkPacket(&p); err != nil {
bc.errChan <- fmt.Errorf("unexpected error: %s", err)
}
bc.crcChan <- p.CRC64
rxed++
}
}
func packetSender(t *testing.T, nPackets int, conn net.Conn) {
txCRCs := make([]uint64, nPackets)
enc := gob.NewEncoder(conn)
for i := 0; i < nPackets; i++ {
p := createPacket()
txCRCs[i] = p.CRC64
if err := enc.Encode(p); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
conn.Close() // Close the connection so that the receiver quits.
bc := globalBCTable.waitfor(conn.LocalAddr().String())
for err := range bc.errChan {
if err != nil {
t.Fatalf(err.Error())
}
}
rxed := 0
for rxCRC := range bc.crcChan {
if got, want := rxCRC, txCRCs[rxed]; got != want {
t.Errorf("%s -> %s: packet %d: mismatched CRCs: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), rxed, got, want)
}
rxed++
}
if got, want := rxed, nPackets; got != want {
t.Fatalf("%s -> %s: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
}
}
func packetRunner(t *testing.T, ln net.Listener, dialer rpc.DialerFunc, protocol, address string) {
nPackets := 100
go packetReceiver(t, ln, &backChannel{
crcChan: make(chan uint64, nPackets),
errChan: make(chan error, nPackets),
})
conn := newSender(t, dialer, protocol, address)
if err := checkProtocols(conn, protocol); err != nil {
t.Fatalf(err.Error())
}
packetSender(t, nPackets, conn)
}
func byteReceiver(t *testing.T, ln net.Listener, bc *backChannel) {
conn, err := ln.Accept()
if err != nil {
close(bc.byteChan)
close(bc.errChan)
return
}
globalBCTable.add(conn.RemoteAddr().String(), bc)
defer conn.Close()
rxed := 0
for {
buf := make([]byte, rxed+1)
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
bc.errChan <- fmt.Errorf("unexpected error: %s", err)
}
close(bc.byteChan)
close(bc.errChan)
return
}
if got, want := n, len(buf[:n]); got != want {
bc.errChan <- fmt.Errorf("%s -> %s: got %d bytes, expected %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
}
if got, want := buf[0], byte(0xff); got != want {
bc.errChan <- fmt.Errorf("%s -> %s: got %x, want %x", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
}
bc.byteChan <- buf[:n]
rxed++
}
}
func byteSender(t *testing.T, nIterations int, conn net.Conn) {
txBytes := make([][]byte, nIterations+1)
for i := 0; i < nIterations; i++ {
p := make([]byte, i+1)
p[0] = 0xff
for j := 1; j <= i; j++ {
p[j] = byte(64 + i) // start at ASCII A
}
txBytes[i] = p
n, err := conn.Write(p)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if got, want := n, i+1; got != want {
t.Fatalf("wrote %d, not %d bytes", got, want)
}
}
conn.Close()
bc := globalBCTable.waitfor(conn.LocalAddr().String())
for err := range bc.errChan {
if err != nil {
t.Fatalf(err.Error())
}
}
addr := fmt.Sprintf("%s -> %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
rxed := 0
for rxBytes := range bc.byteChan {
if got, want := len(rxBytes), rxed+1; got != want {
t.Fatalf("%s: got %d, want %d bytes", addr, got, want)
}
if got, want := rxBytes[0], byte(0xff); got != want {
t.Fatalf("%s: got %x, want %x", addr, got, want)
}
for i := 0; i < len(rxBytes); i++ {
if got, want := rxBytes[i], txBytes[rxed][i]; got != want {
t.Fatalf("%s: got %c, want %c", addr, got, want)
}
}
rxed++
}
if got, want := rxed, nIterations; got != want {
t.Fatalf("%s: got %d, want %d", addr, got, want)
}
}
func byteRunner(t *testing.T, ln net.Listener, dialer rpc.DialerFunc, protocol, address string) {
nIterations := 10
go byteReceiver(t, ln, &backChannel{
byteChan: make(chan []byte, nIterations),
errChan: make(chan error, nIterations),
})
conn := newSender(t, dialer, protocol, address)
defer conn.Close()
if err := checkProtocols(conn, protocol); err != nil {
t.Fatalf(err.Error())
}
byteSender(t, nIterations, conn)
}