blob: 10630ea2249bb9fc05216b79d234b94af9d30d80 [file] [log] [blame] [edit]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unixfd
import (
"bytes"
"io"
"net"
"reflect"
"testing"
)
type nothing struct{}
func dial(fd *fileDescriptor) (net.Conn, net.Addr, error) {
addr := fd.releaseAddr()
conn, err := unixFDConn(Network, addr.String(), 0)
return conn, addr, err
}
func listen(fd *fileDescriptor) (net.Listener, net.Addr, error) {
addr := fd.releaseAddr()
l, err := unixFDListen(Network, addr.String())
return l, addr, err
}
func testWrite(t *testing.T, c net.Conn, data string) {
n, err := c.Write([]byte(data))
if err != nil {
t.Errorf("Write: %v", err)
return
}
if n != len(data) {
t.Errorf("Wrote %d bytes, expected %d", n, len(data))
}
}
func testRead(t *testing.T, c net.Conn, expected string) {
buf := make([]byte, len(expected)+2)
n, err := c.Read(buf)
if err != nil {
t.Errorf("Read: %v", err)
return
}
if n != len(expected) || !bytes.Equal(buf[0:n], []byte(expected)) {
t.Errorf("got %q, expected %q", buf[0:n], expected)
}
}
func TestDial(t *testing.T) {
local, remote, err := socketpair()
if err != nil {
t.Fatalf("socketpair: %v", err)
}
a, a_addr, err := dial(local)
if err != nil {
t.Fatalf("dial: %v", err)
}
b, b_addr, err := dial(remote)
if err != nil {
t.Fatalf("dial: %v", err)
}
testWrite(t, a, "TEST1")
testRead(t, b, "TEST1")
testWrite(t, b, "TEST2")
testRead(t, a, "TEST2")
if !reflect.DeepEqual(a.LocalAddr(), a_addr) {
t.Errorf("Invalid address %v, expected %v", a.LocalAddr(), a_addr)
}
if !reflect.DeepEqual(a.RemoteAddr(), a_addr) {
t.Errorf("Invalid address %v, expected %v", a.RemoteAddr(), a_addr)
}
if !reflect.DeepEqual(b.LocalAddr(), b_addr) {
t.Errorf("Invalid address %v, expected %v", a.LocalAddr(), b_addr)
}
if !reflect.DeepEqual(b.RemoteAddr(), b_addr) {
t.Errorf("Invalid address %v, expected %v", a.RemoteAddr(), b_addr)
}
}
func TestListen(t *testing.T) {
local, remote, err := socketpair()
if err != nil {
t.Fatalf("socketpair: %v", err)
}
a, _, err := dial(local)
if err != nil {
t.Fatalf("dial: %v", err)
}
l, _, err := listen(remote)
if err != nil {
t.Fatalf("listen: %v", err)
}
b, err := l.Accept()
if err != nil {
t.Fatalf("accept: %v", err)
}
start := make(chan nothing, 0)
done := make(chan nothing)
go func() {
defer close(done)
<-start
if _, err := l.Accept(); err != io.EOF {
t.Fatalf("accept: expected EOF, got %v", err)
}
}()
// block until the goroutine starts running
start <- nothing{}
testWrite(t, a, "LISTEN")
testRead(t, b, "LISTEN")
err = l.Close()
if err != nil {
t.Fatalf("close: %v", err)
}
<-done
// After closed, accept should fail immediately
_, err = l.Accept()
if err == nil {
t.Fatalf("Accept succeeded after close")
}
err = l.Close()
if err == nil {
t.Fatalf("Close succeeded twice")
}
}
func TestSendConnection(t *testing.T) {
server, client, err := Socketpair()
if err != nil {
t.Fatalf("Socketpair: %v", err)
}
uclient, err := net.FileConn(client)
if err != nil {
t.Fatalf("FileConn: %v", err)
}
var readErr error
var n int
var saddr net.Addr
done := make(chan struct{})
buf := make([]byte, 10)
go func() {
var ack func()
saddr, n, ack, readErr = ReadConnection(server, buf)
if ack != nil {
ack()
}
close(done)
}()
caddr, err := SendConnection(uclient.(*net.UnixConn), []byte("hello"))
if err != nil {
t.Fatalf("SendConnection: %v", err)
}
<-done
if readErr != nil {
t.Fatalf("ReadConnection: %v", readErr)
}
if saddr == nil {
t.Fatalf("ReadConnection returned nil, %d", n)
}
data := buf[0:n]
if !bytes.Equal([]byte("hello"), data) {
t.Fatalf("unexpected data %q", data)
}
a, err := unixFDConn(Network, caddr.String(), 0)
if err != nil {
t.Fatalf("dial %v: %v", caddr, err)
}
b, err := unixFDConn(Network, saddr.String(), 0)
if err != nil {
t.Fatalf("dial %v: %v", saddr, err)
}
testWrite(t, a, "TEST1")
testRead(t, b, "TEST1")
testWrite(t, b, "TEST2")
testRead(t, a, "TEST2")
}