feat: add ip-version param

This commit is contained in:
Skyxim
2022-08-28 13:41:19 +08:00
parent 42e489e199
commit 99effb051b
20 changed files with 398 additions and 216 deletions

View File

@ -5,8 +5,10 @@ import (
"errors"
"fmt"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
"net"
"net/netip"
"strings"
"sync"
)
@ -16,6 +18,8 @@ var (
actualDualStackDialContext = dualStackDialContext
tcpConcurrent = false
DisableIPv6 = false
ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel")
)
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
@ -32,13 +36,23 @@ func DialContext(ctx context.Context, network, address string, options ...Option
o(opt)
}
if opt.network == 4 || opt.network == 6 {
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
network = fmt.Sprintf("%s%d", network, opt.network)
}
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleDialContext(ctx, network, address, opt)
case "tcp", "udp":
return actualDualStackDialContext(ctx, network, address, opt)
default:
return nil, errors.New("network invalid")
return nil, ErrorInvalidedNetworkStack
}
}
@ -56,10 +70,6 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio
o(cfg)
}
if DisableIPv6 {
network = "udp4"
}
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
@ -108,7 +118,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
}
if DisableIPv6 && destination.Is6() {
return nil, fmt.Errorf("IPv6 is diabled, dialer cancel")
return nil, ErrorDisableIPv6
}
return dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port))
@ -230,29 +240,49 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
ip netip.Addr
net.Conn
error
resolved bool
isPrimary bool
done bool
}
preferCount := atomic.NewInt32(0)
results := make(chan dialResult)
tcpRacer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{ip: ip}
result := dialResult{ip: ip, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
result.Conn.Close()
_ = result.Conn.Close()
}
}
}()
v := "4"
if ip.Is6() {
v = "6"
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
result.Conn, result.error = dialContext(ctx, network+v, ip, port, opt)
if ip.Is6() {
network += "6"
if opt.prefer != 4 {
result.isPrimary = true
}
}
if ip.Is4() {
network += "4"
if opt.prefer != 6 {
result.isPrimary = true
}
}
if result.isPrimary {
preferCount.Add(1)
}
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
for _, ip := range ips {
@ -260,13 +290,28 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
}
connCount := len(ips)
var fallback dialResult
for i := 0; i < connCount; i++ {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
if res.isPrimary {
return res.Conn, nil
} else {
fallback = res
}
} else {
if res.isPrimary {
preferCount.Add(-1)
if preferCount.Load() == 0 && fallback.done {
return fallback.Conn, nil
}
}
}
case <-ctx.Done():
if fallback.done {
return fallback.Conn, nil
}
break
}
}
@ -303,25 +348,45 @@ func singleDialContext(ctx context.Context, network string, address string, opt
}
func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
switch network {
case "tcp4", "udp4":
return concurrentIPv4DialContext(ctx, network, address, opt)
default:
return concurrentIPv6DialContext(ctx, network, address, opt)
}
}
func concurrentIPv4DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
switch network {
case "tcp4", "udp4":
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
default:
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
}
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
func concurrentIPv6DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
}
if err != nil {

View File

@ -1,6 +1,8 @@
package dialer
import "go.uber.org/atomic"
import (
"go.uber.org/atomic"
)
var (
DefaultOptions []Option
@ -13,6 +15,8 @@ type option struct {
addrReuse bool
routingMark int
direct bool
network int
prefer int
}
type Option func(opt *option)
@ -40,3 +44,25 @@ func WithDirect() Option {
opt.direct = true
}
}
func WithPreferIPv4() Option {
return func(opt *option) {
opt.prefer = 4
}
}
func WithPreferIPv6() Option {
return func(opt *option) {
opt.prefer = 6
}
}
func WithOnlySingleStack(isIPv4 bool) Option {
return func(opt *option) {
if isIPv4 {
opt.network = 4
} else {
opt.network = 6
}
}
}