feat: add ip-version param

This commit is contained in:
Skyxim
2022-08-28 13:41:19 +08:00
parent 42e489e199
commit 99effb051b
20 changed files with 398 additions and 216 deletions

View File

@ -5,8 +5,10 @@ import (
"errors"
"fmt"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
"net"
"net/netip"
"strings"
"sync"
)
@ -16,6 +18,8 @@ var (
actualDualStackDialContext = dualStackDialContext
tcpConcurrent = false
DisableIPv6 = false
ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel")
)
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
@ -32,13 +36,23 @@ func DialContext(ctx context.Context, network, address string, options ...Option
o(opt)
}
if opt.network == 4 || opt.network == 6 {
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
network = fmt.Sprintf("%s%d", network, opt.network)
}
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleDialContext(ctx, network, address, opt)
case "tcp", "udp":
return actualDualStackDialContext(ctx, network, address, opt)
default:
return nil, errors.New("network invalid")
return nil, ErrorInvalidedNetworkStack
}
}
@ -56,10 +70,6 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio
o(cfg)
}
if DisableIPv6 {
network = "udp4"
}
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
@ -108,7 +118,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
}
if DisableIPv6 && destination.Is6() {
return nil, fmt.Errorf("IPv6 is diabled, dialer cancel")
return nil, ErrorDisableIPv6
}
return dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port))
@ -230,29 +240,49 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
ip netip.Addr
net.Conn
error
resolved bool
isPrimary bool
done bool
}
preferCount := atomic.NewInt32(0)
results := make(chan dialResult)
tcpRacer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{ip: ip}
result := dialResult{ip: ip, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
result.Conn.Close()
_ = result.Conn.Close()
}
}
}()
v := "4"
if ip.Is6() {
v = "6"
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
result.Conn, result.error = dialContext(ctx, network+v, ip, port, opt)
if ip.Is6() {
network += "6"
if opt.prefer != 4 {
result.isPrimary = true
}
}
if ip.Is4() {
network += "4"
if opt.prefer != 6 {
result.isPrimary = true
}
}
if result.isPrimary {
preferCount.Add(1)
}
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
for _, ip := range ips {
@ -260,13 +290,28 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
}
connCount := len(ips)
var fallback dialResult
for i := 0; i < connCount; i++ {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
if res.isPrimary {
return res.Conn, nil
} else {
fallback = res
}
} else {
if res.isPrimary {
preferCount.Add(-1)
if preferCount.Load() == 0 && fallback.done {
return fallback.Conn, nil
}
}
}
case <-ctx.Done():
if fallback.done {
return fallback.Conn, nil
}
break
}
}
@ -303,25 +348,45 @@ func singleDialContext(ctx context.Context, network string, address string, opt
}
func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
switch network {
case "tcp4", "udp4":
return concurrentIPv4DialContext(ctx, network, address, opt)
default:
return concurrentIPv6DialContext(ctx, network, address, opt)
}
}
func concurrentIPv4DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
switch network {
case "tcp4", "udp4":
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
default:
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
}
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
func concurrentIPv6DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
}
if err != nil {

View File

@ -1,6 +1,8 @@
package dialer
import "go.uber.org/atomic"
import (
"go.uber.org/atomic"
)
var (
DefaultOptions []Option
@ -13,6 +15,8 @@ type option struct {
addrReuse bool
routingMark int
direct bool
network int
prefer int
}
type Option func(opt *option)
@ -40,3 +44,25 @@ func WithDirect() Option {
opt.direct = true
}
}
func WithPreferIPv4() Option {
return func(opt *option) {
opt.prefer = 4
}
}
func WithPreferIPv6() Option {
return func(opt *option) {
opt.prefer = 6
}
}
func WithOnlySingleStack(isIPv4 bool) Option {
return func(opt *option) {
if isIPv4 {
opt.network = 4
} else {
opt.network = 6
}
}
}

View File

@ -41,7 +41,6 @@ type Resolver interface {
ResolveIPv4(host string) (ip netip.Addr, err error)
ResolveIPv6(host string) (ip netip.Addr, err error)
ResolveAllIP(host string) (ip []netip.Addr, err error)
ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err error)
ResolveAllIPv4(host string) (ips []netip.Addr, err error)
ResolveAllIPv6(host string) (ips []netip.Addr, err error)
}
@ -74,10 +73,10 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIPWithResolver same as ResolveIP, but with a resolver
func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) {
if ips, err := ResolveAllIPPrimaryIPv4WithResolver(host, r); err == nil {
return ips[rand.Intn(len(ips))], nil
if ip, err := ResolveIPv4WithResolver(host, r); err == nil {
return ip, nil
} else {
return netip.Addr{}, err
return ResolveIPv6WithResolver(host, r)
}
}
@ -95,7 +94,6 @@ func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
return ip, nil
}
}
return ResolveIPv4(host)
}
@ -108,7 +106,6 @@ func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
return ip, nil
}
}
return ResolveIPv6(host)
}
@ -121,7 +118,6 @@ func ResolveProxyServerHost(host string) (netip.Addr, error) {
return ip, err
}
}
return ResolveIP(host)
}
@ -160,7 +156,6 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))]))}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
@ -200,7 +195,6 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{netip.AddrFrom4(*(*[4]byte)(ip))}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
@ -232,39 +226,6 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
func ResolveAllIPPrimaryIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
return []netip.Addr{node.Data}, nil
}
ip, err := netip.ParseAddr(host)
if err == nil {
return []netip.Addr{ip}, nil
}
if r != nil {
if DisableIPv6 {
return r.ResolveAllIPv4(host)
}
return r.ResolveAllIPPrimaryIPv4(host)
} else if DisableIPv6 {
return ResolveAllIPv4(host)
}
if DefaultResolver == nil {
ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return []netip.Addr{}, err
}
return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
@ -284,7 +245,6 @@ func ResolveAllIPv6ProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPv6WithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIPv6(host)
}
@ -292,7 +252,6 @@ func ResolveAllIPv4ProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPv4WithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIPv4(host)
}
@ -300,6 +259,5 @@ func ResolveAllIPProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPWithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIP(host)
}