chore: Update dependencies
This commit is contained in:
@ -1,100 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"github.com/Dreamacro/clash/transport/hysteria/utils"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"net"
|
||||
)
|
||||
|
||||
const entryCacheSize = 1024
|
||||
|
||||
type Engine struct {
|
||||
DefaultAction Action
|
||||
Entries []Entry
|
||||
Cache *lru.ARCCache
|
||||
ResolveIPAddr func(string) (*net.IPAddr, error)
|
||||
GeoIPReader *geoip2.Reader
|
||||
}
|
||||
|
||||
type cacheKey struct {
|
||||
Host string
|
||||
Port uint16
|
||||
IsUDP bool
|
||||
}
|
||||
|
||||
type cacheValue struct {
|
||||
Action Action
|
||||
Arg string
|
||||
}
|
||||
|
||||
// action, arg, isDomain, resolvedIP, error
|
||||
func (e *Engine) ResolveAndMatch(host string, port uint16, isUDP bool) (Action, string, bool, *net.IPAddr, error) {
|
||||
ip, zone := utils.ParseIPZone(host)
|
||||
if ip == nil {
|
||||
// Domain
|
||||
ipAddr, err := e.ResolveIPAddr(host)
|
||||
if v, ok := e.Cache.Get(cacheKey{host, port, isUDP}); ok {
|
||||
// Cache hit
|
||||
ce := v.(cacheValue)
|
||||
return ce.Action, ce.Arg, true, ipAddr, err
|
||||
}
|
||||
for _, entry := range e.Entries {
|
||||
mReq := MatchRequest{
|
||||
Domain: host,
|
||||
Port: port,
|
||||
DB: e.GeoIPReader,
|
||||
}
|
||||
if ipAddr != nil {
|
||||
mReq.IP = ipAddr.IP
|
||||
}
|
||||
if isUDP {
|
||||
mReq.Protocol = ProtocolUDP
|
||||
} else {
|
||||
mReq.Protocol = ProtocolTCP
|
||||
}
|
||||
if entry.Match(mReq) {
|
||||
e.Cache.Add(cacheKey{host, port, isUDP},
|
||||
cacheValue{entry.Action, entry.ActionArg})
|
||||
return entry.Action, entry.ActionArg, true, ipAddr, err
|
||||
}
|
||||
}
|
||||
e.Cache.Add(cacheKey{host, port, isUDP}, cacheValue{e.DefaultAction, ""})
|
||||
return e.DefaultAction, "", true, ipAddr, err
|
||||
} else {
|
||||
// IP
|
||||
if v, ok := e.Cache.Get(cacheKey{ip.String(), port, isUDP}); ok {
|
||||
// Cache hit
|
||||
ce := v.(cacheValue)
|
||||
return ce.Action, ce.Arg, false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
for _, entry := range e.Entries {
|
||||
mReq := MatchRequest{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
DB: e.GeoIPReader,
|
||||
}
|
||||
if isUDP {
|
||||
mReq.Protocol = ProtocolUDP
|
||||
} else {
|
||||
mReq.Protocol = ProtocolTCP
|
||||
}
|
||||
if entry.Match(mReq) {
|
||||
e.Cache.Add(cacheKey{ip.String(), port, isUDP},
|
||||
cacheValue{entry.Action, entry.ActionArg})
|
||||
return entry.Action, entry.ActionArg, false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
e.Cache.Add(cacheKey{ip.String(), port, isUDP}, cacheValue{e.DefaultAction, ""})
|
||||
return e.DefaultAction, "", false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
}
|
@ -1,154 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEngine_ResolveAndMatch(t *testing.T) {
|
||||
cache, _ := lru.NewARC(16)
|
||||
e := &Engine{
|
||||
DefaultAction: ActionDirect,
|
||||
Entries: []Entry{
|
||||
{
|
||||
Action: ActionProxy,
|
||||
ActionArg: "",
|
||||
Matcher: &domainMatcher{
|
||||
matcherBase: matcherBase{
|
||||
Protocol: ProtocolTCP,
|
||||
Port: 443,
|
||||
},
|
||||
Domain: "google.com",
|
||||
Suffix: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionHijack,
|
||||
ActionArg: "good.org",
|
||||
Matcher: &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "evil.corp",
|
||||
Suffix: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionProxy,
|
||||
ActionArg: "",
|
||||
Matcher: &netMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Net: &net.IPNet{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
Mask: net.CIDRMask(8, 32),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionBlock,
|
||||
ActionArg: "",
|
||||
Matcher: &allMatcher{},
|
||||
},
|
||||
},
|
||||
Cache: cache,
|
||||
ResolveIPAddr: func(s string) (*net.IPAddr, error) {
|
||||
if strings.Contains(s, "evil.corp") {
|
||||
return nil, errors.New("resolve error")
|
||||
}
|
||||
return net.ResolveIPAddr("ip", s)
|
||||
},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port uint16
|
||||
isUDP bool
|
||||
wantAction Action
|
||||
wantArg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "domain proxy",
|
||||
host: "google.com",
|
||||
port: 443,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain block",
|
||||
host: "google.com",
|
||||
port: 80,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain suffix 1",
|
||||
host: "evil.corp",
|
||||
port: 8899,
|
||||
isUDP: true,
|
||||
wantAction: ActionHijack,
|
||||
wantArg: "good.org",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "domain suffix 2",
|
||||
host: "notevil.corp",
|
||||
port: 22,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "domain suffix 3",
|
||||
host: "im.real.evil.corp",
|
||||
port: 443,
|
||||
isUDP: true,
|
||||
wantAction: ActionHijack,
|
||||
wantArg: "good.org",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ip match",
|
||||
host: "10.2.3.4",
|
||||
port: 80,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "ip mismatch",
|
||||
host: "100.5.6.0",
|
||||
port: 1234,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain proxy cache",
|
||||
host: "google.com",
|
||||
port: 443,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotAction != tt.wantAction {
|
||||
t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction)
|
||||
}
|
||||
if gotArg != tt.wantArg {
|
||||
t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,331 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Action byte
|
||||
type Protocol byte
|
||||
|
||||
const (
|
||||
ActionDirect = Action(iota)
|
||||
ActionProxy
|
||||
ActionBlock
|
||||
ActionHijack
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolAll = Protocol(iota)
|
||||
ProtocolTCP
|
||||
ProtocolUDP
|
||||
)
|
||||
|
||||
var protocolPortAliases = map[string]string{
|
||||
"echo": "*/7",
|
||||
"ftp-data": "*/20",
|
||||
"ftp": "*/21",
|
||||
"ssh": "*/22",
|
||||
"telnet": "*/23",
|
||||
"domain": "*/53",
|
||||
"dns": "*/53",
|
||||
"http": "*/80",
|
||||
"sftp": "*/115",
|
||||
"ntp": "*/123",
|
||||
"https": "*/443",
|
||||
"quic": "udp/443",
|
||||
"socks": "*/1080",
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
Action Action
|
||||
ActionArg string
|
||||
Matcher Matcher
|
||||
}
|
||||
|
||||
type MatchRequest struct {
|
||||
IP net.IP
|
||||
Domain string
|
||||
|
||||
Protocol Protocol
|
||||
Port uint16
|
||||
|
||||
DB *geoip2.Reader
|
||||
}
|
||||
|
||||
type Matcher interface {
|
||||
Match(MatchRequest) bool
|
||||
}
|
||||
|
||||
type matcherBase struct {
|
||||
Protocol Protocol
|
||||
Port uint16 // 0 for all ports
|
||||
}
|
||||
|
||||
func (m *matcherBase) MatchProtocolPort(p Protocol, port uint16) bool {
|
||||
return (m.Protocol == ProtocolAll || m.Protocol == p) && (m.Port == 0 || m.Port == port)
|
||||
}
|
||||
|
||||
func parseProtocolPort(s string) (Protocol, uint16, error) {
|
||||
if protocolPortAliases[s] != "" {
|
||||
s = protocolPortAliases[s]
|
||||
}
|
||||
if len(s) == 0 || s == "*" {
|
||||
return ProtocolAll, 0, nil
|
||||
}
|
||||
parts := strings.Split(s, "/")
|
||||
if len(parts) != 2 {
|
||||
return ProtocolAll, 0, errors.New("invalid protocol/port syntax")
|
||||
}
|
||||
protocol := ProtocolAll
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
protocol = ProtocolTCP
|
||||
case "udp":
|
||||
protocol = ProtocolUDP
|
||||
case "*":
|
||||
protocol = ProtocolAll
|
||||
default:
|
||||
return ProtocolAll, 0, errors.New("invalid protocol")
|
||||
}
|
||||
if parts[1] == "*" {
|
||||
return protocol, 0, nil
|
||||
}
|
||||
port, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolAll, 0, errors.New("invalid port")
|
||||
}
|
||||
return protocol, uint16(port), nil
|
||||
}
|
||||
|
||||
type netMatcher struct {
|
||||
matcherBase
|
||||
Net *net.IPNet
|
||||
}
|
||||
|
||||
func (m *netMatcher) Match(r MatchRequest) bool {
|
||||
if r.IP == nil {
|
||||
return false
|
||||
}
|
||||
return m.Net.Contains(r.IP) && m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type domainMatcher struct {
|
||||
matcherBase
|
||||
Domain string
|
||||
Suffix bool
|
||||
}
|
||||
|
||||
func (m *domainMatcher) Match(r MatchRequest) bool {
|
||||
if len(r.Domain) == 0 {
|
||||
return false
|
||||
}
|
||||
domain := strings.ToLower(r.Domain)
|
||||
return (m.Domain == domain || (m.Suffix && strings.HasSuffix(domain, "."+m.Domain))) &&
|
||||
m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type countryMatcher struct {
|
||||
matcherBase
|
||||
Country string // ISO 3166-1 alpha-2 country code, upper case
|
||||
}
|
||||
|
||||
func (m *countryMatcher) Match(r MatchRequest) bool {
|
||||
if r.IP == nil || r.DB == nil {
|
||||
return false
|
||||
}
|
||||
c, err := r.DB.Country(r.IP)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return c.Country.IsoCode == m.Country && m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type allMatcher struct {
|
||||
matcherBase
|
||||
}
|
||||
|
||||
func (m *allMatcher) Match(r MatchRequest) bool {
|
||||
return m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
func (e Entry) Match(r MatchRequest) bool {
|
||||
return e.Matcher.Match(r)
|
||||
}
|
||||
|
||||
func ParseEntry(s string) (Entry, error) {
|
||||
fields := strings.Fields(s)
|
||||
if len(fields) < 2 {
|
||||
return Entry{}, fmt.Errorf("expected at least 2 fields, got %d", len(fields))
|
||||
}
|
||||
e := Entry{}
|
||||
action := fields[0]
|
||||
conds := fields[1:]
|
||||
switch strings.ToLower(action) {
|
||||
case "direct":
|
||||
e.Action = ActionDirect
|
||||
case "proxy":
|
||||
e.Action = ActionProxy
|
||||
case "block":
|
||||
e.Action = ActionBlock
|
||||
case "hijack":
|
||||
if len(conds) < 2 {
|
||||
return Entry{}, fmt.Errorf("hijack requires at least 3 fields, got %d", len(fields))
|
||||
}
|
||||
e.Action = ActionHijack
|
||||
e.ActionArg = conds[len(conds)-1]
|
||||
conds = conds[:len(conds)-1]
|
||||
default:
|
||||
return Entry{}, fmt.Errorf("invalid action %s", fields[0])
|
||||
}
|
||||
m, err := condsToMatcher(conds)
|
||||
if err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
e.Matcher = m
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func condsToMatcher(conds []string) (Matcher, error) {
|
||||
if len(conds) < 1 {
|
||||
return nil, errors.New("no condition specified")
|
||||
}
|
||||
typ, args := conds[0], conds[1:]
|
||||
switch strings.ToLower(typ) {
|
||||
case "domain":
|
||||
// domain <domain> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for domain: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &domainMatcher{
|
||||
matcherBase: mb,
|
||||
Domain: args[0],
|
||||
Suffix: false,
|
||||
}, nil
|
||||
case "domain-suffix":
|
||||
// domain-suffix <domain> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for domain-suffix: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &domainMatcher{
|
||||
matcherBase: mb,
|
||||
Domain: args[0],
|
||||
Suffix: true,
|
||||
}, nil
|
||||
case "cidr":
|
||||
// cidr <cidr> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for cidr: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
_, ipNet, err := net.ParseCIDR(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &netMatcher{
|
||||
matcherBase: mb,
|
||||
Net: ipNet,
|
||||
}, nil
|
||||
case "ip":
|
||||
// ip <ip> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for ip: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
ip := net.ParseIP(args[0])
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid ip: %s", args[0])
|
||||
}
|
||||
var ipNet *net.IPNet
|
||||
if ip.To4() != nil {
|
||||
ipNet = &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(32, 32),
|
||||
}
|
||||
} else {
|
||||
ipNet = &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(128, 128),
|
||||
}
|
||||
}
|
||||
return &netMatcher{
|
||||
matcherBase: mb,
|
||||
Net: ipNet,
|
||||
}, nil
|
||||
case "country":
|
||||
// country <country> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for country: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &countryMatcher{
|
||||
matcherBase: mb,
|
||||
Country: strings.ToUpper(args[0]),
|
||||
}, nil
|
||||
case "all":
|
||||
// all <optional: protocol/port>
|
||||
if len(args) > 1 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for all: %d, expected 0 or 1", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 1 {
|
||||
protocol, port, err := parseProtocolPort(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &allMatcher{
|
||||
matcherBase: mb,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid condition type: %s", typ)
|
||||
}
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseEntry(t *testing.T) {
|
||||
_, ok3net, _ := net.ParseCIDR("8.8.8.0/24")
|
||||
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Entry
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty", args: args{""}, want: Entry{}, wantErr: true},
|
||||
{name: "ok 1", args: args{"direct domain-suffix google.com"},
|
||||
want: Entry{ActionDirect, "", &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "google.com",
|
||||
Suffix: true,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 2", args: args{"proxy domain shithole"},
|
||||
want: Entry{ActionProxy, "", &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "shithole",
|
||||
Suffix: false,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 3", args: args{"block cidr 8.8.8.0/24 */53"},
|
||||
want: Entry{ActionBlock, "", &netMatcher{
|
||||
matcherBase: matcherBase{ProtocolAll, 53},
|
||||
Net: ok3net,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 4", args: args{"hijack all udp/* udpblackhole.net"},
|
||||
want: Entry{ActionHijack, "udpblackhole.net", &allMatcher{
|
||||
matcherBase: matcherBase{ProtocolUDP, 0},
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "err 1", args: args{"what the heck"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 2", args: args{"proxy sucks ass"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 3", args: args{"block ip 999.999.999.999"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 4", args: args{"hijack domain google.com"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 5", args: args{"hijack domain google.com bing.com 123"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseEntry(tt.args.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEntry() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseEntry() got = %v, wantAction %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user