// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package grpc

import (
	"bytes"
	_ "encoding/binary"
	"errors"
	"fmt"
	"io"
	_ "io/ioutil"
	"log"
	"net"
	"runtime/debug"
	"sync"
	"time"

	"golang.org/x/crypto/nacl/box"
)

const doEncrypt = true

// Implements net.Conn. Encrypts messages using keys negotioated via NaCl.
// TODO: blessings
// TODO: discharges and caveats
// TODO: ensuer that all 5 of these are set. Constructor?
// TODO: do I need to check the message size as in box_coipher.Seal?
// TODO: make all these keys not pointers?
// TODO: should I have locks here?
// TODO: what's all this salsa stuff?
type conn struct {
	rawConn   net.Conn
	publicKey *[32]byte
	secretKey *[32]byte
	sharedKey *[32]byte
	binding   []byte
	nonce     [24]byte // TODO: do I need a second nonce? Does this need to be random?
	counter   uint64   // Why is this so weirdly handled in advanceNonce?
	mu        sync.Mutex
}

// TODO: what if bytesRead < 3? = 3?
func (c *conn) Read(b []byte) (n int, err error) {
	c.mu.Lock()
	defer c.mu.Unlock()

	if !doEncrypt {
		log.Printf("Beginning to Read.\n")
		frame := [3]byte{}
		frameCopied, err := c.rawConn.Read(frame[:])
		if frameCopied != 3 {
			return -1, errors.New("Did not copy 3 frame bytes.")
		}
		msgSize := read3ByteUint(frame)

		resBuf := make([]byte, msgSize)
		bytesRead, err := c.rawConn.Read(resBuf)
		if err != nil {
			log.Fatal(err)
		}
		copy(b, resBuf)
		log.Printf("Read %d bytes: %v\n", bytesRead, resBuf)
		log.Printf("Succeeded in reading.\n\n")
		return bytesRead, nil
		// return len(b), nil
	}

	log.Printf("Beginning to Read.\n")

	frame := [3]byte{}
	frameCopied, err := c.rawConn.Read(frame[:])
	if frameCopied != 3 {
		return -1, errors.New("Did not copy 3 frame bytes.")
	}
	msgSize := read3ByteUint(frame)

	resBuf := make([]byte, msgSize) // TODO better (dynamic) size or way of reading?
	bytesRead, err := c.rawConn.Read(resBuf)
	// log.Printf("Read %d bytes: %v\n", bytesRead, resBuf)
	if err != nil {
		// log.Printf("Failed to read.\n")
		log.Fatal(err)
		return bytesRead, err
	}

	// tmp := make([]byte, 0, bytesRead-box.Overhead) // TODO: is this enough? Also, why do we need both of tmp and out?
	out, ok := box.OpenAfterPrecomputation(nil, resBuf, c.currentNonce(), c.sharedKey)
	log.Printf("%d bytes after opening: %v\n", len(out), out)
	// log.Printf("%d bytes after opening: %v\n", len(out), string(out))
	c.advanceNonce() // TODO: defer? Is there harm in advancing often than necessary (i.e. on error)?
	if !ok {
		// log.Printf("Failed to decrypt.\n")
		//log.Fatal(err)
		return 0, err
	}
	copy(b, out)
	// should we return bytesRead or len(b)?
	// log.Printf("cap(b): %d, len(b) %d, \ncap(out): %d, len(out): %d\n", cap(b), len(b), cap(out), len(out))
	log.Printf("Succeeded in reading.\n\n")
	// return bytesRead, nil
	// return len(b), nil
	return len(out), nil
}

// TODO: all this casting is gross
func (c *conn) Write(b []byte) (n int, err error) {
	// log.Printf("Write called:\n%v\n", string(debug.Stack()))
	c.mu.Lock()
	defer c.mu.Unlock()

	if !doEncrypt {
		log.Printf("Beginning to write.\n")
		tmp := make([]byte, 3)
		err = write3ByteUint(tmp, len(b))
		if err != nil {
			return -1, err
		}
		bytesCopied, err := io.Copy(c.rawConn, bytes.NewReader(tmp))
		if err != nil {
			log.Fatal(err)
		}
		bytesCopied, err = io.Copy(c.rawConn, bytes.NewReader(b))
		if err != nil {
			log.Fatal(err)
		}
		log.Printf("Wrote %d bytes: %v\n", bytesCopied, b)
		log.Printf("Succeeded in writing!\n\n")
		// // return int(bytesCopied), nil
		return len(b), nil
	}

	log.Printf("Beginning to write.\n")
	log.Printf("%d bytes before sealing: %v\n", len(b), b)
	tmp := make([]byte, 3, 3+len(b)+box.Overhead) // TODO: is this enough? Also, why do we need both of tmp and out?
	err = write3ByteUint(tmp[:3], len(b)+box.Overhead)
	if err != nil {
		return -1, err
	}
	out := box.SealAfterPrecomputation(tmp, b, c.currentNonce(), c.sharedKey)
	// log.Printf("tmp: %v", tmp)
	// log.Printf("out: %v", out)
	c.advanceNonce()
	bytesCopied, err := io.Copy(c.rawConn, bytes.NewReader(out))
	// log.Printf("Wrote %d bytes.\n", bytesCopied)
	if err != nil {
		log.Printf("Failed to copy to rawConn.\n")
		log.Fatal(err)
		return int(bytesCopied), err
	}
	if bytesCopied != int64(len(out)) {
		errMsg := fmt.Sprintf("Did not write entire message. Expected to write %d bytes but wrote %d.", len(out), bytesCopied)
		log.Printf("Did not write entire message. Expected to write %d bytes but wrote %d.", len(out), bytesCopied)
		log.Printf("Failed to write.\n")
		log.Fatal(err)
		return -1, errors.New(errMsg)
	}
	// log.Printf("cap(b): %d, len(b) %d, cap(out): %d, len(out): %d\n", cap(b), len(b), cap(out), len(out))
	log.Printf("Succeeded in writing!\n\n")
	// return int(bytesCopied), nil
	return len(b), nil
}

// TODO: understand this stuff
const maxPacketSize = 0xffffff

func write3ByteUint(dst []byte, n int) error {
	if n > maxPacketSize || n < 0 {
		// return NewErrLargerThan3ByteUInt(nil)
		return errors.New("TOOO BIG")
	}
	n = maxPacketSize - n
	dst[0] = byte((n & 0xff0000) >> 16)
	dst[1] = byte((n & 0x00ff00) >> 8)
	dst[2] = byte(n & 0x0000ff)
	return nil
}

func read3ByteUint(src [3]byte) int {
	return maxPacketSize - (int(src[0])<<16 | int(src[1])<<8 | int(src[2]))
}

// TODO: cover these up with an interface?
// TODO: remove all these useless mutex locks and unlocks
func (c *conn) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()
	debug.PrintStack()
	// log.Fatalf("DON'T CLOSE THE CONNETION\n")
	return c.rawConn.Close()
}

func (c *conn) LocalAddr() net.Addr {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.rawConn.LocalAddr()
}

func (c *conn) RemoteAddr() net.Addr {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.rawConn.RemoteAddr()
}

func (c *conn) SetDeadline(t time.Time) error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.rawConn.SetDeadline(t)
}

func (c *conn) SetReadDeadline(t time.Time) error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.rawConn.SetReadDeadline(t)
}

func (c *conn) SetWriteDeadline(t time.Time) error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.rawConn.SetWriteDeadline(t)
}

// TODO: should I just collapse current and advance?
func (c *conn) currentNonce() *[24]byte {
	return &c.nonce
}

// TODO: re enable.
func (c *conn) advanceNonce() {

	// c.counter++
	// binary.LittleEndian.PutUint64(c.nonce[:], c.counter)
}
