v.io/x/jni: make java argument handling more sane

Change-Id: I1053a088301d5a1c3039ed3957c439b63445ba39
diff --git a/impl/google/rpc/util.go b/impl/google/rpc/util.go
index 8b7c892..6060a2f 100644
--- a/impl/google/rpc/util.go
+++ b/impl/google/rpc/util.go
@@ -240,13 +240,12 @@
 	if err != nil {
 		return nil, err
 	}
-	jProxy := jutil.JString(jEnv, spec.Proxy)
 	jChooser, err := JavaAddressChooser(jEnv, spec.AddressChooser)
 	if err != nil {
 		return nil, err
 	}
 	addressSign := jutil.ClassSign("io.v.v23.rpc.ListenSpec$Address")
-	jSpec, err := jutil.NewObject(jEnv, jListenSpecClass, []jutil.Sign{jutil.ArraySign(addressSign), jutil.StringSign, addressChooserSign}, jAddrs, jProxy, jChooser)
+	jSpec, err := jutil.NewObject(jEnv, jListenSpecClass, []jutil.Sign{jutil.ArraySign(addressSign), jutil.StringSign, addressChooserSign}, jAddrs, spec.Proxy, jChooser)
 	if err != nil {
 		return nil, err
 	}
diff --git a/util/call.go b/util/call.go
index 4497d95..735fb28 100644
--- a/util/call.go
+++ b/util/call.go
@@ -8,8 +8,6 @@
 
 import (
 	"fmt"
-	"reflect"
-	"time"
 	"unsafe"
 )
 
@@ -26,116 +24,6 @@
 //
 import "C"
 
-// jArg converts a Go argument to a Java argument.  It uses the provided sign to
-// validate that the argument is of a compatible type.
-func jArg(env *C.JNIEnv, v interface{}, sign Sign) (unsafe.Pointer, bool) {
-	rv := reflect.ValueOf(v)
-	if !rv.IsValid() { // nil value
-		jv := C.jobject(nil)
-		return unsafe.Pointer(&jv), true
-	}
-	if rv.Type() == reflect.TypeOf(time.Time{}) {
-		if sign != DateTimeSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv, err := JTime(env, rv.Interface().(time.Time))
-		if err != nil {
-			return unsafe.Pointer(nil), false
-		}
-		return unsafe.Pointer(&jv), true
-	}
-	if rv.Type() == reflect.TypeOf(time.Duration(0)) {
-		if sign != DurationSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv, err := JDuration(env, rv.Interface().(time.Duration))
-		if err != nil {
-			return unsafe.Pointer(nil), false
-		}
-		return unsafe.Pointer(&jv), true
-	}
-	if rv.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
-		if sign != VExceptionSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv, err := JVException(env, rv.Interface().(error))
-		if err != nil {
-			return unsafe.Pointer(nil), false
-		}
-		return unsafe.Pointer(&jv), true
-	}
-	if rv.Kind() == reflect.Ptr || rv.Kind() == reflect.UnsafePointer {
-		rv = reflect.ValueOf(rv.Pointer()) // Convert the pointer's address to a uintptr
-	}
-	switch rv.Kind() {
-	case reflect.Bool:
-		if sign != BoolSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv := C.jboolean(C.JNI_FALSE)
-		if rv.Bool() {
-			jv = C.JNI_TRUE
-		}
-		return unsafe.Pointer(&jv), true
-	case reflect.Int, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Uint, reflect.Uint32, reflect.Uint16, reflect.Uint8:
-		if !isSignOneOf(sign, []Sign{ByteSign, ShortSign, IntSign, LongSign}) {
-			return unsafe.Pointer(nil), false
-		}
-		jv := C.jint(rv.Int())
-		return unsafe.Pointer(&jv), true
-	case reflect.Int64:
-		if sign != LongSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv := C.jlong(rv.Int())
-		return unsafe.Pointer(&jv), true
-	case reflect.Uint64:
-		if sign != LongSign {
-			return unsafe.Pointer(nil), false
-		}
-		jv := C.jlong(rv.Uint())
-		return unsafe.Pointer(&jv), true
-	case reflect.Uintptr:
-		if isSignOneOf(sign, []Sign{ByteSign, BoolSign, CharSign, ShortSign, IntSign, FloatSign, DoubleSign}) {
-			return unsafe.Pointer(nil), false
-		}
-		jv := C.jlong(rv.Uint())
-		return unsafe.Pointer(&jv), true
-	case reflect.String:
-		if sign != StringSign {
-			return unsafe.Pointer(nil), false
-		}
-		// JString allocates the strings locally, so they are freed automatically when we return to Java.
-		jv := JString(env, rv.String())
-		if jv == nil {
-			return unsafe.Pointer(nil), false
-		}
-		return unsafe.Pointer(&jv), true
-	case reflect.Slice, reflect.Array:
-		switch rv.Type().Elem().Kind() {
-		case reflect.Uint8:
-			if sign != ArraySign(ByteSign) {
-				return unsafe.Pointer(nil), false
-			}
-			bs := rv.Interface().([]byte)
-			jv := JByteArray(env, bs)
-			return unsafe.Pointer(&jv), true
-		case reflect.String:
-			if sign != ArraySign(StringSign) {
-				return unsafe.Pointer(nil), false
-			}
-			// TODO(bprosnitz) We should handle objects by calling jArg recursively. We need a way to get the sign of the target type or treat it as an Object for non-string types.
-			strs := rv.Interface().([]string)
-			jv := JStringArray(env, strs)
-			return unsafe.Pointer(&jv), true
-		default:
-			return unsafe.Pointer(nil), false
-		}
-	default:
-		return unsafe.Pointer(nil), false
-	}
-}
-
 // jArgArray converts a slice of Go args to an array of Java args.  It uses the provided slice of
 // Signs to validate that the arguments are of compatible types.
 func jArgArray(env *C.JNIEnv, args []interface{}, argSigns []Sign) (jArr *C.jvalue, free func(), err error) {
@@ -145,11 +33,11 @@
 	jvalueArr := C.allocJValueArray(C.int(len(args)))
 	for i, arg := range args {
 		sign := argSigns[i]
-		jValPtr, ok := jArg(env, arg, sign)
+		jVal, ok := jValue(env, arg, sign)
 		if !ok {
 			return (*C.jvalue)(nil), nil, fmt.Errorf("couldn't get Java value for argument #%d [%v] of expected type %v", i, arg, sign)
 		}
-		C.setJValueArrayElement(jvalueArr, C.int(i), *(*C.jvalue)(jValPtr))
+		C.setJValueArrayElement(jvalueArr, C.int(i), jVal)
 	}
 	freeFunc := func() {
 		C.free(unsafe.Pointer(jvalueArr))
@@ -195,7 +83,11 @@
 	}
 	jmid = C.jmethodID(id)
 	jvalArray, freeFunc, err = jArgArray(jenv, args, argSigns)
+	if err != nil {
+		err = fmt.Errorf("error creating arguments for method %s: %v", name, err)
+	}
 	return
+
 }
 
 // setupStaticMethodCall performs the shared preparation operations between
@@ -211,6 +103,9 @@
 	}
 	jmid = C.jmethodID(id)
 	jvalArray, freeFunc, err = jArgArray(jenv, args, argSigns)
+	if err != nil {
+		err = fmt.Errorf("error creating arguments for method %s: %v", name, err)
+	}
 	return
 }
 
@@ -306,11 +201,11 @@
 		if err != nil {
 			return nil, err
 		}
-		jValue, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
+		jVal, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
 		if err != nil {
 			return nil, err
 		}
-		ret[jKey] = jValue
+		ret[jKey] = jVal
 	}
 	return ret, nil
 }
@@ -347,11 +242,11 @@
 		if err != nil {
 			return nil, err
 		}
-		jValue, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
+		jVal, err := CallObjectMethod(env, jEntry, "getValue", nil, ObjectSign)
 		if err != nil {
 			return nil, err
 		}
-		ret[jKey] = append(ret[jKey], jValue)
+		ret[jKey] = append(ret[jKey], jVal)
 	}
 	return ret, nil
 }
diff --git a/util/value.go b/util/value.go
new file mode 100644
index 0000000..ab067d3
--- /dev/null
+++ b/util/value.go
@@ -0,0 +1,251 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build java android
+
+package util
+
+import (
+	"reflect"
+	"time"
+	"unsafe"
+)
+
+// #include "jni_wrapper.h"
+//
+// jvalue jBoolValue(jboolean val) {
+//   jvalue ret = { .z = val };
+//   return ret;
+// }
+// jvalue jByteValue(jbyte val) {
+//   jvalue ret = { .b = val };
+//   return ret;
+// }
+// jvalue jCharValue(jchar val) {
+//   jvalue ret = { .c = val };
+//   return ret;
+// }
+// jvalue jShortValue(jshort val) {
+//   jvalue ret = { .s = val };
+//   return ret;
+// }
+// jvalue jIntValue(jint val) {
+//   jvalue ret = { .i = val };
+//   return ret;
+// }
+// jvalue jLongValue(jlong val) {
+//   jvalue ret = { .j = val };
+//   return ret;
+// }
+// jvalue jFloatValue(jfloat val) {
+//   jvalue ret = { .f = val };
+//   return ret;
+// }
+// jvalue jDoubleValue(jdouble val) {
+//   jvalue ret = { .d = val };
+//   return ret;
+// }
+// jvalue jObjectValue(jobject val) {
+//   jvalue ret = { .l = val };
+//   return ret;
+// }
+import "C"
+
+var errJValue = C.jObjectValue(nil)
+
+// jValue converts a Go value into a Java value with the given sign.
+func jValue(env *C.JNIEnv, v interface{}, sign Sign) (C.jvalue, bool) {
+	switch sign {
+	case BoolSign:
+		return jBoolValue(v)
+	case ByteSign:
+		return jByteValue(v)
+	case CharSign:
+		return jCharValue(v)
+	case ShortSign:
+		return jShortValue(v)
+	case IntSign:
+		return jIntValue(v)
+	case LongSign:
+		return jLongValue(v)
+	case StringSign:
+		return jStringValue(env, v)
+	case DateTimeSign:
+		return jDateTimeValue(env, v)
+	case DurationSign:
+		return jDurationValue(env, v)
+	case VExceptionSign:
+		return jVExceptionValue(env, v)
+	case ArraySign(ByteSign):
+		return jByteArrayValue(env, v)
+	case ArraySign(StringSign):
+		return jStringArrayValue(env, v)
+	default:
+		return jObjectValue(v)
+	}
+}
+
+func jBoolValue(v interface{}) (C.jvalue, bool) {
+	val, ok := v.(bool)
+	if !ok {
+		return errJValue, false
+	}
+	jBool := C.jboolean(C.JNI_FALSE)
+	if val {
+		jBool = C.jboolean(C.JNI_TRUE)
+	}
+	return C.jBoolValue(jBool), true
+}
+
+func jByteValue(v interface{}) (C.jvalue, bool) {
+	val, ok := intValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	return C.jByteValue(C.jbyte(val)), true
+}
+
+func jCharValue(v interface{}) (C.jvalue, bool) {
+	val, ok := intValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	return C.jCharValue(C.jchar(val)), true
+}
+
+func jShortValue(v interface{}) (C.jvalue, bool) {
+	val, ok := intValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	return C.jShortValue(C.jshort(val)), true
+}
+
+func jIntValue(v interface{}) (C.jvalue, bool) {
+	val, ok := intValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	return C.jIntValue(C.jint(val)), true
+}
+
+func jLongValue(v interface{}) (C.jvalue, bool) {
+	val, ok := intValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	return C.jLongValue(C.jlong(val)), true
+}
+
+func jStringValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	str, ok := v.(string)
+	if !ok {
+		return errJValue, false
+	}
+	return jObjectValue(JString(env, str))
+}
+
+func jDateTimeValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	t, ok := v.(time.Time)
+	if !ok {
+		return errJValue, false
+	}
+	jTime, err := JTime(env, t)
+	if err != nil {
+		return errJValue, false
+	}
+	return jObjectValue(jTime)
+}
+
+func jDurationValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	d, ok := v.(time.Duration)
+	if !ok {
+		return errJValue, false
+	}
+	jDuration, err := JDuration(env, d)
+	if err != nil {
+		return errJValue, false
+	}
+	return jObjectValue(jDuration)
+}
+
+func jVExceptionValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	err, ok := v.(error)
+	if !ok {
+		return errJValue, false
+	}
+	jVException, err := JVException(env, err)
+	if err != nil {
+		return errJValue, false
+	}
+	return jObjectValue(jVException)
+}
+
+func jByteArrayValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	arr, ok := v.([]byte)
+	if !ok {
+		return errJValue, false
+	}
+	return jObjectValue(JByteArray(env, arr))
+}
+
+func jStringArrayValue(env *C.JNIEnv, v interface{}) (C.jvalue, bool) {
+	arr, ok := v.([]string)
+	if !ok {
+		return errJValue, false
+	}
+	return jObjectValue(JStringArray(env, arr))
+}
+
+func jObjectValue(v interface{}) (C.jvalue, bool) {
+	rv := reflect.ValueOf(v)
+	if !rv.IsValid() { // nil value
+		return C.jObjectValue(nil), true
+	}
+	ptr, ok := ptrValue(v)
+	if !ok {
+		return errJValue, false
+	}
+	// TODO(spetrovic): figure out a way to not use unsafe.Pointer here.
+	return C.jObjectValue(C.jobject(unsafe.Pointer(ptr))), true
+}
+
+func intValue(v interface{}) (int64, bool) {
+	switch val := v.(type) {
+	case int64:
+		return val, true
+	case int:
+		return int64(val), true
+	case int32:
+		return int64(val), true
+	case int16:
+		return int64(val), true
+	case int8:
+		return int64(val), true
+	case uint64:
+		return int64(val), true
+	case uint:
+		return int64(val), true
+	case uint32:
+		return int64(val), true
+	case uint16:
+		return int64(val), true
+	case uint8:
+		return int64(val), true
+	default:
+		return 0, false
+	}
+}
+
+func ptrValue(v interface{}) (uintptr, bool) {
+	rv := reflect.ValueOf(v)
+	switch rv.Kind() {
+	case reflect.Ptr, reflect.UnsafePointer:
+		return rv.Pointer(), true
+	case reflect.Uintptr:
+		return uintptr(rv.Uint()), true
+	default:
+		return 0, false
+	}
+}
diff --git a/v23/security/jni.go b/v23/security/jni.go
index 5dce22b..08ddff0 100644
--- a/v23/security/jni.go
+++ b/v23/security/jni.go
@@ -279,7 +279,7 @@
 		jutil.JThrowV(env, err)
 		return nil
 	}
-	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, C.jobject(nil), C.jobject(nil))
+	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, C.jobject(nil), C.jobject(nil))
 	if err != nil {
 		jutil.JThrowV(env, err)
 		return nil
@@ -310,7 +310,7 @@
 		jutil.JThrowV(env, err)
 		return nil
 	}
-	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, jStore, jRoots)
+	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, jStore, jRoots)
 	if err != nil {
 		jutil.JThrowV(env, err)
 		return nil
@@ -356,7 +356,7 @@
 		jutil.JThrowV(env, err)
 		return nil
 	}
-	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, &principal, jSigner, C.jobject(nil), C.jobject(nil))
+	jPrincipal, err := jutil.NewObject(env, jVPrincipalImplClass, []jutil.Sign{jutil.LongSign, signerSign, blessingStoreSign, blessingRootsSign}, int64(jutil.PtrValue(&principal)), jSigner, C.jobject(nil), C.jobject(nil))
 	if err != nil {
 		jutil.JThrowV(env, err)
 		return nil