blob: d319973d0e50592234ba2a44345cd32589d9b41c [file] [log] [blame]
// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package grpc
import (
"bytes"
_ "encoding/binary"
"errors"
"fmt"
"io"
_ "io/ioutil"
"log"
"net"
"runtime/debug"
"sync"
"time"
"golang.org/x/crypto/nacl/box"
)
const doEncrypt = true
// Implements net.Conn. Encrypts messages using keys negotioated via NaCl.
// TODO: blessings
// TODO: discharges and caveats
// TODO: ensuer that all 5 of these are set. Constructor?
// TODO: do I need to check the message size as in box_coipher.Seal?
// TODO: make all these keys not pointers?
// TODO: should I have locks here?
// TODO: what's all this salsa stuff?
type conn struct {
rawConn net.Conn
publicKey *[32]byte
secretKey *[32]byte
sharedKey *[32]byte
binding []byte
nonce [24]byte // TODO: do I need a second nonce? Does this need to be random?
counter uint64 // Why is this so weirdly handled in advanceNonce?
mu sync.Mutex
}
// TODO: what if bytesRead < 3? = 3?
func (c *conn) Read(b []byte) (n int, err error) {
// c.mu.Lock()
// defer c.mu.Unlock()
if !doEncrypt {
log.Printf("Beginning to Read.\n")
frame := [3]byte{}
frameCopied, err := c.rawConn.Read(frame[:])
if frameCopied != 3 {
return frameCopied, errors.New("Did not copy 3 frame bytes.")
}
msgSize := read3ByteUint(frame)
resBuf := make([]byte, msgSize)
bytesRead, err := c.rawConn.Read(resBuf)
if err != nil {
log.Fatal(err)
}
copy(b, resBuf)
log.Printf("Read %d bytes: %v\n", bytesRead, resBuf)
log.Printf("Succeeded in reading.\n\n")
return bytesRead, nil
// return len(b), nil
}
log.Printf("Beginning to Read.\n")
frame := [3]byte{}
frameCopied, err := c.rawConn.Read(frame[:])
if frameCopied != 3 {
return frameCopied, errors.New("Did not copy 3 frame bytes.")
}
msgSize := read3ByteUint(frame)
resBuf := make([]byte, msgSize) // TODO better (dynamic) size or way of reading?
bytesRead, err := c.rawConn.Read(resBuf)
// log.Printf("Read %d bytes: %v\n", bytesRead, resBuf)
if err != nil {
// log.Printf("Failed to read.\n")
log.Fatal(err)
return bytesRead, err
}
// tmp := make([]byte, 0, bytesRead-box.Overhead) // TODO: is this enough? Also, why do we need both of tmp and out?
out, ok := box.OpenAfterPrecomputation(nil, resBuf, c.currentNonce(), c.sharedKey)
log.Printf("%d bytes after opening: %v\n", len(out), out)
// log.Printf("%d bytes after opening: %v\n", len(out), string(out))
c.advanceNonce() // TODO: defer? Is there harm in advancing often than necessary (i.e. on error)?
if !ok {
// log.Printf("Failed to decrypt.\n")
//log.Fatal(err)
return 0, err
}
copy(b, out)
// should we return bytesRead or len(b)?
// log.Printf("cap(b): %d, len(b) %d, \ncap(out): %d, len(out): %d\n", cap(b), len(b), cap(out), len(out))
log.Printf("Succeeded in reading.\n\n")
// return bytesRead, nil
// return len(b), nil
return len(out), nil
}
// TODO: all this casting is gross
func (c *conn) Write(b []byte) (n int, err error) {
// log.Printf("Write called:\n%v\n", string(debug.Stack()))
// c.mu.Lock()
// defer c.mu.Unlock()
if !doEncrypt {
log.Printf("Beginning to write.\n")
tmp := make([]byte, 3)
err = write3ByteUint(tmp, len(b))
if err != nil {
return 0, err
}
bytesCopied, err := io.Copy(c.rawConn, bytes.NewReader(tmp))
if err != nil {
log.Fatal(err)
}
bytesCopied, err = io.Copy(c.rawConn, bytes.NewReader(b))
if err != nil {
log.Fatal(err)
}
log.Printf("Wrote %d bytes: %v\n", bytesCopied, b)
log.Printf("Succeeded in writing!\n\n")
// // return int(bytesCopied), nil
return len(b), nil
}
log.Printf("Beginning to write.\n")
log.Printf("%d bytes before sealing: %v\n", len(b), b)
tmp := make([]byte, 3, 3+len(b)+box.Overhead) // TODO: is this enough? Also, why do we need both of tmp and out?
err = write3ByteUint(tmp[:3], len(b)+box.Overhead)
if err != nil {
return 0, err
}
out := box.SealAfterPrecomputation(tmp, b, c.currentNonce(), c.sharedKey)
// log.Printf("tmp: %v", tmp)
// log.Printf("out: %v", out)
c.advanceNonce()
bytesCopied, err := io.Copy(c.rawConn, bytes.NewReader(out))
// log.Printf("Wrote %d bytes.\n", bytesCopied)
if err != nil {
log.Printf("Failed to copy to rawConn.\n")
log.Fatal(err)
return int(bytesCopied), err
}
if bytesCopied != int64(len(out)) {
errMsg := fmt.Sprintf("Did not write entire message. Expected to write %d bytes but wrote %d.", len(out), bytesCopied)
log.Printf("Did not write entire message. Expected to write %d bytes but wrote %d.", len(out), bytesCopied)
log.Printf("Failed to write.\n")
log.Fatal(err)
return int(bytesCopied), errors.New(errMsg)
}
// log.Printf("cap(b): %d, len(b) %d, cap(out): %d, len(out): %d\n", cap(b), len(b), cap(out), len(out))
log.Printf("Succeeded in writing!\n\n")
// return int(bytesCopied), nil
return len(b), nil
}
// TODO: understand this stuff
const maxPacketSize = 0xffffff
func write3ByteUint(dst []byte, n int) error {
if n > maxPacketSize || n < 0 {
// return NewErrLargerThan3ByteUInt(nil)
return errors.New("TOOO BIG")
}
n = maxPacketSize - n
dst[0] = byte((n & 0xff0000) >> 16)
dst[1] = byte((n & 0x00ff00) >> 8)
dst[2] = byte(n & 0x0000ff)
return nil
}
func read3ByteUint(src [3]byte) int {
return maxPacketSize - (int(src[0])<<16 | int(src[1])<<8 | int(src[2]))
}
// TODO: cover these up with an interface?
// TODO: remove all these useless mutex locks and unlocks
func (c *conn) Close() error {
// c.mu.Lock()
// defer c.mu.Unlock()
log.Printf("SOMEONE IS CLOSING THE CONNECTION\n")
return c.rawConn.Close()
}
func (c *conn) LocalAddr() net.Addr {
// c.mu.Lock()
// defer c.mu.Unlock()
return c.rawConn.LocalAddr()
}
func (c *conn) RemoteAddr() net.Addr {
// c.mu.Lock()
// defer c.mu.Unlock()
return c.rawConn.RemoteAddr()
}
func (c *conn) SetDeadline(t time.Time) error {
// c.mu.Lock()
// defer c.mu.Unlock()
return c.rawConn.SetDeadline(t)
}
func (c *conn) SetReadDeadline(t time.Time) error {
// c.mu.Lock()
// defer c.mu.Unlock()
return c.rawConn.SetReadDeadline(t)
}
func (c *conn) SetWriteDeadline(t time.Time) error {
// c.mu.Lock()
// defer c.mu.Unlock()
return c.rawConn.SetWriteDeadline(t)
}
// TODO: should I just collapse current and advance?
func (c *conn) currentNonce() *[24]byte {
c.mu.Lock()
defer c.mu.Unlock()
return &c.nonce
}
// TODO: re enable.
func (c *conn) advanceNonce() {
// c.counter++
// binary.LittleEndian.PutUint64(c.nonce[:], c.counter)
}