Cosmos Nicolaou | 3c50ac4 | 2014-12-23 07:40:19 -0800 | [diff] [blame] | 1 | package websocket_test |
| 2 | |
| 3 | import ( |
| 4 | "encoding/gob" |
| 5 | "fmt" |
| 6 | "hash/crc64" |
| 7 | "io" |
| 8 | "math/rand" |
| 9 | "net" |
| 10 | "sync" |
| 11 | "testing" |
| 12 | "time" |
| 13 | |
Jiri Simsa | 764efb7 | 2014-12-25 20:57:03 -0800 | [diff] [blame] | 14 | "v.io/core/veyron2/ipc/stream" |
Cosmos Nicolaou | 3c50ac4 | 2014-12-23 07:40:19 -0800 | [diff] [blame] | 15 | |
Jiri Simsa | 764efb7 | 2014-12-25 20:57:03 -0800 | [diff] [blame] | 16 | "v.io/core/veyron/lib/testutil" |
Cosmos Nicolaou | 3c50ac4 | 2014-12-23 07:40:19 -0800 | [diff] [blame] | 17 | ) |
| 18 | |
| 19 | var crcTable *crc64.Table |
| 20 | |
| 21 | func init() { |
| 22 | testutil.Init() |
| 23 | crcTable = crc64.MakeTable(crc64.ISO) |
| 24 | } |
| 25 | |
| 26 | func newSender(t *testing.T, dialer stream.DialerFunc, protocol, address string) net.Conn { |
| 27 | conn, err := dialer(protocol, address, time.Minute) |
| 28 | if err != nil { |
| 29 | t.Fatalf("unexpected error: %s", err) |
| 30 | return nil |
| 31 | } |
| 32 | return conn |
| 33 | } |
| 34 | |
| 35 | func checkProtocols(conn net.Conn, tx string) error { |
| 36 | expectedProtocol := map[string]string{ |
| 37 | "ws": "ws", "wsh": "tcp", "tcp": "tcp", |
| 38 | } |
| 39 | if got, want := conn.LocalAddr().Network(), expectedProtocol[tx]; got != want { |
| 40 | return fmt.Errorf("wrong local protocol: got %q, want %q", got, want) |
| 41 | } |
| 42 | // Can't tell that the remote protocol is really 'wsh' |
| 43 | if got, want := conn.RemoteAddr().Network(), expectedProtocol[tx]; got != want { |
| 44 | return fmt.Errorf("wrong remote protocol: got %q, want %q", got, want) |
| 45 | } |
| 46 | return nil |
| 47 | } |
| 48 | |
| 49 | type packet struct { |
| 50 | Data []byte |
| 51 | Size int |
| 52 | CRC64 uint64 |
| 53 | } |
| 54 | |
| 55 | func createPacket() *packet { |
| 56 | p := &packet{} |
| 57 | p.Size = rand.Intn(4 * 1024) |
| 58 | p.Data = make([]byte, p.Size) |
| 59 | for i := 0; i < p.Size; i++ { |
| 60 | p.Data[i] = byte(rand.Int() & 0xff) |
| 61 | } |
| 62 | p.CRC64 = crc64.Checksum([]byte(p.Data), crcTable) |
| 63 | return p |
| 64 | } |
| 65 | |
| 66 | func checkPacket(p *packet) error { |
| 67 | if got, want := len(p.Data), p.Size; got != want { |
| 68 | return fmt.Errorf("wrong sizes: got %d, want %d", got, want) |
| 69 | } |
| 70 | crc := crc64.Checksum(p.Data, crcTable) |
| 71 | if got, want := crc, p.CRC64; got != want { |
| 72 | return fmt.Errorf("wrong crc: got %d, want %d", got, want) |
| 73 | } |
| 74 | return nil |
| 75 | } |
| 76 | |
| 77 | type backChannel struct { |
| 78 | crcChan chan uint64 |
| 79 | byteChan chan []byte |
| 80 | errChan chan error |
| 81 | } |
| 82 | |
| 83 | type bcTable struct { |
| 84 | ready *sync.Cond |
| 85 | sync.Mutex |
| 86 | bc map[string]*backChannel |
| 87 | } |
| 88 | |
| 89 | var globalBCTable bcTable |
| 90 | |
| 91 | func init() { |
| 92 | globalBCTable.ready = sync.NewCond(&globalBCTable) |
| 93 | globalBCTable.bc = make(map[string]*backChannel) |
| 94 | } |
| 95 | |
| 96 | func (bt *bcTable) waitfor(key string) *backChannel { |
| 97 | bt.Lock() |
| 98 | defer bt.Unlock() |
| 99 | for { |
| 100 | bc := bt.bc[key] |
| 101 | if bc != nil { |
| 102 | delete(bt.bc, key) |
| 103 | return bc |
| 104 | } |
| 105 | bt.ready.Wait() |
| 106 | } |
| 107 | } |
| 108 | |
| 109 | func (bt *bcTable) add(key string, bc *backChannel) { |
| 110 | bt.Lock() |
| 111 | bt.bc[key] = bc |
| 112 | bt.Unlock() |
| 113 | bt.ready.Broadcast() |
| 114 | } |
| 115 | |
| 116 | func packetReceiver(t *testing.T, ln net.Listener, bc *backChannel) { |
| 117 | conn, err := ln.Accept() |
| 118 | if err != nil { |
| 119 | close(bc.crcChan) |
| 120 | close(bc.errChan) |
| 121 | return |
| 122 | } |
| 123 | |
| 124 | globalBCTable.add(conn.RemoteAddr().String(), bc) |
| 125 | |
| 126 | defer conn.Close() |
| 127 | dec := gob.NewDecoder(conn) |
| 128 | rxed := 0 |
| 129 | for { |
| 130 | var p packet |
| 131 | err := dec.Decode(&p) |
| 132 | if err != nil { |
| 133 | if err != io.EOF { |
| 134 | bc.errChan <- fmt.Errorf("unexpected error: %s", err) |
| 135 | } |
| 136 | close(bc.crcChan) |
| 137 | close(bc.errChan) |
| 138 | return |
| 139 | } |
| 140 | if err := checkPacket(&p); err != nil { |
| 141 | bc.errChan <- fmt.Errorf("unexpected error: %s", err) |
| 142 | } |
| 143 | bc.crcChan <- p.CRC64 |
| 144 | rxed++ |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | func packetSender(t *testing.T, nPackets int, conn net.Conn) { |
| 149 | txCRCs := make([]uint64, nPackets) |
| 150 | enc := gob.NewEncoder(conn) |
| 151 | for i := 0; i < nPackets; i++ { |
| 152 | p := createPacket() |
| 153 | txCRCs[i] = p.CRC64 |
| 154 | if err := enc.Encode(p); err != nil { |
| 155 | t.Fatalf("unexpected error: %s", err) |
| 156 | } |
| 157 | } |
| 158 | conn.Close() // Close the connection so that the receiver quits. |
| 159 | |
| 160 | bc := globalBCTable.waitfor(conn.LocalAddr().String()) |
| 161 | for err := range bc.errChan { |
| 162 | if err != nil { |
| 163 | t.Fatalf(err.Error()) |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | rxed := 0 |
| 168 | for rxCRC := range bc.crcChan { |
| 169 | if got, want := rxCRC, txCRCs[rxed]; got != want { |
| 170 | t.Errorf("%s -> %s: packet %d: mismatched CRCs: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), rxed, got, want) |
| 171 | } |
| 172 | rxed++ |
| 173 | } |
| 174 | if got, want := rxed, nPackets; got != want { |
| 175 | t.Fatalf("%s -> %s: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want) |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | func packetRunner(t *testing.T, ln net.Listener, dialer stream.DialerFunc, protocol, address string) { |
| 180 | nPackets := 100 |
| 181 | go packetReceiver(t, ln, &backChannel{ |
| 182 | crcChan: make(chan uint64, nPackets), |
| 183 | errChan: make(chan error, nPackets), |
| 184 | }) |
| 185 | |
| 186 | conn := newSender(t, dialer, protocol, address) |
| 187 | if err := checkProtocols(conn, protocol); err != nil { |
| 188 | t.Fatalf(err.Error()) |
| 189 | } |
| 190 | packetSender(t, nPackets, conn) |
| 191 | } |
| 192 | |
| 193 | func byteReceiver(t *testing.T, ln net.Listener, bc *backChannel) { |
| 194 | conn, err := ln.Accept() |
| 195 | if err != nil { |
| 196 | close(bc.byteChan) |
| 197 | close(bc.errChan) |
| 198 | return |
| 199 | } |
| 200 | globalBCTable.add(conn.RemoteAddr().String(), bc) |
| 201 | |
| 202 | defer conn.Close() |
| 203 | rxed := 0 |
| 204 | for { |
| 205 | buf := make([]byte, rxed+1) |
| 206 | n, err := conn.Read(buf) |
| 207 | if err != nil { |
| 208 | if err != io.EOF { |
| 209 | bc.errChan <- fmt.Errorf("unexpected error: %s", err) |
| 210 | } |
| 211 | close(bc.byteChan) |
| 212 | close(bc.errChan) |
| 213 | return |
| 214 | } |
| 215 | if got, want := n, len(buf[:n]); got != want { |
| 216 | bc.errChan <- fmt.Errorf("%s -> %s: got %d bytes, expected %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want) |
| 217 | } |
| 218 | if got, want := buf[0], byte(0xff); got != want { |
| 219 | bc.errChan <- fmt.Errorf("%s -> %s: got %x, want %x", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want) |
| 220 | } |
| 221 | bc.byteChan <- buf[:n] |
| 222 | rxed++ |
| 223 | } |
| 224 | } |
| 225 | |
| 226 | func byteSender(t *testing.T, nIterations int, conn net.Conn) { |
| 227 | txBytes := make([][]byte, nIterations+1) |
| 228 | for i := 0; i < nIterations; i++ { |
| 229 | p := make([]byte, i+1) |
| 230 | p[0] = 0xff |
| 231 | for j := 1; j <= i; j++ { |
| 232 | p[j] = byte(64 + i) // start at ASCII A |
| 233 | } |
| 234 | txBytes[i] = p |
| 235 | n, err := conn.Write(p) |
| 236 | if err != nil { |
| 237 | t.Fatalf("unexpected error: %s", err) |
| 238 | } |
| 239 | if got, want := n, i+1; got != want { |
| 240 | t.Fatalf("wrote %d, not %d bytes", got, want) |
| 241 | } |
| 242 | } |
| 243 | conn.Close() |
| 244 | |
| 245 | bc := globalBCTable.waitfor(conn.LocalAddr().String()) |
| 246 | |
| 247 | for err := range bc.errChan { |
| 248 | if err != nil { |
| 249 | t.Fatalf(err.Error()) |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | addr := fmt.Sprintf("%s -> %s", conn.LocalAddr().String(), conn.RemoteAddr().String()) |
| 254 | rxed := 0 |
| 255 | for rxBytes := range bc.byteChan { |
| 256 | if got, want := len(rxBytes), rxed+1; got != want { |
| 257 | t.Fatalf("%s: got %d, want %d bytes", addr, got, want) |
| 258 | } |
| 259 | if got, want := rxBytes[0], byte(0xff); got != want { |
| 260 | t.Fatalf("%s: got %x, want %x", addr, got, want) |
| 261 | } |
| 262 | for i := 0; i < len(rxBytes); i++ { |
| 263 | if got, want := rxBytes[i], txBytes[rxed][i]; got != want { |
| 264 | t.Fatalf("%s: got %c, want %c", addr, got, want) |
| 265 | } |
| 266 | } |
| 267 | rxed++ |
| 268 | } |
| 269 | if got, want := rxed, nIterations; got != want { |
| 270 | t.Fatalf("%s: got %d, want %d", addr, got, want) |
| 271 | } |
| 272 | } |
| 273 | |
| 274 | func byteRunner(t *testing.T, ln net.Listener, dialer stream.DialerFunc, protocol, address string) { |
| 275 | nIterations := 10 |
| 276 | go byteReceiver(t, ln, &backChannel{ |
| 277 | byteChan: make(chan []byte, nIterations), |
| 278 | errChan: make(chan error, nIterations), |
| 279 | }) |
| 280 | |
| 281 | conn := newSender(t, dialer, protocol, address) |
| 282 | defer conn.Close() |
| 283 | if err := checkProtocols(conn, protocol); err != nil { |
| 284 | t.Fatalf(err.Error()) |
| 285 | } |
| 286 | byteSender(t, nIterations, conn) |
| 287 | } |