veyron/runtimes/google/ipc: Restore the check for nil context
to before the context is used.
Change-Id: I84f39d0d4f55f45da4e5eba84f38a0ecc4fecdb1
diff --git a/runtimes/google/ipc/client.go b/runtimes/google/ipc/client.go
index b2bd087..8f7659d 100644
--- a/runtimes/google/ipc/client.go
+++ b/runtimes/google/ipc/client.go
@@ -179,6 +179,24 @@
func (c *client) StartCall(ctx context.T, name, method string, args []interface{}, opts ...ipc.CallOpt) (ipc.Call, error) {
defer vlog.LogCall()()
+ return c.startCall(ctx, name, method, args, opts)
+}
+
+func getNoResolveOpt(opts []ipc.CallOpt) bool {
+ for _, o := range opts {
+ if r, ok := o.(options.NoResolve); ok {
+ return bool(r)
+ }
+ }
+ return false
+}
+
+// startCall ensures StartCall always returns verror.E.
+func (c *client) startCall(ctx context.T, name, method string, args []interface{}, opts []ipc.CallOpt) (ipc.Call, verror.E) {
+ if ctx == nil {
+ return nil, verror.BadArgf("ipc: %s.%s called with nil context", name, method)
+ }
+
// Context specified deadline.
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
@@ -196,7 +214,7 @@
break
}
}
- call, err := c.startCall(ctx, name, method, args, opts...)
+ call, err := c.tryCall(ctx, name, method, args, opts)
if err == nil {
return call, nil
}
@@ -208,20 +226,8 @@
return nil, lastErr
}
-func getNoResolveOpt(opts []ipc.CallOpt) bool {
- for _, o := range opts {
- if r, ok := o.(options.NoResolve); ok {
- return bool(r)
- }
- }
- return false
-}
-
-// startCall ensures StartCall always returns verror.E.
-func (c *client) startCall(ctx context.T, name, method string, args []interface{}, opts ...ipc.CallOpt) (ipc.Call, verror.E) {
- if ctx == nil {
- return nil, verror.BadArgf("ipc: %s.%s called with nil context", name, method)
- }
+// tryCall makes a single attempt at a call.
+func (c *client) tryCall(ctx context.T, name, method string, args []interface{}, opts []ipc.CallOpt) (ipc.Call, verror.E) {
ctx, _ = vtrace.WithNewSpan(ctx, fmt.Sprintf("Client Call: %s.%s", name, method))
// Resolve name unless told not to.
var servers []string
diff --git a/runtimes/google/ipc/full_test.go b/runtimes/google/ipc/full_test.go
index 0f0f3a9..4383169 100644
--- a/runtimes/google/ipc/full_test.go
+++ b/runtimes/google/ipc/full_test.go
@@ -1079,6 +1079,23 @@
}
}
+func TestCallWithNilContext(t *testing.T) {
+ sm := imanager.InternalNew(naming.FixedRoutingID(0x66666666))
+ defer sm.Shutdown()
+ ns := tnaming.NewSimpleNamespace()
+ client, err := InternalNewClient(sm, ns, options.VCSecurityNone)
+ if err != nil {
+ t.Fatalf("InternalNewClient failed: %v", err)
+ }
+ call, err := client.StartCall(nil, "foo", "bar", []interface{}{})
+ if call != nil {
+ t.Errorf("Expected nil interface got: %#v", call)
+ }
+ if !verror.Is(err, verror.BadArg) {
+ t.Errorf("Expected a BadArg error, got: %s", err.Error())
+ }
+}
+
func init() {
testutil.Init()
vom.Register(fakeTimeCaveat(0))