Merge "services/device/internal/suid: Set attr.Dir correctly"
diff --git a/lib/security/blessingroots.go b/lib/security/blessingroots.go
index e5748bf..5c0a7b2 100644
--- a/lib/security/blessingroots.go
+++ b/lib/security/blessingroots.go
@@ -16,6 +16,8 @@
"v.io/x/ref/lib/security/serialization"
)
+var errRootsAddPattern = verror.Register(pkgPath+".errRootsAddPattern", verror.NoRetry, "{1:}{2:} a root cannot be recognized for all blessing names (i.e., the pattern '...')")
+
// blessingRoots implements security.BlessingRoots.
type blessingRoots struct {
persistedData SerializerReaderWriter
@@ -33,6 +35,9 @@
}
func (br *blessingRoots) Add(root security.PublicKey, pattern security.BlessingPattern) error {
+ if pattern == security.AllPrincipals {
+ return verror.New(errRootsAddPattern, nil)
+ }
key, err := stateMapKey(root)
if err != nil {
return err
diff --git a/lib/security/blessingroots_test.go b/lib/security/blessingroots_test.go
index a8ef94f..ad6afd5 100644
--- a/lib/security/blessingroots_test.go
+++ b/lib/security/blessingroots_test.go
@@ -30,6 +30,9 @@
}
func (t *rootsTester) add(br security.BlessingRoots) error {
+ if err := br.Add(t[0], security.AllPrincipals); err == nil {
+ return fmt.Errorf("Add( , %v) succeeded, expected it to fail", security.AllPrincipals)
+ }
testdata := []struct {
root security.PublicKey
pattern security.BlessingPattern
diff --git a/profiles/chrome/chromeinit.go b/profiles/chrome/chromeinit.go
index ccbfc83..0baedb3 100644
--- a/profiles/chrome/chromeinit.go
+++ b/profiles/chrome/chromeinit.go
@@ -26,7 +26,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.Dial, websocket.Listener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.Dial, websocket.Resolve, websocket.Listener)
commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime)
}
diff --git a/profiles/fake/fake.go b/profiles/fake/fake.go
index ddd6a0d..279f4c4 100644
--- a/profiles/fake/fake.go
+++ b/profiles/fake/fake.go
@@ -32,7 +32,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
}
func Init(ctx *context.T) (v23.Runtime, *context.T, v23.Shutdown, error) {
diff --git a/profiles/gce/init.go b/profiles/gce/init.go
index d71409d..99651de 100644
--- a/profiles/gce/init.go
+++ b/profiles/gce/init.go
@@ -34,7 +34,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
}
diff --git a/profiles/genericinit.go b/profiles/genericinit.go
index 484f621..3851947 100644
--- a/profiles/genericinit.go
+++ b/profiles/genericinit.go
@@ -26,7 +26,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
flags.SetDefaultHostPort(":0")
commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
}
diff --git a/profiles/internal/lib/websocket/conn_nacl.go b/profiles/internal/lib/websocket/conn_nacl.go
index c9467cb..479d38f 100644
--- a/profiles/internal/lib/websocket/conn_nacl.go
+++ b/profiles/internal/lib/websocket/conn_nacl.go
@@ -46,6 +46,10 @@
return WebsocketConn(address, ws), nil
}
+func Resolve(protocol, address string) (string, string, error) {
+ return "ws", address, nil
+}
+
func (c *wrappedConn) Read(b []byte) (int, error) {
c.readLock.Lock()
defer c.readLock.Unlock()
diff --git a/profiles/internal/lib/websocket/hybrid.go b/profiles/internal/lib/websocket/hybrid.go
index 9e4e52a..4d8e3a2 100644
--- a/profiles/internal/lib/websocket/hybrid.go
+++ b/profiles/internal/lib/websocket/hybrid.go
@@ -30,6 +30,17 @@
return conn, nil
}
+// HybridResolve performs a DNS resolution on the network, address and always
+// returns tcp as its Network.
+func HybridResolve(network, address string) (string, string, error) {
+ tcp := mapWebSocketToTCP[network]
+ tcpAddr, err := net.ResolveTCPAddr(tcp, address)
+ if err != nil {
+ return "", "", err
+ }
+ return tcp, tcpAddr.String(), nil
+}
+
// HybridListener returns a net.Listener that supports both tcp and
// websockets over the same, single, port. A listen address of
// --v23.tcp.protocol=wsh --v23.tcp.address=127.0.0.1:8101 means
diff --git a/profiles/internal/lib/websocket/resolver.go b/profiles/internal/lib/websocket/resolver.go
new file mode 100644
index 0000000..1286cf6
--- /dev/null
+++ b/profiles/internal/lib/websocket/resolver.go
@@ -0,0 +1,21 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !nacl
+
+package websocket
+
+import (
+ "net"
+)
+
+// Resolve performs a DNS resolution on the provided protocol and address.
+func Resolve(protocol, address string) (string, string, error) {
+ tcp := mapWebSocketToTCP[protocol]
+ tcpAddr, err := net.ResolveTCPAddr(tcp, address)
+ if err != nil {
+ return "", "", err
+ }
+ return "ws", tcpAddr.String(), nil
+}
diff --git a/profiles/internal/rpc/full_test.go b/profiles/internal/rpc/full_test.go
index 12445ee..78fb4ee 100644
--- a/profiles/internal/rpc/full_test.go
+++ b/profiles/internal/rpc/full_test.go
@@ -2077,7 +2077,7 @@
}
func init() {
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
security.RegisterCaveatValidator(fakeTimeCaveat, func(_ *context.T, _ security.Call, t int64) error {
if now := clock.Now(); now > t {
return fmt.Errorf("fakeTimeCaveat expired: now=%d > then=%d", now, t)
diff --git a/profiles/internal/rpc/protocols/tcp/init.go b/profiles/internal/rpc/protocols/tcp/init.go
index a6067b9..da259a7 100644
--- a/profiles/internal/rpc/protocols/tcp/init.go
+++ b/profiles/internal/rpc/protocols/tcp/init.go
@@ -16,9 +16,9 @@
)
func init() {
- rpc.RegisterProtocol("tcp", tcpDial, tcpListen, "tcp4", "tcp6")
- rpc.RegisterProtocol("tcp4", tcpDial, tcpListen)
- rpc.RegisterProtocol("tcp6", tcpDial, tcpListen)
+ rpc.RegisterProtocol("tcp", tcpDial, tcpResolve, tcpListen, "tcp4", "tcp6")
+ rpc.RegisterProtocol("tcp4", tcpDial, tcpResolve, tcpListen)
+ rpc.RegisterProtocol("tcp6", tcpDial, tcpResolve, tcpListen)
}
func tcpDial(network, address string, timeout time.Duration) (net.Conn, error) {
@@ -32,6 +32,15 @@
return conn, nil
}
+// tcpResolve performs a DNS resolution on the provided network and address.
+func tcpResolve(network, address string) (string, string, error) {
+ tcpAddr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return "", "", err
+ }
+ return tcpAddr.Network(), tcpAddr.String(), nil
+}
+
// tcpListen returns a listener that sets KeepAlive on all accepted connections.
func tcpListen(network, laddr string) (net.Listener, error) {
ln, err := net.Listen(network, laddr)
diff --git a/profiles/internal/rpc/protocols/ws/init.go b/profiles/internal/rpc/protocols/ws/init.go
index dde1c5b..6167bbf 100644
--- a/profiles/internal/rpc/protocols/ws/init.go
+++ b/profiles/internal/rpc/protocols/ws/init.go
@@ -12,7 +12,7 @@
func init() {
// ws, ws4, ws6 represent websocket protocol instances.
- rpc.RegisterProtocol("ws", websocket.Dial, websocket.Listener, "ws4", "ws6")
- rpc.RegisterProtocol("ws4", websocket.Dial, websocket.Listener)
- rpc.RegisterProtocol("ws6", websocket.Dial, websocket.Listener)
+ rpc.RegisterProtocol("ws", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4", "ws6")
+ rpc.RegisterProtocol("ws4", websocket.Dial, websocket.Resolve, websocket.Listener)
+ rpc.RegisterProtocol("ws6", websocket.Dial, websocket.Resolve, websocket.Listener)
}
diff --git a/profiles/internal/rpc/protocols/wsh/init.go b/profiles/internal/rpc/protocols/wsh/init.go
index 159cfd2..ef706d4 100644
--- a/profiles/internal/rpc/protocols/wsh/init.go
+++ b/profiles/internal/rpc/protocols/wsh/init.go
@@ -13,7 +13,7 @@
)
func init() {
- rpc.RegisterProtocol("wsh", websocket.HybridDial, websocket.HybridListener, "tcp4", "tcp6", "ws4", "ws6")
- rpc.RegisterProtocol("wsh4", websocket.HybridDial, websocket.HybridListener, "tcp4", "ws4")
- rpc.RegisterProtocol("wsh6", websocket.HybridDial, websocket.HybridListener, "tcp6", "ws6")
+ rpc.RegisterProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp4", "tcp6", "ws4", "ws6")
+ rpc.RegisterProtocol("wsh4", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp4", "ws4")
+ rpc.RegisterProtocol("wsh6", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp6", "ws6")
}
diff --git a/profiles/internal/rpc/protocols/wsh_nacl/init.go b/profiles/internal/rpc/protocols/wsh_nacl/init.go
index 8c72fd3..6540b75 100644
--- a/profiles/internal/rpc/protocols/wsh_nacl/init.go
+++ b/profiles/internal/rpc/protocols/wsh_nacl/init.go
@@ -15,7 +15,7 @@
func init() {
// We limit wsh to ws since in general nacl does not allow direct access
// to TCP/UDP networking.
- rpc.RegisterProtocol("wsh", websocket.Dial, websocket.Listener, "ws4", "ws6")
- rpc.RegisterProtocol("wsh4", websocket.Dial, websocket.Listener, "ws4")
- rpc.RegisterProtocol("wsh6", websocket.Dial, websocket.Listener, "ws6")
+ rpc.RegisterProtocol("wsh", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4", "ws6")
+ rpc.RegisterProtocol("wsh4", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4")
+ rpc.RegisterProtocol("wsh6", websocket.Dial, websocket.Resolve, websocket.Listener, "ws6")
}
diff --git a/profiles/internal/rpc/stream/errors.go b/profiles/internal/rpc/stream/errors.go
index a51ee5e..2a54607 100644
--- a/profiles/internal/rpc/stream/errors.go
+++ b/profiles/internal/rpc/stream/errors.go
@@ -20,14 +20,15 @@
// of their errors are intended to be used as arguments to higher level errors.
var (
// TODO(cnicolaou): rename ErrSecurity to ErrAuth
- ErrSecurity = verror.Register(pkgPath+".errSecurity", verror.NoRetry, "{:3}")
- ErrNotTrusted = verror.Register(pkgPath+".errNotTrusted", verror.NoRetry, "{:3}")
- ErrNetwork = verror.Register(pkgPath+".errNetwork", verror.NoRetry, "{:3}")
- ErrDialFailed = verror.Register(pkgPath+".errDialFailed", verror.NoRetry, "{:3}")
- ErrProxy = verror.Register(pkgPath+".errProxy", verror.NoRetry, "{:3}")
- ErrBadArg = verror.Register(pkgPath+".errBadArg", verror.NoRetry, "{:3}")
- ErrBadState = verror.Register(pkgPath+".errBadState", verror.NoRetry, "{:3}")
- ErrAborted = verror.Register(pkgPath+".errAborted", verror.NoRetry, "{:3}")
+ ErrSecurity = verror.Register(pkgPath+".errSecurity", verror.NoRetry, "{:3}")
+ ErrNotTrusted = verror.Register(pkgPath+".errNotTrusted", verror.NoRetry, "{:3}")
+ ErrNetwork = verror.Register(pkgPath+".errNetwork", verror.NoRetry, "{:3}")
+ ErrDialFailed = verror.Register(pkgPath+".errDialFailed", verror.NoRetry, "{:3}")
+ ErrResolveFailed = verror.Register(pkgPath+".errResolveFailed", verror.NoRetry, "{:3}")
+ ErrProxy = verror.Register(pkgPath+".errProxy", verror.NoRetry, "{:3}")
+ ErrBadArg = verror.Register(pkgPath+".errBadArg", verror.NoRetry, "{:3}")
+ ErrBadState = verror.Register(pkgPath+".errBadState", verror.NoRetry, "{:3}")
+ ErrAborted = verror.Register(pkgPath+".errAborted", verror.NoRetry, "{:3}")
)
// NetError implements net.Error
diff --git a/profiles/internal/rpc/stream/manager/error_test.go b/profiles/internal/rpc/stream/manager/error_test.go
index ff51fe6..545d051 100644
--- a/profiles/internal/rpc/stream/manager/error_test.go
+++ b/profiles/internal/rpc/stream/manager/error_test.go
@@ -89,6 +89,10 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
+func simpleResolver(network, address string) (string, string, error) {
+ return network, address, nil
+}
+
func TestDialErrors(t *testing.T) {
_, shutdown := test.InitForTest()
defer shutdown()
@@ -100,8 +104,9 @@
// bad protocol
ep, _ := inaming.NewEndpoint(naming.FormatEndpoint("x", "127.0.0.1:2"))
_, err := client.Dial(ep, pclient)
- if verror.ErrorID(err) != stream.ErrDialFailed.ID {
- t.Fatalf("wrong error: %s", err)
+ // A bad protocol should result in a Resolve Error.
+ if verror.ErrorID(err) != stream.ErrResolveFailed.ID {
+ t.Errorf("wrong error: %v", err)
}
t.Log(err)
@@ -109,11 +114,11 @@
ep, _ = inaming.NewEndpoint(naming.FormatEndpoint("tcp", "127.0.0.1:2"))
_, err = client.Dial(ep, pclient)
if verror.ErrorID(err) != stream.ErrDialFailed.ID {
- t.Fatalf("wrong error: %s", err)
+ t.Errorf("wrong error: %v", err)
}
t.Log(err)
- rpc.RegisterProtocol("dropData", dropDataDialer, net.Listen)
+ rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
ln, sep, err := server.Listen("tcp", "127.0.0.1:0", pserver, pserver.BlessingStore().Default())
if err != nil {
@@ -129,7 +134,7 @@
}
_, err = client.Dial(cep, pclient)
if verror.ErrorID(err) != stream.ErrNetwork.ID {
- t.Fatalf("wrong error: %s", err)
+ t.Errorf("wrong error: %v", err)
}
t.Log(err)
}
diff --git a/profiles/internal/rpc/stream/manager/listener.go b/profiles/internal/rpc/stream/manager/listener.go
index 869e50f..3e0ec5d 100644
--- a/profiles/internal/rpc/stream/manager/listener.go
+++ b/profiles/internal/rpc/stream/manager/listener.go
@@ -209,8 +209,8 @@
conn.Close()
return
}
- ln.vifs.Insert(vf)
- ln.manager.vifs.Insert(vf)
+ ln.vifs.Insert(vf, conn.RemoteAddr().Network(), conn.RemoteAddr().String())
+ ln.manager.vifs.Insert(vf, conn.RemoteAddr().Network(), conn.RemoteAddr().String())
ln.vifLoops.Add(1)
vifLoop(vf, ln.q, func() {
@@ -223,6 +223,12 @@
}
}
+func (ln *netListener) deleteVIF(vf *vif.VIF) {
+ vlog.VI(2).Infof("VIF %v is closed, removing from cache", vf)
+ ln.vifs.Delete(vf)
+ ln.manager.vifs.Delete(vf)
+}
+
func (ln *netListener) Accept() (stream.Flow, error) {
item, err := ln.q.Get(nil)
switch {
@@ -252,12 +258,6 @@
return nil
}
-func (ln *netListener) deleteVIF(vf *vif.VIF) {
- vlog.VI(2).Infof("VIF %v is closed, removing from cache", vf)
- ln.vifs.Delete(vf)
- ln.manager.vifs.Delete(vf)
-}
-
func (ln *netListener) String() string {
return fmt.Sprintf("%T: (%v, %v)", ln, ln.netLn.Addr().Network(), ln.netLn.Addr())
}
diff --git a/profiles/internal/rpc/stream/manager/manager.go b/profiles/internal/rpc/stream/manager/manager.go
index b5a3807..afa3afa 100644
--- a/profiles/internal/rpc/stream/manager/manager.go
+++ b/profiles/internal/rpc/stream/manager/manager.go
@@ -87,8 +87,8 @@
defer vlog.LogCall()() // AUTO-GENERATED, DO NOT EDIT, MUST BE FIRST STATEMENT
}
-func dial(network, address string, timeout time.Duration) (net.Conn, error) {
- if d, _, _ := rpc.RegisteredProtocol(network); d != nil {
+func dial(d rpc.DialerFunc, network, address string, timeout time.Duration) (net.Conn, error) {
+ if d != nil {
conn, err := d(network, address, timeout)
if err != nil {
return nil, verror.New(stream.ErrDialFailed, nil, err)
@@ -98,6 +98,17 @@
return nil, verror.New(stream.ErrDialFailed, nil, verror.New(errUnknownNetwork, nil, network))
}
+func resolve(r rpc.ResolverFunc, network, address string) (string, string, error) {
+ if r != nil {
+ net, addr, err := r(network, address)
+ if err != nil {
+ return "", "", verror.New(stream.ErrResolveFailed, nil, err)
+ }
+ return net, addr, nil
+ }
+ return "", "", verror.New(stream.ErrResolveFailed, nil, verror.New(errUnknownNetwork, nil, network))
+}
+
// FindOrDialVIF returns the network connection (VIF) to the provided address
// from the cache in the manager. If not already present in the cache, a new
// connection will be created using net.Dial.
@@ -111,37 +122,38 @@
}
}
addr := remote.Addr()
- network, address := addr.Network(), addr.String()
- if vf := m.vifs.Find(network, address); vf != nil {
- return vf, nil
- }
- vlog.VI(1).Infof("(%q, %q) not in VIF cache. Dialing", network, address)
- conn, err := dial(network, address, timeout)
- if err != nil {
- return nil, err
- }
+ d, r, _, _ := rpc.RegisteredProtocol(addr.Network())
// (network, address) in the endpoint might not always match up
// with the key used in the vifs. For example:
// - conn, err := net.Dial("tcp", "www.google.com:80")
// fmt.Println(conn.RemoteAddr()) // Might yield the corresponding IP address
// - Similarly, an unspecified IP address (net.IP.IsUnspecified) like "[::]:80"
// might yield "[::1]:80" (loopback interface) in conn.RemoteAddr().
- // Thus, look for VIFs with the resolved address as well.
- resNetwork, resAddress := conn.RemoteAddr().Network(), conn.RemoteAddr().String()
- if vf := m.vifs.BlockingFind(resNetwork, resAddress); vf != nil {
- vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, resNetwork, resAddress)
- conn.Close()
+ // Thus, look for VIFs with the resolved address.
+ network, address, err := resolve(r, addr.Network(), addr.String())
+ if err != nil {
+ return nil, err
+ }
+ vf, unblock := m.vifs.BlockingFind(network, address)
+ if vf != nil {
+ vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache.", addr.Network(), addr.String(), network, address)
return vf, nil
}
- defer m.vifs.Unblock(resNetwork, resAddress)
+ defer unblock()
+
+ vlog.VI(1).Infof("(%q, %q) not in VIF cache. Dialing", network, address)
+ conn, err := dial(d, network, address, timeout)
+ if err != nil {
+ return nil, err
+ }
opts = append([]stream.VCOpt{vc.StartTimeout{defaultStartTimeout}}, opts...)
- vf, err := vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
+ vf, err = vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
if err != nil {
conn.Close()
return nil, err
}
- m.vifs.Insert(vf)
+ m.vifs.Insert(vf, network, address)
return vf, nil
}
@@ -163,8 +175,13 @@
return nil, verror.NewErrInternal(nil) // Not reached
}
+func (m *manager) deleteVIF(vf *vif.VIF) {
+ vlog.VI(2).Infof("%p: VIF %v is closed, removing from cache", m, vf)
+ m.vifs.Delete(vf)
+}
+
func listen(protocol, address string) (net.Listener, error) {
- if _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
+ if _, _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
ln, err := l(protocol, address)
if err != nil {
return nil, verror.New(stream.ErrNetwork, nil, err)
@@ -241,11 +258,6 @@
return ln, ep, nil
}
-func (m *manager) deleteVIF(vf *vif.VIF) {
- vlog.VI(2).Infof("%p: VIF %v is closed, removing from cache", m, vf)
- m.vifs.Delete(vf)
-}
-
func (m *manager) ShutdownEndpoint(remote naming.Endpoint) {
vifs := m.vifs.List()
total := 0
diff --git a/profiles/internal/rpc/stream/manager/manager_test.go b/profiles/internal/rpc/stream/manager/manager_test.go
index ff74b71..5042471 100644
--- a/profiles/internal/rpc/stream/manager/manager_test.go
+++ b/profiles/internal/rpc/stream/manager/manager_test.go
@@ -628,14 +628,12 @@
// We'd like an endpoint that contains an address that's different than the
// one used for the connection. In practice this is awkward to achieve since
// we don't want to listen on ":0" since that will annoy firewalls. Instead we
- // listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
- // 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
- // the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
- // from the address of the connection (e.g. "127.0.0.1:55324").
+ // create a endpoint with "localhost", which will result in an endpoint that
+ // doesn't contain 127.0.0.1.
_, port, _ := net.SplitHostPort(ep.Addr().String())
nep := &inaming.Endpoint{
Protocol: ep.Addr().Network(),
- Address: net.JoinHostPort("", port),
+ Address: net.JoinHostPort("localhost", port),
RID: ep.RoutingID(),
}
@@ -769,10 +767,13 @@
dialer := func(_, _ string, _ time.Duration) (net.Conn, error) {
return nil, fmt.Errorf("tn.Dial")
}
+ resolver := func(_, _ string) (string, string, error) {
+ return "", "", fmt.Errorf("tn.Resolve")
+ }
listener := func(_, _ string) (net.Listener, error) {
return nil, fmt.Errorf("tn.Listen")
}
- rpc.RegisterProtocol("tn", dialer, listener)
+ rpc.RegisterProtocol("tn", dialer, resolver, listener)
_, _, err := server.Listen("tnx", "127.0.0.1:0", principal, blessings)
if err == nil || !strings.Contains(err.Error(), "unknown network: tnx") {
@@ -789,7 +790,7 @@
return net.Listen("tcp", addr)
}
- if got, want := rpc.RegisterProtocol("tn", dialer, listener), true; got != want {
+ if got, want := rpc.RegisterProtocol("tn", dialer, resolver, listener), true; got != want {
t.Errorf("got %t, want %t", got, want)
}
@@ -799,8 +800,8 @@
}
_, err = client.Dial(ep, testutil.NewPrincipal("client"))
- if err == nil || !strings.Contains(err.Error(), "tn.Dial") {
- t.Fatal("expected error is missing (%v)", err)
+ if err == nil || !strings.Contains(err.Error(), "tn.Resolve") {
+ t.Fatalf("expected error is missing (%v)", err)
}
}
@@ -942,17 +943,9 @@
}
go acceptLoop(ln)
- // We'd like an endpoint that contains an address that's different than the
- // one used for the connection. In practice this is awkward to achieve since
- // we don't want to listen on ":0" since that will annoy firewalls. Instead we
- // listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
- // 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
- // the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
- // from the address of the connection (e.g. "127.0.0.1:55324").
- _, port, _ := net.SplitHostPort(ep.Addr().String())
nep := &inaming.Endpoint{
Protocol: ep.Addr().Network(),
- Address: net.JoinHostPort("", port),
+ Address: ep.Addr().String(),
RID: ep.RoutingID(),
}
diff --git a/profiles/internal/rpc/stream/proxy/proxy.go b/profiles/internal/rpc/stream/proxy/proxy.go
index 1022b10..4959491 100644
--- a/profiles/internal/rpc/stream/proxy/proxy.go
+++ b/profiles/internal/rpc/stream/proxy/proxy.go
@@ -229,7 +229,7 @@
laddr := spec.Addrs[0]
network := laddr.Protocol
address := laddr.Address
- _, listenFn, _ := rpc.RegisteredProtocol(network)
+ _, _, listenFn, _ := rpc.RegisteredProtocol(network)
if listenFn == nil {
return nil, verror.New(stream.ErrProxy, nil, verror.New(errUnknownNetwork, nil, network))
}
diff --git a/profiles/internal/rpc/stream/vif/set.go b/profiles/internal/rpc/stream/vif/set.go
index 497e43b..3032dfc 100644
--- a/profiles/internal/rpc/stream/vif/set.go
+++ b/profiles/internal/rpc/stream/vif/set.go
@@ -20,6 +20,7 @@
mu sync.RWMutex
set map[string][]*VIF // GUARDED_BY(mu)
started map[string]bool // GUARDED_BY(mu)
+ keys map[*VIF]string // GUARDED_BY(mu)
cond *sync.Cond
}
@@ -28,6 +29,7 @@
s := &Set{
set: make(map[string][]*VIF),
started: make(map[string]bool),
+ keys: make(map[*VIF]string),
}
s.cond = sync.NewCond(&s.mu)
return s
@@ -37,36 +39,49 @@
// is identified by the provided (network, address). Returns nil if there is no
// such VIF.
//
-// If BlockingFind returns nil, the caller is required to call Unblock, to avoid deadlock.
-// The network and address in Unblock must be the same as used in the BlockingFind call.
-// During this time, all new BlockingFind calls for this network and address will Block until
-// the corresponding Unblock call is made.
-func (s *Set) BlockingFind(network, address string) *VIF {
- return s.find(network, address, true)
+// The caller is required to call the returned unblock function, to avoid deadlock.
+// Until the returned function is called, all new BlockingFind calls for this
+// network and address will block.
+func (s *Set) BlockingFind(network, address string) (*VIF, func()) {
+ if isNonDistinctConn(network, address) {
+ return nil, func() {}
+ }
+
+ k := key(network, address)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for s.started[k] {
+ s.cond.Wait()
+ }
+
+ _, _, _, p := rpc.RegisteredProtocol(network)
+ for _, n := range p {
+ if vifs := s.set[key(n, address)]; len(vifs) > 0 {
+ return vifs[rand.Intn(len(vifs))], func() {}
+ }
+ }
+
+ s.started[k] = true
+ return nil, func() { s.unblock(network, address) }
}
-// Unblock marks the status of the network, address as no longer started, and
+// unblock marks the status of the network, address as no longer started, and
// broadcasts waiting threads.
-func (s *Set) Unblock(network, address string) {
+func (s *Set) unblock(network, address string) {
s.mu.Lock()
delete(s.started, key(network, address))
s.cond.Broadcast()
s.mu.Unlock()
}
-// Find returns a VIF where the remote end of the underlying network connection
-// is identified by the provided (network, address). Returns nil if there is no
-// such VIF.
-func (s *Set) Find(network, address string) *VIF {
- return s.find(network, address, false)
-}
-
// Insert adds a VIF to the set.
-func (s *Set) Insert(vif *VIF) {
- addr := vif.conn.RemoteAddr()
- k := key(addr.Network(), addr.String())
+func (s *Set) Insert(vif *VIF, network, address string) {
+ k := key(network, address)
s.mu.Lock()
defer s.mu.Unlock()
+ s.keys[vif] = k
vifs := s.set[k]
for _, v := range vifs {
if v == vif {
@@ -74,16 +89,13 @@
}
}
s.set[k] = append(vifs, vif)
- vif.addSet(s)
}
// Delete removes a VIF from the set.
func (s *Set) Delete(vif *VIF) {
- vif.removeSet(s)
- addr := vif.conn.RemoteAddr()
- k := key(addr.Network(), addr.String())
s.mu.Lock()
defer s.mu.Unlock()
+ k := s.keys[vif]
vifs := s.set[k]
for i, v := range vifs {
if v == vif {
@@ -92,6 +104,7 @@
} else {
s.set[k] = append(vifs[:i], vifs[i+1:]...)
}
+ delete(s.keys, vif)
return
}
}
@@ -108,33 +121,6 @@
return l
}
-func (s *Set) find(network, address string, blocking bool) *VIF {
- if isNonDistinctConn(network, address) {
- return nil
- }
-
- k := key(network, address)
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- for blocking && s.started[k] {
- s.cond.Wait()
- }
-
- _, _, p := rpc.RegisteredProtocol(network)
- for _, n := range p {
- if vifs := s.set[key(n, address)]; len(vifs) > 0 {
- return vifs[rand.Intn(len(vifs))]
- }
- }
-
- if blocking {
- s.started[k] = true
- }
- return nil
-}
-
func key(network, address string) string {
if network == "tcp" || network == "ws" {
host, _, _ := net.SplitHostPort(address)
diff --git a/profiles/internal/rpc/stream/vif/set_test.go b/profiles/internal/rpc/stream/vif/set_test.go
index bc95fa5..e9e597d 100644
--- a/profiles/internal/rpc/stream/vif/set_test.go
+++ b/profiles/internal/rpc/stream/vif/set_test.go
@@ -24,7 +24,10 @@
var supportsIPv6 bool
func init() {
- rpc.RegisterProtocol("unix", net.DialTimeout, net.Listen)
+ simpleResolver := func(network, address string) (string, string, error) {
+ return network, address, nil
+ }
+ rpc.RegisterProtocol("unix", net.DialTimeout, simpleResolver, net.Listen)
// Check whether the platform supports IPv6.
ln, err := net.Listen("tcp6", "[::1]:0")
@@ -35,7 +38,7 @@
}
func newConn(network, address string) (net.Conn, net.Conn, error) {
- dfunc, lfunc, _ := rpc.RegisteredProtocol(network)
+ dfunc, _, lfunc, _ := rpc.RegisteredProtocol(network)
ln, err := lfunc(network, address)
if err != nil {
return nil, nil, err
@@ -94,6 +97,12 @@
return d
}
+func find(set *vif.Set, n, a string) *vif.VIF {
+ found, unblock := set.BlockingFind(n, a)
+ unblock()
+ return found
+}
+
func TestSetBasic(t *testing.T) {
sockdir, err := ioutil.TempDir("", "TestSetBasic")
if err != nil {
@@ -143,23 +152,23 @@
}
a := c.RemoteAddr()
- set.Insert(vf)
+ set.Insert(vf, a.Network(), a.String())
for _, n := range test.compatibles {
- if found := set.Find(n, a.String()); found == nil {
- t.Fatalf("%s: Got nil, but want [%v] on Find(%q, %q))", name, vf, n, a)
+ if found := find(set, n, a.String()); found == nil {
+ t.Fatalf("%s: Got nil, but want [%v] on find(%q, %q))", name, vf, n, a)
}
}
for _, n := range diff(all, test.compatibles) {
- if v := set.Find(n, a.String()); v != nil {
- t.Fatalf("%s: Got [%v], but want nil on Find(%q, %q))", name, v, n, a)
+ if v := find(set, n, a.String()); v != nil {
+ t.Fatalf("%s: Got [%v], but want nil on find(%q, %q))", name, v, n, a)
}
}
set.Delete(vf)
for _, n := range all {
- if v := set.Find(n, a.String()); v != nil {
- t.Fatalf("%s: Got [%v], but want nil on Find(%q, %q))", name, v, n, a)
+ if v := find(set, n, a.String()); v != nil {
+ t.Fatalf("%s: Got [%v], but want nil on find(%q, %q))", name, v, n, a)
}
}
}
@@ -186,17 +195,17 @@
}
set := vif.NewSet()
- set.Insert(vf1)
- if v := set.Find(a1.Network(), a1.String()); v != nil {
- t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a1.Network(), a1)
+ set.Insert(vf1, a1.Network(), a1.String())
+ if v := find(set, a1.Network(), a1.String()); v != nil {
+ t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a1.Network(), a1)
}
if l := set.List(); len(l) != 1 || l[0] != vf1 {
t.Errorf("Unexpected list of VIFs: %v", l)
}
- set.Insert(vf2)
- if v := set.Find(a2.Network(), a2.String()); v != nil {
- t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a2.Network(), a2)
+ set.Insert(vf2, a2.Network(), a2.String())
+ if v := find(set, a2.Network(), a2.String()); v != nil {
+ t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a2.Network(), a2)
}
if l := set.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
t.Errorf("Unexpected list of VIFs: %v", l)
@@ -247,17 +256,17 @@
}
set := vif.NewSet()
- set.Insert(vf1)
- if v := set.Find(a1.Network(), a1.String()); v != nil {
- t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a1.Network(), a1)
+ set.Insert(vf1, a1.Network(), a1.String())
+ if v := find(set, a1.Network(), a1.String()); v != nil {
+ t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a1.Network(), a1)
}
if l := set.List(); len(l) != 1 || l[0] != vf1 {
t.Errorf("Unexpected list of VIFs: %v", l)
}
- set.Insert(vf2)
- if v := set.Find(a2.Network(), a2.String()); v != nil {
- t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a2.Network(), a2)
+ set.Insert(vf2, a2.Network(), a2.String())
+ if v := find(set, a2.Network(), a2.String()); v != nil {
+ t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a2.Network(), a2)
}
if l := set.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
t.Errorf("Unexpected list of VIFs: %v", l)
@@ -275,68 +284,38 @@
func TestSetInsertDelete(t *testing.T) {
c1, s1 := net.Pipe()
- c2, s2 := net.Pipe()
vf1, _, err := newVIF(c1, s1)
if err != nil {
t.Fatal(err)
}
- vf2, _, err := newVIF(c2, s2)
- if err != nil {
- t.Fatal(err)
- }
set1 := vif.NewSet()
- set2 := vif.NewSet()
- set1.Insert(vf1)
+ n1, a1 := c1.RemoteAddr().Network(), c1.RemoteAddr().String()
+ set1.Insert(vf1, n1, a1)
if l := set1.List(); len(l) != 1 || l[0] != vf1 {
t.Errorf("Unexpected list of VIFs: %v", l)
}
- set1.Insert(vf2)
- if l := set1.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
-
- set2.Insert(vf1)
- set2.Insert(vf2)
set1.Delete(vf1)
- if l := set1.List(); len(l) != 1 || l[0] != vf2 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
- if l := set2.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
-
- vf1.Close()
- if l := set1.List(); len(l) != 1 || l[0] != vf2 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
- if l := set2.List(); len(l) != 1 || l[0] != vf2 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
-
- vf2.Close()
if l := set1.List(); len(l) != 0 {
t.Errorf("Unexpected list of VIFs: %v", l)
}
- if l := set2.List(); len(l) != 0 {
- t.Errorf("Unexpected list of VIFs: %v", l)
- }
}
func TestBlockingFind(t *testing.T) {
network, address := "tcp", "127.0.0.1:1234"
set := vif.NewSet()
- set.BlockingFind(network, address)
+ _, unblock := set.BlockingFind(network, address)
ch := make(chan *vif.VIF, 1)
// set.BlockingFind should block until set.Unblock is called with the corresponding VIF,
// since set.BlockingFind was called earlier.
go func(ch chan *vif.VIF) {
- ch <- set.BlockingFind(network, address)
+ vf, _ := set.BlockingFind(network, address)
+ ch <- vf
}(ch)
// set.BlockingFind for a different network and address should not block.
@@ -351,8 +330,8 @@
if err != nil {
t.Fatal(err)
}
- set.Insert(vf)
- set.Unblock(network, address)
+ set.Insert(vf, network, address)
+ unblock()
// Now the set.BlockingFind should have returned the correct vif.
if cachedVif := <-ch; cachedVif != vf {
diff --git a/profiles/internal/rpc/stream/vif/vif.go b/profiles/internal/rpc/stream/vif/vif.go
index 2342ddd..fb206b7 100644
--- a/profiles/internal/rpc/stream/vif/vif.go
+++ b/profiles/internal/rpc/stream/vif/vif.go
@@ -120,10 +120,6 @@
isClosed bool // GUARDED_BY(isClosedMu)
onClose func(*VIF)
- // All sets that this VIF is in.
- muSets sync.Mutex
- sets []*Set // GUARDED_BY(muSets)
-
// These counters track the number of messages sent and received by
// this VIF.
muMsgCounters sync.Mutex
@@ -345,27 +341,6 @@
return vc, nil
}
-// addSet adds a set to the list of sets this VIF is in. This method is called
-// by Set.Insert().
-func (vif *VIF) addSet(s *Set) {
- vif.muSets.Lock()
- defer vif.muSets.Unlock()
- vif.sets = append(vif.sets, s)
-}
-
-// removeSet removes a set from the list of sets this VIF is in. This method is
-// called by Set.Delete().
-func (vif *VIF) removeSet(s *Set) {
- vif.muSets.Lock()
- defer vif.muSets.Unlock()
- for ix, vs := range vif.sets {
- if vs == s {
- vif.sets = append(vif.sets[:ix], vif.sets[ix+1:]...)
- return
- }
- }
-}
-
// Close closes all VCs (and thereby Flows) over the VIF and then closes the
// underlying network connection after draining all pending writes on those
// VCs.
@@ -378,14 +353,6 @@
vif.isClosed = true
vif.isClosedMu.Unlock()
- vif.muSets.Lock()
- sets := vif.sets
- vif.sets = nil
- vif.muSets.Unlock()
- for _, s := range sets {
- s.Delete(vif)
- }
-
vlog.VI(1).Infof("Closing VIF %s", vif)
// Stop accepting new VCs.
vif.StopAccepting()
diff --git a/profiles/internal/rpc/test/client_test.go b/profiles/internal/rpc/test/client_test.go
index c44a635..9665d0d 100644
--- a/profiles/internal/rpc/test/client_test.go
+++ b/profiles/internal/rpc/test/client_test.go
@@ -411,6 +411,10 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
+func simpleResolver(network, address string) (string, string, error) {
+ return network, address, nil
+}
+
func TestStartCallBadProtocol(t *testing.T) {
ctx, shutdown := newCtx()
defer shutdown()
@@ -423,7 +427,7 @@
logErrors(t, msg, true, false, false, err)
}
- rpc.RegisterProtocol("dropData", dropDataDialer, net.Listen)
+ rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
// The following test will fail due to a broken connection.
// We need to run mount table and servers with no security to use
diff --git a/profiles/internal/testing/mocks/mocknet/mocknet.go b/profiles/internal/testing/mocks/mocknet/mocknet.go
index 0b73e4b..259563c 100644
--- a/profiles/internal/testing/mocks/mocknet/mocknet.go
+++ b/profiles/internal/testing/mocks/mocknet/mocknet.go
@@ -72,7 +72,7 @@
// dialer := func(network, address string, timeout time.Duration) (net.Conn, error) {
// return mocknet.DialerWithOpts(mocknet.Opts{UnderlyingProtocol:"tcp"}, network, address, timeout)
// }
-// rpc.RegisterProtocol("brkDial", dialer, net.Listen)
+// rpc.RegisterProtocol("brkDial", dialer, resolver, net.Listen)
//
func DialerWithOpts(opts Opts, network, address string, timeout time.Duration) (net.Conn, error) {
protocol := opts.UnderlyingProtocol
diff --git a/profiles/internal/testing/mocks/mocknet/mocknet_test.go b/profiles/internal/testing/mocks/mocknet/mocknet_test.go
index f65ca4b..91c445e 100644
--- a/profiles/internal/testing/mocks/mocknet/mocknet_test.go
+++ b/profiles/internal/testing/mocks/mocknet/mocknet_test.go
@@ -272,7 +272,11 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
- rpc.RegisterProtocol("dropControl", dropControlDialer, net.Listen)
+ simpleResolver := func(network, address string) (string, string, error) {
+ return network, address, nil
+ }
+
+ rpc.RegisterProtocol("dropControl", dropControlDialer, simpleResolver, net.Listen)
server, fn := initServer(t, ctx)
defer fn()
diff --git a/profiles/roaming/roaminginit.go b/profiles/roaming/roaminginit.go
index c6b37f6..bff4202 100644
--- a/profiles/roaming/roaminginit.go
+++ b/profiles/roaming/roaminginit.go
@@ -48,7 +48,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
}
diff --git a/profiles/static/staticinit.go b/profiles/static/staticinit.go
index dc93eae..c18a028 100644
--- a/profiles/static/staticinit.go
+++ b/profiles/static/staticinit.go
@@ -31,7 +31,7 @@
func init() {
v23.RegisterProfile(Init)
- rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+ rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
}
diff --git a/services/agent/internal/unixfd/unixfd.go b/services/agent/internal/unixfd/unixfd.go
index 6d4328d..6497af7 100644
--- a/services/agent/internal/unixfd/unixfd.go
+++ b/services/agent/internal/unixfd/unixfd.go
@@ -35,7 +35,7 @@
const Network string = "unixfd"
func init() {
- rpc.RegisterProtocol(Network, unixFDConn, unixFDListen)
+ rpc.RegisterProtocol(Network, unixFDConn, unixFDResolve, unixFDListen)
}
// singleConnListener implements net.Listener for an already-connected socket.
@@ -136,6 +136,10 @@
return c.addr
}
+func unixFDResolve(_, address string) (string, string, error) {
+ return Network, address, nil
+}
+
func unixFDListen(protocol, address string) (net.Listener, error) {
conn, err := unixFDConn(protocol, address, 0)
if err != nil {