blob: 3517bb256ed9fd5f94b34de70c7341b1e6dfef1b [file] [log] [blame]
Jiri Simsad7616c92015-03-24 23:44:30 -07001// Copyright 2015 The Vanadium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
Jiri Simsa5293dcb2014-05-10 09:56:38 -07005package reflectutil
6
7import (
8 "reflect"
9 "sort"
10)
11
12// AreComparable is a helper to call AreComparableTypes.
13func AreComparable(a, b interface{}) bool {
14 return AreComparableTypes(reflect.TypeOf(a), reflect.TypeOf(b))
15}
16
17// AreComparableTypes returns true iff a and b are comparable types: bools,
18// strings and numbers, and composites using arrays, slices, structs or
19// pointers.
20func AreComparableTypes(a, b reflect.Type) bool {
21 return areComparable(a, b, make(map[reflect.Type]bool))
22}
23
24func areComparable(a, b reflect.Type, seen map[reflect.Type]bool) bool {
25 if a.Kind() != b.Kind() {
26 if isUint(a) && isUint(b) || isInt(a) && isInt(b) || isFloat(a) && isFloat(b) || isComplex(a) && isComplex(b) {
27 return true // Special-case for comparable numbers.
28 }
29 return false // Different kinds are incomparable.
30 }
31
32 // Deal with cyclic types.
33 if seen[a] {
34 return true
35 }
36 seen[a] = true
37
38 switch a.Kind() {
39 case reflect.Bool, reflect.String,
40 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
41 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
42 reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
43 return true
44 case reflect.Array, reflect.Slice, reflect.Ptr:
45 return areComparable(a.Elem(), b.Elem(), seen)
46 case reflect.Struct:
47 if a.NumField() != b.NumField() {
48 return false
49 }
50 for fx := 0; fx < a.NumField(); fx++ {
51 af := a.Field(fx)
52 bf := b.Field(fx)
53 if af.Name != bf.Name || af.PkgPath != bf.PkgPath {
54 return false
55 }
56 if !areComparable(af.Type, bf.Type, seen) {
57 return false
58 }
59 }
60 return true
61 default:
62 // Unhandled: Map, Interface, Chan, Func, UnsafePointer
63 return false
64 }
65}
66
67func isUint(rt reflect.Type) bool {
68 switch rt.Kind() {
69 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
70 return true
71 }
72 return false
73}
74
75func isInt(rt reflect.Type) bool {
76 switch rt.Kind() {
77 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
78 return true
79 }
80 return false
81}
82
83func isFloat(rt reflect.Type) bool {
84 switch rt.Kind() {
85 case reflect.Float32, reflect.Float64:
86 return true
87 }
88 return false
89}
90
91func isComplex(rt reflect.Type) bool {
92 switch rt.Kind() {
93 case reflect.Complex64, reflect.Complex128:
94 return true
95 }
96 return false
97}
98
99// Less is a helper to call LessValues.
100func Less(a, b interface{}) bool {
101 return LessValues(reflect.ValueOf(a), reflect.ValueOf(b))
102}
103
104// LessValues returns true iff a and b are comparable and a < b. If a and b are
105// incomparable an arbitrary value is returned. Cyclic values are not handled;
106// if a and b are cyclic and equal, this will infinite loop. Arrays, slices and
107// structs use lexicographic ordering, and complex numbers compare real before
108// imaginary.
109func LessValues(a, b reflect.Value) bool {
110 if a.Kind() != b.Kind() {
111 return false // Different kinds are incomparable.
112 }
113 switch a.Kind() {
114 case reflect.Bool:
115 return lessBool(a.Bool(), b.Bool())
116 case reflect.String:
117 return a.String() < b.String()
118 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
119 return a.Uint() < b.Uint()
120 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
121 return a.Int() < b.Int()
122 case reflect.Float32, reflect.Float64:
123 return a.Float() < b.Float()
124 case reflect.Complex64, reflect.Complex128:
125 return lessComplex(a.Complex(), b.Complex())
126 case reflect.Array:
127 return compareArray(a, b) == -1
128 case reflect.Slice:
129 return compareSlice(a, b) == -1
130 case reflect.Struct:
131 return compareStruct(a, b) == -1
132 case reflect.Ptr:
133 if a.IsNil() || b.IsNil() {
134 return a.IsNil() && !b.IsNil() // nil is less than non-nil.
135 }
136 return LessValues(a.Elem(), b.Elem())
137 default:
138 return false
139 }
140}
141
142func lessBool(a, b bool) bool {
143 return !a && b // false < true
144}
145
146func lessComplex(a, b complex128) bool {
147 // Compare lexicographically, real part before imaginary part.
148 if real(a) == real(b) {
149 return imag(a) < imag(b)
150 }
151 return real(a) < real(b)
152}
153
154// Compare is a helper to call CompareValues.
155func Compare(a, b interface{}) int {
156 return CompareValues(reflect.ValueOf(a), reflect.ValueOf(b))
157}
158
159// CompareValues returns an integer comparing two values. If a and b are
160// comparable, the result is 0 if a == b, -1 if a < b and +1 if a > b. If a and
161// b are incomparable an arbitrary value is returned. Arrays, slices and
162// structs use lexicographic ordering, and complex numbers compare real before
163// imaginary.
164func CompareValues(a, b reflect.Value) int {
165 if a.Kind() != b.Kind() {
166 return 0 // Different kinds are incomparable.
167 }
168 switch a.Kind() {
169 case reflect.Array:
170 return compareArray(a, b)
171 case reflect.Slice:
172 return compareSlice(a, b)
173 case reflect.Struct:
174 return compareStruct(a, b)
175 case reflect.Ptr:
176 if a.IsNil() || b.IsNil() {
177 if a.IsNil() && !b.IsNil() {
178 return -1
179 }
180 if !a.IsNil() && b.IsNil() {
181 return +1
182 }
183 return 0
184 }
185 return CompareValues(a.Elem(), b.Elem())
186 }
187 if LessValues(a, b) {
188 return -1 // a < b
189 }
190 if LessValues(b, a) {
191 return +1 // a > b
192 }
193 return 0 // a == b, or incomparable.
194}
195
196func compareArray(a, b reflect.Value) int {
197 // Return lexicographic ordering of the array elements.
198 for ix := 0; ix < a.Len(); ix++ {
199 if c := CompareValues(a.Index(ix), b.Index(ix)); c != 0 {
200 return c
201 }
202 }
203 return 0
204}
205
206func compareSlice(a, b reflect.Value) int {
207 // Return lexicographic ordering of the slice elements.
208 for ix := 0; ix < a.Len() && ix < b.Len(); ix++ {
209 if c := CompareValues(a.Index(ix), b.Index(ix)); c != 0 {
210 return c
211 }
212 }
213 // Equal prefixes, shorter comes before longer.
214 if a.Len() < b.Len() {
215 return -1
216 }
217 if a.Len() > b.Len() {
218 return +1
219 }
220 return 0
221}
222
223func compareStruct(a, b reflect.Value) int {
224 // Return lexicographic ordering of the struct fields.
225 for ix := 0; ix < a.NumField(); ix++ {
226 if c := CompareValues(a.Field(ix), b.Field(ix)); c != 0 {
227 return c
228 }
229 }
230 return 0
231}
232
233// TrySortValues sorts a slice of reflect.Value if the value kind is supported.
234// Supported kinds are bools, strings and numbers, and composites using arrays,
235// slices, structs or pointers. Arrays, slices and structs use lexicographic
236// ordering, and complex numbers compare real before imaginary. If the values
237// in the slice aren't comparable or supported, the resulting ordering is
238// arbitrary.
239func TrySortValues(v []reflect.Value) []reflect.Value {
240 if len(v) <= 1 {
241 return v
242 }
243 switch v[0].Kind() {
244 case reflect.Bool:
245 sort.Sort(rvBools{v})
246 case reflect.String:
247 sort.Sort(rvStrings{v})
248 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
249 sort.Sort(rvUints{v})
250 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
251 sort.Sort(rvInts{v})
252 case reflect.Float32, reflect.Float64:
253 sort.Sort(rvFloats{v})
254 case reflect.Complex64, reflect.Complex128:
255 sort.Sort(rvComplexes{v})
256 case reflect.Array:
257 sort.Sort(rvArrays{v})
258 case reflect.Slice:
259 sort.Sort(rvSlices{v})
260 case reflect.Struct:
261 sort.Sort(rvStructs{v})
262 case reflect.Ptr:
263 sort.Sort(rvPtrs{v})
264 }
265 return v
266}
267
268// Sorting helpers, heavily inspired by similar code in text/template.
269
270type rvs []reflect.Value
271
272func (x rvs) Len() int { return len(x) }
273func (x rvs) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
274
275type rvBools struct{ rvs }
276type rvStrings struct{ rvs }
277type rvUints struct{ rvs }
278type rvInts struct{ rvs }
279type rvFloats struct{ rvs }
280type rvComplexes struct{ rvs }
281type rvArrays struct{ rvs }
282type rvSlices struct{ rvs }
283type rvStructs struct{ rvs }
284type rvPtrs struct{ rvs }
285
286func (x rvBools) Less(i, j int) bool {
287 return lessBool(x.rvs[i].Bool(), x.rvs[j].Bool())
288}
289func (x rvStrings) Less(i, j int) bool {
290 return x.rvs[i].String() < x.rvs[j].String()
291}
292func (x rvUints) Less(i, j int) bool {
293 return x.rvs[i].Uint() < x.rvs[j].Uint()
294}
295func (x rvInts) Less(i, j int) bool {
296 return x.rvs[i].Int() < x.rvs[j].Int()
297}
298func (x rvFloats) Less(i, j int) bool {
299 return x.rvs[i].Float() < x.rvs[j].Float()
300}
301func (x rvComplexes) Less(i, j int) bool {
302 return lessComplex(x.rvs[i].Complex(), x.rvs[j].Complex())
303}
304func (x rvArrays) Less(i, j int) bool {
305 return compareArray(x.rvs[i], x.rvs[j]) == -1
306}
307func (x rvSlices) Less(i, j int) bool {
308 return compareSlice(x.rvs[i], x.rvs[j]) == -1
309}
310func (x rvStructs) Less(i, j int) bool {
311 return compareStruct(x.rvs[i], x.rvs[j]) == -1
312}
313func (x rvPtrs) Less(i, j int) bool {
314 return LessValues(x.rvs[i], x.rvs[j])
315}