Merge "rpc/stream/vif: Concurrent mgr.FindOrDialVIFs should share VIF cache."
diff --git a/profiles/internal/rpc/stream/manager/manager.go b/profiles/internal/rpc/stream/manager/manager.go
index 16e951a..85c081f 100644
--- a/profiles/internal/rpc/stream/manager/manager.go
+++ b/profiles/internal/rpc/stream/manager/manager.go
@@ -125,11 +125,13 @@
 	// - Similarly, an unspecified IP address (net.IP.IsUnspecified) like "[::]:80"
 	//   might yield "[::1]:80" (loopback interface) in conn.RemoteAddr().
 	// Thus, look for VIFs with the resolved address as well.
-	if vf := m.vifs.Find(conn.RemoteAddr().Network(), conn.RemoteAddr().String()); vf != nil {
-		vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, conn.RemoteAddr().Network(), conn.RemoteAddr())
+	resNetwork, resAddress := conn.RemoteAddr().Network(), conn.RemoteAddr().String()
+	if vf := m.vifs.BlockingFind(resNetwork, resAddress); vf != nil {
+		vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, resNetwork, resAddress)
 		conn.Close()
 		return vf, nil
 	}
+	defer m.vifs.Unblock(resNetwork, resAddress)
 
 	opts = append([]stream.VCOpt{vc.StartTimeout{defaultStartTimeout}}, opts...)
 	vf, err := vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
@@ -137,14 +139,6 @@
 		conn.Close()
 		return nil, err
 	}
-	// TODO(ashankar): If two goroutines are simultaneously invoking
-	// manager.Dial, it is possible that two VIFs are inserted into m.vifs
-	// for the same remote network address. This is normally not a problem,
-	// but can be troublesome if the remote endpoint corresponds to a
-	// proxy, since the proxy requires a single network connection per
-	// routing id. Figure out a way to handle this cleanly. One option is
-	// to have only a single VIF per remote network address - have to think
-	// that through.
 	m.vifs.Insert(vf)
 	return vf, nil
 }
diff --git a/profiles/internal/rpc/stream/manager/manager_test.go b/profiles/internal/rpc/stream/manager/manager_test.go
index 5ae22f3..ff74b71 100644
--- a/profiles/internal/rpc/stream/manager/manager_test.go
+++ b/profiles/internal/rpc/stream/manager/manager_test.go
@@ -924,3 +924,53 @@
 	t.Logf("Server FD limit:%d", nfiles)
 	t.Logf("Client connection attempts: %d", nattempts)
 }
+
+func TestConcurrentDials(t *testing.T) {
+	// Concurrent Dials to the same network, address should only result in one VIF.
+	server := InternalNew(naming.FixedRoutingID(0x55555555))
+	client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+	principal := testutil.NewPrincipal("test")
+
+	// Using "tcp4" instead of "tcp" because the latter can end up with IPv6
+	// addresses and our Google Compute Engine integration test machines cannot
+	// resolve IPv6 addresses.
+	// As of April 2014, https://developers.google.com/compute/docs/networking
+	// said that IPv6 is not yet supported.
+	ln, ep, err := server.Listen("tcp4", "127.0.0.1:0", principal, principal.BlessingStore().Default())
+	if err != nil {
+		t.Fatal(err)
+	}
+	go acceptLoop(ln)
+
+	// We'd like an endpoint that contains an address that's different than the
+	// one used for the connection. In practice this is awkward to achieve since
+	// we don't want to listen on ":0" since that will annoy firewalls. Instead we
+	// listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
+	// 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
+	// the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
+	// from the address of the connection (e.g. "127.0.0.1:55324").
+	_, port, _ := net.SplitHostPort(ep.Addr().String())
+	nep := &inaming.Endpoint{
+		Protocol: ep.Addr().Network(),
+		Address:  net.JoinHostPort("", port),
+		RID:      ep.RoutingID(),
+	}
+
+	// Dial multiple VCs
+	errCh := make(chan error, 10)
+	for i := 0; i < 10; i++ {
+		go func() {
+			_, err = client.Dial(nep, testutil.NewPrincipal("client"))
+			errCh <- err
+		}()
+	}
+	for i := 0; i < 10; i++ {
+		if err = <-errCh; err != nil {
+			t.Fatal(err)
+		}
+	}
+	// They should all be on the same VIF.
+	if n := numVIFs(client); n != 1 {
+		t.Errorf("Client has %d VIFs, want 1\n%v", n, debugString(client))
+	}
+}
diff --git a/profiles/internal/rpc/stream/vif/set.go b/profiles/internal/rpc/stream/vif/set.go
index c6629f0..b966735 100644
--- a/profiles/internal/rpc/stream/vif/set.go
+++ b/profiles/internal/rpc/stream/vif/set.go
@@ -17,46 +17,47 @@
 // connection.  Multiple goroutines can invoke methods on the Set
 // simultaneously.
 type Set struct {
-	mu  sync.RWMutex
-	set map[string][]*VIF
+	mu      sync.RWMutex
+	set     map[string][]*VIF // GUARDED_BY(mu)
+	started map[string]bool   // GUARDED_BY(mu)
+	cond    *sync.Cond
 }
 
 // NewSet returns a new Set of VIFs.
 func NewSet() *Set {
-	return &Set{set: make(map[string][]*VIF)}
+	s := &Set{
+		set:     make(map[string][]*VIF),
+		started: make(map[string]bool),
+	}
+	s.cond = sync.NewCond(&s.mu)
+	return s
+}
+
+// BlockingFind returns a VIF where the remote end of the underlying network connection
+// is identified by the provided (network, address). Returns nil if there is no
+// such VIF.
+//
+// If BlockingFind returns nil, the caller is required to call Unblock, to avoid deadlock.
+// The network and address in Unblock must be the same as used in the BlockingFind call.
+// During this time, all new BlockingFind calls for this network and address will Block until
+// the corresponding Unblock call is made.
+func (s *Set) BlockingFind(network, address string) *VIF {
+	return s.find(network, address, true)
+}
+
+// Unblock broadcasts all threads
+func (s *Set) Unblock(network, address string) {
+	s.mu.Lock()
+	delete(s.started, key(network, address))
+	s.cond.Broadcast()
+	s.mu.Unlock()
 }
 
 // Find returns a VIF where the remote end of the underlying network connection
 // is identified by the provided (network, address). Returns nil if there is no
 // such VIF.
-//
-// If there are multiple VIFs established to the same remote network address,
-// Find will randomly return one of them.
 func (s *Set) Find(network, address string) *VIF {
-	if len(address) == 0 ||
-		(network == "pipe" && address == "pipe") ||
-		(runtime.GOOS == "linux" && network == "unix" && address == "@") { // autobind
-		// Some network connections (like those created with net.Pipe or Unix sockets)
-		// do not end up with distinct net.Addrs on distinct net.Conns. For those cases,
-		// avoid the cache collisions by disabling cache lookups for them.
-		return nil
-	}
-
-	var keys []string
-	_, _, p := rpc.RegisteredProtocol(network)
-	for _, n := range p {
-		keys = append(keys, key(n, address))
-	}
-
-	s.mu.RLock()
-	defer s.mu.RUnlock()
-	for _, k := range keys {
-		vifs := s.set[k]
-		if len(vifs) > 0 {
-			return vifs[rand.Intn(len(vifs))]
-		}
-	}
-	return nil
+	return s.find(network, address, false)
 }
 
 // Insert adds a VIF to the set
@@ -106,6 +107,33 @@
 	return l
 }
 
+func (s *Set) find(network, address string, blocking bool) *VIF {
+	if isNonDistinctConn(network, address) {
+		return nil
+	}
+
+	k := key(network, address)
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	for blocking && s.started[k] {
+		s.cond.Wait()
+	}
+
+	_, _, p := rpc.RegisteredProtocol(network)
+	for _, n := range p {
+		if vifs := s.set[key(n, address)]; len(vifs) > 0 {
+			return vifs[rand.Intn(len(vifs))]
+		}
+	}
+
+	if blocking {
+		s.started[k] = true
+	}
+	return nil
+}
+
 func key(network, address string) string {
 	if network == "tcp" || network == "ws" {
 		host, _, _ := net.SplitHostPort(address)
@@ -121,3 +149,11 @@
 	}
 	return network + ":" + address
 }
+
+// Some network connections (like those created with net.Pipe or Unix sockets)
+// do not end up with distinct net.Addrs on distinct net.Conns.
+func isNonDistinctConn(network, address string) bool {
+	return len(address) == 0 ||
+		(network == "pipe" && address == "pipe") ||
+		(runtime.GOOS == "linux" && network == "unix" && address == "@")
+}
diff --git a/profiles/internal/rpc/stream/vif/set_test.go b/profiles/internal/rpc/stream/vif/set_test.go
index 6ab1634..bc95fa5 100644
--- a/profiles/internal/rpc/stream/vif/set_test.go
+++ b/profiles/internal/rpc/stream/vif/set_test.go
@@ -324,3 +324,38 @@
 		t.Errorf("Unexpected list of VIFs: %v", l)
 	}
 }
+
+func TestBlockingFind(t *testing.T) {
+	network, address := "tcp", "127.0.0.1:1234"
+	set := vif.NewSet()
+
+	set.BlockingFind(network, address)
+
+	ch := make(chan *vif.VIF, 1)
+
+	// set.BlockingFind should block until set.Unblock is called with the corresponding VIF,
+	// since set.BlockingFind was called earlier.
+	go func(ch chan *vif.VIF) {
+		ch <- set.BlockingFind(network, address)
+	}(ch)
+
+	// set.BlockingFind for a different network and address should not block.
+	set.BlockingFind("network", "address")
+
+	// Create and insert the VIF.
+	c, s, err := newConn(network, address)
+	if err != nil {
+		t.Fatal(err)
+	}
+	vf, _, err := newVIF(c, s)
+	if err != nil {
+		t.Fatal(err)
+	}
+	set.Insert(vf)
+	set.Unblock(network, address)
+
+	// Now the set.BlockingFind should have returned the correct vif.
+	if cachedVif := <-ch; cachedVif != vf {
+		t.Errorf("got %v, want %v", cachedVif, vf)
+	}
+}