// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package websocket_test

import (
	"encoding/gob"
	"fmt"
	"hash/crc64"
	"io"
	"math/rand"
	"net"
	"sync"
	"testing"
	"time"

	"v.io/v23/context"
	"v.io/v23/rpc"
)

//go:generate jiri test generate

var crcTable *crc64.Table

func init() {
	crcTable = crc64.MakeTable(crc64.ISO)
}

func newSender(t *testing.T, dialer rpc.DialerFunc, protocol, address string) net.Conn {
	ctx, _ := context.RootContext()
	conn, err := dialer(ctx, protocol, address, time.Minute)
	if err != nil {
		t.Fatalf("unexpected error: %s", err)
		return nil
	}
	return conn
}

func checkProtocols(conn net.Conn, tx string) error {
	expectedProtocol := map[string]string{
		"ws": "ws", "wsh": "tcp", "tcp": "tcp",
	}
	if got, want := conn.LocalAddr().Network(), expectedProtocol[tx]; got != want {
		return fmt.Errorf("wrong local protocol: got %q, want %q", got, want)
	}
	// Can't tell that the remote protocol is really 'wsh'
	if got, want := conn.RemoteAddr().Network(), expectedProtocol[tx]; got != want {
		return fmt.Errorf("wrong remote protocol: got %q, want %q", got, want)
	}
	return nil
}

type packet struct {
	Data  []byte
	Size  int
	CRC64 uint64
}

func createPacket() *packet {
	p := &packet{}
	p.Size = rand.Intn(4 * 1024)
	p.Data = make([]byte, p.Size)
	for i := 0; i < p.Size; i++ {
		p.Data[i] = byte(rand.Int() & 0xff)
	}
	p.CRC64 = crc64.Checksum([]byte(p.Data), crcTable)
	return p
}

func checkPacket(p *packet) error {
	if got, want := len(p.Data), p.Size; got != want {
		return fmt.Errorf("wrong sizes: got %d, want %d", got, want)
	}
	crc := crc64.Checksum(p.Data, crcTable)
	if got, want := crc, p.CRC64; got != want {
		return fmt.Errorf("wrong crc: got %d, want %d", got, want)
	}
	return nil
}

type backChannel struct {
	crcChan  chan uint64
	byteChan chan []byte
	errChan  chan error
}

type bcTable struct {
	ready *sync.Cond
	sync.Mutex
	bc map[string]*backChannel
}

var globalBCTable bcTable

func init() {
	globalBCTable.ready = sync.NewCond(&globalBCTable)
	globalBCTable.bc = make(map[string]*backChannel)
}

func (bt *bcTable) waitfor(key string) *backChannel {
	bt.Lock()
	defer bt.Unlock()
	for {
		bc := bt.bc[key]
		if bc != nil {
			delete(bt.bc, key)
			return bc
		}
		bt.ready.Wait()
	}
}

func (bt *bcTable) add(key string, bc *backChannel) {
	bt.Lock()
	bt.bc[key] = bc
	bt.Unlock()
	bt.ready.Broadcast()
}

func packetReceiver(t *testing.T, ln net.Listener, bc *backChannel) {
	conn, err := ln.Accept()
	if err != nil {
		close(bc.crcChan)
		close(bc.errChan)
		return
	}

	globalBCTable.add(conn.RemoteAddr().String(), bc)

	defer conn.Close()
	dec := gob.NewDecoder(conn)
	rxed := 0
	for {
		var p packet
		err := dec.Decode(&p)
		if err != nil {
			if err != io.EOF {
				bc.errChan <- fmt.Errorf("unexpected error: %s", err)
			}
			close(bc.crcChan)
			close(bc.errChan)
			return
		}
		if err := checkPacket(&p); err != nil {
			bc.errChan <- fmt.Errorf("unexpected error: %s", err)
		}
		bc.crcChan <- p.CRC64
		rxed++
	}
}

func packetSender(t *testing.T, nPackets int, conn net.Conn) {
	txCRCs := make([]uint64, nPackets)
	enc := gob.NewEncoder(conn)
	for i := 0; i < nPackets; i++ {
		p := createPacket()
		txCRCs[i] = p.CRC64
		if err := enc.Encode(p); err != nil {
			t.Fatalf("unexpected error: %s", err)
		}
	}
	conn.Close() // Close the connection so that the receiver quits.

	bc := globalBCTable.waitfor(conn.LocalAddr().String())
	for err := range bc.errChan {
		if err != nil {
			t.Fatalf(err.Error())
		}
	}

	rxed := 0
	for rxCRC := range bc.crcChan {
		if got, want := rxCRC, txCRCs[rxed]; got != want {
			t.Errorf("%s -> %s: packet %d: mismatched CRCs: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), rxed, got, want)
		}
		rxed++
	}
	if got, want := rxed, nPackets; got != want {
		t.Fatalf("%s -> %s: got %d, want %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
	}
}

func packetRunner(t *testing.T, ln net.Listener, dialer rpc.DialerFunc, protocol, address string) {
	nPackets := 100
	go packetReceiver(t, ln, &backChannel{
		crcChan: make(chan uint64, nPackets),
		errChan: make(chan error, nPackets),
	})

	conn := newSender(t, dialer, protocol, address)
	if err := checkProtocols(conn, protocol); err != nil {
		t.Fatalf(err.Error())
	}
	packetSender(t, nPackets, conn)
}

func byteReceiver(t *testing.T, ln net.Listener, bc *backChannel) {
	conn, err := ln.Accept()
	if err != nil {
		close(bc.byteChan)
		close(bc.errChan)
		return
	}
	globalBCTable.add(conn.RemoteAddr().String(), bc)

	defer conn.Close()
	rxed := 0
	for {
		buf := make([]byte, rxed+1)
		n, err := conn.Read(buf)
		if err != nil {
			if err != io.EOF {
				bc.errChan <- fmt.Errorf("unexpected error: %s", err)
			}
			close(bc.byteChan)
			close(bc.errChan)
			return
		}
		if got, want := n, len(buf[:n]); got != want {
			bc.errChan <- fmt.Errorf("%s -> %s: got %d bytes, expected %d", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
		}
		if got, want := buf[0], byte(0xff); got != want {
			bc.errChan <- fmt.Errorf("%s -> %s: got %x, want %x", conn.LocalAddr().String(), conn.RemoteAddr().String(), got, want)
		}
		bc.byteChan <- buf[:n]
		rxed++
	}
}

func byteSender(t *testing.T, nIterations int, conn net.Conn) {
	txBytes := make([][]byte, nIterations+1)
	for i := 0; i < nIterations; i++ {
		p := make([]byte, i+1)
		p[0] = 0xff
		for j := 1; j <= i; j++ {
			p[j] = byte(64 + i) // start at ASCII A
		}
		txBytes[i] = p
		n, err := conn.Write(p)
		if err != nil {
			t.Fatalf("unexpected error: %s", err)
		}
		if got, want := n, i+1; got != want {
			t.Fatalf("wrote %d, not %d bytes", got, want)
		}
	}
	conn.Close()

	bc := globalBCTable.waitfor(conn.LocalAddr().String())

	for err := range bc.errChan {
		if err != nil {
			t.Fatalf(err.Error())
		}
	}

	addr := fmt.Sprintf("%s -> %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
	rxed := 0
	for rxBytes := range bc.byteChan {
		if got, want := len(rxBytes), rxed+1; got != want {
			t.Fatalf("%s: got %d, want %d bytes", addr, got, want)
		}
		if got, want := rxBytes[0], byte(0xff); got != want {
			t.Fatalf("%s: got %x, want %x", addr, got, want)
		}
		for i := 0; i < len(rxBytes); i++ {
			if got, want := rxBytes[i], txBytes[rxed][i]; got != want {
				t.Fatalf("%s: got %c, want %c", addr, got, want)
			}
		}
		rxed++
	}
	if got, want := rxed, nIterations; got != want {
		t.Fatalf("%s: got %d, want %d", addr, got, want)
	}
}

func byteRunner(t *testing.T, ln net.Listener, dialer rpc.DialerFunc, protocol, address string) {
	nIterations := 10
	go byteReceiver(t, ln, &backChannel{
		byteChan: make(chan []byte, nIterations),
		errChan:  make(chan error, nIterations),
	})

	conn := newSender(t, dialer, protocol, address)
	defer conn.Close()
	if err := checkProtocols(conn, protocol); err != nil {
		t.Fatalf(err.Error())
	}
	byteSender(t, nIterations, conn)
}
