Merge "rpc/stream/vif: Concurrent mgr.FindOrDialVIFs should share VIF cache."
diff --git a/profiles/internal/rpc/stream/manager/manager.go b/profiles/internal/rpc/stream/manager/manager.go
index 16e951a..85c081f 100644
--- a/profiles/internal/rpc/stream/manager/manager.go
+++ b/profiles/internal/rpc/stream/manager/manager.go
@@ -125,11 +125,13 @@
// - Similarly, an unspecified IP address (net.IP.IsUnspecified) like "[::]:80"
// might yield "[::1]:80" (loopback interface) in conn.RemoteAddr().
// Thus, look for VIFs with the resolved address as well.
- if vf := m.vifs.Find(conn.RemoteAddr().Network(), conn.RemoteAddr().String()); vf != nil {
- vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, conn.RemoteAddr().Network(), conn.RemoteAddr())
+ resNetwork, resAddress := conn.RemoteAddr().Network(), conn.RemoteAddr().String()
+ if vf := m.vifs.BlockingFind(resNetwork, resAddress); vf != nil {
+ vlog.VI(1).Infof("(%q, %q) resolved to (%q, %q) which exists in the VIF cache. Closing newly Dialed connection", network, address, resNetwork, resAddress)
conn.Close()
return vf, nil
}
+ defer m.vifs.Unblock(resNetwork, resAddress)
opts = append([]stream.VCOpt{vc.StartTimeout{defaultStartTimeout}}, opts...)
vf, err := vif.InternalNewDialedVIF(conn, m.rid, principal, nil, m.deleteVIF, opts...)
@@ -137,14 +139,6 @@
conn.Close()
return nil, err
}
- // TODO(ashankar): If two goroutines are simultaneously invoking
- // manager.Dial, it is possible that two VIFs are inserted into m.vifs
- // for the same remote network address. This is normally not a problem,
- // but can be troublesome if the remote endpoint corresponds to a
- // proxy, since the proxy requires a single network connection per
- // routing id. Figure out a way to handle this cleanly. One option is
- // to have only a single VIF per remote network address - have to think
- // that through.
m.vifs.Insert(vf)
return vf, nil
}
diff --git a/profiles/internal/rpc/stream/manager/manager_test.go b/profiles/internal/rpc/stream/manager/manager_test.go
index 5ae22f3..ff74b71 100644
--- a/profiles/internal/rpc/stream/manager/manager_test.go
+++ b/profiles/internal/rpc/stream/manager/manager_test.go
@@ -924,3 +924,53 @@
t.Logf("Server FD limit:%d", nfiles)
t.Logf("Client connection attempts: %d", nattempts)
}
+
+func TestConcurrentDials(t *testing.T) {
+ // Concurrent Dials to the same network, address should only result in one VIF.
+ server := InternalNew(naming.FixedRoutingID(0x55555555))
+ client := InternalNew(naming.FixedRoutingID(0xcccccccc))
+ principal := testutil.NewPrincipal("test")
+
+ // Using "tcp4" instead of "tcp" because the latter can end up with IPv6
+ // addresses and our Google Compute Engine integration test machines cannot
+ // resolve IPv6 addresses.
+ // As of April 2014, https://developers.google.com/compute/docs/networking
+ // said that IPv6 is not yet supported.
+ ln, ep, err := server.Listen("tcp4", "127.0.0.1:0", principal, principal.BlessingStore().Default())
+ if err != nil {
+ t.Fatal(err)
+ }
+ go acceptLoop(ln)
+
+ // We'd like an endpoint that contains an address that's different than the
+ // one used for the connection. In practice this is awkward to achieve since
+ // we don't want to listen on ":0" since that will annoy firewalls. Instead we
+ // listen on 127.0.0.1 and we fabricate an endpoint that doesn't contain
+ // 127.0.0.1 by using ":0" to create it. This leads to an endpoint such that
+ // the address encoded in the endpoint (e.g. "0.0.0.0:55324") is different
+ // from the address of the connection (e.g. "127.0.0.1:55324").
+ _, port, _ := net.SplitHostPort(ep.Addr().String())
+ nep := &inaming.Endpoint{
+ Protocol: ep.Addr().Network(),
+ Address: net.JoinHostPort("", port),
+ RID: ep.RoutingID(),
+ }
+
+ // Dial multiple VCs
+ errCh := make(chan error, 10)
+ for i := 0; i < 10; i++ {
+ go func() {
+ _, err = client.Dial(nep, testutil.NewPrincipal("client"))
+ errCh <- err
+ }()
+ }
+ for i := 0; i < 10; i++ {
+ if err = <-errCh; err != nil {
+ t.Fatal(err)
+ }
+ }
+ // They should all be on the same VIF.
+ if n := numVIFs(client); n != 1 {
+ t.Errorf("Client has %d VIFs, want 1\n%v", n, debugString(client))
+ }
+}
diff --git a/profiles/internal/rpc/stream/vif/set.go b/profiles/internal/rpc/stream/vif/set.go
index c6629f0..b966735 100644
--- a/profiles/internal/rpc/stream/vif/set.go
+++ b/profiles/internal/rpc/stream/vif/set.go
@@ -17,46 +17,47 @@
// connection. Multiple goroutines can invoke methods on the Set
// simultaneously.
type Set struct {
- mu sync.RWMutex
- set map[string][]*VIF
+ mu sync.RWMutex
+ set map[string][]*VIF // GUARDED_BY(mu)
+ started map[string]bool // GUARDED_BY(mu)
+ cond *sync.Cond
}
// NewSet returns a new Set of VIFs.
func NewSet() *Set {
- return &Set{set: make(map[string][]*VIF)}
+ s := &Set{
+ set: make(map[string][]*VIF),
+ started: make(map[string]bool),
+ }
+ s.cond = sync.NewCond(&s.mu)
+ return s
+}
+
+// BlockingFind returns a VIF where the remote end of the underlying network connection
+// is identified by the provided (network, address). Returns nil if there is no
+// such VIF.
+//
+// If BlockingFind returns nil, the caller is required to call Unblock, to avoid deadlock.
+// The network and address in Unblock must be the same as used in the BlockingFind call.
+// During this time, all new BlockingFind calls for this network and address will Block until
+// the corresponding Unblock call is made.
+func (s *Set) BlockingFind(network, address string) *VIF {
+ return s.find(network, address, true)
+}
+
+// Unblock broadcasts all threads
+func (s *Set) Unblock(network, address string) {
+ s.mu.Lock()
+ delete(s.started, key(network, address))
+ s.cond.Broadcast()
+ s.mu.Unlock()
}
// Find returns a VIF where the remote end of the underlying network connection
// is identified by the provided (network, address). Returns nil if there is no
// such VIF.
-//
-// If there are multiple VIFs established to the same remote network address,
-// Find will randomly return one of them.
func (s *Set) Find(network, address string) *VIF {
- if len(address) == 0 ||
- (network == "pipe" && address == "pipe") ||
- (runtime.GOOS == "linux" && network == "unix" && address == "@") { // autobind
- // Some network connections (like those created with net.Pipe or Unix sockets)
- // do not end up with distinct net.Addrs on distinct net.Conns. For those cases,
- // avoid the cache collisions by disabling cache lookups for them.
- return nil
- }
-
- var keys []string
- _, _, p := rpc.RegisteredProtocol(network)
- for _, n := range p {
- keys = append(keys, key(n, address))
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
- for _, k := range keys {
- vifs := s.set[k]
- if len(vifs) > 0 {
- return vifs[rand.Intn(len(vifs))]
- }
- }
- return nil
+ return s.find(network, address, false)
}
// Insert adds a VIF to the set
@@ -106,6 +107,33 @@
return l
}
+func (s *Set) find(network, address string, blocking bool) *VIF {
+ if isNonDistinctConn(network, address) {
+ return nil
+ }
+
+ k := key(network, address)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for blocking && s.started[k] {
+ s.cond.Wait()
+ }
+
+ _, _, p := rpc.RegisteredProtocol(network)
+ for _, n := range p {
+ if vifs := s.set[key(n, address)]; len(vifs) > 0 {
+ return vifs[rand.Intn(len(vifs))]
+ }
+ }
+
+ if blocking {
+ s.started[k] = true
+ }
+ return nil
+}
+
func key(network, address string) string {
if network == "tcp" || network == "ws" {
host, _, _ := net.SplitHostPort(address)
@@ -121,3 +149,11 @@
}
return network + ":" + address
}
+
+// Some network connections (like those created with net.Pipe or Unix sockets)
+// do not end up with distinct net.Addrs on distinct net.Conns.
+func isNonDistinctConn(network, address string) bool {
+ return len(address) == 0 ||
+ (network == "pipe" && address == "pipe") ||
+ (runtime.GOOS == "linux" && network == "unix" && address == "@")
+}
diff --git a/profiles/internal/rpc/stream/vif/set_test.go b/profiles/internal/rpc/stream/vif/set_test.go
index 6ab1634..bc95fa5 100644
--- a/profiles/internal/rpc/stream/vif/set_test.go
+++ b/profiles/internal/rpc/stream/vif/set_test.go
@@ -324,3 +324,38 @@
t.Errorf("Unexpected list of VIFs: %v", l)
}
}
+
+func TestBlockingFind(t *testing.T) {
+ network, address := "tcp", "127.0.0.1:1234"
+ set := vif.NewSet()
+
+ set.BlockingFind(network, address)
+
+ ch := make(chan *vif.VIF, 1)
+
+ // set.BlockingFind should block until set.Unblock is called with the corresponding VIF,
+ // since set.BlockingFind was called earlier.
+ go func(ch chan *vif.VIF) {
+ ch <- set.BlockingFind(network, address)
+ }(ch)
+
+ // set.BlockingFind for a different network and address should not block.
+ set.BlockingFind("network", "address")
+
+ // Create and insert the VIF.
+ c, s, err := newConn(network, address)
+ if err != nil {
+ t.Fatal(err)
+ }
+ vf, _, err := newVIF(c, s)
+ if err != nil {
+ t.Fatal(err)
+ }
+ set.Insert(vf)
+ set.Unblock(network, address)
+
+ // Now the set.BlockingFind should have returned the correct vif.
+ if cachedVif := <-ch; cachedVif != vf {
+ t.Errorf("got %v, want %v", cachedVif, vf)
+ }
+}