Merge "veyron/services/mgmt: Get rid of the term 'invoker' (one more pass)"
diff --git a/lib/modules/core/core_test.go b/lib/modules/core/core_test.go
index d8a2a1d..1cc69e2 100644
--- a/lib/modules/core/core_test.go
+++ b/lib/modules/core/core_test.go
@@ -6,6 +6,7 @@
"reflect"
"sort"
"strconv"
+ "strings"
"testing"
"time"
@@ -242,6 +243,7 @@
srvSession.ExpectVar("NAME")
addr := srvSession.ExpectVar("ADDR")
addr = naming.JoinAddressName(addr, "")
+ wsAddr := strings.Replace(addr, "@tcp@", "@ws@", 1)
// Resolve an object
resolver, err := sh.Start(core.ResolveCommand, nil, rootName+"/"+echoName)
@@ -249,11 +251,11 @@
t.Fatalf("unexpected error: %s", err)
}
resolverSession := expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Errorf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Errorf("got %v, want either %v or %v", got, want, wsAddr)
}
if err = resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
@@ -261,16 +263,17 @@
// Resolve to a mount table using a rooted name.
addr = naming.JoinAddressName(mountAddrs[mtName], "echo")
+ wsAddr = strings.Replace(addr, "@tcp@", "@ws@", 1)
resolver, err = sh.Start(core.ResolveMTCommand, nil, rootName+"/"+echoName)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
resolverSession = expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Fatalf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Fatalf("got %v, want either %v or %v", got, want, wsAddr)
}
if err := resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
@@ -290,11 +293,11 @@
t.Fatalf("unexpected error: %s", err)
}
resolverSession = expect.NewSession(t, resolver.Stdout(), time.Minute)
- if got, want := resolverSession.ExpectVar("RN"), "1"; got != want {
+ if got, want := resolverSession.ExpectVar("RN"), "2"; got != want {
t.Fatalf("got %v, want %v", got, want)
}
- if got, want := resolverSession.ExpectVar("R0"), addr; got != want {
- t.Fatalf("got %v, want %v", got, want)
+ if got, want := resolverSession.ExpectVar("R0"), addr; got != want && got != wsAddr {
+ t.Fatalf("got %v, want either %v or %v", got, want, wsAddr)
}
if err := resolver.Shutdown(nil, os.Stderr); err != nil {
t.Fatalf("unexpected error: %s", err)
diff --git a/lib/modules/core/mounttable.go b/lib/modules/core/mounttable.go
index 306498f..7d5d7f8 100644
--- a/lib/modules/core/mounttable.go
+++ b/lib/modules/core/mounttable.go
@@ -108,7 +108,7 @@
return nil
}
-type resolver func(ctx context.T, name string, opts ...naming.ResolveOpt) (names []string, err error)
+type resolver func(ctx context.T, name string) (names []string, err error)
func resolve(fn resolver, stdin io.Reader, stdout, stderr io.Writer, env map[string]string, args ...string) error {
if err := checkArgs(args[1:], 1, "<name>"); err != nil {
diff --git a/lib/netconfig/ipaux_other.go b/lib/netconfig/ipaux_other.go
index 2826424..948c0c7 100644
--- a/lib/netconfig/ipaux_other.go
+++ b/lib/netconfig/ipaux_other.go
@@ -49,5 +49,7 @@
}
func GetIPRoutes(defaultOnly bool) []*IPRoute {
- panic("Not yet implemented")
+ // TODO(nlacasse,bprosnitz): Consider implementing? For now return
+ // empty array, since that seems to keep things working.
+ return []*IPRoute{}
}
diff --git a/lib/unixfd/unixfd_test.go b/lib/unixfd/unixfd_test.go
index 152efa8..74a2c65 100644
--- a/lib/unixfd/unixfd_test.go
+++ b/lib/unixfd/unixfd_test.go
@@ -142,7 +142,7 @@
done := make(chan struct{})
buf := make([]byte, 10)
go func() {
- saddr, n, err = ReadConnection(server, buf)
+ saddr, n, readErr = ReadConnection(server, buf)
close(done)
}()
caddr, err := SendConnection(uclient.(*net.UnixConn), []byte("hello"), true)
diff --git a/lib/websocket/conn.go b/lib/websocket/conn.go
new file mode 100644
index 0000000..e6963c4
--- /dev/null
+++ b/lib/websocket/conn.go
@@ -0,0 +1,124 @@
+// +build !nacl
+
+package websocket
+
+import (
+ "fmt"
+ "github.com/gorilla/websocket"
+ "io"
+ "net"
+ "sync"
+ "time"
+)
+
+// WebsocketConn provides a net.Conn interface for a websocket connection.
+func WebsocketConn(ws *websocket.Conn) net.Conn {
+ return &wrappedConn{ws: ws}
+}
+
+// wrappedConn provides a net.Conn interface to a websocket.
+// The underlying websocket connection needs regular calls to Read to make sure
+// websocket control messages (such as pings) are processed by the websocket
+// library.
+type wrappedConn struct {
+ ws *websocket.Conn
+ currReader io.Reader
+
+ // The gorilla docs aren't explicit about reading and writing from
+ // different goroutines. It is explicit that only one goroutine can
+ // do a write at any given time and only one goroutine can do a read
+ // at any given time. Based on inspection it seems that using a reader
+ // and writer simultaneously is safe, but this might change with
+ // future changes. We can't actually share the lock, because this means
+ // that we can't write while we are waiting for a message, causing some
+ // deadlocks where a write is need to unblock a read.
+ writeLock sync.Mutex
+ readLock sync.Mutex
+}
+
+func (c *wrappedConn) readFromCurrReader(b []byte) (int, error) {
+ n, err := c.currReader.Read(b)
+ if err == io.EOF {
+ err = nil
+ c.currReader = nil
+ }
+ return n, err
+}
+
+func (c *wrappedConn) Read(b []byte) (int, error) {
+ c.readLock.Lock()
+ defer c.readLock.Unlock()
+ var n int
+ var err error
+
+ // TODO(bjornick): It would be nice to be able to read multiple messages at
+ // a time in case the first message is not big enough to fill b and another
+ // message is ready.
+ // Loop until we either get data or an error. This exists
+ // mostly to avoid return 0, nil.
+ for n == 0 && err == nil {
+ if c.currReader == nil {
+ t, r, err := c.ws.NextReader()
+
+ if t != websocket.BinaryMessage {
+ return 0, fmt.Errorf("Unexpected message type %d", t)
+ }
+ if err != nil {
+ return 0, err
+ }
+ c.currReader = r
+ }
+ n, err = c.readFromCurrReader(b)
+ }
+ return n, err
+}
+
+func (c *wrappedConn) Write(b []byte) (int, error) {
+ c.writeLock.Lock()
+ defer c.writeLock.Unlock()
+ if err := c.ws.WriteMessage(websocket.BinaryMessage, b); err != nil {
+ return 0, err
+ }
+ return len(b), nil
+}
+
+func (c *wrappedConn) Close() error {
+ c.writeLock.Lock()
+ defer c.writeLock.Unlock()
+ return c.ws.Close()
+}
+
+func (c *wrappedConn) LocalAddr() net.Addr {
+ return websocketAddr{s: c.ws.LocalAddr().String()}
+}
+
+func (c *wrappedConn) RemoteAddr() net.Addr {
+ return websocketAddr{s: c.ws.RemoteAddr().String()}
+}
+
+func (c *wrappedConn) SetDeadline(t time.Time) error {
+ if err := c.SetReadDeadline(t); err != nil {
+ return err
+ }
+ return c.SetWriteDeadline(t)
+}
+
+func (c *wrappedConn) SetReadDeadline(t time.Time) error {
+ return c.ws.SetReadDeadline(t)
+}
+
+func (c *wrappedConn) SetWriteDeadline(t time.Time) error {
+ return c.ws.SetWriteDeadline(t)
+}
+
+type websocketAddr struct {
+ s string
+}
+
+func (websocketAddr) Network() string {
+ return "ws"
+}
+
+func (w websocketAddr) String() string {
+ return w.s
+}
diff --git a/lib/websocket/conn_nacl.go b/lib/websocket/conn_nacl.go
new file mode 100644
index 0000000..22d368d
--- /dev/null
+++ b/lib/websocket/conn_nacl.go
@@ -0,0 +1,105 @@
+// +build nacl
+
+package websocket
+
+import (
+ "net"
+ "net/url"
+ "runtime/ppapi"
+ "sync"
+ "time"
+)
+
+// Ppapi instance which must be set before the Dial is called.
+var PpapiInstance ppapi.Instance
+
+func WebsocketConn(address string, ws *ppapi.WebsocketConn) net.Conn {
+ return &wrappedConn{
+ address: address,
+ ws: ws,
+ }
+}
+
+type wrappedConn struct {
+ address string
+ ws *ppapi.WebsocketConn
+ readLock sync.Mutex
+ writeLock sync.Mutex
+ currBuffer []byte
+}
+
+func Dial(address string) (net.Conn, error) {
+ inst := PpapiInstance
+ u, err := url.Parse("ws://" + address)
+ if err != nil {
+ return nil, err
+ }
+
+ ws, err := inst.DialWebsocket(u.String())
+ if err != nil {
+ return nil, err
+ }
+ return WebsocketConn(address, ws), nil
+}
+
+func (c *wrappedConn) Read(b []byte) (int, error) {
+ c.readLock.Lock()
+ defer c.readLock.Unlock()
+
+ var err error
+ if len(c.currBuffer) == 0 {
+ c.currBuffer, err = c.ws.ReceiveMessage()
+ if err != nil {
+ return 0, nil
+ }
+ }
+
+ n := copy(b, c.currBuffer)
+ c.currBuffer = c.currBuffer[n:]
+ return n, nil
+}
+
+func (c *wrappedConn) Write(b []byte) (int, error) {
+ c.writeLock.Lock()
+ defer c.writeLock.Unlock()
+ if err := c.ws.SendMessage(b); err != nil {
+ return 0, err
+ }
+ return len(b), nil
+}
+
+func (c *wrappedConn) Close() error {
+ return c.ws.Close()
+}
+
+func (c *wrappedConn) LocalAddr() net.Addr {
+ return websocketAddr{s: c.address}
+}
+
+func (c *wrappedConn) RemoteAddr() net.Addr {
+ return websocketAddr{s: c.address}
+}
+
+func (c *wrappedConn) SetDeadline(t time.Time) error {
+ panic("SetDeadline not implemented.")
+}
+
+func (c *wrappedConn) SetReadDeadline(t time.Time) error {
+ panic("SetReadDeadline not implemented.")
+}
+
+func (c *wrappedConn) SetWriteDeadline(t time.Time) error {
+ panic("SetWriteDeadline not implemented.")
+}
+
+type websocketAddr struct {
+ s string
+}
+
+func (websocketAddr) Network() string {
+ return "ws"
+}
+
+func (w websocketAddr) String() string {
+ return w.s
+}
diff --git a/lib/websocket/conn_test.go b/lib/websocket/conn_test.go
new file mode 100644
index 0000000..db38cbe
--- /dev/null
+++ b/lib/websocket/conn_test.go
@@ -0,0 +1,111 @@
+// +build !nacl
+package websocket
+
+import (
+ "bytes"
+ "github.com/gorilla/websocket"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+)
+
+func writer(c net.Conn, data []byte, times int, wg *sync.WaitGroup) {
+ defer wg.Done()
+ b := []byte{byte(len(data))}
+ b = append(b, data...)
+ for i := 0; i < times; i++ {
+ c.Write(b)
+ }
+}
+
+func readMessage(c net.Conn) ([]byte, error) {
+ var length [1]byte
+ // Read the size
+ for {
+ n, err := c.Read(length[:])
+ if err != nil {
+ return nil, err
+ }
+ if n == 1 {
+ break
+ }
+ }
+ size := int(length[0])
+ buf := make([]byte, size)
+ n := 0
+ for n < size {
+ nn, err := c.Read(buf[n:])
+ if err != nil {
+ return buf, err
+ }
+ n += nn
+ }
+
+ return buf, nil
+}
+
+func reader(t *testing.T, c net.Conn, expected []byte, totalWrites int) {
+ totalReads := 0
+ for buf, err := readMessage(c); err == nil; buf, err = readMessage(c) {
+ totalReads++
+ if !bytes.Equal(buf, expected) {
+ t.Errorf("Unexpected message %v, expected %v", buf, expected)
+ }
+ }
+ if totalReads != totalWrites {
+ t.Errorf("wrong number of messages expected %v, got %v", totalWrites, totalReads)
+ }
+}
+
+func TestMultipleGoRoutines(t *testing.T) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+ addr := l.Addr()
+ input := []byte("no races here")
+ const numWriters int = 12
+ const numWritesPerWriter int = 1000
+ const totalWrites int = numWriters * numWritesPerWriter
+ s := &http.Server{
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
+ return
+ }
+ ws, err := websocket.Upgrade(w, r, nil, 1024, 1024)
+ if _, ok := err.(websocket.HandshakeError); ok {
+ http.Error(w, "Not a websocket handshake", 400)
+ return
+ } else if err != nil {
+ http.Error(w, "Internal Error", 500)
+ return
+ }
+ reader(t, WebsocketConn(ws), input, totalWrites)
+ }),
+ }
+ // Dial out in another go routine
+ go func() {
+ conn, err := Dial(addr.String())
+ numTries := 0
+ for err != nil && numTries < 5 {
+ numTries++
+ time.Sleep(time.Second)
+ }
+
+ if err != nil {
+ t.Fatalf("failed to connect to server: %v", err)
+ }
+ var writers sync.WaitGroup
+ writers.Add(numWriters)
+ for i := 0; i < numWriters; i++ {
+ go writer(conn, input, numWritesPerWriter, &writers)
+ }
+ writers.Wait()
+ conn.Close()
+ l.Close()
+ }()
+ s.Serve(l)
+}
diff --git a/lib/websocket/dialer.go b/lib/websocket/dialer.go
new file mode 100644
index 0000000..dea1262
--- /dev/null
+++ b/lib/websocket/dialer.go
@@ -0,0 +1,28 @@
+// +build !nacl
+
+package websocket
+
+import (
+ "github.com/gorilla/websocket"
+ "net"
+ "net/http"
+ "net/url"
+)
+
+func Dial(address string) (net.Conn, error) {
+ conn, err := net.Dial("tcp", address)
+ if err != nil {
+ return nil, err
+ }
+ u, err := url.Parse("ws://" + address)
+
+ if err != nil {
+ return nil, err
+ }
+ ws, _, err := websocket.NewClient(conn, u, http.Header{}, 4096, 4096)
+ if err != nil {
+ return nil, err
+ }
+
+ return WebsocketConn(ws), nil
+}
diff --git a/runtimes/google/ipc/client.go b/runtimes/google/ipc/client.go
index 2b15b7c..15d20ef 100644
--- a/runtimes/google/ipc/client.go
+++ b/runtimes/google/ipc/client.go
@@ -291,13 +291,13 @@
// tryCall makes a single attempt at a call, against possibly multiple servers.
func (c *client) tryCall(ctx context.T, name, method string, args []interface{}, opts []ipc.CallOpt) (ipc.Call, verror.E) {
ctx, _ = vtrace.WithNewSpan(ctx, fmt.Sprintf("<client>\"%s\".%s", name, method))
- mtPattern, serverPattern, name := splitObjectName(name)
+ _, serverPattern, name := splitObjectName(name)
// Resolve name unless told not to.
var servers []string
if getNoResolveOpt(opts) {
servers = []string{name}
} else {
- if resolved, err := c.ns.Resolve(ctx, name, naming.RootBlessingPatternOpt(mtPattern)); err != nil {
+ if resolved, err := c.ns.Resolve(ctx, name); err != nil {
return nil, verror.NoExistf("ipc: Resolve(%q) failed: %v", name, err)
} else {
// An empty set of protocols means all protocols...
diff --git a/runtimes/google/ipc/client_test.go b/runtimes/google/ipc/client_test.go
index d5be886..c77320f 100644
--- a/runtimes/google/ipc/client_test.go
+++ b/runtimes/google/ipc/client_test.go
@@ -101,8 +101,8 @@
}
}
- // Verify that there are 101 entries for echoServer in the mount table.
- if got, want := numServers(t, sh, "echoServer"), "101"; got != want {
+ // Verify that there are 102 entries for echoServer in the mount table.
+ if got, want := numServers(t, sh, "echoServer"), "102"; got != want {
vlog.Fatalf("got: %q, want: %q", got, want)
}
@@ -119,7 +119,7 @@
// TODO(cnicolaou,p): figure out why the real entry isn't removed
// from the mount table.
// Verify that there are 100 entries for echoServer in the mount table.
- if got, want := numServers(t, sh, "echoServer"), "101"; got != want {
+ if got, want := numServers(t, sh, "echoServer"), "102"; got != want {
vlog.Fatalf("got: %q, want: %q", got, want)
}
}
diff --git a/runtimes/google/ipc/full_test.go b/runtimes/google/ipc/full_test.go
index 67e14b9..857b2e2 100644
--- a/runtimes/google/ipc/full_test.go
+++ b/runtimes/google/ipc/full_test.go
@@ -26,6 +26,7 @@
"veyron.io/veyron/veyron/lib/netstate"
"veyron.io/veyron/veyron/lib/testutil"
tsecurity "veyron.io/veyron/veyron/lib/testutil/security"
+ "veyron.io/veyron/veyron/lib/websocket"
imanager "veyron.io/veyron/veyron/runtimes/google/ipc/stream/manager"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vc"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
@@ -35,6 +36,10 @@
vsecurity "veyron.io/veyron/veyron/security"
)
+func init() {
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
+}
+
var (
errMethod = verror.Abortedf("server returned an error")
clock = new(fakeClock)
@@ -237,6 +242,20 @@
vlog.VI(1).Info("server.Stop DONE")
}
+func resolveWSEndpoint(ns naming.Namespace, name string) (string, error) {
+ // Find the ws endpoint and use that.
+ servers, err := ns.Resolve(testContext(), name)
+ if err != nil {
+ return "", err
+ }
+ for _, s := range servers {
+ if strings.Index(s, "@ws@") != -1 {
+ return s, nil
+ }
+ }
+ return "", fmt.Errorf("No ws endpoint found %v", servers)
+}
+
type bundle struct {
client ipc.Client
server ipc.Server
@@ -440,20 +459,40 @@
}
}
client.Close()
+
}
}
+type websocketMode bool
+type closeSendMode bool
+
+const (
+ useWebsocket websocketMode = true
+ noWebsocket websocketMode = false
+
+ closeSend closeSendMode = true
+ noCloseSend closeSendMode = false
+)
+
func TestRPC(t *testing.T) {
- testRPC(t, true)
+ testRPC(t, closeSend, noWebsocket)
+}
+
+func TestRPCWithWebsocket(t *testing.T) {
+ testRPC(t, closeSend, useWebsocket)
}
// TestCloseSendOnFinish tests that Finish informs the server that no more
// inputs will be sent by the client if CloseSend has not already done so.
func TestRPCCloseSendOnFinish(t *testing.T) {
- testRPC(t, false)
+ testRPC(t, noCloseSend, noWebsocket)
}
-func testRPC(t *testing.T, shouldCloseSend bool) {
+func TestRPCCloseSendOnFinishWithWebsocket(t *testing.T) {
+ testRPC(t, noCloseSend, useWebsocket)
+}
+
+func testRPC(t *testing.T, shouldCloseSend closeSendMode, shouldUseWebsocket websocketMode) {
type v []interface{}
type testcase struct {
name string
@@ -494,7 +533,16 @@
pserver.AddToRoots(pclient.BlessingStore().Default())
for _, test := range tests {
vlog.VI(1).Infof("%s client.StartCall", name(test))
- call, err := b.client.StartCall(testContext(), test.name, test.method, test.args)
+ vname := test.name
+ if shouldUseWebsocket {
+ var err error
+ vname, err = resolveWSEndpoint(b.ns, vname)
+ if err != nil && err != test.startErr {
+ t.Errorf(`%s ns.Resolve got error "%v", want "%v"`, name(test), err, test.startErr)
+ continue
+ }
+ }
+ call, err := b.client.StartCall(testContext(), vname, test.method, test.args)
if err != test.startErr {
t.Errorf(`%s client.StartCall got error "%v", want "%v"`, name(test), err, test.startErr)
continue
diff --git a/runtimes/google/ipc/server.go b/runtimes/google/ipc/server.go
index 75447aa..00e0688 100644
--- a/runtimes/google/ipc/server.go
+++ b/runtimes/google/ipc/server.go
@@ -283,6 +283,11 @@
s.active.Done()
}()
s.publisher.AddServer(s.publishEP(iep, s.servesMountTable), s.servesMountTable)
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ epCopy := *iep
+ epCopy.Protocol = "ws"
+ s.publisher.AddServer(s.publishEP(&epCopy, s.servesMountTable), s.servesMountTable)
+ }
}
if len(listenSpec.Proxy) > 0 {
@@ -321,6 +326,13 @@
s.listeners[ln] = nil
s.Unlock()
s.publisher.AddServer(s.publishEP(iep, s.servesMountTable), s.servesMountTable)
+
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ epCopy := *iep
+ epCopy.Protocol = "ws"
+ s.publisher.AddServer(s.publishEP(&epCopy, s.servesMountTable), s.servesMountTable)
+ }
+
return iep, ln, nil
}
@@ -351,6 +363,11 @@
// The listener is done, so:
// (1) Unpublish its name
s.publisher.RemoveServer(s.publishEP(iep, s.servesMountTable))
+ if strings.HasPrefix(iep.Protocol, "tcp") {
+ iepCopy := *iep
+ iepCopy.Protocol = "ws"
+ s.publisher.RemoveServer(s.publishEP(&iepCopy, s.servesMountTable))
+ }
}
s.Lock()
diff --git a/runtimes/google/ipc/server_test.go b/runtimes/google/ipc/server_test.go
index 1bb853a..a918f00 100644
--- a/runtimes/google/ipc/server_test.go
+++ b/runtimes/google/ipc/server_test.go
@@ -130,6 +130,10 @@
return addr
}
+func addWSName(name string) []string {
+ return []string{name, strings.Replace(name, "@tcp@", "@ws@", 1)}
+}
+
func testProxy(t *testing.T, spec ipc.ListenSpec) {
sm := imanager.InternalNew(naming.FixedRoutingID(0x555555555))
ns := tnaming.NewSimpleNamespace()
@@ -201,9 +205,9 @@
t.Fatalf("unexpected error: %s", err)
}
proxiedEP.RID = naming.FixedRoutingID(0x555555555)
- expectedEndpoints := []string{proxiedEP.String()}
+ expectedEndpoints := addWSName(proxiedEP.String())
if hasLocalListener {
- expectedEndpoints = append(expectedEndpoints, ep.String())
+ expectedEndpoints = append(expectedEndpoints, addWSName(ep.String())...)
}
// Proxy connetions are created asynchronously, so we wait for the
@@ -228,8 +232,12 @@
if hasLocalListener {
// Listen will publish both the local and proxied endpoint with the
// mount table, given that we're trying to test the proxy, we remove
- // the local endpoint from the mount table entry!
- ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(ep.String(), ""))
+ // the local endpoint from the mount table entry! We have to remove both
+ // the tcp and the websocket address.
+ sep := ep.String()
+ wsep := strings.Replace(sep, "@tcp@", "@ws@", 1)
+ ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(sep, ""))
+ ns.Unmount(testContext(), "mountpoint/server", naming.JoinAddressName(wsep, ""))
}
// Proxied endpoint should be published and RPC should succeed (through proxy)
diff --git a/runtimes/google/ipc/stream/manager/manager.go b/runtimes/google/ipc/stream/manager/manager.go
index 8620a61..55e06ee 100644
--- a/runtimes/google/ipc/stream/manager/manager.go
+++ b/runtimes/google/ipc/stream/manager/manager.go
@@ -12,6 +12,7 @@
"veyron.io/veyron/veyron/lib/stats"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/crypto"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vif"
+ "veyron.io/veyron/veyron/runtimes/google/ipc/stream/wslistener"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
inaming "veyron.io/veyron/veyron/runtimes/google/naming"
@@ -65,7 +66,20 @@
if d, _ := stream.RegisteredProtocol(network); d != nil {
return d(address)
}
- return net.DialTimeout(network, address, timeout)
+ conn, err := net.DialTimeout(network, address, timeout)
+ if err != nil || !strings.HasPrefix(network, "tcp") {
+ return conn, err
+ }
+
+ // For tcp connections we add an extra magic byte so we can differentiate between
+ // raw tcp and websocket on the same port.
+ switch n, err := conn.Write([]byte{wslistener.BinaryMagicByte}); {
+ case err != nil:
+ return nil, err
+ case n != 1:
+ return nil, fmt.Errorf("Unable to write the magic byte")
+ }
+ return conn, nil
}
// FindOrDialVIF returns the network connection (VIF) to the provided address
@@ -187,6 +201,12 @@
closeNetListener(netln)
return nil, nil, errShutDown
}
+
+ // If the protocol is tcp, we add the listener that supports both tcp and websocket
+ // so that javascript can talk to this server.
+ if strings.HasPrefix(protocol, "tcp") {
+ netln = wslistener.NewListener(netln)
+ }
ln := newNetListener(m, netln, opts)
m.listeners[ln] = true
m.muListeners.Unlock()
diff --git a/runtimes/google/ipc/stream/manager/manager_test.go b/runtimes/google/ipc/stream/manager/manager_test.go
index ccddf69..19dd942 100644
--- a/runtimes/google/ipc/stream/manager/manager_test.go
+++ b/runtimes/google/ipc/stream/manager/manager_test.go
@@ -12,6 +12,7 @@
"testing"
"time"
+ "veyron.io/veyron/veyron/lib/websocket"
"veyron.io/veyron/veyron2/ipc/stream"
"veyron.io/veyron/veyron2/naming"
"veyron.io/veyron/veyron2/security"
@@ -38,9 +39,10 @@
// introduces less variance in the behavior of the test.
runtime.GOMAXPROCS(1)
modules.RegisterChild("runServer", "", runServer)
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
}
-func TestSimpleFlow(t *testing.T) {
+func testSimpleFlow(t *testing.T, useWebsocket bool) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -49,6 +51,10 @@
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
+
data := "the dark knight rises"
var clientVC stream.VC
var clientF1 stream.Flow
@@ -124,6 +130,14 @@
}
}
+func TestSimpleFlow(t *testing.T) {
+ testSimpleFlow(t, false)
+}
+
+func TestSimpleFlowWS(t *testing.T) {
+ testSimpleFlow(t, true)
+}
+
func TestConnectionTimeout(t *testing.T) {
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -145,7 +159,7 @@
}
}
-func TestAuthenticatedByDefault(t *testing.T) {
+func testAuthenticatedByDefault(t *testing.T, useWebsocket bool) {
var (
server = InternalNew(naming.FixedRoutingID(0x55555555))
client = InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -161,6 +175,9 @@
if err != nil {
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
errs := make(chan error)
@@ -209,6 +226,14 @@
}
}
+func TestAuthenticatedByDefault(t *testing.T) {
+ testAuthenticatedByDefault(t, false)
+}
+
+func TestAuthenticatedByDefaultWS(t *testing.T) {
+ testAuthenticatedByDefault(t, true)
+}
+
func numListeners(m stream.Manager) int { return len(m.(*manager).listeners) }
func debugString(m stream.Manager) string { return m.(*manager).DebugString() }
func numVIFs(m stream.Manager) int { return len(m.(*manager).vifs.List()) }
@@ -271,6 +296,29 @@
}
}
+func TestCloseListenerWS(t *testing.T) {
+ server := InternalNew(naming.FixedRoutingID(0x5e97e9))
+
+ ln, ep, err := server.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ep.(*inaming.Endpoint).Protocol = "ws"
+
+ // Server will just listen for flows and close them.
+ go acceptLoop(ln)
+ client := InternalNew(naming.FixedRoutingID(0xc1e41))
+ if _, err = client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+ client = InternalNew(naming.FixedRoutingID(0xc1e42))
+ if _, err := client.Dial(ep); err == nil {
+ t.Errorf("client.Dial(%q) should have failed", ep)
+ }
+}
+
func TestShutdown(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x5e97e9))
ln, _, err := server.Listen("tcp", "127.0.0.1:0")
@@ -316,6 +364,33 @@
}
}
+func TestShutdownEndpointWS(t *testing.T) {
+ server := InternalNew(naming.FixedRoutingID(0x55555555))
+ client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+
+ ln, ep, err := server.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ep.(*inaming.Endpoint).Protocol = "ws"
+
+ // Server will just listen for flows and close them.
+ go acceptLoop(ln)
+
+ vc, err := client.Dial(ep)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if f, err := vc.Connect(); f == nil || err != nil {
+ t.Errorf("vc.Connect failed: (%v, %v)", f, err)
+ }
+ client.ShutdownEndpoint(ep)
+ if f, err := vc.Connect(); f != nil || err == nil {
+ t.Errorf("vc.Connect unexpectedly succeeded: (%v, %v)", f, err)
+ }
+}
+
/* TLS + resumption + channel bindings is broken: <https://secure-resumption.com/#channelbindings>.
func TestSessionTicketCache(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
@@ -335,7 +410,7 @@
}
*/
-func TestMultipleVCs(t *testing.T) {
+func testMultipleVCs(t *testing.T, useWebsocket bool) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -348,6 +423,10 @@
if err != nil {
t.Fatal(err)
}
+ if useWebsocket {
+ ep.(*inaming.Endpoint).Protocol = "ws"
+ }
+
read := func(flow stream.Flow, c chan string) {
var buf bytes.Buffer
var tmp [1024]byte
@@ -414,6 +493,14 @@
}
}
+func TestMultipleVCs(t *testing.T) {
+ testMultipleVCs(t, false)
+}
+
+func TestMultipleVCsWS(t *testing.T) {
+ testMultipleVCs(t, true)
+}
+
func TestAddressResolution(t *testing.T) {
server := InternalNew(naming.FixedRoutingID(0x55555555))
client := InternalNew(naming.FixedRoutingID(0xcccccccc))
@@ -494,6 +581,46 @@
}
}
+func TestServerRestartDuringClientLifetimeWS(t *testing.T) {
+ client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+ sh := modules.NewShell(".*")
+ defer sh.Cleanup(nil, nil)
+ h, err := sh.Start("runServer", nil, "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ s := expect.NewSession(t, h.Stdout(), time.Minute)
+ addr := s.ReadLine()
+
+ ep, err := inaming.NewEndpoint(addr)
+ if err != nil {
+ t.Fatalf("inaming.NewEndpoint(%q): %v", addr, err)
+ }
+ ep.Protocol = "ws"
+ if _, err := client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+ h.Shutdown(nil, os.Stderr)
+
+ // A new VC cannot be created since the server is dead
+ if _, err := client.Dial(ep); err == nil {
+ t.Fatal("Expected client.Dial to fail since server is dead")
+ }
+
+ h, err = sh.Start("runServer", nil, addr)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ s = expect.NewSession(t, h.Stdout(), time.Minute)
+ // Restarting the server, listening on the same address as before
+ if addr2 := s.ReadLine(); addr2 != addr || err != nil {
+ t.Fatalf("Got (%q, %v) want (%q, nil)", addr2, err, addr)
+ }
+ if _, err := client.Dial(ep); err != nil {
+ t.Fatal(err)
+ }
+}
+
// Needed by modules framework
func TestHelperProcess(t *testing.T) {
modules.DispatchInTest()
diff --git a/runtimes/google/ipc/stream/proxy/proxy.go b/runtimes/google/ipc/stream/proxy/proxy.go
index a0d794e..0da73ce 100644
--- a/runtimes/google/ipc/stream/proxy/proxy.go
+++ b/runtimes/google/ipc/stream/proxy/proxy.go
@@ -11,6 +11,7 @@
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/message"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vc"
"veyron.io/veyron/veyron/runtimes/google/ipc/stream/vif"
+ "veyron.io/veyron/veyron/runtimes/google/ipc/stream/wslistener"
"veyron.io/veyron/veyron/runtimes/google/ipc/version"
"veyron.io/veyron/veyron/runtimes/google/lib/bqueue"
"veyron.io/veyron/veyron/runtimes/google/lib/bqueue/drrqueue"
@@ -135,6 +136,7 @@
if err != nil {
return nil, fmt.Errorf("net.Listen(%q, %q) failed: %v", network, address, err)
}
+ ln = wslistener.NewListener(ln)
if len(pubAddress) == 0 {
pubAddress = ln.Addr().String()
}
diff --git a/runtimes/google/ipc/stream/vif/vif.go b/runtimes/google/ipc/stream/vif/vif.go
index b69cdb4..355ea0b 100644
--- a/runtimes/google/ipc/stream/vif/vif.go
+++ b/runtimes/google/ipc/stream/vif/vif.go
@@ -472,7 +472,7 @@
for {
f, err := hr.Listener.Accept()
if err != nil {
- vlog.VI(2).Infof("Accept failed on VC %v on VIF %v", vc, vif)
+ vlog.VI(2).Infof("Accept failed on VC %v on VIF %v: %v", vc, vif, err)
return
}
if err := acceptor.Put(ConnectorAndFlow{vc, f}); err != nil {
diff --git a/runtimes/google/ipc/stream/wslistener/listener.go b/runtimes/google/ipc/stream/wslistener/listener.go
new file mode 100644
index 0000000..79dc035
--- /dev/null
+++ b/runtimes/google/ipc/stream/wslistener/listener.go
@@ -0,0 +1,214 @@
+// +build !nacl
+
+package wslistener
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "sync"
+
+ vwebsocket "veyron.io/veyron/veyron/lib/websocket"
+ "veyron.io/veyron/veyron/runtimes/google/lib/upcqueue"
+
+ "veyron.io/veyron/veyron2/vlog"
+
+ "github.com/gorilla/websocket"
+)
+
+var errListenerIsClosed = errors.New("Listener has been Closed")
+
+// We picked 0xFF because it's obviously outside the range of ASCII,
+// and is completely unused in UTF-8.
+const BinaryMagicByte byte = 0xFF
+
+const bufferSize int = 4096
+
+// A listener that is able to handle either raw tcp request or websocket requests.
+// The result of Accept is is a net.Conn interface.
+type wsTCPListener struct {
+ // The queue of net.Conn to be returned by Accept.
+ q *upcqueue.T
+
+ // The queue for the http listener when we detect an http request.
+ httpQ *upcqueue.T
+
+ // The underlying listener.
+ netLn net.Listener
+ wsServer http.Server
+
+ netLoop sync.WaitGroup
+ wsLoop sync.WaitGroup
+}
+
+// bufferedConn is used to allow us to Peek at the first byte to see if it
+// is the magic byte used by veyron tcp requests. Other than that it behaves
+// like a normal net.Conn.
+type bufferedConn struct {
+ net.Conn
+ // TODO(bjornick): Remove this buffering because we have way too much
+ // buffering anyway. We really only need to buffer the first byte.
+ r *bufio.Reader
+}
+
+func newBufferedConn(c net.Conn) bufferedConn {
+ return bufferedConn{Conn: c, r: bufio.NewReaderSize(c, bufferSize)}
+}
+
+func (c *bufferedConn) Peek(n int) ([]byte, error) {
+ return c.r.Peek(n)
+}
+
+func (c *bufferedConn) Read(p []byte) (int, error) {
+ return c.r.Read(p)
+}
+
+// queueListener is a listener that returns connections that are in q.
+type queueListener struct {
+ q *upcqueue.T
+ // ln is needed to implement Close and Addr
+ ln net.Listener
+}
+
+func (l *queueListener) Accept() (net.Conn, error) {
+ item, err := l.q.Get(nil)
+ switch {
+ case err == upcqueue.ErrQueueIsClosed:
+ return nil, errListenerIsClosed
+ case err != nil:
+ return nil, fmt.Errorf("Accept failed: %v", err)
+ default:
+ return item.(net.Conn), nil
+ }
+}
+
+func (l *queueListener) Close() error {
+ l.q.Shutdown()
+ return l.ln.Close()
+}
+
+func (l *queueListener) Addr() net.Addr {
+ return l.ln.Addr()
+}
+
+func NewListener(netLn net.Listener) net.Listener {
+ ln := &wsTCPListener{
+ q: upcqueue.New(),
+ httpQ: upcqueue.New(),
+ netLn: netLn,
+ }
+ ln.netLoop.Add(1)
+ go ln.netAcceptLoop()
+ httpListener := &queueListener{
+ q: ln.httpQ,
+ ln: ln,
+ }
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ defer ln.wsLoop.Done()
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed)
+ return
+ }
+ ws, err := websocket.Upgrade(w, r, nil, bufferSize, bufferSize)
+ if _, ok := err.(websocket.HandshakeError); ok {
+ http.Error(w, "Not a websocket handshake", 400)
+ vlog.Errorf("Rejected a non-websocket request: %v", err)
+ return
+ } else if err != nil {
+ http.Error(w, "Internal Error", 500)
+ vlog.Errorf("Rejected a non-websocket request: %v", err)
+ return
+ }
+ conn := vwebsocket.WebsocketConn(ws)
+ if err := ln.q.Put(conn); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed: %v", ws.RemoteAddr(), ws.LocalAddr(), err)
+ ws.Close()
+ return
+ }
+
+ }
+ ln.wsServer = http.Server{
+ Handler: http.HandlerFunc(handler),
+ }
+ go ln.wsServer.Serve(httpListener)
+ return ln
+}
+
+func (ln *wsTCPListener) netAcceptLoop() {
+ defer ln.Close()
+ defer ln.netLoop.Done()
+ for {
+ conn, err := ln.netLn.Accept()
+ if err != nil {
+ vlog.VI(1).Infof("Exiting netAcceptLoop: net.Listener.Accept() failed on %v with %v", ln.netLn, err)
+ return
+ }
+ vlog.VI(1).Infof("New net.Conn accepted from %s (local address: %s)", conn.RemoteAddr(), conn.LocalAddr())
+ bc := newBufferedConn(conn)
+ magic, err := bc.Peek(1)
+ if err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as the magic byte failed to be read: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+
+ vlog.VI(1).Info("Got a connection from %s (local address: %s)", conn.RemoteAddr(), conn.LocalAddr())
+ // Check to see if it is a regular connection or a http connection.
+ if magic[0] == BinaryMagicByte {
+ if _, err := bc.r.ReadByte(); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s), could read past the magic byte: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ if err := ln.q.Put(&bc); err != nil {
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed in vifLoop: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ continue
+ }
+
+ ln.wsLoop.Add(1)
+ if err := ln.httpQ.Put(&bc); err != nil {
+ ln.wsLoop.Done()
+ vlog.VI(1).Infof("Shutting down conn from %s (local address: %s) as Put failed in vifLoop: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
+ bc.Close()
+ continue
+ }
+ }
+}
+
+func (ln *wsTCPListener) Accept() (net.Conn, error) {
+ item, err := ln.q.Get(nil)
+ switch {
+ case err == upcqueue.ErrQueueIsClosed:
+ return nil, errListenerIsClosed
+ case err != nil:
+ return nil, fmt.Errorf("Accept failed: %v", err)
+ default:
+ return item.(net.Conn), nil
+ }
+}
+
+func (ln *wsTCPListener) Close() error {
+ addr := ln.netLn.Addr()
+ err := ln.netLn.Close()
+ vlog.VI(1).Infof("Closed net.Listener on (%q, %q): %v", addr.Network(), addr, err)
+ ln.httpQ.Shutdown()
+ ln.netLoop.Wait()
+ ln.wsLoop.Wait()
+ // q has to be shutdown after the netAcceptLoop finishes because that loop
+ // could be in the process of accepting a websocket connection. The ordering
+ // relative to wsLoop is not really relevant because the wsLoop counter wil
+ // decrement every time there a websocket connection has been handled and does
+ // not block on gets from q.
+ ln.q.Shutdown()
+ vlog.VI(3).Infof("Close stream.wsTCPListener %s", ln)
+ return nil
+}
+
+func (ln *wsTCPListener) Addr() net.Addr {
+ return ln.netLn.Addr()
+}
diff --git a/runtimes/google/ipc/stream/wslistener/listener_nacl.go b/runtimes/google/ipc/stream/wslistener/listener_nacl.go
new file mode 100644
index 0000000..686dfff
--- /dev/null
+++ b/runtimes/google/ipc/stream/wslistener/listener_nacl.go
@@ -0,0 +1,16 @@
+// +build nacl
+
+package wslistener
+
+import (
+ "net"
+)
+
+// Websocket listeners are not supported in NaCl.
+// This file is needed for compilation only.
+
+const BinaryMagicByte byte = 0x90
+
+func NewListener(netLn net.Listener) net.Listener {
+ panic("Websocket NewListener called in nacl code!")
+}
diff --git a/runtimes/google/naming/namespace/all_test.go b/runtimes/google/naming/namespace/all_test.go
index 163bcc9..e40b1f3 100644
--- a/runtimes/google/naming/namespace/all_test.go
+++ b/runtimes/google/naming/namespace/all_test.go
@@ -3,13 +3,14 @@
import (
"runtime"
"runtime/debug"
+ "strings"
"sync"
"testing"
"time"
"veyron.io/veyron/veyron2"
- "veyron.io/veyron/veyron2/context"
"veyron.io/veyron/veyron2/ipc"
+ "veyron.io/veyron/veyron2/ipc/stream"
"veyron.io/veyron/veyron2/naming"
"veyron.io/veyron/veyron2/options"
"veyron.io/veyron/veyron2/rt"
@@ -20,19 +21,26 @@
"veyron.io/veyron/veyron/lib/glob"
"veyron.io/veyron/veyron/lib/testutil"
+ "veyron.io/veyron/veyron/lib/websocket"
_ "veyron.io/veyron/veyron/profiles"
- "veyron.io/veyron/veyron/runtimes/google/ipc/stream/sectest"
"veyron.io/veyron/veyron/runtimes/google/naming/namespace"
service "veyron.io/veyron/veyron/services/mounttable/lib"
)
-func init() { testutil.Init() }
+func init() {
+ testutil.Init()
+ stream.RegisterProtocol("ws", websocket.Dial, nil)
+}
func boom(t *testing.T, f string, v ...interface{}) {
t.Logf(f, v...)
t.Fatal(string(debug.Stack()))
}
+func addWSName(name string) []string {
+ return []string{name, strings.Replace(name, "@tcp@", "@ws@", 1)}
+}
+
// N squared but who cares, this is a little test.
// Ignores dups.
func contains(container, contained []string) bool {
@@ -140,28 +148,20 @@
}
}
-func doResolveTest(t *testing.T, fname string, f func(context.T, string, ...naming.ResolveOpt) ([]string, error), ctx context.T, name string, want []string, opts ...naming.ResolveOpt) {
- servers, err := f(ctx, name, opts...)
- if err != nil {
- boom(t, "Failed to %s %s: %s", fname, name, err)
- }
- compare(t, fname, name, servers, want)
-}
-
func testResolveToMountTable(t *testing.T, r veyron2.Runtime, ns naming.Namespace, name string, want ...string) {
- doResolveTest(t, "ResolveToMountTable", ns.ResolveToMountTable, r.NewContext(), name, want)
-}
-
-func testResolveToMountTableWithPattern(t *testing.T, r veyron2.Runtime, ns naming.Namespace, name string, pattern naming.ResolveOpt, want ...string) {
- doResolveTest(t, "ResolveToMountTable", ns.ResolveToMountTable, r.NewContext(), name, want, pattern)
+ servers, err := ns.ResolveToMountTable(r.NewContext(), name)
+ if err != nil {
+ boom(t, "Failed to ResolveToMountTable %q: %s", name, err)
+ }
+ compare(t, "ResolveToMountTable", name, servers, want)
}
func testResolve(t *testing.T, r veyron2.Runtime, ns naming.Namespace, name string, want ...string) {
- doResolveTest(t, "Resolve", ns.Resolve, r.NewContext(), name, want)
-}
-
-func testResolveWithPattern(t *testing.T, r veyron2.Runtime, ns naming.Namespace, name string, pattern naming.ResolveOpt, want ...string) {
- doResolveTest(t, "Resolve", ns.Resolve, r.NewContext(), name, want, pattern)
+ servers, err := ns.Resolve(r.NewContext(), name)
+ if err != nil {
+ boom(t, "Failed to Resolve %q: %s", name, err)
+ }
+ compare(t, "Resolve", name, servers, want)
}
func testUnresolve(t *testing.T, r veyron2.Runtime, ns naming.Namespace, name string, want ...string) {
@@ -302,16 +302,16 @@
testResolveToMountTable(t, r, ns, m, rootMT)
// The server registered for each mount point is a mount table
- testResolve(t, r, ns, m, mts[m].name)
+ testResolve(t, r, ns, m, addWSName(mts[m].name)...)
// ResolveToMountTable will walk through to the sub MountTables
mtbar := naming.Join(m, "bar")
subMT := naming.Join(mts[m].name, "bar")
- testResolveToMountTable(t, r, ns, mtbar, subMT)
+ testResolveToMountTable(t, r, ns, mtbar, addWSName(subMT)...)
}
for _, j := range []string{j1MP, j2MP, j3MP} {
- testResolve(t, r, ns, j, jokes[j].name)
+ testResolve(t, r, ns, j, addWSName(jokes[j].name)...)
}
}
@@ -344,7 +344,7 @@
mt2mt := naming.Join(mts[mt2MP].name, "a")
// The mt2/a is served by the mt2 mount table
- testResolveToMountTable(t, r, ns, mt2a, mt2mt)
+ testResolveToMountTable(t, r, ns, mt2a, addWSName(mt2mt)...)
// The server for mt2a is mt3server from the second mount above.
testResolve(t, r, ns, mt2a, mt3Server)
@@ -359,12 +359,14 @@
}
}
+ names := []string{naming.JoinAddressName(mts[mt4MP].name, "a"),
+ naming.JoinAddressName(mts[mt5MP].name, "a")}
+ names = append(names, addWSName(naming.JoinAddressName(mts[mt2MP].name, "a"))...)
// We now have 3 mount tables prepared to serve mt2/a
- testResolveToMountTable(t, r, ns, "mt2/a",
- naming.JoinAddressName(mts[mt2MP].name, "a"),
- naming.JoinAddressName(mts[mt4MP].name, "a"),
- naming.JoinAddressName(mts[mt5MP].name, "a"))
- testResolve(t, r, ns, "mt2", mts[mt2MP].name, mts[mt4MP].name, mts[mt5MP].name)
+ testResolveToMountTable(t, r, ns, "mt2/a", names...)
+ names = []string{mts[mt4MP].name, mts[mt5MP].name}
+ names = append(names, addWSName(mts[mt2MP].name)...)
+ testResolve(t, r, ns, "mt2", names...)
}
// TestNestedMounts tests some more deeply nested mounts
@@ -380,15 +382,15 @@
// Set up some nested mounts and verify resolution.
for _, m := range []string{"mt4/foo", "mt4/foo/bar"} {
- testResolve(t, r, ns, m, mts[m].name)
+ testResolve(t, r, ns, m, addWSName(mts[m].name)...)
}
testResolveToMountTable(t, r, ns, "mt4/foo",
- naming.JoinAddressName(mts[mt4MP].name, "foo"))
+ addWSName(naming.JoinAddressName(mts[mt4MP].name, "foo"))...)
testResolveToMountTable(t, r, ns, "mt4/foo/bar",
- naming.JoinAddressName(mts["mt4/foo"].name, "bar"))
+ addWSName(naming.JoinAddressName(mts["mt4/foo"].name, "bar"))...)
testResolveToMountTable(t, r, ns, "mt4/foo/baz",
- naming.JoinAddressName(mts["mt4/foo"].name, "baz"))
+ addWSName(naming.JoinAddressName(mts["mt4/foo"].name, "baz"))...)
}
// TestServers tests invoking RPCs on simple servers
@@ -402,16 +404,16 @@
// Let's run some non-mount table services
for _, j := range []string{j1MP, j2MP, j3MP} {
- testResolve(t, r, ns, j, jokes[j].name)
+ testResolve(t, r, ns, j, addWSName(jokes[j].name)...)
knockKnock(t, r, j)
globalName := naming.JoinAddressName(mts["mt4"].name, j)
disp := &dispatcher{}
gj := "g_" + j
jokes[gj] = runServer(t, r, disp, globalName)
- testResolve(t, r, ns, "mt4/"+j, jokes[gj].name)
+ testResolve(t, r, ns, "mt4/"+j, addWSName(jokes[gj].name)...)
knockKnock(t, r, "mt4/"+j)
- testResolveToMountTable(t, r, ns, "mt4/"+j, globalName)
- testResolveToMountTable(t, r, ns, "mt4/"+j+"/garbage", globalName+"/garbage")
+ testResolveToMountTable(t, r, ns, "mt4/"+j, addWSName(globalName)...)
+ testResolveToMountTable(t, r, ns, "mt4/"+j+"/garbage", addWSName(globalName+"/garbage")...)
}
}
@@ -557,7 +559,8 @@
boom(t, "Failed to Mount %s: %s", m, err)
}
- testResolve(t, r, ns, "c1", c1.name)
+ // Since c1 was mounted with the Serve call, it will have both the tcp and ws endpoints.
+ testResolve(t, r, ns, "c1", addWSName(c1.name)...)
testResolve(t, r, ns, "c1/c2", c1.name)
testResolve(t, r, ns, "c1/c3", c3.name)
testResolve(t, r, ns, "c1/c3/c4", c1.name)
@@ -620,56 +623,3 @@
t.Errorf("namespace.New should have failed with an unrooted name")
}
}
-
-func bless(blesser, delegate security.Principal, extension string) {
- b, err := blesser.Bless(delegate.PublicKey(), blesser.BlessingStore().Default(), extension, security.UnconstrainedUse())
- if err != nil {
- panic(err)
- }
- delegate.BlessingStore().SetDefault(b)
-}
-
-func TestRootBlessing(t *testing.T) {
- // We need the default runtime for the server-side mounttable code
- // which references rt.R() to create new endpoints
- cr := rt.Init()
- r, _ := rt.New() // We use a different runtime for the client side.
-
- proot := sectest.NewPrincipal("root")
- bless(proot, r.Principal(), "server")
- bless(proot, cr.Principal(), "client")
-
- cr.Principal().AddToRoots(proot.BlessingStore().Default())
- r.Principal().AddToRoots(proot.BlessingStore().Default())
-
- root, mts, _, stopper := createNamespace(t, r)
- defer stopper()
- ns := r.Namespace()
-
- name := naming.Join(root.name, mt2MP)
- // First check with a non-matching blessing pattern.
- _, err := ns.Resolve(r.NewContext(), name, naming.RootBlessingPatternOpt("root/foobar"))
- if !verror.Is(err, verror.NoAccess.ID) {
- t.Errorf("Resolve expected NoAccess error, got %v", err)
- }
- _, err = ns.ResolveToMountTable(r.NewContext(), name, naming.RootBlessingPatternOpt("root/foobar"))
- if !verror.Is(err, verror.NoAccess.ID) {
- t.Errorf("ResolveToMountTable expected NoAccess error, got %v", err)
- }
-
- // Now check a matching pattern.
- testResolveWithPattern(t, r, ns, name, naming.RootBlessingPatternOpt("root/server"), mts[mt2MP].name)
- testResolveToMountTableWithPattern(t, r, ns, name, naming.RootBlessingPatternOpt("root/server"), name)
-
- // After successful lookup it should be cached, so the pattern doesn't matter.
- testResolveWithPattern(t, r, ns, name, naming.RootBlessingPatternOpt("root/foobar"), mts[mt2MP].name)
-
- // Test calling a method.
- jokeName := naming.Join(root.name, mt4MP, j1MP)
- runServer(t, r, &dispatcher{}, naming.Join(mts["mt4"].name, j1MP))
- _, err = r.Client().StartCall(r.NewContext(), "[root/foobar]"+jokeName, "KnockKnock", nil)
- if err == nil {
- t.Errorf("StartCall expected NoAccess error, got %v", err)
- }
- knockKnock(t, r, "[root/server]"+jokeName)
-}
diff --git a/runtimes/google/naming/namespace/resolve.go b/runtimes/google/naming/namespace/resolve.go
index 8ac7773..eb69a3c 100644
--- a/runtimes/google/naming/namespace/resolve.go
+++ b/runtimes/google/naming/namespace/resolve.go
@@ -2,7 +2,6 @@
import (
"errors"
- "fmt"
"runtime"
"veyron.io/veyron/veyron2/context"
@@ -13,17 +12,11 @@
"veyron.io/veyron/veyron2/vlog"
)
-func (ns *namespace) resolveAgainstMountTable(ctx context.T, client ipc.Client, e *naming.MountEntry, pattern string) (*naming.MountEntry, error) {
+func (ns *namespace) resolveAgainstMountTable(ctx context.T, client ipc.Client, e *naming.MountEntry) (*naming.MountEntry, error) {
// Try each server till one answers.
finalErr := errors.New("no servers to resolve query")
for _, s := range e.Servers {
- var pattern_and_name string
name := naming.JoinAddressName(s.Server, e.Name)
- if pattern != "" {
- pattern_and_name = naming.JoinAddressName(s.Server, fmt.Sprintf("[%s]%s", pattern, e.Name))
- } else {
- pattern_and_name = name
- }
// First check the cache.
if ne, err := ns.resolutionCache.lookup(name); err == nil {
vlog.VI(2).Infof("resolveAMT %s from cache -> %v", name, convertServersToStrings(ne.Servers, ne.Name))
@@ -31,7 +24,7 @@
}
// Not in cache, call the real server.
callCtx, _ := ctx.WithTimeout(callTimeout)
- call, err := client.StartCall(callCtx, pattern_and_name, "ResolveStepX", nil, options.NoResolve(true))
+ call, err := client.StartCall(callCtx, name, "ResolveStepX", nil, options.NoResolve(true))
if err != nil {
finalErr = err
vlog.VI(2).Infof("ResolveStep.StartCall %s failed: %s", name, err)
@@ -68,7 +61,7 @@
}
// ResolveX implements veyron2/naming.Namespace.
-func (ns *namespace) ResolveX(ctx context.T, name string, opts ...naming.ResolveOpt) (*naming.MountEntry, error) {
+func (ns *namespace) ResolveX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
e, _ := ns.rootMountEntry(name)
if vlog.V(2) {
@@ -79,7 +72,6 @@
if len(e.Servers) == 0 {
return nil, verror.Make(naming.ErrNoSuchName, ctx, name)
}
- pattern := getRootPattern(opts)
// Iterate walking through mount table servers.
for remaining := ns.maxResolveDepth; remaining > 0; remaining-- {
vlog.VI(2).Infof("ResolveX(%s) loop %v", name, *e)
@@ -89,7 +81,7 @@
}
var err error
curr := e
- if e, err = ns.resolveAgainstMountTable(ctx, ns.rt.Client(), curr, pattern); err != nil {
+ if e, err = ns.resolveAgainstMountTable(ctx, ns.rt.Client(), curr); err != nil {
// If the name could not be found in the mount table, return an error.
if verror.Is(err, naming.ErrNoSuchNameRoot.ID) {
err = verror.Make(naming.ErrNoSuchName, ctx, name)
@@ -98,26 +90,20 @@
vlog.VI(1).Infof("ResolveX(%s) -> (NoSuchName: %v)", name, curr)
return nil, err
}
- if verror.Is(err, verror.NoAccess.ID) {
- vlog.VI(1).Infof("ResolveX(%s) -> (NoAccess: %v)", name, curr)
- return nil, err
-
- }
// Any other failure (server not found, no ResolveStep
// method, etc.) are a sign that iterative resolution can
// stop.
vlog.VI(1).Infof("ResolveX(%s) -> %v", name, curr)
return curr, nil
}
- pattern = ""
}
return nil, verror.Make(naming.ErrResolutionDepthExceeded, ctx)
}
// Resolve implements veyron2/naming.Namespace.
-func (ns *namespace) Resolve(ctx context.T, name string, opts ...naming.ResolveOpt) ([]string, error) {
+func (ns *namespace) Resolve(ctx context.T, name string) ([]string, error) {
defer vlog.LogCall()()
- e, err := ns.ResolveX(ctx, name, opts...)
+ e, err := ns.ResolveX(ctx, name)
if err != nil {
return nil, err
}
@@ -125,7 +111,7 @@
}
// ResolveToMountTableX implements veyron2/naming.Namespace.
-func (ns *namespace) ResolveToMountTableX(ctx context.T, name string, opts ...naming.ResolveOpt) (*naming.MountEntry, error) {
+func (ns *namespace) ResolveToMountTableX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
e, _ := ns.rootMountEntry(name)
if vlog.V(2) {
@@ -136,7 +122,6 @@
if len(e.Servers) == 0 {
return nil, verror.Make(naming.ErrNoMountTable, ctx)
}
- pattern := getRootPattern(opts)
last := e
for remaining := ns.maxResolveDepth; remaining > 0; remaining-- {
vlog.VI(2).Infof("ResolveToMountTableX(%s) loop %v", name, e)
@@ -147,7 +132,7 @@
vlog.VI(1).Infof("ResolveToMountTableX(%s) -> %v", name, last)
return last, nil
}
- if e, err = ns.resolveAgainstMountTable(ctx, ns.rt.Client(), e, pattern); err != nil {
+ if e, err = ns.resolveAgainstMountTable(ctx, ns.rt.Client(), e); err != nil {
if verror.Is(err, naming.ErrNoSuchNameRoot.ID) {
vlog.VI(1).Infof("ResolveToMountTableX(%s) -> %v (NoSuchRoot: %v)", name, last, curr)
return last, nil
@@ -172,15 +157,14 @@
return nil, err
}
last = curr
- pattern = ""
}
return nil, verror.Make(naming.ErrResolutionDepthExceeded, ctx)
}
// ResolveToMountTable implements veyron2/naming.Namespace.
-func (ns *namespace) ResolveToMountTable(ctx context.T, name string, opts ...naming.ResolveOpt) ([]string, error) {
+func (ns *namespace) ResolveToMountTable(ctx context.T, name string) ([]string, error) {
defer vlog.LogCall()()
- e, err := ns.ResolveToMountTableX(ctx, name, opts...)
+ e, err := ns.ResolveToMountTableX(ctx, name)
if err != nil {
return nil, err
}
@@ -270,12 +254,3 @@
}
return flushed
}
-
-func getRootPattern(opts []naming.ResolveOpt) string {
- for _, opt := range opts {
- if pattern, ok := opt.(naming.RootBlessingPatternOpt); ok {
- return string(pattern)
- }
- }
- return ""
-}
diff --git a/runtimes/google/testing/mocks/naming/namespace.go b/runtimes/google/testing/mocks/naming/namespace.go
index f6b0ba0..bff47f6 100644
--- a/runtimes/google/testing/mocks/naming/namespace.go
+++ b/runtimes/google/testing/mocks/naming/namespace.go
@@ -57,7 +57,7 @@
return nil
}
-func (ns *namespace) Resolve(ctx context.T, name string, opts ...naming.ResolveOpt) ([]string, error) {
+func (ns *namespace) Resolve(ctx context.T, name string) ([]string, error) {
defer vlog.LogCall()()
if address, _ := naming.SplitAddressName(name); len(address) > 0 {
return []string{name}, nil
@@ -77,7 +77,7 @@
return nil, verror.NoExistf("Resolve name %q not found in %v", name, ns.mounts)
}
-func (ns *namespace) ResolveX(ctx context.T, name string, opts ...naming.ResolveOpt) (*naming.MountEntry, error) {
+func (ns *namespace) ResolveX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
e := new(naming.MountEntry)
if address, _ := naming.SplitAddressName(name); len(address) > 0 {
@@ -98,14 +98,14 @@
return nil, verror.NoExistf("Resolve name %q not found in %v", name, ns.mounts)
}
-func (ns *namespace) ResolveToMountTableX(ctx context.T, name string, opts ...naming.ResolveOpt) (*naming.MountEntry, error) {
+func (ns *namespace) ResolveToMountTableX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
// TODO(mattr): Implement this method for tests that might need it.
panic("ResolveToMountTable not implemented")
return nil, nil
}
-func (ns *namespace) ResolveToMountTable(ctx context.T, name string, opts ...naming.ResolveOpt) ([]string, error) {
+func (ns *namespace) ResolveToMountTable(ctx context.T, name string) ([]string, error) {
defer vlog.LogCall()()
// TODO(mattr): Implement this method for tests that might need it.
panic("ResolveToMountTable not implemented")
diff --git a/security/blessingstore.go b/security/blessingstore.go
index 1a7540d..57ef450 100644
--- a/security/blessingstore.go
+++ b/security/blessingstore.go
@@ -175,18 +175,75 @@
}
}
+// TODO(ataly, ashankar): Get rid of this struct once we have switched all credentials
+// directories to the new serialization format.
+type oldState struct {
+ Store map[security.BlessingPattern]security.WireBlessings
+ Default security.WireBlessings
+}
+
+// TODO(ataly, ashankar): Get rid of this method once we have switched all
+// credentials directories to the new serialization format.
+func (bs *blessingStore) tryOldFormat() bool {
+ var empty security.WireBlessings
+ if len(bs.state.Store) == 0 {
+ return bs.state.Default == nil || reflect.DeepEqual(bs.state.Default.Value, empty)
+ }
+ for _, wb := range bs.state.Store {
+ if len(wb.Value.CertificateChains) == 0 {
+ return true
+ }
+ }
+ return false
+}
+
+// TODO(ataly, ashankar): Get rid of this method once we have switched all
+// credentials directories to the new serialization format.
+func (bs *blessingStore) deserializeOld() error {
+ data, signature, err := bs.serializer.Readers()
+ if err != nil {
+ return err
+ }
+ if data == nil && signature == nil {
+ return nil
+ }
+ var old oldState
+ if err := decodeFromStorage(&old, data, signature, bs.signer.PublicKey()); err != nil {
+ return err
+ }
+ for p, wire := range old.Store {
+ bs.state.Store[p] = &blessings{Value: wire}
+ }
+ bs.state.Default = &blessings{Value: old.Default}
+ return nil
+}
+
+func (bs *blessingStore) deserialize() error {
+ data, signature, err := bs.serializer.Readers()
+ if err != nil {
+ return err
+ }
+ if data == nil && signature == nil {
+ return nil
+ }
+ if err := decodeFromStorage(&bs.state, data, signature, bs.signer.PublicKey()); err == nil && !bs.tryOldFormat() {
+ return nil
+ }
+ if err := bs.deserializeOld(); err != nil {
+ return err
+ }
+ return nil
+}
+
// newPersistingBlessingStore returns a security.BlessingStore for a principal
// that is initialized with the persisted data. The returned security.BlessingStore
// also persists any updates to its state.
func newPersistingBlessingStore(serializer SerializerReaderWriter, signer serialization.Signer) (security.BlessingStore, error) {
verifyBlessings := func(wb *blessings, key security.PublicKey) error {
- if wb == nil {
- return nil
- }
if err := wb.Verify(); err != nil {
return err
}
- if b := wb.Blessings(); !reflect.DeepEqual(b.PublicKey(), key) {
+ if b := wb.Blessings(); b != nil && !reflect.DeepEqual(b.PublicKey(), key) {
return fmt.Errorf("read Blessings: %v that are not for provided PublicKey: %v", b, key)
}
return nil
@@ -200,21 +257,25 @@
serializer: serializer,
signer: signer,
}
- data, signature, err := bs.serializer.Readers()
- if err != nil {
+ if err := bs.deserialize(); err != nil {
return nil, err
}
- if data != nil && signature != nil {
- if err := decodeFromStorage(&bs.state, data, signature, bs.signer.PublicKey()); err != nil {
- return nil, err
- }
- }
for _, wb := range bs.state.Store {
if err := verifyBlessings(wb, bs.publicKey); err != nil {
return nil, err
}
}
- if err := verifyBlessings(bs.state.Default, bs.publicKey); err != nil {
+ if bs.state.Default != nil {
+ if err := verifyBlessings(bs.state.Default, bs.publicKey); err != nil {
+ return nil, err
+ }
+ }
+ // Save the blessingstore in the new serialization format. This will ensure
+ // that all credentials directories in the old format will switch to the new
+ // format.
+ // TODO(ataly, ashankar): Get rid of this once we have switched all
+ // credentials directories to the new serialization format.
+ if err := bs.save(); err != nil {
return nil, err
}
return bs, nil
diff --git a/security/principal.go b/security/principal.go
index 364ee96..4ebf6e6 100644
--- a/security/principal.go
+++ b/security/principal.go
@@ -43,17 +43,9 @@
if err := mkDir(dir); err != nil {
return nil, err
}
- roots, err := NewFileSerializer(path.Join(dir, blessingRootsDataFile), path.Join(dir, blessingRootsSigFile))
- if err != nil {
- return nil, err
- }
- store, err := NewFileSerializer(path.Join(dir, blessingStoreDataFile), path.Join(dir, blessingStoreSigFile))
- if err != nil {
- return nil, err
- }
return &PrincipalStateSerializer{
- BlessingRoots: roots,
- BlessingStore: store,
+ BlessingRoots: NewFileSerializer(path.Join(dir, blessingRootsDataFile), path.Join(dir, blessingRootsSigFile)),
+ BlessingStore: NewFileSerializer(path.Join(dir, blessingStoreDataFile), path.Join(dir, blessingStoreSigFile)),
}, nil
}
diff --git a/security/serializer_reader_writer.go b/security/serializer_reader_writer.go
index 633456d..6a7c03c 100644
--- a/security/serializer_reader_writer.go
+++ b/security/serializer_reader_writer.go
@@ -18,48 +18,44 @@
// FileSerializer implements SerializerReaderWriter that persists state to files.
type FileSerializer struct {
- data *os.File
- signature *os.File
-
dataFilePath string
signatureFilePath string
}
// NewFileSerializer creates a FileSerializer with the given data and signature files.
-func NewFileSerializer(dataFilePath, signatureFilePath string) (*FileSerializer, error) {
- data, err := os.Open(dataFilePath)
- if err != nil && !os.IsNotExist(err) {
- return nil, err
- }
- signature, err := os.Open(signatureFilePath)
- if err != nil && !os.IsNotExist(err) {
- return nil, err
- }
+func NewFileSerializer(dataFilePath, signatureFilePath string) *FileSerializer {
return &FileSerializer{
- data: data,
- signature: signature,
dataFilePath: dataFilePath,
signatureFilePath: signatureFilePath,
- }, nil
+ }
}
func (fs *FileSerializer) Readers() (io.ReadCloser, io.ReadCloser, error) {
- if fs.data == nil || fs.signature == nil {
+ data, err := os.Open(fs.dataFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, nil, err
+ }
+ signature, err := os.Open(fs.signatureFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, nil, err
+ }
+ if data == nil || signature == nil {
return nil, nil, nil
}
- return fs.data, fs.signature, nil
+ return data, signature, nil
}
func (fs *FileSerializer) Writers() (io.WriteCloser, io.WriteCloser, error) {
// Remove previous version of the files
os.Remove(fs.dataFilePath)
os.Remove(fs.signatureFilePath)
- var err error
- if fs.data, err = os.Create(fs.dataFilePath); err != nil {
+ data, err := os.Create(fs.dataFilePath)
+ if err != nil {
return nil, nil, err
}
- if fs.signature, err = os.Create(fs.signatureFilePath); err != nil {
+ signature, err := os.Create(fs.signatureFilePath)
+ if err != nil {
return nil, nil, err
}
- return fs.data, fs.signature, nil
+ return data, signature, nil
}
diff --git a/security/testdata/blessingstore.sig b/security/testdata/blessingstore.sig
index 0256a55..2cdfb39 100644
--- a/security/testdata/blessingstore.sig
+++ b/security/testdata/blessingstore.sig
Binary files differ
diff --git a/services/identity/auditor/blessing_auditor.go b/services/identity/auditor/blessing_auditor.go
index 28f7b62..b00d28d 100644
--- a/services/identity/auditor/blessing_auditor.go
+++ b/services/identity/auditor/blessing_auditor.go
@@ -2,8 +2,8 @@
import (
"bytes"
+ "database/sql"
"fmt"
- _ "github.com/go-sql-driver/mysql"
"strings"
"time"
@@ -32,8 +32,8 @@
// NewSQLBlessingAuditor returns an auditor for wrapping a principal with, and a BlessingLogReader
// for reading the audits made by that auditor. The config is used to construct the connection
// to the SQL database that the auditor and BlessingLogReader use.
-func NewSQLBlessingAuditor(config SQLConfig) (audit.Auditor, BlessingLogReader, error) {
- db, err := newSQLDatabase(config)
+func NewSQLBlessingAuditor(sqlDB *sql.DB) (audit.Auditor, BlessingLogReader, error) {
+ db, err := newSQLDatabase(sqlDB, "BlessingAudit")
if err != nil {
return nil, nil, fmt.Errorf("failed to create sql db: %v", err)
}
diff --git a/services/identity/auditor/blessing_auditor_test.go b/services/identity/auditor/blessing_auditor_test.go
index 946d976..7970ae3 100644
--- a/services/identity/auditor/blessing_auditor_test.go
+++ b/services/identity/auditor/blessing_auditor_test.go
@@ -76,12 +76,12 @@
if !reflect.DeepEqual(got.Blessings, test.Blessings) {
t.Errorf("got %v, want %v", got.Blessings, test.Blessings)
}
- var extraRoutines bool
+ var extra bool
for _ = range ch {
// Drain the channel to prevent the producer goroutines from being leaked.
- extraRoutines = true
+ extra = true
}
- if extraRoutines {
+ if extra {
t.Errorf("Got more entries that expected for test %+v", test)
}
}
diff --git a/services/identity/auditor/sql_database.go b/services/identity/auditor/sql_database.go
index 405f97f..aeb9af6 100644
--- a/services/identity/auditor/sql_database.go
+++ b/services/identity/auditor/sql_database.go
@@ -4,61 +4,52 @@
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
-
"time"
+
"veyron.io/veyron/veyron2/vlog"
)
-// SQLConfig contains the information to create a connection to a sql database.
-type SQLConfig struct {
- // Database is a driver specific string specifying how to connect to the database.
- Database string `json:"database"`
- Table string `json:"table"`
-}
-
type database interface {
Insert(entry databaseEntry) error
Query(email string) <-chan databaseEntry
}
type databaseEntry struct {
- email, revocationCaveatID string
- caveats, blessings []byte
- timestamp time.Time
- decodeErr error
+ email string
+ caveats, blessings []byte
+ timestamp time.Time
+ decodeErr error
}
// newSQLDatabase returns a SQL implementation of the database interface.
// If the table does not exist it creates it.
-func newSQLDatabase(config SQLConfig) (database, error) {
- db, err := sql.Open("mysql", config.Database)
- if err != nil {
- return nil, fmt.Errorf("failed to create database with config(%v): %v", config, err)
- }
- if err := db.Ping(); err != nil {
- return nil, err
- }
- createStmt, err := db.Prepare(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s ( Email NVARCHAR(256), Caveats BLOB, Timestamp DATETIME, RevocationCaveatID NVARCHAR(1000), Blessings BLOB );", config.Table))
+func newSQLDatabase(db *sql.DB, table string) (database, error) {
+ createStmt, err := db.Prepare(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s ( Email NVARCHAR(256), Caveats BLOB, Timestamp DATETIME, Blessings BLOB );", table))
if err != nil {
return nil, err
}
if _, err = createStmt.Exec(); err != nil {
return nil, err
}
- insertStmt, err := db.Prepare(fmt.Sprintf("INSERT INTO %s (Email, Caveats, RevocationCaveatID, Timestamp, Blessings) VALUES (?, ?, ?, ?, ?)", config.Table))
+ insertStmt, err := db.Prepare(fmt.Sprintf("INSERT INTO %s (Email, Caveats, Timestamp, Blessings) VALUES (?, ?, ?, ?)", table))
if err != nil {
return nil, err
}
- queryStmt, err := db.Prepare(fmt.Sprintf("SELECT Email, Caveats, RevocationCaveatID, Timestamp, Blessings from %s WHERE Email=?", config.Table))
+ queryStmt, err := db.Prepare(fmt.Sprintf("SELECT Email, Caveats, Timestamp, Blessings FROM %s WHERE Email=?", table))
return sqlDatabase{insertStmt, queryStmt}, err
}
+// Table with 4 columns:
+// (1) Email = string email of the Blessee.
+// (2) Caveats = vom encoded caveats
+// (3) Blessings = vom encoded resulting blessings.
+// (4) Timestamp = time that the blessing happened.
type sqlDatabase struct {
insertStmt, queryStmt *sql.Stmt
}
func (s sqlDatabase) Insert(entry databaseEntry) error {
- _, err := s.insertStmt.Exec(entry.email, entry.caveats, entry.revocationCaveatID, entry.timestamp, entry.blessings)
+ _, err := s.insertStmt.Exec(entry.email, entry.caveats, entry.timestamp, entry.blessings)
return err
}
@@ -78,7 +69,7 @@
}
for rows.Next() {
var dbentry databaseEntry
- if err = rows.Scan(&dbentry.email, &dbentry.caveats, &dbentry.revocationCaveatID, &dbentry.timestamp, &dbentry.blessings); err != nil {
+ if err = rows.Scan(&dbentry.email, &dbentry.caveats, &dbentry.timestamp, &dbentry.blessings); err != nil {
vlog.Errorf("scan of row failed %v", err)
dbentry.decodeErr = fmt.Errorf("failed to read sql row, %s", err)
}
diff --git a/services/identity/auditor/sql_database_test.go b/services/identity/auditor/sql_database_test.go
new file mode 100644
index 0000000..bbd2fa6
--- /dev/null
+++ b/services/identity/auditor/sql_database_test.go
@@ -0,0 +1,53 @@
+package auditor
+
+import (
+ "github.com/DATA-DOG/go-sqlmock"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestSQLDatabaseQuery(t *testing.T) {
+ db, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("failed to create new mock database stub: %v", err)
+ }
+ columns := []string{"Email", "Caveat", "Timestamp", "Blessings"}
+ sqlmock.ExpectExec("CREATE TABLE IF NOT EXISTS tableName (.+)").
+ WillReturnResult(sqlmock.NewResult(0, 1))
+ d, err := newSQLDatabase(db, "tableName")
+ if err != nil {
+ t.Fatalf("failed to create SQLDatabase: %v", err)
+ }
+
+ entry := databaseEntry{
+ email: "email",
+ caveats: []byte("caveats"),
+ timestamp: time.Now(),
+ blessings: []byte("blessings"),
+ }
+ sqlmock.ExpectExec("INSERT INTO tableName (.+) VALUES (.+)").
+ WithArgs(entry.email, entry.caveats, entry.timestamp, entry.blessings).
+ WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
+ if err := d.Insert(entry); err != nil {
+ t.Errorf("failed to insert into SQLDatabase: %v", err)
+ }
+
+ // Test the querying.
+ sqlmock.ExpectQuery("SELECT Email, Caveats, Timestamp, Blessings FROM tableName").
+ WithArgs(entry.email).
+ WillReturnRows(sqlmock.NewRows(columns).AddRow(entry.email, entry.caveats, entry.timestamp, entry.blessings))
+ ch := d.Query(entry.email)
+ if res := <-ch; !reflect.DeepEqual(res, entry) {
+ t.Errorf("got %#v, expected %#v", res, entry)
+ }
+
+ var extra bool
+ for _ = range ch {
+ // Drain the channel to prevent the producer goroutines from being leaked.
+ extra = true
+ }
+ if extra {
+ t.Errorf("Got more entries that expected")
+ }
+}
diff --git a/services/identity/blesser/oauth.go b/services/identity/blesser/oauth.go
index 7b59171..c50ed3a 100644
--- a/services/identity/blesser/oauth.go
+++ b/services/identity/blesser/oauth.go
@@ -122,11 +122,7 @@
var caveat security.Caveat
var err error
if b.revocationManager != nil {
- revocationCaveat, err := b.revocationManager.NewCaveat(self.PublicKey(), b.dischargerLocation)
- if err != nil {
- return noblessings, "", err
- }
- caveat, err = security.NewCaveat(revocationCaveat)
+ caveat, err = b.revocationManager.NewCaveat(self.PublicKey(), b.dischargerLocation)
} else {
caveat, err = security.ExpiryCaveat(time.Now().Add(b.duration))
}
diff --git a/services/identity/googleoauth/handler.go b/services/identity/googleoauth/handler.go
index 109f3fd..2ac127b 100644
--- a/services/identity/googleoauth/handler.go
+++ b/services/identity/googleoauth/handler.go
@@ -430,11 +430,7 @@
if h.args.RevocationManager == nil {
return nil, fmt.Errorf("server not configured to support revocation")
}
- tpc, err := h.args.RevocationManager.NewCaveat(h.args.R.Principal().PublicKey(), h.args.DischargerLocation)
- if err != nil {
- return nil, fmt.Errorf("failed to create revocation caveat: %v", err)
- }
- revocation, err := security.NewCaveat(tpc)
+ revocation, err := h.args.RevocationManager.NewCaveat(h.args.R.Principal().PublicKey(), h.args.DischargerLocation)
if err != nil {
return nil, fmt.Errorf("failed to create revocation caveat: %v", err)
}
diff --git a/services/identity/googleoauth/template.go b/services/identity/googleoauth/template.go
index 2bd72bd..958c70e 100644
--- a/services/identity/googleoauth/template.go
+++ b/services/identity/googleoauth/template.go
@@ -171,18 +171,17 @@
<label class="col-sm-2" for="required-caveat">Expiration</label>
<div class="col-sm-10" class="input-group" name="required-caveat">
<div class="radio">
- <div class="input-group">
- <input type="radio" name="requiredCaveat" id="requiredCaveat" value="Expiry" checked>
- <input type="text" name="expiry" id="expiry" value="1h" placeholder="time after which the blessing will expire">
- </div>
- </div>
- <div class="radio">
<label>
- <!-- TODO(suharshs): Re-enable -->
- <input type="radio" name="requiredCaveat" id="requiredCaveat" value="Revocation" disabled>
+ <input type="radio" name="requiredCaveat" id="requiredCaveat" value="Revocation" checked>
When explicitly revoked
</label>
</div>
+ <div class="radio">
+ <div class="input-group">
+ <input type="radio" name="requiredCaveat" id="requiredCaveat" value="Expiry">
+ <input type="text" name="expiry" id="expiry" value="1h" placeholder="time after which the blessing will expire">
+ </div>
+ </div>
</div>
</div>
<h4 class="form-signin-heading">Additional caveats</h4>
diff --git a/services/identity/identityd/main.go b/services/identity/identityd/main.go
index 3372b0b..0219929 100644
--- a/services/identity/identityd/main.go
+++ b/services/identity/identityd/main.go
@@ -3,7 +3,7 @@
import (
"crypto/rand"
- "encoding/json"
+ "database/sql"
"flag"
"fmt"
"html/template"
@@ -42,18 +42,14 @@
tlsconfig = flag.String("tlsconfig", "", "Comma-separated list of TLS certificate and private key files. This must be provided.")
host = flag.String("host", defaultHost(), "Hostname the HTTP server listens on. This can be the name of the host running the webserver, but if running behind a NAT or load balancer, this should be the host name that clients will connect to. For example, if set to 'x.com', Veyron identities will have the IssuerName set to 'x.com' and clients can expect to find the public key of the signer at 'x.com/pubkey/'.")
- // Flag controlling auditing of Blessing operations.
- auditConfig = flag.String("audit_config", "", "A JSON-encoded file with sql server configuration information for auditing. The file must have an entry for user, host, password, database, and table.")
+ // Flag controlling auditing and revocation of Blessing operations.
+ sqlConfig = flag.String("sqlconfig", "", "Path to file containing go-sql-driver connection string of the following form: [username[:password]@][protocol[(address)]]/dbname")
// Configuration for various Google OAuth-based clients.
googleConfigWeb = flag.String("google_config_web", "", "Path to JSON-encoded OAuth client configuration for the web application that renders the audit log for blessings provided by this provider.")
googleConfigChrome = flag.String("google_config_chrome", "", "Path to the JSON-encoded OAuth client configuration for Chrome browser applications that obtain blessings from this server (via the OAuthBlesser.BlessUsingAccessToken RPC) from this server.")
googleConfigAndroid = flag.String("google_config_android", "", "Path to the JSON-encoded OAuth client configuration for Android applications that obtain blessings from this server (via the OAuthBlesser.BlessUsingAccessToken RPC) from this server.")
googleDomain = flag.String("google_domain", "", "An optional domain name. When set, only email addresses from this domain are allowed to authenticate via Google OAuth")
-
- // Revocation/expiry configuration.
- // TODO(ashankar,ataly,suharshs): Re-enable by default once the move to the new security API is complete?
- revocationDir = flag.String("revocation_dir", "" /*filepath.Join(os.TempDir(), "revocation_dir")*/, "Path where the revocation manager will store caveat and revocation information.")
)
const (
@@ -64,14 +60,32 @@
func main() {
flag.Usage = usage
- p, blessingLogReader := providerPrincipal()
+ flag.Parse()
+
+ var sqlDB *sql.DB
+ var err error
+ if len(*sqlConfig) > 0 {
+ config, err := ioutil.ReadFile(*sqlConfig)
+ if err != nil {
+ vlog.Fatalf("failed to read sql config from %v", *sqlConfig)
+ }
+ sqlDB, err = dbFromConfigDatabase(strings.Trim(string(config), "\n"))
+ if err != nil {
+ vlog.Fatalf("failed to create sqlDB: %v", err)
+ }
+ }
+
+ p, blessingLogReader := providerPrincipal(sqlDB)
r := rt.Init(options.RuntimePrincipal{p})
defer r.Cleanup()
- // Calling with empty string returns a empty RevocationManager
- revocationManager, err := revocation.NewRevocationManager(*revocationDir)
- if err != nil {
- vlog.Fatalf("Failed to start RevocationManager: %v", err)
+ var revocationManager *revocation.RevocationManager
+ // Only set revocationManager sqlConfig (and thus sqlDB) is set.
+ if sqlDB != nil {
+ revocationManager, err = revocation.NewRevocationManager(sqlDB)
+ if err != nil {
+ vlog.Fatalf("Failed to start RevocationManager: %v", err)
+ }
}
// Setup handlers
@@ -119,7 +133,7 @@
if len(*googleConfigChrome) > 0 || len(*googleConfigAndroid) > 0 {
args.GoogleServers = appendSuffixTo(published, googleService)
}
- if len(*auditConfig) > 0 && len(*googleConfigWeb) > 0 {
+ if sqlDB != nil && len(*googleConfigWeb) > 0 {
args.ListBlessingsRoute = googleoauth.ListBlessingsRoute
}
if err := tmpl.Execute(w, args); err != nil {
@@ -281,7 +295,7 @@
// providerPrincipal returns the Principal to use for the identity provider (i.e., this program) and
// the database where audits will be store. If no database exists nil will be returned.
-func providerPrincipal() (security.Principal, auditor.BlessingLogReader) {
+func providerPrincipal(sqlDB *sql.DB) (security.Principal, auditor.BlessingLogReader) {
// TODO(ashankar): Somewhat silly to have to create a runtime, but oh-well.
r, err := rt.New()
if err != nil {
@@ -289,33 +303,25 @@
}
defer r.Cleanup()
p := r.Principal()
- if len(*auditConfig) == 0 {
+ if sqlDB == nil {
return p, nil
}
- config, err := readSQLConfigFromFile(*auditConfig)
- if err != nil {
- vlog.Fatalf("Failed to read sql config: %v", err)
- }
- auditor, reader, err := auditor.NewSQLBlessingAuditor(config)
+ auditor, reader, err := auditor.NewSQLBlessingAuditor(sqlDB)
if err != nil {
vlog.Fatalf("Failed to create sql auditor from config: %v", err)
}
return audit.NewPrincipal(p, auditor), reader
}
-func readSQLConfigFromFile(file string) (auditor.SQLConfig, error) {
- var config auditor.SQLConfig
- content, err := ioutil.ReadFile(file)
+func dbFromConfigDatabase(database string) (*sql.DB, error) {
+ db, err := sql.Open("mysql", database+"?parseTime=true")
if err != nil {
- return config, err
+ return nil, fmt.Errorf("failed to create database with database(%v): %v", database, err)
}
- if err := json.Unmarshal(content, &config); err != nil {
- return config, err
+ if err := db.Ping(); err != nil {
+ return nil, err
}
- if len(strings.Split(config.Table, " ")) != 1 || strings.Contains(config.Table, ";") {
- return config, fmt.Errorf("sql config table value must be 1 word long")
- }
- return config, nil
+ return db, nil
}
func httpaddress() string {
diff --git a/services/identity/revocation/revocation_manager.go b/services/identity/revocation/revocation_manager.go
index 698e759..b4b2178 100644
--- a/services/identity/revocation/revocation_manager.go
+++ b/services/identity/revocation/revocation_manager.go
@@ -3,112 +3,89 @@
import (
"crypto/rand"
- "encoding/hex"
+ "database/sql"
"fmt"
- "path/filepath"
- "strconv"
"sync"
"time"
- "veyron.io/veyron/veyron/services/identity/util"
"veyron.io/veyron/veyron2/security"
"veyron.io/veyron/veyron2/vom"
)
// RevocationManager persists information for revocation caveats to provided discharges and allow for future revocations.
-type RevocationManager struct {
- caveatMap *util.DirectoryStore // Map of blessed identity's caveats. ThirdPartyCaveatID -> revocationCaveatID
+type RevocationManager struct{}
+
+// NewRevocationManager returns a RevocationManager that persists information about
+// revocationCaveats in a SQL database and allows for revocation and caveat creation.
+// This function can only be called once because of the use of global variables.
+func NewRevocationManager(sqlDB *sql.DB) (*RevocationManager, error) {
+ revocationLock.Lock()
+ defer revocationLock.Unlock()
+ if revocationDB != nil {
+ return nil, fmt.Errorf("NewRevocationManager can only be called once")
+ }
+ var err error
+ revocationDB, err = newSQLDatabase(sqlDB, "RevocationCaveatInfo")
+ if err != nil {
+ return nil, err
+ }
+ return &RevocationManager{}, nil
}
-var revocationMap *util.DirectoryStore
+var revocationDB database
var revocationLock sync.RWMutex
-// NewCaveat returns a security.ThirdPartyCaveat for which discharges will be
+// NewCaveat returns a security.Caveat constructed with a ThirdPartyCaveat for which discharges will be
// issued iff Revoke has not been called for the returned caveat.
-func (r *RevocationManager) NewCaveat(discharger security.PublicKey, dischargerLocation string) (security.ThirdPartyCaveat, error) {
+func (r *RevocationManager) NewCaveat(discharger security.PublicKey, dischargerLocation string) (security.Caveat, error) {
+ var empty security.Caveat
var revocation [16]byte
if _, err := rand.Read(revocation[:]); err != nil {
- return nil, err
+ return empty, err
}
restriction, err := security.NewCaveat(revocationCaveat(revocation))
if err != nil {
- return nil, err
+ return empty, err
}
cav, err := security.NewPublicKeyCaveat(discharger, dischargerLocation, security.ThirdPartyRequirements{}, restriction)
if err != nil {
- return nil, err
+ return empty, err
}
- if err = r.caveatMap.Put(hex.EncodeToString([]byte(cav.ID())), hex.EncodeToString(revocation[:])); err != nil {
- return nil, err
+ if err = revocationDB.InsertCaveat(cav.ID(), revocation[:]); err != nil {
+ return empty, err
}
- return cav, nil
+ return security.NewCaveat(cav)
}
// Revoke disables discharges from being issued for the provided third-party caveat.
func (r *RevocationManager) Revoke(caveatID string) error {
- token, err := r.caveatMap.Get(hex.EncodeToString([]byte(caveatID)))
- if err != nil {
- return err
- }
- return revocationMap.Put(token, strconv.FormatInt(time.Now().Unix(), 10))
+ return revocationDB.Revoke(caveatID)
}
// GetRevocationTimestamp returns the timestamp at which a caveat was revoked.
// If the caveat wasn't revoked returns nil
func (r *RevocationManager) GetRevocationTime(caveatID string) *time.Time {
- token, err := r.caveatMap.Get(hex.EncodeToString([]byte(caveatID)))
+ timestamp, err := revocationDB.RevocationTime(caveatID)
if err != nil {
return nil
}
- timestamp, err := revocationMap.Get(token)
- if err != nil {
- return nil
- }
- unix_int, err := strconv.ParseInt(timestamp, 10, 64)
- if err != nil {
- return nil
- }
- revocationTime := time.Unix(unix_int, 0)
- return &revocationTime
+ return timestamp
}
type revocationCaveat [16]byte
func (cav revocationCaveat) Validate(security.Context) error {
revocationLock.RLock()
- if revocationMap == nil {
+ if revocationDB == nil {
revocationLock.RUnlock()
return fmt.Errorf("missing call to NewRevocationManager")
}
revocationLock.RUnlock()
- if revocationMap.Exists(hex.EncodeToString(cav[:])) {
+ revoked, err := revocationDB.IsRevoked(cav[:])
+ if revoked {
return fmt.Errorf("revoked")
}
- return nil
-}
-
-// NewRevocationManager returns a RevocationManager that persists information about
-// revocationCaveats and allows for revocation and caveat creation.
-// This function can only be called once because of the use of global variables.
-func NewRevocationManager(dir string) (*RevocationManager, error) {
- revocationLock.Lock()
- defer revocationLock.Unlock()
- if revocationMap != nil {
- return nil, fmt.Errorf("NewRevocationManager can only be called once")
- }
- // If empty string return nil revocationManager
- if len(dir) == 0 {
- return nil, nil
- }
- caveatMap, err := util.NewDirectoryStore(filepath.Join(dir, "caveat_dir"))
- if err != nil {
- return nil, err
- }
- revocationMap, err = util.NewDirectoryStore(filepath.Join(dir, "revocation_dir"))
- if err != nil {
- return nil, err
- }
- return &RevocationManager{caveatMap}, nil
+ return err
}
func init() {
diff --git a/services/identity/revocation/revocation_test.go b/services/identity/revocation/revocation_test.go
new file mode 100644
index 0000000..171d036
--- /dev/null
+++ b/services/identity/revocation/revocation_test.go
@@ -0,0 +1,104 @@
+package revocation
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "veyron.io/veyron/veyron2"
+ "veyron.io/veyron/veyron2/naming"
+ "veyron.io/veyron/veyron2/rt"
+ "veyron.io/veyron/veyron2/security"
+ "veyron.io/veyron/veyron2/vom"
+
+ "veyron.io/veyron/veyron/profiles"
+ services "veyron.io/veyron/veyron/services/security"
+ "veyron.io/veyron/veyron/services/security/discharger"
+)
+
+type mockDatabase struct {
+ tpCavIDToRevCavID map[string][]byte
+ revCavIDToTimestamp map[string]*time.Time
+}
+
+func (m *mockDatabase) InsertCaveat(thirdPartyCaveatID string, revocationCaveatID []byte) error {
+ m.tpCavIDToRevCavID[thirdPartyCaveatID] = revocationCaveatID
+ return nil
+}
+
+func (m *mockDatabase) Revoke(thirdPartyCaveatID string) error {
+ timestamp := time.Now()
+ m.revCavIDToTimestamp[string(m.tpCavIDToRevCavID[thirdPartyCaveatID])] = ×tamp
+ return nil
+}
+
+func (m *mockDatabase) IsRevoked(revocationCaveatID []byte) (bool, error) {
+ _, exists := m.revCavIDToTimestamp[string(revocationCaveatID)]
+ return exists, nil
+}
+
+func (m *mockDatabase) RevocationTime(thirdPartyCaveatID string) (*time.Time, error) {
+ return m.revCavIDToTimestamp[string(m.tpCavIDToRevCavID[thirdPartyCaveatID])], nil
+}
+
+func newRevocationManager(t *testing.T) *RevocationManager {
+ revocationDB = &mockDatabase{make(map[string][]byte), make(map[string]*time.Time)}
+ return &RevocationManager{}
+}
+
+func revokerSetup(t *testing.T) (dischargerKey security.PublicKey, dischargerEndpoint string, revoker *RevocationManager, closeFunc func(), runtime veyron2.Runtime) {
+ r := rt.Init()
+ revokerService := newRevocationManager(t)
+ dischargerServer, err := r.NewServer()
+ if err != nil {
+ t.Fatalf("rt.R().NewServer: %s", err)
+ }
+ dischargerEP, err := dischargerServer.Listen(profiles.LocalListenSpec)
+ if err != nil {
+ t.Fatalf("dischargerServer.Listen failed: %v", err)
+ }
+ dischargerServiceStub := services.DischargerServer(discharger.NewDischarger())
+ if err := dischargerServer.Serve("", dischargerServiceStub, nil); err != nil {
+ t.Fatalf("dischargerServer.Serve revoker: %s", err)
+ }
+ return r.Principal().PublicKey(),
+ naming.JoinAddressName(dischargerEP.String(), ""),
+ revokerService,
+ func() {
+ dischargerServer.Stop()
+ },
+ r
+}
+
+func TestDischargeRevokeDischargeRevokeDischarge(t *testing.T) {
+ dcKey, dc, revoker, closeFunc, r := revokerSetup(t)
+ defer closeFunc()
+
+ discharger := services.DischargerClient(dc)
+ caveat, err := revoker.NewCaveat(dcKey, dc)
+ if err != nil {
+ t.Fatalf("failed to create revocation caveat: %s", err)
+ }
+ var cav security.ThirdPartyCaveat
+ if err := vom.NewDecoder(bytes.NewBuffer(caveat.ValidatorVOM)).Decode(&cav); err != nil {
+ t.Fatalf("failed to create decode tp caveat: %s", err)
+ }
+
+ var impetus security.DischargeImpetus
+
+ if _, err = discharger.Discharge(r.NewContext(), cav, impetus); err != nil {
+ t.Fatalf("failed to get discharge: %s", err)
+ }
+ if err = revoker.Revoke(cav.ID()); err != nil {
+ t.Fatalf("failed to revoke: %s", err)
+ }
+ if discharge, err := discharger.Discharge(r.NewContext(), cav, impetus); err == nil || discharge != nil {
+ t.Fatalf("got a discharge for a revoked caveat: %s", err)
+ }
+ if err = revoker.Revoke(cav.ID()); err != nil {
+ t.Fatalf("failed to revoke again: %s", err)
+ }
+ if discharge, err := discharger.Discharge(r.NewContext(), cav, impetus); err == nil || discharge != nil {
+ t.Fatalf("got a discharge for a doubly revoked caveat: %s", err)
+ }
+}
diff --git a/services/identity/revocation/revoker_test.go b/services/identity/revocation/revoker_test.go
deleted file mode 100644
index 7f5338b..0000000
--- a/services/identity/revocation/revoker_test.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package revocation
-
-import (
- "os"
- "path/filepath"
- "testing"
-
- "veyron.io/veyron/veyron2"
- "veyron.io/veyron/veyron2/naming"
- "veyron.io/veyron/veyron2/rt"
- "veyron.io/veyron/veyron2/security"
-
- "veyron.io/veyron/veyron/profiles"
- services "veyron.io/veyron/veyron/services/security"
- "veyron.io/veyron/veyron/services/security/discharger"
-)
-
-func revokerSetup(t *testing.T) (dischargerKey security.PublicKey, dischargerEndpoint string, revoker *RevocationManager, closeFunc func(), runtime veyron2.Runtime) {
- var dir = filepath.Join(os.TempDir(), "revoker_test_dir")
- r := rt.Init()
- revokerService, err := NewRevocationManager(dir)
- if err != nil {
- t.Fatalf("NewRevocationManager failed: %v", err)
- }
-
- dischargerServer, err := r.NewServer()
- if err != nil {
- t.Fatalf("rt.R().NewServer: %s", err)
- }
- dischargerEP, err := dischargerServer.Listen(profiles.LocalListenSpec)
- if err != nil {
- t.Fatalf("dischargerServer.Listen failed: %v", err)
- }
- dischargerServiceStub := services.DischargerServer(discharger.NewDischarger())
- if err := dischargerServer.Serve("", dischargerServiceStub, nil); err != nil {
- t.Fatalf("dischargerServer.Serve revoker: %s", err)
- }
- return r.Principal().PublicKey(),
- naming.JoinAddressName(dischargerEP.String(), ""),
- revokerService,
- func() {
- defer os.RemoveAll(dir)
- dischargerServer.Stop()
- },
- r
-}
-
-func TestDischargeRevokeDischargeRevokeDischarge(t *testing.T) {
- dcKey, dc, revoker, closeFunc, r := revokerSetup(t)
- defer closeFunc()
-
- discharger := services.DischargerClient(dc)
- cav, err := revoker.NewCaveat(dcKey, dc)
- if err != nil {
- t.Fatalf("failed to create public key caveat: %s", err)
- }
-
- var impetus security.DischargeImpetus
-
- if _, err = discharger.Discharge(r.NewContext(), cav, impetus); err != nil {
- t.Fatalf("failed to get discharge: %s", err)
- }
- if err = revoker.Revoke(cav.ID()); err != nil {
- t.Fatalf("failed to revoke: %s", err)
- }
- if discharge, err := discharger.Discharge(r.NewContext(), cav, impetus); err == nil || discharge != nil {
- t.Fatalf("got a discharge for a revoked caveat: %s", err)
- }
- if err = revoker.Revoke(cav.ID()); err != nil {
- t.Fatalf("failed to revoke again: %s", err)
- }
- if discharge, err := discharger.Discharge(r.NewContext(), cav, impetus); err == nil || discharge != nil {
- t.Fatalf("got a discharge for a doubly revoked caveat: %s", err)
- }
-}
diff --git a/services/identity/revocation/sql_database.go b/services/identity/revocation/sql_database.go
new file mode 100644
index 0000000..cfe5633
--- /dev/null
+++ b/services/identity/revocation/sql_database.go
@@ -0,0 +1,80 @@
+package revocation
+
+import (
+ "database/sql"
+ "encoding/hex"
+ "fmt"
+ "time"
+)
+
+type database interface {
+ InsertCaveat(thirdPartyCaveatID string, revocationCaveatID []byte) error
+ Revoke(thirdPartyCaveatID string) error
+ IsRevoked(revocationCaveatID []byte) (bool, error)
+ RevocationTime(thirdPartyCaveatID string) (*time.Time, error)
+}
+
+// Table with 3 columns:
+// (1) ThirdPartyCaveatID= string thirdPartyCaveatID.
+// (2) RevocationCaveatID= hex encoded revcationCaveatID.
+// (3) RevocationTime= time (if any) that the Caveat was revoked.
+type sqlDatabase struct {
+ insertCaveatStmt, revokeStmt, isRevokedStmt, revocationTimeStmt *sql.Stmt
+}
+
+func (s *sqlDatabase) InsertCaveat(thirdPartyCaveatID string, revocationCaveatID []byte) error {
+ _, err := s.insertCaveatStmt.Exec(thirdPartyCaveatID, hex.EncodeToString(revocationCaveatID))
+ return err
+}
+
+func (s *sqlDatabase) Revoke(thirdPartyCaveatID string) error {
+ _, err := s.revokeStmt.Exec(time.Now(), thirdPartyCaveatID)
+ return err
+}
+
+func (s *sqlDatabase) IsRevoked(revocationCaveatID []byte) (bool, error) {
+ rows, err := s.isRevokedStmt.Query(hex.EncodeToString(revocationCaveatID))
+ if err != nil {
+ return false, err
+ }
+ return rows.Next(), nil
+}
+
+func (s *sqlDatabase) RevocationTime(thirdPartyCaveatID string) (*time.Time, error) {
+ rows, err := s.revocationTimeStmt.Query(thirdPartyCaveatID)
+ if err != nil {
+ return nil, err
+ }
+ if rows.Next() {
+ var timestamp time.Time
+ if err := rows.Scan(×tamp); err != nil {
+ return nil, err
+ }
+ return ×tamp, nil
+ }
+ return nil, fmt.Errorf("the caveat (%v) was not revoked", thirdPartyCaveatID)
+}
+
+func newSQLDatabase(db *sql.DB, table string) (database, error) {
+ createStmt, err := db.Prepare(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s ( ThirdPartyCaveatID NVARCHAR(255), RevocationCaveatID NVARCHAR(255), RevocationTime DATETIME, PRIMARY KEY (ThirdPartyCaveatID), KEY (RevocationCaveatID) );", table))
+ if err != nil {
+ return nil, err
+ }
+ if _, err = createStmt.Exec(); err != nil {
+ return nil, err
+ }
+ insertCaveatStmt, err := db.Prepare(fmt.Sprintf("INSERT INTO %s (ThirdPartyCaveatID, RevocationCaveatID, RevocationTime) VALUES (?, ?, NULL)", table))
+ if err != nil {
+ return nil, err
+ }
+ revokeStmt, err := db.Prepare(fmt.Sprintf("UPDATE %s SET RevocationTime=? WHERE ThirdPartyCaveatID=?", table))
+ if err != nil {
+ return nil, err
+ }
+ isRevokedStmt, err := db.Prepare(fmt.Sprintf("SELECT 1 FROM %s WHERE RevocationCaveatID=? AND RevocationTime IS NOT NULL", table))
+ if err != nil {
+ return nil, err
+ }
+ revocationTimeStmt, err := db.Prepare(fmt.Sprintf("SELECT RevocationTime FROM %s WHERE ThirdPartyCaveatID=?", table))
+ return &sqlDatabase{insertCaveatStmt, revokeStmt, isRevokedStmt, revocationTimeStmt}, err
+}
diff --git a/services/identity/revocation/sql_database_test.go b/services/identity/revocation/sql_database_test.go
new file mode 100644
index 0000000..de9171b
--- /dev/null
+++ b/services/identity/revocation/sql_database_test.go
@@ -0,0 +1,73 @@
+package revocation
+
+import (
+ "encoding/hex"
+ "github.com/DATA-DOG/go-sqlmock"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestSQLDatabase(t *testing.T) {
+ db, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("failed to create new mock database stub: %v", err)
+ }
+ columns := []string{"ThirdPartyCaveatID", "RevocationCaveatID", "RevocationTime"}
+ sqlmock.ExpectExec("CREATE TABLE IF NOT EXISTS tableName (.+)").
+ WillReturnResult(sqlmock.NewResult(0, 1))
+ d, err := newSQLDatabase(db, "tableName")
+ if err != nil {
+ t.Fatalf("failed to create SQLDatabase: %v", err)
+ }
+
+ tpCavID, revCavID := "tpCavID", []byte("revCavID")
+ tpCavID2, revCavID2 := "tpCavID2", []byte("revCavID2")
+ encRevCavID := hex.EncodeToString(revCavID)
+ encRevCavID2 := hex.EncodeToString(revCavID2)
+ sqlmock.ExpectExec("INSERT INTO tableName (.+) VALUES (.+)").
+ WithArgs(tpCavID, encRevCavID).
+ WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
+ if err := d.InsertCaveat(tpCavID, revCavID); err != nil {
+ t.Errorf("failed to InsertCaveat into SQLDatabase: %v", err)
+ }
+
+ sqlmock.ExpectExec("INSERT INTO tableName (.+) VALUES (.+)").
+ WithArgs(tpCavID2, encRevCavID2).
+ WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
+ if err := d.InsertCaveat(tpCavID2, revCavID2); err != nil {
+ t.Errorf("second InsertCaveat into SQLDatabase failed: %v", err)
+ }
+
+ // Test Revocation
+ sqlmock.ExpectExec("UPDATE tableName SET RevocationTime=.+").
+ WillReturnResult(sqlmock.NewResult(0, 1))
+ if err := d.Revoke(tpCavID); err != nil {
+ t.Errorf("failed to Revoke Caveat: %v", err)
+ }
+
+ // Test IsRevoked returns true.
+ sqlmock.ExpectQuery("SELECT 1 FROM tableName").
+ WithArgs(encRevCavID).
+ WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 1, 1))
+ if revoked, err := d.IsRevoked(revCavID); err != nil || !revoked {
+ t.Errorf("expected revCavID to be revoked: err: (%v)", err)
+ }
+
+ // Test IsRevoked returns false.
+ sqlmock.ExpectQuery("SELECT 1 FROM tableName").
+ WithArgs(encRevCavID2).
+ WillReturnRows(sqlmock.NewRows(columns))
+ if revoked, err := d.IsRevoked(revCavID2); err != nil || revoked {
+ t.Errorf("expected revCavID to not be revoked: err: (%v)", err)
+ }
+
+ // Test RevocationTime.
+ revocationTime := time.Now()
+ sqlmock.ExpectQuery("SELECT RevocationTime FROM tableName").
+ WithArgs(tpCavID).
+ WillReturnRows(sqlmock.NewRows([]string{"RevocationTime"}).AddRow(revocationTime))
+ if got, err := d.RevocationTime(tpCavID); err != nil || !reflect.DeepEqual(*got, revocationTime) {
+ t.Errorf("got %v, expected %v: err : %v", got, revocationTime, err)
+ }
+}
diff --git a/services/identity/util/directory_store.go b/services/identity/util/directory_store.go
deleted file mode 100644
index 7c93638..0000000
--- a/services/identity/util/directory_store.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package util
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path/filepath"
-)
-
-// DirectoryStore implements a key-value store on a filesystem where data for each key is stored in its own file.
-// TODO(suharshs): When vstore is ready replace this with the veyron store.
-type DirectoryStore struct {
- dir string
-}
-
-func (s DirectoryStore) Exists(key string) bool {
- _, err := os.Stat(s.pathName(key))
- return !os.IsNotExist(err)
-}
-
-func (s DirectoryStore) Put(key, value string) error {
- return ioutil.WriteFile(s.pathName(key), []byte(value), 0600)
-}
-
-func (s DirectoryStore) Get(key string) (string, error) {
- bytes, err := ioutil.ReadFile(s.pathName(key))
- return string(bytes), err
-}
-
-func (s DirectoryStore) pathName(key string) string {
- return filepath.Join(string(s.dir), key)
-}
-
-// NewDirectoryStore returns a key-value store that uses one file per key,
-// and places all data in the provided directory.
-func NewDirectoryStore(dir string) (*DirectoryStore, error) {
- if len(dir) == 0 {
- return nil, fmt.Errorf("must provide non-empty directory name")
- }
- // Make the directory if it doesn't already exist.
- if err := os.MkdirAll(dir, 0700); err != nil {
- return nil, err
- }
-
- return &DirectoryStore{dir}, nil
-}
diff --git a/services/identity/util/sql_config.go b/services/identity/util/sql_config.go
new file mode 100644
index 0000000..f375d63
--- /dev/null
+++ b/services/identity/util/sql_config.go
@@ -0,0 +1,18 @@
+package util
+
+import (
+ "database/sql"
+ "fmt"
+ _ "github.com/go-sql-driver/mysql"
+)
+
+func DBFromConfigDatabase(database string) (*sql.DB, error) {
+ db, err := sql.Open("mysql", database+"?parseTime=true")
+ if err != nil {
+ return nil, fmt.Errorf("failed to create database with database(%v): %v", database, err)
+ }
+ if err := db.Ping(); err != nil {
+ return nil, err
+ }
+ return db, nil
+}
diff --git a/services/mgmt/node/impl/util_test.go b/services/mgmt/node/impl/util_test.go
index 4e8a866..62dc931 100644
--- a/services/mgmt/node/impl/util_test.go
+++ b/services/mgmt/node/impl/util_test.go
@@ -8,6 +8,7 @@
"reflect"
"runtime"
"sort"
+ "strings"
"testing"
"time"
@@ -159,10 +160,18 @@
if err != nil {
t.Fatalf("Resolve(%v) failed: %v", name, err)
}
- if want, got := replicas, len(results); want != got {
+
+ filteredResults := []string{}
+ for _, r := range results {
+ if strings.Index(r, "@tcp") != -1 {
+ filteredResults = append(filteredResults, r)
+ }
+ }
+ // We are going to get a websocket and a tcp endpoint for each replica.
+ if want, got := replicas, len(filteredResults); want != got {
t.Fatalf("Resolve(%v) expected %d result(s), got %d instead", name, want, got)
}
- return results
+ return filteredResults
}
// The following set of functions are convenience wrappers around Update and
diff --git a/services/mounttable/mounttabled/test.sh b/services/mounttable/mounttabled/test.sh
index bb65687..2ebdc7f 100755
--- a/services/mounttable/mounttabled/test.sh
+++ b/services/mounttable/mounttabled/test.sh
@@ -37,7 +37,8 @@
|| (cat "${MTLOG}"; shell_test::fail "line ${LINENO}: failed to identify endpoint")
# Get the neighborhood endpoint from the mounttable.
- NHEP=$(${MOUNTTABLE_BIN} glob "${EP}" nh | grep ^nh | cut -d' ' -f2) \
+ NHEP=$(${MOUNTTABLE_BIN} glob "${EP}" nh | grep ^nh | \
+ sed -e 's/ \/@.@ws@[^ ]* (TTL .m..s)//' | cut -d' ' -f2) \
|| (cat "${MTLOG}"; shell_test::fail "line ${LINENO}: failed to identify neighborhood endpoint")
# Mount objects and verify the result.
@@ -47,7 +48,8 @@
|| shell_test::fail "line ${LINENO}: failed to mount www.google.com"
# <mounttable>.Glob('*')
- GOT=$(${MOUNTTABLE_BIN} glob "${EP}" '*' | sed 's/TTL [^)]*/TTL XmXXs/' | sort) \
+ GOT=$(${MOUNTTABLE_BIN} glob "${EP}" '*' | \
+ sed -e 's/ \/@.@ws@[^ ]* (TTL .m..s)//' -e 's/TTL [^)]*/TTL XmXXs/' | sort) \
|| shell_test::fail "line ${LINENO}: failed to run mounttable"
WANT="[${EP}]
google /www.google.com:80 (TTL XmXXs)
diff --git a/tools/mgmt/nodex/acl_impl.go b/tools/mgmt/nodex/acl_impl.go
index 8348f5b..b4efc6a 100644
--- a/tools/mgmt/nodex/acl_impl.go
+++ b/tools/mgmt/nodex/acl_impl.go
@@ -16,7 +16,7 @@
Run: runGet,
Name: "get",
Short: "Get ACLs for the given target.",
- Long: "Get ACLs for the given target with friendly output. Also see getraw.",
+ Long: "Get ACLs for the given target.",
ArgsName: "<node manager name>",
ArgsLong: `
<node manager name> can be a Vanadium name for a node manager,
@@ -37,9 +37,8 @@
func (a byBlessing) Less(i, j int) bool { return a[i].blessing < a[j].blessing }
func runGet(cmd *cmdline.Command, args []string) error {
-
if expected, got := 1, len(args); expected != got {
- return cmd.UsageErrorf("install: incorrect number of arguments, expected %d, got %d", expected, got)
+ return cmd.UsageErrorf("get: incorrect number of arguments, expected %d, got %d", expected, got)
}
vanaName := args[0]
@@ -54,7 +53,6 @@
output = append(output, formattedACLEntry{string(k), "in", objACL.In[k].String()})
}
for k, _ := range objACL.NotIn {
-
output = append(output, formattedACLEntry{string(k), "nin", objACL.NotIn[k].String()})
}
@@ -74,7 +72,7 @@
Name: "acl",
Short: "Tool for creating associations between Vanadium blessings and a system account",
Long: `
-The associate tool facilitates managing blessing to system account associations.
+The acl tool manages ACLs on the node manger, installations and instances.
`,
Children: []*cmdline.Command{cmdGet},
}
diff --git a/tools/mgmt/nodex/acl_test.go b/tools/mgmt/nodex/acl_test.go
index fb629b9..ccd6d6b 100644
--- a/tools/mgmt/nodex/acl_test.go
+++ b/tools/mgmt/nodex/acl_test.go
@@ -46,7 +46,7 @@
t.Fatalf("%v, ouput: %v, error: %v", err)
}
if expected, got := "root/bob/... nin W\nroot/other in R\nroot/self/... in XRWADM", strings.TrimSpace(stdout.String()); got != expected {
- t.Fatalf("Unexpected output from list. Got %q, expected %q", got, expected)
+ t.Fatalf("Unexpected output from get. Got %q, expected %q", got, expected)
}
if got, expected := tape.Play(), []interface{}{"GetACL"}; !reflect.DeepEqual(expected, got) {
t.Errorf("invalid call sequence. Got %v, want %v", got, expected)