v.io/x/jni: support for executor changes

MultiPart: 2/2
Change-Id: I850942127456dfefe7afc6e9513d52162911ed15
diff --git a/impl/google/channel/jni.go b/impl/google/channel/jni.go
index 81ede12..852fbf1 100644
--- a/impl/google/channel/jni.go
+++ b/impl/google/channel/jni.go
@@ -16,6 +16,7 @@
 import "C"
 
 var (
+	contextSign = jutil.ClassSign("io.v.v23.context.VContext")
 	// Global reference for io.v.impl.google.channel.InputChannelImpl class.
 	jInputChannelImplClass jutil.Class
 	// Global reference for io.v.impl.google.channel.OutputChannelImpl class.
diff --git a/impl/google/channel/util.go b/impl/google/channel/util.go
index 486307a..c307ae6 100644
--- a/impl/google/channel/util.go
+++ b/impl/google/channel/util.go
@@ -7,7 +7,10 @@
 package channel
 
 import (
+	"v.io/v23/context"
+
 	jutil "v.io/x/jni/util"
+	jcontext "v.io/x/jni/v23/context"
 )
 
 // #include "jni.h"
@@ -19,8 +22,12 @@
 //
 // The recv function must return verror.ErrEndOfFile when there are no more elements
 // to receive.
-func JavaInputChannel(env jutil.Env, recv func() (jutil.Object, error)) (jutil.Object, error) {
-	jInputChannel, err := jutil.NewObject(env, jInputChannelImplClass, []jutil.Sign{jutil.LongSign}, int64(jutil.PtrValue(&recv)))
+func JavaInputChannel(env jutil.Env, ctx *context.T, ctxCancel func(), recv func() (jutil.Object, error)) (jutil.Object, error) {
+	jContext, err := jcontext.JavaContext(env, ctx, ctxCancel)
+	if err != nil {
+		return jutil.NullObject, err
+	}
+	jInputChannel, err := jutil.NewObject(env, jInputChannelImplClass, []jutil.Sign{contextSign, jutil.LongSign}, jContext, int64(jutil.PtrValue(&recv)))
 	if err != nil {
 		return jutil.NullObject, err
 	}
@@ -30,8 +37,12 @@
 
 // JavaOutputChannel creates a new Java OutputChannel object given the provided Go convert, send
 // and close functions. Send is invoked with the result of convert, which must be non-blocking.
-func JavaOutputChannel(env jutil.Env, convert func(jutil.Object) (interface{}, error), send func(interface{}) error, close func() error) (jutil.Object, error) {
-	jOutputChannel, err := jutil.NewObject(env, jOutputChannelImplClass, []jutil.Sign{jutil.LongSign, jutil.LongSign, jutil.LongSign}, int64(jutil.PtrValue(&convert)), int64(jutil.PtrValue(&send)), int64(jutil.PtrValue(&close)))
+func JavaOutputChannel(env jutil.Env, ctx *context.T, ctxCancel func(), convert func(jutil.Object) (interface{}, error), send func(interface{}) error, close func() error) (jutil.Object, error) {
+	jContext, err := jcontext.JavaContext(env, ctx, ctxCancel)
+	if err != nil {
+		return jutil.NullObject, err
+	}
+	jOutputChannel, err := jutil.NewObject(env, jOutputChannelImplClass, []jutil.Sign{contextSign, jutil.LongSign, jutil.LongSign, jutil.LongSign}, jContext, int64(jutil.PtrValue(&convert)), int64(jutil.PtrValue(&send)), int64(jutil.PtrValue(&close)))
 	if err != nil {
 		return jutil.NullObject, err
 	}
diff --git a/impl/google/discovery/jni.go b/impl/google/discovery/jni.go
index 0a6d9d5..6a0b4d2 100644
--- a/impl/google/discovery/jni.go
+++ b/impl/google/discovery/jni.go
@@ -186,7 +186,7 @@
 //export Java_io_v_impl_google_lib_discovery_VDiscoveryImpl_nativeScan
 func Java_io_v_impl_google_lib_discovery_VDiscoveryImpl_nativeScan(jenv *C.JNIEnv, jDiscovery C.jobject, goDiscoveryPtr C.jlong, jContext C.jobject, jQuery C.jstring) C.jobject {
 	env := jutil.Env(uintptr(unsafe.Pointer(jenv)))
-	ctx, _, err := jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
+	ctx, cancel, err := jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
 	if err != nil {
 		jutil.JThrowV(env, err)
 		return nil
@@ -201,7 +201,7 @@
 		scanChannel, scanError = ds.Scan(ctx, query)
 		close(scanDone)
 	}()
-	jChannel, err := jchannel.JavaInputChannel(env, func() (jutil.Object, error) {
+	jChannel, err := jchannel.JavaInputChannel(env, ctx, cancel, func() (jutil.Object, error) {
 		// A few blocking calls below - don't call GetEnv() before they complete.
 		<-scanDone
 		if scanError != nil {
diff --git a/impl/google/namespace/jni.go b/impl/google/namespace/jni.go
index e6ad0f1..a12d963 100644
--- a/impl/google/namespace/jni.go
+++ b/impl/google/namespace/jni.go
@@ -58,8 +58,8 @@
 	return nil
 }
 
-func globArgs(env jutil.Env, jContext C.jobject, jPattern C.jstring, jOptions C.jobject) (context *context.T, pattern string, opts []naming.NamespaceOpt, err error) {
-	context, _, err = jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
+func globArgs(env jutil.Env, jContext C.jobject, jPattern C.jstring, jOptions C.jobject) (context *context.T, cancel func(), pattern string, opts []naming.NamespaceOpt, err error) {
+	context, cancel, err = jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
 	if err != nil {
 		return
 	}
@@ -75,7 +75,7 @@
 func Java_io_v_impl_google_namespace_NamespaceImpl_nativeGlob(jenv *C.JNIEnv, jNamespaceClass C.jclass, goNamespacePtr C.jlong, jContext C.jobject, jPattern C.jstring, jOptions C.jobject) C.jobject {
 	env := jutil.Env(uintptr(unsafe.Pointer(jenv)))
 	n := *(*namespace.T)(jutil.NativePtr(goNamespacePtr))
-	context, pattern, opts, err := globArgs(env, jContext, jPattern, jOptions)
+	ctx, cancel, pattern, opts, err := globArgs(env, jContext, jPattern, jOptions)
 	if err != nil {
 		jutil.JThrowV(env, err)
 		return nil
@@ -84,10 +84,10 @@
 	var globError error
 	globDone := make(chan bool)
 	go func() {
-		globChannel, globError = n.Glob(context, pattern, opts...)
+		globChannel, globError = n.Glob(ctx, pattern, opts...)
 		close(globDone)
 	}()
-	jChannel, err := jchannel.JavaInputChannel(env, func() (jutil.Object, error) {
+	jChannel, err := jchannel.JavaInputChannel(env, ctx, cancel, func() (jutil.Object, error) {
 		// A few blocking calls below - don't call GetEnv() before they complete.
 		<-globDone
 		if globError != nil {
@@ -95,7 +95,7 @@
 		}
 		globReply, ok := <-globChannel
 		if !ok {
-			return jutil.NullObject, verror.NewErrEndOfFile(context)
+			return jutil.NullObject, verror.NewErrEndOfFile(ctx)
 		}
 		env, freeFunc := jutil.GetEnv()
 		defer freeFunc()
diff --git a/impl/google/rpc/invoker.go b/impl/google/rpc/invoker.go
index e488c9f..098a341 100644
--- a/impl/google/rpc/invoker.go
+++ b/impl/google/rpc/invoker.go
@@ -94,7 +94,7 @@
 		freeFunc()
 		return nil, err
 	}
-	jStreamServerCall, err := javaStreamServerCall(env, call)
+	jStreamServerCall, err := javaStreamServerCall(env, jContext, call)
 	if err != nil {
 		freeFunc()
 		return nil, err
@@ -222,7 +222,7 @@
 	close := func() error {
 		return nil
 	}
-	jOutputChannel, err := jchannel.JavaOutputChannel(env, convert, send, close)
+	jOutputChannel, err := jchannel.JavaOutputChannel(env, ctx, nil, convert, send, close)
 	if err != nil {
 		return err
 	}
diff --git a/impl/google/rpc/jni.go b/impl/google/rpc/jni.go
index e5e5c03..d9ae2ef 100644
--- a/impl/google/rpc/jni.go
+++ b/impl/google/rpc/jni.go
@@ -242,7 +242,7 @@
 	return args, nil
 }
 
-func doStartCall(context *context.T, name, method string, skipServerAuth bool, goPtr C.jlong, args []interface{}) (jutil.Object, error) {
+func doStartCall(ctx *context.T, cancel func(), name, method string, skipServerAuth bool, goPtr C.jlong, args []interface{}) (jutil.Object, error) {
 	var opts []rpc.CallOpt
 	if skipServerAuth {
 		opts = append(opts,
@@ -250,13 +250,17 @@
 			options.ServerAuthorizer{security.AllowEveryone()})
 	}
 	// Invoke StartCall
-	call, err := (*(*rpc.Client)(jutil.NativePtr(goPtr))).StartCall(context, name, method, args, opts...)
+	call, err := (*(*rpc.Client)(jutil.NativePtr(goPtr))).StartCall(ctx, name, method, args, opts...)
 	if err != nil {
 		return jutil.NullObject, err
 	}
 	env, freeFunc := jutil.GetEnv()
 	defer freeFunc()
-	jCall, err := javaCall(env, call)
+	jContext, err := jcontext.JavaContext(env, ctx, cancel)
+	if err != nil {
+		return jutil.NullObject, err
+	}
+	jCall, err := javaCall(env, jContext, call)
 	if err != nil {
 		return jutil.NullObject, err
 	}
@@ -271,7 +275,7 @@
 	name := jutil.GoString(env, jutil.Object(uintptr(unsafe.Pointer(jName))))
 	method := jutil.GoString(env, jutil.Object(uintptr(unsafe.Pointer(jMethod))))
 	jCallback := jutil.Object(uintptr(unsafe.Pointer(jCallbackObj)))
-	context, _, err := jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
+	ctx, cancel, err := jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
 	if err != nil {
 		jutil.CallbackOnFailure(env, jCallback, err)
 		return
@@ -283,7 +287,7 @@
 	}
 	skipServerAuth := jSkipServerAuth == C.JNI_TRUE
 	jutil.DoAsyncCall(env, jCallback, func() (jutil.Object, error) {
-		return doStartCall(context, name, method, skipServerAuth, goPtr, args)
+		return doStartCall(ctx, cancel, name, method, skipServerAuth, goPtr, args)
 	})
 }
 
diff --git a/impl/google/rpc/util.go b/impl/google/rpc/util.go
index b87812a..c534de2 100644
--- a/impl/google/rpc/util.go
+++ b/impl/google/rpc/util.go
@@ -12,6 +12,7 @@
 	"runtime"
 
 	"v.io/v23/rpc"
+
 	jutil "v.io/x/jni/util"
 )
 
@@ -46,11 +47,11 @@
 
 // javaStreamServerCall converts the provided Go serverCall into a Java StreamServerCall
 // object.
-func javaStreamServerCall(env jutil.Env, call rpc.StreamServerCall) (jutil.Object, error) {
+func javaStreamServerCall(env jutil.Env, jContext jutil.Object, call rpc.StreamServerCall) (jutil.Object, error) {
 	if call == nil {
 		return jutil.NullObject, fmt.Errorf("Go StreamServerCall value cannot be nil")
 	}
-	jStream, err := javaStream(env, call)
+	jStream, err := javaStream(env, jContext, call)
 	if err != nil {
 		return jutil.NullObject, err
 	}
@@ -68,15 +69,15 @@
 }
 
 // javaCall converts the provided Go Call value into a Java Call object.
-func javaCall(env jutil.Env, call rpc.ClientCall) (jutil.Object, error) {
+func javaCall(env jutil.Env, jContext jutil.Object, call rpc.ClientCall) (jutil.Object, error) {
 	if call == nil {
 		return jutil.NullObject, fmt.Errorf("Go Call value cannot be nil")
 	}
-	jStream, err := javaStream(env, call)
+	jStream, err := javaStream(env, jContext, call)
 	if err != nil {
 		return jutil.NullObject, err
 	}
-	jCall, err := jutil.NewObject(env, jClientCallImplClass, []jutil.Sign{jutil.LongSign, streamSign}, int64(jutil.PtrValue(&call)), jStream)
+	jCall, err := jutil.NewObject(env, jClientCallImplClass, []jutil.Sign{contextSign, jutil.LongSign, streamSign}, jContext, int64(jutil.PtrValue(&call)), jStream)
 	if err != nil {
 		return jutil.NullObject, err
 	}
@@ -85,8 +86,8 @@
 }
 
 // javaStream converts the provided Go stream into a Java Stream object.
-func javaStream(env jutil.Env, stream rpc.Stream) (jutil.Object, error) {
-	jStream, err := jutil.NewObject(env, jStreamImplClass, []jutil.Sign{jutil.LongSign}, int64(jutil.PtrValue(&stream)))
+func javaStream(env jutil.Env, jContext jutil.Object, stream rpc.Stream) (jutil.Object, error) {
+	jStream, err := jutil.NewObject(env, jStreamImplClass, []jutil.Sign{contextSign, jutil.LongSign}, jContext, int64(jutil.PtrValue(&stream)))
 	if err != nil {
 		return jutil.NullObject, err
 	}
@@ -175,11 +176,7 @@
 	default:
 		return jutil.NullObject, fmt.Errorf("Unrecognized state: %d", state)
 	}
-	jState, err := jutil.CallStaticObjectMethod(env, jServerStateClass, "valueOf", []jutil.Sign{jutil.StringSign}, serverStateSign, name)
-	if err != nil {
-		return jutil.NullObject, err
-	}
-	return jState, nil
+	return jutil.CallStaticObjectMethod(env, jServerStateClass, "valueOf", []jutil.Sign{jutil.StringSign}, serverStateSign, name)
 }
 
 // JavaMountStatus converts the provided rpc.MountStatus value into a Java
diff --git a/impl/google/rt/jni.go b/impl/google/rt/jni.go
index 90c1855..ee6eeb2 100644
--- a/impl/google/rt/jni.go
+++ b/impl/google/rt/jni.go
@@ -61,16 +61,22 @@
 }
 
 //export Java_io_v_impl_google_rt_VRuntimeImpl_nativeShutdown
-func Java_io_v_impl_google_rt_VRuntimeImpl_nativeShutdown(jenv *C.JNIEnv, jRuntime C.jclass, jContext C.jobject) {
+func Java_io_v_impl_google_rt_VRuntimeImpl_nativeShutdown(jenv *C.JNIEnv, jRuntime C.jclass, jContext C.jobject, jCallbackObj C.jobject) {
 	env := jutil.Env(uintptr(unsafe.Pointer(jenv)))
+	jCallback := jutil.Object(uintptr(unsafe.Pointer(jCallbackObj)))
 	ctx, _, err := jcontext.GoContext(env, jutil.Object(uintptr(unsafe.Pointer(jContext))))
 	if err != nil {
 		jutil.JThrowV(env, err)
 	}
 	value := ctx.Value(shutdownKey{})
-	if shutdownFunc, ok := value.(v23.Shutdown); ok {
-		shutdownFunc()
+	shutdownFunc, ok := value.(v23.Shutdown)
+	if !ok {
+		panic("shutdown function not found")
 	}
+	jutil.DoAsyncCall(env, jCallback, func() (jutil.Object, error) {
+		shutdownFunc()
+		return jutil.NullObject, nil
+	})
 }
 
 //export Java_io_v_impl_google_rt_VRuntimeImpl_nativeWithNewClient
diff --git a/v23/context/jni.go b/v23/context/jni.go
index 7dbd4e1..633ae95 100644
--- a/v23/context/jni.go
+++ b/v23/context/jni.go
@@ -19,9 +19,12 @@
 import "C"
 
 var (
-	classSign = jutil.ClassSign("java.lang.Class")
+	classSign      = jutil.ClassSign("java.lang.Class")
+	doneReasonSign = jutil.ClassSign("io.v.v23.context.VContext$DoneReason")
 	// Global reference for io.v.v23.context.VContext class.
 	jVContextClass jutil.Class
+	// Global reference for io.v.v23.context.VContext$DoneReason
+	jDoneReasonClass jutil.Class
 )
 
 // Init initializes the JNI code with the given Java environment. This method
@@ -35,6 +38,10 @@
 	if err != nil {
 		return err
 	}
+	jDoneReasonClass, err = jutil.JFindClass(env, "io/v/v23/context/VContext$DoneReason")
+	if err != nil {
+		return err
+	}
 	return nil
 }
 
@@ -78,18 +85,27 @@
 	return C.jobject(unsafe.Pointer(jDeadline))
 }
 
-//export Java_io_v_v23_context_VContext_nativeDone
-func Java_io_v_v23_context_VContext_nativeDone(jenv *C.JNIEnv, jVContext C.jobject, goPtr C.jlong, jCallbackObj C.jobject) {
+//export Java_io_v_v23_context_VContext_nativeOnDone
+func Java_io_v_v23_context_VContext_nativeOnDone(jenv *C.JNIEnv, jVContext C.jobject, goPtr C.jlong, jCallbackObj C.jobject) {
 	env := jutil.Env(uintptr(unsafe.Pointer(jenv)))
 	jCallback := jutil.Object(uintptr(unsafe.Pointer(jCallbackObj)))
-	c := (*(*context.T)(jutil.NativePtr(goPtr))).Done()
+	ctx := (*(*context.T)(jutil.NativePtr(goPtr)))
+	c := ctx.Done()
 	if c == nil {
 		jutil.CallbackOnFailure(env, jCallback, errors.New("Context isn't cancelable"))
 		return
 	}
 	jutil.DoAsyncCall(env, jCallback, func() (jutil.Object, error) {
 		<-c
-		return jutil.NullObject, nil
+		env, freeFunc := jutil.GetEnv()
+		defer freeFunc()
+		jReason, err := JavaContextDoneReason(env, ctx.Err())
+		if err != nil {
+			return jutil.NullObject, err
+		}
+		// Must grab a global reference as we free up the env and all local references that come along
+		// with it.
+		return jutil.NewGlobalRef(env, jReason), nil // Un-refed in DoAsyncCall
 	})
 }
 
diff --git a/v23/context/util.go b/v23/context/util.go
index c2d93e4..1b762f2 100644
--- a/v23/context/util.go
+++ b/v23/context/util.go
@@ -111,3 +111,18 @@
 	}
 	return val.jObj, nil
 }
+
+// JavaContextDoneReason return the Java DoneReason given the Go error returned
+// by ctx.Error().
+func JavaContextDoneReason(env jutil.Env, err error) (jutil.Object, error) {
+	var name string
+	switch err {
+	case context.Canceled:
+		name = "CANCELED"
+	case context.DeadlineExceeded:
+		name = "DEADLINE_EXCEEDED"
+	default:
+		return jutil.NullObject, fmt.Errorf("Unrecognized context done reason: %v", err)
+	}
+	return jutil.CallStaticObjectMethod(env, jDoneReasonClass, "valueOf", []jutil.Sign{jutil.StringSign}, doneReasonSign, name)
+}