veyron/runtimes/google/ipc: When listening on tcp, also allow websocket
requests to the same port.
Change-Id: I719f583f52920b7238118ed28e7d576166952248
diff --git a/lib/modules/core/core_test.go b/lib/modules/core/core_test.go
index d8a2a1d..1cc69e2 100644
--- a/lib/modules/core/core_test.go
+++ b/lib/modules/core/core_test.go
@@ -6,6 +6,7 @@
"reflect"
"sort"
"strconv"
+ "strings"
"testing"
"time"
@@ -242,6 +243,7 @@
srvSession.ExpectVar("NAME")
addr := srvSession.ExpectVar("ADDR")
addr = naming.JoinAddressName(addr, "")
+ wsAddr := strings.Replace(addr, "@tcp@", "@ws@", 1)
// Resolve an object
resolver, err := sh.Start(core.ResolveCommand, nil, rootName+"/"+echoName)
@@ -249,11 +251,11 @@
t.Fatalf("unexpected error: %s", err)
}
resolverSession := expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Errorf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Errorf("got %v, want either %v or %v", got, want, wsAddr)
}
if err = resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
@@ -261,16 +263,17 @@
// Resolve to a mount table using a rooted name.
addr = naming.JoinAddressName(mountAddrs[mtName], "echo")
+ wsAddr = strings.Replace(addr, "@tcp@", "@ws@", 1)
resolver, err = sh.Start(core.ResolveMTCommand, nil, rootName+"/"+echoName)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
resolverSession = expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Fatalf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Fatalf("got %v, want either %v or %v", got, want, wsAddr)
}
if err := resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
@@ -290,11 +293,11 @@
t.Fatalf("unexpected error: %s", err)
}
resolverSession = expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Fatalf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Fatalf("got %v, want either %v or %v", got, want, wsAddr)
}
if err := resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
diff --git a/lib/websocket/conn.go b/lib/websocket/conn.go
new file mode 100644
index 0000000..574b992
--- /dev/null
+++ b/lib/websocket/conn.go
@@ -0,0 +1,125 @@
+// +build !nacl
+package websocket
+
+import (
+ "fmt"
+ "github.com/gorilla/websocket"
+ "io"
+ "net"
+ "sync"
+
+ "time"
+)
+
+// WebsocketConn provides a net.Conn interface for a websocket connection.
+func WebsocketConn(ws *websocket.Conn) net.Conn {
+ return &wrappedConn{ws: ws}
+}
+
+// wrappedConn provides a net.Conn interface to a websocket.
+// The underlying websocket connection needs regular calls to Read to make sure
+// websocket control messages (such as pings) are processed by the websocket
+// library.
+type wrappedConn struct {
+ ws *websocket.Conn
+ currReader io.Reader
+
+ // The gorilla docs aren't explicit about reading and writing from
+ // different goroutines. It is explicit that only one goroutine can
+ // do a write at any given time and only one goroutine can do a read
+ // at any given time. Based on inspection it seems that using a reader
+ // and writer simultaneously is safe, but this might change with
+ // future changes. We can't actually share the lock, because this means
+ // that we can't write while we are waiting for a message, causing some
+ // deadlocks where a write is need to unblock a read.
+ writeLock sync.Mutex
+ readLock sync.Mutex
+}
+
+func (c *wrappedConn) readFromCurrReader(b []byte) (int, error) {
+ n, err := c.currReader.Read(b)
+ if err == io.EOF {
+ err = nil
+ c.currReader = nil
+ }
+ return n, err
+
+}
+
+func (c *wrappedConn) Read(b []byte) (int, error) {
+ c.readLock.Lock()
+ defer c.readLock.Unlock()
+ var n int
+ var err error
+
+ // TODO(bjornick): It would be nice to be able to read multiple messages at
+ // a time in case the first message is not big enough to fill b and another
+ // message is ready.
+ // Loop until we either get data or an error. This exists
+ // mostly to avoid return 0, nil.
+ for n == 0 && err == nil {
+ if c.currReader == nil {
+ t, r, err := c.ws.NextReader()
+
+ if t != websocket.BinaryMessage {
+ return 0, fmt.Errorf("Unexpected message type %d", t)
+ }
+ if err != nil {
+ return 0, err
+ }
+ c.currReader = r
+ }
+ n, err = c.readFromCurrReader(b)
+ }
+ return n, err
+}
+
+func (c *wrappedConn) Write(b []byte) (int, error) {
+ c.writeLock.Lock()
+ defer c.writeLock.Unlock()
+ if err := c.ws.WriteMessage(websocket.BinaryMessage, b); err != nil {
+ return 0, err
+ }
+ return len(b), nil
+}
+
+func (c *wrappedConn) Close() error {
+ c.writeLock.Lock()
+ defer c.writeLock.Unlock()
+ return c.ws.Close()
+}
+
+func (c *wrappedConn) LocalAddr() net.Addr {
+ return websocketAddr{s: c.ws.LocalAddr().String()}
+}
+
+func (c *wrappedConn) RemoteAddr() net.Addr {
+ return websocketAddr{s: c.ws.RemoteAddr().String()}
+}
+
+func (c *wrappedConn) SetDeadline(t time.Time) error {
+ if err := c.SetReadDeadline(t); err != nil {
+ return err
+ }
+ return c.SetWriteDeadline(t)
+}
+
+func (c *wrappedConn) SetReadDeadline(t time.Time) error {
+ return c.ws.SetReadDeadline(t)
+}
+
+func (c *wrappedConn) SetWriteDeadline(t time.Time) error {
+ return c.ws.SetWriteDeadline(t)
+}
+
+type websocketAddr struct {
+ s string
+}
+
+func (websocketAddr) Network() string {
+ return "ws"
+}
+
+func (w websocketAddr) String() string {
+ return w.s
+}
diff --git a/lib/websocket/conn_test.go b/lib/websocket/conn_test.go
new file mode 100644
index 0000000..db38cbe
--- /dev/null
+++ b/lib/websocket/conn_test.go
@@ -0,0 +1,111 @@
+// +build !nacl
+package websocket
+
+import (
+ "bytes"
+ "github.com/gorilla/websocket"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+)
+
+func writer(c net.Conn, data []byte, times int, wg *sync.WaitGroup) {
+ defer wg.Done()
+ b := []byte{byte(len(data))}
+ b = append(b, data...)
+ for i := 0; i < times; i++ {
+ c.Write(b)
+ }
+}
+
+func readMessage(c net.Conn) ([]byte, error) {
+ var length [1]byte
+ // Read the size
+ for {
+ n, err := c.Read(length[:])
+ if err != nil {
+ return nil, err
+ }
+ if n == 1 {
+ break
+ }
+ }
+ size := int(length[0])
+ buf := make([]byte, size)
+ n := 0
+ for n < size {
+ nn, err := c.Read(buf[n:])
+ if err != nil {
+ return buf, err
+ }
+ n += nn
+ }
+
+ return buf, nil
+}
+
+func reader(t *testing.T, c net.Conn, expected []byte, totalWrites int) {
+ totalReads := 0
+ for buf, err := readMessage(c); err == nil; buf, err = readMessage(c) {
+ totalReads++
+ if !bytes.Equal(buf, expected) {
+ t.Errorf("Unexpected message %v, expected %v", buf, expected)
+ }
+ }
+ if totalReads != totalWrites {
+ t.Errorf("wrong number of messages expected %v, got %v", totalWrites, totalReads)
+ }
+}
+
+func TestMultipleGoRoutines(t *testing.T) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+ addr := l.Addr()
+ input := []byte("no races here")
+ const numWriters int = 12
+ const numWritesPerWriter int = 1000
+ const totalWrites int = numWriters * numWritesPerWriter
+ s := &http.Server{
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
+ return
+ }
+ ws, err := websocket.Upgrade(w, r, nil, 1024, 1024)
+ if _, ok := err.(websocket.HandshakeError); ok {
+ http.Error(w, "Not a websocket handshake", 400)
+ return
+ } else if err != nil {
+ http.Error(w, "Internal Error", 500)
+ return
+ }
+ reader(t, WebsocketConn(ws), input, totalWrites)
+ }),
+ }
+ // Dial out in another go routine
+ go func() {
+ conn, err := Dial(addr.String())
+ numTries := 0
+ for err != nil && numTries < 5 {
+ numTries++
+ time.Sleep(time.Second)
+ }
+
+ if err != nil {
+ t.Fatalf("failed to connect to server: %v", err)
+ }
+ var writers sync.WaitGroup
+ writers.Add(numWriters)
+ for i := 0; i < numWriters; i++ {
+ go writer(conn, input, numWritesPerWriter, &writers)
+ }
+ writers.Wait()
+ conn.Close()
+ l.Close()
+ }()
+ s.Serve(l)
+}
diff --git a/lib/websocket/dialer.go b/lib/websocket/dialer.go
new file mode 100644
index 0000000..8b12662
--- /dev/null
+++ b/lib/websocket/dialer.go
@@ -0,0 +1,27 @@
+// +build !nacl
+package websocket
+
+import (
+ "github.com/gorilla/websocket"
+ "net"
+ "net/http"
+ "net/url"
+)
+
+func Dial(address string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", address)
+ if err != nil {
+ return nil, err
+ }
+ u, err := url.Parse("ws://" + address)
+
+ if err != nil {
+ return nil, err
+ }
+ ws, _, err := websocket.NewClient(conn, u, http.Header{}, 4096, 4096)
+ if err != nil {
+ return nil, err
+ }
+
+ return WebsocketConn(ws), nil
+}
diff --git a/runtimes/google/ipc/client_test.go b/runtimes/google/ipc/client_test.go
index d5be886..c77320f 100644
--- a/runtimes/google/ipc/client_test.go
+++ b/runtimes/google/ipc/client_test.go
@@ -101,8 +101,8 @@
}
}
- // Verify that there are 101 entries for echoServer in the mount table.
- if got, want := numServers(t, sh, "echoServer"), "101"; got != want {
+ // Verify that there are 102 entries for echoServer in the mount table.
+ if got, want := numServers(t, sh, "echoServer"), "102"; got != want {
vlog.Fatalf("got: %q, want: %q", got, want)
}
@@ -119,7 +119,7 @@
// TODO(cnicolaou,p): figure out why the real entry isn't removed
// from the mount table.
// Verify that there are 100 entries for echoServer in the mount table.
- if got, want := numServers(t, sh, "echoServer"), "101"; got != want {
+ if got, want := numServers(t, sh, "echoServer"), "102"; got != want {
vlog.Fatalf("got: %q, want: %q", got, want)
}
}
diff --git a/runtimes/google/ipc/full_test.go b/runtimes/google/ipc/full_test.go
index 67e14b9..857b2e2 100644
--- a/runtimes/google/ipc/full_test.go
+++ b/runtimes/google/ipc/full_test.go
@@ -26,6 +26,7 @@
"veyron.io/veyron/veyron/lib/netstate"
"veyron.io/veyron/veyron/lib/testutil"
tsecurity "veyron.io/veyron/veyron/lib/testutil/security"
+ "veyron.io/veyron/veyron/lib/websocket"
imanager "veyron.io/veyron/veyron/runtimes/google/ipc/stream/manager"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vc"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
@@ -35,6 +36,10 @@
vsecurity "veyron.io/veyron/veyron/security"
)
+func init() {
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
+}
+
var (
errMethod = verror.Abortedf("server returned an error")
clock = new(fakeClock)
@@ -237,6 +242,20 @@
vlog.VI(1).Info("server.Stop DONE")
}
+func resolveWSEndpoint(ns naming.Namespace, name string) (string, error) {
+ // Find the ws endpoint and use that.
+ servers, err := ns.Resolve(testContext(), name)
+ if err != nil {
+ return "", err
+ }
+ for _, s := range servers {
+ if strings.Index(s, "@ws@") != -1 {
+ return s, nil
+ }
+ }
+ return "", fmt.Errorf("No ws endpoint found %v", servers)
+}
+
type bundle struct {
client ipc.Client
server ipc.Server
@@ -440,20 +459,40 @@
}
}
client.Close()
+
}
}
+type websocketMode bool
+type closeSendMode bool
+
+const (
+ useWebsocket websocketMode = true
+ noWebsocket websocketMode = false
+
+ closeSend closeSendMode = true
+ noCloseSend closeSendMode = false
+)
+
func TestRPC(t *testing.T) {
- testRPC(t, true)
+ testRPC(t, closeSend, noWebsocket)
+}
+
+func TestRPCWithWebsocket(t *testing.T) {
+ testRPC(t, closeSend, useWebsocket)
}
// TestCloseSendOnFinish tests that Finish informs the server that no more
// inputs will be sent by the client if CloseSend has not already done so.
func TestRPCCloseSendOnFinish(t *testing.T) {
- testRPC(t, false)
+ testRPC(t, noCloseSend, noWebsocket)
}
-func testRPC(t *testing.T, shouldCloseSend bool) {
+func TestRPCCloseSendOnFinishWithWebsocket(t *testing.T) {
+ testRPC(t, noCloseSend, useWebsocket)
+}
+
+func testRPC(t *testing.T, shouldCloseSend closeSendMode, shouldUseWebsocket websocketMode) {
type v []interface{}
type testcase struct {
name string
@@ -494,7 +533,16 @@
pserver.AddToRoots(pclient.BlessingStore().Default())
for _, test := range tests {
vlog.VI(1).Infof("%s client.StartCall", name(test))
- call, err := b.client.StartCall(testContext(), test.name, test.method, test.args)
+ vname := test.name
+ if shouldUseWebsocket {
+ var err error
+ vname, err = resolveWSEndpoint(b.ns, vname)
+ if err != nil && err != test.startErr {
+ t.Errorf(`%s ns.Resolve got error "%v", want "%v"`, name(test), err, test.startErr)
+ continue
+ }
+ }
+ call, err := b.client.StartCall(testContext(), vname, test.method, test.args)
if err != test.startErr {
t.Errorf(`%s client.StartCall got error "%v", want "%v"`, name(test), err, test.startErr)
continue
diff --git a/runtimes/google/ipc/server.go b/runtimes/google/ipc/server.go
index 75447aa..00e0688 100644
--- a/runtimes/google/ipc/server.go
+++ b/runtimes/google/ipc/server.go
@@ -283,6 +283,11 @@
s.active.Done()
}()
s.publisher.AddServer(s.publishEP(iep, s.servesMountTable), s.servesMountTable)
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ epCopy := *iep
+ epCopy.Protocol = "ws"
+ s.publisher.AddServer(s.publishEP(&epCopy, s.servesMountTable), s.servesMountTable)
+ }
}
if len(listenSpec.Proxy) > 0 {
@@ -321,6 +326,13 @@
s.listeners[ln] = nil
s.Unlock()
s.publisher.AddServer(s.publishEP(iep, s.servesMountTable), s.servesMountTable)
+
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ epCopy := *iep
+ epCopy.Protocol = "ws"
+ s.publisher.AddServer(s.publishEP(&epCopy, s.servesMountTable), s.servesMountTable)
+ }
+
return iep, ln, nil
}
@@ -351,6 +363,11 @@
// The listener is done, so:
// (1) Unpublish its name
s.publisher.RemoveServer(s.publishEP(iep, s.servesMountTable))
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ iepCopy := *iep
+ iepCopy.Protocol = "ws"
+ s.publisher.RemoveServer(s.publishEP(&iepCopy, s.servesMountTable))
+ }
}
s.Lock()
diff --git a/runtimes/google/ipc/server_test.go b/runtimes/google/ipc/server_test.go
index 1bb853a..a918f00 100644
--- a/runtimes/google/ipc/server_test.go
+++ b/runtimes/google/ipc/server_test.go
@@ -130,6 +130,10 @@
return addr
}
+func addWSName(name string) []string {
+ return []string{name, strings.Replace(name, "@tcp@", "@ws@", 1)}
+}
+
func testProxy(t *testing.T, spec ipc.ListenSpec) {
sm := imanager.InternalNew(naming.FixedRoutingID(0x555555555))
ns := tnaming.NewSimpleNamespace()
@@ -201,9 +205,9 @@
t.Fatalf("unexpected error: %s", err)
}
proxiedEP.RID = naming.FixedRoutingID(0x555555555)
- expectedEndpoints := []string{proxiedEP.String()}
+ expectedEndpoints := addWSName(proxiedEP.String())
if hasLocalListener {
- expectedEndpoints = append(expectedEndpoints, ep.String())
+ expectedEndpoints = append(expectedEndpoints, addWSName(ep.String())...)
}
// Proxy connetions are created asynchronously, so we wait for the
@@ -228,8 +232,12 @@
if hasLocalListener {
// Listen will publish both the local and proxied endpoint with the
// mount table, given that we're trying to test the proxy, we remove
- // the local endpoint from the mount table entry!
- ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(ep.String(), ""))
+ // the local endpoint from the mount table entry! We have to remove both
+ // the tcp and the websocket address.
+ sep := ep.String()
+ wsep := strings.Replace(sep, "@tcp@", "@ws@", 1)
+ ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(sep, ""))
+ ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(wsep, ""))
}
// Proxied endpoint should be published and RPC should succeed (through proxy)
diff --git a/runtimes/google/ipc/stream/manager/manager.go b/runtimes/google/ipc/stream/manager/manager.go
index 8620a61..55e06ee 100644
--- a/runtimes/google/ipc/stream/manager/manager.go
+++ b/runtimes/google/ipc/stream/manager/manager.go
@@ -12,6 +12,7 @@
"veyron.io/veyron/veyron/lib/stats"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/crypto"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vif"
+ "veyron.io/veyron/veyron/runtimes/google/ipc/stream/wslistener"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
inaming "veyron.io/veyron/veyron/runtimes/google/naming"
@@ -65,7 +66,20 @@
if d, _ := stream.RegisteredProtocol(network); d != nil {
return d(address)
}
- return net.DialTimeout(network, address, timeout)
+ conn, err := net.DialTimeout(network, address, timeout)
+ if err != nil || !strings.HasPrefix(network, "tcp") {
+ return conn, err
+ }
+
+ // For tcp connections we add an extra magic byte so we can differentiate between
+ // raw tcp and websocket on the same port.
+ switch n, err := conn.Write([]byte{wslistener.BinaryMagicByte}); {
+ case err != nil:
+ return nil, err
+ case n != 1:
+ return nil, fmt.Errorf("Unable to write the magic byte")
+ }
+ return conn, nil
}
// FindOrDialVIF returns the network connection (VIF) to the provided address
@@ -187,6 +201,12 @@
closeNetListener(netln)
return nil, nil, errShutDown
}
+
+ // If the protocol is tcp, we add the listener that supports both tcp and websocket
+ // so that javascript can talk to this server.
+ if strings.HasPrefix(protocol, "tcp") {
+ netln = wslistener.NewListener(netln)
+ }
ln := newNetListener(m, netln, opts)
m.listeners[ln] = true
m.muListeners.Unlock()
diff --git a/runtimes/google/ipc/stream/manager/manager_test.go b/runtimes/google/ipc/stream/manager/manager_test.go
index ccddf69..19dd942 100644
--- a/runtimes/google/ipc/stream/manager/manager_test.go
+++ b/runtimes/google/ipc/stream/manager/manager_test.go
@@ -12,6 +12,7 @@
"testing"
"time"
+ "veyron.io/veyron/veyron/lib/websocket"
"veyron.io/veyron/veyron2/ipc/stream"
"veyron.io/veyron/veyron2/naming"
"veyron.io/veyron/veyron2/security"
@@ -38,9 +39,10 @@
// introduces less variance in the behavior of the test.
runtime.GOMAXPROCS(1)
modules.RegisterChild("runServer", "", runServer)
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
}
-func TestSimpleFlow(t *testing.T) {
+func testSimpleFlow(t *testing.T, useWebsocket bool) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -49,6 +51,10 @@
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
+
data := "the dark knight rises"
var clientVC stream.VC
var clientF1 stream.Flow
@@ -124,6 +130,14 @@
}
}
+func TestSimpleFlow(t *testing.T) {
+ testSimpleFlow(t, false)
+}
+
+func TestSimpleFlowWS(t *testing.T) {
+ testSimpleFlow(t, true)
+}
+
func TestConnectionTimeout(t *testing.T) {
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -145,7 +159,7 @@
}
}
-func TestAuthenticatedByDefault(t *testing.T) {
+func testAuthenticatedByDefault(t *testing.T, useWebsocket bool) {
var (
server = InternalNew(naming.FixedRoutingID(0x55555555))
client = InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -161,6 +175,9 @@
if err != nil {
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
errs := make(chan error)
@@ -209,6 +226,14 @@
}
}
+func TestAuthenticatedByDefault(t *testing.T) {
+ testAuthenticatedByDefault(t, false)
+}
+
+func TestAuthenticatedByDefaultWS(t *testing.T) {
+ testAuthenticatedByDefault(t, true)
+}
+
func numListeners(m stream.Manager) int { return len(m.(*manager).listeners) }
func debugString(m stream.Manager) string { return m.(*manager).DebugString() }
func numVIFs(m stream.Manager) int { return len(m.(*manager).vifs.List()) }
@@ -271,6 +296,29 @@
}
}
+func TestCloseListenerWS(t *testing.T) {
+ server := InternalNew(naming.FixedRoutingID(0x5e97e9))
+
+ ln, ep, err := server.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ep.(*inaming.Endpoint).Protocol = "ws"
+
+ // Server will just listen for flows and close them.
+ go acceptLoop(ln)
+ client := InternalNew(naming.FixedRoutingID(0xc1e41))
+ if _, err = client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+ client = InternalNew(naming.FixedRoutingID(0xc1e42))
+ if _, err := client.Dial(ep); err == nil {
+ t.Errorf("client.Dial(%q) should have failed", ep)
+ }
+}
+
func TestShutdown(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x5e97e9))
ln, _, err := server.Listen("tcp", "127.0.0.1:0")
@@ -316,6 +364,33 @@
}
}
+func TestShutdownEndpointWS(t *testing.T) {
+ server := InternalNew(naming.FixedRoutingID(0x55555555))
+ client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+
+ ln, ep, err := server.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ep.(*inaming.Endpoint).Protocol = "ws"
+
+ // Server will just listen for flows and close them.
+ go acceptLoop(ln)
+
+ vc, err := client.Dial(ep)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if f, err := vc.Connect(); f == nil || err != nil {
+ t.Errorf("vc.Connect failed: (%v, %v)", f, err)
+ }
+ client.ShutdownEndpoint(ep)
+ if f, err := vc.Connect(); f != nil || err == nil {
+ t.Errorf("vc.Connect unexpectedly succeeded: (%v, %v)", f, err)
+ }
+}
+
/* TLS + resumption + channel bindings is broken: <https://secure-resumption.com/#channelbindings>.
func TestSessionTicketCache(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
@@ -335,7 +410,7 @@
}
*/
-func TestMultipleVCs(t *testing.T) {
+func testMultipleVCs(t *testing.T, useWebsocket bool) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -348,6 +423,10 @@
if err != nil {
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
+
read := func(flow stream.Flow, c chan string) {
var buf bytes.Buffer
var tmp [1024]byte
@@ -414,6 +493,14 @@
}
}
+func TestMultipleVCs(t *testing.T) {
+ testMultipleVCs(t, false)
+}
+
+func TestMultipleVCsWS(t *testing.T) {
+ testMultipleVCs(t, true)
+}
+
func TestAddressResolution(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -494,6 +581,46 @@
}
}
+func TestServerRestartDuringClientLifetimeWS(t *testing.T) {
+ client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+ sh := modules.NewShell(".*")
+ defer sh.Cleanup(nil, nil)
+ h, err := sh.Start("runServer", nil, "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ s := expect.NewSession(t, h.Stdout(), time.Minute)
+ addr := s.ReadLine()
+
+ ep, err := inaming.NewEndpoint(addr)
+ if err != nil {
+ t.Fatalf("inaming.NewEndpoint(%q): %v", addr, err)
+ }
+ ep.Protocol = "ws"
+ if _, err := client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+ h.Shutdown(nil, os.Stderr)
+
+ // A new VC cannot be created since the server is dead
+ if _, err := client.Dial(ep); err == nil {
+ t.Fatal("Expected client.Dial to fail since server is dead")
+ }
+
+ h, err = sh.Start("runServer", nil, addr)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ s = expect.NewSession(t, h.Stdout(), time.Minute)
+ // Restarting the server, listening on the same address as before
+ if addr2 := s.ReadLine(); addr2 != addr || err != nil {
+ t.Fatalf("Got (%q, %v) want (%q, nil)", addr2, err, addr)
+ }
+ if _, err := client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+}
+
// Needed by modules framework
func TestHelperProcess(t *testing.T) {
modules.DispatchInTest()
diff --git a/runtimes/google/ipc/stream/proxy/proxy.go b/runtimes/google/ipc/stream/proxy/proxy.go
index a0d794e..0da73ce 100644
--- a/runtimes/google/ipc/stream/proxy/proxy.go
+++ b/runtimes/google/ipc/stream/proxy/proxy.go
@@ -11,6 +11,7 @@
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/message"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vc"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vif"
+ "veyron.io/veyron/veyron/runtimes/google/ipc/stream/wslistener"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
"veyron.io/veyron/veyron/runtimes/google/lib/bqueue"
"veyron.io/veyron/veyron/runtimes/google/lib/bqueue/drrqueue"
@@ -135,6 +136,7 @@
if err != nil {
return nil, fmt.Errorf("net.Listen(%q, %q) failed: %v", network, address, err)
}
+ ln = wslistener.NewListener(ln)
if len(pubAddress) == 0 {
pubAddress = ln.Addr().String()
}
diff --git a/runtimes/google/ipc/stream/vif/vif.go b/runtimes/google/ipc/stream/vif/vif.go
index b69cdb4..355ea0b 100644
--- a/runtimes/google/ipc/stream/vif/vif.go
+++ b/runtimes/google/ipc/stream/vif/vif.go
@@ -472,7 +472,7 @@
for {
f, err := hr.Listener.Accept()
if err != nil {
- vlog.VI(2).Infof("Accept failed on VC %v on VIF %v", vc, vif)
+ vlog.VI(2).Infof("Accept failed on VC %v on VIF %v: %v", vc, vif, err)
return
}
if err := acceptor.Put(ConnectorAndFlow{vc, f}); err != nil {
diff --git a/runtimes/google/ipc/stream/wslistener/listener.go b/runtimes/google/ipc/stream/wslistener/listener.go
new file mode 100644
index 0000000..dfa2bc0
--- /dev/null
+++ b/runtimes/google/ipc/stream/wslistener/listener.go
@@ -0,0 +1,212 @@
+package wslistener
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "sync"
+
+ vwebsocket "veyron.io/veyron/veyron/lib/websocket"
+ "veyron.io/veyron/veyron/runtimes/google/lib/upcqueue"
+
+ "veyron.io/veyron/veyron2/vlog"
+
+ "github.com/gorilla/websocket"
+)
+
+var errListenerIsClosed = errors.New("Listener has been Closed")
+
+// We picked 0xFF because it's obviously outside the range of ASCII,
+// and is completely unused in UTF-8.
+const BinaryMagicByte byte = 0xFF
+
+const bufferSize int = 4096
+
+// A listener that is able to handle either raw tcp request or websocket requests.
+// The result of Accept is is a net.Conn interface.
+type wsTCPListener struct {
+ // The queue of net.Conn to be returned by Accept.
+ q *upcqueue.T
+
+ // The queue for the http listener when we detect an http request.
+ httpQ *upcqueue.T
+
+ // The underlying listener.
+ netLn net.Listener
+ wsServer http.Server
+
+ netLoop sync.WaitGroup
+ wsLoop sync.WaitGroup
+}
+
+// bufferedConn is used to allow us to Peek at the first byte to see if it
+// is the magic byte used by veyron tcp requests. Other than that it behaves
+// like a normal net.Conn.
+type bufferedConn struct {
+ net.Conn
+ // TODO(bjornick): Remove this buffering because we have way too much
+ // buffering anyway. We really only need to buffer the first byte.
+ r *bufio.Reader
+}
+
+func newBufferedConn(c net.Conn) bufferedConn {
+ return bufferedConn{Conn: c, r: bufio.NewReaderSize(c, bufferSize)}
+}
+
+func (c *bufferedConn) Peek(n int) ([]byte, error) {
+ return c.r.Peek(n)
+}
+
+func (c *bufferedConn) Read(p []byte) (int, error) {
+ return c.r.Read(p)
+}
+
+// queueListener is a listener that returns connections that are in q.
+type queueListener struct {
+ q *upcqueue.T
+ // ln is needed to implement Close and Addr
+ ln net.Listener
+}
+
+func (l *queueListener) Accept() (net.Conn, error) {
+ item, err := l.q.Get(nil)
+ switch {
+ case err == upcqueue.ErrQueueIsClosed:
+ return nil, errListenerIsClosed
+ case err != nil:
+ return nil, fmt.Errorf("Accept failed: %v", err)
+ default:
+ return item.(net.Conn), nil
+ }
+}
+
+func (l *queueListener) Close() error {
+ l.q.Shutdown()
+ return l.ln.Close()
+}
+
+func (l *queueListener) Addr() net.Addr {
+ return l.ln.Addr()
+}
+
+func NewListener(netLn net.Listener) net.Listener {
+ ln := &wsTCPListener{
+ q: upcqueue.New(),
+ httpQ: upcqueue.New(),
+ netLn: netLn,
+ }
+ ln.netLoop.Add(1)
+ go ln.netAcceptLoop()
+ httpListener := &queueListener{
+ q: ln.httpQ,
+ ln: ln,
+ }
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ defer ln.wsLoop.Done()
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
+ return
+ }
+ ws, err := websocket.Upgrade(w, r, nil, bufferSize, bufferSize)
+ if _, ok := err.(websocket.HandshakeError); ok {
+ http.Error(w, "Not a websocket handshake", 400)
+ vlog.Errorf("Rejected a non-websocket request: %v", err)
+ return
+ } else if err != nil {
+ http.Error(w, "Internal Error", 500)
+ vlog.Errorf("Rejected a non-websocket request: %v", err)
+ return
+ }
+ conn := vwebsocket.WebsocketConn(ws)
+ if err := ln.q.Put(conn); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed: %v", ws.RemoteAddr(), ws.LocalAddr(), err)
+ ws.Close()
+ return
+ }
+
+ }
+ ln.wsServer = http.Server{
+ Handler: http.HandlerFunc(handler),
+ }
+ go ln.wsServer.Serve(httpListener)
+ return ln
+}
+
+func (ln *wsTCPListener) netAcceptLoop() {
+ defer ln.Close()
+ defer ln.netLoop.Done()
+ for {
+ conn, err := ln.netLn.Accept()
+ if err != nil {
+ vlog.VI(1).Infof("Exiting netAcceptLoop: net.Listener.Accept() failed on %v with %v", ln.netLn, err)
+ return
+ }
+ vlog.VI(1).Infof("New net.Conn accepted from %s (local address: %s)", conn.RemoteAddr(), conn.LocalAddr())
+ bc := newBufferedConn(conn)
+ magic, err := bc.Peek(1)
+ if err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as the magic byte failed to be read: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+
+ vlog.VI(1).Info("Got a connection from %s (local address: %s)", conn.RemoteAddr(), conn.LocalAddr())
+ // Check to see if it is a regular connection or a http connection.
+ if magic[0] == BinaryMagicByte {
+ if _, err := bc.r.ReadByte(); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s), could read past the magic byte: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ if err := ln.q.Put(&bc); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed in vifLoop: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ continue
+ }
+
+ ln.wsLoop.Add(1)
+ if err := ln.httpQ.Put(&bc); err != nil {
+ ln.wsLoop.Done()
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed in vifLoop: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ }
+}
+
+func (ln *wsTCPListener) Accept() (net.Conn, error) {
+ item, err := ln.q.Get(nil)
+ switch {
+ case err == upcqueue.ErrQueueIsClosed:
+ return nil, errListenerIsClosed
+ case err != nil:
+ return nil, fmt.Errorf("Accept failed: %v", err)
+ default:
+ return item.(net.Conn), nil
+ }
+}
+
+func (ln *wsTCPListener) Close() error {
+ addr := ln.netLn.Addr()
+ err := ln.netLn.Close()
+ vlog.VI(1).Infof("Closed net.Listener on (%q, %q): %v", addr.Network(), addr, err)
+ ln.httpQ.Shutdown()
+ ln.netLoop.Wait()
+ ln.wsLoop.Wait()
+ // q has to be shutdown after the netAcceptLoop finishes because that loop
+ // could be in the process of accepting a websocket connection. The ordering
+ // relative to wsLoop is not really relevant because the wsLoop counter wil
+ // decrement every time there a websocket connection has been handled and does
+ // not block on gets from q.
+ ln.q.Shutdown()
+ vlog.VI(3).Infof("Close stream.wsTCPListener %s", ln)
+ return nil
+}
+
+func (ln *wsTCPListener) Addr() net.Addr {
+ return ln.netLn.Addr()
+}
diff --git a/runtimes/google/naming/namespace/all_test.go b/runtimes/google/naming/namespace/all_test.go
index 4192a68..e40b1f3 100644
--- a/runtimes/google/naming/namespace/all_test.go
+++ b/runtimes/google/naming/namespace/all_test.go
@@ -3,12 +3,14 @@
import (
"runtime"
"runtime/debug"
+ "strings"
"sync"
"testing"
"time"
"veyron.io/veyron/veyron2"
"veyron.io/veyron/veyron2/ipc"
+ "veyron.io/veyron/veyron2/ipc/stream"
"veyron.io/veyron/veyron2/naming"
"veyron.io/veyron/veyron2/options"
"veyron.io/veyron/veyron2/rt"
@@ -19,18 +21,26 @@
"veyron.io/veyron/veyron/lib/glob"
"veyron.io/veyron/veyron/lib/testutil"
+ "veyron.io/veyron/veyron/lib/websocket"
_ "veyron.io/veyron/veyron/profiles"
"veyron.io/veyron/veyron/runtimes/google/naming/namespace"
service "veyron.io/veyron/veyron/services/mounttable/lib"
)
-func init() { testutil.Init() }
+func init() {
+ testutil.Init()
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
+}
func boom(t *testing.T, f string, v ...interface{}) {
t.Logf(f, v...)
t.Fatal(string(debug.Stack()))
}
+func addWSName(name string) []string {
+ return []string{name, strings.Replace(name, "@tcp@", "@ws@", 1)}
+}
+
// N squared but who cares, this is a little test.
// Ignores dups.
func contains(container, contained []string) bool {
@@ -292,16 +302,16 @@
testResolveToMountTable(t, r, ns, m, rootMT)
// The server registered for each mount point is a mount table
- testResolve(t, r, ns, m, mts[m].name)
+ testResolve(t, r, ns, m, addWSName(mts[m].name)...)
// ResolveToMountTable will walk through to the sub MountTables
mtbar := naming.Join(m, "bar")
subMT := naming.Join(mts[m].name, "bar")
- testResolveToMountTable(t, r, ns, mtbar, subMT)
+ testResolveToMountTable(t, r, ns, mtbar, addWSName(subMT)...)
}
for _, j := range []string{j1MP, j2MP, j3MP} {
- testResolve(t, r, ns, j, jokes[j].name)
+ testResolve(t, r, ns, j, addWSName(jokes[j].name)...)
}
}
@@ -334,7 +344,7 @@
mt2mt := naming.Join(mts[mt2MP].name, "a")
// The mt2/a is served by the mt2 mount table
- testResolveToMountTable(t, r, ns, mt2a, mt2mt)
+ testResolveToMountTable(t, r, ns, mt2a, addWSName(mt2mt)...)
// The server for mt2a is mt3server from the second mount above.
testResolve(t, r, ns, mt2a, mt3Server)
@@ -349,12 +359,14 @@
}
}
+ names := []string{naming.JoinAddressName(mts[mt4MP].name, "a"),
+ naming.JoinAddressName(mts[mt5MP].name, "a")}
+ names = append(names, addWSName(naming.JoinAddressName(mts[mt2MP].name, "a"))...)
// We now have 3 mount tables prepared to serve mt2/a
- testResolveToMountTable(t, r, ns, "mt2/a",
- naming.JoinAddressName(mts[mt2MP].name, "a"),
- naming.JoinAddressName(mts[mt4MP].name, "a"),
- naming.JoinAddressName(mts[mt5MP].name, "a"))
- testResolve(t, r, ns, "mt2", mts[mt2MP].name, mts[mt4MP].name, mts[mt5MP].name)
+ testResolveToMountTable(t, r, ns, "mt2/a", names...)
+ names = []string{mts[mt4MP].name, mts[mt5MP].name}
+ names = append(names, addWSName(mts[mt2MP].name)...)
+ testResolve(t, r, ns, "mt2", names...)
}
// TestNestedMounts tests some more deeply nested mounts
@@ -370,15 +382,15 @@
// Set up some nested mounts and verify resolution.
for _, m := range []string{"mt4/foo", "mt4/foo/bar"} {
- testResolve(t, r, ns, m, mts[m].name)
+ testResolve(t, r, ns, m, addWSName(mts[m].name)...)
}
testResolveToMountTable(t, r, ns, "mt4/foo",
- naming.JoinAddressName(mts[mt4MP].name, "foo"))
+ addWSName(naming.JoinAddressName(mts[mt4MP].name, "foo"))...)
testResolveToMountTable(t, r, ns, "mt4/foo/bar",
- naming.JoinAddressName(mts["mt4/foo"].name, "bar"))
+ addWSName(naming.JoinAddressName(mts["mt4/foo"].name, "bar"))...)
testResolveToMountTable(t, r, ns, "mt4/foo/baz",
- naming.JoinAddressName(mts["mt4/foo"].name, "baz"))
+ addWSName(naming.JoinAddressName(mts["mt4/foo"].name, "baz"))...)
}
// TestServers tests invoking RPCs on simple servers
@@ -392,16 +404,16 @@
// Let's run some non-mount table services
for _, j := range []string{j1MP, j2MP, j3MP} {
- testResolve(t, r, ns, j, jokes[j].name)
+ testResolve(t, r, ns, j, addWSName(jokes[j].name)...)
knockKnock(t, r, j)
globalName := naming.JoinAddressName(mts["mt4"].name, j)
disp := &dispatcher{}
gj := "g_" + j
jokes[gj] = runServer(t, r, disp, globalName)
- testResolve(t, r, ns, "mt4/"+j, jokes[gj].name)
+ testResolve(t, r, ns, "mt4/"+j, addWSName(jokes[gj].name)...)
knockKnock(t, r, "mt4/"+j)
- testResolveToMountTable(t, r, ns, "mt4/"+j, globalName)
- testResolveToMountTable(t, r, ns, "mt4/"+j+"/garbage", globalName+"/garbage")
+ testResolveToMountTable(t, r, ns, "mt4/"+j, addWSName(globalName)...)
+ testResolveToMountTable(t, r, ns, "mt4/"+j+"/garbage", addWSName(globalName+"/garbage")...)
}
}
@@ -547,7 +559,8 @@
boom(t, "Failed to Mount %s: %s", m, err)
}
- testResolve(t, r, ns, "c1", c1.name)
+ // Since c1 was mounted with the Serve call, it will have both the tcp and ws endpoints.
+ testResolve(t, r, ns, "c1", addWSName(c1.name)...)
testResolve(t, r, ns, "c1/c2", c1.name)
testResolve(t, r, ns, "c1/c3", c3.name)
testResolve(t, r, ns, "c1/c3/c4", c1.name)
diff --git a/services/mgmt/node/impl/util_test.go b/services/mgmt/node/impl/util_test.go
index 4e8a866..62dc931 100644
--- a/services/mgmt/node/impl/util_test.go
+++ b/services/mgmt/node/impl/util_test.go
@@ -8,6 +8,7 @@
"reflect"
"runtime"
"sort"
+ "strings"
"testing"
"time"
@@ -159,10 +160,18 @@
if err != nil {
t.Fatalf("Resolve(%v) failed: %v", name, err)
}
- if want, got := replicas, len(results); want != got {
+
+ filteredResults := []string{}
+ for _, r := range results {
+ if strings.Index(r, "@tcp") != -1 {
+ filteredResults = append(filteredResults, r)
+ }
+ }
+ // We are going to get a websocket and a tcp endpoint for each replica.
+ if want, got := replicas, len(filteredResults); want != got {
t.Fatalf("Resolve(%v) expected %d result(s), got %d instead", name, want, got)
}
- return results
+ return filteredResults
}
// The following set of functions are convenience wrappers around Update and