diff --git a/treap.go b/treap.go index f298ba0..9269c0f 100644 --- a/treap.go +++ b/treap.go @@ -189,6 +189,8 @@ func (t *Table) UnionMutable(other *Table) { t.root6 = t.root6.union(other.root6, true, false) } +// union two treaps. +// flag overwrite isn't public but needed as input for rec-descent calls, see below when trepa are swapped. func (n *node) union(b *node, overwrite bool, immutable bool) *node { // recursion stop condition if n == nil { @@ -631,14 +633,25 @@ func cmpRR(a, b netip.Prefix) int { return aLast.Compare(bLast) } -// ipTooBig returns true if ip is greater than prefix last address. -func ipTooBig(ip netip.Addr, p netip.Prefix) bool { - _, pLast := extnetip.Range(p) - return ip.Compare(pLast) > 0 +// ipTooBig returns true if ip is greater than prefix last ip address. +// +// false true +// | | +// V V +// +// ------- other --------> +func ipTooBig(ip netip.Addr, other netip.Prefix) bool { + _, pLastIP := extnetip.Range(other) + return ip.Compare(pLastIP) > 0 } -// pfxTooBig returns true if k last address is greater than p last address. -func pfxTooBig(k netip.Prefix, p netip.Prefix) bool { - _, ip := extnetip.Range(k) - return ipTooBig(ip, p) +// pfxTooBig returns true if prefix last address is greater than other last ip address. +// +// ------------ pfx --------------> true +// ------ pfx ----> false +// +// ------- other --------> +func pfxTooBig(pfx netip.Prefix, other netip.Prefix) bool { + _, pfxLastIP := extnetip.Range(pfx) + return ipTooBig(pfxLastIP, other) } diff --git a/treap_test.go b/treap_test.go index 261d5e3..93ae4b0 100644 --- a/treap_test.go +++ b/treap_test.go @@ -2,7 +2,6 @@ package cidrtree_test import ( "fmt" - mrand "math/rand" "net/netip" "reflect" "strings" @@ -524,31 +523,50 @@ func TestLookupCIDR(t *testing.T) { func TestUnion(t *testing.T) { t.Parallel() rtbl := new(cidrtree.Table) + rtbl2 := new(cidrtree.Table) for _, route := range routes { rtbl.InsertMutable(route.cidr, route.nextHop) + rtbl2.InsertMutable(route.cidr, route.nextHop) } - clone := rtbl.Clone() - if !reflect.DeepEqual(rtbl, clone) { - t.Fatal("Clone isn't deep equal to original table.") + rtbl.UnionMutable(rtbl2) + if rtbl.String() != asTopoStr { + t.Errorf("Fprint()\nwant:\n%sgot:\n%s", asTopoStr, rtbl.String()) } - probe := routes[mrand.Intn(len(routes))] - rtbl2 := new(cidrtree.Table).Insert(probe.cidr, "overwrite value") - - // overwrite value for this cidr - rtbl.UnionMutable(rtbl2) + rtbl3 := rtbl.Union(rtbl2) + if rtbl3.String() != asTopoStr { + t.Errorf("Fprint()\nwant:\n%sgot:\n%s", asTopoStr, rtbl.String()) + } +} - if reflect.DeepEqual(rtbl, clone) { - t.Fatal("union with overwrite must not deep equal to original table.") +func TestUnionDupe(t *testing.T) { + t.Parallel() + rtbl1 := new(cidrtree.Table) + rtbl2 := new(cidrtree.Table) + for _, cidr := range shuffleFullTable(100_000) { + rtbl1.InsertMutable(cidr, 1) + // dupe cidr with different value + rtbl2.InsertMutable(cidr, 2) } + // both tables have identical CIDRs but with different values + // overwrite all values with value=2 + rtbl1.UnionMutable(rtbl2) - _, value, ok := rtbl.LookupCIDR(probe.cidr) - if !ok { - t.Errorf("LookupCIDR(%v), expect %v, got %v", probe.cidr, true, ok) + var wrongValue bool + + // callback as closure + cb := func(pfx netip.Prefix, val any) bool { + if v, ok := val.(int); ok && v != 2 { + wrongValue = true + return false + } + return true } - if value != "overwrite value" { - t.Errorf("UnionMutable with duplicate, expect %q, got %q", "overwrite value", value) + + rtbl1.Walk(cb) + if wrongValue { + t.Error("Union with duplicate CIDRs didn't overwrite") } }