deepequal.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Copyright (c) 2014 Arista Networks, Inc.
  2. // Use of this source code is governed by the Apache License 2.0
  3. // that can be found in the COPYING file.
  4. package test
  5. import (
  6. "bytes"
  7. "math"
  8. "reflect"
  9. "notabug.org/themusicgod1/goarista/areflect"
  10. "notabug.org/themusicgod1/goarista/key"
  11. )
  12. var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem()
  13. // DeepEqual is a faster implementation of reflect.DeepEqual that:
  14. // - Has a reflection-free fast-path for all the common types we use.
  15. // - Gives data types the ability to exclude some of their fields from the
  16. // consideration of DeepEqual by tagging them with `deepequal:"ignore"`.
  17. // - Gives data types the ability to define their own comparison method by
  18. // implementing the comparable interface.
  19. // - Supports "composite" (or "complex") keys in maps that are pointers.
  20. func DeepEqual(a, b interface{}) bool {
  21. return deepEqual(a, b, nil)
  22. }
  23. func deepEqual(a, b interface{}, seen map[edge]struct{}) bool {
  24. if a == nil || b == nil {
  25. return a == b
  26. }
  27. switch a := a.(type) {
  28. // Short circuit fast-path for common built-in types.
  29. // Note: the cases are listed by frequency.
  30. case bool:
  31. return a == b
  32. case map[string]interface{}:
  33. v, ok := b.(map[string]interface{})
  34. if !ok || len(a) != len(v) {
  35. return false
  36. }
  37. for key, value := range a {
  38. if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
  39. return false
  40. }
  41. }
  42. return true
  43. case string, uint32, uint64, int32,
  44. uint16, int16, uint8, int8, int64:
  45. return a == b
  46. case *map[string]interface{}:
  47. v, ok := b.(*map[string]interface{})
  48. if !ok || a == nil || v == nil {
  49. return ok && a == v
  50. }
  51. return deepEqual(*a, *v, seen)
  52. case map[interface{}]interface{}:
  53. v, ok := b.(map[interface{}]interface{})
  54. if !ok {
  55. return false
  56. }
  57. // We compare in both directions to catch keys that are in b but not
  58. // in a. It sucks to have to do another O(N^2) for this, but oh well.
  59. return mapEqual(a, v) && mapEqual(v, a)
  60. case float32:
  61. v, ok := b.(float32)
  62. return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v))))
  63. case float64:
  64. v, ok := b.(float64)
  65. return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v)))
  66. case []string:
  67. v, ok := b.([]string)
  68. if !ok || len(a) != len(v) {
  69. return false
  70. }
  71. for i, s := range a {
  72. if s != v[i] {
  73. return false
  74. }
  75. }
  76. return true
  77. case []byte:
  78. v, ok := b.([]byte)
  79. return ok && bytes.Equal(a, v)
  80. case map[uint64]interface{}:
  81. v, ok := b.(map[uint64]interface{})
  82. if !ok || len(a) != len(v) {
  83. return false
  84. }
  85. for key, value := range a {
  86. if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
  87. return false
  88. }
  89. }
  90. return true
  91. case *map[interface{}]interface{}:
  92. v, ok := b.(*map[interface{}]interface{})
  93. if !ok || a == nil || v == nil {
  94. return ok && a == v
  95. }
  96. return deepEqual(*a, *v, seen)
  97. case key.Comparable:
  98. return a.Equal(b)
  99. case []uint32:
  100. v, ok := b.([]uint32)
  101. if !ok || len(a) != len(v) {
  102. return false
  103. }
  104. for i, s := range a {
  105. if s != v[i] {
  106. return false
  107. }
  108. }
  109. return true
  110. case []uint64:
  111. v, ok := b.([]uint64)
  112. if !ok || len(a) != len(v) {
  113. return false
  114. }
  115. for i, s := range a {
  116. if s != v[i] {
  117. return false
  118. }
  119. }
  120. return true
  121. case []interface{}:
  122. v, ok := b.([]interface{})
  123. if !ok || len(a) != len(v) {
  124. return false
  125. }
  126. for i, s := range a {
  127. if !deepEqual(s, v[i], seen) {
  128. return false
  129. }
  130. }
  131. return true
  132. case *[]string:
  133. v, ok := b.(*[]string)
  134. if !ok || a == nil || v == nil {
  135. return ok && a == v
  136. }
  137. return deepEqual(*a, *v, seen)
  138. case *[]interface{}:
  139. v, ok := b.(*[]interface{})
  140. if !ok || a == nil || v == nil {
  141. return ok && a == v
  142. }
  143. return deepEqual(*a, *v, seen)
  144. default:
  145. // Handle other kinds of non-comparable objects.
  146. return genericDeepEqual(a, b, seen)
  147. }
  148. }
  149. type edge struct {
  150. from uintptr
  151. to uintptr
  152. }
  153. func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool {
  154. av := reflect.ValueOf(a)
  155. bv := reflect.ValueOf(b)
  156. if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid {
  157. return avalid == bvalid
  158. }
  159. if bv.Type() != av.Type() {
  160. return false
  161. }
  162. switch av.Kind() {
  163. case reflect.Ptr:
  164. if av.IsNil() || bv.IsNil() {
  165. return a == b
  166. }
  167. av = av.Elem()
  168. bv = bv.Elem()
  169. if av.CanAddr() && bv.CanAddr() {
  170. e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()}
  171. // Detect and prevent cycles.
  172. if seen == nil {
  173. seen = make(map[edge]struct{})
  174. } else if _, ok := seen[e]; ok {
  175. return true
  176. }
  177. seen[e] = struct{}{}
  178. }
  179. return deepEqual(av.Interface(), bv.Interface(), seen)
  180. case reflect.Slice, reflect.Array:
  181. l := av.Len()
  182. if l != bv.Len() {
  183. return false
  184. }
  185. for i := 0; i < l; i++ {
  186. if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) {
  187. return false
  188. }
  189. }
  190. return true
  191. case reflect.Map:
  192. if av.IsNil() != bv.IsNil() {
  193. return false
  194. }
  195. if av.Len() != bv.Len() {
  196. return false
  197. }
  198. if av.Pointer() == bv.Pointer() {
  199. return true
  200. }
  201. for _, k := range av.MapKeys() {
  202. // Upon finding the first key that's a pointer, we bail out and do
  203. // a O(N^2) comparison.
  204. if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface {
  205. ok, _, _ := complexKeyMapEqual(av, bv, seen)
  206. return ok
  207. }
  208. ea := av.MapIndex(k)
  209. eb := bv.MapIndex(k)
  210. if !eb.IsValid() {
  211. return false
  212. }
  213. if !deepEqual(ea.Interface(), eb.Interface(), seen) {
  214. return false
  215. }
  216. }
  217. return true
  218. case reflect.Struct:
  219. typ := av.Type()
  220. if typ.Implements(comparableType) {
  221. return av.Interface().(key.Comparable).Equal(bv.Interface())
  222. }
  223. for i, n := 0, av.NumField(); i < n; i++ {
  224. if typ.Field(i).Tag.Get("deepequal") == "ignore" {
  225. continue
  226. }
  227. af := areflect.ForceExport(av.Field(i))
  228. bf := areflect.ForceExport(bv.Field(i))
  229. if !deepEqual(af.Interface(), bf.Interface(), seen) {
  230. return false
  231. }
  232. }
  233. return true
  234. default:
  235. // Other the basic types.
  236. return a == b
  237. }
  238. }
  239. // Compares two maps with complex keys (that are pointers). This assumes the
  240. // maps have already been checked to have the same sizes. The cost of this
  241. // function is O(N^2) in the size of the input maps.
  242. //
  243. // The return is to be interpreted this way:
  244. // true, _, _ => av == bv
  245. // false, key, invalid => the given key wasn't found in bv
  246. // false, key, value => the given key had the given value in bv,
  247. // which is different in av
  248. func complexKeyMapEqual(av, bv reflect.Value,
  249. seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) {
  250. for _, ka := range av.MapKeys() {
  251. var eb reflect.Value // The entry in bv with a key equal to ka
  252. for _, kb := range bv.MapKeys() {
  253. if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) {
  254. // Found the corresponding entry in bv.
  255. eb = bv.MapIndex(kb)
  256. break
  257. }
  258. }
  259. if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'.
  260. return false, ka, reflect.Value{}
  261. }
  262. ea := av.MapIndex(ka)
  263. if !deepEqual(ea.Interface(), eb.Interface(), seen) {
  264. return false, ka, eb
  265. }
  266. }
  267. return true, reflect.Value{}, reflect.Value{}
  268. }
  269. // mapEqual does O(N^2) comparisons to check that all the keys present in the
  270. // first map are also present in the second map and have identical values.
  271. func mapEqual(a, b map[interface{}]interface{}) bool {
  272. if len(a) != len(b) {
  273. return false
  274. }
  275. for akey, avalue := range a {
  276. found := false
  277. for bkey, bvalue := range b {
  278. if DeepEqual(akey, bkey) {
  279. if !DeepEqual(avalue, bvalue) {
  280. return false
  281. }
  282. found = true
  283. break
  284. }
  285. }
  286. if !found {
  287. return false
  288. }
  289. }
  290. return true
  291. }