v.io/x/jni: performance optimization for Java native code
In particular:
- moves Java intensive code to Java, to reduce the # of Go->Java calls.
- doesn't detach Java environment from threads (an expensive operation).
Change-Id: I168ff9fb6787d69bcd93a42d17b2501737ae40b7
diff --git a/impl/google/rpc/dispatcher.go b/impl/google/rpc/dispatcher.go
index b2130ff..27312ac 100644
--- a/impl/google/rpc/dispatcher.go
+++ b/impl/google/rpc/dispatcher.go
@@ -10,9 +10,10 @@
"fmt"
"runtime"
+ "v.io/v23/rpc"
"v.io/v23/security"
+
jutil "v.io/x/jni/util"
- jsecurity "v.io/x/jni/v23/security"
)
// #include "jni.h"
@@ -45,40 +46,25 @@
env, freeFunc := jutil.GetEnv()
defer freeFunc()
- // Call Java dispatcher's lookup() method.
- serviceObjectWithAuthorizerSign := jutil.ClassSign("io.v.v23.rpc.ServiceObjectWithAuthorizer")
- jObj, err := jutil.CallObjectMethod(env, d.jDispatcher, "lookup", []jutil.Sign{jutil.StringSign}, serviceObjectWithAuthorizerSign, suffix)
+ dispatcherSign := jutil.ClassSign("io.v.v23.rpc.Dispatcher")
+ result, err := jutil.CallStaticLongArrayMethod(env, jUtilClass, "lookup", []jutil.Sign{dispatcherSign, jutil.StringSign}, d.jDispatcher, suffix)
if err != nil {
return nil, nil, fmt.Errorf("error invoking Java dispatcher's lookup() method: %v", err)
}
- if jObj == nil {
+ if result == nil {
// Lookup returned null, which means that the dispatcher isn't handling the object -
// this is not an error.
return nil, nil, nil
}
-
- // Extract the Java service object and Authorizer.
- jServiceObj, err := jutil.CallObjectMethod(env, jObj, "getServiceObject", nil, jutil.ObjectSign)
- if err != nil {
- return nil, nil, err
+ if len(result) != 2 {
+ return nil, nil, fmt.Errorf("lookup returned %d elems, want 2", len(result))
}
- if jServiceObj == nil {
- return nil, nil, fmt.Errorf("null service object returned by Java's ServiceObjectWithAuthorizer")
+ invoker := *(*rpc.Invoker)(jutil.Ptr(result[0]))
+ jutil.GoUnref(jutil.Ptr(result[0]))
+ authorizer := security.Authorizer(nil)
+ if result[1] != 0 {
+ authorizer = *(*security.Authorizer)(jutil.Ptr(result[1]))
+ jutil.GoUnref(jutil.Ptr(result[1]))
}
- authSign := jutil.ClassSign("io.v.v23.security.Authorizer")
- jAuth, err := jutil.CallObjectMethod(env, jObj, "getAuthorizer", nil, authSign)
- if err != nil {
- return nil, nil, err
- }
-
- // Create Go Invoker and Authorizer.
- i, err := goInvoker((*C.JNIEnv)(env), C.jobject(jServiceObj))
- if err != nil {
- return nil, nil, err
- }
- a, err := jsecurity.GoAuthorizer(env, jAuth)
- if err != nil {
- return nil, nil, err
- }
- return i, a, nil
+ return invoker, authorizer, nil
}
diff --git a/impl/google/rpc/invoker.go b/impl/google/rpc/invoker.go
index 8b4fd10..b802a99 100644
--- a/impl/google/rpc/invoker.go
+++ b/impl/google/rpc/invoker.go
@@ -16,6 +16,7 @@
"v.io/v23/rpc"
"v.io/v23/vdl"
"v.io/v23/vdlroot/signature"
+ "v.io/v23/vom"
jchannel "v.io/x/jni/impl/google/channel"
jutil "v.io/x/jni/util"
@@ -69,14 +70,17 @@
value := new(vdl.Value)
argptrs[i] = &value
}
- // Get the method tags.
- jTags, err := jutil.CallObjectMethod(env, i.jInvoker, "getMethodTags", []jutil.Sign{jutil.StringSign}, jutil.ArraySign(jutil.VdlValueSign), jutil.CamelCase(method))
+ jVomTags, err := jutil.CallStaticObjectMethod(env, jUtilClass, "getMethodTags", []jutil.Sign{invokerSign, jutil.StringSign}, jutil.ArraySign(jutil.ByteArraySign), i.jInvoker, jutil.CamelCase(method))
if err != nil {
return nil, nil, err
}
- tags, err = jutil.GoVDLValueArray(env, jTags)
- if err != nil {
- return nil, nil, err
+ vomTags := jutil.GoByteArrayArray(env, jVomTags)
+ tags = make([]*vdl.Value, len(vomTags))
+ for i, vomTag := range vomTags {
+ var err error
+ if tags[i], err = jutil.VomDecodeToValue(vomTag); err != nil {
+ return nil, nil, err
+ }
}
return
}
@@ -94,18 +98,28 @@
if err != nil {
return nil, err
}
- // Convert Go arguments to Java.
- jArgs, err := i.prepareArgs(env, method, argptrs)
+ vomArgs := make([][]byte, len(argptrs))
+ for i, argptr := range argptrs {
+ arg := interface{}(jutil.DerefOrDie(argptr))
+ var err error
+ if vomArgs[i], err = vom.Encode(arg); err != nil {
+ return nil, err
+ }
+ }
+ jVomArgs := jutil.JByteArrayArray(env, vomArgs)
+ jVomResults, err := jutil.CallStaticObjectMethod(env, jUtilClass, "invoke", []jutil.Sign{invokerSign, contextSign, streamServerCallSign, jutil.StringSign, jutil.ArraySign(jutil.ByteArraySign)}, jutil.ArraySign(jutil.ByteArraySign), i.jInvoker, jContext, jStreamServerCall, jutil.CamelCase(method), jVomArgs)
if err != nil {
return nil, err
}
- // Invoke the method.
- resultarr, err := jutil.CallObjectArrayMethod(env, i.jInvoker, "invoke", []jutil.Sign{contextSign, streamServerCallSign, jutil.StringSign, jutil.ArraySign(jutil.ObjectSign)}, jutil.ObjectSign, jContext, jStreamServerCall, jutil.CamelCase(method), jArgs)
- if err != nil {
- return nil, err
+ vomResults := jutil.GoByteArrayArray(env, jVomResults)
+ results = make([]interface{}, len(vomResults))
+ for i, vomResult := range vomResults {
+ var err error
+ if results[i], err = jutil.VomDecodeToValue(vomResult); err != nil {
+ return nil, err
+ }
}
- // Convert Java results into Go.
- return i.prepareResults(env, method, resultarr)
+ return
}
func (i *invoker) Signature(ctx *context.T, call rpc.ServerCall) ([]signature.Interface, error) {
diff --git a/impl/google/rpc/jni.go b/impl/google/rpc/jni.go
index c4ec851..03e85d8 100644
--- a/impl/google/rpc/jni.go
+++ b/impl/google/rpc/jni.go
@@ -12,6 +12,7 @@
"net"
"unsafe"
+ "v.io/v23/options"
"v.io/v23/rpc"
"v.io/v23/vdl"
"v.io/v23/vom"
@@ -27,6 +28,7 @@
var (
contextSign = jutil.ClassSign("io.v.v23.context.VContext")
+ invokerSign = jutil.ClassSign("io.v.v23.rpc.Invoker")
serverCallSign = jutil.ClassSign("io.v.v23.rpc.ServerCall")
streamServerCallSign = jutil.ClassSign("io.v.v23.rpc.StreamServerCall")
streamSign = jutil.ClassSign("io.v.impl.google.rpc.Stream")
@@ -49,6 +51,8 @@
jServerCallClass C.jclass
// Global reference for io.v.impl.google.rpc.Stream class.
jStreamClass C.jclass
+ // Global reference for io.v.impl.google.rpc.Util class.
+ jUtilClass C.jclass
// Global reference for io.v.v23.rpc.Invoker class.
jInvokerClass C.jclass
// Global reference for io.v.v23.rpc.ListenSpec class.
@@ -129,6 +133,11 @@
return err
}
jStreamClass = C.jclass(class)
+ class, err = jutil.JFindClass(jEnv, "io/v/impl/google/rpc/Util")
+ if err != nil {
+ return err
+ }
+ jUtilClass = C.jclass(class)
class, err = jutil.JFindClass(jEnv, "io/v/v23/rpc/Invoker")
if err != nil {
return err
@@ -330,7 +339,7 @@
}
//export Java_io_v_impl_google_rpc_Client_nativeStartCall
-func Java_io_v_impl_google_rpc_Client_nativeStartCall(env *C.JNIEnv, jClient C.jobject, goPtr C.jlong, jContext C.jobject, jName C.jstring, jMethod C.jstring, jVomArgs C.jobjectArray, jOptions C.jobject) C.jobject {
+func Java_io_v_impl_google_rpc_Client_nativeStartCall(env *C.JNIEnv, jClient C.jobject, goPtr C.jlong, jContext C.jobject, jName C.jstring, jMethod C.jstring, jVomArgs C.jobjectArray, jSkipServerEndpointAuthorization C.jboolean) C.jobject {
name := jutil.GoString(env, jName)
method := jutil.GoString(env, jMethod)
context, err := jcontext.GoContext(env, jContext)
@@ -348,9 +357,13 @@
return nil
}
}
+ var opts []rpc.CallOpt
+ if jSkipServerEndpointAuthorization == C.JNI_TRUE {
+ opts = append(opts, options.SkipServerEndpointAuthorization{})
+ }
// Invoke StartCall
- call, err := (*(*rpc.Client)(jutil.Ptr(goPtr))).StartCall(context, name, method, args)
+ call, err := (*(*rpc.Client)(jutil.Ptr(goPtr))).StartCall(context, name, method, args, opts...)
if err != nil {
jutil.JThrowV(env, err)
return nil
@@ -526,3 +539,25 @@
func Java_io_v_impl_google_rpc_AddressChooser_nativeFinalize(env *C.JNIEnv, jAddressChooser C.jobject, goPtr C.jlong) {
jutil.GoUnref(jutil.Ptr(goPtr))
}
+
+//export Java_io_v_impl_google_rpc_Util_nativeGoInvoker
+func Java_io_v_impl_google_rpc_Util_nativeGoInvoker(env *C.JNIEnv, jUtil C.jclass, jServiceObject C.jobject) C.jlong {
+ invoker, err := goInvoker(env, jServiceObject)
+ if err != nil {
+ jutil.JThrowV(env, err)
+ return C.jlong(0)
+ }
+ jutil.GoRef(&invoker) // Un-refed when the Go invoker is returned to the Go runtime
+ return C.jlong(jutil.PtrValue(&invoker))
+}
+
+//export Java_io_v_impl_google_rpc_Util_nativeGoAuthorizer
+func Java_io_v_impl_google_rpc_Util_nativeGoAuthorizer(env *C.JNIEnv, jUtil C.jclass, jAuthorizer C.jobject) C.jlong {
+ auth, err := jsecurity.GoAuthorizer(env, jAuthorizer)
+ if err != nil {
+ jutil.JThrowV(env, err)
+ return C.jlong(0)
+ }
+ jutil.GoRef(&auth) // Un-refed when the Go authorizer is returned to the Go runtime
+ return C.jlong(jutil.PtrValue(&auth))
+}
diff --git a/util/call.go b/util/call.go
index db66800..b0bd07a 100644
--- a/util/call.go
+++ b/util/call.go
@@ -430,6 +430,15 @@
return GoByteArray(env, jArr), nil
}
+// CallStaticLongArrayMethod calls a static Java method that returns a array of long.
+func CallStaticLongArrayMethod(env interface{}, class interface{}, name string, argSigns []Sign, args ...interface{}) ([]int64, error) {
+ jArr, err := CallStaticObjectMethod(env, class, name, argSigns, ArraySign(LongSign), args...)
+ if err != nil {
+ return nil, err
+ }
+ return GoLongArray(env, jArr), nil
+}
+
// CallStaticIntMethod calls a static Java method that returns an int.
func CallStaticIntMethod(env interface{}, class interface{}, name string, argSigns []Sign, args ...interface{}) (int, error) {
jenv, jclass, jmid, jvalArray, freeFunc, err := setupStaticMethodCall(env, class, name, argSigns, IntSign, args...)
diff --git a/util/jni_wrapper.c b/util/jni_wrapper.c
index f7a72db..ce3f50d 100644
--- a/util/jni_wrapper.c
+++ b/util/jni_wrapper.c
@@ -106,10 +106,18 @@
return (*env)->GetByteArrayElements(env, array, isCopy);
}
+jlong* GetLongArrayElements(JNIEnv* env, jlongArray array, jboolean *isCopy) {
+ return (*env)->GetLongArrayElements(env, array, isCopy);
+}
+
void ReleaseByteArrayElements(JNIEnv* env, jbyteArray array, jbyte* elems, jint mode) {
(*env)->ReleaseByteArrayElements(env, array, elems, mode);
}
+void ReleaseLongArrayElements(JNIEnv* env, jlongArray array, jlong* elems, jint mode) {
+ (*env)->ReleaseLongArrayElements(env, array, elems, mode);
+}
+
void SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start, jsize len, const jbyte* data) {
(*env)->SetByteArrayRegion(env, array, start, len, data);
}
@@ -162,6 +170,10 @@
return (*jvm)->AttachCurrentThread(jvm, (void**) env, args);
}
+jint AttachCurrentThreadAsDaemon(JavaVM* jvm, JNIEnv** env, void* args) {
+ return (*jvm)->AttachCurrentThreadAsDaemon(jvm, (void**) env, args);
+}
+
jint DetachCurrentThread(JavaVM* jvm) {
return (*jvm)->DetachCurrentThread(jvm);
}
diff --git a/util/jni_wrapper.h b/util/jni_wrapper.h
index 449cf52..bd43ac4 100644
--- a/util/jni_wrapper.h
+++ b/util/jni_wrapper.h
@@ -64,8 +64,9 @@
// Sets an element of an Object array.
void SetObjectArrayElement(JNIEnv* env, jobjectArray array, jsize index, jobject obj);
-// Returns the contents of a provided Java byte array as a primitive array.
+// Returns the contents of a provided Java byte/long array as a primitive array.
jbyte* GetByteArrayElements(JNIEnv* env, jbyteArray array, jboolean *isCopy);
+jlong* GetLongArrayElements(JNIEnv* env, jlongArray array, jboolean *isCopy);
// Informs the VM that the native code no longer needs access to elems.
// If necessary, this function copies back all changes made to elems to the
@@ -74,6 +75,7 @@
// JNI_COMMIT copy back the content but do not free the elems buffer
// JNI_ABORT free the buffer without copying back the possible changes
void ReleaseByteArrayElements(JNIEnv* env, jbyteArray array, jbyte* elems, jint mode);
+void ReleaseLongArrayElements(JNIEnv* env, jlongArray array, jlong* elems, jint mode);
// Copies the data from a primitive array into the Java array.
void SetByteArrayRegion(JNIEnv* env, jbyteArray array, jsize start, jsize len, const jbyte* data);
@@ -118,6 +120,10 @@
// Attaches the current thread to a Java VM.
jint AttachCurrentThread(JavaVM* jvm, JNIEnv** env, void* args);
+// Attaches the current thread as a daemon to a Java VM. This means that the
+// Java VM will not wait for this thread to complete before exiting the program.
+jint AttachCurrentThreadAsDaemon(JavaVM* jvm, JNIEnv** env, void* args);
+
// Detaches the current thread from a Java VM.
jint DetachCurrentThread(JavaVM* jvm);
diff --git a/util/util.go b/util/util.go
index 9f470e5..0d2431d 100644
--- a/util/util.go
+++ b/util/util.go
@@ -195,15 +195,12 @@
// to *C.JNIEnv, the above scenario can never occur.
runtime.LockOSThread()
var env *C.JNIEnv
- if C.GetEnv(jVM, &env, C.JNI_VERSION_1_6) == C.JNI_OK {
- return unsafe.Pointer(env), func() {
- runtime.UnlockOSThread()
- }
+ if C.GetEnv(jVM, &env, C.JNI_VERSION_1_6) != C.JNI_OK {
+ // Couldn't get env - attach the thread. Note that we never detach
+ // the thread so the next call to GetEnv on this thread will succeed.
+ C.AttachCurrentThreadAsDaemon(jVM, &env, nil)
}
- // Couldn't get env, attach the thread.
- C.AttachCurrentThread(jVM, &env, nil)
return unsafe.Pointer(env), func() {
- C.DetachCurrentThread(jVM)
runtime.UnlockOSThread()
}
}
@@ -584,6 +581,28 @@
return
}
+// GoLongArray converts the provided Java long array into a Go int64 slice.
+// NOTE: Because CGO creates package-local types and because this method may be
+// invoked from a different package, Java types are passed in an empty interface
+// and then cast into their package local types.
+func GoLongArray(jEnv, jArr interface{}) (ret []int64) {
+ env := getEnv(jEnv)
+ arr := getLongArray(jArr)
+ if arr == nil {
+ return
+ }
+ length := int(C.GetArrayLength(env, C.jarray(arr)))
+ ret = make([]int64, length)
+ elems := C.GetLongArrayElements(env, arr, nil)
+ defer C.ReleaseLongArrayElements(env, arr, elems, C.JNI_ABORT)
+ ptr := elems
+ for i := 0; i < length; i++ {
+ ret[i] = int64(*ptr)
+ ptr = (*C.jlong)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + unsafe.Sizeof(*ptr)))
+ }
+ return
+}
+
// JByteArrayArray converts the provided [][]byte value into a Java array of
// byte arrays.
// NOTE: Because CGO creates package-local types and because this method may be
@@ -793,7 +812,7 @@
defer C.free(unsafe.Pointer(cSignature))
mid := C.GetStaticMethodID(env, class, cName, cSignature)
if err := JExceptionMsg(env); err != nil || mid == C.jmethodID(nil) {
- return nil, fmt.Errorf("couldn't find method %s with a given signature.", name)
+ return nil, fmt.Errorf("couldn't find method %s with a given signature: %s", name, signature)
}
return unsafe.Pointer(mid), nil
}
@@ -876,6 +895,9 @@
func getByteArray(jByteArray interface{}) C.jbyteArray {
return C.jbyteArray(unsafe.Pointer(PtrValue(jByteArray)))
}
+func getLongArray(jLongArray interface{}) C.jlongArray {
+ return C.jlongArray(unsafe.Pointer(PtrValue(jLongArray)))
+}
func getObject(jObj interface{}) C.jobject {
return C.jobject(unsafe.Pointer(PtrValue(jObj)))
}