veyron/lib/bluetooth: Make the bluetooth net.Conn and net.Listener
implementations amenable to concurrent method invocations.

This is required to adhere to the net.Conn interface.

Prior to this commit, the net.Conn and net.Listener implementations in
the bluetooth package were subject to races where a read/write/accept
may be called on a file descriptor that has been closed and reassigned
to a different file/socket.

This commit introduces a "thread-safe" file descriptor wrapper in the
bluetooth.fd type. The basic idea is to replace one file descriptor with
two - the second one being an eventfd used to signal a pending Close
operation and then using a mutex and the select system call to ensure
that operations on the file descriptors cannot happen after they have
been closed.

Change-Id: Ic57161cf4f2b893138beba9cfefba36a4141aaec
diff --git a/lib/bluetooth/bluetooth_linux.go b/lib/bluetooth/bluetooth_linux.go
index ed7e378..2a3ba83 100644
--- a/lib/bluetooth/bluetooth_linux.go
+++ b/lib/bluetooth/bluetooth_linux.go
@@ -128,11 +128,7 @@
 		syscall.Close(int(socket))
 		return nil, fmt.Errorf("listen error: %v", err)
 	}
-
-	return &listener{
-		localAddr: local,
-		socket:    int(socket),
-	}, nil
+	return newListener(int(socket), local)
 }
 
 // Dial creates a new RFCOMM connection with the remote address, specified in
@@ -182,11 +178,7 @@
 		defer C.free(unsafe.Pointer(es))
 		return nil, fmt.Errorf("dial error: error connecting to remote address: %s, error: %s", remoteAddr, C.GoString(es))
 	}
-	return &conn{
-		fd:         int(socket),
-		localAddr:  &local,
-		remoteAddr: remote,
-	}, nil
+	return newConn(int(socket), &local, remote)
 }
 
 // Device is a struct representing an opened Bluetooth device.  It consists of
diff --git a/lib/bluetooth/conn.go b/lib/bluetooth/conn.go
index 6b50941..6574313 100644
--- a/lib/bluetooth/conn.go
+++ b/lib/bluetooth/conn.go
@@ -1,3 +1,5 @@
+// +build linux
+
 package bluetooth
 
 import (
@@ -9,56 +11,34 @@
 
 // conn represents one RFCOMM connection between two bluetooth devices.
 // It implements the net.Conn interface.
-//
-// TODO(ashankar,spetrovic): net.Conn implementations are supposed to be safe
-// for concurrent method invocations. This implementation is not. Fix.
 type conn struct {
-	fd                    int
-	localAddr, remoteAddr *addr
+	fd                    *fd
+	localAddr, remoteAddr net.Addr
 	readDeadline          time.Time
 	writeDeadline         time.Time
 }
 
+func newConn(sockfd int, local, remote net.Addr) (net.Conn, error) {
+	fd, err := newFD(sockfd)
+	if err != nil {
+		syscall.Close(sockfd)
+		return nil, err
+	}
+	return &conn{fd: fd, localAddr: local, remoteAddr: remote}, nil
+}
+
 func (c *conn) String() string {
 	return fmt.Sprintf("Bluetooth (%s) <--> (%s)", c.localAddr, c.remoteAddr)
 }
 
-// helper method for Read and Write that ensures:
-// - the returned 'n' is always >= 0, as per guidelines for the io.Reader and
-//   io.Writer interfaces.
-func (c *conn) rw(n int, err error) (int, error) {
-	if n < 0 {
-		n = 0
-	}
-	return n, err
-}
-
-// Implements the net.Conn interface.
-func (c *conn) Read(p []byte) (n int, err error) {
-	return c.rw(syscall.Read(c.fd, p))
-}
-
-// Implements the net.Conn interface.
-func (c *conn) Write(p []byte) (n int, err error) {
-	return c.rw(syscall.Write(c.fd, p))
-}
-
-// Implements the net.Conn interface.
-func (c *conn) Close() error {
-	return syscall.Close(c.fd)
-}
-
-// Implements the net.Conn interface.
-func (c *conn) LocalAddr() net.Addr {
-	return c.localAddr
-}
-
-// Implements the net.Conn interface.
-func (c *conn) RemoteAddr() net.Addr {
-	return c.remoteAddr
-}
-
-// Implements the net.Conn interface.
+// net.Conn interface methods
+func (c *conn) Read(p []byte) (n int, err error)   { return c.fd.Read(p) }
+func (c *conn) Write(p []byte) (n int, err error)  { return c.fd.Write(p) }
+func (c *conn) Close() error                       { return c.fd.Close() }
+func (c *conn) LocalAddr() net.Addr                { return c.localAddr }
+func (c *conn) RemoteAddr() net.Addr               { return c.remoteAddr }
+func (c *conn) SetReadDeadline(t time.Time) error  { return c.setSockoptTimeval(t, syscall.SO_RCVTIMEO) }
+func (c *conn) SetWriteDeadline(t time.Time) error { return c.setSockoptTimeval(t, syscall.SO_SNDTIMEO) }
 func (c *conn) SetDeadline(t time.Time) error {
 	if err := c.SetReadDeadline(t); err != nil {
 		return err
@@ -69,18 +49,14 @@
 	return nil
 }
 
-// Implements the net.Conn interface.
-func (c *conn) SetReadDeadline(t time.Time) error {
-	if timeout := getTimeout(t); timeout != nil {
-		return syscall.SetsockoptTimeval(c.fd, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout)
+func (c *conn) setSockoptTimeval(t time.Time, opt int) error {
+	fd, err := c.fd.Reference()
+	if err != nil {
+		return err
 	}
-	return nil
-}
-
-// Implements the net.Conn interface.
-func (c *conn) SetWriteDeadline(t time.Time) error {
+	defer c.fd.ReleaseReference()
 	if timeout := getTimeout(t); timeout != nil {
-		return syscall.SetsockoptTimeval(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDTIMEO, timeout)
+		return syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET, opt, timeout)
 	}
 	return nil
 }
diff --git a/lib/bluetooth/conn_test.go b/lib/bluetooth/conn_test.go
new file mode 100644
index 0000000..8aaa2f1
--- /dev/null
+++ b/lib/bluetooth/conn_test.go
@@ -0,0 +1,87 @@
+// +build linux
+
+package bluetooth
+
+import (
+	"runtime"
+	"syscall"
+	"testing"
+)
+
+// TestConnConcurrency attempts to tests that methods on the *conn type be
+// friendly to concurrent invocation. Unable to figure out a clean way to do
+// this, the author has resorted to just firing up a bunch of goroutines and
+// hoping that failures will manifest often.
+func TestConnConcurrency(t *testing.T) {
+	const (
+		// These numbers were tuned to make the test fail "often"
+		// without the accompanying change to conn.go in the commit
+		// that added this test on the machine that the author was
+		// using at the time.
+		nConcurrentReaders = 30
+		nConcurrentClosers = 10
+	)
+	mp := runtime.GOMAXPROCS(nConcurrentReaders)
+	defer runtime.GOMAXPROCS(mp)
+
+	pipe := func() (rfd, wfd int) {
+		var fds [2]int
+		if err := syscall.Pipe(fds[:]); err != nil {
+			t.Fatal(err)
+		}
+		return fds[0], fds[1]
+	}
+	rfd, wfd := pipe()
+	rConn, _ := newConn(rfd, nil, nil)
+	wConn, _ := newConn(wfd, nil, nil)
+	const (
+		bugs  = "bugs bunny"
+		daffy = "daffy duck"
+	)
+	rchan := make(chan string)
+	// Write a bunch of times
+	for i := 0; i < nConcurrentReaders; i++ {
+		go wConn.Write([]byte(bugs))
+	}
+	read := func() {
+		buf := make([]byte, len(bugs))
+		if n, err := rConn.Read(buf); err == nil {
+			rchan <- string(buf[:n])
+			return
+		}
+		rchan <- ""
+	}
+	// Fire up half the readers before Close
+	for i := 0; i < nConcurrentReaders; i += 2 {
+		go read()
+	}
+	// Fire up the closers (and attempt to reassign the file descriptors to
+	// something new).
+	for i := 0; i < nConcurrentClosers; i++ {
+		go func() {
+			rConn.Close()
+			// Create new FDs, which may re-use the closed file descriptors
+			// and write something else to them.
+			rfd, wfd := pipe()
+			syscall.Write(wfd, []byte(daffy))
+			syscall.Close(wfd)
+			syscall.Close(rfd)
+		}()
+	}
+	// And then the remaining readers
+	for i := 1; i < nConcurrentReaders; i += 2 {
+		go read()
+	}
+	// Now read from the channel, should either see full bugs bunnies or empty strings.
+	nEmpty := 0
+	for i := 0; i < nConcurrentReaders; i++ {
+		got := <-rchan
+		switch {
+		case len(got) == 0:
+			nEmpty++
+		case got != bugs:
+			t.Errorf("Read %q, wanted %q or empty string", got, bugs)
+		}
+	}
+	t.Logf("Read returned non-empty %d/%d times", (nConcurrentReaders - nEmpty), nConcurrentReaders)
+}
diff --git a/lib/bluetooth/fd.go b/lib/bluetooth/fd.go
new file mode 100644
index 0000000..54766f3
--- /dev/null
+++ b/lib/bluetooth/fd.go
@@ -0,0 +1,217 @@
+package bluetooth
+
+// #include <stddef.h>
+// #include <sys/eventfd.h>
+// #include <sys/select.h>
+//
+// int add_to_eventfd(int fd) {
+//	uint64_t val = 1;
+//	return write(fd, &val, 8);
+// }
+//
+// int wait(int eventfd, int readfd, int writefd) {
+//	fd_set readfds, writefds;
+//	FD_ZERO(&readfds);
+//	FD_ZERO(&writefds);
+//	fd_set* writefdsp = NULL;
+//
+//	FD_SET(eventfd, &readfds);
+//	int nfds = eventfd + 1;
+//
+//	if (readfd >= 0) {
+//		FD_SET(readfd, &readfds);
+//		if (readfd >= nfds) {
+//			nfds = readfd + 1;
+//		}
+//	}
+//	if (writefd >= 0) {
+//		FD_SET(writefd, &writefds);
+//		if (writefd >= nfds) {
+//			nfds = writefd + 1;
+//		}
+//		writefdsp = &writefds;
+//	}
+//	// TODO(ashankar): Should EINTR be handled by a retry?
+//	// See "Select Law" section of "man 2 select_tut".
+//	int nready = select(nfds, &readfds, writefdsp, NULL, NULL);
+//	return nready >= 0 && (FD_ISSET(readfd, &readfds) || FD_ISSET(writefd, &writefds));
+// }
+import "C"
+
+import (
+	"fmt"
+	"io"
+	"sync"
+	"syscall"
+)
+
+// An fd enables concurrent invocations of Read, Write and Close on a file
+// descriptor.
+//
+// It ensures that read, write and close operations do not conflict and thereby
+// avoids races between file descriptors being closed and re-used while a
+// read/write is being initiated.
+//
+// This is achieved by using an eventfd to signal the intention to close a
+// descriptor and a select over the eventfd and the file descriptor being
+// protected.
+type fd struct {
+	mu              sync.Mutex
+	datafd, eventfd C.int
+	closing         bool       // Whether Close has been or is being invoked.
+	done            *sync.Cond // Signaled when no Read or Writes are pending.
+	refcnt          int
+}
+
+// newFD creates an fd object providing read, write and close operations
+// over datafd that are not hostile to concurrent invocations.
+func newFD(datafd int) (*fd, error) {
+	eventfd, err := C.eventfd(0, C.EFD_CLOEXEC)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create eventfd: %v", err)
+	}
+	ret := &fd{datafd: C.int(datafd), eventfd: eventfd}
+	ret.done = sync.NewCond(&ret.mu)
+	return ret, nil
+}
+
+func (fd *fd) Read(p []byte) (int, error) {
+	e, d, err := fd.prepare()
+	if err != nil {
+		return 0, err
+	}
+	defer fd.finish()
+	if err := wait(e, d, -1); err != nil {
+		return 0, err
+	}
+	return fd.rw(syscall.Read(int(fd.datafd), p))
+}
+
+func (fd *fd) Write(p []byte) (int, error) {
+	e, d, err := fd.prepare()
+	if err != nil {
+		return 0, err
+	}
+	defer fd.finish()
+	if err := wait(e, -1, d); err != nil {
+		return 0, err
+	}
+	return fd.rw(syscall.Write(int(fd.datafd), p))
+}
+
+// RunWhenReadable invokes f(file descriptor) when the file descriptor is ready
+// to be read. It returns an error if the file descriptor has been closed
+// either before or while this method is being invoked.
+//
+// f must NOT close readfd.
+func (fd *fd) RunWhenReadable(f func(readfd int)) error {
+	e, d, err := fd.prepare()
+	if err != nil {
+		return err
+	}
+	defer fd.finish()
+	if err := wait(e, d, -1); err != nil {
+		return err
+	}
+	f(int(d))
+	return nil
+}
+
+// Reference returns the underlying file descriptor and ensures that calls to
+// Close will block until ReleaseReference has been called.
+//
+// Clients must NOT close the returned file descriptor.
+func (fd *fd) Reference() (int, error) {
+	fd.mu.Lock()
+	defer fd.mu.Unlock()
+	if fd.closing {
+		return -1, fmt.Errorf("closing")
+	}
+	if fd.datafd < 0 {
+		return -1, fmt.Errorf("closed")
+	}
+	fd.refcnt++
+	return int(fd.datafd), nil
+}
+
+// ReleaseReference returns a reference to the file descriptor grabbed by a
+// call to Reference, thereby unblocking any Close operations.
+func (fd *fd) ReleaseReference() { fd.finish() }
+
+// helper method for Read and Write that ensures:
+// - the returned 'n' is always >= 0, as per guidelines for the io.Reader and
+//   io.Writer interfaces.
+func (fd *fd) rw(n int, err error) (int, error) {
+	if n == 0 && err == nil {
+		err = io.EOF
+	}
+	if n < 0 {
+		n = 0
+	}
+	return n, err
+}
+
+func (fd *fd) prepare() (eventfd, datafd C.int, err error) {
+	fd.mu.Lock()
+	defer fd.mu.Unlock()
+	if fd.closing {
+		return 0, 0, fmt.Errorf("closing")
+	}
+	fd.refcnt++
+	// returned file descriptors are guaranteed to be
+	// valid till refcnt is reduced by at least 1, since
+	// Close waits for the refcnt to go down to zero before
+	// closing these file descriptors.
+	return fd.eventfd, fd.datafd, nil
+}
+
+func wait(eventfd, readfd, writefd C.int) error {
+	ok, err := C.wait(eventfd, readfd, writefd)
+	if err != nil {
+		return err
+	}
+	if ok <= 0 {
+		return fmt.Errorf("closing")
+	}
+	return nil
+}
+
+func (fd *fd) finish() {
+	fd.mu.Lock()
+	fd.refcnt--
+	if fd.closing && fd.refcnt == 0 {
+		fd.done.Broadcast()
+	}
+	fd.mu.Unlock()
+}
+
+func (fd *fd) Close() error {
+	fd.mu.Lock()
+	defer fd.mu.Unlock()
+	if !fd.closing {
+		// Send an "event" to notify of closures.
+		if _, err := C.add_to_eventfd(fd.eventfd); err != nil {
+			return fmt.Errorf("failed to notify closure on eventfd: %v", err)
+		}
+		// Prevent any new Read/Write/RunWhenReadable calls from starting.
+		fd.closing = true
+	}
+	for fd.refcnt > 0 {
+		fd.done.Wait()
+	}
+	// At this point, there are no concurrent Read/Write/RunWhenReadable
+	// calls that are using the file descriptors.
+	if fd.eventfd > 0 {
+		if err := syscall.Close(int(fd.eventfd)); err != nil {
+			return fmt.Errorf("failed to close eventfd: %v", err)
+		}
+		fd.eventfd = -1
+	}
+	if fd.datafd > 0 {
+		if err := syscall.Close(int(fd.datafd)); err != nil {
+			return fmt.Errorf("failed to close underlying socket/filedescriptor: %v", err)
+		}
+		fd.datafd = -1
+	}
+	return nil
+}
diff --git a/lib/bluetooth/fd_test.go b/lib/bluetooth/fd_test.go
new file mode 100644
index 0000000..9b7b662
--- /dev/null
+++ b/lib/bluetooth/fd_test.go
@@ -0,0 +1,151 @@
+// +build linux
+
+package bluetooth
+
+import (
+	"fmt"
+	"io"
+	"sort"
+	"syscall"
+	"testing"
+	"time"
+)
+
+// mkfds returns two *fds, one on which Read can be called and one on which
+// Write can be called by using the pipe system call. This pipe is a cheap
+// approximation of a file descriptor backed by a network socket that the fd type
+// is really intended for.
+func mkfds() (readfd, writefd *fd, err error) {
+	var fds [2]int
+	if err = syscall.Pipe(fds[:]); err != nil {
+		err = fmt.Errorf("syscall.Pipe failed: %v", err)
+		return
+	}
+	if readfd, err = newFD(fds[0]); err != nil {
+		err = fmt.Errorf("newFD failed for readfd: %v", err)
+		return
+	}
+	if writefd, err = newFD(fds[1]); err != nil {
+		err = fmt.Errorf("newFD failed for writefd: %v", err)
+		return
+	}
+	return
+}
+
+// canClose calls fd.Close and returns true if fd.Close returns.
+// It returns false if fd.Close blocks.
+// This function uses time to guess whether fd.Close is blocked or
+// not, and is thus not the most accurate implementation. The author
+// welcomes advice on restructuring this function or tests involving
+// it to make the testing deterministically accurate.
+func canClose(fd *fd) bool {
+	c := make(chan error)
+	go func() {
+		c <- fd.Close()
+	}()
+	select {
+	case <-c:
+		return true
+	case <-time.After(time.Millisecond):
+		return false
+	}
+}
+
+func TestFDBasic(t *testing.T) {
+	rfd, wfd, err := mkfds()
+	if err != nil {
+		t.Fatal(err)
+	}
+	const batman = "batman"
+	if n, err := wfd.Write([]byte(batman)); n != 6 || err != nil {
+		t.Errorf("Got (%d, %v) want (6, nil)", n, err)
+	}
+	var read [1024]byte
+	if n, err := rfd.Read(read[:]); n != 6 || err != nil || string(read[:n]) != string(batman) {
+		t.Errorf("Got (%d, %v) = %q, want (6, nil) = %q", n, err, read[:n], batman)
+	}
+	if err := rfd.Close(); err != nil {
+		t.Error(err)
+	}
+	if err := wfd.Close(); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestFDReference(t *testing.T) {
+	fd, _, err := mkfds()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := fd.Reference(); err != nil {
+		t.Fatal(err)
+	}
+	if canClose(fd) {
+		t.Errorf("Should not be able to close fd since there is an outstanding reference")
+	}
+	fd.ReleaseReference()
+	if !canClose(fd) {
+		t.Errorf("Should be able to close fd since there are no outstanding references")
+	}
+}
+
+func TestFDReadEOF(t *testing.T) {
+	rfd, wfd, err := mkfds()
+	if err != nil {
+		t.Fatal(err)
+	}
+	const (
+		bugs  = "bugs"
+		bunny = "bunny"
+	)
+	if n, err := wfd.Write([]byte(bugs)); n != len(bugs) || err != nil {
+		t.Fatalf("Got (%d, %v) want (%d, nil)", n, err, len(bugs))
+	}
+	if n, err := wfd.Write([]byte(bunny)); n != len(bunny) || err != nil {
+		t.Fatalf("Got (%d, %v) want (%d, nil)", n, err, len(bunny))
+	}
+	if err := wfd.Close(); err != nil {
+		t.Fatal(err)
+	}
+	var read [1024]byte
+	if n, err := rfd.Read(read[:]); n != len(bugs)+len(bunny) || err != nil {
+		t.Errorf("Got (%d, %v) = %q, want (%d, nil) = %q", n, err, read[:n], len(bugs)+len(bunny), "bugsbunny")
+	}
+	if n, err := rfd.Read(read[:]); n != 0 || err != io.EOF {
+		t.Errorf("Got (%d, %v) = %q, want (0, EOF)", n, err, read[:n])
+	}
+}
+
+func TestFDReadLessThanReady(t *testing.T) {
+	rfd, wfd, err := mkfds()
+	if err != nil {
+		t.Fatal(err)
+	}
+	const nbytes = 20
+	rchan := make(chan int, nbytes)
+	written := make([]byte, nbytes)
+	for i := 1; i <= nbytes; i++ {
+		written[i-1] = byte(i)
+		go func() {
+			var buf [1]byte
+			rfd.Read(buf[:])
+			rchan <- int(buf[0])
+		}()
+	}
+	if n, err := wfd.Write(written); n != nbytes || err != nil {
+		t.Fatal("Got (%d, %v), want (%d, nil)", n, err, nbytes)
+	}
+	if err := wfd.Close(); err != nil {
+		t.Fatal(err)
+	}
+	read := make([]int, nbytes)
+	for i := 0; i < nbytes; i++ {
+		read[i] = <-rchan
+	}
+	sort.Ints(read)
+	for i, v := range read {
+		if i != v-1 {
+			t.Fatalf("Got %v, wanted it to be sorted", read)
+		}
+	}
+}
diff --git a/lib/bluetooth/listener.go b/lib/bluetooth/listener.go
index da383ad..a375a8f 100644
--- a/lib/bluetooth/listener.go
+++ b/lib/bluetooth/listener.go
@@ -5,7 +5,6 @@
 import (
 	"fmt"
 	"net"
-	"syscall"
 	"unsafe"
 )
 
@@ -20,17 +19,40 @@
 // listener waits for incoming RFCOMM connections on the provided socket.
 // It implements the net.Listener interface.
 type listener struct {
-	localAddr *addr
-	socket    int
+	fd         *fd
+	acceptChan chan (acceptResult)
+	localAddr  net.Addr
+}
+
+type acceptResult struct {
+	conn net.Conn
+	err  error
+}
+
+func newListener(sockfd int, addr net.Addr) (net.Listener, error) {
+	fd, err := newFD(sockfd)
+	if err != nil {
+		return nil, err
+	}
+	return &listener{fd: fd, acceptChan: make(chan acceptResult), localAddr: addr}, nil
 }
 
 // Implements the net.Listener interface.
 func (l *listener) Accept() (net.Conn, error) {
+	go l.fd.RunWhenReadable(l.accept)
+	r := <-l.acceptChan
+	return r.conn, r.err
+}
+
+func (l *listener) accept(sockfd int) {
 	var fd C.int
 	var remoteMAC *C.char
-	if es := C.bt_accept(C.int(l.socket), &fd, &remoteMAC); es != nil {
+	var result acceptResult
+	defer func() { l.acceptChan <- result }()
+	if es := C.bt_accept(C.int(sockfd), &fd, &remoteMAC); es != nil {
 		defer C.free(unsafe.Pointer(es))
-		return nil, fmt.Errorf("error accepting connection on %s, socket: %d, error: %s", l.localAddr, l.socket, C.GoString(es))
+		result.err = fmt.Errorf("error accepting connection on %s, socket: %d, error: %s", l.localAddr, sockfd, C.GoString(es))
+		return
 	}
 	defer C.free(unsafe.Pointer(remoteMAC))
 
@@ -38,21 +60,17 @@
 	var remote addr
 	var err error
 	if remote.mac, err = net.ParseMAC(C.GoString(remoteMAC)); err != nil {
-		return nil, fmt.Errorf("invalid remote MAC address: %s, err: %s", C.GoString(remoteMAC), err)
+		result.err = fmt.Errorf("invalid remote MAC address: %s, err: %s", C.GoString(remoteMAC), err)
+		return
 	}
 	// There's no way to get accurate remote channel number, so use 0.
 	remote.channel = 0
-
-	return &conn{
-		fd:         int(fd),
-		localAddr:  l.localAddr,
-		remoteAddr: &remote,
-	}, nil
+	result.conn, result.err = newConn(int(fd), l.localAddr, &remote)
 }
 
 // Implements the net.Listener interface.
 func (l *listener) Close() error {
-	return syscall.Close(l.socket)
+	return l.fd.Close()
 }
 
 // Implements the net.Listener interface.