v.io/x/jni: make java argument handling more sane
Change-Id: I1053a088301d5a1c3039ed3957c439b63445ba39
diff --git a/impl/google/rpc/util.go b/impl/google/rpc/util.go
index 8b7c892..6060a2f 100644
--- a/impl/google/rpc/util.go
+++ b/impl/google/rpc/util.go
@@ -240,13 +240,12 @@
if err != nil {
return nil, err
}
- jProxy := jutil.JString(jEnv, spec.Proxy)
jChooser, err := JavaAddressChooser(jEnv, spec.AddressChooser)
if err != nil {
return nil, err
}
addressSign := jutil.ClassSign("io.v.v23.rpc.ListenSpec$Address")
- jSpec, err := jutil.NewObject(jEnv, jListenSpecClass, []jutil.Sign{jutil.ArraySign(addressSign), jutil.StringSign, addressChooserSign}, jAddrs, jProxy, jChooser)
+ jSpec, err := jutil.NewObject(jEnv, jListenSpecClass, []jutil.Sign{jutil.ArraySign(addressSign), jutil.StringSign, addressChooserSign}, jAddrs, spec.Proxy, jChooser)
if err != nil {
return nil, err
}
diff --git a/util/call.go b/util/call.go
index 4497d95..735fb28 100644
--- a/util/call.go
+++ b/util/call.go
@@ -8,8 +8,6 @@
import (
"fmt"
- "reflect"
- "time"
"unsafe"
)
@@ -26,116 +24,6 @@
//
import "C"
-// jArg converts a Go argument to a Java argument. It uses the provided sign to
-// validate that the argument is of a compatible type.
-func jArg(env *C.JNIEnv, v interface{}, sign Sign) (unsafe.Pointer, bool) {
- rv := reflect.ValueOf(v)
- if !rv.IsValid() { // nil value
- jv := C.jobject(nil)
- return unsafe.Pointer(&jv), true
- }
- if rv.Type() == reflect.TypeOf(time.Time{}) {
- if sign != DateTimeSign {
- return unsafe.Pointer(nil), false
- }
- jv, err := JTime(env, rv.Interface().(time.Time))
- if err != nil {
- return unsafe.Pointer(nil), false
- }
- return unsafe.Pointer(&jv), true
- }
- if rv.Type() == reflect.TypeOf(time.Duration(0)) {
- if sign != DurationSign {
- return unsafe.Pointer(nil), false
- }
- jv, err := JDuration(env, rv.Interface().(time.Duration))
- if err != nil {
- return unsafe.Pointer(nil), false
- }
- return unsafe.Pointer(&jv), true
- }
- if rv.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
- if sign != VExceptionSign {
- return unsafe.Pointer(nil), false
- }
- jv, err := JVException(env, rv.Interface().(error))
- if err != nil {
- return unsafe.Pointer(nil), false
- }
- return unsafe.Pointer(&jv), true
- }
- if rv.Kind() == reflect.Ptr || rv.Kind() == reflect.UnsafePointer {
- rv = reflect.ValueOf(rv.Pointer()) // Convert the pointer's address to a uintptr
- }
- switch rv.Kind() {
- case reflect.Bool:
- if sign != BoolSign {
- return unsafe.Pointer(nil), false
- }
- jv := C.jboolean(C.JNI_FALSE)
- if rv.Bool() {
- jv = C.JNI_TRUE
- }
- return unsafe.Pointer(&jv), true
- case reflect.Int, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Uint, reflect.Uint32, reflect.Uint16, reflect.Uint8:
- if !isSignOneOf(sign, []Sign{ByteSign, ShortSign, IntSign, LongSign}) {
- return unsafe.Pointer(nil), false
- }
- jv := C.jint(rv.Int())
- return unsafe.Pointer(&jv), true
- case reflect.Int64:
- if sign != LongSign {
- return unsafe.Pointer(nil), false
- }
- jv := C.jlong(rv.Int())
- return unsafe.Pointer(&jv), true
- case reflect.Uint64:
- if sign != LongSign {
- return unsafe.Pointer(nil), false
- }
- jv := C.jlong(rv.Uint())
- return unsafe.Pointer(&jv), true
- case reflect.Uintptr:
- if isSignOneOf(sign, []Sign{ByteSign, BoolSign, CharSign, ShortSign, IntSign, FloatSign, DoubleSign}) {
- return unsafe.Pointer(nil), false
- }
- jv := C.jlong(rv.Uint())
- return unsafe.Pointer(&jv), true
- case reflect.String:
- if sign != StringSign {
- return unsafe.Pointer(nil), false
- }
- // JString allocates the strings locally, so they are freed automatically when we return to Java.
- jv := JString(env, rv.String())
- if jv == nil {
- return unsafe.Pointer(nil), false
- }
- return unsafe.Pointer(&jv), true
- case reflect.Slice, reflect.Array:
- switch rv.Type().Elem().Kind() {
- case reflect.Uint8:
- if sign != ArraySign(ByteSign) {
- return unsafe.Pointer(nil), false
- }
- bs := rv.Interface().([]byte)
- jv := JByteArray(env, bs)
- return unsafe.Pointer(&jv), true
- case reflect.String:
- if sign != ArraySign(StringSign) {
- return unsafe.Pointer(nil), false
- }
- // TODO(bprosnitz) We should handle objects by calling jArg recursively. We need a way to get the sign of the target type or treat it as an Object for non-string types.
- strs := rv.Interface().([]string)
- jv := JStringArray(env, strs)
- return unsafe.Pointer(&jv), true
- default:
- return unsafe.Pointer(nil), false
- }
- default:
- return unsafe.Pointer(nil), false
- }
-}
-
// jArgArray converts a slice of Go args to an array of Java args. It uses the provided slice of
// Signs to validate that the arguments are of compatible types.
func jArgArray(env *C.JNIEnv, args []interface{}, argSigns []Sign) (jArr *C.jvalue, free func(), err error) {
@@ -145,11 +33,11 @@
jvalueArr := C.allocJValueArray(C.int(len(args)))
for i, arg := range args {
sign := argSigns[i]
- jValPtr, ok := jArg(env, arg, sign)
+ jVal, ok := jValue(env, arg, sign)
if !ok {
return (*C.jvalue)(nil), nil, fmt.Errorf("couldn't get Java value for argument #%d [%v] of expected type %v", i, arg, sign)
}
- C.setJValueArrayElement(jvalueArr, C.int(i), *(*C.jvalue)(jValPtr))
+ C.setJValueArrayElement(jvalueArr, C.int(i), jVal)
}
freeFunc := func() {
C.free(unsafe.Pointer(jvalueArr))
@@ -195,7 +83,11 @@
}
jmid = C.jmethodID(id)
jvalArray, freeFunc, err = jArgArray(jenv, args, argSigns)
+ if err != nil {
+ err = fmt.Errorf("error creating arguments for method %s: %v", name, err)
+ }
return
+
}
// setupStaticMethodCall performs the shared preparation operations between
@@ -211,6 +103,9 @@
}
jmid = C.jmethodID(id)
jvalArray, freeFunc, err = jArgArray(jenv, args, argSigns)
+ if err != nil {
+ err = fmt.Errorf("error creating arguments for method %s: %v", name, err)
+ }
return
}
@@ -306,11 +201,11 @@
if err != nil {
return nil, err
}
- jValue, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
+ jVal, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
if err != nil {
return nil, err
}
- ret[jKey] = jValue
+ ret[jKey] = jVal
}
return ret, nil
}
@@ -347,11 +242,11 @@
if err != nil {
return nil, err
}
- jValue, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
+ jVal, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
if err != nil {
return nil, err
}
- ret[jKey] = append(ret[jKey], jValue)
+ ret[jKey] = append(ret[jKey], jVal)
}
return ret, nil
}
diff --git a/util/value.go b/util/value.go
new file mode 100644
index 0000000..ab067d3
--- /dev/null
+++ b/util/value.go
@@ -0,0 +1,251 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build java android
+
+package util
+
+import (
+ "reflect"
+ "time"
+ "unsafe"
+)
+
+// #include "jni_wrapper.h"
+//
+// jvalue jBoolValue(jboolean val) {
+// jvalue ret = { .z = val };
+// return ret;
+// }
+// jvalue jByteValue(jbyte val) {
+// jvalue ret = { .b = val };
+// return ret;
+// }
+// jvalue jCharValue(jchar val) {
+// jvalue ret = { .c = val };
+// return ret;
+// }
+// jvalue jShortValue(jshort val) {
+// jvalue ret = { .s = val };
+// return ret;
+// }
+// jvalue jIntValue(jint val) {
+// jvalue ret = { .i = val };
+// return ret;
+// }
+// jvalue jLongValue(jlong val) {
+// jvalue ret = { .j = val };
+// return ret;
+// }
+// jvalue jFloatValue(jfloat val) {
+// jvalue ret = { .f = val };
+// return ret;
+// }
+// jvalue jDoubleValue(jdouble val) {
+// jvalue ret = { .d = val };
+// return ret;
+// }
+// jvalue jObjectValue(jobject val) {
+// jvalue ret = { .l = val };
+// return ret;
+// }
+import "C"
+
+var errJValue = C.jObjectValue(nil)
+
+// jValue converts a Go value into a Java value with the given sign.
+func jValue(env *C.JNIEnv, v interface{}, sign Sign) (C.jvalue, bool) {
+ switch sign {
+ case BoolSign:
+ return jBoolValue(v)
+ case ByteSign:
+ return jByteValue(v)
+ case CharSign:
+ return jCharValue(v)
+ case ShortSign:
+ return jShortValue(v)
+ case IntSign:
+ return jIntValue(v)
+ case LongSign:
+ return jLongValue(v)
+ case StringSign:
+ return jStringValue(env, v)
+ case DateTimeSign:
+ return jDateTimeValue(env, v)
+ case DurationSign:
+ return jDurationValue(env, v)
+ case VExceptionSign:
+ return jVExceptionValue(env, v)
+ case ArraySign(ByteSign):
+ return jByteArrayValue(env, v)
+ case ArraySign(StringSign):
+ return jStringArrayValue(env, v)
+ default:
+ return jObjectValue(v)
+ }
+}
+
+func jBoolValue(v interface{}) (C.jvalue, bool) {
+ val, ok := v.(bool)
+ if !ok {
+ return errJValue, false
+ }
+ jBool := C.jboolean(C.JNI_FALSE)
+ if val {
+ jBool = C.jboolean(C.JNI_TRUE)
+ }
+ return C.jBoolValue(jBool), true
+}
+
+func jByteValue(v interface{}) (C.jvalue, bool) {
+ val, ok := intValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ return C.jByteValue(C.jbyte(val)), true
+}
+
+func jCharValue(v interface{}) (C.jvalue, bool) {
+ val, ok := intValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ return C.jCharValue(C.jchar(val)), true
+}
+
+func jShortValue(v interface{}) (C.jvalue, bool) {
+ val, ok := intValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ return C.jShortValue(C.jshort(val)), true
+}
+
+func jIntValue(v interface{}) (C.jvalue, bool) {
+ val, ok := intValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ return C.jIntValue(C.jint(val)), true
+}
+
+func jLongValue(v interface{}) (C.jvalue, bool) {
+ val, ok := intValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ return C.jLongValue(C.jlong(val)), true
+}
+
+func jStringValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ str, ok := v.(string)
+ if !ok {
+ return errJValue, false
+ }
+ return jObjectValue(JString(env, str))
+}
+
+func jDateTimeValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ t, ok := v.(time.Time)
+ if !ok {
+ return errJValue, false
+ }
+ jTime, err := JTime(env, t)
+ if err != nil {
+ return errJValue, false
+ }
+ return jObjectValue(jTime)
+}
+
+func jDurationValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ d, ok := v.(time.Duration)
+ if !ok {
+ return errJValue, false
+ }
+ jDuration, err := JDuration(env, d)
+ if err != nil {
+ return errJValue, false
+ }
+ return jObjectValue(jDuration)
+}
+
+func jVExceptionValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ err, ok := v.(error)
+ if !ok {
+ return errJValue, false
+ }
+ jVException, err := JVException(env, err)
+ if err != nil {
+ return errJValue, false
+ }
+ return jObjectValue(jVException)
+}
+
+func jByteArrayValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ arr, ok := v.([]byte)
+ if !ok {
+ return errJValue, false
+ }
+ return jObjectValue(JByteArray(env, arr))
+}
+
+func jStringArrayValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+ arr, ok := v.([]string)
+ if !ok {
+ return errJValue, false
+ }
+ return jObjectValue(JStringArray(env, arr))
+}
+
+func jObjectValue(v interface{}) (C.jvalue, bool) {
+ rv := reflect.ValueOf(v)
+ if !rv.IsValid() { // nil value
+ return C.jObjectValue(nil), true
+ }
+ ptr, ok := ptrValue(v)
+ if !ok {
+ return errJValue, false
+ }
+ // TODO(spetrovic): figure out a way to not use unsafe.Pointer here.
+ return C.jObjectValue(C.jobject(unsafe.Pointer(ptr))), true
+}
+
+func intValue(v interface{}) (int64, bool) {
+ switch val := v.(type) {
+ case int64:
+ return val, true
+ case int:
+ return int64(val), true
+ case int32:
+ return int64(val), true
+ case int16:
+ return int64(val), true
+ case int8:
+ return int64(val), true
+ case uint64:
+ return int64(val), true
+ case uint:
+ return int64(val), true
+ case uint32:
+ return int64(val), true
+ case uint16:
+ return int64(val), true
+ case uint8:
+ return int64(val), true
+ default:
+ return 0, false
+ }
+}
+
+func ptrValue(v interface{}) (uintptr, bool) {
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Ptr, reflect.UnsafePointer:
+ return rv.Pointer(), true
+ case reflect.Uintptr:
+ return uintptr(rv.Uint()), true
+ default:
+ return 0, false
+ }
+}
diff --git a/v23/security/jni.go b/v23/security/jni.go
index 5dce22b..08ddff0 100644
--- a/v23/security/jni.go
+++ b/v23/security/jni.go
@@ -279,7 +279,7 @@
jutil.JThrowV(env, err)
return nil
}
- jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, C.jobject(nil), C.jobject(nil))
+ jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, C.jobject(nil), C.jobject(nil))
if err != nil {
jutil.JThrowV(env, err)
return nil
@@ -310,7 +310,7 @@
jutil.JThrowV(env, err)
return nil
}
- jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, jStore, jRoots)
+ jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, jStore, jRoots)
if err != nil {
jutil.JThrowV(env, err)
return nil
@@ -356,7 +356,7 @@
jutil.JThrowV(env, err)
return nil
}
- jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, C.jobject(nil), C.jobject(nil))
+ jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, C.jobject(nil), C.jobject(nil))
if err != nil {
jutil.JThrowV(env, err)
return nil