ref: Add ctx to RegisteredProtocol dial, resolve, and listen functions
This will allow us to implement the Proxy protocol as a RegisteredProtocol.
MultiPart: 1/2
Change-Id: I9e828035cc0c6154b1ea7a9bcf675de01ea1d68b
diff --git a/runtime/internal/flow/manager/manager.go b/runtime/internal/flow/manager/manager.go
index 7e395ef..4bdca655 100644
--- a/runtime/internal/flow/manager/manager.go
+++ b/runtime/internal/flow/manager/manager.go
@@ -230,14 +230,14 @@
if dl, ok := ctx.Deadline(); ok {
timeout = dl.Sub(time.Now())
}
- return d(protocol, address, timeout)
+ return d(ctx, protocol, address, timeout)
}
return nil, NewErrUnknownProtocol(ctx, protocol)
}
func resolve(ctx *context.T, r rpc.ResolverFunc, protocol, address string) (string, string, error) {
if r != nil {
- net, addr, err := r(protocol, address)
+ net, addr, err := r(ctx, protocol, address)
if err != nil {
return "", "", err
}
@@ -248,7 +248,7 @@
func listen(ctx *context.T, protocol, address string) (net.Listener, error) {
if _, _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
- ln, err := l(protocol, address)
+ ln, err := l(ctx, protocol, address)
if err != nil {
return nil, err
}
diff --git a/runtime/internal/lib/websocket/conn_nacl.go b/runtime/internal/lib/websocket/conn_nacl.go
index 479d38f..ca7e16d 100644
--- a/runtime/internal/lib/websocket/conn_nacl.go
+++ b/runtime/internal/lib/websocket/conn_nacl.go
@@ -12,6 +12,8 @@
"runtime/ppapi"
"sync"
"time"
+
+ "v.io/v23/context"
)
// Ppapi instance which must be set before the Dial is called.
@@ -32,7 +34,7 @@
currBuffer []byte
}
-func Dial(protocol, address string, timeout time.Duration) (net.Conn, error) {
+func Dial(ctx *context.T, protocol, address string, timeout time.Duration) (net.Conn, error) {
inst := PpapiInstance
u, err := url.Parse("ws://" + address)
if err != nil {
@@ -46,7 +48,7 @@
return WebsocketConn(address, ws), nil
}
-func Resolve(protocol, address string) (string, string, error) {
+func Resolve(ctx *context.T, protocol, address string) (string, string, error) {
return "ws", address, nil
}
diff --git a/runtime/internal/lib/websocket/conn_test.go b/runtime/internal/lib/websocket/conn_test.go
index 1415afc..292f616 100644
--- a/runtime/internal/lib/websocket/conn_test.go
+++ b/runtime/internal/lib/websocket/conn_test.go
@@ -15,6 +15,8 @@
"time"
"github.com/gorilla/websocket"
+
+ "v.io/v23/context"
)
func writer(c net.Conn, data []byte, times int, wg *sync.WaitGroup) {
@@ -94,7 +96,8 @@
}
// Dial out in another go routine
go func() {
- conn, err := Dial("tcp", addr.String(), time.Second)
+ ctx, _ := context.RootContext()
+ conn, err := Dial(ctx, "tcp", addr.String(), time.Second)
numTries := 0
for err != nil && numTries < 5 {
numTries++
diff --git a/runtime/internal/lib/websocket/dialer.go b/runtime/internal/lib/websocket/dialer.go
index 8e6e173..f7a3b21 100644
--- a/runtime/internal/lib/websocket/dialer.go
+++ b/runtime/internal/lib/websocket/dialer.go
@@ -15,9 +15,11 @@
"github.com/gorilla/websocket"
"v.io/x/ref/runtime/internal/lib/tcputil"
+
+ "v.io/v23/context"
)
-func Dial(protocol, address string, timeout time.Duration) (net.Conn, error) {
+func Dial(ctx *context.T, protocol, address string, timeout time.Duration) (net.Conn, error) {
var then time.Time
if timeout > 0 {
then = time.Now().Add(timeout)
diff --git a/runtime/internal/lib/websocket/hybrid.go b/runtime/internal/lib/websocket/hybrid.go
index 1b28892..fb92d0a 100644
--- a/runtime/internal/lib/websocket/hybrid.go
+++ b/runtime/internal/lib/websocket/hybrid.go
@@ -9,6 +9,8 @@
"time"
"v.io/x/ref/runtime/internal/lib/tcputil"
+
+ "v.io/v23/context"
)
// TODO(jhahn): Figure out a way for this mapping to be shared.
@@ -18,7 +20,7 @@
// always uses tcp. A client must specifically elect to use websockets by
// calling websocket.Dialer. The returned net.Conn will report 'tcp' as its
// Network.
-func HybridDial(network, address string, timeout time.Duration) (net.Conn, error) {
+func HybridDial(ctx *context.T, network, address string, timeout time.Duration) (net.Conn, error) {
tcp := mapWebSocketToTCP[network]
conn, err := net.DialTimeout(tcp, address, timeout)
if err != nil {
@@ -32,7 +34,7 @@
// HybridResolve performs a DNS resolution on the network, address and always
// returns tcp as its Network.
-func HybridResolve(network, address string) (string, string, error) {
+func HybridResolve(ctx *context.T, network, address string) (string, string, error) {
tcp := mapWebSocketToTCP[network]
tcpAddr, err := net.ResolveTCPAddr(tcp, address)
if err != nil {
@@ -49,6 +51,6 @@
// to decide if it's a websocket protocol or not. These must be 'GET ' for
// websockets, all other protocols must guarantee to not send 'GET ' as the
// first four bytes of the payload.
-func HybridListener(protocol, address string) (net.Listener, error) {
+func HybridListener(ctx *context.T, protocol, address string) (net.Listener, error) {
return listener(protocol, address, true)
}
diff --git a/runtime/internal/lib/websocket/listener.go b/runtime/internal/lib/websocket/listener.go
index 1206e8c..f31f096 100644
--- a/runtime/internal/lib/websocket/listener.go
+++ b/runtime/internal/lib/websocket/listener.go
@@ -18,6 +18,8 @@
"v.io/x/ref/internal/logger"
"v.io/x/ref/runtime/internal/lib/tcputil"
+
+ "v.io/v23/context"
)
var errListenerIsClosed = errors.New("Listener has been Closed")
@@ -53,7 +55,7 @@
return conn, nil
}
-func Listener(protocol, address string) (net.Listener, error) {
+func Listener(ctx *context.T, protocol, address string) (net.Listener, error) {
return listener(protocol, address, false)
}
diff --git a/runtime/internal/lib/websocket/listener_nacl.go b/runtime/internal/lib/websocket/listener_nacl.go
index a2e838e..ebf6255 100644
--- a/runtime/internal/lib/websocket/listener_nacl.go
+++ b/runtime/internal/lib/websocket/listener_nacl.go
@@ -9,6 +9,8 @@
import (
"fmt"
"net"
+
+ "v.io/v23/context"
)
// Websocket listeners are not supported in NaCl.
@@ -17,6 +19,6 @@
return nil, fmt.Errorf("Websocket Listener called in nacl code!")
}
-func Listener(protocol, address string) (net.Listener, error) {
+func Listener(ctx *context.T, protocol, address string) (net.Listener, error) {
return nil, fmt.Errorf("Websocket Listener called in nacl code!")
}
diff --git a/runtime/internal/lib/websocket/listener_test.go b/runtime/internal/lib/websocket/listener_test.go
index 477965a..f19ddd3 100644
--- a/runtime/internal/lib/websocket/listener_test.go
+++ b/runtime/internal/lib/websocket/listener_test.go
@@ -13,10 +13,13 @@
"strings"
"testing"
"time"
+
+ "v.io/v23/context"
)
func TestAcceptsAreNotSerialized(t *testing.T) {
- ln, err := HybridListener("wsh", "127.0.0.1:0")
+ ctx, _ := context.RootContext()
+ ln, err := HybridListener(ctx, "wsh", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
@@ -50,7 +53,7 @@
// blocked on the portscanner.
// (Wait for the portscanner to establish the TCP connection first).
<-portscan
- conn, err := Dial(ln.Addr().Network(), ln.Addr().String(), time.Second)
+ conn, err := Dial(ctx, ln.Addr().Network(), ln.Addr().String(), time.Second)
if err != nil {
t.Fatal(err)
}
@@ -58,7 +61,8 @@
}
func TestNonWebsocketRequest(t *testing.T) {
- ln, err := HybridListener("wsh", "127.0.0.1:0")
+ ctx, _ := context.RootContext()
+ ln, err := HybridListener(ctx, "wsh", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
diff --git a/runtime/internal/lib/websocket/resolver.go b/runtime/internal/lib/websocket/resolver.go
index 1286cf6..5a99c23 100644
--- a/runtime/internal/lib/websocket/resolver.go
+++ b/runtime/internal/lib/websocket/resolver.go
@@ -8,10 +8,12 @@
import (
"net"
+
+ "v.io/v23/context"
)
// Resolve performs a DNS resolution on the provided protocol and address.
-func Resolve(protocol, address string) (string, string, error) {
+func Resolve(ctx *context.T, protocol, address string) (string, string, error) {
tcp := mapWebSocketToTCP[protocol]
tcpAddr, err := net.ResolveTCPAddr(tcp, address)
if err != nil {
diff --git a/runtime/internal/lib/websocket/util_test.go b/runtime/internal/lib/websocket/util_test.go
index a0fbbd9..a7466d1 100644
--- a/runtime/internal/lib/websocket/util_test.go
+++ b/runtime/internal/lib/websocket/util_test.go
@@ -15,6 +15,7 @@
"testing"
"time"
+ "v.io/v23/context"
"v.io/v23/rpc"
)
@@ -27,7 +28,8 @@
}
func newSender(t *testing.T, dialer rpc.DialerFunc, protocol, address string) net.Conn {
- conn, err := dialer(protocol, address, time.Minute)
+ ctx, _ := context.RootContext()
+ conn, err := dialer(ctx, protocol, address, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %s", err)
return nil
diff --git a/runtime/internal/lib/websocket/ws_test.go b/runtime/internal/lib/websocket/ws_test.go
index 3940d52..b368cce 100644
--- a/runtime/internal/lib/websocket/ws_test.go
+++ b/runtime/internal/lib/websocket/ws_test.go
@@ -8,14 +8,17 @@
"net"
"sync"
"testing"
+ "time"
+ "v.io/v23/context"
"v.io/v23/rpc"
"v.io/x/ref/runtime/internal/lib/websocket"
)
func packetTester(t *testing.T, dialer rpc.DialerFunc, listener rpc.ListenerFunc, txProtocol, rxProtocol string) {
- ln, err := listener(rxProtocol, "127.0.0.1:0")
+ ctx, _ := context.RootContext()
+ ln, err := listener(ctx, rxProtocol, "127.0.0.1:0")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@@ -29,7 +32,8 @@
}
func byteTester(t *testing.T, dialer rpc.DialerFunc, listener rpc.ListenerFunc, txProtocol, rxProtocol string) {
- ln, err := listener(rxProtocol, "127.0.0.1:0")
+ ctx, _ := context.RootContext()
+ ln, err := listener(ctx, rxProtocol, "127.0.0.1:0")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@@ -43,6 +47,10 @@
}
+func simpleDial(ctx *context.T, p, a string, timeout time.Duration) (net.Conn, error) {
+ return net.DialTimeout(p, a, timeout)
+}
+
func TestWSToWS(t *testing.T) {
byteTester(t, websocket.Dial, websocket.Listener, "ws", "ws")
packetTester(t, websocket.Dial, websocket.Listener, "ws", "ws")
@@ -59,12 +67,13 @@
}
func TestTCPToWSH(t *testing.T) {
- byteTester(t, net.DialTimeout, websocket.HybridListener, "tcp", "wsh")
- packetTester(t, net.DialTimeout, websocket.HybridListener, "tcp", "wsh")
+ byteTester(t, simpleDial, websocket.HybridListener, "tcp", "wsh")
+ packetTester(t, simpleDial, websocket.HybridListener, "tcp", "wsh")
}
func TestMixed(t *testing.T) {
- ln, err := websocket.HybridListener("wsh", "127.0.0.1:0")
+ ctx, _ := context.RootContext()
+ ln, err := websocket.HybridListener(ctx, "wsh", "127.0.0.1:0")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
@@ -78,7 +87,7 @@
pwg.Add(4)
go packetTest(websocket.Dial, "ws")
- go packetTest(net.DialTimeout, "tcp")
+ go packetTest(simpleDial, "tcp")
go packetTest(websocket.Dial, "ws")
go packetTest(websocket.HybridDial, "wsh")
pwg.Wait()
@@ -90,7 +99,7 @@
}
bwg.Add(4)
go byteTest(websocket.Dial, "ws")
- go byteTest(net.DialTimeout, "tcp")
+ go byteTest(simpleDial, "tcp")
go byteTest(websocket.Dial, "ws")
go byteTest(websocket.HybridDial, "wsh")
diff --git a/runtime/internal/rpc/protocols/tcp/init.go b/runtime/internal/rpc/protocols/tcp/init.go
index 55897a1..a78117f 100644
--- a/runtime/internal/rpc/protocols/tcp/init.go
+++ b/runtime/internal/rpc/protocols/tcp/init.go
@@ -9,6 +9,7 @@
"net"
"time"
+ "v.io/v23/context"
"v.io/v23/rpc"
"v.io/x/ref/runtime/internal/lib/tcputil"
@@ -20,7 +21,7 @@
rpc.RegisterProtocol("tcp6", tcpDial, tcpResolve, tcpListen)
}
-func tcpDial(network, address string, timeout time.Duration) (net.Conn, error) {
+func tcpDial(ctx *context.T, network, address string, timeout time.Duration) (net.Conn, error) {
conn, err := net.DialTimeout(network, address, timeout)
if err != nil {
return nil, err
@@ -32,7 +33,7 @@
}
// tcpResolve performs a DNS resolution on the provided network and address.
-func tcpResolve(network, address string) (string, string, error) {
+func tcpResolve(ctx *context.T, network, address string) (string, string, error) {
tcpAddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return "", "", err
@@ -41,7 +42,7 @@
}
// tcpListen returns a listener that sets KeepAlive on all accepted connections.
-func tcpListen(network, laddr string) (net.Listener, error) {
+func tcpListen(ctx *context.T, network, laddr string) (net.Listener, error) {
ln, err := net.Listen(network, laddr)
if err != nil {
return nil, err
diff --git a/runtime/internal/rpc/stream/benchmark/throughput_ws.go b/runtime/internal/rpc/stream/benchmark/throughput_ws.go
index 07babce..b72ed1b 100644
--- a/runtime/internal/rpc/stream/benchmark/throughput_ws.go
+++ b/runtime/internal/rpc/stream/benchmark/throughput_ws.go
@@ -10,13 +10,16 @@
"testing"
"v.io/x/ref/runtime/internal/lib/websocket"
+
+ "v.io/v23/context"
)
// benchmarkWS sets up nConns WS connections and measures throughput.
func benchmarkWS(b *testing.B, nConns int) {
+ ctx, _ := context.RootContext()
rchan := make(chan net.Conn, nConns)
wchan := make(chan net.Conn, nConns)
- ln, err := websocket.Listener("ws", "127.0.0.1:0")
+ ln, err := websocket.Listener(ctx, "ws", "127.0.0.1:0")
if err != nil {
b.Fatalf("websocket.Listener failed: %v", err)
return
@@ -25,7 +28,7 @@
// One goroutine to dial nConns connections.
go func() {
for i := 0; i < nConns; i++ {
- conn, err := websocket.Dial("ws", ln.Addr().String(), 0)
+ conn, err := websocket.Dial(ctx, "ws", ln.Addr().String(), 0)
if err != nil {
b.Fatalf("websocket.Dial(%q, %q) failed: %v", "ws", ln.Addr(), err)
wchan <- nil
diff --git a/runtime/internal/rpc/stream/benchmark/throughput_wsh.go b/runtime/internal/rpc/stream/benchmark/throughput_wsh.go
index f160184..517fb2e 100644
--- a/runtime/internal/rpc/stream/benchmark/throughput_wsh.go
+++ b/runtime/internal/rpc/stream/benchmark/throughput_wsh.go
@@ -10,13 +10,16 @@
"testing"
"v.io/x/ref/runtime/internal/lib/websocket"
+
+ "v.io/v23/context"
)
// benchmarkWS sets up nConns WS connections and measures throughput.
func benchmarkWSH(b *testing.B, protocol string, nConns int) {
+ ctx, _ := context.RootContext()
rchan := make(chan net.Conn, nConns)
wchan := make(chan net.Conn, nConns)
- ln, err := websocket.HybridListener("wsh", "127.0.0.1:0")
+ ln, err := websocket.HybridListener(ctx, "wsh", "127.0.0.1:0")
if err != nil {
b.Fatalf("websocket.HybridListener failed: %v", err)
return
@@ -31,7 +34,7 @@
case "tcp":
conn, err = net.Dial("tcp", ln.Addr().String())
case "ws":
- conn, err = websocket.Dial("ws", ln.Addr().String(), 0)
+ conn, err = websocket.Dial(ctx, "ws", ln.Addr().String(), 0)
}
if err != nil {
b.Fatalf("Dial(%q, %q) failed: %v", protocol, ln.Addr(), err)
diff --git a/runtime/internal/rpc/stream/manager/error_test.go b/runtime/internal/rpc/stream/manager/error_test.go
index 1b72577..78d2fc0 100644
--- a/runtime/internal/rpc/stream/manager/error_test.go
+++ b/runtime/internal/rpc/stream/manager/error_test.go
@@ -10,6 +10,7 @@
"time"
"v.io/v23"
+ "v.io/v23/context"
"v.io/v23/naming"
"v.io/v23/rpc"
"v.io/v23/security"
@@ -79,7 +80,7 @@
}
}
-func dropDataDialer(network, address string, timeout time.Duration) (net.Conn, error) {
+func dropDataDialer(ctx *context.T, network, address string, timeout time.Duration) (net.Conn, error) {
matcher := func(read bool, msg message.T) bool {
switch msg.(type) {
case *message.Setup:
@@ -94,10 +95,14 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
-func simpleResolver(network, address string) (string, string, error) {
+func simpleResolver(ctx *context.T, network, address string) (string, string, error) {
return network, address, nil
}
+func simpleListen(ctx *context.T, network, address string) (net.Listener, error) {
+ return net.Listen(network, address)
+}
+
func TestDialErrors(t *testing.T) {
ctx, shutdown := test.V23Init()
defer shutdown()
@@ -125,7 +130,7 @@
}
t.Log(err)
- rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
+ rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, simpleListen)
ln, sep, err := server.Listen(sctx, "tcp", "127.0.0.1:0", pserver.BlessingStore().Default())
if err != nil {
diff --git a/runtime/internal/rpc/stream/manager/manager.go b/runtime/internal/rpc/stream/manager/manager.go
index 7b3a7ad..4ddef8c 100644
--- a/runtime/internal/rpc/stream/manager/manager.go
+++ b/runtime/internal/rpc/stream/manager/manager.go
@@ -91,26 +91,26 @@
defer apilog.LogCall(nil)(nil) // gologcop: DO NOT EDIT, MUST BE FIRST STATEMENT
}
-func dial(d rpc.DialerFunc, network, address string, timeout time.Duration) (net.Conn, error) {
+func dial(ctx *context.T, d rpc.DialerFunc, network, address string, timeout time.Duration) (net.Conn, error) {
if d != nil {
- conn, err := d(network, address, timeout)
+ conn, err := d(ctx, network, address, timeout)
if err != nil {
- return nil, verror.New(stream.ErrDialFailed, nil, err)
+ return nil, verror.New(stream.ErrDialFailed, ctx, err)
}
return conn, nil
}
- return nil, verror.New(stream.ErrDialFailed, nil, verror.New(errUnknownNetwork, nil, network))
+ return nil, verror.New(stream.ErrDialFailed, ctx, verror.New(errUnknownNetwork, ctx, network))
}
-func resolve(r rpc.ResolverFunc, network, address string) (string, string, error) {
+func resolve(ctx *context.T, r rpc.ResolverFunc, network, address string) (string, string, error) {
if r != nil {
- net, addr, err := r(network, address)
+ net, addr, err := r(ctx, network, address)
if err != nil {
- return "", "", verror.New(stream.ErrResolveFailed, nil, err)
+ return "", "", verror.New(stream.ErrResolveFailed, ctx, err)
}
return net, addr, nil
}
- return "", "", verror.New(stream.ErrResolveFailed, nil, verror.New(errUnknownNetwork, nil, network))
+ return "", "", verror.New(stream.ErrResolveFailed, ctx, verror.New(errUnknownNetwork, ctx, network))
}
type dialResult struct {
@@ -139,7 +139,7 @@
// - Similarly, an unspecified IP address (net.IP.IsUnspecified) like "[::]:80"
// might yield "[::1]:80" (loopback interface) in conn.RemoteAddr().
// Thus, look for VIFs with the resolved address.
- network, address, err := resolve(r, addr.Network(), addr.String())
+ network, address, err := resolve(ctx, r, addr.Network(), addr.String())
if err != nil {
return nil, err
}
@@ -154,7 +154,7 @@
ch := make(chan *dialResult)
go func() {
- conn, err := dial(d, network, address, timeout)
+ conn, err := dial(ctx, d, network, address, timeout)
ch <- &dialResult{conn, err}
}()
@@ -202,15 +202,15 @@
m.vifs.Delete(vf)
}
-func listen(protocol, address string) (net.Listener, error) {
+func listen(ctx *context.T, protocol, address string) (net.Listener, error) {
if _, _, l, _ := rpc.RegisteredProtocol(protocol); l != nil {
- ln, err := l(protocol, address)
+ ln, err := l(ctx, protocol, address)
if err != nil {
- return nil, verror.New(stream.ErrNetwork, nil, err)
+ return nil, verror.New(stream.ErrNetwork, ctx, err)
}
return ln, nil
}
- return nil, verror.New(stream.ErrBadArg, nil, verror.New(errUnknownNetwork, nil, protocol))
+ return nil, verror.New(stream.ErrBadArg, ctx, verror.New(errUnknownNetwork, ctx, protocol))
}
func (m *manager) Listen(ctx *context.T, protocol, address string, blessings security.Blessings, opts ...stream.ListenerOpt) (stream.Listener, naming.Endpoint, error) {
@@ -243,7 +243,7 @@
}
return m.remoteListen(ctx, ep, opts)
}
- netln, err := listen(protocol, address)
+ netln, err := listen(ctx, protocol, address)
if err != nil {
return nil, nil, err
}
diff --git a/runtime/internal/rpc/stream/manager/manager_test.go b/runtime/internal/rpc/stream/manager/manager_test.go
index 22f6a61..12f7bd5 100644
--- a/runtime/internal/rpc/stream/manager/manager_test.go
+++ b/runtime/internal/rpc/stream/manager/manager_test.go
@@ -21,6 +21,7 @@
"time"
"v.io/v23"
+ "v.io/v23/context"
"v.io/v23/naming"
"v.io/v23/rpc"
"v.io/v23/security"
@@ -847,13 +848,13 @@
blessings := principal.BlessingStore().Default()
ctx, _ = v23.WithPrincipal(ctx, principal)
- dialer := func(_, _ string, _ time.Duration) (net.Conn, error) {
+ dialer := func(_ *context.T, _, _ string, _ time.Duration) (net.Conn, error) {
return nil, fmt.Errorf("tn.Dial")
}
- resolver := func(_, _ string) (string, string, error) {
+ resolver := func(_ *context.T, _, _ string) (string, string, error) {
return "", "", fmt.Errorf("tn.Resolve")
}
- listener := func(_, _ string) (net.Listener, error) {
+ listener := func(_ *context.T, _, _ string) (net.Listener, error) {
return nil, fmt.Errorf("tn.Listen")
}
rpc.RegisterProtocol("tn", dialer, resolver, listener)
@@ -869,7 +870,7 @@
}
// Need a functional listener to test Dial.
- listener = func(_, addr string) (net.Listener, error) {
+ listener = func(_ *context.T, _, addr string) (net.Listener, error) {
return net.Listen("tcp", addr)
}
diff --git a/runtime/internal/rpc/stream/proxy/proxy.go b/runtime/internal/rpc/stream/proxy/proxy.go
index 5d30d3c..ee68f73 100644
--- a/runtime/internal/rpc/stream/proxy/proxy.go
+++ b/runtime/internal/rpc/stream/proxy/proxy.go
@@ -234,7 +234,7 @@
if listenFn == nil {
return nil, verror.New(stream.ErrProxy, nil, verror.New(errUnknownNetwork, nil, network))
}
- ln, err := listenFn(network, address)
+ ln, err := listenFn(ctx, network, address)
if err != nil {
return nil, verror.New(stream.ErrProxy, nil, verror.New(errListenFailed, nil, network, address, err))
}
diff --git a/runtime/internal/rpc/stream/vif/set_test.go b/runtime/internal/rpc/stream/vif/set_test.go
index 5ecf5d0..b79d8d5 100644
--- a/runtime/internal/rpc/stream/vif/set_test.go
+++ b/runtime/internal/rpc/stream/vif/set_test.go
@@ -30,10 +30,16 @@
var supportsIPv6 bool
func init() {
- simpleResolver := func(network, address string) (string, string, error) {
+ simpleResolver := func(ctx *context.T, network, address string) (string, string, error) {
return network, address, nil
}
- rpc.RegisterProtocol("unix", net.DialTimeout, simpleResolver, net.Listen)
+ simpleDial := func(ctx *context.T, p, a string, timeout time.Duration) (net.Conn, error) {
+ return net.DialTimeout(p, a, timeout)
+ }
+ simpleListen := func(ctx *context.T, p, a string) (net.Listener, error) {
+ return net.Listen(p, a)
+ }
+ rpc.RegisterProtocol("unix", simpleDial, simpleResolver, simpleListen)
// Check whether the platform supports IPv6.
ln, err := net.Listen("tcp6", "[::1]:0")
@@ -45,7 +51,7 @@
func newConn(network, address string) (net.Conn, net.Conn, error) {
dfunc, _, lfunc, _ := rpc.RegisteredProtocol(network)
- ln, err := lfunc(network, address)
+ ln, err := lfunc(nil, network, address)
if err != nil {
return nil, nil, err
}
@@ -61,7 +67,7 @@
done <- conn
}()
- conn, err := dfunc(ln.Addr().Network(), ln.Addr().String(), 1*time.Second)
+ conn, err := dfunc(nil, ln.Addr().Network(), ln.Addr().String(), 1*time.Second)
if err != nil {
return nil, nil, err
}
diff --git a/runtime/internal/rpc/test/client_test.go b/runtime/internal/rpc/test/client_test.go
index 85db6ea..4bcdaf3 100644
--- a/runtime/internal/rpc/test/client_test.go
+++ b/runtime/internal/rpc/test/client_test.go
@@ -377,7 +377,7 @@
logErr("timeout to server", err)
}
-func dropDataDialer(network, address string, timeout time.Duration) (net.Conn, error) {
+func dropDataDialer(ctx *context.T, network, address string, timeout time.Duration) (net.Conn, error) {
matcher := func(read bool, msg message.T) bool {
// Drop and close the connection when reading the first data message.
if _, ok := msg.(*message.Data); ok && read {
@@ -392,7 +392,7 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
-func simpleResolver(network, address string) (string, string, error) {
+func simpleResolver(ctx *context.T, network, address string) (string, string, error) {
return network, address, nil
}
@@ -407,7 +407,11 @@
logErrors(t, msg, true, false, false, err)
}
- rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, net.Listen)
+ simpleListen := func(ctx *context.T, protocol, address string) (net.Listener, error) {
+ return net.Listen(protocol, address)
+ }
+
+ rpc.RegisterProtocol("dropData", dropDataDialer, simpleResolver, simpleListen)
// The following test will fail due to a broken connection.
// We need to run mount table and servers with no security to use
diff --git a/runtime/internal/testing/mocks/mocknet/mocknet_test.go b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
index 5839ec8..f545b56 100644
--- a/runtime/internal/testing/mocks/mocknet/mocknet_test.go
+++ b/runtime/internal/testing/mocks/mocknet/mocknet_test.go
@@ -355,7 +355,7 @@
return true
}
- dropControlDialer := func(network, address string, timeout time.Duration) (net.Conn, error) {
+ dropControlDialer := func(ctx *context.T, network, address string, timeout time.Duration) (net.Conn, error) {
opts := mocknet.Opts{
Mode: mocknet.V23CloseAtMessage,
V23MessageMatcher: matcher,
@@ -363,11 +363,15 @@
return mocknet.DialerWithOpts(opts, network, address, timeout)
}
- simpleResolver := func(network, address string) (string, string, error) {
+ simpleResolver := func(ctx *context.T, network, address string) (string, string, error) {
return network, address, nil
}
- rpc.RegisterProtocol("dropControl", dropControlDialer, simpleResolver, net.Listen)
+ simpleListen := func(ctx *context.T, network, address string) (net.Listener, error) {
+ return net.Listen(network, address)
+ }
+
+ rpc.RegisterProtocol("dropControl", dropControlDialer, simpleResolver, simpleListen)
server, fn := initServer(t, ctx)
defer fn()
diff --git a/services/agent/internal/unixfd/unixfd.go b/services/agent/internal/unixfd/unixfd.go
index 6497af7..025ccaf 100644
--- a/services/agent/internal/unixfd/unixfd.go
+++ b/services/agent/internal/unixfd/unixfd.go
@@ -17,6 +17,7 @@
"time"
"unsafe"
+ "v.io/v23/context"
"v.io/v23/rpc"
"v.io/v23/verror"
)
@@ -84,7 +85,7 @@
return l.addr
}
-func unixFDConn(protocol, address string, timeout time.Duration) (net.Conn, error) {
+func unixFDConn(ctx *context.T, protocol, address string, timeout time.Duration) (net.Conn, error) {
// TODO(cnicolaou): have this respect the timeout. Possibly have a helper
// function that can be generally used for this, but in practice, I think
// it'll be cleaner to use the underlying protocol's deadline support of it
@@ -136,12 +137,12 @@
return c.addr
}
-func unixFDResolve(_, address string) (string, string, error) {
+func unixFDResolve(ctx *context.T, _, address string) (string, string, error) {
return Network, address, nil
}
-func unixFDListen(protocol, address string) (net.Listener, error) {
- conn, err := unixFDConn(protocol, address, 0)
+func unixFDListen(ctx *context.T, protocol, address string) (net.Listener, error) {
+ conn, err := unixFDConn(ctx, protocol, address, 0)
if err != nil {
return nil, err
}
diff --git a/services/agent/internal/unixfd/unixfd_test.go b/services/agent/internal/unixfd/unixfd_test.go
index 10630ea..06f7010 100644
--- a/services/agent/internal/unixfd/unixfd_test.go
+++ b/services/agent/internal/unixfd/unixfd_test.go
@@ -10,19 +10,23 @@
"net"
"reflect"
"testing"
+
+ "v.io/v23/context"
)
type nothing struct{}
func dial(fd *fileDescriptor) (net.Conn, net.Addr, error) {
addr := fd.releaseAddr()
- conn, err := unixFDConn(Network, addr.String(), 0)
+ ctx, _ := context.RootContext()
+ conn, err := unixFDConn(ctx, Network, addr.String(), 0)
return conn, addr, err
}
func listen(fd *fileDescriptor) (net.Listener, net.Addr, error) {
addr := fd.releaseAddr()
- l, err := unixFDListen(Network, addr.String())
+ ctx, _ := context.RootContext()
+ l, err := unixFDListen(ctx, Network, addr.String())
return l, addr, err
}
@@ -169,11 +173,12 @@
t.Fatalf("unexpected data %q", data)
}
- a, err := unixFDConn(Network, caddr.String(), 0)
+ ctx, _ := context.RootContext()
+ a, err := unixFDConn(ctx, Network, caddr.String(), 0)
if err != nil {
t.Fatalf("dial %v: %v", caddr, err)
}
- b, err := unixFDConn(Network, saddr.String(), 0)
+ b, err := unixFDConn(ctx, Network, saddr.String(), 0)
if err != nil {
t.Fatalf("dial %v: %v", saddr, err)
}