blob: 6ff5d303cc582fc25d6f3fb07ceba8d3b38c2052 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflectutil
import (
"fmt"
"reflect"
)
// Equal is similar to reflect.DeepEqual, except that it also
// considers the sharing structure for pointers. When reflect.DeepEqual
// encounters pointers it just compares the dereferenced values; we also keep
// track of the pointers themselves and require that if a pointer appears
// multiple places in a, it appears in the same places in b.
func DeepEqual(a, b interface{}, options *DeepEqualOpts) bool {
return deepEqual(reflect.ValueOf(a), reflect.ValueOf(b), &orderInfo{}, options)
}
// TODO(bprosnitz) Implement the debuggable deep equal option.
// TODO(bprosnitz) Add an option to turn pointer sharing on/off
// DeepEqualOpts represents the options configuration for DeepEqual.
type DeepEqualOpts struct {
SliceEqNilEmpty bool
}
// orderInfo tracks pointer ordering information. As we encounter new pointers
// in our a and b values we maintain their ordering information in the slices.
// We use slices rather than maps for efficiency; we typically have a small
// number of pointers and sequential lookup is fast.
type orderInfo struct {
orderA, orderB []uintptr
}
func lookupPtr(items []uintptr, target uintptr) (int, bool) { // (index, seen)
for index, item := range items {
if item == target {
return index, true
}
}
return -1, false
}
// sharingEqual returns equal=true iff the sharing structure between a and b is
// the same, and returns seen=true iff we've seen either a or b before.
func (info *orderInfo) sharingEqual(a, b uintptr) (bool, bool) { // (equal, seen)
indexA, seenA := lookupPtr(info.orderA, a)
indexB, seenB := lookupPtr(info.orderB, b)
if seenA || seenB {
return seenA == seenB && indexA == indexB, seenA || seenB
}
// Neither type has been seen - add to our order slices and return.
info.orderA = append(info.orderA, a)
info.orderB = append(info.orderB, b)
return true, false
}
func deepEqual(a, b reflect.Value, info *orderInfo, options *DeepEqualOpts) bool {
// We only consider sharing via explicit pointers, and ignore sharing via
// slices, maps or pointers to internal data.
if !a.IsValid() || !b.IsValid() {
return a.IsValid() == b.IsValid()
}
if a.Type() != b.Type() {
return false
}
switch a.Kind() {
case reflect.Ptr:
if a.IsNil() || b.IsNil() {
return a.IsNil() == b.IsNil()
}
equal, seen := info.sharingEqual(a.Pointer(), b.Pointer())
if !equal {
return false // a and b are not equal if their sharing isn't equal.
}
if seen {
// Skip the deepEqual call if we've already seen the pointers and they're
// equal, otherwise we'll have an infinite loop for cyclic values.
return true
}
return deepEqual(a.Elem(), b.Elem(), info, options)
case reflect.Array:
if a.Len() != b.Len() {
return false
}
for ix := 0; ix < a.Len(); ix++ {
if !deepEqual(a.Index(ix), b.Index(ix), info, options) {
return false
}
}
return true
case reflect.Slice:
if !options.SliceEqNilEmpty {
if a.IsNil() || b.IsNil() {
return a.IsNil() == b.IsNil()
}
}
if a.Len() != b.Len() {
return false
}
for ix := 0; ix < a.Len(); ix++ {
if !deepEqual(a.Index(ix), b.Index(ix), info, options) {
return false
}
}
return true
case reflect.Map:
if a.IsNil() || b.IsNil() {
return a.IsNil() == b.IsNil()
}
if a.Len() != b.Len() {
return false
}
for _, key := range a.MapKeys() {
if !deepEqual(a.MapIndex(key), b.MapIndex(key), info, options) {
return false
}
}
return true
case reflect.Struct:
for fx := 0; fx < a.NumField(); fx++ {
if !deepEqual(a.Field(fx), b.Field(fx), info, options) {
return false
}
}
return true
case reflect.Interface:
if a.IsNil() || b.IsNil() {
return a.IsNil() == b.IsNil()
}
return deepEqual(a.Elem(), b.Elem(), info, options)
// Ideally we would add a default clause here that would just return
// a.Interface() == b.Interface(), but that panics if we're dealing with
// unexported fields. Instead we check each primitive type.
case reflect.Bool:
return a.Bool() == b.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() == b.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() == b.Uint()
case reflect.Float32, reflect.Float64:
return a.Float() == b.Float()
case reflect.Complex64, reflect.Complex128:
return a.Complex() == b.Complex()
case reflect.String:
return a.String() == b.String()
case reflect.UnsafePointer:
return a.Pointer() == b.Pointer()
default:
panic(fmt.Errorf("SharingDeepEqual unhandled kind %v type %q", a.Kind(), a.Type()))
}
}