From c5b54a2d4f993339636889865f383b1badeb8c9b Mon Sep 17 00:00:00 2001 From: Karl Gaissmaier Date: Wed, 10 Jan 2024 11:53:58 +0100 Subject: [PATCH] now with generic value --- README.md | 31 +++-- bench_test.go | 10 +- debug.go | 8 +- example_test.go | 123 ++++++++++---------- stringify.go | 24 ++-- treap.go | 290 +++++++++++++++++++++++------------------------ treap_test.go | 128 ++++++++++----------- whitebox_test.go | 20 ++-- 8 files changed, 318 insertions(+), 316 deletions(-) diff --git a/README.md b/README.md index 840aa32..ac5772b 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ## !!! ATTENTION -API currently not stable! +API is currently not stable! ## Overview @@ -16,8 +16,7 @@ API currently not stable! The implementation is based on treaps, which have been augmented here for CIDRs. Treaps are randomized, self-balancing binary search trees. Due to the nature of treaps the lookups (readers) and the update (writer) can be easily decoupled. This is the perfect fit for a software router or firewall. -This package is a specialization of the more generic [interval package] of the same author, -but explicit for CIDRs. It has a narrow focus with a specialized API for IP routing tables. +This package is a specialization of the more generic [interval package] of the same author, but explicit for CIDRs. It has a narrow focus with a specialized API for IP routing tables. [interval package]: https://github.com/gaissmai/interval @@ -25,23 +24,23 @@ but explicit for CIDRs. It has a narrow focus with a specialized API for IP rout ```go import "github.com/gaissmai/cidrtree" - type Table struct { // Has unexported fields. } + type Table[V any] struct { // Has unexported fields. } Table is an IPv4 and IPv6 routing table. The zero value is ready to use. - func (t Table) Lookup(ip netip.Addr) (lpm netip.Prefix, value any, ok bool) - func (t Table) LookupPrefix(pfx netip.Prefix) (lpm netip.Prefix, value any, ok bool) + func (t Table[V]) Lookup(ip netip.Addr) (lpm netip.Prefix, value V, ok bool) + func (t Table[V]) LookupPrefix(pfx netip.Prefix) (lpm netip.Prefix, value V, ok bool) - func (t *Table) Insert(pfx netip.Prefix, val any) - func (t *Table) Delete(pfx netip.Prefix) bool - func (t *Table) Union(other Table) + func (t *Table[V]) Insert(pfx netip.Prefix, value V) + func (t *Table[V]) Delete(pfx netip.Prefix) bool + func (t *Table[V]) Union(other Table[V]) - func (t Table) InsertImmutable(pfx netip.Prefix, val any) *Table - func (t Table) DeleteImmutable(pfx netip.Prefix) (*Table, bool) - func (t Table) UnionImmutable(other Table) *Table - func (t Table) Clone() *Table + func (t Table[V]) InsertImmutable(pfx netip.Prefix, value V) *Table[V] + func (t Table[V]) DeleteImmutable(pfx netip.Prefix) (*Table[V], bool) + func (t Table[V]) UnionImmutable(other Table[V]) *Table[V] + func (t Table[V]) Clone() *Table[V] - func (t Table) String() string - func (t Table) Fprint(w io.Writer) error + func (t Table[V]) String() string + func (t Table[V]) Fprint(w io.Writer) error - func (t Table) Walk(cb func(pfx netip.Prefix, val any) bool) + func (t Table[V]) Walk(cb func(pfx netip.Prefix, value V) bool) ``` diff --git a/bench_test.go b/bench_test.go index e80c92d..27a8e79 100644 --- a/bench_test.go +++ b/bench_test.go @@ -25,7 +25,7 @@ var intMap = map[int]string{ func BenchmarkLookup(b *testing.B) { for k := 1; k <= 100_000; k *= 10 { - rt := new(cidrtree.Table) + rt := new(cidrtree.Table[any]) cidrs := shuffleFullTable(k) for _, cidr := range cidrs { rt.Insert(cidr, nil) @@ -45,7 +45,7 @@ func BenchmarkLookup(b *testing.B) { func BenchmarkLookupPrefix(b *testing.B) { for k := 1; k <= 100_000; k *= 10 { - rt := new(cidrtree.Table) + rt := new(cidrtree.Table[any]) cidrs := shuffleFullTable(k) for _, cidr := range cidrs { rt.Insert(cidr, nil) @@ -64,7 +64,7 @@ func BenchmarkLookupPrefix(b *testing.B) { func BenchmarkClone(b *testing.B) { for k := 1; k <= 100_000; k *= 10 { - rt := new(cidrtree.Table) + rt := new(cidrtree.Table[any]) for _, cidr := range shuffleFullTable(k) { rt.Insert(cidr, nil) } @@ -80,7 +80,7 @@ func BenchmarkClone(b *testing.B) { func BenchmarkInsert(b *testing.B) { for k := 1; k <= 100_000; k *= 10 { - rt := new(cidrtree.Table) + rt := new(cidrtree.Table[any]) cidrs := shuffleFullTable(k) for _, cidr := range cidrs { rt.Insert(cidr, nil) @@ -99,7 +99,7 @@ func BenchmarkInsert(b *testing.B) { func BenchmarkDelete(b *testing.B) { for k := 1; k <= 100_000; k *= 10 { - rt := new(cidrtree.Table) + rt := new(cidrtree.Table[any]) cidrs := shuffleFullTable(k) for _, cidr := range cidrs { rt.Insert(cidr, nil) diff --git a/debug.go b/debug.go index 2a397fa..9432252 100644 --- a/debug.go +++ b/debug.go @@ -10,7 +10,7 @@ import ( // fprintBST writes a horizontal tree diagram of the binary search tree (BST) to w. // // Note: This is for debugging purposes only. -func (t Table) fprintBST(w io.Writer) error { +func (t Table[V]) fprintBST(w io.Writer) error { if t.root4 != nil { if _, err := fmt.Fprint(w, "R "); err != nil { return err @@ -33,7 +33,7 @@ func (t Table) fprintBST(w io.Writer) error { } // fprintBST recursive helper. -func (n *node) fprintBST(w io.Writer, pad string) error { +func (n *node[V]) fprintBST(w io.Writer, pad string) error { // stringify this node _, err := fmt.Fprintf(w, "%v [prio:%.4g] [subtree maxUpper: %v]\n", n.cidr, float64(n.prio)/math.MaxUint64, n.maxUpper.cidr) if err != nil { @@ -80,7 +80,7 @@ func (n *node) fprintBST(w io.Writer, pad string) error { // If the skip function is not nil, a true return value defines which nodes must be skipped in the statistics. // // Note: This is for debugging and testing purposes only during development. -func (t Table) statistics(skip func(netip.Prefix, any, int) bool) (size int, maxDepth int, average, deviation float64) { +func (t Table[V]) statistics(skip func(netip.Prefix, any, int) bool) (size int, maxDepth int, average, deviation float64) { // key is depth, value is the sum of nodes with this depth depths := make(map[int]int) @@ -120,7 +120,7 @@ func (t Table) statistics(skip func(netip.Prefix, any, int) bool) (size int, max } // walkWithDepth in ascending prefix order. -func (n *node) walkWithDepth(cb func(netip.Prefix, any, int) bool, depth int) bool { +func (n *node[V]) walkWithDepth(cb func(netip.Prefix, any, int) bool, depth int) bool { if n == nil { return true } diff --git a/example_test.go b/example_test.go index 10d02ca..29d8311 100644 --- a/example_test.go +++ b/example_test.go @@ -8,73 +8,76 @@ import ( "github.com/gaissmai/cidrtree" ) -func addr(s string) netip.Addr { +func mustAddr(s string) netip.Addr { return netip.MustParseAddr(s) } -func prfx(s string) netip.Prefix { +func mustPfx(s string) netip.Prefix { return netip.MustParsePrefix(s) } -var input = []netip.Prefix{ - prfx("fe80::/10"), - prfx("172.16.0.0/12"), - prfx("10.0.0.0/24"), - prfx("::1/128"), - prfx("192.168.0.0/16"), - prfx("10.0.0.0/8"), - prfx("::/0"), - prfx("10.0.1.0/24"), - prfx("169.254.0.0/16"), - prfx("2000::/3"), - prfx("2001:db8::/32"), - prfx("127.0.0.0/8"), - prfx("127.0.0.1/32"), - prfx("192.168.1.0/24"), +var input = []struct { + cidr netip.Prefix + nextHop netip.Addr +}{ + {mustPfx("fe80::/10"), mustAddr("::1%lo")}, + {mustPfx("172.16.0.0/12"), mustAddr("8.8.8.8")}, + {mustPfx("10.0.0.0/24"), mustAddr("8.8.8.8")}, + {mustPfx("::1/128"), mustAddr("::1%eth0")}, + {mustPfx("192.168.0.0/16"), mustAddr("9.9.9.9")}, + {mustPfx("10.0.0.0/8"), mustAddr("9.9.9.9")}, + {mustPfx("::/0"), mustAddr("2001:db8::1")}, + {mustPfx("10.0.1.0/24"), mustAddr("10.0.0.0")}, + {mustPfx("169.254.0.0/16"), mustAddr("10.0.0.0")}, + {mustPfx("2000::/3"), mustAddr("2001:db8::1")}, + {mustPfx("2001:db8::/32"), mustAddr("2001:db8::1")}, + {mustPfx("127.0.0.0/8"), mustAddr("127.0.0.1")}, + {mustPfx("127.0.0.1/32"), mustAddr("127.0.0.1")}, + {mustPfx("192.168.1.0/24"), mustAddr("127.0.0.1")}, } func ExampleTable_Lookup() { - rtbl := new(cidrtree.Table) - for _, cidr := range input { - rtbl.Insert(cidr, nil) + rtbl := new(cidrtree.Table[netip.Addr]) + for _, item := range input { + rtbl.Insert(item.cidr, item.nextHop) } rtbl.Fprint(os.Stdout) fmt.Println() - ip := addr("42.0.0.0") + ip := mustAddr("42.0.0.0") lpm, value, ok := rtbl.Lookup(ip) - fmt.Printf("Lookup: %-20v lpm: %-15v value: %v, ok: %v\n", ip, lpm, value, ok) + fmt.Printf("Lookup: %-20v lpm: %-15v value: %11v, ok: %v\n", ip, lpm, value, ok) - ip = addr("10.0.1.17") + ip = mustAddr("10.0.1.17") lpm, value, ok = rtbl.Lookup(ip) - fmt.Printf("Lookup: %-20v lpm: %-15v value: %v, ok: %v\n", ip, lpm, value, ok) + fmt.Printf("Lookup: %-20v lpm: %-15v value: %11v, ok: %v\n", ip, lpm, value, ok) - ip = addr("2001:7c0:3100:1::111") + ip = mustAddr("2001:7c0:3100:1::111") lpm, value, ok = rtbl.Lookup(ip) - fmt.Printf("Lookup: %-20v lpm: %-15v value: %v, ok: %v\n", ip, lpm, value, ok) + fmt.Printf("Lookup: %-20v lpm: %-15v value: %11v, ok: %v\n", ip, lpm, value, ok) // Output: // ▼ - // ├─ 10.0.0.0/8 () - // │ ├─ 10.0.0.0/24 () - // │ └─ 10.0.1.0/24 () - // ├─ 127.0.0.0/8 () - // │ └─ 127.0.0.1/32 () - // ├─ 169.254.0.0/16 () - // ├─ 172.16.0.0/12 () - // └─ 192.168.0.0/16 () - // └─ 192.168.1.0/24 () + // ├─ 10.0.0.0/8 (9.9.9.9) + // │ ├─ 10.0.0.0/24 (8.8.8.8) + // │ └─ 10.0.1.0/24 (10.0.0.0) + // ├─ 127.0.0.0/8 (127.0.0.1) + // │ └─ 127.0.0.1/32 (127.0.0.1) + // ├─ 169.254.0.0/16 (10.0.0.0) + // ├─ 172.16.0.0/12 (8.8.8.8) + // └─ 192.168.0.0/16 (9.9.9.9) + // └─ 192.168.1.0/24 (127.0.0.1) // ▼ - // └─ ::/0 () - // ├─ ::1/128 () - // ├─ 2000::/3 () - // │ └─ 2001:db8::/32 () - // └─ fe80::/10 () + // └─ ::/0 (2001:db8::1) + // ├─ ::1/128 (::1%eth0) + // ├─ 2000::/3 (2001:db8::1) + // │ └─ 2001:db8::/32 (2001:db8::1) + // └─ fe80::/10 (::1%lo) // - // Lookup: 42.0.0.0 lpm: invalid Prefix value: , ok: false - // Lookup: 10.0.1.17 lpm: 10.0.1.0/24 value: , ok: true - // Lookup: 2001:7c0:3100:1::111 lpm: 2000::/3 value: , ok: true + // Lookup: 42.0.0.0 lpm: invalid Prefix value: invalid IP, ok: false + // Lookup: 10.0.1.17 lpm: 10.0.1.0/24 value: 10.0.0.0, ok: true + // Lookup: 2001:7c0:3100:1::111 lpm: 2000::/3 value: 2001:db8::1, ok: true } func ExampleTable_Walk() { @@ -83,25 +86,25 @@ func ExampleTable_Walk() { return true } - rtbl := new(cidrtree.Table) - for _, cidr := range input { - rtbl.Insert(cidr, nil) + rtbl := new(cidrtree.Table[any]) + for _, item := range input { + rtbl.Insert(item.cidr, item.nextHop) } rtbl.Walk(cb) // Output: - // 10.0.0.0/8 () - // 10.0.0.0/24 () - // 10.0.1.0/24 () - // 127.0.0.0/8 () - // 127.0.0.1/32 () - // 169.254.0.0/16 () - // 172.16.0.0/12 () - // 192.168.0.0/16 () - // 192.168.1.0/24 () - // ::/0 () - // ::1/128 () - // 2000::/3 () - // 2001:db8::/32 () - // fe80::/10 () + // 10.0.0.0/8 (9.9.9.9) + // 10.0.0.0/24 (8.8.8.8) + // 10.0.1.0/24 (10.0.0.0) + // 127.0.0.0/8 (127.0.0.1) + // 127.0.0.1/32 (127.0.0.1) + // 169.254.0.0/16 (10.0.0.0) + // 172.16.0.0/12 (8.8.8.8) + // 192.168.0.0/16 (9.9.9.9) + // 192.168.1.0/24 (127.0.0.1) + // ::/0 (2001:db8::1) + // ::1/128 (::1%eth0) + // 2000::/3 (2001:db8::1) + // 2001:db8::/32 (2001:db8::1) + // fe80::/10 (::1%lo) } diff --git a/stringify.go b/stringify.go index 5361935..e3cd394 100644 --- a/stringify.go +++ b/stringify.go @@ -7,7 +7,7 @@ import ( ) // String returns a hierarchical tree diagram of the ordered CIDRs as string, just a wrapper for [Tree.Fprint]. -func (t Table) String() string { +func (t Table[V]) String() string { w := new(strings.Builder) _ = t.Fprint(w) return w.String() @@ -17,7 +17,7 @@ func (t Table) String() string { // // The order from top to bottom is in ascending order of the start address // and the subtree structure is determined by the CIDRs coverage. -func (t Table) Fprint(w io.Writer) error { +func (t Table[V]) Fprint(w io.Writer) error { if err := t.root4.fprint(w); err != nil { return err } @@ -27,16 +27,16 @@ func (t Table) Fprint(w io.Writer) error { return nil } -func (n *node) fprint(w io.Writer) error { +func (n *node[V]) fprint(w io.Writer) error { if n == nil { return nil } // pcm = parent-child-mapping - var pcm parentChildsMap + var pcm parentChildsMap[V] // init map - pcm.pcMap = make(map[*node][]*node) + pcm.pcMap = make(map[*node[V]][]*node[V]) pcm = n.buildParentChildsMap(pcm) @@ -50,11 +50,11 @@ func (n *node) fprint(w io.Writer) error { } // start recursion with root and empty padding - var root *node + var root *node[V] return root.walkAndStringify(w, pcm, "") } -func (n *node) walkAndStringify(w io.Writer, pcm parentChildsMap, pad string) error { +func (n *node[V]) walkAndStringify(w io.Writer, pcm parentChildsMap[V], pad string) error { // the prefix (pad + glyphe) is already printed on the line on upper level if n != nil { if _, err := fmt.Fprintf(w, "%v (%v)\n", n.cidr, n.value); err != nil { @@ -92,13 +92,13 @@ func (n *node) walkAndStringify(w io.Writer, pcm parentChildsMap, pad string) er // parentChildsMap, needed for hierarchical tree printing, this is not BST printing! // // CIDR tree, parent->childs relation printed. A parent CIDR covers a child CIDR. -type parentChildsMap struct { - pcMap map[*node][]*node // parent -> []child map - stack []*node // just needed for the algo +type parentChildsMap[T any] struct { + pcMap map[*node[T]][]*node[T] // parent -> []child map + stack []*node[T] // just needed for the algo } // buildParentChildsMap, in-order traversal -func (n *node) buildParentChildsMap(pcm parentChildsMap) parentChildsMap { +func (n *node[V]) buildParentChildsMap(pcm parentChildsMap[V]) parentChildsMap[V] { if n == nil { return pcm } @@ -114,7 +114,7 @@ func (n *node) buildParentChildsMap(pcm parentChildsMap) parentChildsMap { } // pcmForNode, find parent in stack, remove cidrs from stack, put this cidr on stack. -func (n *node) pcmForNode(pcm parentChildsMap) parentChildsMap { +func (n *node[V]) pcmForNode(pcm parentChildsMap[V]) parentChildsMap[V] { // if this cidr is covered by a prev cidr on stack for j := len(pcm.stack) - 1; j >= 0; j-- { that := pcm.stack[j] diff --git a/treap.go b/treap.go index cc9bab2..4f6a0bc 100644 --- a/treap.go +++ b/treap.go @@ -16,51 +16,165 @@ import ( ) // Table is an IPv4 and IPv6 routing table. The zero value is ready to use. -type Table struct { +type Table[V any] struct { // make a treap for every IP version, the bits of the prefix are part of the weighted priority - root4 *node - root6 *node + root4 *node[V] + root6 *node[V] } // node is the recursive data structure of the treap. -type node struct { - maxUpper *node // augment the treap, see also recalc() - left *node - right *node - value any +type node[V any] struct { + maxUpper *node[V] // augment the treap, see also recalc() + left *node[V] + right *node[V] + value V cidr netip.Prefix prio uint64 } -// Insert adds pfx to the table with value val, changing the original table. -// If pfx is already present in the table, its value is set to val. -func (t *Table) Insert(pfx netip.Prefix, val any) { +// Lookup returns the longest-prefix-match (lpm) for given ip. +// If the ip isn't covered by any CIDR, the zero value and false is returned. +// +// Lookup does not allocate memory. +func (t Table[V]) Lookup(ip netip.Addr) (lpm netip.Prefix, value V, ok bool) { + if ip.Is4() { + // don't return the depth + lpm, value, ok, _ = t.root4.lpmIP(ip, 0) + return + } + // don't return the depth + lpm, value, ok, _ = t.root6.lpmIP(ip, 0) + return +} + +// LookupPrefix returns the longest-prefix-match (lpm) for given prefix. +// If the prefix isn't equal or covered by any CIDR in the table, the zero value and false is returned. +// +// LookupPrefix does not allocate memory. +func (t Table[V]) LookupPrefix(pfx netip.Prefix) (lpm netip.Prefix, value V, ok bool) { + pfx = pfx.Masked() // always canonicalize! + + if pfx.Addr().Is4() { + // don't return the depth + lpm, value, ok, _ = t.root4.lpmCIDR(pfx, 0) + return + } + // don't return the depth + lpm, value, ok, _ = t.root6.lpmCIDR(pfx, 0) + return +} + +// Insert adds pfx to the routing table with value of generic type V. +// If pfx is already present in the table, its value is set to the new value. +func (t *Table[V]) Insert(pfx netip.Prefix, value V) { pfx = pfx.Masked() // always canonicalize! if pfx.Addr().Is4() { - t.root4 = t.root4.insert(makeNode(pfx, val), false) + t.root4 = t.root4.insert(makeNode(pfx, value), false) return } - t.root6 = t.root6.insert(makeNode(pfx, val), false) + t.root6 = t.root6.insert(makeNode(pfx, value), false) } -// InsertImmutable adds pfx to the table with value val, returning a new table. -// If pfx is already present in the table, its value is set to val. -func (t Table) InsertImmutable(pfx netip.Prefix, val any) *Table { +// InsertImmutable adds pfx to the table with value of generic type V, returning a new table. +// If pfx is already present in the table, its value is set to the new value. +func (t Table[V]) InsertImmutable(pfx netip.Prefix, value V) *Table[V] { pfx = pfx.Masked() // always canonicalize! if pfx.Addr().Is4() { - t.root4 = t.root4.insert(makeNode(pfx, val), true) + t.root4 = t.root4.insert(makeNode(pfx, value), true) return &t } - t.root6 = t.root6.insert(makeNode(pfx, val), true) + t.root6 = t.root6.insert(makeNode(pfx, value), true) + return &t +} + +// Delete removes the prefix from table, returns true if it exists, false otherwise. +func (t *Table[V]) Delete(pfx netip.Prefix) bool { + pfx = pfx.Masked() // always canonicalize! + + is4 := pfx.Addr().Is4() + + n := t.root6 + if is4 { + n = t.root4 + } + + // split/join is set to mutable + l, m, r := n.split(pfx, false) + n = l.join(r, false) + + if is4 { + t.root4 = n + } else { + t.root6 = n + } + + return m != nil +} + +// DeleteImmutable removes the prefix if it exists, returns the new table and true, false if not found. +func (t Table[V]) DeleteImmutable(pfx netip.Prefix) (*Table[V], bool) { + pfx = pfx.Masked() // always canonicalize! + + is4 := pfx.Addr().Is4() + + n := t.root6 + if is4 { + n = t.root4 + } + + // split/join is set to immutable + l, m, r := n.split(pfx, true) + n = l.join(r, true) + + if is4 { + t.root4 = n + } else { + t.root6 = n + } + + ok := m != nil + return &t, ok +} + +// Clone, deep cloning of the routing table. +func (t Table[V]) Clone() *Table[V] { + t.root4 = t.root4.clone() + t.root6 = t.root6.clone() + return &t +} + +// Union combines two tables, changing the receiver table. +// If there are duplicate entries, the value is taken from the other table. +func (t *Table[V]) Union(other Table[V]) { + t.root4 = t.root4.union(other.root4, true, false) + t.root6 = t.root6.union(other.root6, true, false) +} + +// UnionImmutable combines any two tables immutable and returns the combined table. +// If there are duplicate entries, the value is taken from the other table. +func (t Table[V]) UnionImmutable(other Table[V]) *Table[V] { + t.root4 = t.root4.union(other.root4, true, true) + t.root6 = t.root6.union(other.root6, true, true) return &t } +// Walk iterates the cidrtree in ascending order. +// The callback function is called with the prefix and value of the respective node and the depth in the tree. +// If callback returns `false`, the iteration is aborted. +func (t Table[V]) Walk(cb func(pfx netip.Prefix, value V) bool) { + if !t.root4.walk(cb) { + return + } + + t.root6.walk(cb) +} + // insert into treap, changing nodes are copied, new treap is returned, // old treap is modified if immutable is false. // If node is already present in the table, its value is set to val. -func (n *node) insert(m *node, immutable bool) *node { +func (n *node[V]) insert(m *node[V], immutable bool) *node[V] { if n == nil { // recursion stop condition return m @@ -129,73 +243,9 @@ func (n *node) insert(m *node, immutable bool) *node { return n } -// DeleteImmutable removes the prefix if it exists, returns the new table and true, false if not found. -func (t Table) DeleteImmutable(pfx netip.Prefix) (*Table, bool) { - pfx = pfx.Masked() // always canonicalize! - - is4 := pfx.Addr().Is4() - - n := t.root6 - if is4 { - n = t.root4 - } - - // split/join must be immutable - l, m, r := n.split(pfx, true) - n = l.join(r, true) - - if is4 { - t.root4 = n - } else { - t.root6 = n - } - - ok := m != nil - return &t, ok -} - -// Delete removes the prefix from table, returns true if it exists, false otherwise. -func (t *Table) Delete(pfx netip.Prefix) bool { - pfx = pfx.Masked() // always canonicalize! - - is4 := pfx.Addr().Is4() - - n := t.root6 - if is4 { - n = t.root4 - } - - // split/join is mutable - l, m, r := n.split(pfx, false) - n = l.join(r, false) - - if is4 { - t.root4 = n - } else { - t.root6 = n - } - - return m != nil -} - -// UnionImmutable combines any two tables immutable and returns the combined table. -// If there are duplicate entries, the value is taken from the other table. -func (t Table) UnionImmutable(other Table) *Table { - t.root4 = t.root4.union(other.root4, true, true) - t.root6 = t.root6.union(other.root6, true, true) - return &t -} - -// Union combines two tables, changing the receiver table. -// If there are duplicate entries, the value is taken from the other table. -func (t *Table) Union(other Table) { - t.root4 = t.root4.union(other.root4, true, false) - t.root6 = t.root6.union(other.root6, true, false) -} - // union two treaps. // flag overwrite isn't public but needed as input for rec-descent calls, see below when trepa are swapped. -func (n *node) union(b *node, overwrite bool, immutable bool) *node { +func (n *node[V]) union(b *node[V], overwrite bool, immutable bool) *node[V] { // recursion stop condition if n == nil { return b @@ -234,19 +284,8 @@ func (n *node) union(b *node, overwrite bool, immutable bool) *node { return n } -// Walk iterates the cidrtree in ascending order. -// The callback function is called with the prefix and value of the respective node and the depth in the tree. -// If callback returns `false`, the iteration is aborted. -func (t Table) Walk(cb func(pfx netip.Prefix, val any) bool) { - if !t.root4.walk(cb) { - return - } - - t.root6.walk(cb) -} - // walk tree in ascending prefix order. -func (n *node) walk(cb func(netip.Prefix, any) bool) bool { +func (n *node[V]) walk(cb func(netip.Prefix, V) bool) bool { if n == nil { return true } @@ -269,23 +308,8 @@ func (n *node) walk(cb func(netip.Prefix, any) bool) bool { return true } -// Lookup returns the longest-prefix-match (lpm) for given ip. -// If the ip isn't covered by any CIDR, the zero value and false is returned. -// -// Lookup does not allocate memory. -func (t Table) Lookup(ip netip.Addr) (lpm netip.Prefix, value any, ok bool) { - if ip.Is4() { - // don't return the depth - lpm, value, ok, _ = t.root4.lpmIP(ip, 0) - return - } - // don't return the depth - lpm, value, ok, _ = t.root6.lpmIP(ip, 0) - return -} - // lpmIP rec-descent -func (n *node) lpmIP(ip netip.Addr, depth int) (lpm netip.Prefix, value any, ok bool, atDepth int) { +func (n *node[V]) lpmIP(ip netip.Addr, depth int) (lpm netip.Prefix, value V, ok bool, atDepth int) { for { // recursion stop condition if n == nil { @@ -322,25 +346,8 @@ func (n *node) lpmIP(ip netip.Addr, depth int) (lpm netip.Prefix, value any, ok return n.left.lpmIP(ip, depth+1) } -// LookupPrefix returns the longest-prefix-match (lpm) for given prefix. -// If the prefix isn't equal or covered by any CIDR in the table, the zero value and false is returned. -// -// LookupPrefix does not allocate memory. -func (t Table) LookupPrefix(pfx netip.Prefix) (lpm netip.Prefix, value any, ok bool) { - pfx = pfx.Masked() // always canonicalize! - - if pfx.Addr().Is4() { - // don't return the depth - lpm, value, ok, _ = t.root4.lpmCIDR(pfx, 0) - return - } - // don't return the depth - lpm, value, ok, _ = t.root6.lpmCIDR(pfx, 0) - return -} - // lpmCIDR rec-descent -func (n *node) lpmCIDR(pfx netip.Prefix, depth int) (lpm netip.Prefix, value any, ok bool, atDepth int) { +func (n *node[V]) lpmCIDR(pfx netip.Prefix, depth int) (lpm netip.Prefix, value V, ok bool, atDepth int) { for { // recursion stop condition if n == nil { @@ -392,14 +399,7 @@ func (n *node) lpmCIDR(pfx netip.Prefix, depth int) (lpm netip.Prefix, value any return n.left.lpmCIDR(pfx, depth+1) } -// Clone, deep cloning of the routing table. -func (t Table) Clone() *Table { - t.root4 = t.root4.clone() - t.root6 = t.root6.clone() - return &t -} - -func (n *node) clone() *node { +func (n *node[V]) clone() *node[V] { if n == nil { return n } @@ -421,7 +421,7 @@ func (n *node) clone() *node { // and greater-than the provided cidr (BST key). The resulting nodes are // properly formed treaps or nil. // If the split must be immutable, first copy concerned nodes. -func (n *node) split(cidr netip.Prefix, immutable bool) (left, mid, right *node) { +func (n *node[V]) split(cidr netip.Prefix, immutable bool) (left, mid, right *node[V]) { // recursion stop condition if n == nil { return nil, nil, nil @@ -472,7 +472,7 @@ func (n *node) split(cidr netip.Prefix, immutable bool) (left, mid, right *node) // join combines two disjunct treaps. All nodes in treap n have keys <= that of treap m // for this algorithm to work correctly. If the join must be immutable, first copy concerned nodes. -func (n *node) join(m *node, immutable bool) *node { +func (n *node[V]) join(m *node[V], immutable bool) *node[V] { // recursion stop condition if n == nil { return m @@ -511,17 +511,17 @@ func (n *node) join(m *node, immutable bool) *node { // ########################################################### // makeNode, create new node with cidr. -func makeNode(pfx netip.Prefix, val any) *node { - n := new(node) +func makeNode[V any](pfx netip.Prefix, value V) *node[V] { + n := new(node[V]) n.cidr = pfx.Masked() // always store the prefix in normalized form - n.value = val + n.value = value n.prio = mrand.Uint64() n.recalc() // init the augmented field with recalc return n } // copyNode, make a shallow copy of the pointers and the cidr. -func (n *node) copyNode() *node { +func (n *node[V]) copyNode() *node[V] { c := *n return &c } @@ -529,7 +529,7 @@ func (n *node) copyNode() *node { // recalc the augmented fields in treap node after each creation/modification // with values in descendants. // Only one level deeper must be considered. The treap datastructure is very easy to augment. -func (n *node) recalc() { +func (n *node[V]) recalc() { if n == nil { return } diff --git a/treap_test.go b/treap_test.go index 4fca344..1dc09f0 100644 --- a/treap_test.go +++ b/treap_test.go @@ -44,7 +44,7 @@ var routes = makeRoutes(routesStr) func makeRoutes(rs []routeStr) []route { var routes []route for _, s := range rs { - routes = append(routes, route{prfx(s.cidr), addr(s.nextHop)}) + routes = append(routes, route{mustPfx(s.cidr), mustAddr(s.nextHop)}) } return routes } @@ -92,7 +92,7 @@ func TestZeroValue(t *testing.T) { var zeroIP netip.Addr var zeroCIDR netip.Prefix - var zeroTable cidrtree.Table + var zeroTable cidrtree.Table[any] if zeroTable.String() != "" { t.Errorf("String() = %v, want \"\"", "") @@ -129,7 +129,7 @@ func TestZeroValue(t *testing.T) { func TestInsertImmutable(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl = rtbl.InsertImmutable(route.cidr, route.nextHop) @@ -142,7 +142,7 @@ func TestInsertImmutable(t *testing.T) { func TestDupInsert(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) @@ -183,7 +183,7 @@ func TestDupInsert(t *testing.T) { func TestInsert(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) @@ -197,7 +197,7 @@ func TestInsert(t *testing.T) { func TestImmutable(t *testing.T) { t.Parallel() - rtbl1 := new(cidrtree.Table) + rtbl1 := new(cidrtree.Table[any]) for _, route := range routes { rtbl1.Insert(route.cidr, route.nextHop) } @@ -235,7 +235,7 @@ func TestImmutable(t *testing.T) { } func TestMutable(t *testing.T) { - rtbl1 := new(cidrtree.Table) + rtbl1 := new(cidrtree.Table[any]) for _, route := range routes { rtbl1.Insert(route.cidr, route.nextHop) } @@ -253,13 +253,13 @@ func TestMutable(t *testing.T) { } // reset table1, table2 - rtbl1 = new(cidrtree.Table) + rtbl1 = new(cidrtree.Table[any]) for _, route := range routes { rtbl1.Insert(route.cidr, route.nextHop) } rtbl2 = rtbl1.Clone() - probe = route{cidr: prfx("1.2.3.4/17")} + probe = route{cidr: mustPfx("1.2.3.4/17")} rtbl1.Insert(probe.cidr, probe.nextHop) if reflect.DeepEqual(rtbl1, rtbl2) { @@ -274,7 +274,7 @@ func TestMutable(t *testing.T) { func TestDeleteImmutable(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -299,7 +299,7 @@ func TestDeleteImmutable(t *testing.T) { func TestDelete(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -322,7 +322,7 @@ func TestDelete(t *testing.T) { func TestLookupIP(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -334,39 +334,39 @@ func TestLookupIP(t *testing.T) { wantOK bool }{ { - ip: addr("10.0.1.17"), - want: prfx("10.0.1.0/24"), - want2: addr("203.0.113.0"), + ip: mustAddr("10.0.1.17"), + want: mustPfx("10.0.1.0/24"), + want2: mustAddr("203.0.113.0"), wantOK: true, }, { - ip: addr("10.2.3.4"), - want: prfx("10.0.0.0/8"), - want2: addr("203.0.113.0"), + ip: mustAddr("10.2.3.4"), + want: mustPfx("10.0.0.0/8"), + want2: mustAddr("203.0.113.0"), wantOK: true, }, { - ip: addr("12.0.0.0"), + ip: mustAddr("12.0.0.0"), want: netip.Prefix{}, want2: netip.Addr{}, wantOK: false, }, { - ip: addr("127.0.0.255"), - want: prfx("127.0.0.0/8"), - want2: addr("203.0.113.0"), + ip: mustAddr("127.0.0.255"), + want: mustPfx("127.0.0.0/8"), + want2: mustAddr("203.0.113.0"), wantOK: true, }, { - ip: addr("::2"), - want: prfx("::/0"), - want2: addr("2001:db8::1"), + ip: mustAddr("::2"), + want: mustPfx("::/0"), + want2: mustAddr("2001:db8::1"), wantOK: true, }, { - ip: addr("2001:db8:affe:cafe::dead:beef"), - want: prfx("2001:db8::/32"), - want2: addr("2001:db8::1"), + ip: mustAddr("2001:db8:affe:cafe::dead:beef"), + want: mustPfx("2001:db8::/32"), + want2: mustAddr("2001:db8::1"), wantOK: true, }, } @@ -377,12 +377,12 @@ func TestLookupIP(t *testing.T) { } } - prefix := prfx("10.0.0.0/8") + prefix := mustPfx("10.0.0.0/8") if ok := rtbl.Delete(prefix); !ok { t.Errorf("Delete(%v) = %v, want %v", prefix, ok, true) } - ip := addr("1.2.3.4") + ip := mustAddr("1.2.3.4") want := netip.Prefix{} want2 := any(nil) @@ -390,12 +390,12 @@ func TestLookupIP(t *testing.T) { t.Errorf("Lookup(%v) = %v, %v, %v, want %v, %v, %v", ip, got, got2, ok, want, want2, false) } - prefix = prfx("::/0") + prefix = mustPfx("::/0") if ok := rtbl.Delete(prefix); !ok { t.Errorf("Delete(%v) = %v, want %v", prefix, ok, true) } - ip = addr("::2") + ip = mustAddr("::2") want = netip.Prefix{} want2 = any(nil) @@ -406,7 +406,7 @@ func TestLookupIP(t *testing.T) { // ########################################## tc := shuffleFullTable(100_000) - rtbl2 := new(cidrtree.Table) + rtbl2 := new(cidrtree.Table[any]) for _, cidr := range tc { rtbl2.Insert(cidr, nil) } @@ -428,7 +428,7 @@ func TestLookupIP(t *testing.T) { func TestLookupCIDR(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -440,39 +440,39 @@ func TestLookupCIDR(t *testing.T) { wantOK bool }{ { - cidr: prfx("10.0.1.0/29"), - wantCIDR: prfx("10.0.1.0/24"), - wantValue: addr("203.0.113.0"), + cidr: mustPfx("10.0.1.0/29"), + wantCIDR: mustPfx("10.0.1.0/24"), + wantValue: mustAddr("203.0.113.0"), wantOK: true, }, { - cidr: prfx("10.2.0.0/16"), - wantCIDR: prfx("10.0.0.0/8"), - wantValue: addr("203.0.113.0"), + cidr: mustPfx("10.2.0.0/16"), + wantCIDR: mustPfx("10.0.0.0/8"), + wantValue: mustAddr("203.0.113.0"), wantOK: true, }, { - cidr: prfx("12.0.0.0/8"), + cidr: mustPfx("12.0.0.0/8"), wantCIDR: netip.Prefix{}, wantValue: netip.Addr{}, wantOK: false, }, { - cidr: prfx("127.0.0.2/32"), - wantCIDR: prfx("127.0.0.0/8"), - wantValue: addr("203.0.113.0"), + cidr: mustPfx("127.0.0.2/32"), + wantCIDR: mustPfx("127.0.0.0/8"), + wantValue: mustAddr("203.0.113.0"), wantOK: true, }, { - cidr: prfx("::2/96"), - wantCIDR: prfx("::/0"), - wantValue: addr("2001:db8::1"), + cidr: mustPfx("::2/96"), + wantCIDR: mustPfx("::/0"), + wantValue: mustAddr("2001:db8::1"), wantOK: true, }, { - cidr: prfx("2001:db8:affe:cafe:dead:beef::/96"), - wantCIDR: prfx("2001:db8::/32"), - wantValue: addr("2001:db8::1"), + cidr: mustPfx("2001:db8:affe:cafe:dead:beef::/96"), + wantCIDR: mustPfx("2001:db8::/32"), + wantValue: mustAddr("2001:db8::1"), wantOK: true, }, } @@ -483,12 +483,12 @@ func TestLookupCIDR(t *testing.T) { } } - prefix := prfx("10.0.0.0/8") + prefix := mustPfx("10.0.0.0/8") if ok := rtbl.Delete(prefix); !ok { t.Errorf("Delete(%v) = %v, want %v", prefix, ok, true) } - cidr := prfx("10.2.0.0/16") + cidr := mustPfx("10.2.0.0/16") wantCIDR := netip.Prefix{} wantValue := any(nil) @@ -496,12 +496,12 @@ func TestLookupCIDR(t *testing.T) { t.Errorf("LookupCIDR(%v) = %v, %v, %v, want %v, %v, %v", cidr, got, got2, ok, wantCIDR, wantValue, false) } - prefix = prfx("::/0") + prefix = mustPfx("::/0") if ok := rtbl.Delete(prefix); !ok { t.Errorf("Delete(%v) = %v, want %v", prefix, ok, true) } - cidr = prfx("::2/96") + cidr = mustPfx("::2/96") wantCIDR = netip.Prefix{} wantValue = any(nil) @@ -513,7 +513,7 @@ func TestLookupCIDR(t *testing.T) { tc := shuffleFullTable(100_000) - rtbl2 := new(cidrtree.Table) + rtbl2 := new(cidrtree.Table[any]) for _, cidr := range tc { rtbl2.Insert(cidr, nil) } @@ -526,8 +526,8 @@ func TestLookupCIDR(t *testing.T) { func TestUnion(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) - rtbl2 := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) + rtbl2 := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) rtbl2.Insert(route.cidr, route.nextHop) @@ -539,7 +539,7 @@ func TestUnion(t *testing.T) { } clone := rtbl.Clone() - rtbl.Union(cidrtree.Table{}) + rtbl.Union(cidrtree.Table[any]{}) if !reflect.DeepEqual(rtbl, clone) { t.Fatal("UnionMutable with zero value changed original") } @@ -552,8 +552,8 @@ func TestUnion(t *testing.T) { func TestUnionDupe(t *testing.T) { t.Parallel() - rtbl1 := new(cidrtree.Table) - rtbl2 := new(cidrtree.Table) + rtbl1 := new(cidrtree.Table[any]) + rtbl2 := new(cidrtree.Table[any]) for _, cidr := range shuffleFullTable(100_000) { rtbl1.Insert(cidr, 1) // dupe cidr with different value @@ -582,7 +582,7 @@ func TestUnionDupe(t *testing.T) { func TestFprint(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -599,7 +599,7 @@ func TestFprint(t *testing.T) { func TestWalk(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -618,7 +618,7 @@ func TestWalk(t *testing.T) { func TestWalkStartStop(t *testing.T) { t.Parallel() - rtbl := new(cidrtree.Table) + rtbl := new(cidrtree.Table[any]) for _, route := range routes { rtbl.Insert(route.cidr, route.nextHop) } @@ -629,7 +629,7 @@ func TestWalkStartStop(t *testing.T) { // skip return true } - if pfx == prfx("fc00::/7") { + if pfx == mustPfx("fc00::/7") { // stop return false } diff --git a/whitebox_test.go b/whitebox_test.go index c2f0644..df6eae3 100644 --- a/whitebox_test.go +++ b/whitebox_test.go @@ -12,8 +12,8 @@ import ( "testing" ) -func TestFprintBST(t *testing.T) { - rtbl := new(Table) +func TestFprintBSTVerbose(t *testing.T) { + rtbl := new(Table[any]) for i := 1; i <= 48; i++ { rtbl.Insert(randPfx4(), nil) rtbl.Insert(randPfx6(), nil) @@ -37,9 +37,9 @@ func TestFprintBST(t *testing.T) { t.Log(w.String()) } -func TestStatisticsRandom(t *testing.T) { +func TestStatisticsRandomVerbose(t *testing.T) { for i := 10; i <= 100_000; i *= 10 { - rtbl := new(Table) + rtbl := new(Table[any]) for c := 0; c <= i; c++ { rtbl.Insert(randPfx(), nil) } @@ -56,8 +56,8 @@ func TestStatisticsRandom(t *testing.T) { } } -func TestStatisticsFullTable(t *testing.T) { - rtbl := new(Table) +func TestStatisticsFullTableVerbose(t *testing.T) { + rtbl := new(Table[any]) for _, cidr := range fullTable { rtbl.Insert(cidr, nil) } @@ -72,7 +72,7 @@ func TestStatisticsFullTable(t *testing.T) { t.Logf("FullTable: size: %10d, maxDepth: %4d, average: %3.2f, deviation: %3.2f", size, maxDepth, average, deviation) } -func TestLPMRandom(t *testing.T) { +func TestLPMRandomVerbose(t *testing.T) { var size int var depth int var maxDepth int @@ -80,7 +80,7 @@ func TestLPMRandom(t *testing.T) { var lpm netip.Prefix for i := 10; i <= 100_000; i *= 10 { - rtbl := new(Table) + rtbl := new(Table[any]) for c := 0; c <= i; c++ { rtbl.Insert(randPfx(), nil) } @@ -97,7 +97,7 @@ func TestLPMRandom(t *testing.T) { } } -func TestLPMFullTableWithDefaultRoutes(t *testing.T) { +func TestLPMFullTableWithDefaultRoutesVerbose(t *testing.T) { var size int var depth int var maxDepth int @@ -106,7 +106,7 @@ func TestLPMFullTableWithDefaultRoutes(t *testing.T) { var addr netip.Addr var lpm netip.Prefix - rtbl := new(Table) + rtbl := new(Table[any]) for _, cidr := range fullTable { rtbl.Insert(cidr, nil) }