ref: Add resolver to rpc.RegisterProtocol.

When dialing some of our protocols we get a net.Conn with a different
(protocol,address) pair than what we asked for. E.g. anything that
does DNS resolution will convert the host portion of the address into
an IP address.

This is annoying, because many times we want to use the (protocol,address)
pair as a cache key; i.e. it's used as a VIF cache key. Since the resolved
(protocol,address) is only available after calling the dialer, we
currently waste resources by actually dialing the address, only to
discover the resolved address is already in our cache, and closing the
dialed connection.

This change introduces a Resolver func to rpc.RegisterProtocol that can
be used to perform the DNS lookup without creating a real connection.
For protocols that do not have a resolution step, it is just the identity
function.

closes vanadium/issues#431

MultiPart: 2/2

Change-Id: If9c0de686ea89598fb424fa0f09de91c320a1161
diff --git a/profiles/chrome/chromeinit.go b/profiles/chrome/chromeinit.go
index ccbfc83..0baedb3 100644
--- a/profiles/chrome/chromeinit.go
+++ b/profiles/chrome/chromeinit.go
@@ -26,7 +26,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.Dial, websocket.Listener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.Dial, websocket.Resolve, websocket.Listener)
 	commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime)
 }
 
diff --git a/profiles/fake/fake.go b/profiles/fake/fake.go
index ddd6a0d..279f4c4 100644
--- a/profiles/fake/fake.go
+++ b/profiles/fake/fake.go
@@ -32,7 +32,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 }
 
 func Init(ctx *context.T) (v23.Runtime, *context.T, v23.Shutdown, error) {
diff --git a/profiles/gce/init.go b/profiles/gce/init.go
index d71409d..99651de 100644
--- a/profiles/gce/init.go
+++ b/profiles/gce/init.go
@@ -34,7 +34,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 	commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
 }
 
diff --git a/profiles/genericinit.go b/profiles/genericinit.go
index 484f621..3851947 100644
--- a/profiles/genericinit.go
+++ b/profiles/genericinit.go
@@ -26,7 +26,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 	flags.SetDefaultHostPort(":0")
 	commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
 }
diff --git a/profiles/internal/lib/websocket/conn_nacl.go b/profiles/internal/lib/websocket/conn_nacl.go
index c9467cb..479d38f 100644
--- a/profiles/internal/lib/websocket/conn_nacl.go
+++ b/profiles/internal/lib/websocket/conn_nacl.go
@@ -46,6 +46,10 @@
 	return WebsocketConn(address, ws), nil
 }
 
+func Resolve(protocol, address string) (string, string, error) {
+	return "ws", address, nil
+}
+
 func (c *wrappedConn) Read(b []byte) (int, error) {
 	c.readLock.Lock()
 	defer c.readLock.Unlock()
diff --git a/profiles/internal/lib/websocket/hybrid.go b/profiles/internal/lib/websocket/hybrid.go
index 9e4e52a..4d8e3a2 100644
--- a/profiles/internal/lib/websocket/hybrid.go
+++ b/profiles/internal/lib/websocket/hybrid.go
@@ -30,6 +30,17 @@
 	return conn, nil
 }
 
+// HybridResolve performs a DNS resolution on the network, address and always
+// returns tcp as its Network.
+func HybridResolve(network, address string) (string, string, error) {
+	tcp := mapWebSocketToTCP[network]
+	tcpAddr, err := net.ResolveTCPAddr(tcp, address)
+	if err != nil {
+		return "", "", err
+	}
+	return tcp, tcpAddr.String(), nil
+}
+
 // HybridListener returns a net.Listener that supports both tcp and
 // websockets over the same, single, port. A listen address of
 // --v23.tcp.protocol=wsh --v23.tcp.address=127.0.0.1:8101 means
diff --git a/profiles/internal/lib/websocket/resolver.go b/profiles/internal/lib/websocket/resolver.go
new file mode 100644
index 0000000..1286cf6
--- /dev/null
+++ b/profiles/internal/lib/websocket/resolver.go
@@ -0,0 +1,21 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !nacl
+
+package websocket
+
+import (
+	"net"
+)
+
+// Resolve performs a DNS resolution on the provided protocol and address.
+func Resolve(protocol, address string) (string, string, error) {
+	tcp := mapWebSocketToTCP[protocol]
+	tcpAddr, err := net.ResolveTCPAddr(tcp, address)
+	if err != nil {
+		return "", "", err
+	}
+	return "ws", tcpAddr.String(), nil
+}
diff --git a/profiles/internal/rpc/full_test.go b/profiles/internal/rpc/full_test.go
index 12445ee..78fb4ee 100644
--- a/profiles/internal/rpc/full_test.go
+++ b/profiles/internal/rpc/full_test.go
@@ -2077,7 +2077,7 @@
 }
 
 func init() {
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 	security.RegisterCaveatValidator(fakeTimeCaveat, func(_ *context.T, _ security.Call, t int64) error {
 		if now := clock.Now(); now > t {
 			return fmt.Errorf("fakeTimeCaveat expired: now=%d > then=%d", now, t)
diff --git a/profiles/internal/rpc/protocols/tcp/init.go b/profiles/internal/rpc/protocols/tcp/init.go
index a6067b9..da259a7 100644
--- a/profiles/internal/rpc/protocols/tcp/init.go
+++ b/profiles/internal/rpc/protocols/tcp/init.go
@@ -16,9 +16,9 @@
 )
 
 func init() {
-	rpc.RegisterProtocol("tcp", tcpDial, tcpListen, "tcp4", "tcp6")
-	rpc.RegisterProtocol("tcp4", tcpDial, tcpListen)
-	rpc.RegisterProtocol("tcp6", tcpDial, tcpListen)
+	rpc.RegisterProtocol("tcp", tcpDial, tcpResolve, tcpListen, "tcp4", "tcp6")
+	rpc.RegisterProtocol("tcp4", tcpDial, tcpResolve, tcpListen)
+	rpc.RegisterProtocol("tcp6", tcpDial, tcpResolve, tcpListen)
 }
 
 func tcpDial(network, address string, timeout time.Duration) (net.Conn, error) {
@@ -32,6 +32,15 @@
 	return conn, nil
 }
 
+// tcpResolve performs a DNS resolution on the provided network and address.
+func tcpResolve(network, address string) (string, string, error) {
+	tcpAddr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return "", "", err
+	}
+	return tcpAddr.Network(), tcpAddr.String(), nil
+}
+
 // tcpListen returns a listener that sets KeepAlive on all accepted connections.
 func tcpListen(network, laddr string) (net.Listener, error) {
 	ln, err := net.Listen(network, laddr)
diff --git a/profiles/internal/rpc/protocols/ws/init.go b/profiles/internal/rpc/protocols/ws/init.go
index dde1c5b..6167bbf 100644
--- a/profiles/internal/rpc/protocols/ws/init.go
+++ b/profiles/internal/rpc/protocols/ws/init.go
@@ -12,7 +12,7 @@
 
 func init() {
 	// ws, ws4, ws6 represent websocket protocol instances.
-	rpc.RegisterProtocol("ws", websocket.Dial, websocket.Listener, "ws4", "ws6")
-	rpc.RegisterProtocol("ws4", websocket.Dial, websocket.Listener)
-	rpc.RegisterProtocol("ws6", websocket.Dial, websocket.Listener)
+	rpc.RegisterProtocol("ws", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4", "ws6")
+	rpc.RegisterProtocol("ws4", websocket.Dial, websocket.Resolve, websocket.Listener)
+	rpc.RegisterProtocol("ws6", websocket.Dial, websocket.Resolve, websocket.Listener)
 }
diff --git a/profiles/internal/rpc/protocols/wsh/init.go b/profiles/internal/rpc/protocols/wsh/init.go
index 159cfd2..ef706d4 100644
--- a/profiles/internal/rpc/protocols/wsh/init.go
+++ b/profiles/internal/rpc/protocols/wsh/init.go
@@ -13,7 +13,7 @@
 )
 
 func init() {
-	rpc.RegisterProtocol("wsh", websocket.HybridDial, websocket.HybridListener, "tcp4", "tcp6", "ws4", "ws6")
-	rpc.RegisterProtocol("wsh4", websocket.HybridDial, websocket.HybridListener, "tcp4", "ws4")
-	rpc.RegisterProtocol("wsh6", websocket.HybridDial, websocket.HybridListener, "tcp6", "ws6")
+	rpc.RegisterProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp4", "tcp6", "ws4", "ws6")
+	rpc.RegisterProtocol("wsh4", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp4", "ws4")
+	rpc.RegisterProtocol("wsh6", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener, "tcp6", "ws6")
 }
diff --git a/profiles/internal/rpc/protocols/wsh_nacl/init.go b/profiles/internal/rpc/protocols/wsh_nacl/init.go
index 8c72fd3..6540b75 100644
--- a/profiles/internal/rpc/protocols/wsh_nacl/init.go
+++ b/profiles/internal/rpc/protocols/wsh_nacl/init.go
@@ -15,7 +15,7 @@
 func init() {
 	// We limit wsh to ws since in general nacl does not allow direct access
 	// to TCP/UDP networking.
-	rpc.RegisterProtocol("wsh", websocket.Dial, websocket.Listener, "ws4", "ws6")
-	rpc.RegisterProtocol("wsh4", websocket.Dial, websocket.Listener, "ws4")
-	rpc.RegisterProtocol("wsh6", websocket.Dial, websocket.Listener, "ws6")
+	rpc.RegisterProtocol("wsh", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4", "ws6")
+	rpc.RegisterProtocol("wsh4", websocket.Dial, websocket.Resolve, websocket.Listener, "ws4")
+	rpc.RegisterProtocol("wsh6", websocket.Dial, websocket.Resolve, websocket.Listener, "ws6")
 }
diff --git a/profiles/internal/rpc/stream/errors.go b/profiles/internal/rpc/stream/errors.go
index 4aae55c..c5744f2 100644
--- a/profiles/internal/rpc/stream/errors.go
+++ b/profiles/internal/rpc/stream/errors.go
@@ -19,14 +19,15 @@
 // of their errors are intended to be used as arguments to higher level errors.
 var (
 	// TODO(cnicolaou): rename ErrSecurity to ErrAuth
-	ErrSecurity   = verror.Register(pkgPath+".errSecurity", verror.NoRetry, "{:3}")
-	ErrNotTrusted = verror.Register(pkgPath+".errNotTrusted", verror.NoRetry, "{:3}")
-	ErrNetwork    = verror.Register(pkgPath+".errNetwork", verror.NoRetry, "{:3}")
-	ErrDialFailed = verror.Register(pkgPath+".errDialFailed", verror.NoRetry, "{:3}")
-	ErrProxy      = verror.Register(pkgPath+".errProxy", verror.NoRetry, "{:3}")
-	ErrBadArg     = verror.Register(pkgPath+".errBadArg", verror.NoRetry, "{:3}")
-	ErrBadState   = verror.Register(pkgPath+".errBadState", verror.NoRetry, "{:3}")
-	ErrAborted    = verror.Register(pkgPath+".errAborted", verror.NoRetry, "{:3}")
+	ErrSecurity      = verror.Register(pkgPath+".errSecurity", verror.NoRetry, "{:3}")
+	ErrNotTrusted    = verror.Register(pkgPath+".errNotTrusted", verror.NoRetry, "{:3}")
+	ErrNetwork       = verror.Register(pkgPath+".errNetwork", verror.NoRetry, "{:3}")
+	ErrDialFailed    = verror.Register(pkgPath+".errDialFailed", verror.NoRetry, "{:3}")
+	ErrResolveFailed = verror.Register(pkgPath+".errResolveFailed", verror.NoRetry, "{:3}")
+	ErrProxy         = verror.Register(pkgPath+".errProxy", verror.NoRetry, "{:3}")
+	ErrBadArg        = verror.Register(pkgPath+".errBadArg", verror.NoRetry, "{:3}")
+	ErrBadState      = verror.Register(pkgPath+".errBadState", verror.NoRetry, "{:3}")
+	ErrAborted       = verror.Register(pkgPath+".errAborted", verror.NoRetry, "{:3}")
 )
 
 // NetError implements net.Error
diff --git a/profiles/internal/rpc/stream/manager/error_test.go b/profiles/internal/rpc/stream/manager/error_test.go
index ff51fe6..545d051 100644
--- a/profiles/internal/rpc/stream/manager/error_test.go
+++ b/profiles/internal/rpc/stream/manager/error_test.go
@@ -89,6 +89,10 @@
 	return mocknet.DialerWithOpts(opts, network, address, timeout)
 }
 
+func simpleResolver(network, address string) (string, string, error) {
+	return network, address, nil
+}
+
 func TestDialErrors(t *testing.T) {
 	_, shutdown := test.InitForTest()
 	defer shutdown()
@@ -100,8 +104,9 @@
 	// bad protocol
 	ep, _ := inaming.NewEndpoint(naming.FormatEndpoint("x", "127.0.0.1:2"))
 	_, err := client.Dial(ep, pclient)
-	if verror.ErrorID(err) != stream.ErrDialFailed.ID {
-		t.Fatalf("wrong error: %s", err)
+	// A bad protocol should result in a Resolve Error.
+	if verror.ErrorID(err) != stream.ErrResolveFailed.ID {
+		t.Errorf("wrong error: %v", err)
 	}
 	t.Log(err)
 
@@ -109,11 +114,11 @@
 	ep, _ = inaming.NewEndpoint(naming.FormatEndpoint("tcp", "127.0.0.1:2"))
 	_, err = client.Dial(ep, pclient)
 	if verror.ErrorID(err) != stream.ErrDialFailed.ID {
-		t.Fatalf("wrong error: %s", err)
+		t.Errorf("wrong error: %v", err)
 	}
 	t.Log(err)
 
-	rpc.RegisterProtocol("dropData", dropDataDialer, net.Listen)
+	rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
 
 	ln, sep, err := server.Listen("tcp", "127.0.0.1:0", pserver, pserver.BlessingStore().Default())
 	if err != nil {
@@ -129,7 +134,7 @@
 	}
 	_, err = client.Dial(cep, pclient)
 	if verror.ErrorID(err) != stream.ErrNetwork.ID {
-		t.Fatalf("wrong error: %s", err)
+		t.Errorf("wrong error: %v", err)
 	}
 	t.Log(err)
 }
diff --git a/profiles/internal/rpc/stream/manager/listener.go b/profiles/internal/rpc/stream/manager/listener.go
index 869e50f..3e0ec5d 100644
--- a/profiles/internal/rpc/stream/manager/listener.go
+++ b/profiles/internal/rpc/stream/manager/listener.go
@@ -209,8 +209,8 @@
 				conn.Close()
 				return
 			}
-			ln.vifs.Insert(vf)
-			ln.manager.vifs.Insert(vf)
+			ln.vifs.Insert(vf, conn.RemoteAddr().Network(), conn.RemoteAddr().String())
+			ln.manager.vifs.Insert(vf, conn.RemoteAddr().Network(), conn.RemoteAddr().String())
 
 			ln.vifLoops.Add(1)
 			vifLoop(vf, ln.q, func() {
@@ -223,6 +223,12 @@
 	}
 }
 
+func (ln *netListener) deleteVIF(vf *vif.VIF) {
+	vlog.VI(2).Infof("VIF %v is closed, removing from cache", vf)
+	ln.vifs.Delete(vf)
+	ln.manager.vifs.Delete(vf)
+}
+
 func (ln *netListener) Accept() (stream.Flow, error) {
 	item, err := ln.q.Get(nil)
 	switch {
@@ -252,12 +258,6 @@
 	return nil
 }
 
-func (ln *netListener) deleteVIF(vf *vif.VIF) {
-	vlog.VI(2).Infof("VIF %v is closed, removing from cache", vf)
-	ln.vifs.Delete(vf)
-	ln.manager.vifs.Delete(vf)
-}
-
 func (ln *netListener) String() string {
 	return fmt.Sprintf("%T: (%v, %v)", ln, ln.netLn.Addr().Network(), ln.netLn.Addr())
 }
diff --git a/profiles/internal/rpc/stream/manager/manager.go b/profiles/internal/rpc/stream/manager/manager.go
index 85c081f..ffbab60 100644
--- a/profiles/internal/rpc/stream/manager/manager.go
+++ b/profiles/internal/rpc/stream/manager/manager.go
@@ -85,8 +85,8 @@
 func (DialTimeout) RPCStreamVCOpt() {}
 func (DialTimeout) RPCClientOpt()   {}
 
-func dial(network, address string, timeout time.Duration) (net.Conn, error) {
-	if d, _, _ := rpc.RegisteredProtocol(network); d != nil {
+func dial(d rpc.DialerFunc, network, address string, timeout time.Duration) (net.Conn, error) {
+	if d != nil {
 		conn, err := d(network, address, timeout)
 		if err != nil {
 			return nil, verror.New(stream.ErrDialFailed, nil, err)
@@ -96,6 +96,17 @@
 	return nil, verror.New(stream.ErrDialFailed, nil, verror.New(errUnknownNetwork, nil, network))
 }
 
+func resolve(r rpc.ResolverFunc, network, address string) (string, string, error) {
+	if r != nil {
+		net, addr, err := r(network, address)
+		if err != nil {
+			return "", "", verror.New(stream.ErrResolveFailed, nil, err)
+		}
+		return net, addr, nil
+	}
+	return "", "", verror.New(stream.ErrResolveFailed, nil, verror.New(errUnknownNetwork, nil, network))
+}
+
 // FindOrDialVIF returns the network connection (VIF) to the provided address
 // from the cache in the manager. If not already present in the cache, a new
 // connection will be created using net.Dial.
@@ -109,37 +120,38 @@
 		}
 	}
 	addr := remote.Addr()
-	network, address := addr.Network(), addr.String()
-	if vf := m.vifs.Find(network, address); vf != nil {
-		return vf, nil
-	}
-	vlog.VI(1).Infof("(%q, %q) not in VIF cache. Dialing", network, address)
-	conn, err := dial(network, address, timeout)
-	if err != nil {
-		return nil, err
-	}
+	d, r, _, _ := rpc.RegisteredProtocol(addr.Network())
 	// (network, address) in the endpoint might not always match up
 	// with the key used in the vifs. For example:
 	// - conn, err := net.Dial("tcp", "www.google.com:80")
 	//   fmt.Println(conn.RemoteAddr()) // Might yield the corresponding IP address
 	// - Similarly, an unspecified IP address (net.IP.IsUnspecified) like "[::]:80"
 	//   might yield "[::1]:80" (loopback interface) in conn.RemoteAddr().
-	// Thus, look for VIFs with the resolved address as well.
-	resNetwork, resAddress := conn.RemoteAddr().Network(), conn.RemoteAddr().String()
-	if vf := m.vifs.BlockingFind(resNetwork, resAddress); vf != nil {
-		vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, resNetwork, resAddress)
-		conn.Close()
+	// Thus, look for VIFs with the resolved address.
+	network, address, err := resolve(r, addr.Network(), addr.String())
+	if err != nil {
+		return nil, err
+	}
+	vf, unblock := m.vifs.BlockingFind(network, address)
+	if vf != nil {
+		vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache.", addr.Network(), addr.String(), network, address)
 		return vf, nil
 	}
-	defer m.vifs.Unblock(resNetwork, resAddress)
+	defer unblock()
+
+	vlog.VI(1).Infof("(%q, %q) not in VIF cache. Dialing", network, address)
+	conn, err := dial(d, network, address, timeout)
+	if err != nil {
+		return nil, err
+	}
 
 	opts = append([]stream.VCOpt{vc.StartTimeout{defaultStartTimeout}}, opts...)
-	vf, err := vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
+	vf, err = vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
 	if err != nil {
 		conn.Close()
 		return nil, err
 	}
-	m.vifs.Insert(vf)
+	m.vifs.Insert(vf, network, address)
 	return vf, nil
 }
 
@@ -161,8 +173,13 @@
 	return nil, verror.NewErrInternal(nil) // Not reached
 }
 
+func (m *manager) deleteVIF(vf *vif.VIF) {
+	vlog.VI(2).Infof("%p: VIF %v is closed, removing from cache", m, vf)
+	m.vifs.Delete(vf)
+}
+
 func listen(protocol, address string) (net.Listener, error) {
-	if _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
+	if _, _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
 		ln, err := l(protocol, address)
 		if err != nil {
 			return nil, verror.New(stream.ErrNetwork, nil, err)
@@ -239,11 +256,6 @@
 	return ln, ep, nil
 }
 
-func (m *manager) deleteVIF(vf *vif.VIF) {
-	vlog.VI(2).Infof("%p: VIF %v is closed, removing from cache", m, vf)
-	m.vifs.Delete(vf)
-}
-
 func (m *manager) ShutdownEndpoint(remote naming.Endpoint) {
 	vifs := m.vifs.List()
 	total := 0
diff --git a/profiles/internal/rpc/stream/manager/manager_test.go b/profiles/internal/rpc/stream/manager/manager_test.go
index ff74b71..5042471 100644
--- a/profiles/internal/rpc/stream/manager/manager_test.go
+++ b/profiles/internal/rpc/stream/manager/manager_test.go
@@ -628,14 +628,12 @@
 	// We'd like an endpoint that contains an address that's different than the
 	// one used for the connection. In practice this is awkward to achieve since
 	// we don't want to listen on ":0" since that will annoy firewalls. Instead we
-	// listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
-	// 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
-	// the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
-	// from the address of the connection (e.g. "127.0.0.1:55324").
+	// create a endpoint with "localhost", which will result in an endpoint that
+	// doesn't contain 127.0.0.1.
 	_, port, _ := net.SplitHostPort(ep.Addr().String())
 	nep := &inaming.Endpoint{
 		Protocol: ep.Addr().Network(),
-		Address:  net.JoinHostPort("", port),
+		Address:  net.JoinHostPort("localhost", port),
 		RID:      ep.RoutingID(),
 	}
 
@@ -769,10 +767,13 @@
 	dialer := func(_, _ string, _ time.Duration) (net.Conn, error) {
 		return nil, fmt.Errorf("tn.Dial")
 	}
+	resolver := func(_, _ string) (string, string, error) {
+		return "", "", fmt.Errorf("tn.Resolve")
+	}
 	listener := func(_, _ string) (net.Listener, error) {
 		return nil, fmt.Errorf("tn.Listen")
 	}
-	rpc.RegisterProtocol("tn", dialer, listener)
+	rpc.RegisterProtocol("tn", dialer, resolver, listener)
 
 	_, _, err := server.Listen("tnx", "127.0.0.1:0", principal, blessings)
 	if err == nil || !strings.Contains(err.Error(), "unknown network: tnx") {
@@ -789,7 +790,7 @@
 		return net.Listen("tcp", addr)
 	}
 
-	if got, want := rpc.RegisterProtocol("tn", dialer, listener), true; got != want {
+	if got, want := rpc.RegisterProtocol("tn", dialer, resolver, listener), true; got != want {
 		t.Errorf("got %t, want %t", got, want)
 	}
 
@@ -799,8 +800,8 @@
 	}
 
 	_, err = client.Dial(ep, testutil.NewPrincipal("client"))
-	if err == nil || !strings.Contains(err.Error(), "tn.Dial") {
-		t.Fatal("expected error is missing (%v)", err)
+	if err == nil || !strings.Contains(err.Error(), "tn.Resolve") {
+		t.Fatalf("expected error is missing (%v)", err)
 	}
 }
 
@@ -942,17 +943,9 @@
 	}
 	go acceptLoop(ln)
 
-	// We'd like an endpoint that contains an address that's different than the
-	// one used for the connection. In practice this is awkward to achieve since
-	// we don't want to listen on ":0" since that will annoy firewalls. Instead we
-	// listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
-	// 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
-	// the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
-	// from the address of the connection (e.g. "127.0.0.1:55324").
-	_, port, _ := net.SplitHostPort(ep.Addr().String())
 	nep := &inaming.Endpoint{
 		Protocol: ep.Addr().Network(),
-		Address:  net.JoinHostPort("", port),
+		Address:  ep.Addr().String(),
 		RID:      ep.RoutingID(),
 	}
 
diff --git a/profiles/internal/rpc/stream/proxy/proxy.go b/profiles/internal/rpc/stream/proxy/proxy.go
index 1022b10..4959491 100644
--- a/profiles/internal/rpc/stream/proxy/proxy.go
+++ b/profiles/internal/rpc/stream/proxy/proxy.go
@@ -229,7 +229,7 @@
 	laddr := spec.Addrs[0]
 	network := laddr.Protocol
 	address := laddr.Address
-	_, listenFn, _ := rpc.RegisteredProtocol(network)
+	_, _, listenFn, _ := rpc.RegisteredProtocol(network)
 	if listenFn == nil {
 		return nil, verror.New(stream.ErrProxy, nil, verror.New(errUnknownNetwork, nil, network))
 	}
diff --git a/profiles/internal/rpc/stream/vif/set.go b/profiles/internal/rpc/stream/vif/set.go
index 497e43b..3032dfc 100644
--- a/profiles/internal/rpc/stream/vif/set.go
+++ b/profiles/internal/rpc/stream/vif/set.go
@@ -20,6 +20,7 @@
 	mu      sync.RWMutex
 	set     map[string][]*VIF // GUARDED_BY(mu)
 	started map[string]bool   // GUARDED_BY(mu)
+	keys    map[*VIF]string   // GUARDED_BY(mu)
 	cond    *sync.Cond
 }
 
@@ -28,6 +29,7 @@
 	s := &Set{
 		set:     make(map[string][]*VIF),
 		started: make(map[string]bool),
+		keys:    make(map[*VIF]string),
 	}
 	s.cond = sync.NewCond(&s.mu)
 	return s
@@ -37,36 +39,49 @@
 // is identified by the provided (network, address). Returns nil if there is no
 // such VIF.
 //
-// If BlockingFind returns nil, the caller is required to call Unblock, to avoid deadlock.
-// The network and address in Unblock must be the same as used in the BlockingFind call.
-// During this time, all new BlockingFind calls for this network and address will Block until
-// the corresponding Unblock call is made.
-func (s *Set) BlockingFind(network, address string) *VIF {
-	return s.find(network, address, true)
+// The caller is required to call the returned unblock function, to avoid deadlock.
+// Until the returned function is called, all new BlockingFind calls for this
+// network and address will block.
+func (s *Set) BlockingFind(network, address string) (*VIF, func()) {
+	if isNonDistinctConn(network, address) {
+		return nil, func() {}
+	}
+
+	k := key(network, address)
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	for s.started[k] {
+		s.cond.Wait()
+	}
+
+	_, _, _, p := rpc.RegisteredProtocol(network)
+	for _, n := range p {
+		if vifs := s.set[key(n, address)]; len(vifs) > 0 {
+			return vifs[rand.Intn(len(vifs))], func() {}
+		}
+	}
+
+	s.started[k] = true
+	return nil, func() { s.unblock(network, address) }
 }
 
-// Unblock marks the status of the network, address as no longer started, and
+// unblock marks the status of the network, address as no longer started, and
 // broadcasts waiting threads.
-func (s *Set) Unblock(network, address string) {
+func (s *Set) unblock(network, address string) {
 	s.mu.Lock()
 	delete(s.started, key(network, address))
 	s.cond.Broadcast()
 	s.mu.Unlock()
 }
 
-// Find returns a VIF where the remote end of the underlying network connection
-// is identified by the provided (network, address). Returns nil if there is no
-// such VIF.
-func (s *Set) Find(network, address string) *VIF {
-	return s.find(network, address, false)
-}
-
 // Insert adds a VIF to the set.
-func (s *Set) Insert(vif *VIF) {
-	addr := vif.conn.RemoteAddr()
-	k := key(addr.Network(), addr.String())
+func (s *Set) Insert(vif *VIF, network, address string) {
+	k := key(network, address)
 	s.mu.Lock()
 	defer s.mu.Unlock()
+	s.keys[vif] = k
 	vifs := s.set[k]
 	for _, v := range vifs {
 		if v == vif {
@@ -74,16 +89,13 @@
 		}
 	}
 	s.set[k] = append(vifs, vif)
-	vif.addSet(s)
 }
 
 // Delete removes a VIF from the set.
 func (s *Set) Delete(vif *VIF) {
-	vif.removeSet(s)
-	addr := vif.conn.RemoteAddr()
-	k := key(addr.Network(), addr.String())
 	s.mu.Lock()
 	defer s.mu.Unlock()
+	k := s.keys[vif]
 	vifs := s.set[k]
 	for i, v := range vifs {
 		if v == vif {
@@ -92,6 +104,7 @@
 			} else {
 				s.set[k] = append(vifs[:i], vifs[i+1:]...)
 			}
+			delete(s.keys, vif)
 			return
 		}
 	}
@@ -108,33 +121,6 @@
 	return l
 }
 
-func (s *Set) find(network, address string, blocking bool) *VIF {
-	if isNonDistinctConn(network, address) {
-		return nil
-	}
-
-	k := key(network, address)
-
-	s.mu.Lock()
-	defer s.mu.Unlock()
-
-	for blocking && s.started[k] {
-		s.cond.Wait()
-	}
-
-	_, _, p := rpc.RegisteredProtocol(network)
-	for _, n := range p {
-		if vifs := s.set[key(n, address)]; len(vifs) > 0 {
-			return vifs[rand.Intn(len(vifs))]
-		}
-	}
-
-	if blocking {
-		s.started[k] = true
-	}
-	return nil
-}
-
 func key(network, address string) string {
 	if network == "tcp" || network == "ws" {
 		host, _, _ := net.SplitHostPort(address)
diff --git a/profiles/internal/rpc/stream/vif/set_test.go b/profiles/internal/rpc/stream/vif/set_test.go
index bc95fa5..e9e597d 100644
--- a/profiles/internal/rpc/stream/vif/set_test.go
+++ b/profiles/internal/rpc/stream/vif/set_test.go
@@ -24,7 +24,10 @@
 var supportsIPv6 bool
 
 func init() {
-	rpc.RegisterProtocol("unix", net.DialTimeout, net.Listen)
+	simpleResolver := func(network, address string) (string, string, error) {
+		return network, address, nil
+	}
+	rpc.RegisterProtocol("unix", net.DialTimeout, simpleResolver, net.Listen)
 
 	// Check whether the platform supports IPv6.
 	ln, err := net.Listen("tcp6", "[::1]:0")
@@ -35,7 +38,7 @@
 }
 
 func newConn(network, address string) (net.Conn, net.Conn, error) {
-	dfunc, lfunc, _ := rpc.RegisteredProtocol(network)
+	dfunc, _, lfunc, _ := rpc.RegisteredProtocol(network)
 	ln, err := lfunc(network, address)
 	if err != nil {
 		return nil, nil, err
@@ -94,6 +97,12 @@
 	return d
 }
 
+func find(set *vif.Set, n, a string) *vif.VIF {
+	found, unblock := set.BlockingFind(n, a)
+	unblock()
+	return found
+}
+
 func TestSetBasic(t *testing.T) {
 	sockdir, err := ioutil.TempDir("", "TestSetBasic")
 	if err != nil {
@@ -143,23 +152,23 @@
 		}
 		a := c.RemoteAddr()
 
-		set.Insert(vf)
+		set.Insert(vf, a.Network(), a.String())
 		for _, n := range test.compatibles {
-			if found := set.Find(n, a.String()); found == nil {
-				t.Fatalf("%s: Got nil, but want [%v] on Find(%q, %q))", name, vf, n, a)
+			if found := find(set, n, a.String()); found == nil {
+				t.Fatalf("%s: Got nil, but want [%v] on find(%q, %q))", name, vf, n, a)
 			}
 		}
 
 		for _, n := range diff(all, test.compatibles) {
-			if v := set.Find(n, a.String()); v != nil {
-				t.Fatalf("%s: Got [%v], but want nil on Find(%q, %q))", name, v, n, a)
+			if v := find(set, n, a.String()); v != nil {
+				t.Fatalf("%s: Got [%v], but want nil on find(%q, %q))", name, v, n, a)
 			}
 		}
 
 		set.Delete(vf)
 		for _, n := range all {
-			if v := set.Find(n, a.String()); v != nil {
-				t.Fatalf("%s: Got [%v], but want nil on Find(%q, %q))", name, v, n, a)
+			if v := find(set, n, a.String()); v != nil {
+				t.Fatalf("%s: Got [%v], but want nil on find(%q, %q))", name, v, n, a)
 			}
 		}
 	}
@@ -186,17 +195,17 @@
 	}
 
 	set := vif.NewSet()
-	set.Insert(vf1)
-	if v := set.Find(a1.Network(), a1.String()); v != nil {
-		t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a1.Network(), a1)
+	set.Insert(vf1, a1.Network(), a1.String())
+	if v := find(set, a1.Network(), a1.String()); v != nil {
+		t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a1.Network(), a1)
 	}
 	if l := set.List(); len(l) != 1 || l[0] != vf1 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
 	}
 
-	set.Insert(vf2)
-	if v := set.Find(a2.Network(), a2.String()); v != nil {
-		t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a2.Network(), a2)
+	set.Insert(vf2, a2.Network(), a2.String())
+	if v := find(set, a2.Network(), a2.String()); v != nil {
+		t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a2.Network(), a2)
 	}
 	if l := set.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
@@ -247,17 +256,17 @@
 	}
 
 	set := vif.NewSet()
-	set.Insert(vf1)
-	if v := set.Find(a1.Network(), a1.String()); v != nil {
-		t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a1.Network(), a1)
+	set.Insert(vf1, a1.Network(), a1.String())
+	if v := find(set, a1.Network(), a1.String()); v != nil {
+		t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a1.Network(), a1)
 	}
 	if l := set.List(); len(l) != 1 || l[0] != vf1 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
 	}
 
-	set.Insert(vf2)
-	if v := set.Find(a2.Network(), a2.String()); v != nil {
-		t.Fatalf("Got [%v], but want nil on Find(%q, %q))", v, a2.Network(), a2)
+	set.Insert(vf2, a2.Network(), a2.String())
+	if v := find(set, a2.Network(), a2.String()); v != nil {
+		t.Fatalf("Got [%v], but want nil on find(%q, %q))", v, a2.Network(), a2)
 	}
 	if l := set.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
@@ -275,68 +284,38 @@
 
 func TestSetInsertDelete(t *testing.T) {
 	c1, s1 := net.Pipe()
-	c2, s2 := net.Pipe()
 	vf1, _, err := newVIF(c1, s1)
 	if err != nil {
 		t.Fatal(err)
 	}
-	vf2, _, err := newVIF(c2, s2)
-	if err != nil {
-		t.Fatal(err)
-	}
 
 	set1 := vif.NewSet()
-	set2 := vif.NewSet()
 
-	set1.Insert(vf1)
+	n1, a1 := c1.RemoteAddr().Network(), c1.RemoteAddr().String()
+	set1.Insert(vf1, n1, a1)
 	if l := set1.List(); len(l) != 1 || l[0] != vf1 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
 	}
-	set1.Insert(vf2)
-	if l := set1.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
-
-	set2.Insert(vf1)
-	set2.Insert(vf2)
 
 	set1.Delete(vf1)
-	if l := set1.List(); len(l) != 1 || l[0] != vf2 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
-	if l := set2.List(); len(l) != 2 || l[0] != vf1 || l[1] != vf2 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
-
-	vf1.Close()
-	if l := set1.List(); len(l) != 1 || l[0] != vf2 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
-	if l := set2.List(); len(l) != 1 || l[0] != vf2 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
-
-	vf2.Close()
 	if l := set1.List(); len(l) != 0 {
 		t.Errorf("Unexpected list of VIFs: %v", l)
 	}
-	if l := set2.List(); len(l) != 0 {
-		t.Errorf("Unexpected list of VIFs: %v", l)
-	}
 }
 
 func TestBlockingFind(t *testing.T) {
 	network, address := "tcp", "127.0.0.1:1234"
 	set := vif.NewSet()
 
-	set.BlockingFind(network, address)
+	_, unblock := set.BlockingFind(network, address)
 
 	ch := make(chan *vif.VIF, 1)
 
 	// set.BlockingFind should block until set.Unblock is called with the corresponding VIF,
 	// since set.BlockingFind was called earlier.
 	go func(ch chan *vif.VIF) {
-		ch <- set.BlockingFind(network, address)
+		vf, _ := set.BlockingFind(network, address)
+		ch <- vf
 	}(ch)
 
 	// set.BlockingFind for a different network and address should not block.
@@ -351,8 +330,8 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	set.Insert(vf)
-	set.Unblock(network, address)
+	set.Insert(vf, network, address)
+	unblock()
 
 	// Now the set.BlockingFind should have returned the correct vif.
 	if cachedVif := <-ch; cachedVif != vf {
diff --git a/profiles/internal/rpc/stream/vif/vif.go b/profiles/internal/rpc/stream/vif/vif.go
index 2342ddd..fb206b7 100644
--- a/profiles/internal/rpc/stream/vif/vif.go
+++ b/profiles/internal/rpc/stream/vif/vif.go
@@ -120,10 +120,6 @@
 	isClosed   bool // GUARDED_BY(isClosedMu)
 	onClose    func(*VIF)
 
-	// All sets that this VIF is in.
-	muSets sync.Mutex
-	sets   []*Set // GUARDED_BY(muSets)
-
 	// These counters track the number of messages sent and received by
 	// this VIF.
 	muMsgCounters sync.Mutex
@@ -345,27 +341,6 @@
 	return vc, nil
 }
 
-// addSet adds a set to the list of sets this VIF is in. This method is called
-// by Set.Insert().
-func (vif *VIF) addSet(s *Set) {
-	vif.muSets.Lock()
-	defer vif.muSets.Unlock()
-	vif.sets = append(vif.sets, s)
-}
-
-// removeSet removes a set from the list of sets this VIF is in. This method is
-// called by Set.Delete().
-func (vif *VIF) removeSet(s *Set) {
-	vif.muSets.Lock()
-	defer vif.muSets.Unlock()
-	for ix, vs := range vif.sets {
-		if vs == s {
-			vif.sets = append(vif.sets[:ix], vif.sets[ix+1:]...)
-			return
-		}
-	}
-}
-
 // Close closes all VCs (and thereby Flows) over the VIF and then closes the
 // underlying network connection after draining all pending writes on those
 // VCs.
@@ -378,14 +353,6 @@
 	vif.isClosed = true
 	vif.isClosedMu.Unlock()
 
-	vif.muSets.Lock()
-	sets := vif.sets
-	vif.sets = nil
-	vif.muSets.Unlock()
-	for _, s := range sets {
-		s.Delete(vif)
-	}
-
 	vlog.VI(1).Infof("Closing VIF %s", vif)
 	// Stop accepting new VCs.
 	vif.StopAccepting()
diff --git a/profiles/internal/rpc/test/client_test.go b/profiles/internal/rpc/test/client_test.go
index c44a635..9665d0d 100644
--- a/profiles/internal/rpc/test/client_test.go
+++ b/profiles/internal/rpc/test/client_test.go
@@ -411,6 +411,10 @@
 	return mocknet.DialerWithOpts(opts, network, address, timeout)
 }
 
+func simpleResolver(network, address string) (string, string, error) {
+	return network, address, nil
+}
+
 func TestStartCallBadProtocol(t *testing.T) {
 	ctx, shutdown := newCtx()
 	defer shutdown()
@@ -423,7 +427,7 @@
 		logErrors(t, msg, true, false, false, err)
 	}
 
-	rpc.RegisterProtocol("dropData", dropDataDialer, net.Listen)
+	rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
 
 	// The following test will fail due to a broken connection.
 	// We need to run mount table and servers with no security to use
diff --git a/profiles/internal/testing/mocks/mocknet/mocknet.go b/profiles/internal/testing/mocks/mocknet/mocknet.go
index 0b73e4b..259563c 100644
--- a/profiles/internal/testing/mocks/mocknet/mocknet.go
+++ b/profiles/internal/testing/mocks/mocknet/mocknet.go
@@ -72,7 +72,7 @@
 //  dialer := func(network, address string, timeout time.Duration) (net.Conn, error) {
 //	    return mocknet.DialerWithOpts(mocknet.Opts{UnderlyingProtocol:"tcp"}, network, address, timeout)
 //  }
-// rpc.RegisterProtocol("brkDial", dialer, net.Listen)
+// rpc.RegisterProtocol("brkDial", dialer, resolver, net.Listen)
 //
 func DialerWithOpts(opts Opts, network, address string, timeout time.Duration) (net.Conn, error) {
 	protocol := opts.UnderlyingProtocol
diff --git a/profiles/internal/testing/mocks/mocknet/mocknet_test.go b/profiles/internal/testing/mocks/mocknet/mocknet_test.go
index f65ca4b..91c445e 100644
--- a/profiles/internal/testing/mocks/mocknet/mocknet_test.go
+++ b/profiles/internal/testing/mocks/mocknet/mocknet_test.go
@@ -272,7 +272,11 @@
 		return mocknet.DialerWithOpts(opts, network, address, timeout)
 	}
 
-	rpc.RegisterProtocol("dropControl", dropControlDialer, net.Listen)
+	simpleResolver := func(network, address string) (string, string, error) {
+		return network, address, nil
+	}
+
+	rpc.RegisterProtocol("dropControl", dropControlDialer, simpleResolver, net.Listen)
 
 	server, fn := initServer(t, ctx)
 	defer fn()
diff --git a/profiles/roaming/roaminginit.go b/profiles/roaming/roaminginit.go
index c6b37f6..bff4202 100644
--- a/profiles/roaming/roaminginit.go
+++ b/profiles/roaming/roaminginit.go
@@ -48,7 +48,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 	commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
 }
 
diff --git a/profiles/static/staticinit.go b/profiles/static/staticinit.go
index dc93eae..c18a028 100644
--- a/profiles/static/staticinit.go
+++ b/profiles/static/staticinit.go
@@ -31,7 +31,7 @@
 
 func init() {
 	v23.RegisterProfile(Init)
-	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridListener)
+	rpc.RegisterUnknownProtocol("wsh", websocket.HybridDial, websocket.HybridResolve, websocket.HybridListener)
 	commonFlags = flags.CreateAndRegister(flag.CommandLine, flags.Runtime, flags.Listen)
 }
 
diff --git a/services/agent/internal/unixfd/unixfd.go b/services/agent/internal/unixfd/unixfd.go
index 6d4328d..6497af7 100644
--- a/services/agent/internal/unixfd/unixfd.go
+++ b/services/agent/internal/unixfd/unixfd.go
@@ -35,7 +35,7 @@
 const Network string = "unixfd"
 
 func init() {
-	rpc.RegisterProtocol(Network, unixFDConn, unixFDListen)
+	rpc.RegisterProtocol(Network, unixFDConn, unixFDResolve, unixFDListen)
 }
 
 // singleConnListener implements net.Listener for an already-connected socket.
@@ -136,6 +136,10 @@
 	return c.addr
 }
 
+func unixFDResolve(_, address string) (string, string, error) {
+	return Network, address, nil
+}
+
 func unixFDListen(protocol, address string) (net.Listener, error) {
 	conn, err := unixFDConn(protocol, address, 0)
 	if err != nil {