Refactor: DomainTrie use generics
This commit is contained in:
@ -17,8 +17,8 @@ var ErrInvalidDomain = errors.New("invalid domain")
|
||||
|
||||
// DomainTrie contains the main logic for adding and searching nodes for domain segments.
|
||||
// support wildcard domain (e.g *.google.com)
|
||||
type DomainTrie struct {
|
||||
root *Node
|
||||
type DomainTrie[T comparable] struct {
|
||||
root *Node[T]
|
||||
}
|
||||
|
||||
func ValidAndSplitDomain(domain string) ([]string, bool) {
|
||||
@ -51,7 +51,7 @@ func ValidAndSplitDomain(domain string) ([]string, bool) {
|
||||
// 3. subdomain.*.example.com
|
||||
// 4. .example.com
|
||||
// 5. +.example.com
|
||||
func (t *DomainTrie) Insert(domain string, data any) error {
|
||||
func (t *DomainTrie[T]) Insert(domain string, data T) error {
|
||||
parts, valid := ValidAndSplitDomain(domain)
|
||||
if !valid {
|
||||
return ErrInvalidDomain
|
||||
@ -68,13 +68,13 @@ func (t *DomainTrie) Insert(domain string, data any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DomainTrie) insert(parts []string, data any) {
|
||||
func (t *DomainTrie[T]) insert(parts []string, data T) {
|
||||
node := t.root
|
||||
// reverse storage domain part to save space
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
part := parts[i]
|
||||
if !node.hasChild(part) {
|
||||
node.addChild(part, newNode(nil))
|
||||
node.addChild(part, newNode(getZero[T]()))
|
||||
}
|
||||
|
||||
node = node.getChild(part)
|
||||
@ -88,7 +88,7 @@ func (t *DomainTrie) insert(parts []string, data any) {
|
||||
// 1. static part
|
||||
// 2. wildcard domain
|
||||
// 2. dot wildcard domain
|
||||
func (t *DomainTrie) Search(domain string) *Node {
|
||||
func (t *DomainTrie[T]) Search(domain string) *Node[T] {
|
||||
parts, valid := ValidAndSplitDomain(domain)
|
||||
if !valid || parts[0] == "" {
|
||||
return nil
|
||||
@ -96,26 +96,26 @@ func (t *DomainTrie) Search(domain string) *Node {
|
||||
|
||||
n := t.search(t.root, parts)
|
||||
|
||||
if n == nil || n.Data == nil {
|
||||
if n == nil || n.Data == getZero[T]() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (t *DomainTrie) search(node *Node, parts []string) *Node {
|
||||
func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
|
||||
if len(parts) == 0 {
|
||||
return node
|
||||
}
|
||||
|
||||
if c := node.getChild(parts[len(parts)-1]); c != nil {
|
||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil {
|
||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
if c := node.getChild(wildcard); c != nil {
|
||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil {
|
||||
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
|
||||
return n
|
||||
}
|
||||
}
|
||||
@ -124,6 +124,6 @@ func (t *DomainTrie) search(node *Node, parts []string) *Node {
|
||||
}
|
||||
|
||||
// New returns a new, empty Trie.
|
||||
func New() *DomainTrie {
|
||||
return &DomainTrie{root: newNode(nil)}
|
||||
func New[T comparable]() *DomainTrie[T] {
|
||||
return &DomainTrie[T]{root: newNode[T](getZero[T]())}
|
||||
}
|
||||
|
@ -1,16 +1,16 @@
|
||||
package trie
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var localIP = net.IP{127, 0, 0, 1}
|
||||
var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
|
||||
func TestTrie_Basic(t *testing.T) {
|
||||
tree := New()
|
||||
tree := New[netip.Addr]()
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"google.com",
|
||||
@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) {
|
||||
|
||||
node := tree.Search("example.com")
|
||||
assert.NotNil(t, node)
|
||||
assert.True(t, node.Data.(net.IP).Equal(localIP))
|
||||
assert.True(t, node.Data == localIP)
|
||||
assert.NotNil(t, tree.Insert("", localIP))
|
||||
assert.Nil(t, tree.Search(""))
|
||||
assert.NotNil(t, tree.Search("localhost"))
|
||||
@ -31,7 +31,7 @@ func TestTrie_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTrie_Wildcard(t *testing.T) {
|
||||
tree := New()
|
||||
tree := New[netip.Addr]()
|
||||
domains := []string{
|
||||
"*.example.com",
|
||||
"sub.*.example.com",
|
||||
@ -64,7 +64,7 @@ func TestTrie_Wildcard(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTrie_Priority(t *testing.T) {
|
||||
tree := New()
|
||||
tree := New[int]()
|
||||
domains := []string{
|
||||
".dev",
|
||||
"example.dev",
|
||||
@ -79,18 +79,18 @@ func TestTrie_Priority(t *testing.T) {
|
||||
}
|
||||
|
||||
for idx, domain := range domains {
|
||||
tree.Insert(domain, idx)
|
||||
tree.Insert(domain, idx+1)
|
||||
}
|
||||
|
||||
assertFn("test.dev", 0)
|
||||
assertFn("foo.bar.dev", 0)
|
||||
assertFn("example.dev", 1)
|
||||
assertFn("foo.example.dev", 2)
|
||||
assertFn("test.example.dev", 3)
|
||||
assertFn("test.dev", 1)
|
||||
assertFn("foo.bar.dev", 1)
|
||||
assertFn("example.dev", 2)
|
||||
assertFn("foo.example.dev", 3)
|
||||
assertFn("test.example.dev", 4)
|
||||
}
|
||||
|
||||
func TestTrie_Boundary(t *testing.T) {
|
||||
tree := New()
|
||||
tree := New[netip.Addr]()
|
||||
tree.Insert("*.dev", localIP)
|
||||
|
||||
assert.NotNil(t, tree.Insert(".", localIP))
|
||||
@ -99,7 +99,7 @@ func TestTrie_Boundary(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTrie_WildcardBoundary(t *testing.T) {
|
||||
tree := New()
|
||||
tree := New[netip.Addr]()
|
||||
tree.Insert("+.*", localIP)
|
||||
tree.Insert("stun.*.*.*", localIP)
|
||||
|
||||
|
@ -1,34 +1,31 @@
|
||||
package trie
|
||||
|
||||
// Node is the trie's node
|
||||
type Node struct {
|
||||
children map[string]*Node
|
||||
Data any
|
||||
type Node[T comparable] struct {
|
||||
children map[string]*Node[T]
|
||||
Data T
|
||||
}
|
||||
|
||||
func (n *Node) getChild(s string) *Node {
|
||||
if n.children == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Node[T]) getChild(s string) *Node[T] {
|
||||
return n.children[s]
|
||||
}
|
||||
|
||||
func (n *Node) hasChild(s string) bool {
|
||||
func (n *Node[T]) hasChild(s string) bool {
|
||||
return n.getChild(s) != nil
|
||||
}
|
||||
|
||||
func (n *Node) addChild(s string, child *Node) {
|
||||
if n.children == nil {
|
||||
n.children = map[string]*Node{}
|
||||
}
|
||||
|
||||
func (n *Node[T]) addChild(s string, child *Node[T]) {
|
||||
n.children[s] = child
|
||||
}
|
||||
|
||||
func newNode(data any) *Node {
|
||||
return &Node{
|
||||
func newNode[T comparable](data T) *Node[T] {
|
||||
return &Node[T]{
|
||||
Data: data,
|
||||
children: nil,
|
||||
children: map[string]*Node[T]{},
|
||||
}
|
||||
}
|
||||
|
||||
func getZero[T comparable]() T {
|
||||
var result T
|
||||
return result
|
||||
}
|
||||
|
Reference in New Issue
Block a user