Compare commits

..

7 Commits

Author SHA1 Message Date
10d2d14938 Merge branch 'Beta' into Meta
# Conflicts:
#	rules/provider/classical_strategy.go
2022-07-02 10:41:41 +08:00
691cf1d8d6 Merge pull request #94 from bash99/Meta
Update README.md
2022-06-15 19:15:51 +08:00
d1decb8e58 Update README.md
add permissions for systemctl services
clash-dashboard change to updated one
2022-06-15 14:00:05 +08:00
7d04904109 fix: leak dns when domain in hosts list 2022-06-11 18:51:26 +08:00
a5acd3aa97 refactor: clear linkname,reduce cycle dependencies,transport init geosite function 2022-06-11 18:51:22 +08:00
eea9a12560 fix: 规则匹配默认策略组返回错误 2022-06-09 14:18:35 +08:00
0a4570b55c fix: group filter touch provider 2022-06-09 14:18:29 +08:00
274 changed files with 7905 additions and 16412 deletions

View File

@ -5,14 +5,12 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: Setup Go
uses: actions/setup-go@v3
- name: Set up Go
uses: actions/setup-go@v1
with:
go-version: '1.19'
check-latest: true
cache: true
go-version: 1.18
- name: Check out code
uses: actions/checkout@v1
- name: Build
run: make all
- name: Release

View File

@ -12,15 +12,26 @@ jobs:
Build:
runs-on: ubuntu-latest
steps:
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: Setup Go
uses: actions/setup-go@v3
- name: Cache go module
uses: actions/cache@v2
with:
go-version: '1.19'
check-latest: true
cache: true
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Test
if: ${{github.ref_name=='Beta'}}

View File

@ -7,16 +7,24 @@ jobs:
Build:
runs-on: ubuntu-latest
steps:
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: Setup Go
uses: actions/setup-go@v3
- name: Cache go module
uses: actions/cache@v2
with:
go-version: '1.19'
check-latest: true
cache: true
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Test
run: |
go test ./...

View File

@ -16,11 +16,11 @@ RUN go mod download &&\
FROM alpine:latest
LABEL org.opencontainers.image.source="https://github.com/MetaCubeX/Clash.Meta"
RUN apk add --no-cache ca-certificates tzdata iptables
RUN apk add --no-cache ca-certificates tzdata
VOLUME ["/root/.config/clash/"]
COPY --from=builder /clash-config/ /root/.config/clash/
COPY --from=builder /clash /clash
RUN chmod +x /clash
ENTRYPOINT [ "/clash" ]
ENTRYPOINT [ "/clash" ]

View File

@ -12,7 +12,7 @@ VERSION=$(shell git rev-parse --short HEAD)
endif
BUILDTIME=$(shell date -u)
GOBUILD=CGO_ENABLED=0 go build -tags with_gvisor -trimpath -ldflags '-X "github.com/Dreamacro/clash/constant.Version=$(VERSION)" \
GOBUILD=CGO_ENABLED=0 go build -trimpath -ldflags '-X "github.com/Dreamacro/clash/constant.Version=$(VERSION)" \
-X "github.com/Dreamacro/clash/constant.BuildTime=$(BUILDTIME)" \
-w -s -buildid='
@ -147,11 +147,3 @@ lint:
clean:
rm $(BINDIR)/*
CLANG ?= clang-14
CFLAGS := -O2 -g -Wall -Werror $(CFLAGS)
ebpf: export BPF_CLANG := $(CLANG)
ebpf: export BPF_CFLAGS := $(CFLAGS)
ebpf:
cd component/ebpf/ && go generate ./...

View File

@ -212,21 +212,6 @@ proxies:
grpc-service-name: grpcname
```
Support outbound transport protocol `Wireguard`
```yaml
proxies:
- name: "wg"
type: wireguard
server: 162.159.192.1
port: 2480
ip: 172.16.0.2
ipv6: fd01:5ca1:ab1e:80fa:ab85:6eea:213f:f4a5
private-key: eCtXsJZ27+4PbhDkHnB923tkUn2Gj59wZw5wFA75MnU=
public-key: Cr8hWlKvtDt7nrvf+f0brNQQzabAqrjfBvas9pmowjo=
udp: true
```
### IPTABLES configuration
Work on Linux OS who's supported `iptables`
@ -301,7 +286,6 @@ the [GitHub Wiki](https://github.com/Dreamacro/clash/wiki/use-clash-as-a-library
## Credits
* [Dreamacro/clash](https://github.com/Dreamacro/clash)
* [SagerNet/sing-box](https://github.com/SagerNet/sing-box)
* [riobard/go-shadowsocks2](https://github.com/riobard/go-shadowsocks2)
* [v2ray/v2ray-core](https://github.com/v2ray/v2ray-core)
* [WireGuard/wireguard-go](https://github.com/WireGuard/wireguard-go)

View File

@ -11,6 +11,7 @@ import (
"net/http"
"net/netip"
"net/url"
"strings"
"time"
"go.uber.org/atomic"
@ -39,6 +40,11 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
// DialContext implements C.ProxyAdapter
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...)
wasCancel := false
if err != nil {
wasCancel = strings.Contains(err.Error(), "operation was canceled")
}
p.alive.Store(err == nil || wasCancel)
return conn, err
}
@ -52,6 +58,7 @@ func (p *Proxy) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
// ListenPacketContext implements C.ProxyAdapter
func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
p.alive.Store(err == nil)
return pc, err
}
@ -198,9 +205,10 @@ func urlToMetadata(rawURL string) (addr C.Metadata, err error) {
}
addr = C.Metadata{
Host: u.Hostname(),
DstIP: netip.Addr{},
DstPort: port,
AddrType: C.AtypDomainName,
Host: u.Hostname(),
DstIP: netip.Addr{},
DstPort: port,
}
return
}

View File

@ -17,9 +17,5 @@ func NewHTTP(target socks5.Addr, source net.Addr, conn net.Conn) *context.ConnCo
metadata.SrcIP = ip
metadata.SrcPort = port
}
if ip, port, err := parseAddr(conn.LocalAddr().String()); err == nil {
metadata.InIP = ip
metadata.InPort = port
}
return context.NewConnContext(conn, metadata)
}

View File

@ -16,9 +16,5 @@ func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext {
metadata.SrcIP = ip
metadata.SrcPort = port
}
if ip, port, err := parseAddr(conn.LocalAddr().String()); err == nil {
metadata.InIP = ip
metadata.InPort = port
}
return context.NewConnContext(conn, metadata)
}

View File

@ -25,12 +25,6 @@ func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type) *PacketAda
metadata.SrcIP = ip
metadata.SrcPort = port
}
if p, ok := packet.(C.UDPPacketInAddr); ok {
if ip, port, err := parseAddr(p.InAddr().String()); err == nil {
metadata.InIP = ip
metadata.InPort = port
}
}
return &PacketAdapter{
UDPPacket: packet,

View File

@ -22,14 +22,6 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnCo
metadata.SrcPort = port
}
}
localAddr := conn.LocalAddr()
// Filter when net.Addr interface is nil
if localAddr != nil {
if ip, port, err := parseAddr(localAddr.String()); err == nil {
metadata.InIP = ip
metadata.InPort = port
}
}
return context.NewConnContext(conn, metadata)
}
@ -40,12 +32,17 @@ func NewInner(conn net.Conn, dst string, host string) *context.ConnContext {
metadata.Type = C.INNER
metadata.DNSMode = C.DNSMapping
metadata.Host = host
metadata.AddrType = C.AtypDomainName
metadata.Process = C.ClashName
if h, port, err := net.SplitHostPort(dst); err == nil {
metadata.DstPort = port
if host == "" {
if ip, err := netip.ParseAddr(h); err == nil {
metadata.DstIP = ip
metadata.AddrType = C.AtypIPv4
if ip.Is6() {
metadata.AddrType = C.AtypIPv6
}
}
}
}

View File

@ -1,19 +1,21 @@
package inbound
import (
"github.com/Dreamacro/clash/common/nnip"
"net"
"net/http"
"net/netip"
"strconv"
"strings"
"github.com/Dreamacro/clash/common/nnip"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
func parseSocksAddr(target socks5.Addr) *C.Metadata {
metadata := &C.Metadata{}
metadata := &C.Metadata{
AddrType: int(target[0]),
}
switch target[0] {
case socks5.AtypDomainName:
@ -24,8 +26,7 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata {
metadata.DstIP = nnip.IpToAddr(net.IP(target[1 : 1+net.IPv4len]))
metadata.DstPort = strconv.Itoa((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1]))
case socks5.AtypIPv6:
ip6, _ := netip.AddrFromSlice(target[1 : 1+net.IPv6len])
metadata.DstIP = ip6.Unmap()
metadata.DstIP = nnip.IpToAddr(net.IP(target[1 : 1+net.IPv6len]))
metadata.DstPort = strconv.Itoa((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1]))
}
@ -43,14 +44,21 @@ func parseHTTPAddr(request *http.Request) *C.Metadata {
host = strings.TrimRight(host, ".")
metadata := &C.Metadata{
NetWork: C.TCP,
Host: host,
DstIP: netip.Addr{},
DstPort: port,
NetWork: C.TCP,
AddrType: C.AtypDomainName,
Host: host,
DstIP: netip.Addr{},
DstPort: port,
}
ip, err := netip.ParseAddr(host)
if err == nil {
switch {
case ip.Is6():
metadata.AddrType = C.AtypIPv6
default:
metadata.AddrType = C.AtypIPv4
}
metadata.DstIP = ip
}

View File

@ -13,14 +13,13 @@ import (
)
type Base struct {
name string
addr string
iface string
tp C.AdapterType
udp bool
rmark int
id string
prefer C.DNSPrefer
name string
addr string
iface string
tp C.AdapterType
udp bool
rmark int
id string
}
// Name implements C.ProxyAdapter
@ -90,7 +89,7 @@ func (b *Base) Addr() string {
}
// Unwrap implements C.ProxyAdapter
func (b *Base) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
func (b *Base) Unwrap(metadata *C.Metadata) C.Proxy {
return nil
}
@ -104,25 +103,12 @@ func (b *Base) DialOptions(opts ...dialer.Option) []dialer.Option {
opts = append(opts, dialer.WithRoutingMark(b.rmark))
}
switch b.prefer {
case C.IPv4Only:
opts = append(opts, dialer.WithOnlySingleStack(true))
case C.IPv6Only:
opts = append(opts, dialer.WithOnlySingleStack(false))
case C.IPv4Prefer:
opts = append(opts, dialer.WithPreferIPv4())
case C.IPv6Prefer:
opts = append(opts, dialer.WithPreferIPv6())
default:
}
return opts
}
type BasicOption struct {
Interface string `proxy:"interface-name,omitempty" group:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty" group:"routing-mark,omitempty"`
IPVersion string `proxy:"ip-version,omitempty" group:"ip-version,omitempty"`
}
type BaseOption struct {
@ -132,18 +118,16 @@ type BaseOption struct {
UDP bool
Interface string
RoutingMark int
Prefer C.DNSPrefer
}
func NewBase(opt BaseOption) *Base {
return &Base{
name: opt.Name,
addr: opt.Addr,
tp: opt.Type,
udp: opt.UDP,
iface: opt.Interface,
rmark: opt.RoutingMark,
prefer: opt.Prefer,
name: opt.Name,
addr: opt.Addr,
tp: opt.Type,
udp: opt.UDP,
iface: opt.Interface,
rmark: opt.RoutingMark,
}
}

View File

@ -40,10 +40,9 @@ type directPacketConn struct {
func NewDirect() *Direct {
return &Direct{
Base: &Base{
name: "DIRECT",
tp: C.Direct,
udp: true,
prefer: C.DualStack,
name: "DIRECT",
tp: C.Direct,
udp: true,
},
}
}
@ -51,10 +50,9 @@ func NewDirect() *Direct {
func NewCompatible() *Direct {
return &Direct{
Base: &Base{
name: "COMPATIBLE",
tp: C.Compatible,
udp: true,
prefer: C.DualStack,
name: "COMPATIBLE",
tp: C.Compatible,
udp: true,
},
}
}

View File

@ -7,7 +7,6 @@ import (
"encoding/base64"
"errors"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
"io"
"net"
"net/http"
@ -36,7 +35,6 @@ type HttpOption struct {
TLS bool `proxy:"tls,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
Headers map[string]string `proxy:"headers,omitempty"`
}
@ -128,41 +126,30 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
return fmt.Errorf("can not connect remote err code: %d", resp.StatusCode)
}
func NewHttp(option HttpOption) (*Http, error) {
func NewHttp(option HttpOption) *Http {
var tlsConfig *tls.Config
if option.TLS {
sni := option.Server
if option.SNI != "" {
sni = option.SNI
}
if len(option.Fingerprint) == 0 {
tlsConfig = tlsC.GetGlobalFingerprintTLCConfig(&tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: sni,
})
} else {
var err error
if tlsConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(&tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: sni,
}, option.Fingerprint); err != nil {
return nil, err
}
tlsConfig = &tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: sni,
}
}
return &Http{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Http,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Http,
iface: option.Interface,
rmark: option.RoutingMark,
},
user: option.UserName,
pass: option.Password,
tlsConfig: tlsConfig,
option: &option,
}, nil
}
}

View File

@ -2,31 +2,26 @@ package outbound
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"encoding/pem"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"regexp"
"strconv"
"time"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
M "github.com/sagernet/sing/common/metadata"
"github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
hyCongestion "github.com/Dreamacro/clash/transport/hysteria/congestion"
"github.com/Dreamacro/clash/transport/hysteria/core"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"github.com/Dreamacro/clash/transport/hysteria/pmtud_fix"
"github.com/Dreamacro/clash/transport/hysteria/transport"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/pmtud_fix"
"github.com/tobyxdd/hysteria/pkg/transport"
)
const (
@ -51,15 +46,11 @@ type Hysteria struct {
func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
hdc := hyDialerWithContext{
ctx: context.Background(),
ctx: ctx,
hyDialer: func() (net.PacketConn, error) {
return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...)
},
remoteAddr: func(addr string) (net.Addr, error) {
return resolveUDPAddrWithPrefer("udp", addr, h.prefer)
},
}
tcpConn, err := h.client.DialTCP(metadata.RemoteAddress(), &hdc)
if err != nil {
return nil, err
@ -70,13 +61,10 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts .
func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
hdc := hyDialerWithContext{
ctx: context.Background(),
ctx: ctx,
hyDialer: func() (net.PacketConn, error) {
return dialer.ListenPacket(ctx, "udp", "", h.Base.DialOptions(opts...)...)
},
remoteAddr: func(addr string) (net.Addr, error) {
return resolveUDPAddrWithPrefer("udp", addr, h.prefer)
},
}
udpConn, err := h.client.DialUDP(&hdc)
if err != nil {
@ -87,27 +75,22 @@ func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata
type HysteriaOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Protocol string `proxy:"protocol,omitempty"`
ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash
Up string `proxy:"up"`
UpSpeed int `proxy:"up-speed,omitempty"` // compatible with Stash
Down string `proxy:"down"`
DownSpeed int `proxy:"down-speed,omitempty"` // compatible with Stash
Auth string `proxy:"auth,omitempty"`
AuthString string `proxy:"auth_str,omitempty"`
Obfs string `proxy:"obfs,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
ALPN []string `proxy:"alpn,omitempty"`
CustomCA string `proxy:"ca,omitempty"`
CustomCAString string `proxy:"ca_str,omitempty"`
ReceiveWindowConn int `proxy:"recv_window_conn,omitempty"`
ReceiveWindow int `proxy:"recv_window,omitempty"`
DisableMTUDiscovery bool `proxy:"disable_mtu_discovery,omitempty"`
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Protocol string `proxy:"protocol,omitempty"`
Up string `proxy:"up"`
Down string `proxy:"down"`
AuthString string `proxy:"auth_str,omitempty"`
Obfs string `proxy:"obfs,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
ALPN string `proxy:"alpn,omitempty"`
CustomCA string `proxy:"ca,omitempty"`
CustomCAString string `proxy:"ca_str,omitempty"`
ReceiveWindowConn int `proxy:"recv_window_conn,omitempty"`
ReceiveWindow int `proxy:"recv_window,omitempty"`
DisableMTUDiscovery bool `proxy:"disable_mtu_discovery,omitempty"`
}
func (c *HysteriaOption) Speed() (uint64, uint64, error) {
@ -137,73 +120,51 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
if option.SNI != "" {
serverName = option.SNI
}
tlsConfig := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: option.SkipCertVerify,
MinVersion: tls.VersionTLS13,
}
var bs []byte
var err error
if len(option.CustomCA) > 0 {
bs, err = os.ReadFile(option.CustomCA)
if err != nil {
return nil, fmt.Errorf("hysteria %s load ca error: %w", addr, err)
}
} else if option.CustomCAString != "" {
bs = []byte(option.CustomCAString)
}
if len(bs) > 0 {
block, _ := pem.Decode(bs)
if block == nil {
return nil, fmt.Errorf("CA cert is not PEM")
}
fpBytes := sha256.Sum256(block.Bytes)
if len(option.Fingerprint) == 0 {
option.Fingerprint = hex.EncodeToString(fpBytes[:])
}
}
if len(option.Fingerprint) != 0 {
var err error
tlsConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint)
if err != nil {
return nil, err
}
} else {
tlsConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig)
}
if len(option.ALPN) > 0 {
tlsConfig.NextProtos = option.ALPN
tlsConfig.NextProtos = []string{option.ALPN}
} else {
tlsConfig.NextProtos = []string{DefaultALPN}
}
if len(option.CustomCA) > 0 {
bs, err := ioutil.ReadFile(option.CustomCA)
if err != nil {
return nil, fmt.Errorf("hysteria %s load ca error: %w", addr, err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(bs) {
return nil, fmt.Errorf("hysteria %s failed to parse ca_str", addr)
}
tlsConfig.RootCAs = cp
} else if option.CustomCAString != "" {
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM([]byte(option.CustomCAString)) {
return nil, fmt.Errorf("hysteria %s failed to parse ca_str", addr)
}
tlsConfig.RootCAs = cp
}
quicConfig := &quic.Config{
InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn),
MaxStreamReceiveWindow: uint64(option.ReceiveWindowConn),
InitialConnectionReceiveWindow: uint64(option.ReceiveWindow),
MaxConnectionReceiveWindow: uint64(option.ReceiveWindow),
KeepAlivePeriod: 10 * time.Second,
KeepAlive: true,
DisablePathMTUDiscovery: option.DisableMTUDiscovery,
EnableDatagrams: true,
}
if option.ObfsProtocol != "" {
option.Protocol = option.ObfsProtocol
}
if option.Protocol == "" {
option.Protocol = DefaultProtocol
}
if option.ReceiveWindowConn == 0 {
quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow / 10
quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow
quicConfig.MaxStreamReceiveWindow = DefaultStreamReceiveWindow
}
if option.ReceiveWindow == 0 {
quicConfig.InitialConnectionReceiveWindow = DefaultConnectionReceiveWindow / 10
quicConfig.InitialConnectionReceiveWindow = DefaultConnectionReceiveWindow
quicConfig.MaxConnectionReceiveWindow = DefaultConnectionReceiveWindow
}
if !quicConfig.DisablePathMTUDiscovery && pmtud_fix.DisablePathMTUDiscovery {
@ -211,12 +172,6 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
}
var auth = []byte(option.AuthString)
if option.Auth != "" {
auth, err = base64.StdEncoding.DecodeString(option.Auth)
if err != nil {
return nil, err
}
}
var obfuscator obfs.Obfuscator
if len(option.Obfs) > 0 {
obfuscator = obfs.NewXPlusObfuscator([]byte(option.Obfs))
@ -226,12 +181,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
if err != nil {
return nil, err
}
if option.UpSpeed != 0 {
up = uint64(option.UpSpeed * mbpsToBps)
}
if option.DownSpeed != 0 {
down = uint64(option.DownSpeed * mbpsToBps)
}
client, err := core.NewClient(
addr, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
@ -242,13 +192,12 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
}
return &Hysteria{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Hysteria,
udp: true,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: addr,
tp: C.Hysteria,
udp: true,
iface: option.Interface,
rmark: option.RoutingMark,
},
client: client,
}, nil
@ -314,9 +263,8 @@ func (c *hyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
}
type hyDialerWithContext struct {
hyDialer func() (net.PacketConn, error)
ctx context.Context
remoteAddr func(host string) (net.Addr, error)
hyDialer func() (net.PacketConn, error)
ctx context.Context
}
func (h *hyDialerWithContext) ListenPacket() (net.PacketConn, error) {
@ -326,7 +274,3 @@ func (h *hyDialerWithContext) ListenPacket() (net.PacketConn, error) {
func (h *hyDialerWithContext) Context() context.Context {
return h.ctx
}
func (h *hyDialerWithContext) RemoteAddr(host string) (net.Addr, error) {
return h.remoteAddr(host)
}

View File

@ -27,10 +27,9 @@ func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
func NewReject() *Reject {
return &Reject{
Base: &Base{
name: "REJECT",
tp: C.Reject,
udp: true,
prefer: C.DualStack,
name: "REJECT",
tp: C.Reject,
udp: true,
},
}
}
@ -38,10 +37,9 @@ func NewReject() *Reject {
func NewPass() *Reject {
return &Reject{
Base: &Base{
name: "PASS",
tp: C.Pass,
udp: true,
prefer: C.DualStack,
name: "PASS",
tp: C.Pass,
udp: true,
},
}
}

View File

@ -7,6 +7,7 @@ import (
"net"
"strconv"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/common/structure"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
@ -15,11 +16,16 @@ import (
v2rayObfs "github.com/Dreamacro/clash/transport/v2ray-plugin"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing-shadowsocks/shadowimpl"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/uot"
)
func init() {
buf.DefaultAllocator = pool.DefaultAllocator
}
type ShadowSocks struct {
*Base
method shadowsocks.Method
@ -54,7 +60,6 @@ type v2rayObfsOption struct {
Host string `obfs:"host,omitempty"`
Path string `obfs:"path,omitempty"`
TLS bool `obfs:"tls,omitempty"`
Fingerprint string `obfs:"fingerprint,omitempty"`
Headers map[string]string `obfs:"headers,omitempty"`
SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
Mux bool `obfs:"mux,omitempty"`
@ -76,7 +81,8 @@ func (ss *ShadowSocks) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e
}
}
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443"))
metadata.Host = uot.UOTMagicAddress
metadata.DstPort = "443"
}
return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}
@ -109,7 +115,7 @@ func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Meta
return nil, err
}
addr, err := resolveUDPAddrWithPrefer("udp", ss.addr, ss.prefer)
addr, err := resolveUDPAddr("udp", ss.addr)
if err != nil {
pc.Close()
return nil, err
@ -179,13 +185,12 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
return &ShadowSocks{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Shadowsocks,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: addr,
tp: C.Shadowsocks,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
method: method,

View File

@ -79,7 +79,7 @@ func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Me
return nil, err
}
addr, err := resolveUDPAddrWithPrefer("udp", ssr.addr, ssr.prefer)
addr, err := resolveUDPAddr("udp", ssr.addr)
if err != nil {
pc.Close()
return nil, err
@ -143,13 +143,12 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
return &ShadowSocksR{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.ShadowsocksR,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: addr,
tp: C.ShadowsocksR,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
cipher: coreCiph,
obfs: obfs,

View File

@ -152,13 +152,12 @@ func NewSnell(option SnellOption) (*Snell, error) {
s := &Snell{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Snell,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: addr,
tp: C.Snell,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
psk: psk,
obfsOption: obfsOption,

View File

@ -5,7 +5,6 @@ import (
"crypto/tls"
"errors"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
"io"
"net"
"strconv"
@ -34,7 +33,6 @@ type Socks5Option struct {
TLS bool `proxy:"tls,omitempty"`
UDP bool `proxy:"udp,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
}
// StreamConn implements C.ProxyAdapter
@ -140,40 +138,30 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil
}
func NewSocks5(option Socks5Option) (*Socks5, error) {
func NewSocks5(option Socks5Option) *Socks5 {
var tlsConfig *tls.Config
if option.TLS {
tlsConfig = &tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: option.Server,
}
if len(option.Fingerprint) == 0 {
tlsConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig)
} else {
var err error
if tlsConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint); err != nil {
return nil, err
}
}
}
return &Socks5{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Socks5,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Socks5,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
user: option.UserName,
pass: option.Password,
tls: option.TLS,
skipCertVerify: option.SkipCertVerify,
tlsConfig: tlsConfig,
}, nil
}
}
type socksPacketConn struct {

View File

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
"net"
"net/http"
"strconv"
@ -36,7 +35,6 @@ type TrojanOption struct {
ALPN []string `proxy:"alpn,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
UDP bool `proxy:"udp,omitempty"`
Network string `proxy:"network,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
@ -190,7 +188,6 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
ServerName: option.Server,
SkipCertVerify: option.SkipCertVerify,
FlowShow: option.FlowShow,
Fingerprint: option.Fingerprint,
}
if option.Network != "ws" && len(option.Flow) >= 16 {
@ -209,13 +206,12 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
t := &Trojan{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Trojan,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: addr,
tp: C.Trojan,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
instance: trojan.New(tOption),
option: &option,
@ -238,15 +234,6 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
ServerName: tOption.ServerName,
}
if len(option.Fingerprint) == 0 {
tlsConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig)
} else {
var err error
if tlsConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint); err != nil {
return nil, err
}
}
if t.option.Flow != "" {
t.transport = gun.NewHTTP2XTLSClient(dialFn, tlsConfig)
} else {

View File

@ -5,7 +5,6 @@ import (
"crypto/tls"
xtls "github.com/xtls/go"
"net"
"net/netip"
"strconv"
"sync"
"time"
@ -44,11 +43,10 @@ func getClientXSessionCache() xtls.ClientSessionCache {
func serializesSocksAddr(metadata *C.Metadata) []byte {
var buf [][]byte
addrType := metadata.AddrType()
aType := uint8(addrType)
aType := uint8(metadata.AddrType)
p, _ := strconv.ParseUint(metadata.DstPort, 10, 16)
port := []byte{uint8(p >> 8), uint8(p & 0xff)}
switch addrType {
switch metadata.AddrType {
case socks5.AtypDomainName:
lenM := uint8(len(metadata.Host))
host := []byte(metadata.Host)
@ -76,63 +74,6 @@ func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
}
func resolveUDPAddrWithPrefer(network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ip netip.Addr
switch prefer {
case C.IPv4Only:
ip, err = resolver.ResolveIPv4ProxyServerHost(host)
case C.IPv6Only:
ip, err = resolver.ResolveIPv6ProxyServerHost(host)
case C.IPv6Prefer:
var ips []netip.Addr
ips, err = resolver.ResolveAllIPProxyServerHost(host)
var fallback netip.Addr
if err == nil {
for _, addr := range ips {
if addr.Is6() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
ip = fallback
}
default:
// C.IPv4Prefer, C.DualStack and other
var ips []netip.Addr
ips, err = resolver.ResolveAllIPProxyServerHost(host)
var fallback netip.Addr
if err == nil {
for _, addr := range ips {
if addr.Is4() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
if !ip.IsValid() && fallback.IsValid() {
ip = fallback
}
}
}
if err != nil {
return nil, err
}
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
}
func safeConnClose(c net.Conn, err error) {
if err != nil {
_ = c.Close()

View File

@ -6,19 +6,17 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/Dreamacro/clash/common/convert"
"io"
"net"
"net/http"
"strconv"
"sync"
"github.com/Dreamacro/clash/common/convert"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/resolver"
tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/gun"
"github.com/Dreamacro/clash/transport/socks5"
"github.com/Dreamacro/clash/transport/vless"
"github.com/Dreamacro/clash/transport/vmess"
)
@ -57,7 +55,6 @@ type VlessOption struct {
WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
ServerName string `proxy:"servername,omitempty"`
}
@ -83,19 +80,12 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
}
if v.option.TLS {
wsOpts.TLS = true
tlsConfig := &tls.Config{
wsOpts.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: host,
InsecureSkipVerify: v.option.SkipCertVerify,
NextProtos: []string{"http/1.1"},
}
if len(v.option.Fingerprint) == 0 {
wsOpts.TLSConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig)
} else {
wsOpts.TLSConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, v.option.Fingerprint)
}
if v.option.ServerName != "" {
wsOpts.TLSConfig.ServerName = v.option.ServerName
} else if host := wsOpts.Headers.Get("Host"); host != "" {
@ -162,7 +152,6 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error)
xtlsOpts := vless.XTLSConfig{
Host: host,
SkipCertVerify: v.option.SkipCertVerify,
FingerPrint: v.option.Fingerprint,
}
if isH2 {
@ -179,7 +168,6 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error)
tlsOpts := vmess.TLSConfig{
Host: host,
SkipCertVerify: v.option.SkipCertVerify,
FingerPrint: v.option.Fingerprint,
}
if isH2 {
@ -281,16 +269,16 @@ func (v *Vless) SupportUOT() bool {
func parseVlessAddr(metadata *C.Metadata) *vless.DstAddr {
var addrType byte
var addr []byte
switch metadata.AddrType() {
case socks5.AtypIPv4:
switch metadata.AddrType {
case C.AtypIPv4:
addrType = vless.AtypIPv4
addr = make([]byte, net.IPv4len)
copy(addr[:], metadata.DstIP.AsSlice())
case socks5.AtypIPv6:
case C.AtypIPv6:
addrType = vless.AtypIPv6
addr = make([]byte, net.IPv6len)
copy(addr[:], metadata.DstIP.AsSlice())
case socks5.AtypDomainName:
case C.AtypDomainName:
addrType = vless.AtypDomainName
addr = make([]byte, len(metadata.Host)+1)
addr[0] = byte(len(metadata.Host))
@ -419,12 +407,11 @@ func NewVless(option VlessOption) (*Vless, error) {
v := &Vless{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vless,
udp: option.UDP,
iface: option.Interface,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vless,
udp: option.UDP,
iface: option.Interface,
},
client: client,
option: &option,
@ -449,10 +436,10 @@ func NewVless(option VlessOption) (*Vless, error) {
ServiceName: v.option.GrpcOpts.GrpcServiceName,
Host: v.option.ServerName,
}
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(&tls.Config{
tlsConfig := &tls.Config{
InsecureSkipVerify: v.option.SkipCertVerify,
ServerName: v.option.ServerName,
})
}
if v.option.ServerName == "" {
host, _, _ := net.SplitHostPort(v.addr)

View File

@ -5,19 +5,19 @@ import (
"crypto/tls"
"errors"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
vmess "github.com/sagernet/sing-vmess"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/Dreamacro/clash/common/convert"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/gun"
clashVMess "github.com/Dreamacro/clash/transport/vmess"
"github.com/sagernet/sing-vmess"
"github.com/sagernet/sing-vmess/packetaddr"
M "github.com/sagernet/sing/common/metadata"
)
@ -45,16 +45,12 @@ type VmessOption struct {
Network string `proxy:"network,omitempty"`
TLS bool `proxy:"tls,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
ServerName string `proxy:"servername,omitempty"`
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
PacketAddr bool `proxy:"packet-addr,omitempty"`
XUDP bool `proxy:"xudp,omitempty"`
PacketEncoding string `proxy:"packet-encoding,omitempty"`
GlobalPadding bool `proxy:"global-padding,omitempty"`
AuthenticatedLength bool `proxy:"authenticated-length,omitempty"`
}
@ -104,26 +100,21 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
if v.option.TLS {
wsOpts.TLS = true
tlsConfig := &tls.Config{
wsOpts.TLSConfig = &tls.Config{
ServerName: host,
InsecureSkipVerify: v.option.SkipCertVerify,
NextProtos: []string{"http/1.1"},
}
if len(v.option.Fingerprint) == 0 {
wsOpts.TLSConfig = tlsC.GetGlobalFingerprintTLCConfig(tlsConfig)
} else {
var err error
if wsOpts.TLSConfig, err = tlsC.GetSpecifiedFingerprintTLSConfig(tlsConfig, v.option.Fingerprint); err != nil {
return nil, err
}
}
if v.option.ServerName != "" {
wsOpts.TLSConfig.ServerName = v.option.ServerName
} else if host := wsOpts.Headers.Get("Host"); host != "" {
wsOpts.TLSConfig.ServerName = host
}
} else {
if host := wsOpts.Headers.Get("Host"); host == "" {
wsOpts.Headers.Set("Host", convert.RandHost())
convert.SetUserAgent(wsOpts.Headers)
}
}
c, err = clashVMess.StreamWebsocketConn(c, wsOpts)
case "http":
@ -200,11 +191,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return nil, err
}
if metadata.NetWork == C.UDP {
if v.option.XUDP {
return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}
@ -251,8 +238,6 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
}
if v.option.PacketAddr {
_metadata := *metadata // make a copy
metadata = &_metadata
metadata.Host = packetaddr.SeqPacketMagicAddress
metadata.DstPort = "443"
}
@ -266,11 +251,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
}
defer safeConnClose(c, err)
if v.option.XUDP {
c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
c, err = dialer.DialContext(ctx, "tcp", v.addr, v.Base.DialOptions(opts...)...)
if err != nil {
@ -287,7 +268,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
}
if v.option.PacketAddr {
return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindConn(c)}, v), nil
return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindClient(c)}, v), nil
} else if pc, ok := c.(net.PacketConn); ok {
return newPacketConn(&threadSafePacketConn{PacketConn: pc}, v), nil
}
@ -297,7 +278,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
// ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vmess) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if v.option.PacketAddr {
return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindConn(c)}, v), nil
return newPacketConn(&threadSafePacketConn{PacketConn: packetaddr.NewBindClient(c)}, v), nil
} else if pc, ok := c.(net.PacketConn); ok {
return newPacketConn(&threadSafePacketConn{PacketConn: pc}, v), nil
}
@ -312,9 +293,6 @@ func (v *Vmess) SupportUOT() bool {
func NewVmess(option VmessOption) (*Vmess, error) {
security := strings.ToLower(option.Cipher)
var options []vmess.ClientOption
if option.GlobalPadding {
options = append(options, vmess.ClientWithGlobalPadding())
}
if option.AuthenticatedLength {
options = append(options, vmess.ClientWithAuthenticatedLength())
}
@ -323,16 +301,6 @@ func NewVmess(option VmessOption) (*Vmess, error) {
return nil, err
}
switch option.PacketEncoding {
case "packetaddr", "packet":
option.PacketAddr = true
case "xudp":
option.XUDP = true
}
if option.XUDP {
option.PacketAddr = false
}
switch option.Network {
case "h2", "grpc":
if !option.TLS {
@ -342,13 +310,12 @@ func NewVmess(option VmessOption) (*Vmess, error) {
v := &Vmess{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
},
client: client,
option: &option,

View File

@ -1,328 +0,0 @@
package outbound
import (
"context"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/sing"
wireguard "github.com/metacubex/sing-wireguard"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/device"
)
type WireGuard struct {
*Base
bind *wireguard.ClientBind
device *device.Device
tunDevice wireguard.Device
dialer *wgDialer
startOnce sync.Once
}
type WireGuardOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"`
PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"`
PreSharedKey string `proxy:"pre-shared-key,omitempty"`
Reserved []int `proxy:"reserved,omitempty"`
Workers int `proxy:"workers,omitempty"`
MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"`
}
type wgDialer struct {
options []dialer.Option
}
func (d *wgDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return dialer.DialContext(ctx, network, destination.String(), d.options...)
}
func (d *wgDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return dialer.ListenPacket(ctx, "udp", "", d.options...)
}
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
outbound := &WireGuard{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.WireGuard,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
dialer: &wgDialer{},
}
runtime.SetFinalizer(outbound, closeWireGuard)
var reserved [3]uint8
if len(option.Reserved) > 0 {
if len(option.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(option.Reserved))
}
reserved[0] = uint8(option.Reserved[0])
reserved[1] = uint8(option.Reserved[1])
reserved[2] = uint8(option.Reserved[2])
}
peerAddr := M.ParseSocksaddr(option.Server)
peerAddr.Port = uint16(option.Port)
outbound.bind = wireguard.NewClientBind(context.Background(), outbound.dialer, peerAddr, reserved)
localPrefixes := make([]netip.Prefix, 0, 2)
if len(option.Ip) > 0 {
if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32"
}
if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, E.Cause(err, "ip address parse error")
}
}
if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128"
}
if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, E.Cause(err, "ipv6 address parse error")
}
}
if len(localPrefixes) == 0 {
return nil, E.New("missing local address")
}
var privateKey, peerPublicKey, preSharedKey string
{
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey = hex.EncodeToString(bytes)
}
{
bytes, err := base64.StdEncoding.DecodeString(option.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peerPublicKey = hex.EncodeToString(bytes)
}
if option.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(option.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
preSharedKey = hex.EncodeToString(bytes)
}
ipcConf := "private_key=" + privateKey
ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + peerAddr.String()
if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey
}
var has4, has6 bool
for _, address := range localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
mtu := option.MTU
if mtu == 0 {
mtu = 1408
}
var err error
outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu))
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
outbound.device = device.NewDevice(outbound.tunDevice, outbound.bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) {
sing.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
},
Errorf: func(format string, args ...interface{}) {
sing.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
},
}, option.Workers)
if debug.Enabled {
sing.Logger.Trace("created wireguard ipc conf: \n", ipcConf)
}
err = outbound.device.IpcSet(ipcConf)
if err != nil {
return nil, E.Cause(err, "setup wireguard")
}
//err = outbound.tunDevice.Start()
return outbound, nil
}
func closeWireGuard(w *WireGuard) {
if w.device != nil {
w.device.Close()
}
_ = common.Close(w.tunDevice)
}
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
w.dialer.options = opts
var conn net.Conn
w.startOnce.Do(func() {
err = w.tunDevice.Start()
})
if err != nil {
return nil, err
}
if !metadata.Resolved() {
var addrs []netip.Addr
addrs, err = resolver.ResolveAllIP(metadata.Host)
if err != nil {
return nil, err
}
conn, err = N.DialSerial(ctx, w.tunDevice, "tcp", M.ParseSocksaddr(metadata.RemoteAddress()), addrs)
} else {
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.ParseSocksaddr(metadata.Pure().RemoteAddress()))
}
if err != nil {
return nil, err
}
return NewConn(&wgConn{conn, w}, w), nil
}
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
w.dialer.options = opts
var pc net.PacketConn
w.startOnce.Do(func() {
err = w.tunDevice.Start()
})
if err != nil {
return nil, err
}
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
pc, err = w.tunDevice.ListenPacket(ctx, M.ParseSocksaddr(metadata.Pure().RemoteAddress()))
if err != nil {
return nil, err
}
return newPacketConn(&wgPacketConn{pc, w}, w), nil
}
type wgConn struct {
conn net.Conn
wg *WireGuard
}
func (c *wgConn) Read(b []byte) (n int, err error) {
defer runtime.KeepAlive(c.wg)
return c.conn.Read(b)
}
func (c *wgConn) Write(b []byte) (n int, err error) {
defer runtime.KeepAlive(c.wg)
return c.conn.Write(b)
}
func (c *wgConn) Close() error {
defer runtime.KeepAlive(c.wg)
return c.conn.Close()
}
func (c *wgConn) LocalAddr() net.Addr {
defer runtime.KeepAlive(c.wg)
return c.conn.LocalAddr()
}
func (c *wgConn) RemoteAddr() net.Addr {
defer runtime.KeepAlive(c.wg)
return c.conn.RemoteAddr()
}
func (c *wgConn) SetDeadline(t time.Time) error {
defer runtime.KeepAlive(c.wg)
return c.conn.SetDeadline(t)
}
func (c *wgConn) SetReadDeadline(t time.Time) error {
defer runtime.KeepAlive(c.wg)
return c.conn.SetReadDeadline(t)
}
func (c *wgConn) SetWriteDeadline(t time.Time) error {
defer runtime.KeepAlive(c.wg)
return c.conn.SetWriteDeadline(t)
}
type wgPacketConn struct {
pc net.PacketConn
wg *WireGuard
}
func (pc *wgPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
defer runtime.KeepAlive(pc.wg)
return pc.pc.ReadFrom(p)
}
func (pc *wgPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
defer runtime.KeepAlive(pc.wg)
return pc.pc.WriteTo(p, addr)
}
func (pc *wgPacketConn) Close() error {
defer runtime.KeepAlive(pc.wg)
return pc.pc.Close()
}
func (pc *wgPacketConn) LocalAddr() net.Addr {
defer runtime.KeepAlive(pc.wg)
return pc.pc.LocalAddr()
}
func (pc *wgPacketConn) SetDeadline(t time.Time) error {
defer runtime.KeepAlive(pc.wg)
return pc.pc.SetDeadline(t)
}
func (pc *wgPacketConn) SetReadDeadline(t time.Time) error {
defer runtime.KeepAlive(pc.wg)
return pc.pc.SetReadDeadline(t)
}
func (pc *wgPacketConn) SetWriteDeadline(t time.Time) error {
defer runtime.KeepAlive(pc.wg)
return pc.pc.SetWriteDeadline(t)
}

View File

@ -31,7 +31,7 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
c.AppendToChains(f)
f.onDialSuccess()
} else {
f.onDialFailed(proxy.Type(), err)
f.onDialFailed()
}
return c, err
@ -72,30 +72,24 @@ func (f *Fallback) MarshalJSON() ([]byte, error) {
}
// Unwrap implements C.ProxyAdapter
func (f *Fallback) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
proxy := f.findAliveProxy(touch)
func (f *Fallback) Unwrap(metadata *C.Metadata) C.Proxy {
proxy := f.findAliveProxy(true)
return proxy
}
func (f *Fallback) findAliveProxy(touch bool) C.Proxy {
proxies := f.GetProxies(touch)
for _, proxy := range proxies {
if len(f.selected) == 0 {
if proxy.Alive() {
return proxy
}
} else {
if proxy.Name() == f.selected {
if proxy.Alive() {
return proxy
} else {
f.selected = ""
}
}
al := proxies[0]
for i := len(proxies) - 1; i > -1; i-- {
proxy := proxies[i]
if proxy.Name() == f.selected && proxy.Alive() {
return proxy
}
if proxy.Alive() {
al = proxy
}
}
return proxies[0]
return al
}
func (f *Fallback) Set(name string) error {
@ -131,7 +125,6 @@ func NewFallback(option *GroupCommonOption, providers []provider.ProxyProvider)
RoutingMark: option.RoutingMark,
},
option.Filter,
option.ExcludeFilter,
providers,
}),
disableUDP: option.DisableUDP,

View File

@ -11,156 +11,93 @@ import (
"github.com/Dreamacro/clash/tunnel"
"github.com/dlclark/regexp2"
"go.uber.org/atomic"
"strings"
"sync"
"time"
)
type GroupBase struct {
*outbound.Base
filterRegs []*regexp2.Regexp
excludeFilterReg *regexp2.Regexp
providers []provider.ProxyProvider
failedTestMux sync.Mutex
failedTimes int
failedTime time.Time
failedTesting *atomic.Bool
proxies [][]C.Proxy
versions []atomic.Uint32
filter *regexp2.Regexp
providers []provider.ProxyProvider
versions sync.Map // map[string]uint
proxies sync.Map // map[string][]C.Proxy
failedTestMux sync.Mutex
failedTimes int
failedTime time.Time
failedTesting *atomic.Bool
}
type GroupBaseOption struct {
outbound.BaseOption
filter string
excludeFilter string
providers []provider.ProxyProvider
filter string
providers []provider.ProxyProvider
}
func NewGroupBase(opt GroupBaseOption) *GroupBase {
var excludeFilterReg *regexp2.Regexp
if opt.excludeFilter != "" {
excludeFilterReg = regexp2.MustCompile(opt.excludeFilter, 0)
}
var filterRegs []*regexp2.Regexp
var filter *regexp2.Regexp = nil
if opt.filter != "" {
for _, filter := range strings.Split(opt.filter, "`") {
filterReg := regexp2.MustCompile(filter, 0)
filterRegs = append(filterRegs, filterReg)
}
filter = regexp2.MustCompile(opt.filter, 0)
}
gb := &GroupBase{
Base: outbound.NewBase(opt.BaseOption),
filterRegs: filterRegs,
excludeFilterReg: excludeFilterReg,
providers: opt.providers,
failedTesting: atomic.NewBool(false),
}
gb.proxies = make([][]C.Proxy, len(opt.providers))
gb.versions = make([]atomic.Uint32, len(opt.providers))
return gb
}
func (gb *GroupBase) Touch() {
for _, pd := range gb.providers {
pd.Touch()
return &GroupBase{
Base: outbound.NewBase(opt.BaseOption),
filter: filter,
providers: opt.providers,
failedTesting: atomic.NewBool(false),
}
}
func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
var proxies []C.Proxy
if len(gb.filterRegs) == 0 {
if gb.filter == nil {
var proxies []C.Proxy
for _, pd := range gb.providers {
if touch {
pd.Touch()
}
proxies = append(proxies, pd.Proxies()...)
}
} else {
for i, pd := range gb.providers {
if touch {
pd.Touch()
}
if pd.VehicleType() == types.Compatible {
gb.versions[i].Store(pd.Version())
gb.proxies[i] = pd.Proxies()
continue
}
version := gb.versions[i].Load()
if version != pd.Version() && gb.versions[i].CompareAndSwap(version, pd.Version()) {
var (
proxies []C.Proxy
newProxies []C.Proxy
)
proxies = pd.Proxies()
proxiesSet := map[string]struct{}{}
for _, filterReg := range gb.filterRegs {
for _, p := range proxies {
name := p.Name()
if mat, _ := filterReg.FindStringMatch(name); mat != nil {
if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{}
newProxies = append(newProxies, p)
}
}
}
}
gb.proxies[i] = newProxies
}
}
for _, p := range gb.proxies {
proxies = append(proxies, p...)
if len(proxies) == 0 {
return append(proxies, tunnel.Proxies()["COMPATIBLE"])
}
return proxies
}
for _, pd := range gb.providers {
if touch {
pd.Touch()
}
if pd.VehicleType() == types.Compatible {
gb.proxies.Store(pd.Name(), pd.Proxies())
gb.versions.Store(pd.Name(), pd.Version())
continue
}
if version, ok := gb.versions.Load(pd.Name()); !ok || version != pd.Version() {
var (
proxies []C.Proxy
newProxies []C.Proxy
)
proxies = pd.Proxies()
for _, p := range proxies {
if mat, _ := gb.filter.FindStringMatch(p.Name()); mat != nil {
newProxies = append(newProxies, p)
}
}
gb.proxies.Store(pd.Name(), newProxies)
gb.versions.Store(pd.Name(), pd.Version())
}
}
var proxies []C.Proxy
gb.proxies.Range(func(key, value any) bool {
proxies = append(proxies, value.([]C.Proxy)...)
return true
})
if len(proxies) == 0 {
return append(proxies, tunnel.Proxies()["COMPATIBLE"])
}
if len(gb.providers) > 1 && len(gb.filterRegs) > 1 {
var newProxies []C.Proxy
proxiesSet := map[string]struct{}{}
for _, filterReg := range gb.filterRegs {
for _, p := range proxies {
name := p.Name()
if mat, _ := filterReg.FindStringMatch(name); mat != nil {
if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{}
newProxies = append(newProxies, p)
}
}
}
}
for _, p := range proxies { // add not matched proxies at the end
name := p.Name()
if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{}
newProxies = append(newProxies, p)
}
}
proxies = newProxies
}
if gb.excludeFilterReg != nil {
var newProxies []C.Proxy
for _, p := range proxies {
name := p.Name()
if mat, _ := gb.excludeFilterReg.FindStringMatch(name); mat != nil {
continue
}
newProxies = append(newProxies, p)
}
proxies = newProxies
}
return proxies
}
@ -192,13 +129,8 @@ func (gb *GroupBase) URLTest(ctx context.Context, url string) (map[string]uint16
}
}
func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error) {
if adapterType == C.Direct || adapterType == C.Compatible || adapterType == C.Reject || adapterType == C.Pass {
return
}
if strings.Contains(err.Error(), "connection refused") {
go gb.healthCheck()
func (gb *GroupBase) onDialFailed() {
if gb.failedTesting.Load() {
return
}
@ -212,40 +144,31 @@ func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error) {
gb.failedTime = time.Now()
} else {
if time.Since(gb.failedTime) > gb.failedTimeoutInterval() {
gb.failedTimes = 0
return
}
log.Debugln("ProxyGroup: %s failed count: %d", gb.Name(), gb.failedTimes)
if gb.failedTimes >= gb.maxFailedTimes() {
gb.failedTesting.Store(true)
log.Warnln("because %s failed multiple times, active health check", gb.Name())
gb.healthCheck()
wg := sync.WaitGroup{}
for _, proxyProvider := range gb.providers {
wg.Add(1)
proxyProvider := proxyProvider
go func() {
defer wg.Done()
proxyProvider.HealthCheck()
}()
}
wg.Wait()
gb.failedTesting.Store(false)
gb.failedTimes = 0
}
}
}()
}
func (gb *GroupBase) healthCheck() {
if gb.failedTesting.Load() {
return
}
gb.failedTesting.Store(true)
wg := sync.WaitGroup{}
for _, proxyProvider := range gb.providers {
wg.Add(1)
proxyProvider := proxyProvider
go func() {
defer wg.Done()
proxyProvider.HealthCheck()
}()
}
wg.Wait()
gb.failedTesting.Store(false)
gb.failedTimes = 0
}
func (gb *GroupBase) failedIntervalTime() int64 {
return 5 * time.Second.Milliseconds()
}

View File

@ -29,8 +29,10 @@ type LoadBalance struct {
var errStrategy = errors.New("unsupported strategy")
func parseStrategy(config map[string]any) string {
if strategy, ok := config["strategy"].(string); ok {
return strategy
if elm, ok := config["strategy"]; ok {
if strategy, ok := elm.(string); ok {
return strategy
}
}
return "consistent-hashing"
}
@ -82,17 +84,17 @@ func jumpHash(key uint64, buckets int32) int32 {
// DialContext implements C.ProxyAdapter
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
proxy := lb.Unwrap(metadata, true)
defer func() {
if err == nil {
c.AppendToChains(lb)
lb.onDialSuccess()
} else {
lb.onDialFailed(proxy.Type(), err)
lb.onDialFailed()
}
}()
proxy := lb.Unwrap(metadata)
c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
return
}
@ -105,7 +107,7 @@ func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Meta
}
}()
proxy := lb.Unwrap(metadata, true)
proxy := lb.Unwrap(metadata)
return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
}
@ -143,13 +145,6 @@ func strategyConsistentHashing() strategyFn {
}
}
// when availability is poor, traverse the entire list to get the available nodes
for _, proxy := range proxies {
if proxy.Alive() {
return proxy
}
}
return proxies[0]
}
}
@ -190,8 +185,8 @@ func strategyStickySessions() strategyFn {
}
// Unwrap implements C.ProxyAdapter
func (lb *LoadBalance) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
proxies := lb.GetProxies(touch)
func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
proxies := lb.GetProxies(true)
return lb.strategyFn(proxies, metadata)
}
@ -228,7 +223,6 @@ func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvide
RoutingMark: option.RoutingMark,
},
option.Filter,
option.ExcludeFilter,
providers,
}),
strategyFn: strategyFn,

View File

@ -21,16 +21,15 @@ var (
type GroupCommonOption struct {
outbound.BasicOption
Name string `group:"name"`
Type string `group:"type"`
Proxies []string `group:"proxies,omitempty"`
Use []string `group:"use,omitempty"`
URL string `group:"url,omitempty"`
Interval int `group:"interval,omitempty"`
Lazy bool `group:"lazy,omitempty"`
DisableUDP bool `group:"disable-udp,omitempty"`
Filter string `group:"filter,omitempty"`
ExcludeFilter string `group:"exclude-filter,omitempty"`
Name string `group:"name"`
Type string `group:"type"`
Proxies []string `group:"proxies,omitempty"`
Use []string `group:"use,omitempty"`
URL string `group:"url,omitempty"`
Interval int `group:"interval,omitempty"`
Lazy bool `group:"lazy,omitempty"`
DisableUDP bool `group:"disable-udp,omitempty"`
Filter string `group:"filter,omitempty"`
}
func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, providersMap map[string]types.ProxyProvider) (C.ProxyAdapter, error) {

View File

@ -153,11 +153,11 @@ func (r *Relay) proxies(metadata *C.Metadata, touch bool) ([]C.Proxy, []C.Proxy)
for n, proxy := range rawProxies {
proxies = append(proxies, proxy)
chainProxies = append(chainProxies, proxy)
subproxy := proxy.Unwrap(metadata, touch)
subproxy := proxy.Unwrap(metadata)
for subproxy != nil {
chainProxies = append(chainProxies, subproxy)
proxies[n] = subproxy
subproxy = subproxy.Unwrap(metadata, touch)
subproxy = subproxy.Unwrap(metadata)
}
}
@ -185,7 +185,6 @@ func NewRelay(option *GroupCommonOption, providers []provider.ProxyProvider) *Re
RoutingMark: option.RoutingMark,
},
"",
"",
providers,
}),
}

View File

@ -74,8 +74,8 @@ func (s *Selector) Set(name string) error {
}
// Unwrap implements C.ProxyAdapter
func (s *Selector) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
return s.selectedProxy(touch)
func (s *Selector) Unwrap(*C.Metadata) C.Proxy {
return s.selectedProxy(true)
}
func (s *Selector) selectedProxy(touch bool) C.Proxy {
@ -99,7 +99,6 @@ func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider)
RoutingMark: option.RoutingMark,
},
option.Filter,
option.ExcludeFilter,
providers,
}),
selected: "COMPATIBLE",

View File

@ -34,13 +34,12 @@ func (u *URLTest) Now() string {
// DialContext implements C.ProxyAdapter
func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
proxy := u.fast(true)
c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
c, err = u.fast(true).DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
if err == nil {
c.AppendToChains(u)
u.onDialSuccess()
} else {
u.onDialFailed(proxy.Type(), err)
u.onDialFailed()
}
return c, err
}
@ -56,12 +55,12 @@ func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata,
}
// Unwrap implements C.ProxyAdapter
func (u *URLTest) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
return u.fast(touch)
func (u *URLTest) Unwrap(*C.Metadata) C.Proxy {
return u.fast(true)
}
func (u *URLTest) fast(touch bool) C.Proxy {
elm, _, shared := u.fastSingle.Do(func() (C.Proxy, error) {
elm, _, _ := u.fastSingle.Do(func() (C.Proxy, error) {
proxies := u.GetProxies(touch)
fast := proxies[0]
min := fast.LastDelay()
@ -90,9 +89,6 @@ func (u *URLTest) fast(touch bool) C.Proxy {
return u.fastNode, nil
})
if shared && touch { // a shared fastSingle.Do() may cause providers untouched, so we touch them again
u.Touch()
}
return elm
}
@ -143,7 +139,6 @@ func NewURLTest(option *GroupCommonOption, providers []provider.ProxyProvider, o
},
option.Filter,
option.ExcludeFilter,
providers,
}),
fastSingle: singledo.NewSingle[C.Proxy](time.Second * 10),

View File

@ -16,19 +16,32 @@ func addrToMetadata(rawAddress string) (addr *C.Metadata, err error) {
return
}
if ip, err := netip.ParseAddr(host); err != nil {
ip, err := netip.ParseAddr(host)
if err != nil {
addr = &C.Metadata{
Host: host,
DstPort: port,
AddrType: C.AtypDomainName,
Host: host,
DstIP: netip.Addr{},
DstPort: port,
}
} else {
err = nil
return
} else if ip.Is4() {
addr = &C.Metadata{
Host: "",
DstIP: ip,
DstPort: port,
AddrType: C.AtypIPv4,
Host: "",
DstIP: ip,
DstPort: port,
}
return
}
addr = &C.Metadata{
AddrType: C.AtypIPv6,
Host: "",
DstIP: ip,
DstPort: port,
}
return
}

View File

@ -40,14 +40,14 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
if err != nil {
break
}
proxy, err = outbound.NewSocks5(*socksOption)
proxy = outbound.NewSocks5(*socksOption)
case "http":
httpOption := &outbound.HttpOption{}
err = decoder.Decode(mapping, httpOption)
if err != nil {
break
}
proxy, err = outbound.NewHttp(*httpOption)
proxy = outbound.NewHttp(*httpOption)
case "vmess":
vmessOption := &outbound.VmessOption{
HTTPOpts: outbound.HTTPOptions{
@ -88,13 +88,6 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
break
}
proxy, err = outbound.NewHysteria(*hyOption)
case "wireguard":
hyOption := &outbound.WireGuardOption{}
err = decoder.Decode(mapping, hyOption)
if err != nil {
break
}
proxy, err = outbound.NewWireGuard(*hyOption)
default:
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
}

View File

@ -1,4 +1,4 @@
package resource
package provider
import (
"bytes"
@ -16,34 +16,29 @@ var (
dirMode os.FileMode = 0o755
)
type Parser[V any] func([]byte) (V, error)
type parser[V any] func([]byte) (V, error)
type Fetcher[V any] struct {
resourceType string
name string
vehicle types.Vehicle
UpdatedAt *time.Time
ticker *time.Ticker
done chan struct{}
hash [16]byte
parser Parser[V]
interval time.Duration
OnUpdate func(V)
type fetcher[V any] struct {
name string
vehicle types.Vehicle
updatedAt *time.Time
ticker *time.Ticker
done chan struct{}
hash [16]byte
parser parser[V]
interval time.Duration
onUpdate func(V)
}
func (f *Fetcher[V]) Name() string {
func (f *fetcher[V]) Name() string {
return f.name
}
func (f *Fetcher[V]) Vehicle() types.Vehicle {
return f.vehicle
}
func (f *Fetcher[V]) VehicleType() types.VehicleType {
func (f *fetcher[V]) VehicleType() types.VehicleType {
return f.vehicle.Type()
}
func (f *Fetcher[V]) Initial() (V, error) {
func (f *fetcher[V]) Initial() (V, error) {
var (
buf []byte
err error
@ -54,7 +49,7 @@ func (f *Fetcher[V]) Initial() (V, error) {
if stat, fErr := os.Stat(f.vehicle.Path()); fErr == nil {
buf, err = os.ReadFile(f.vehicle.Path())
modTime := stat.ModTime()
f.UpdatedAt = &modTime
f.updatedAt = &modTime
isLocal = true
if f.interval != 0 && modTime.Add(f.interval).Before(time.Now()) {
log.Infoln("[Provider] %s not updated for a long time, force refresh", f.Name())
@ -68,11 +63,11 @@ func (f *Fetcher[V]) Initial() (V, error) {
return getZero[V](), err
}
var contents V
var proxies V
if forceUpdate {
var forceBuf []byte
if forceBuf, err = f.vehicle.Read(); err == nil {
if contents, err = f.parser(forceBuf); err == nil {
if proxies, err = f.parser(forceBuf); err == nil {
isLocal = false
buf = forceBuf
}
@ -80,7 +75,7 @@ func (f *Fetcher[V]) Initial() (V, error) {
}
if err != nil || !forceUpdate {
contents, err = f.parser(buf)
proxies, err = f.parser(buf)
}
if err != nil {
@ -94,7 +89,7 @@ func (f *Fetcher[V]) Initial() (V, error) {
return getZero[V](), err
}
contents, err = f.parser(buf)
proxies, err = f.parser(buf)
if err != nil {
return getZero[V](), err
}
@ -110,15 +105,15 @@ func (f *Fetcher[V]) Initial() (V, error) {
f.hash = md5.Sum(buf)
// pull contents automatically
// pull proxies automatically
if f.ticker != nil {
go f.pullLoop()
}
return contents, nil
return proxies, nil
}
func (f *Fetcher[V]) Update() (V, bool, error) {
func (f *fetcher[V]) Update() (V, bool, error) {
buf, err := f.vehicle.Read()
if err != nil {
return getZero[V](), false, err
@ -127,12 +122,12 @@ func (f *Fetcher[V]) Update() (V, bool, error) {
now := time.Now()
hash := md5.Sum(buf)
if bytes.Equal(f.hash[:], hash[:]) {
f.UpdatedAt = &now
_ = os.Chtimes(f.vehicle.Path(), now, now)
f.updatedAt = &now
os.Chtimes(f.vehicle.Path(), now, now)
return getZero[V](), true, nil
}
contents, err := f.parser(buf)
proxies, err := f.parser(buf)
if err != nil {
return getZero[V](), false, err
}
@ -143,20 +138,20 @@ func (f *Fetcher[V]) Update() (V, bool, error) {
}
}
f.UpdatedAt = &now
f.updatedAt = &now
f.hash = hash
return contents, false, nil
return proxies, false, nil
}
func (f *Fetcher[V]) Destroy() error {
func (f *fetcher[V]) Destroy() error {
if f.ticker != nil {
f.done <- struct{}{}
}
return nil
}
func (f *Fetcher[V]) pullLoop() {
func (f *fetcher[V]) pullLoop() {
for {
select {
case <-f.ticker.C:
@ -167,13 +162,13 @@ func (f *Fetcher[V]) pullLoop() {
}
if same {
log.Debugln("[Provider] %s's content doesn't change", f.Name())
log.Debugln("[Provider] %s's proxies doesn't change", f.Name())
continue
}
log.Infoln("[Provider] %s's content update", f.Name())
if f.OnUpdate != nil {
f.OnUpdate(elm)
log.Infoln("[Provider] %s's proxies update", f.Name())
if f.onUpdate != nil {
f.onUpdate(elm)
}
case <-f.done:
f.ticker.Stop()
@ -194,19 +189,19 @@ func safeWrite(path string, buf []byte) error {
return os.WriteFile(path, buf, fileMode)
}
func NewFetcher[V any](name string, interval time.Duration, vehicle types.Vehicle, parser Parser[V], onUpdate func(V)) *Fetcher[V] {
func newFetcher[V any](name string, interval time.Duration, vehicle types.Vehicle, parser parser[V], onUpdate func(V)) *fetcher[V] {
var ticker *time.Ticker
if interval != 0 {
ticker = time.NewTicker(interval)
}
return &Fetcher[V]{
return &fetcher[V]{
name: name,
ticker: ticker,
vehicle: vehicle,
parser: parser,
done: make(chan struct{}, 1),
OnUpdate: onUpdate,
onUpdate: onUpdate,
interval: interval,
}
}

View File

@ -2,14 +2,12 @@ package provider
import (
"context"
"github.com/Dreamacro/clash/common/singledo"
"time"
"github.com/Dreamacro/clash/common/batch"
"github.com/Dreamacro/clash/common/singledo"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"github.com/gofrs/uuid"
"go.uber.org/atomic"
)
@ -37,13 +35,16 @@ func (hc *HealthCheck) process() {
go func() {
time.Sleep(30 * time.Second)
hc.lazyCheck()
hc.check()
}()
for {
select {
case <-ticker.C:
hc.lazyCheck()
now := time.Now().Unix()
if !hc.lazy || now-hc.lastTouch.Load() < int64(hc.interval) {
hc.check()
}
case <-hc.done:
ticker.Stop()
return
@ -51,17 +52,6 @@ func (hc *HealthCheck) process() {
}
}
func (hc *HealthCheck) lazyCheck() bool {
now := time.Now().Unix()
if !hc.lazy || now-hc.lastTouch.Load() < int64(hc.interval) {
hc.check()
return true
} else {
log.Debugln("Skip once health check because we are lazy")
return false
}
}
func (hc *HealthCheck) setProxy(proxies []C.Proxy) {
hc.proxies = proxies
}
@ -76,26 +66,18 @@ func (hc *HealthCheck) touch() {
func (hc *HealthCheck) check() {
_, _, _ = hc.singleDo.Do(func() (struct{}, error) {
id := ""
if uid, err := uuid.NewV4(); err == nil {
id = uid.String()
}
log.Debugln("Start New Health Checking {%s}", id)
b, _ := batch.New[bool](context.Background(), batch.WithConcurrencyNum[bool](10))
for _, proxy := range hc.proxies {
p := proxy
b.Go(p.Name(), func() (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout)
defer cancel()
log.Debugln("Health Checking %s {%s}", p.Name(), id)
_, _ = p.URLTest(ctx, hc.url)
log.Debugln("Health Checked %s : %t %d ms {%s}", p.Name(), p.Alive(), p.LastDelay(), id)
return false, nil
})
}
b.Wait()
log.Debugln("Finish A Health Checking {%s}", id)
return struct{}{}, nil
})
}

View File

@ -3,7 +3,6 @@ package provider
import (
"errors"
"fmt"
"github.com/Dreamacro/clash/component/resource"
"time"
"github.com/Dreamacro/clash/common/structure"
@ -21,13 +20,12 @@ type healthCheckSchema struct {
}
type proxyProviderSchema struct {
Type string `provider:"type"`
Path string `provider:"path"`
URL string `provider:"url,omitempty"`
Interval int `provider:"interval,omitempty"`
Filter string `provider:"filter,omitempty"`
ExcludeFilter string `provider:"exclude-filter,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Type string `provider:"type"`
Path string `provider:"path"`
URL string `provider:"url,omitempty"`
Interval int `provider:"interval,omitempty"`
Filter string `provider:"filter,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
}
func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvider, error) {
@ -53,15 +51,14 @@ func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvide
var vehicle types.Vehicle
switch schema.Type {
case "file":
vehicle = resource.NewFileVehicle(path)
vehicle = NewFileVehicle(path)
case "http":
vehicle = resource.NewHTTPVehicle(schema.URL, path)
vehicle = NewHTTPVehicle(schema.URL, path)
default:
return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type)
}
interval := time.Duration(uint(schema.Interval)) * time.Second
filter := schema.Filter
excludeFilter := schema.ExcludeFilter
return NewProxySetProvider(name, interval, filter, excludeFilter, vehicle, hc)
return NewProxySetProvider(name, interval, filter, vehicle, hc)
}

View File

@ -1,18 +1,13 @@
package provider
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/Dreamacro/clash/common/convert"
clashHttp "github.com/Dreamacro/clash/component/http"
"github.com/Dreamacro/clash/component/resource"
"github.com/Dreamacro/clash/log"
"github.com/dlclark/regexp2"
"net/http"
"math"
"runtime"
"strings"
"time"
"github.com/Dreamacro/clash/adapter"
@ -36,30 +31,28 @@ type ProxySetProvider struct {
}
type proxySetProvider struct {
*resource.Fetcher[[]C.Proxy]
proxies []C.Proxy
healthCheck *HealthCheck
version uint32
subscriptionInfo *SubscriptionInfo
*fetcher[[]C.Proxy]
proxies []C.Proxy
healthCheck *HealthCheck
version uint
}
func (pp *proxySetProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": pp.Name(),
"type": pp.Type().String(),
"vehicleType": pp.VehicleType().String(),
"proxies": pp.Proxies(),
"updatedAt": pp.UpdatedAt,
"subscriptionInfo": pp.subscriptionInfo,
"name": pp.Name(),
"type": pp.Type().String(),
"vehicleType": pp.VehicleType().String(),
"proxies": pp.Proxies(),
"updatedAt": pp.updatedAt,
})
}
func (pp *proxySetProvider) Version() uint32 {
func (pp *proxySetProvider) Version() uint {
return pp.version
}
func (pp *proxySetProvider) Name() string {
return pp.Fetcher.Name()
return pp.name
}
func (pp *proxySetProvider) HealthCheck() {
@ -67,19 +60,19 @@ func (pp *proxySetProvider) HealthCheck() {
}
func (pp *proxySetProvider) Update() error {
elm, same, err := pp.Fetcher.Update()
elm, same, err := pp.fetcher.Update()
if err == nil && !same {
pp.OnUpdate(elm)
pp.onUpdate(elm)
}
return err
}
func (pp *proxySetProvider) Initial() error {
elm, err := pp.Fetcher.Initial()
elm, err := pp.fetcher.Initial()
if err != nil {
return err
}
pp.OnUpdate(elm)
pp.onUpdate(elm)
return nil
}
@ -99,61 +92,19 @@ func (pp *proxySetProvider) setProxies(proxies []C.Proxy) {
pp.proxies = proxies
pp.healthCheck.setProxy(proxies)
if pp.healthCheck.auto() {
defer func() { go pp.healthCheck.lazyCheck() }()
defer func() { go pp.healthCheck.check() }()
}
}
func (pp *proxySetProvider) getSubscriptionInfo() {
if pp.VehicleType() != types.HTTP {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := clashHttp.HttpRequest(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {"clash"}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
userInfoStr := strings.TrimSpace(resp.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
resp2, err := clashHttp.HttpRequest(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {"Quantumultx"}}, nil)
if err != nil {
return
}
defer resp2.Body.Close()
userInfoStr = strings.TrimSpace(resp2.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
return
}
}
pp.subscriptionInfo, err = NewSubscriptionInfo(userInfoStr)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
}
}()
}
func stopProxyProvider(pd *ProxySetProvider) {
pd.healthCheck.close()
_ = pd.Fetcher.Destroy()
_ = pd.fetcher.Destroy()
}
func NewProxySetProvider(name string, interval time.Duration, filter string, excludeFilter string, vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) {
excludeFilterReg, err := regexp2.Compile(excludeFilter, 0)
func NewProxySetProvider(name string, interval time.Duration, filter string, vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) {
filterReg, err := regexp2.Compile(filter, 0)
if err != nil {
return nil, fmt.Errorf("invalid excludeFilter regex: %w", err)
}
var filterRegs []*regexp2.Regexp
for _, filter := range strings.Split(filter, "`") {
filterReg, err := regexp2.Compile(filter, 0)
if err != nil {
return nil, fmt.Errorf("invalid filter regex: %w", err)
}
filterRegs = append(filterRegs, filterReg)
return nil, fmt.Errorf("invalid filter regex: %w", err)
}
if hc.auto() {
@ -165,10 +116,9 @@ func NewProxySetProvider(name string, interval time.Duration, filter string, exc
healthCheck: hc,
}
fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, proxiesParseAndFilter(filter, excludeFilter, filterRegs, excludeFilterReg), proxiesOnUpdate(pd))
pd.Fetcher = fetcher
fetcher := newFetcher[[]C.Proxy](name, interval, vehicle, proxiesParseAndFilter(filter, filterReg), proxiesOnUpdate(pd))
pd.fetcher = fetcher
pd.getSubscriptionInfo()
wrapper := &ProxySetProvider{pd}
runtime.SetFinalizer(wrapper, stopProxyProvider)
return wrapper, nil
@ -183,7 +133,7 @@ type compatibleProvider struct {
name string
healthCheck *HealthCheck
proxies []C.Proxy
version uint32
version uint
}
func (cp *compatibleProvider) MarshalJSON() ([]byte, error) {
@ -195,7 +145,7 @@ func (cp *compatibleProvider) MarshalJSON() ([]byte, error) {
})
}
func (cp *compatibleProvider) Version() uint32 {
func (cp *compatibleProvider) Version() uint {
return cp.version
}
@ -258,12 +208,15 @@ func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*Co
func proxiesOnUpdate(pd *proxySetProvider) func([]C.Proxy) {
return func(elm []C.Proxy) {
pd.setProxies(elm)
pd.version += 1
pd.getSubscriptionInfo()
if pd.version == math.MaxUint {
pd.version = 0
} else {
pd.version++
}
}
}
func proxiesParseAndFilter(filter string, excludeFilter string, filterRegs []*regexp2.Regexp, excludeFilterReg *regexp2.Regexp) resource.Parser[[]C.Proxy] {
func proxiesParseAndFilter(filter string, filterReg *regexp2.Regexp) parser[[]C.Proxy] {
return func(buf []byte) ([]C.Proxy, error) {
schema := &ProxySchema{}
@ -280,37 +233,17 @@ func proxiesParseAndFilter(filter string, excludeFilter string, filterRegs []*re
}
proxies := []C.Proxy{}
proxiesSet := map[string]struct{}{}
for _, filterReg := range filterRegs {
for idx, mapping := range schema.Proxies {
mName, ok := mapping["name"]
if !ok {
continue
}
name, ok := mName.(string)
if !ok {
continue
}
if len(excludeFilter) > 0 {
if mat, _ := excludeFilterReg.FindStringMatch(name); mat != nil {
continue
}
}
if len(filter) > 0 {
if mat, _ := filterReg.FindStringMatch(name); mat == nil {
continue
}
}
if _, ok := proxiesSet[name]; ok {
continue
}
proxy, err := adapter.ParseProxy(mapping)
if err != nil {
return nil, fmt.Errorf("proxy %d error: %w", idx, err)
}
proxiesSet[name] = struct{}{}
proxies = append(proxies, proxy)
for idx, mapping := range schema.Proxies {
name, ok := mapping["name"]
mat, _ := filterReg.FindStringMatch(name.(string))
if ok && len(filter) > 0 && mat == nil {
continue
}
proxy, err := adapter.ParseProxy(mapping)
if err != nil {
return nil, fmt.Errorf("proxy %d error: %w", idx, err)
}
proxies = append(proxies, proxy)
}
if len(proxies) == 0 {

View File

@ -1,57 +0,0 @@
package provider
import (
"github.com/dlclark/regexp2"
"strconv"
"strings"
)
type SubscriptionInfo struct {
Upload int64
Download int64
Total int64
Expire int64
}
func NewSubscriptionInfo(str string) (si *SubscriptionInfo, err error) {
si = &SubscriptionInfo{}
str = strings.ToLower(str)
reTraffic := regexp2.MustCompile("upload=(\\d+); download=(\\d+); total=(\\d+)", 0)
reExpire := regexp2.MustCompile("expire=(\\d+)", 0)
match, err := reTraffic.FindStringMatch(str)
if err != nil || match == nil {
return nil, err
}
group := match.Groups()
si.Upload, err = str2uint64(group[1].String())
if err != nil {
return nil, err
}
si.Download, err = str2uint64(group[2].String())
if err != nil {
return nil, err
}
si.Total, err = str2uint64(group[3].String())
if err != nil {
return nil, err
}
match, _ = reExpire.FindStringMatch(str)
if match != nil {
group = match.Groups()
si.Expire, err = str2uint64(group[1].String())
if err != nil {
return nil, err
}
}
return
}
func str2uint64(str string) (int64, error) {
i, err := strconv.ParseInt(str, 10, 64)
return i, err
}

View File

@ -1,8 +1,8 @@
package resource
package provider
import (
"context"
clashHttp "github.com/Dreamacro/clash/component/http"
netHttp "github.com/Dreamacro/clash/component/http"
types "github.com/Dreamacro/clash/constant/provider"
"io"
"net/http"
@ -35,10 +35,6 @@ type HTTPVehicle struct {
path string
}
func (h *HTTPVehicle) Url() string {
return h.url
}
func (h *HTTPVehicle) Type() types.VehicleType {
return types.HTTP
}
@ -50,7 +46,7 @@ func (h *HTTPVehicle) Path() string {
func (h *HTTPVehicle) Read() ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
resp, err := clashHttp.HttpRequest(ctx, h.url, http.MethodGet, nil, nil)
resp, err := netHttp.HttpRequest(ctx, h.url, http.MethodGet, nil, nil)
if err != nil {
return nil, err
}

View File

@ -1,45 +0,0 @@
package convert
import (
"encoding/base64"
"strings"
)
var (
encRaw = base64.RawStdEncoding
enc = base64.StdEncoding
)
// DecodeBase64 try to decode content from the given bytes,
// which can be in base64.RawStdEncoding, base64.StdEncoding or just plaintext.
func DecodeBase64(buf []byte) []byte {
result, err := tryDecodeBase64(buf)
if err != nil {
return buf
}
return result
}
func tryDecodeBase64(buf []byte) ([]byte, error) {
dBuf := make([]byte, encRaw.DecodedLen(len(buf)))
n, err := encRaw.Decode(dBuf, buf)
if err != nil {
n, err = enc.Decode(dBuf, buf)
if err != nil {
return nil, err
}
}
return dBuf[:n], nil
}
func urlSafe(data string) string {
return strings.NewReplacer("+", "-", "/", "_").Replace(data)
}
func decodeUrlSafe(data string) string {
dcBuf, err := base64.RawURLEncoding.DecodeString(data)
if err != nil {
return ""
}
return string(dcBuf)
}

View File

@ -2,6 +2,7 @@ package convert
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
@ -9,6 +10,21 @@ import (
"strings"
)
var encRaw = base64.RawStdEncoding
var enc = base64.StdEncoding
func DecodeBase64(buf []byte) []byte {
dBuf := make([]byte, encRaw.DecodedLen(len(buf)))
n, err := encRaw.Decode(dBuf, buf)
if err != nil {
n, err = enc.Decode(dBuf, buf)
if err != nil {
return buf
}
}
return dBuf[:n]
}
// ConvertsV2Ray convert V2Ray subscribe proxies data to clash proxies config
func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
data := DecodeBase64(buf)
@ -47,7 +63,7 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
hysteria["port"] = urlHysteria.Port()
hysteria["sni"] = query.Get("peer")
hysteria["obfs"] = query.Get("obfs")
hysteria["alpn"] = []string{query.Get("alpn")}
hysteria["alpn"] = query.Get("alpn")
hysteria["auth_str"] = query.Get("auth")
hysteria["protocol"] = query.Get("protocol")
up := query.Get("up")
@ -114,45 +130,102 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
proxies = append(proxies, trojan)
case "vless":
urlVLess, err := url.Parse(line)
urlVless, err := url.Parse(line)
if err != nil {
continue
}
query := urlVLess.Query()
query := urlVless.Query()
name := uniqueName(names, urlVless.Fragment)
vless := make(map[string]any, 20)
handleVShareLink(names, urlVLess, scheme, vless)
if flow := query.Get("flow"); flow != "" {
vless["flow"] = strings.ToLower(flow)
vless["name"] = name
vless["type"] = scheme
vless["server"] = urlVless.Hostname()
vless["port"] = urlVless.Port()
vless["uuid"] = urlVless.User.Username()
vless["udp"] = true
vless["skip-cert-verify"] = false
vless["tls"] = false
tls := strings.ToLower(query.Get("security"))
if strings.Contains(tls, "tls") {
vless["tls"] = true
}
sni := query.Get("sni")
if sni != "" {
vless["servername"] = sni
}
flow := strings.ToLower(query.Get("flow"))
if flow != "" {
vless["flow"] = flow
}
network := strings.ToLower(query.Get("type"))
fakeType := strings.ToLower(query.Get("headerType"))
if fakeType == "http" {
network = "http"
} else if network == "http" {
network = "h2"
}
vless["network"] = network
switch network {
case "tcp":
if fakeType != "none" {
headers := make(map[string]any)
httpOpts := make(map[string]any)
httpOpts["path"] = []string{"/"}
if query.Get("host") != "" {
headers["Host"] = []string{query.Get("host")}
}
if query.Get("method") != "" {
httpOpts["method"] = query.Get("method")
}
if query.Get("path") != "" {
httpOpts["path"] = []string{query.Get("path")}
}
httpOpts["headers"] = headers
vless["http-opts"] = httpOpts
}
case "http":
headers := make(map[string]any)
h2Opts := make(map[string]any)
h2Opts["path"] = []string{"/"}
if query.Get("path") != "" {
h2Opts["path"] = []string{query.Get("path")}
}
if query.Get("host") != "" {
h2Opts["host"] = []string{query.Get("host")}
}
h2Opts["headers"] = headers
vless["h2-opts"] = h2Opts
case "ws":
headers := make(map[string]any)
wsOpts := make(map[string]any)
headers["User-Agent"] = RandUserAgent()
headers["Host"] = query.Get("host")
wsOpts["path"] = query.Get("path")
wsOpts["headers"] = headers
vless["ws-opts"] = wsOpts
case "grpc":
grpcOpts := make(map[string]any)
grpcOpts["grpc-service-name"] = query.Get("serviceName")
vless["grpc-opts"] = grpcOpts
}
proxies = append(proxies, vless)
case "vmess":
// V2RayN-styled share link
// https://github.com/2dust/v2rayN/wiki/%E5%88%86%E4%BA%AB%E9%93%BE%E6%8E%A5%E6%A0%BC%E5%BC%8F%E8%AF%B4%E6%98%8E(ver-2)
dcBuf, err := tryDecodeBase64([]byte(body))
dcBuf, err := encRaw.DecodeString(body)
if err != nil {
// Xray VMessAEAD share link
urlVMess, err := url.Parse(line)
if err != nil {
continue
}
query := urlVMess.Query()
vmess := make(map[string]any, 20)
handleVShareLink(names, urlVMess, scheme, vmess)
vmess["alterId"] = 0
vmess["cipher"] = "auto"
if encryption := query.Get("encryption"); encryption != "" {
vmess["cipher"] = encryption
}
if packetEncoding := query.Get("packetEncoding"); packetEncoding != "" {
switch packetEncoding {
case "packet":
vmess["packet-addr"] = true
case "xudp":
vmess["xudp"] = true
}
}
proxies = append(proxies, vmess)
continue
}
@ -171,21 +244,18 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
vmess["server"] = values["add"]
vmess["port"] = values["port"]
vmess["uuid"] = values["id"]
if alterId, ok := values["aid"]; ok {
vmess["alterId"] = alterId
} else {
vmess["alterId"] = 0
}
vmess["alterId"] = values["aid"]
vmess["cipher"] = "auto"
vmess["udp"] = true
vmess["tls"] = false
vmess["skip-cert-verify"] = false
vmess["cipher"] = "auto"
if cipher, ok := values["scy"]; ok && cipher != "" {
vmess["cipher"] = cipher
if values["cipher"] != nil && values["cipher"] != "" {
vmess["cipher"] = values["cipher"]
}
if sni, ok := values["sni"]; ok && sni != "" {
sni := values["sni"]
if sni != "" {
vmess["servername"] = sni
}
@ -198,7 +268,7 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
vmess["network"] = network
tls := strings.ToLower(values["tls"].(string))
if strings.HasSuffix(tls, "tls") {
if strings.Contains(tls, "tls") {
vmess["tls"] = true
}
@ -206,12 +276,12 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "http":
headers := make(map[string]any)
httpOpts := make(map[string]any)
if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host.(string)}
if values["host"] != "" && values["host"] != nil {
headers["Host"] = []string{values["host"].(string)}
}
httpOpts["path"] = []string{"/"}
if path, ok := values["path"]; ok && path != "" {
httpOpts["path"] = []string{path.(string)}
if values["path"] != "" && values["path"] != nil {
httpOpts["path"] = []string{values["path"].(string)}
}
httpOpts["headers"] = headers
@ -220,8 +290,8 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "h2":
headers := make(map[string]any)
h2Opts := make(map[string]any)
if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host.(string)}
if values["host"] != "" && values["host"] != nil {
headers["Host"] = []string{values["host"].(string)}
}
h2Opts["path"] = values["path"]
@ -233,11 +303,11 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
headers := make(map[string]any)
wsOpts := make(map[string]any)
wsOpts["path"] = []string{"/"}
if host, ok := values["host"]; ok && host != "" {
headers["Host"] = host.(string)
if values["host"] != "" && values["host"] != nil {
headers["Host"] = values["host"].(string)
}
if path, ok := values["path"]; ok && path != "" {
wsOpts["path"] = path.(string)
if values["path"] != "" && values["path"] != nil {
wsOpts["path"] = values["path"].(string)
}
wsOpts["headers"] = headers
vmess["ws-opts"] = wsOpts
@ -287,7 +357,7 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
}
}
ss := make(map[string]any, 10)
ss := make(map[string]any, 20)
ss["name"] = name
ss["type"] = scheme
@ -297,9 +367,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
ss["password"] = password
query := urlSS.Query()
ss["udp"] = true
if query.Get("udp-over-tcp") == "true" || query.Get("uot") == "1" {
ss["udp-over-tcp"] = true
}
if strings.Contains(query.Get("plugin"), "obfs") {
obfsParams := strings.Split(query.Get("plugin"), ";")
ss["plugin"] = "obfs"
@ -377,6 +444,18 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
return proxies, nil
}
func urlSafe(data string) string {
return strings.ReplaceAll(strings.ReplaceAll(data, "+", "-"), "/", "_")
}
func decodeUrlSafe(data string) string {
dcBuf, err := base64.RawURLEncoding.DecodeString(data)
if err != nil {
return ""
}
return string(dcBuf)
}
func uniqueName(names map[string]int, name string) string {
if index, ok := names[name]; ok {
index++

View File

@ -1,89 +0,0 @@
package convert
import (
"net/url"
"strings"
)
func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy map[string]any) {
// Xray VMessAEAD / VLESS share link standard
// https://github.com/XTLS/Xray-core/discussions/716
query := url.Query()
proxy["name"] = uniqueName(names, url.Fragment)
proxy["type"] = scheme
proxy["server"] = url.Hostname()
proxy["port"] = url.Port()
proxy["uuid"] = url.User.Username()
proxy["udp"] = true
proxy["skip-cert-verify"] = false
proxy["tls"] = false
tls := strings.ToLower(query.Get("security"))
if strings.HasSuffix(tls, "tls") {
proxy["tls"] = true
}
if sni := query.Get("sni"); sni != "" {
proxy["servername"] = sni
}
network := strings.ToLower(query.Get("type"))
if network == "" {
network = "tcp"
}
fakeType := strings.ToLower(query.Get("headerType"))
if fakeType == "http" {
network = "http"
} else if network == "http" {
network = "h2"
}
proxy["network"] = network
switch network {
case "tcp":
if fakeType != "none" {
headers := make(map[string]any)
httpOpts := make(map[string]any)
httpOpts["path"] = []string{"/"}
if host := query.Get("host"); host != "" {
headers["Host"] = []string{host}
}
if method := query.Get("method"); method != "" {
httpOpts["method"] = method
}
if path := query.Get("path"); path != "" {
httpOpts["path"] = []string{path}
}
httpOpts["headers"] = headers
proxy["http-opts"] = httpOpts
}
case "http":
headers := make(map[string]any)
h2Opts := make(map[string]any)
h2Opts["path"] = []string{"/"}
if path := query.Get("path"); path != "" {
h2Opts["path"] = []string{path}
}
if host := query.Get("host"); host != "" {
h2Opts["host"] = []string{host}
}
h2Opts["headers"] = headers
proxy["h2-opts"] = h2Opts
case "ws":
headers := make(map[string]any)
wsOpts := make(map[string]any)
headers["User-Agent"] = RandUserAgent()
headers["Host"] = query.Get("host")
wsOpts["path"] = query.Get("path")
wsOpts["headers"] = headers
proxy["ws-opts"] = wsOpts
case "grpc":
grpcOpts := make(map[string]any)
grpcOpts["grpc-service-name"] = query.Get("serviceName")
proxy["grpc-opts"] = grpcOpts
}
}

View File

@ -4,6 +4,8 @@ import (
"io"
"net"
"time"
"github.com/Dreamacro/clash/common/pool"
)
// Relay copies between left and right bidirectionally.
@ -11,14 +13,18 @@ func Relay(leftConn, rightConn net.Conn) {
ch := make(chan error)
go func() {
buf := pool.Get(pool.RelayBufferSize)
// Wrapping to avoid using *net.TCPConn.(ReadFrom)
// See also https://github.com/Dreamacro/clash/pull/1209
_, err := io.Copy(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn})
_, err := io.CopyBuffer(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}, buf)
pool.Put(buf)
leftConn.SetReadDeadline(time.Now())
ch <- err
}()
_, _ = io.Copy(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn})
buf := pool.Get(pool.RelayBufferSize)
io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf)
pool.Put(buf)
rightConn.SetReadDeadline(time.Now())
<-ch
}

View File

@ -8,7 +8,7 @@ import (
"sync"
)
var defaultAllocator = NewAllocator()
var DefaultAllocator = NewAllocator()
// Allocator for incoming frames, optimized to prevent overwriting after zeroing
type Allocator struct {

View File

@ -13,9 +13,9 @@ const (
)
func Get(size int) []byte {
return defaultAllocator.Get(size)
return DefaultAllocator.Get(size)
}
func Put(buf []byte) error {
return defaultAllocator.Put(buf)
return DefaultAllocator.Put(buf)
}

View File

@ -1,7 +0,0 @@
package pool
import "github.com/sagernet/sing/common/buf"
func init() {
buf.DefaultAllocator = defaultAllocator
}

View File

@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
"net"
"net/netip"
"strings"
"sync"
)
@ -18,8 +16,6 @@ var (
actualDualStackDialContext = dualStackDialContext
tcpConcurrent = false
DisableIPv6 = false
ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel")
)
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
@ -36,23 +32,13 @@ func DialContext(ctx context.Context, network, address string, options ...Option
o(opt)
}
if opt.network == 4 || opt.network == 6 {
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
network = fmt.Sprintf("%s%d", network, opt.network)
}
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleDialContext(ctx, network, address, opt)
case "tcp", "udp":
return actualDualStackDialContext(ctx, network, address, opt)
default:
return nil, ErrorInvalidedNetworkStack
return nil, errors.New("network invalid")
}
}
@ -70,6 +56,10 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio
o(cfg)
}
if DisableIPv6 {
network = "udp4"
}
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
@ -118,7 +108,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
}
if DisableIPv6 && destination.Is6() {
return nil, ErrorDisableIPv6
return nil, fmt.Errorf("IPv6 is diabled, dialer cancel")
}
return dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port))
@ -208,7 +198,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
}
}
return nil, errors.New("dual stack tcp shake hands failed")
return nil, errors.New("never touched")
}
func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
@ -218,16 +208,13 @@ func concurrentDualStackDialContext(ctx context.Context, network, address string
}
var ips []netip.Addr
if opt.direct {
ips, err = resolver.ResolveAllIP(host)
} else {
ips, err = resolver.ResolveAllIPProxyServerHost(host)
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
@ -239,49 +226,29 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
ip netip.Addr
net.Conn
error
isPrimary bool
done bool
resolved bool
}
preferCount := atomic.NewInt32(0)
results := make(chan dialResult)
tcpRacer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{ip: ip, done: true}
result := dialResult{ip: ip}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
_ = result.Conn.Close()
result.Conn.Close()
}
}
}()
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
v := "4"
if ip.Is6() {
network += "6"
if opt.prefer != 4 {
result.isPrimary = true
}
v = "6"
}
if ip.Is4() {
network += "4"
if opt.prefer != 6 {
result.isPrimary = true
}
}
if result.isPrimary {
preferCount.Add(1)
}
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
result.Conn, result.error = dialContext(ctx, network+v, ip, port, opt)
}
for _, ip := range ips {
@ -289,48 +256,17 @@ func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr
}
connCount := len(ips)
var fallback dialResult
var primaryError error
for i := 0; i < connCount; i++ {
select {
case res := <-results:
if res.error == nil {
if res.isPrimary {
return res.Conn, nil
} else {
if !fallback.done || fallback.error != nil {
fallback = res
}
}
} else {
if res.isPrimary {
primaryError = res.error
preferCount.Add(-1)
if preferCount.Load() == 0 && fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
}
return res.Conn, nil
}
case <-ctx.Done():
if fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
break
}
}
if fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
if primaryError != nil {
return nil, primaryError
}
if fallback.error != nil {
return nil, fallback.error
}
return nil, fmt.Errorf("all ips %v tcp shake hands failed", ips)
}
@ -363,45 +299,25 @@ func singleDialContext(ctx context.Context, network string, address string, opt
}
func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
switch network {
case "tcp4", "udp4":
return concurrentIPv4DialContext(ctx, network, address, opt)
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
default:
return concurrentIPv6DialContext(ctx, network, address, opt)
}
}
func concurrentIPv4DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if !opt.direct {
ips, err = resolver.ResolveAllIPv4ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv4(host)
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
func concurrentIPv6DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
if !opt.direct {
ips, err = resolver.ResolveAllIPv6ProxyServerHost(host)
} else {
ips, err = resolver.ResolveAllIPv6(host)
}
}
if err != nil {

View File

@ -1,8 +1,6 @@
package dialer
import (
"go.uber.org/atomic"
)
import "go.uber.org/atomic"
var (
DefaultOptions []Option
@ -15,8 +13,6 @@ type option struct {
addrReuse bool
routingMark int
direct bool
network int
prefer int
}
type Option func(opt *option)
@ -44,25 +40,3 @@ func WithDirect() Option {
opt.direct = true
}
}
func WithPreferIPv4() Option {
return func(opt *option) {
opt.prefer = 4
}
}
func WithPreferIPv6() Option {
return func(opt *option) {
opt.prefer = 6
}
}
func WithOnlySingleStack(isIPv4 bool) Option {
return func(opt *option) {
if isIPv4 {
opt.network = 4
} else {
opt.network = 6
}
}
}

View File

@ -1,99 +0,0 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_ENDIAN__
#define __BPF_ENDIAN__
/*
* Isolate byte #n and put it into byte #m, for __u##b type.
* E.g., moving byte #6 (nnnnnnnn) into byte #1 (mmmmmmmm) for __u64:
* 1) xxxxxxxx nnnnnnnn xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx mmmmmmmm xxxxxxxx
* 2) nnnnnnnn xxxxxxxx xxxxxxxx xxxxxxxx xxxxxxxx mmmmmmmm xxxxxxxx 00000000
* 3) 00000000 00000000 00000000 00000000 00000000 00000000 00000000 nnnnnnnn
* 4) 00000000 00000000 00000000 00000000 00000000 00000000 nnnnnnnn 00000000
*/
#define ___bpf_mvb(x, b, n, m) ((__u##b)(x) << (b-(n+1)*8) >> (b-8) << (m*8))
#define ___bpf_swab16(x) ((__u16)( \
___bpf_mvb(x, 16, 0, 1) | \
___bpf_mvb(x, 16, 1, 0)))
#define ___bpf_swab32(x) ((__u32)( \
___bpf_mvb(x, 32, 0, 3) | \
___bpf_mvb(x, 32, 1, 2) | \
___bpf_mvb(x, 32, 2, 1) | \
___bpf_mvb(x, 32, 3, 0)))
#define ___bpf_swab64(x) ((__u64)( \
___bpf_mvb(x, 64, 0, 7) | \
___bpf_mvb(x, 64, 1, 6) | \
___bpf_mvb(x, 64, 2, 5) | \
___bpf_mvb(x, 64, 3, 4) | \
___bpf_mvb(x, 64, 4, 3) | \
___bpf_mvb(x, 64, 5, 2) | \
___bpf_mvb(x, 64, 6, 1) | \
___bpf_mvb(x, 64, 7, 0)))
/* LLVM's BPF target selects the endianness of the CPU
* it compiles on, or the user specifies (bpfel/bpfeb),
* respectively. The used __BYTE_ORDER__ is defined by
* the compiler, we cannot rely on __BYTE_ORDER from
* libc headers, since it doesn't reflect the actual
* requested byte order.
*
* Note, LLVM's BPF target has different __builtin_bswapX()
* semantics. It does map to BPF_ALU | BPF_END | BPF_TO_BE
* in bpfel and bpfeb case, which means below, that we map
* to cpu_to_be16(). We could use it unconditionally in BPF
* case, but better not rely on it, so that this header here
* can be used from application and BPF program side, which
* use different targets.
*/
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
# define __bpf_ntohs(x) __builtin_bswap16(x)
# define __bpf_htons(x) __builtin_bswap16(x)
# define __bpf_constant_ntohs(x) ___bpf_swab16(x)
# define __bpf_constant_htons(x) ___bpf_swab16(x)
# define __bpf_ntohl(x) __builtin_bswap32(x)
# define __bpf_htonl(x) __builtin_bswap32(x)
# define __bpf_constant_ntohl(x) ___bpf_swab32(x)
# define __bpf_constant_htonl(x) ___bpf_swab32(x)
# define __bpf_be64_to_cpu(x) __builtin_bswap64(x)
# define __bpf_cpu_to_be64(x) __builtin_bswap64(x)
# define __bpf_constant_be64_to_cpu(x) ___bpf_swab64(x)
# define __bpf_constant_cpu_to_be64(x) ___bpf_swab64(x)
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
# define __bpf_ntohs(x) (x)
# define __bpf_htons(x) (x)
# define __bpf_constant_ntohs(x) (x)
# define __bpf_constant_htons(x) (x)
# define __bpf_ntohl(x) (x)
# define __bpf_htonl(x) (x)
# define __bpf_constant_ntohl(x) (x)
# define __bpf_constant_htonl(x) (x)
# define __bpf_be64_to_cpu(x) (x)
# define __bpf_cpu_to_be64(x) (x)
# define __bpf_constant_be64_to_cpu(x) (x)
# define __bpf_constant_cpu_to_be64(x) (x)
#else
# error "Fix your compiler's __BYTE_ORDER__?!"
#endif
#define bpf_htons(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_htons(x) : __bpf_htons(x))
#define bpf_ntohs(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_ntohs(x) : __bpf_ntohs(x))
#define bpf_htonl(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_htonl(x) : __bpf_htonl(x))
#define bpf_ntohl(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_ntohl(x) : __bpf_ntohl(x))
#define bpf_cpu_to_be64(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_cpu_to_be64(x) : __bpf_cpu_to_be64(x))
#define bpf_be64_to_cpu(x) \
(__builtin_constant_p(x) ? \
__bpf_constant_be64_to_cpu(x) : __bpf_be64_to_cpu(x))
#endif /* __BPF_ENDIAN__ */

File diff suppressed because it is too large Load Diff

View File

@ -1,262 +0,0 @@
/* SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) */
#ifndef __BPF_HELPERS__
#define __BPF_HELPERS__
/*
* Note that bpf programs need to include either
* vmlinux.h (auto-generated from BTF) or linux/types.h
* in advance since bpf_helper_defs.h uses such types
* as __u64.
*/
#include "bpf_helper_defs.h"
#define __uint(name, val) int (*name)[val]
#define __type(name, val) typeof(val) *name
#define __array(name, val) typeof(val) *name[]
/*
* Helper macro to place programs, maps, license in
* different sections in elf_bpf file. Section names
* are interpreted by libbpf depending on the context (BPF programs, BPF maps,
* extern variables, etc).
* To allow use of SEC() with externs (e.g., for extern .maps declarations),
* make sure __attribute__((unused)) doesn't trigger compilation warning.
*/
#define SEC(name) \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wignored-attributes\"") \
__attribute__((section(name), used)) \
_Pragma("GCC diagnostic pop") \
/* Avoid 'linux/stddef.h' definition of '__always_inline'. */
#undef __always_inline
#define __always_inline inline __attribute__((always_inline))
#ifndef __noinline
#define __noinline __attribute__((noinline))
#endif
#ifndef __weak
#define __weak __attribute__((weak))
#endif
/*
* Use __hidden attribute to mark a non-static BPF subprogram effectively
* static for BPF verifier's verification algorithm purposes, allowing more
* extensive and permissive BPF verification process, taking into account
* subprogram's caller context.
*/
#define __hidden __attribute__((visibility("hidden")))
/* When utilizing vmlinux.h with BPF CO-RE, user BPF programs can't include
* any system-level headers (such as stddef.h, linux/version.h, etc), and
* commonly-used macros like NULL and KERNEL_VERSION aren't available through
* vmlinux.h. This just adds unnecessary hurdles and forces users to re-define
* them on their own. So as a convenience, provide such definitions here.
*/
#ifndef NULL
#define NULL ((void *)0)
#endif
#ifndef KERNEL_VERSION
#define KERNEL_VERSION(a, b, c) (((a) << 16) + ((b) << 8) + ((c) > 255 ? 255 : (c)))
#endif
/*
* Helper macros to manipulate data structures
*/
#ifndef offsetof
#define offsetof(TYPE, MEMBER) ((unsigned long)&((TYPE *)0)->MEMBER)
#endif
#ifndef container_of
#define container_of(ptr, type, member) \
({ \
void *__mptr = (void *)(ptr); \
((type *)(__mptr - offsetof(type, member))); \
})
#endif
/*
* Helper macro to throw a compilation error if __bpf_unreachable() gets
* built into the resulting code. This works given BPF back end does not
* implement __builtin_trap(). This is useful to assert that certain paths
* of the program code are never used and hence eliminated by the compiler.
*
* For example, consider a switch statement that covers known cases used by
* the program. __bpf_unreachable() can then reside in the default case. If
* the program gets extended such that a case is not covered in the switch
* statement, then it will throw a build error due to the default case not
* being compiled out.
*/
#ifndef __bpf_unreachable
# define __bpf_unreachable() __builtin_trap()
#endif
/*
* Helper function to perform a tail call with a constant/immediate map slot.
*/
#if __clang_major__ >= 8 && defined(__bpf__)
static __always_inline void
bpf_tail_call_static(void *ctx, const void *map, const __u32 slot)
{
if (!__builtin_constant_p(slot))
__bpf_unreachable();
/*
* Provide a hard guarantee that LLVM won't optimize setting r2 (map
* pointer) and r3 (constant map index) from _different paths_ ending
* up at the _same_ call insn as otherwise we won't be able to use the
* jmpq/nopl retpoline-free patching by the x86-64 JIT in the kernel
* given they mismatch. See also d2e4c1e6c294 ("bpf: Constant map key
* tracking for prog array pokes") for details on verifier tracking.
*
* Note on clobber list: we need to stay in-line with BPF calling
* convention, so even if we don't end up using r0, r4, r5, we need
* to mark them as clobber so that LLVM doesn't end up using them
* before / after the call.
*/
asm volatile("r1 = %[ctx]\n\t"
"r2 = %[map]\n\t"
"r3 = %[slot]\n\t"
"call 12"
:: [ctx]"r"(ctx), [map]"r"(map), [slot]"i"(slot)
: "r0", "r1", "r2", "r3", "r4", "r5");
}
#endif
/*
* Helper structure used by eBPF C program
* to describe BPF map attributes to libbpf loader
*/
struct bpf_map_def {
unsigned int type;
unsigned int key_size;
unsigned int value_size;
unsigned int max_entries;
unsigned int map_flags;
};
enum libbpf_pin_type {
LIBBPF_PIN_NONE,
/* PIN_BY_NAME: pin maps by name (in /sys/fs/bpf by default) */
LIBBPF_PIN_BY_NAME,
};
enum libbpf_tristate {
TRI_NO = 0,
TRI_YES = 1,
TRI_MODULE = 2,
};
#define __kconfig __attribute__((section(".kconfig")))
#define __ksym __attribute__((section(".ksyms")))
#ifndef ___bpf_concat
#define ___bpf_concat(a, b) a ## b
#endif
#ifndef ___bpf_apply
#define ___bpf_apply(fn, n) ___bpf_concat(fn, n)
#endif
#ifndef ___bpf_nth
#define ___bpf_nth(_, _1, _2, _3, _4, _5, _6, _7, _8, _9, _a, _b, _c, N, ...) N
#endif
#ifndef ___bpf_narg
#define ___bpf_narg(...) \
___bpf_nth(_, ##__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#endif
#define ___bpf_fill0(arr, p, x) do {} while (0)
#define ___bpf_fill1(arr, p, x) arr[p] = x
#define ___bpf_fill2(arr, p, x, args...) arr[p] = x; ___bpf_fill1(arr, p + 1, args)
#define ___bpf_fill3(arr, p, x, args...) arr[p] = x; ___bpf_fill2(arr, p + 1, args)
#define ___bpf_fill4(arr, p, x, args...) arr[p] = x; ___bpf_fill3(arr, p + 1, args)
#define ___bpf_fill5(arr, p, x, args...) arr[p] = x; ___bpf_fill4(arr, p + 1, args)
#define ___bpf_fill6(arr, p, x, args...) arr[p] = x; ___bpf_fill5(arr, p + 1, args)
#define ___bpf_fill7(arr, p, x, args...) arr[p] = x; ___bpf_fill6(arr, p + 1, args)
#define ___bpf_fill8(arr, p, x, args...) arr[p] = x; ___bpf_fill7(arr, p + 1, args)
#define ___bpf_fill9(arr, p, x, args...) arr[p] = x; ___bpf_fill8(arr, p + 1, args)
#define ___bpf_fill10(arr, p, x, args...) arr[p] = x; ___bpf_fill9(arr, p + 1, args)
#define ___bpf_fill11(arr, p, x, args...) arr[p] = x; ___bpf_fill10(arr, p + 1, args)
#define ___bpf_fill12(arr, p, x, args...) arr[p] = x; ___bpf_fill11(arr, p + 1, args)
#define ___bpf_fill(arr, args...) \
___bpf_apply(___bpf_fill, ___bpf_narg(args))(arr, 0, args)
/*
* BPF_SEQ_PRINTF to wrap bpf_seq_printf to-be-printed values
* in a structure.
*/
#define BPF_SEQ_PRINTF(seq, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_seq_printf(seq, ___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/*
* BPF_SNPRINTF wraps the bpf_snprintf helper with variadic arguments instead of
* an array of u64.
*/
#define BPF_SNPRINTF(out, out_size, fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_snprintf(out, out_size, ___fmt, \
___param, sizeof(___param)); \
})
#ifdef BPF_NO_GLOBAL_DATA
#define BPF_PRINTK_FMT_MOD
#else
#define BPF_PRINTK_FMT_MOD static const
#endif
#define __bpf_printk(fmt, ...) \
({ \
BPF_PRINTK_FMT_MOD char ____fmt[] = fmt; \
bpf_trace_printk(____fmt, sizeof(____fmt), \
##__VA_ARGS__); \
})
/*
* __bpf_vprintk wraps the bpf_trace_vprintk helper with variadic arguments
* instead of an array of u64.
*/
#define __bpf_vprintk(fmt, args...) \
({ \
static const char ___fmt[] = fmt; \
unsigned long long ___param[___bpf_narg(args)]; \
\
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wint-conversion\"") \
___bpf_fill(___param, args); \
_Pragma("GCC diagnostic pop") \
\
bpf_trace_vprintk(___fmt, sizeof(___fmt), \
___param, sizeof(___param)); \
})
/* Use __bpf_printk when bpf_printk call has 3 or fewer fmt args
* Otherwise use __bpf_vprintk
*/
#define ___bpf_pick_printk(...) \
___bpf_nth(_, ##__VA_ARGS__, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_vprintk, __bpf_vprintk, \
__bpf_vprintk, __bpf_vprintk, __bpf_printk /*3*/, __bpf_printk /*2*/,\
__bpf_printk /*1*/, __bpf_printk /*0*/)
/* Helper macro to print out debug messages */
#define bpf_printk(fmt, args...) ___bpf_pick_printk(args)(fmt, ##args)
#endif

View File

@ -1,342 +0,0 @@
#include <stdint.h>
#include <stdbool.h>
//#include <linux/types.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
//#include <linux/if_packet.h>
//#include <linux/if_vlan.h>
#include <linux/ip.h>
#include <linux/in.h>
#include <linux/tcp.h>
//#include <linux/udp.h>
#include <linux/pkt_cls.h>
#include "bpf_endian.h"
#include "bpf_helpers.h"
#define IP_CSUM_OFF (ETH_HLEN + offsetof(struct iphdr, check))
#define IP_DST_OFF (ETH_HLEN + offsetof(struct iphdr, daddr))
#define IP_SRC_OFF (ETH_HLEN + offsetof(struct iphdr, saddr))
#define IP_PROTO_OFF (ETH_HLEN + offsetof(struct iphdr, protocol))
#define TCP_CSUM_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, check))
#define TCP_SRC_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, source))
#define TCP_DST_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct tcphdr, dest))
//#define UDP_CSUM_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, check))
//#define UDP_SRC_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, source))
//#define UDP_DST_OFF (ETH_HLEN + sizeof(struct iphdr) + offsetof(struct udphdr, dest))
#define IS_PSEUDO 0x10
struct origin_info {
__be32 ip;
__be16 port;
__u16 pad;
};
struct origin_info *origin_info_unused __attribute__((unused));
struct redir_info {
__be32 sip;
__be32 dip;
__be16 sport;
__be16 dport;
};
struct redir_info *redir_info_unused __attribute__((unused));
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__type(key, struct redir_info);
__type(value, struct origin_info);
__uint(max_entries, 65535);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} pair_original_dst_map SEC(".maps");
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, __u32);
__uint(max_entries, 3);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} redir_params_map SEC(".maps");
static __always_inline int rewrite_ip(struct __sk_buff *skb, __be32 new_ip, bool is_dest) {
int ret, off = 0, flags = IS_PSEUDO;
__be32 old_ip;
if (is_dest)
ret = bpf_skb_load_bytes(skb, IP_DST_OFF, &old_ip, 4);
else
ret = bpf_skb_load_bytes(skb, IP_SRC_OFF, &old_ip, 4);
if (ret < 0) {
return ret;
}
off = TCP_CSUM_OFF;
// __u8 proto;
//
// ret = bpf_skb_load_bytes(skb, IP_PROTO_OFF, &proto, 1);
// if (ret < 0) {
// return BPF_DROP;
// }
//
// switch (proto) {
// case IPPROTO_TCP:
// off = TCP_CSUM_OFF;
// break;
//
// case IPPROTO_UDP:
// off = UDP_CSUM_OFF;
// flags |= BPF_F_MARK_MANGLED_0;
// break;
//
// case IPPROTO_ICMPV6:
// off = offsetof(struct icmp6hdr, icmp6_cksum);
// break;
// }
//
// if (off) {
ret = bpf_l4_csum_replace(skb, off, old_ip, new_ip, flags | sizeof(new_ip));
if (ret < 0) {
return ret;
}
// }
ret = bpf_l3_csum_replace(skb, IP_CSUM_OFF, old_ip, new_ip, sizeof(new_ip));
if (ret < 0) {
return ret;
}
if (is_dest)
ret = bpf_skb_store_bytes(skb, IP_DST_OFF, &new_ip, sizeof(new_ip), 0);
else
ret = bpf_skb_store_bytes(skb, IP_SRC_OFF, &new_ip, sizeof(new_ip), 0);
if (ret < 0) {
return ret;
}
return 1;
}
static __always_inline int rewrite_port(struct __sk_buff *skb, __be16 new_port, bool is_dest) {
int ret, off = 0;
__be16 old_port;
if (is_dest)
ret = bpf_skb_load_bytes(skb, TCP_DST_OFF, &old_port, 2);
else
ret = bpf_skb_load_bytes(skb, TCP_SRC_OFF, &old_port, 2);
if (ret < 0) {
return ret;
}
off = TCP_CSUM_OFF;
ret = bpf_l4_csum_replace(skb, off, old_port, new_port, sizeof(new_port));
if (ret < 0) {
return ret;
}
if (is_dest)
ret = bpf_skb_store_bytes(skb, TCP_DST_OFF, &new_port, sizeof(new_port), 0);
else
ret = bpf_skb_store_bytes(skb, TCP_SRC_OFF, &new_port, sizeof(new_port), 0);
if (ret < 0) {
return ret;
}
return 1;
}
static __always_inline bool is_lan_ip(__be32 addr) {
if (addr == 0xffffffff)
return true;
__u8 fist = (__u8)(addr & 0xff);
if (fist == 127 || fist == 10)
return true;
__u8 second = (__u8)((addr >> 8) & 0xff);
if (fist == 172 && second >= 16 && second <= 31)
return true;
if (fist == 192 && second == 168)
return true;
return false;
}
SEC("tc_clash_auto_redir_ingress")
int tc_redir_ingress_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto != bpf_htons(ETH_P_IP))
return TC_ACT_OK;
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
__u32 key = 0, *route_index, *redir_ip, *redir_port;
route_index = bpf_map_lookup_elem(&redir_params_map, &key);
if (!route_index)
return TC_ACT_OK;
if (iph->protocol == IPPROTO_ICMP && *route_index != 0)
return bpf_redirect(*route_index, 0);
if (iph->protocol != IPPROTO_TCP)
return TC_ACT_OK;
struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
if ((void *)(tcph + 1) > data_end)
return TC_ACT_SHOT;
key = 1;
redir_ip = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_ip)
return TC_ACT_OK;
key = 2;
redir_port = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_port)
return TC_ACT_OK;
__be32 new_ip = bpf_htonl(*redir_ip);
__be16 new_port = bpf_htonl(*redir_port) >> 16;
__be32 old_ip = iph->daddr;
__be16 old_port = tcph->dest;
if (old_ip == new_ip || is_lan_ip(old_ip) || bpf_ntohs(old_port) == 53) {
return TC_ACT_OK;
}
struct redir_info p_key = {
.sip = iph->saddr,
.sport = tcph->source,
.dip = new_ip,
.dport = new_port,
};
if (tcph->syn && !tcph->ack) {
struct origin_info origin = {
.ip = old_ip,
.port = old_port,
};
bpf_map_update_elem(&pair_original_dst_map, &p_key, &origin, BPF_NOEXIST);
if (rewrite_ip(skb, new_ip, true) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, new_port, true) < 0) {
return TC_ACT_SHOT;
}
} else {
struct origin_info *origin = bpf_map_lookup_elem(&pair_original_dst_map, &p_key);
if (!origin) {
return TC_ACT_OK;
}
if (rewrite_ip(skb, new_ip, true) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, new_port, true) < 0) {
return TC_ACT_SHOT;
}
}
return TC_ACT_OK;
}
SEC("tc_clash_auto_redir_egress")
int tc_redir_egress_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto != bpf_htons(ETH_P_IP))
return TC_ACT_OK;
__u32 key = 0, *redir_ip, *redir_port; // *clash_mark
// clash_mark = bpf_map_lookup_elem(&redir_params_map, &key);
// if (clash_mark && *clash_mark != 0 && *clash_mark == skb->mark)
// return TC_ACT_OK;
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
if (iph->protocol != IPPROTO_TCP)
return TC_ACT_OK;
struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
if ((void *)(tcph + 1) > data_end)
return TC_ACT_SHOT;
key = 1;
redir_ip = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_ip)
return TC_ACT_OK;
key = 2;
redir_port = bpf_map_lookup_elem(&redir_params_map, &key);
if (!redir_port)
return TC_ACT_OK;
__be32 new_ip = bpf_htonl(*redir_ip);
__be16 new_port = bpf_htonl(*redir_port) >> 16;
__be32 old_ip = iph->saddr;
__be16 old_port = tcph->source;
if (old_ip != new_ip || old_port != new_port) {
return TC_ACT_OK;
}
struct redir_info p_key = {
.sip = iph->daddr,
.sport = tcph->dest,
.dip = iph->saddr,
.dport = tcph->source,
};
struct origin_info *origin = bpf_map_lookup_elem(&pair_original_dst_map, &p_key);
if (!origin) {
return TC_ACT_OK;
}
if (tcph->fin && tcph->ack) {
bpf_map_delete_elem(&pair_original_dst_map, &p_key);
}
if (rewrite_ip(skb, origin->ip, false) < 0) {
return TC_ACT_SHOT;
}
if (rewrite_port(skb, origin->port, false) < 0) {
return TC_ACT_SHOT;
}
return TC_ACT_OK;
}
char _license[] SEC("license") = "GPL";

View File

@ -1,103 +0,0 @@
#include <stdbool.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/in.h>
//#include <linux/tcp.h>
//#include <linux/udp.h>
#include <linux/pkt_cls.h>
#include "bpf_endian.h"
#include "bpf_helpers.h"
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, __u32);
__uint(max_entries, 2);
__uint(pinning, LIBBPF_PIN_BY_NAME);
} tc_params_map SEC(".maps");
static __always_inline bool is_lan_ip(__be32 addr) {
if (addr == 0xffffffff)
return true;
__u8 fist = (__u8)(addr & 0xff);
if (fist == 127 || fist == 10)
return true;
__u8 second = (__u8)((addr >> 8) & 0xff);
if (fist == 172 && second >= 16 && second <= 31)
return true;
if (fist == 192 && second == 168)
return true;
return false;
}
SEC("tc_clash_redirect_to_tun")
int tc_tun_func(struct __sk_buff *skb) {
void *data = (void *)(long)skb->data;
void *data_end = (void *)(long)skb->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return TC_ACT_OK;
if (eth->h_proto == bpf_htons(ETH_P_ARP))
return TC_ACT_OK;
__u32 key = 0, *clash_mark, *tun_ifindex;
clash_mark = bpf_map_lookup_elem(&tc_params_map, &key);
if (!clash_mark)
return TC_ACT_OK;
if (skb->mark == *clash_mark)
return TC_ACT_OK;
if (eth->h_proto == bpf_htons(ETH_P_IP)) {
struct iphdr *iph = (struct iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return TC_ACT_OK;
if (iph->protocol == IPPROTO_ICMP)
return TC_ACT_OK;
__be32 daddr = iph->daddr;
if (is_lan_ip(daddr))
return TC_ACT_OK;
// if (iph->protocol == IPPROTO_TCP) {
// struct tcphdr *tcph = (struct tcphdr *)(iph + 1);
// if ((void *)(tcph + 1) > data_end)
// return TC_ACT_OK;
//
// __u16 source = bpf_ntohs(tcph->source);
// if (source == 22 || source == 80 || source == 443 || source == 8080 || source == 8443 || source == 9090 || (source >= 7890 && source <= 7895))
// return TC_ACT_OK;
// } else if (iph->protocol == IPPROTO_UDP) {
// struct udphdr *udph = (struct udphdr *)(iph + 1);
// if ((void *)(udph + 1) > data_end)
// return TC_ACT_OK;
//
// __u16 source = bpf_ntohs(udph->source);
// if (source == 53 || (source >= 135 && source <= 139))
// return TC_ACT_OK;
// }
}
key = 1;
tun_ifindex = bpf_map_lookup_elem(&tc_params_map, &key);
if (!tun_ifindex)
return TC_ACT_OK;
//return bpf_redirect(*tun_ifindex, BPF_F_INGRESS); // __bpf_rx_skb
return bpf_redirect(*tun_ifindex, 0); // __bpf_tx_skb / __dev_xmit_skb
}
char _license[] SEC("license") = "GPL";

View File

@ -1,13 +0,0 @@
package byteorder
import (
"net"
)
// NetIPv4ToHost32 converts an net.IP to a uint32 in host byte order. ip
// must be a IPv4 address, otherwise the function will panic.
func NetIPv4ToHost32(ip net.IP) uint32 {
ipv4 := ip.To4()
_ = ipv4[3] // Assert length of ipv4.
return Native.Uint32(ipv4)
}

View File

@ -1,12 +0,0 @@
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
package byteorder
import "encoding/binary"
var Native binary.ByteOrder = binary.BigEndian
func HostToNetwork16(u uint16) uint16 { return u }
func HostToNetwork32(u uint32) uint32 { return u }
func NetworkToHost16(u uint16) uint16 { return u }
func NetworkToHost32(u uint32) uint32 { return u }

View File

@ -1,15 +0,0 @@
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
package byteorder
import (
"encoding/binary"
"math/bits"
)
var Native binary.ByteOrder = binary.LittleEndian
func HostToNetwork16(u uint16) uint16 { return bits.ReverseBytes16(u) }
func HostToNetwork32(u uint32) uint32 { return bits.ReverseBytes32(u) }
func NetworkToHost16(u uint16) uint16 { return bits.ReverseBytes16(u) }
func NetworkToHost32(u uint32) uint32 { return bits.ReverseBytes32(u) }

View File

@ -1,33 +0,0 @@
package ebpf
import (
"net/netip"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
type TcEBpfProgram struct {
pros []C.EBpf
rawNICs []string
}
func (t *TcEBpfProgram) RawNICs() []string {
return t.rawNICs
}
func (t *TcEBpfProgram) Close() {
for _, p := range t.pros {
p.Close()
}
}
func (t *TcEBpfProgram) Lookup(srcAddrPort netip.AddrPort) (addr socks5.Addr, err error) {
for _, p := range t.pros {
addr, err = p.Lookup(srcAddrPort)
if err == nil {
return
}
}
return
}

View File

@ -1,137 +0,0 @@
//go:build !android
package ebpf
import (
"fmt"
"net/netip"
"github.com/Dreamacro/clash/common/cmd"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/ebpf/redir"
"github.com/Dreamacro/clash/component/ebpf/tc"
C "github.com/Dreamacro/clash/constant"
"github.com/sagernet/netlink"
)
func GetAutoDetectInterface() (string, error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
if err != nil {
return "", err
}
for _, route := range routes {
if route.Dst == nil {
lk, err := netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return "", err
}
if lk.Type() == "tuntap" {
continue
}
return lk.Attrs().Name, nil
}
}
return "", fmt.Errorf("interface not found")
}
// NewTcEBpfProgram new redirect to tun ebpf program
func NewTcEBpfProgram(ifaceNames []string, tunName string) (*TcEBpfProgram, error) {
tunIface, err := netlink.LinkByName(tunName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", tunName, err)
}
tunIndex := uint32(tunIface.Attrs().Index)
dialer.DefaultRoutingMark.Store(C.ClashTrafficMark)
ifMark := uint32(dialer.DefaultRoutingMark.Load())
var pros []C.EBpf
for _, ifaceName := range ifaceNames {
iface, err := netlink.LinkByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", ifaceName, err)
}
if iface.Attrs().OperState != netlink.OperUp {
return nil, fmt.Errorf("network iface %q is down", ifaceName)
}
attrs := iface.Attrs()
index := attrs.Index
tcPro := tc.NewEBpfTc(ifaceName, index, ifMark, tunIndex)
if err = tcPro.Start(); err != nil {
return nil, err
}
pros = append(pros, tcPro)
}
systemSetting(ifaceNames...)
return &TcEBpfProgram{pros: pros, rawNICs: ifaceNames}, nil
}
// NewRedirEBpfProgram new auto redirect ebpf program
func NewRedirEBpfProgram(ifaceNames []string, redirPort uint16, defaultRouteInterfaceName string) (*TcEBpfProgram, error) {
defaultRouteInterface, err := netlink.LinkByName(defaultRouteInterfaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", defaultRouteInterfaceName, err)
}
defaultRouteIndex := uint32(defaultRouteInterface.Attrs().Index)
var pros []C.EBpf
for _, ifaceName := range ifaceNames {
iface, err := netlink.LinkByName(ifaceName)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q: %w", ifaceName, err)
}
attrs := iface.Attrs()
index := attrs.Index
addrs, err := netlink.AddrList(iface, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("lookup network iface %q address: %w", ifaceName, err)
}
if len(addrs) == 0 {
return nil, fmt.Errorf("network iface %q does not contain any ipv4 addresses", ifaceName)
}
address, _ := netip.AddrFromSlice(addrs[0].IP)
redirAddrPort := netip.AddrPortFrom(address, redirPort)
redirPro := redir.NewEBpfRedirect(ifaceName, index, 0, defaultRouteIndex, redirAddrPort)
if err = redirPro.Start(); err != nil {
return nil, err
}
pros = append(pros, redirPro)
}
systemSetting(ifaceNames...)
return &TcEBpfProgram{pros: pros, rawNICs: ifaceNames}, nil
}
func systemSetting(ifaceNames ...string) {
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.ip_forward=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.forwarding=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.accept_local=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.accept_redirects=1")
_, _ = cmd.ExecCmd("sysctl -w net.ipv4.conf.all.rp_filter=0")
for _, ifaceName := range ifaceNames {
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.forwarding=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.accept_local=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.accept_redirects=1", ifaceName))
_, _ = cmd.ExecCmd(fmt.Sprintf("sysctl -w net.ipv4.conf.%s.rp_filter=0", ifaceName))
}
}

View File

@ -1,21 +0,0 @@
//go:build !linux || android
package ebpf
import (
"fmt"
)
// NewTcEBpfProgram new ebpf tc program
func NewTcEBpfProgram(_ []string, _ string) (*TcEBpfProgram, error) {
return nil, fmt.Errorf("system not supported")
}
// NewRedirEBpfProgram new ebpf redirect program
func NewRedirEBpfProgram(_ []string, _ uint16, _ string) (*TcEBpfProgram, error) {
return nil, fmt.Errorf("system not supported")
}
func GetAutoDetectInterface() (string, error) {
return "", fmt.Errorf("system not supported")
}

View File

@ -1,216 +0,0 @@
//go:build linux
package redir
import (
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"os"
"path/filepath"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/rlimit"
"github.com/sagernet/netlink"
"golang.org/x/sys/unix"
"github.com/Dreamacro/clash/component/ebpf/byteorder"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf ../bpf/redir.c
const (
mapKey1 uint32 = 0
mapKey2 uint32 = 1
mapKey3 uint32 = 2
)
type EBpfRedirect struct {
objs io.Closer
originMap *ebpf.Map
qdisc netlink.Qdisc
filter netlink.Filter
filterEgress netlink.Filter
ifName string
ifIndex int
ifMark uint32
rtIndex uint32
redirIp uint32
redirPort uint16
bpfPath string
}
func NewEBpfRedirect(ifName string, ifIndex int, ifMark uint32, routeIndex uint32, redirAddrPort netip.AddrPort) *EBpfRedirect {
return &EBpfRedirect{
ifName: ifName,
ifIndex: ifIndex,
ifMark: ifMark,
rtIndex: routeIndex,
redirIp: binary.BigEndian.Uint32(redirAddrPort.Addr().AsSlice()),
redirPort: redirAddrPort.Port(),
}
}
func (e *EBpfRedirect) Start() error {
if err := rlimit.RemoveMemlock(); err != nil {
return fmt.Errorf("remove memory lock: %w", err)
}
e.bpfPath = filepath.Join(C.BpfFSPath, e.ifName)
if err := os.MkdirAll(e.bpfPath, os.ModePerm); err != nil {
return fmt.Errorf("failed to create bpf fs subpath: %w", err)
}
var objs bpfObjects
if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: e.bpfPath,
},
}); err != nil {
e.Close()
return fmt.Errorf("loading objects: %w", err)
}
e.objs = &objs
e.originMap = objs.bpfMaps.PairOriginalDstMap
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey1, e.rtIndex, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey2, e.redirIp, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.RedirParamsMap.Update(mapKey3, uint32(e.redirPort), ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
attrs := netlink.QdiscAttrs{
LinkIndex: e.ifIndex,
Handle: netlink.MakeHandle(0xffff, 0),
Parent: netlink.HANDLE_CLSACT,
}
qdisc := &netlink.GenericQdisc{
QdiscAttrs: attrs,
QdiscType: "clsact",
}
e.qdisc = qdisc
if err := netlink.QdiscAdd(qdisc); err != nil {
if os.IsExist(err) {
_ = netlink.QdiscDel(qdisc)
err = netlink.QdiscAdd(qdisc)
}
if err != nil {
e.Close()
return fmt.Errorf("cannot add clsact qdisc: %w", err)
}
}
filterAttrs := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_INGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_IP,
Priority: 0,
}
filter := &netlink.BpfFilter{
FilterAttrs: filterAttrs,
Fd: objs.bpfPrograms.TcRedirIngressFunc.FD(),
Name: "clash-redir-ingress-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filter); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter ingress: %w", err)
}
e.filter = filter
filterAttrsEgress := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_EGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_IP,
Priority: 0,
}
filterEgress := &netlink.BpfFilter{
FilterAttrs: filterAttrsEgress,
Fd: objs.bpfPrograms.TcRedirEgressFunc.FD(),
Name: "clash-redir-egress-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filterEgress); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter egress: %w", err)
}
e.filterEgress = filterEgress
return nil
}
func (e *EBpfRedirect) Close() {
if e.filter != nil {
_ = netlink.FilterDel(e.filter)
}
if e.filterEgress != nil {
_ = netlink.FilterDel(e.filterEgress)
}
if e.qdisc != nil {
_ = netlink.QdiscDel(e.qdisc)
}
if e.objs != nil {
_ = e.objs.Close()
}
_ = os.Remove(filepath.Join(e.bpfPath, "redir_params_map"))
_ = os.Remove(filepath.Join(e.bpfPath, "pair_original_dst_map"))
}
func (e *EBpfRedirect) Lookup(srcAddrPort netip.AddrPort) (socks5.Addr, error) {
rAddr := srcAddrPort.Addr().Unmap()
if rAddr.Is6() {
return nil, fmt.Errorf("remote address is ipv6")
}
srcIp := binary.BigEndian.Uint32(rAddr.AsSlice())
scrPort := srcAddrPort.Port()
key := bpfRedirInfo{
Sip: byteorder.HostToNetwork32(srcIp),
Sport: byteorder.HostToNetwork16(scrPort),
Dip: byteorder.HostToNetwork32(e.redirIp),
Dport: byteorder.HostToNetwork16(e.redirPort),
}
origin := bpfOriginInfo{}
err := e.originMap.Lookup(key, &origin)
if err != nil {
return nil, err
}
addr := make([]byte, net.IPv4len+3)
addr[0] = socks5.AtypIPv4
binary.BigEndian.PutUint32(addr[1:1+net.IPv4len], byteorder.NetworkToHost32(origin.Ip)) // big end
binary.BigEndian.PutUint16(addr[1+net.IPv4len:3+net.IPv4len], byteorder.NetworkToHost16(origin.Port)) // big end
return addr, nil
}

View File

@ -1,138 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package redir
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfOriginInfo struct {
Ip uint32
Port uint16
Pad uint16
}
type bpfRedirInfo struct {
Sip uint32
Dip uint32
Sport uint16
Dport uint16
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcRedirEgressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_ingress_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PairOriginalDstMap *ebpf.MapSpec `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.MapSpec `ebpf:"redir_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PairOriginalDstMap *ebpf.Map `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.Map `ebpf:"redir_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PairOriginalDstMap,
m.RedirParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcRedirEgressFunc *ebpf.Program `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.Program `ebpf:"tc_redir_ingress_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcRedirEgressFunc,
p.TcRedirIngressFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@ -1,138 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package redir
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
type bpfOriginInfo struct {
Ip uint32
Port uint16
Pad uint16
}
type bpfRedirInfo struct {
Sip uint32
Dip uint32
Sport uint16
Dport uint16
}
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcRedirEgressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.ProgramSpec `ebpf:"tc_redir_ingress_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
PairOriginalDstMap *ebpf.MapSpec `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.MapSpec `ebpf:"redir_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
PairOriginalDstMap *ebpf.Map `ebpf:"pair_original_dst_map"`
RedirParamsMap *ebpf.Map `ebpf:"redir_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.PairOriginalDstMap,
m.RedirParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcRedirEgressFunc *ebpf.Program `ebpf:"tc_redir_egress_func"`
TcRedirIngressFunc *ebpf.Program `ebpf:"tc_redir_ingress_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcRedirEgressFunc,
p.TcRedirIngressFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@ -1,119 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build arm64be || armbe || mips || mips64 || mips64p32 || ppc64 || s390 || s390x || sparc || sparc64
// +build arm64be armbe mips mips64 mips64p32 ppc64 s390 s390x sparc sparc64
package tc
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcTunFunc *ebpf.ProgramSpec `ebpf:"tc_tun_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
TcParamsMap *ebpf.MapSpec `ebpf:"tc_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
TcParamsMap *ebpf.Map `ebpf:"tc_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.TcParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcTunFunc *ebpf.Program `ebpf:"tc_tun_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcTunFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfeb.o
var _BpfBytes []byte

Binary file not shown.

View File

@ -1,119 +0,0 @@
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64 || amd64p32 || arm || arm64 || mips64le || mips64p32le || mipsle || ppc64le || riscv64
// +build 386 amd64 amd64p32 arm arm64 mips64le mips64p32le mipsle ppc64le riscv64
package tc
import (
"bytes"
_ "embed"
"fmt"
"io"
"github.com/cilium/ebpf"
)
// loadBpf returns the embedded CollectionSpec for bpf.
func loadBpf() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_BpfBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load bpf: %w", err)
}
return spec, err
}
// loadBpfObjects loads bpf and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *bpfObjects
// *bpfPrograms
// *bpfMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := loadBpf()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// bpfSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfSpecs struct {
bpfProgramSpecs
bpfMapSpecs
}
// bpfSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfProgramSpecs struct {
TcTunFunc *ebpf.ProgramSpec `ebpf:"tc_tun_func"`
}
// bpfMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type bpfMapSpecs struct {
TcParamsMap *ebpf.MapSpec `ebpf:"tc_params_map"`
}
// bpfObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfObjects struct {
bpfPrograms
bpfMaps
}
func (o *bpfObjects) Close() error {
return _BpfClose(
&o.bpfPrograms,
&o.bpfMaps,
)
}
// bpfMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfMaps struct {
TcParamsMap *ebpf.Map `ebpf:"tc_params_map"`
}
func (m *bpfMaps) Close() error {
return _BpfClose(
m.TcParamsMap,
)
}
// bpfPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
type bpfPrograms struct {
TcTunFunc *ebpf.Program `ebpf:"tc_tun_func"`
}
func (p *bpfPrograms) Close() error {
return _BpfClose(
p.TcTunFunc,
)
}
func _BpfClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//go:embed bpf_bpfel.o
var _BpfBytes []byte

Binary file not shown.

View File

@ -1,147 +0,0 @@
//go:build linux
package tc
import (
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/rlimit"
"github.com/sagernet/netlink"
"golang.org/x/sys/unix"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf ../bpf/tc.c
const (
mapKey1 uint32 = 0
mapKey2 uint32 = 1
)
type EBpfTC struct {
objs io.Closer
qdisc netlink.Qdisc
filter netlink.Filter
ifName string
ifIndex int
ifMark uint32
tunIfIndex uint32
bpfPath string
}
func NewEBpfTc(ifName string, ifIndex int, ifMark uint32, tunIfIndex uint32) *EBpfTC {
return &EBpfTC{
ifName: ifName,
ifIndex: ifIndex,
ifMark: ifMark,
tunIfIndex: tunIfIndex,
}
}
func (e *EBpfTC) Start() error {
if err := rlimit.RemoveMemlock(); err != nil {
return fmt.Errorf("remove memory lock: %w", err)
}
e.bpfPath = filepath.Join(C.BpfFSPath, e.ifName)
if err := os.MkdirAll(e.bpfPath, os.ModePerm); err != nil {
return fmt.Errorf("failed to create bpf fs subpath: %w", err)
}
var objs bpfObjects
if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: e.bpfPath,
},
}); err != nil {
e.Close()
return fmt.Errorf("loading objects: %w", err)
}
e.objs = &objs
if err := objs.bpfMaps.TcParamsMap.Update(mapKey1, e.ifMark, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
if err := objs.bpfMaps.TcParamsMap.Update(mapKey2, e.tunIfIndex, ebpf.UpdateAny); err != nil {
e.Close()
return fmt.Errorf("storing objects: %w", err)
}
attrs := netlink.QdiscAttrs{
LinkIndex: e.ifIndex,
Handle: netlink.MakeHandle(0xffff, 0),
Parent: netlink.HANDLE_CLSACT,
}
qdisc := &netlink.GenericQdisc{
QdiscAttrs: attrs,
QdiscType: "clsact",
}
e.qdisc = qdisc
if err := netlink.QdiscAdd(qdisc); err != nil {
if os.IsExist(err) {
_ = netlink.QdiscDel(qdisc)
err = netlink.QdiscAdd(qdisc)
}
if err != nil {
e.Close()
return fmt.Errorf("cannot add clsact qdisc: %w", err)
}
}
filterAttrs := netlink.FilterAttrs{
LinkIndex: e.ifIndex,
Parent: netlink.HANDLE_MIN_EGRESS,
Handle: netlink.MakeHandle(0, 1),
Protocol: unix.ETH_P_ALL,
Priority: 1,
}
filter := &netlink.BpfFilter{
FilterAttrs: filterAttrs,
Fd: objs.bpfPrograms.TcTunFunc.FD(),
Name: "clash-tc-" + e.ifName,
DirectAction: true,
}
if err := netlink.FilterAdd(filter); err != nil {
e.Close()
return fmt.Errorf("cannot attach ebpf object to filter: %w", err)
}
e.filter = filter
return nil
}
func (e *EBpfTC) Close() {
if e.filter != nil {
_ = netlink.FilterDel(e.filter)
}
if e.qdisc != nil {
_ = netlink.QdiscDel(e.qdisc)
}
if e.objs != nil {
_ = e.objs.Close()
}
_ = os.Remove(filepath.Join(e.bpfPath, "tc_params_map"))
}
func (e *EBpfTC) Lookup(_ netip.AddrPort) (socks5.Addr, error) {
return nil, fmt.Errorf("not supported")
}

View File

@ -34,7 +34,7 @@ type Pool struct {
offset netip.Addr
cycle bool
mux sync.Mutex
host *trie.DomainTrie[struct{}]
host *trie.DomainTrie[bool]
ipnet *netip.Prefix
store store
}
@ -150,7 +150,7 @@ func (p *Pool) restoreState() {
type Options struct {
IPNet *netip.Prefix
Host *trie.DomainTrie[struct{}]
Host *trie.DomainTrie[bool]
// Size sets the maximum number of entries in memory
// and does not work if Persistence is true
@ -166,7 +166,7 @@ func New(options Options) (*Pool, error) {
var (
hostAddr = options.IPNet.Masked().Addr()
gateway = hostAddr.Next()
first = gateway.Next().Next().Next() // default start with 198.18.0.4
first = gateway.Next().Next()
last = nnip.UnMasked(*options.IPNet)
)

View File

@ -62,16 +62,16 @@ func TestPool_Basic(t *testing.T) {
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.True(t, pool.Broadcast() == netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 6})))
assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 4})))
assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(netip.MustParseAddr("::1")))
}
}
@ -90,16 +90,16 @@ func TestPool_BasicV6(t *testing.T) {
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805"))
assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.True(t, pool.Broadcast() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8806")))
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("127.0.0.1")))
}
}
@ -116,7 +116,7 @@ func TestPool_CycleUsed(t *testing.T) {
for _, pool := range pools {
foo := pool.Lookup("foo.com")
bar := pool.Lookup("bar.com")
for i := 0; i < 9; i++ {
for i := 0; i < 10; i++ {
pool.Lookup(fmt.Sprintf("%d.com", i))
}
baz := pool.Lookup("baz.com")
@ -128,8 +128,8 @@ func TestPool_CycleUsed(t *testing.T) {
func TestPool_Skip(t *testing.T) {
ipnet := netip.MustParsePrefix("192.168.0.1/29")
tree := trie.New[struct{}]()
tree.Insert("example.com", struct{}{})
tree := trie.New[bool]()
tree.Insert("example.com", true)
pools, tempfile, err := createPools(Options{
IPNet: &ipnet,
Size: 10,
@ -198,8 +198,8 @@ func TestPool_Clone(t *testing.T) {
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 5}))
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
newPool, _ := New(Options{
IPNet: &ipnet,

View File

@ -2,8 +2,8 @@ package http
import (
"context"
"github.com/Dreamacro/clash/component/tls"
"github.com/Dreamacro/clash/listener/inner"
"github.com/Dreamacro/clash/log"
"io"
"net"
"net/http"
@ -52,10 +52,10 @@ func HttpRequest(ctx context.Context, url, method string, header map[string][]st
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
log.Infoln(urlRes.String())
conn := inner.HandleTcp(address, urlRes.Hostname())
return conn, nil
},
TLSClientConfig: tls.GetDefaultTLSConfig(),
}
client := http.Client{Transport: transport}

View File

@ -2,6 +2,9 @@ package process
import (
"errors"
"github.com/Dreamacro/clash/common/nnip"
C "github.com/Dreamacro/clash/constant"
"net"
"net/netip"
)
@ -9,6 +12,8 @@ var (
ErrInvalidNetwork = errors.New("invalid network")
ErrPlatformNotSupport = errors.New("not support on this platform")
ErrNotFound = errors.New("process not found")
enableFindProcess = true
)
const (
@ -16,6 +21,10 @@ const (
UDP = "udp"
)
func EnableFindProcess(e bool) {
enableFindProcess = e
}
func FindProcessName(network string, srcIP netip.Addr, srcPort int) (int32, string, error) {
return findProcessName(network, srcIP, srcPort)
}
@ -27,3 +36,51 @@ func FindUid(network string, srcIP netip.Addr, srcPort int) (int32, error) {
}
return uid, nil
}
func ShouldFindProcess(metadata *C.Metadata) bool {
if !enableFindProcess ||
metadata.Process != "" ||
metadata.ProcessPath != "" {
return false
}
for _, ip := range localIPs {
if ip == metadata.SrcIP {
return true
}
}
return false
}
func AppendLocalIPs(ip ...netip.Addr) {
localIPs = append(ip, localIPs...)
}
func getLocalIPs() []netip.Addr {
ips := []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()}
netInterfaces, err := net.Interfaces()
if err != nil {
ips = append(ips, netip.AddrFrom4([4]byte{127, 0, 0, 1}), nnip.IpToAddr(net.IPv6loopback))
return ips
}
for i := 0; i < len(netInterfaces); i++ {
if (netInterfaces[i].Flags & net.FlagUp) != 0 {
adds, _ := netInterfaces[i].Addrs()
for _, address := range adds {
if ipNet, ok := address.(*net.IPNet); ok {
ips = append(ips, nnip.IpToAddr(ipNet.IP))
}
}
}
}
return ips
}
var localIPs []netip.Addr
func init() {
localIPs = getLocalIPs()
}

View File

@ -3,11 +3,11 @@ package process
import (
"encoding/binary"
"net/netip"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/Dreamacro/clash/common/nnip"
"golang.org/x/sys/unix"
)
@ -17,22 +17,6 @@ const (
proccallnumpidinfo = 0x2
)
var structSize = func() int {
value, _ := syscall.Sysctl("kern.osrelease")
major, _, _ := strings.Cut(value, ".")
n, _ := strconv.ParseInt(major, 10, 64)
switch true {
case n >= 22:
return 408
default:
// from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n
// size/offset are round up (aligned) to 8 bytes in darwin
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
return 384
}
}()
func resolveSocketByNetlink(network string, ip netip.Addr, srcPort int) (int32, int32, error) {
return 0, 0, ErrPlatformNotSupport
}
@ -56,13 +40,16 @@ func findProcessName(network string, ip netip.Addr, port int) (int32, string, er
}
buf := []byte(value)
itemSize := structSize
// from darwin-xnu/bsd/netinet/in_pcblist.c:get_pcblist_n
// size/offset are round up (aligned) to 8 bytes in darwin
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
itemSize := 384
if network == TCP {
// rup8(sizeof(xtcpcb_n))
itemSize += 208
}
var fallbackUDPProcess string
// skip the first xinpgen(24 bytes) block
for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n
@ -76,37 +63,26 @@ func findProcessName(network string, ip netip.Addr, port int) (int32, string, er
// xinpcb_n.inp_vflag
flag := buf[inp+44]
var (
srcIP netip.Addr
srcIsIPv4 bool
)
var srcIP netip.Addr
switch {
case flag&0x1 > 0 && isIPv4:
// ipv4
srcIP, _ = netip.AddrFromSlice(buf[inp+76 : inp+80])
srcIsIPv4 = true
srcIP = nnip.IpToAddr(buf[inp+76 : inp+80])
case flag&0x2 > 0 && !isIPv4:
// ipv6
srcIP, _ = netip.AddrFromSlice(buf[inp+64 : inp+80])
srcIP = nnip.IpToAddr(buf[inp+64 : inp+80])
default:
continue
}
if ip == srcIP {
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
pp, err := getExecPathFromPID(pid)
return -1, pp, err
if ip != srcIP && (network == TCP || !srcIP.IsUnspecified()) {
continue
}
// udp packet connection may be not equal with srcIP
if network == UDP && srcIP.IsUnspecified() && isIPv4 == srcIsIPv4 {
fallbackUDPProcess, _ = getExecPathFromPID(readNativeUint32(buf[so+68 : so+72]))
}
}
if network == UDP && fallbackUDPProcess != "" {
return -1, fallbackUDPProcess, nil
// xsocket_n.so_last_pid
pid := readNativeUint32(buf[so+68 : so+72])
pp, err := getExecPathFromPID(pid)
return -1, pp, err
}
return -1, "", ErrNotFound

View File

@ -39,7 +39,6 @@ func findProcessName(network string, ip netip.Addr, srcPort int) (int32, string,
if err != nil {
return -1, "", err
}
pp, err := resolveProcessNameByProcSearch(inode, uid)
return uid, pp, err
}
@ -111,7 +110,7 @@ func resolveSocketByNetlink(network string, ip netip.Addr, srcPort int) (int32,
return 0, 0, fmt.Errorf("netlink message: NLMSG_ERROR")
}
inode, uid := unpackSocketDiagResponse(&messages[0])
inode, uid := unpackSocketDiagResponse(&message)
if inode < 0 || uid < 0 {
return 0, 0, fmt.Errorf("invalid inode(%d) or uid(%d)", inode, uid)
}
@ -198,6 +197,7 @@ func resolveProcessNameByProcSearch(inode, uid int32) (string, error) {
if err != nil {
continue
}
if runtime.GOOS == "android" {
if bytes.Equal(buffer[:n], socket) {
cmdline, err := os.ReadFile(path.Join(processPath, "cmdline"))

View File

@ -41,6 +41,7 @@ type Resolver interface {
ResolveIPv4(host string) (ip netip.Addr, err error)
ResolveIPv6(host string) (ip netip.Addr, err error)
ResolveAllIP(host string) (ip []netip.Addr, err error)
ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err error)
ResolveAllIPv4(host string) (ips []netip.Addr, err error)
ResolveAllIPv6(host string) (ips []netip.Addr, err error)
}
@ -54,7 +55,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
if ips, err := ResolveAllIPv4WithResolver(host, r); err == nil {
return ips[rand.Intn(len(ips))], nil
} else {
return netip.Addr{}, err
return netip.Addr{}, nil
}
}
@ -73,10 +74,10 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIPWithResolver same as ResolveIP, but with a resolver
func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) {
if ip, err := ResolveIPv4WithResolver(host, r); err == nil {
return ip, nil
if ips, err := ResolveAllIPPrimaryIPv4WithResolver(host, r); err == nil {
return ips[rand.Intn(len(ips))], nil
} else {
return ResolveIPv6WithResolver(host, r)
return netip.Addr{}, err
}
}
@ -94,6 +95,7 @@ func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
return ip, nil
}
}
return ResolveIPv4(host)
}
@ -106,6 +108,7 @@ func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
return ip, nil
}
}
return ResolveIPv6(host)
}
@ -118,6 +121,7 @@ func ResolveProxyServerHost(host string) (netip.Addr, error) {
return ip, err
}
}
return ResolveIP(host)
}
@ -127,7 +131,7 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
}
if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data(); ip.Is6() {
if ip := node.Data; ip.Is6() {
return []netip.Addr{ip}, nil
}
}
@ -154,23 +158,16 @@ func ResolveAllIPv6WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{}, ErrIPNotFound
}
addrs := make([]netip.Addr, 0, len(ipAddrs))
for _, ipAddr := range ipAddrs {
addrs = append(addrs, nnip.IpToAddr(ipAddr))
}
rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
return []netip.Addr{netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))]))}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data(); ip.Is4() {
return []netip.Addr{node.Data()}, nil
if ip := node.Data; ip.Is4() {
return []netip.Addr{node.Data}, nil
}
}
@ -196,27 +193,20 @@ func ResolveAllIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
return []netip.Addr{}, ErrIPNotFound
}
addrs := make([]netip.Addr, 0, len(ipAddrs))
for _, ipAddr := range ipAddrs {
addrs = append(addrs, nnip.IpToAddr(ipAddr))
ip := ipAddrs[rand.Intn(len(ipAddrs))].To4()
if ip == nil {
return []netip.Addr{}, ErrIPVersion
}
rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
return []netip.Addr{netip.AddrFrom4(*(*[4]byte)(ip))}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
return []netip.Addr{node.Data()}, nil
}
ip, err := netip.ParseAddr(host)
if err == nil {
return []netip.Addr{ip}, nil
return []netip.Addr{node.Data}, nil
}
if r != nil {
@ -229,23 +219,52 @@ func ResolveAllIPWithResolver(host string, r Resolver) ([]netip.Addr, error) {
return ResolveAllIPv4(host)
}
ip, err := netip.ParseAddr(host)
if err == nil {
return []netip.Addr{ip}, nil
}
if DefaultResolver == nil {
ipAddrs, err := net.DefaultResolver.LookupIP(context.Background(), "ip", host)
ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return []netip.Addr{}, err
} else if len(ipAddrs) == 0 {
return []netip.Addr{}, ErrIPNotFound
}
addrs := make([]netip.Addr, 0, len(ipAddrs))
for _, ipAddr := range ipAddrs {
addrs = append(addrs, nnip.IpToAddr(ipAddr))
}
rand.Shuffle(len(addrs), func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
return addrs, nil
return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
func ResolveAllIPPrimaryIPv4WithResolver(host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
return []netip.Addr{node.Data}, nil
}
if r != nil {
if DisableIPv6 {
return r.ResolveAllIPv4(host)
}
return r.ResolveAllIPPrimaryIPv4(host)
} else if DisableIPv6 {
return ResolveAllIPv4(host)
}
ip, err := netip.ParseAddr(host)
if err == nil {
return []netip.Addr{ip}, nil
}
if DefaultResolver == nil {
ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return []netip.Addr{}, err
}
return []netip.Addr{nnip.IpToAddr(ipAddr.IP)}, nil
}
return []netip.Addr{}, ErrIPNotFound
}
@ -265,6 +284,7 @@ func ResolveAllIPv6ProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPv6WithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIPv6(host)
}
@ -272,6 +292,7 @@ func ResolveAllIPv4ProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPv4WithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIPv4(host)
}
@ -279,5 +300,6 @@ func ResolveAllIPProxyServerHost(host string) ([]netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveAllIPWithResolver(host, ProxyServerHostResolver)
}
return ResolveAllIP(host)
}

View File

@ -2,19 +2,18 @@ package sniffer
import (
"errors"
"fmt"
"github.com/Dreamacro/clash/constant/sniffer"
"net"
"net/netip"
"strconv"
"sync"
"time"
"github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/trie"
CN "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/sniffer"
"github.com/Dreamacro/clash/log"
)
@ -24,30 +23,27 @@ var (
ErrNoClue = errors.New("not enough information for making a decision")
)
var Dispatcher *SnifferDispatcher
var Dispatcher SnifferDispatcher
type SnifferDispatcher struct {
enable bool
type (
SnifferDispatcher struct {
enable bool
sniffers []sniffer.Sniffer
sniffers []sniffer.Sniffer
forceDomain *trie.DomainTrie[struct{}]
skipSNI *trie.DomainTrie[struct{}]
portRanges *[]utils.Range[uint16]
skipList *cache.LruCache[string, uint8]
rwMux sync.RWMutex
forceDnsMapping bool
parsePureIp bool
}
foreDomain *trie.DomainTrie[bool]
skipSNI *trie.DomainTrie[bool]
portRanges *[]utils.Range[uint16]
}
)
func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*N.BufferedConn)
bufConn, ok := conn.(*CN.BufferedConn)
if !ok {
return
}
if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) {
if metadata.Host == "" || sd.foreDomain.Search(metadata.Host) != nil {
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
if err != nil {
log.Debugln("[Sniffer] Dst port is error")
@ -66,17 +62,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
return
}
sd.rwMux.RLock()
dst := fmt.Sprintf("%s:%s", metadata.DstIP, metadata.DstPort)
if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
defer sd.rwMux.RUnlock()
return
}
sd.rwMux.RUnlock()
if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return
} else {
@ -85,51 +71,36 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
return
}
sd.rwMux.RLock()
sd.skipList.Delete(dst)
sd.rwMux.RUnlock()
sd.replaceDomain(metadata, host)
}
}
}
func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string) {
dstIP := ""
if metadata.DstIP.IsValid() {
dstIP = metadata.DstIP.String()
}
originHost := metadata.Host
if originHost != host {
log.Infoln("[Sniffer] Sniff TCP [%s:%s]-->[%s:%s] success, replace domain [%s]-->[%s]",
metadata.SrcIP, metadata.SrcPort,
dstIP, metadata.DstPort,
metadata.Host, host)
} else {
log.Debugln("[Sniffer] Sniff TCP [%s:%s]-->[%s:%s] success, replace domain [%s]-->[%s]",
metadata.SrcIP, metadata.SrcPort,
dstIP, metadata.DstPort,
metadata.Host, host)
}
log.Debugln("[Sniffer] Sniff TCP [%s:%s]-->[%s:%s] success, replace domain [%s]-->[%s]",
metadata.SrcIP, metadata.SrcPort,
metadata.DstIP, metadata.DstPort,
metadata.Host, host)
metadata.AddrType = C.AtypDomainName
metadata.Host = host
metadata.DNSMode = C.DNSNormal
metadata.DNSMode = C.DNSMapping
resolver.InsertHostByIP(metadata.DstIP, host)
}
func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
}
func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
for _, s := range sd.sniffers {
if s.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
func (sd *SnifferDispatcher) sniffDomain(conn *CN.BufferedConn, metadata *C.Metadata) (string, error) {
for _, sniffer := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(3 * time.Second))
_, err := conn.Peek(1)
_ = conn.SetReadDeadline(time.Time{})
if err != nil {
_, ok := err.(*net.OpError)
if ok {
sd.cacheSniffFailed(metadata)
log.Errorln("[Sniffer] [%s] may not have any sent data, Consider adding skip", metadata.DstIP.String())
_ = conn.Close()
}
@ -144,15 +115,15 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
continue
}
host, err := s.SniffTCP(bytes)
host, err := sniffer.SniffTCP(bytes)
if err != nil {
//log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
//log.Debugln("[Sniffer] [%s] Sniff data failed %s", sniffer.Protocol(), metadata.DstIP)
continue
}
_, err = netip.ParseAddr(host)
if err == nil {
//log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
//log.Debugln("[Sniffer] [%s] Sniff data failed %s", sniffer.Protocol(), metadata.DstIP)
continue
}
@ -163,17 +134,6 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
return "", ErrorSniffFailed
}
func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
sd.rwMux.Lock()
dst := fmt.Sprintf("%s:%s", metadata.DstIP, metadata.DstPort)
count, _ := sd.skipList.Get(dst)
if count <= 5 {
count++
}
sd.skipList.Set(dst, count)
sd.rwMux.Unlock()
}
func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: false,
@ -182,27 +142,23 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
return &dispatcher, nil
}
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[struct{}],
skipSNI *trie.DomainTrie[struct{}], ports *[]utils.Range[uint16],
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
func NewSnifferDispatcher(needSniffer []sniffer.Type, forceDomain *trie.DomainTrie[bool],
skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16]) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
forceDomain: forceDomain,
skipSNI: skipSNI,
portRanges: ports,
skipList: cache.NewLRUCache[string, uint8](cache.WithSize[string, uint8](128), cache.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp,
enable: true,
foreDomain: forceDomain,
skipSNI: skipSNI,
portRanges: ports,
}
for _, snifferName := range needSniffer {
s, err := NewSniffer(snifferName)
sniffer, err := NewSniffer(snifferName)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err
}
dispatcher.sniffers = append(dispatcher.sniffers, s)
dispatcher.sniffers = append(dispatcher.sniffers, sniffer)
}
return &dispatcher, nil

View File

@ -3,11 +3,9 @@ package sniffer
import (
"bytes"
"errors"
"fmt"
C "github.com/Dreamacro/clash/constant"
"net"
"strings"
C "github.com/Dreamacro/clash/constant"
)
var (
@ -90,32 +88,13 @@ func SniffHTTP(b []byte) (*string, error) {
host, _, err := net.SplitHostPort(rawHost)
if err != nil {
if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
return parseHost(rawHost)
host = rawHost
} else {
return nil, err
}
}
if net.ParseIP(host) != nil {
return nil, fmt.Errorf("host is ip")
}
return &host, nil
}
}
return nil, ErrNoClue
}
func parseHost(host string) (*string, error) {
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
if net.ParseIP(host[1:len(host)-1]) != nil {
return nil, fmt.Errorf("host is ip")
}
}
if net.ParseIP(host) != nil {
return nil, fmt.Errorf("host is ip")
}
return &host, nil
}

View File

@ -1,140 +0,0 @@
package tls
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
xtls "github.com/xtls/go"
"sync"
"time"
)
var globalFingerprints = make([][32]byte, 0, 0)
var mutex sync.Mutex
func verifyPeerCertificateAndFingerprints(fingerprints *[][32]byte, insecureSkipVerify bool) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if insecureSkipVerify {
return nil
}
var preErr error
for i := range rawCerts {
rawCert := rawCerts[i]
cert, err := x509.ParseCertificate(rawCert)
if err == nil {
opts := x509.VerifyOptions{
CurrentTime: time.Now(),
}
if _, err := cert.Verify(opts); err == nil {
return nil
} else {
fingerprint := sha256.Sum256(cert.Raw)
for _, fp := range *fingerprints {
if bytes.Equal(fingerprint[:], fp[:]) {
return nil
}
}
preErr = err
}
}
}
return preErr
}
}
func AddCertFingerprint(fingerprint string) error {
fpByte, err2 := convertFingerprint(fingerprint)
if err2 != nil {
return err2
}
mutex.Lock()
globalFingerprints = append(globalFingerprints, *fpByte)
mutex.Unlock()
return nil
}
func convertFingerprint(fingerprint string) (*[32]byte, error) {
fpByte, err := hex.DecodeString(fingerprint)
if err != nil {
return nil, err
}
if len(fpByte) != 32 {
return nil, fmt.Errorf("fingerprint string length error,need sha25 fingerprint")
}
return (*[32]byte)(fpByte), nil
}
func GetDefaultTLSConfig() *tls.Config {
return GetGlobalFingerprintTLCConfig(nil)
}
// GetSpecifiedFingerprintTLSConfig specified fingerprint
func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
if fingerprintBytes, err := convertFingerprint(fingerprint); err != nil {
return nil, err
} else {
if tlsConfig == nil {
return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(&[][32]byte{*fingerprintBytes}, false),
}, nil
} else {
tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(&[][32]byte{*fingerprintBytes}, tlsConfig.InsecureSkipVerify)
tlsConfig.InsecureSkipVerify = true
return tlsConfig, nil
}
}
}
func GetGlobalFingerprintTLCConfig(tlsConfig *tls.Config) *tls.Config {
if tlsConfig == nil {
return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(&globalFingerprints, false),
}
}
tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(&globalFingerprints, tlsConfig.InsecureSkipVerify)
tlsConfig.InsecureSkipVerify = true
return tlsConfig
}
// GetSpecifiedFingerprintXTLSConfig specified fingerprint
func GetSpecifiedFingerprintXTLSConfig(tlsConfig *xtls.Config, fingerprint string) (*xtls.Config, error) {
if fingerprintBytes, err := convertFingerprint(fingerprint); err != nil {
return nil, err
} else {
if tlsConfig == nil {
return &xtls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(&[][32]byte{*fingerprintBytes}, false),
}, nil
} else {
tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(&[][32]byte{*fingerprintBytes}, tlsConfig.InsecureSkipVerify)
tlsConfig.InsecureSkipVerify = true
return tlsConfig, nil
}
}
}
func GetGlobalFingerprintXTLCConfig(tlsConfig *xtls.Config) *xtls.Config {
if tlsConfig == nil {
return &xtls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateAndFingerprints(&globalFingerprints, false),
}
}
tlsConfig.VerifyPeerCertificate = verifyPeerCertificateAndFingerprints(&globalFingerprints, tlsConfig.InsecureSkipVerify)
tlsConfig.InsecureSkipVerify = true
return tlsConfig
}

View File

@ -17,7 +17,7 @@ var ErrInvalidDomain = errors.New("invalid domain")
// DomainTrie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com)
type DomainTrie[T any] struct {
type DomainTrie[T comparable] struct {
root *Node[T]
}
@ -74,13 +74,13 @@ func (t *DomainTrie[T]) insert(parts []string, data T) {
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
if !node.hasChild(part) {
node.addChild(part, newNode[T]())
node.addChild(part, newNode(getZero[T]()))
}
node = node.getChild(part)
}
node.setData(data)
node.Data = data
}
// Search is the most important part of the Trie.
@ -96,7 +96,7 @@ func (t *DomainTrie[T]) Search(domain string) *Node[T] {
n := t.search(t.root, parts)
if n.isEmpty() {
if n == nil || n.Data == getZero[T]() {
return nil
}
@ -109,13 +109,13 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
}
if c := node.getChild(parts[len(parts)-1]); c != nil {
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
return n
}
}
if c := node.getChild(wildcard); c != nil {
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() {
return n
}
}
@ -124,6 +124,6 @@ func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
}
// New returns a new, empty Trie.
func New[T any]() *DomainTrie[T] {
return &DomainTrie[T]{root: newNode[T]()}
func New[T comparable]() *DomainTrie[T] {
return &DomainTrie[T]{root: newNode[T](getZero[T]())}
}

View File

@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) {
node := tree.Search("example.com")
assert.NotNil(t, node)
assert.True(t, node.Data() == localIP)
assert.True(t, node.Data == localIP)
assert.NotNil(t, tree.Insert("", localIP))
assert.Nil(t, tree.Search(""))
assert.NotNil(t, tree.Search("localhost"))
@ -75,7 +75,7 @@ func TestTrie_Priority(t *testing.T) {
assertFn := func(domain string, data int) {
node := tree.Search(domain)
assert.NotNil(t, node)
assert.Equal(t, data, node.Data())
assert.Equal(t, data, node.Data)
}
for idx, domain := range domains {

View File

@ -1,10 +1,9 @@
package trie
// Node is the trie's node
type Node[T any] struct {
type Node[T comparable] struct {
children map[string]*Node[T]
inited bool
data T
Data T
}
func (n *Node[T]) getChild(s string) *Node[T] {
@ -19,31 +18,14 @@ func (n *Node[T]) addChild(s string, child *Node[T]) {
n.children[s] = child
}
func (n *Node[T]) isEmpty() bool {
if n == nil || n.inited == false {
return true
}
return false
}
func (n *Node[T]) setData(data T) {
n.data = data
n.inited = true
}
func (n *Node[T]) Data() T {
return n.data
}
func newNode[T any]() *Node[T] {
func newNode[T comparable](data T) *Node[T] {
return &Node[T]{
Data: data,
children: map[string]*Node[T]{},
inited: false,
data: getZero[T](),
}
}
func getZero[T any]() T {
func getZero[T comparable]() T {
var result T
return result
}

View File

@ -2,9 +2,10 @@ package config
import (
"container/list"
"encoding/json"
"errors"
"fmt"
"github.com/Dreamacro/clash/constant/sniffer"
"github.com/Dreamacro/clash/listener/tun/ipstack/commons"
"net"
"net/netip"
"net/url"
@ -30,7 +31,6 @@ import (
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
providerTypes "github.com/Dreamacro/clash/constant/provider"
"github.com/Dreamacro/clash/constant/sniffer"
snifferTypes "github.com/Dreamacro/clash/constant/sniffer"
"github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log"
@ -55,24 +55,18 @@ type General struct {
EnableProcess bool `json:"enable-process"`
Tun Tun `json:"tun"`
Sniffing bool `json:"sniffing"`
EBpf EBpf `json:"-"`
}
// Inbound config
type Inbound struct {
Port int `json:"port"`
SocksPort int `json:"socks-port"`
RedirPort int `json:"redir-port"`
TProxyPort int `json:"tproxy-port"`
MixedPort int `json:"mixed-port"`
ShadowSocksConfig string `json:"ss-config"`
VmessConfig string `json:"vmess-config"`
TcpTunConfig string `json:"tcptun-config"`
UdpTunConfig string `json:"udptun-config"`
Authentication []string `json:"authentication"`
AllowLan bool `json:"allow-lan"`
BindAddress string `json:"bind-address"`
InboundTfo bool `json:"inbound-tfo"`
Port int `json:"port"`
SocksPort int `json:"socks-port"`
RedirPort int `json:"redir-port"`
TProxyPort int `json:"tproxy-port"`
MixedPort int `json:"mixed-port"`
Authentication []string `json:"authentication"`
AllowLan bool `json:"allow-lan"`
BindAddress string `json:"bind-address"`
}
// Controller config
@ -85,7 +79,6 @@ type Controller struct {
// DNS config
type DNS struct {
Enable bool `yaml:"enable"`
PreferH3 bool `yaml:"prefer-h3"`
IPv6 bool `yaml:"ipv6"`
NameServer []dns.NameServer `yaml:"nameserver"`
Fallback []dns.NameServer `yaml:"fallback"`
@ -122,73 +115,7 @@ type Tun struct {
DNSHijack []netip.AddrPort `yaml:"dns-hijack" json:"dns-hijack"`
AutoRoute bool `yaml:"auto-route" json:"auto-route"`
AutoDetectInterface bool `yaml:"auto-detect-interface" json:"auto-detect-interface"`
RedirectToTun []string `yaml:"-" json:"-"`
MTU uint32 `yaml:"mtu" json:"mtu,omitempty"`
Inet4Address []ListenPrefix `yaml:"inet4-address" json:"inet4-address,omitempty"`
Inet6Address []ListenPrefix `yaml:"inet6-address" json:"inet6-address,omitempty"`
StrictRoute bool `yaml:"strict-route" json:"strict-route,omitempty"`
Inet4RouteAddress []ListenPrefix `yaml:"inet4-route-address" json:"inet4-route-address,omitempty"`
Inet6RouteAddress []ListenPrefix `yaml:"inet6-route-address" json:"inet6-route-address,omitempty"`
IncludeUID []uint32 `yaml:"include-uid" json:"include-uid,omitempty"`
IncludeUIDRange []string `yaml:"include-uid-range" json:"include-uid-range,omitempty"`
ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude-uid,omitempty"`
ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude-uid-range,omitempty"`
IncludeAndroidUser []int `yaml:"include-android-user" json:"include-android-user,omitempty"`
IncludePackage []string `yaml:"include-package" json:"include-package,omitempty"`
ExcludePackage []string `yaml:"exclude-package" json:"exclude-package,omitempty"`
EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint-independent-nat,omitempty"`
UDPTimeout int64 `yaml:"udp-timeout" json:"udp-timeout,omitempty"`
}
type ListenPrefix netip.Prefix
func (p ListenPrefix) MarshalJSON() ([]byte, error) {
prefix := netip.Prefix(p)
if !prefix.IsValid() {
return json.Marshal(nil)
}
return json.Marshal(prefix.String())
}
func (p ListenPrefix) MarshalYAML() (interface{}, error) {
prefix := netip.Prefix(p)
if !prefix.IsValid() {
return nil, nil
}
return prefix.String(), nil
}
func (p *ListenPrefix) UnmarshalJSON(bytes []byte) error {
var value string
err := json.Unmarshal(bytes, &value)
if err != nil {
return err
}
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*p = ListenPrefix(prefix)
return nil
}
func (p *ListenPrefix) UnmarshalYAML(node *yaml.Node) error {
var value string
err := node.Decode(&value)
if err != nil {
return err
}
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*p = ListenPrefix(prefix)
return nil
}
func (p ListenPrefix) Build() netip.Prefix {
return netip.Prefix(p)
TunAddressPrefix netip.Prefix `yaml:"-" json:"-"`
}
// IPTables config
@ -199,31 +126,27 @@ type IPTables struct {
}
type Sniffer struct {
Enable bool
Sniffers []sniffer.Type
Reverses *trie.DomainTrie[struct{}]
ForceDomain *trie.DomainTrie[struct{}]
SkipDomain *trie.DomainTrie[struct{}]
Ports *[]utils.Range[uint16]
ForceDnsMapping bool
ParsePureIp bool
Enable bool
Sniffers []sniffer.Type
Reverses *trie.DomainTrie[bool]
ForceDomain *trie.DomainTrie[bool]
SkipDomain *trie.DomainTrie[bool]
Ports *[]utils.Range[uint16]
}
// Experimental config
type Experimental struct {
Fingerprints []string `yaml:"fingerprints"`
}
type Experimental struct{}
// Config is clash config manager
type Config struct {
General *General
Tun *Tun
IPTables *IPTables
DNS *DNS
Experimental *Experimental
Hosts *trie.DomainTrie[netip.Addr]
Profile *Profile
Rules []C.Rule
SubRules *map[string][]C.Rule
Users []auth.AuthUser
Proxies map[string]C.Proxy
Providers map[string]providerTypes.ProxyProvider
@ -233,7 +156,6 @@ type Config struct {
type RawDNS struct {
Enable bool `yaml:"enable"`
PreferH3 bool `yaml:"prefer-h3"`
IPv6 bool `yaml:"ipv6"`
UseHosts bool `yaml:"use-hosts"`
NameServer []string `yaml:"nameserver"`
@ -263,23 +185,6 @@ type RawTun struct {
DNSHijack []string `yaml:"dns-hijack" json:"dns-hijack"`
AutoRoute bool `yaml:"auto-route" json:"auto-route"`
AutoDetectInterface bool `yaml:"auto-detect-interface"`
RedirectToTun []string `yaml:"-" json:"-"`
MTU uint32 `yaml:"mtu" json:"mtu,omitempty"`
//Inet4Address []ListenPrefix `yaml:"inet4-address" json:"inet4_address,omitempty"`
Inet6Address []ListenPrefix `yaml:"inet6-address" json:"inet6_address,omitempty"`
StrictRoute bool `yaml:"strict-route" json:"strict_route,omitempty"`
Inet4RouteAddress []ListenPrefix `yaml:"inet4_route_address" json:"inet4_route_address,omitempty"`
Inet6RouteAddress []ListenPrefix `yaml:"inet6_route_address" json:"inet6_route_address,omitempty"`
IncludeUID []uint32 `yaml:"include-uid" json:"include_uid,omitempty"`
IncludeUIDRange []string `yaml:"include-uid-range" json:"include_uid_range,omitempty"`
ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude_uid,omitempty"`
ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude_uid_range,omitempty"`
IncludeAndroidUser []int `yaml:"include-android-user" json:"include_android_user,omitempty"`
IncludePackage []string `yaml:"include-package" json:"include_package,omitempty"`
ExcludePackage []string `yaml:"exclude-package" json:"exclude_package,omitempty"`
EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint_independent_nat,omitempty"`
UDPTimeout int64 `yaml:"udp-timeout" json:"udp_timeout,omitempty"`
}
type RawConfig struct {
@ -288,11 +193,6 @@ type RawConfig struct {
RedirPort int `yaml:"redir-port"`
TProxyPort int `yaml:"tproxy-port"`
MixedPort int `yaml:"mixed-port"`
ShadowSocksConfig string `yaml:"ss-config"`
VmessConfig string `yaml:"vmess-config"`
TcpTunConfig string `yaml:"tcptun-config"`
UdpTunConfig string `yaml:"udptun-config"`
InboundTfo bool `yaml:"inbound-tfo"`
Authentication []string `yaml:"authentication"`
AllowLan bool `yaml:"allow-lan"`
BindAddress string `yaml:"bind-address"`
@ -316,7 +216,6 @@ type RawConfig struct {
Hosts map[string]string `yaml:"hosts"`
DNS RawDNS `yaml:"dns"`
Tun RawTun `yaml:"tun"`
EBpf EBpf `yaml:"ebpf"`
IPTables IPTables `yaml:"iptables"`
Experimental Experimental `yaml:"experimental"`
Profile Profile `yaml:"profile"`
@ -324,7 +223,6 @@ type RawConfig struct {
Proxy []map[string]any `yaml:"proxies"`
ProxyGroup []map[string]any `yaml:"proxy-groups"`
Rule []string `yaml:"rules"`
SubRules map[string][]string `yaml:"sub-rules"`
}
type RawGeoXUrl struct {
@ -334,27 +232,13 @@ type RawGeoXUrl struct {
}
type RawSniffer struct {
Enable bool `yaml:"enable" json:"enable"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
Enable bool `yaml:"enable" json:"enable"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
}
// EBpf config
type EBpf struct {
RedirectToTun []string `yaml:"redirect-to-tun" json:"redirect-to-tun"`
AutoRedir []string `yaml:"auto-redir" json:"auto-redir"`
}
var (
GroupsList = list.New()
ProxiesList = list.New()
ParsingProxiesCallback func(groupsList *list.List, proxiesList *list.List)
)
// Parse config
func Parse(buf []byte) (*Config, error) {
rawCfg, err := UnmarshalRawConfig(buf)
@ -388,13 +272,8 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
Device: "",
Stack: C.TunGvisor,
DNSHijack: []string{"0.0.0.0:53"}, // default hijack all dns query
AutoRoute: true,
AutoDetectInterface: true,
Inet6Address: []ListenPrefix{ListenPrefix(netip.MustParsePrefix("fdfe:dcba:9876::1/126"))},
},
EBpf: EBpf{
RedirectToTun: []string{},
AutoRedir: []string{},
AutoRoute: false,
AutoDetectInterface: false,
},
IPTables: IPTables{
Enable: false,
@ -430,13 +309,11 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
},
},
Sniffer: RawSniffer{
Enable: false,
Sniffing: []string{},
ForceDomain: []string{},
SkipDomain: []string{},
Ports: []string{},
ForceDnsMapping: true,
ParsePureIp: true,
Enable: false,
Sniffing: []string{},
ForceDomain: []string{},
SkipDomain: []string{},
Ports: []string{},
},
Profile: Profile{
StoreSelected: true,
@ -478,18 +355,12 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
config.Proxies = proxies
config.Providers = providers
subRules, ruleProviders, err := parseSubRules(rawCfg, proxies)
if err != nil {
return nil, err
}
config.SubRules = subRules
config.RuleProviders = ruleProviders
rules, err := parseRules(rawCfg, proxies, subRules)
rules, ruleProviders, err := parseRules(rawCfg, proxies)
if err != nil {
return nil, err
}
config.Rules = rules
config.RuleProviders = ruleProviders
hosts, err := parseHosts(rawCfg)
if err != nil {
@ -503,10 +374,11 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
}
config.DNS = dnsCfg
err = parseTun(rawCfg.Tun, config.General)
tunCfg, err := parseTun(rawCfg.Tun, config.General, dnsCfg)
if err != nil {
return nil, err
}
config.Tun = tunCfg
config.Users = parseAuthentication(rawCfg.Authentication)
@ -531,21 +403,16 @@ func parseGeneral(cfg *RawConfig) (*General, error) {
return nil, fmt.Errorf("external-ui: %s not exist", externalUI)
}
}
cfg.Tun.RedirectToTun = cfg.EBpf.RedirectToTun
return &General{
Inbound: Inbound{
Port: cfg.Port,
SocksPort: cfg.SocksPort,
RedirPort: cfg.RedirPort,
TProxyPort: cfg.TProxyPort,
MixedPort: cfg.MixedPort,
ShadowSocksConfig: cfg.ShadowSocksConfig,
VmessConfig: cfg.VmessConfig,
TcpTunConfig: cfg.TcpTunConfig,
UdpTunConfig: cfg.UdpTunConfig,
AllowLan: cfg.AllowLan,
BindAddress: cfg.BindAddress,
InboundTfo: cfg.InboundTfo,
Port: cfg.Port,
SocksPort: cfg.SocksPort,
RedirPort: cfg.RedirPort,
TProxyPort: cfg.TProxyPort,
MixedPort: cfg.MixedPort,
AllowLan: cfg.AllowLan,
BindAddress: cfg.BindAddress,
},
Controller: Controller{
ExternalController: cfg.ExternalController,
@ -562,7 +429,6 @@ func parseGeneral(cfg *RawConfig) (*General, error) {
GeodataLoader: cfg.GeodataLoader,
TCPConcurrent: cfg.TCPConcurrent,
EnableProcess: cfg.EnableProcess,
EBpf: cfg.EBpf,
}, nil
}
@ -660,18 +526,12 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
[]providerTypes.ProxyProvider{pd},
)
proxies["GLOBAL"] = adapter.NewProxy(global)
ProxiesList = proxiesList
GroupsList = groupsList
if ParsingProxiesCallback != nil {
// refresh tray menu
go ParsingProxiesCallback(GroupsList, ProxiesList)
}
return proxies, providersMap, nil
}
func parseSubRules(cfg *RawConfig, proxies map[string]C.Proxy) (subRules *map[string][]C.Rule, ruleProviders map[string]providerTypes.RuleProvider, err error) {
ruleProviders = map[string]providerTypes.RuleProvider{}
subRules = &map[string][]C.Rule{}
func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, map[string]providerTypes.RuleProvider, error) {
ruleProviders := map[string]providerTypes.RuleProvider{}
log.Infoln("Geodata Loader mode: %s", geodata.LoaderName())
// parse rule provider
for name, mapping := range cfg.RuleProvider {
@ -684,102 +544,6 @@ func parseSubRules(cfg *RawConfig, proxies map[string]C.Proxy) (subRules *map[st
RP.SetRuleProvider(rp)
}
for name, rawRules := range cfg.SubRules {
var rules []C.Rule
for idx, line := range rawRules {
rawRule := trimArr(strings.Split(line, ","))
var (
payload string
target string
params []string
ruleName = strings.ToUpper(rawRule[0])
)
l := len(rawRule)
if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" || ruleName == "SUB-RULE" {
target = rawRule[l-1]
payload = strings.Join(rawRule[1:l-1], ",")
} else {
if l < 2 {
return nil, nil, fmt.Errorf("sub-rules[%d] [%s] error: format invalid", idx, line)
}
if l < 4 {
rawRule = append(rawRule, make([]string, 4-l)...)
}
if ruleName == "MATCH" {
l = 2
}
if l >= 3 {
l = 3
payload = rawRule[1]
}
target = rawRule[l-1]
params = rawRule[l:]
}
if _, ok := proxies[target]; !ok && ruleName != "SUB-RULE" {
return nil, nil, fmt.Errorf("sub-rules[%d:%s] [%s] error: proxy [%s] not found", idx, name, line, target)
}
params = trimArr(params)
parsed, parseErr := R.ParseRule(ruleName, payload, target, params, subRules)
if parseErr != nil {
return nil, nil, fmt.Errorf("sub-rules[%d] [%s] error: %s", idx, line, parseErr.Error())
}
rules = append(rules, parsed)
}
(*subRules)[name] = rules
}
if err = verifySubRule(subRules); err != nil {
return nil, nil, err
}
return
}
func verifySubRule(subRules *map[string][]C.Rule) error {
for name := range *subRules {
err := verifySubRuleCircularReferences(name, subRules, []string{})
if err != nil {
return err
}
}
return nil
}
func verifySubRuleCircularReferences(n string, subRules *map[string][]C.Rule, arr []string) error {
isInArray := func(v string, array []string) bool {
for _, c := range array {
if v == c {
return true
}
}
return false
}
arr = append(arr, n)
for i, rule := range (*subRules)[n] {
if rule.RuleType() == C.SubRules {
if _, ok := (*subRules)[rule.Adapter()]; !ok {
return fmt.Errorf("sub-rule[%d:%s] error: [%s] not found", i, n, rule.Adapter())
}
if isInArray(rule.Adapter(), arr) {
arr = append(arr, rule.Adapter())
return fmt.Errorf("sub-rule error: circular references [%s]", strings.Join(arr, "->"))
}
if err := verifySubRuleCircularReferences(rule.Adapter(), subRules, arr); err != nil {
return err
}
}
}
return nil
}
func parseRules(cfg *RawConfig, proxies map[string]C.Proxy, subRules *map[string][]C.Rule) ([]C.Rule, error) {
var rules []C.Rule
rulesConfig := cfg.Rule
@ -795,12 +559,12 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy, subRules *map[string
l := len(rule)
if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" || ruleName == "SUB-RULE" {
if ruleName == "NOT" || ruleName == "OR" || ruleName == "AND" {
target = rule[l-1]
payload = strings.Join(rule[1:l-1], ",")
} else {
if l < 2 {
return nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line)
return nil, nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line)
}
if l < 4 {
rule = append(rule, make([]string, 4-l)...)
@ -815,18 +579,15 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy, subRules *map[string
target = rule[l-1]
params = rule[l:]
}
if _, ok := proxies[target]; !ok {
if ruleName != "SUB-RULE" {
return nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target)
} else if _, ok = (*subRules)[target]; !ok {
return nil, fmt.Errorf("rules[%d] [%s] error: sub-rule [%s] not found", idx, line, target)
}
return nil, nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target)
}
params = trimArr(params)
parsed, parseErr := R.ParseRule(ruleName, payload, target, params, subRules)
parsed, parseErr := R.ParseRule(ruleName, payload, target, params)
if parseErr != nil {
return nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error())
return nil, nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error())
}
rules = append(rules, parsed)
@ -834,7 +595,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy, subRules *map[string
runtime.GC()
return rules, nil
return rules, ruleProviders, nil
}
func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) {
@ -875,7 +636,7 @@ func hostWithDefaultPort(host string, defPort string) (string, error) {
return net.JoinHostPort(hostname, port), nil
}
func parseNameServer(servers []string, preferH3 bool) ([]dns.NameServer, error) {
func parseNameServer(servers []string) ([]dns.NameServer, error) {
var nameservers []dns.NameServer
for idx, server := range servers {
@ -888,8 +649,7 @@ func parseNameServer(servers []string, preferH3 bool) ([]dns.NameServer, error)
return nil, fmt.Errorf("DNS NameServer[%d] format error: %s", idx, err.Error())
}
var addr, dnsNetType, proxyAdapter string
params := map[string]string{}
var addr, dnsNetType string
switch u.Scheme {
case "udp":
addr, err = hostWithDefaultPort(u.Host, "53")
@ -901,31 +661,9 @@ func parseNameServer(servers []string, preferH3 bool) ([]dns.NameServer, error)
addr, err = hostWithDefaultPort(u.Host, "853")
dnsNetType = "tcp-tls" // DNS over TLS
case "https":
host := u.Host
if _, _, err := net.SplitHostPort(host); err != nil && strings.Contains(err.Error(), "missing port in address") {
host = net.JoinHostPort(host, "443")
} else {
if err!=nil{
return nil,err
}
}
clearURL := url.URL{Scheme: "https", Host: host, Path: u.Path}
clearURL := url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
addr = clearURL.String()
dnsNetType = "https" // DNS over HTTPS
if len(u.Fragment) != 0 {
for _, s := range strings.Split(u.Fragment, "&") {
arr := strings.Split(s, "=")
if len(arr) == 0 {
continue
} else if len(arr) == 1 {
proxyAdapter = arr[0]
} else if len(arr) == 2 {
params[arr[0]] = arr[1]
} else {
params[arr[0]] = strings.Join(arr[1:], "=")
}
}
}
case "dhcp":
addr = u.Host
dnsNetType = "dhcp" // UDP from DHCP
@ -945,21 +683,19 @@ func parseNameServer(servers []string, preferH3 bool) ([]dns.NameServer, error)
dns.NameServer{
Net: dnsNetType,
Addr: addr,
ProxyAdapter: proxyAdapter,
ProxyAdapter: u.Fragment,
Interface: dialer.DefaultInterface,
Params: params,
PreferH3: preferH3,
},
)
}
return nameservers, nil
}
func parseNameServerPolicy(nsPolicy map[string]string, preferH3 bool) (map[string]dns.NameServer, error) {
func parseNameServerPolicy(nsPolicy map[string]string) (map[string]dns.NameServer, error) {
policy := map[string]dns.NameServer{}
for domain, server := range nsPolicy {
nameservers, err := parseNameServer([]string{server}, preferH3)
nameservers, err := parseNameServer([]string{server})
if err != nil {
return nil, err
}
@ -1030,7 +766,6 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
dnsCfg := &DNS{
Enable: cfg.Enable,
Listen: cfg.Listen,
PreferH3: cfg.PreferH3,
IPv6: cfg.IPv6,
EnhancedMode: cfg.EnhancedMode,
FallbackFilter: FallbackFilter{
@ -1039,26 +774,26 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
},
}
var err error
if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer, cfg.PreferH3); err != nil {
if dnsCfg.NameServer, err = parseNameServer(cfg.NameServer); err != nil {
return nil, err
}
if dnsCfg.Fallback, err = parseNameServer(cfg.Fallback, cfg.PreferH3); err != nil {
if dnsCfg.Fallback, err = parseNameServer(cfg.Fallback); err != nil {
return nil, err
}
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, cfg.PreferH3); err != nil {
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy); err != nil {
return nil, err
}
if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver, cfg.PreferH3); err != nil {
if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil {
return nil, err
}
if len(cfg.DefaultNameserver) == 0 {
return nil, errors.New("default nameserver should have at least one nameserver")
}
if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver, cfg.PreferH3); err != nil {
if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver); err != nil {
return nil, err
}
// check default nameserver is pure ip addr
@ -1066,44 +801,41 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
host, _, err := net.SplitHostPort(ns.Addr)
if err != nil || net.ParseIP(host) == nil {
u, err := url.Parse(ns.Addr)
if err == nil && net.ParseIP(u.Host) == nil {
if ip, _, err := net.SplitHostPort(u.Host); err != nil || net.ParseIP(ip) == nil {
return nil, errors.New("default nameserver should be pure IP")
}
if err != nil || net.ParseIP(u.Host) == nil {
return nil, errors.New("default nameserver should be pure IP")
}
}
}
fakeIPRange, err := netip.ParsePrefix(cfg.FakeIPRange)
T.SetFakeIPRange(fakeIPRange)
if cfg.EnhancedMode == C.DNSFakeIP {
ipnet, err := netip.ParsePrefix(cfg.FakeIPRange)
if err != nil {
return nil, err
}
var host *trie.DomainTrie[struct{}]
var host *trie.DomainTrie[bool]
// fake ip skip host filter
if len(cfg.FakeIPFilter) != 0 {
host = trie.New[struct{}]()
host = trie.New[bool]()
for _, domain := range cfg.FakeIPFilter {
_ = host.Insert(domain, struct{}{})
_ = host.Insert(domain, true)
}
}
if len(dnsCfg.Fallback) != 0 {
if host == nil {
host = trie.New[struct{}]()
host = trie.New[bool]()
}
for _, fb := range dnsCfg.Fallback {
if net.ParseIP(fb.Addr) != nil {
continue
}
_ = host.Insert(fb.Addr, struct{}{})
_ = host.Insert(fb.Addr, true)
}
}
pool, err := fakeip.New(fakeip.Options{
IPNet: &fakeIPRange,
IPNet: &ipnet,
Size: 1000,
Host: host,
Persistence: rawCfg.Profile.StoreFakeIP,
@ -1146,7 +878,18 @@ func parseAuthentication(rawRecords []string) []auth.AuthUser {
return users
}
func parseTun(rawTun RawTun, general *General) error {
func parseTun(rawTun RawTun, general *General, dnsCfg *DNS) (*Tun, error) {
if rawTun.Enable && rawTun.AutoDetectInterface {
autoDetectInterfaceName, err := commons.GetAutoDetectInterface()
if err != nil {
log.Warnln("Can not find auto detect interface.[%s]", err)
} else {
log.Warnln("Auto detect interface: %s", autoDetectInterfaceName)
}
general.Interface = autoDetectInterfaceName
}
var dnsHijack []netip.AddrPort
for _, d := range rawTun.DNSHijack {
@ -1156,58 +899,38 @@ func parseTun(rawTun RawTun, general *General) error {
d = strings.Replace(d, "any", "0.0.0.0", 1)
addrPort, err := netip.ParseAddrPort(d)
if err != nil {
return fmt.Errorf("parse dns-hijack url error: %w", err)
return nil, fmt.Errorf("parse dns-hijack url error: %w", err)
}
dnsHijack = append(dnsHijack, addrPort)
}
tunAddressPrefix := T.FakeIPRange()
if !tunAddressPrefix.IsValid() {
var tunAddressPrefix netip.Prefix
if dnsCfg.FakeIPRange != nil {
tunAddressPrefix = *dnsCfg.FakeIPRange.IPNet()
} else {
tunAddressPrefix = netip.MustParsePrefix("198.18.0.1/16")
}
tunAddressPrefix = netip.PrefixFrom(tunAddressPrefix.Addr(), 30)
general.Tun = Tun{
return &Tun{
Enable: rawTun.Enable,
Device: rawTun.Device,
Stack: rawTun.Stack,
DNSHijack: dnsHijack,
AutoRoute: rawTun.AutoRoute,
AutoDetectInterface: rawTun.AutoDetectInterface,
RedirectToTun: rawTun.RedirectToTun,
MTU: rawTun.MTU,
Inet4Address: []ListenPrefix{ListenPrefix(tunAddressPrefix)},
Inet6Address: rawTun.Inet6Address,
StrictRoute: rawTun.StrictRoute,
Inet4RouteAddress: rawTun.Inet4RouteAddress,
Inet6RouteAddress: rawTun.Inet6RouteAddress,
IncludeUID: rawTun.IncludeUID,
IncludeUIDRange: rawTun.IncludeUIDRange,
ExcludeUID: rawTun.ExcludeUID,
ExcludeUIDRange: rawTun.ExcludeUIDRange,
IncludeAndroidUser: rawTun.IncludeAndroidUser,
IncludePackage: rawTun.IncludePackage,
ExcludePackage: rawTun.ExcludePackage,
EndpointIndependentNat: rawTun.EndpointIndependentNat,
UDPTimeout: rawTun.UDPTimeout,
}
return nil
TunAddressPrefix: tunAddressPrefix,
}, nil
}
func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
sniffer := &Sniffer{
Enable: snifferRaw.Enable,
ForceDnsMapping: snifferRaw.ForceDnsMapping,
ParsePureIp: snifferRaw.ParsePureIp,
Enable: snifferRaw.Enable,
}
var ports []utils.Range[uint16]
if len(snifferRaw.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](80, 80))
ports = append(ports, *utils.NewRange[uint16](443, 443))
ports = append(ports, *utils.NewRange[uint16](0, 65535))
} else {
for _, portRange := range snifferRaw.Ports {
portRaws := strings.Split(portRange, "-")
@ -1252,17 +975,17 @@ func parseSniffer(snifferRaw RawSniffer) (*Sniffer, error) {
for st := range loadSniffer {
sniffer.Sniffers = append(sniffer.Sniffers, st)
}
sniffer.ForceDomain = trie.New[struct{}]()
sniffer.ForceDomain = trie.New[bool]()
for _, domain := range snifferRaw.ForceDomain {
err := sniffer.ForceDomain.Insert(domain, struct{}{})
err := sniffer.ForceDomain.Insert(domain, true)
if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
}
}
sniffer.SkipDomain = trie.New[struct{}]()
sniffer.SkipDomain = trie.New[bool]()
for _, domain := range snifferRaw.SkipDomain {
err := sniffer.SkipDomain.Insert(domain, struct{}{})
err := sniffer.SkipDomain.Insert(domain, true)
if err != nil {
return nil, fmt.Errorf("error domian[%s] in force-domain, error:%v", domain, err)
}

View File

@ -6,9 +6,8 @@ import (
_ "github.com/Dreamacro/clash/component/geodata/standard"
C "github.com/Dreamacro/clash/constant"
"github.com/oschwald/geoip2-golang"
"io"
"io/ioutil"
"net/http"
"os"
"runtime"
)
@ -73,9 +72,9 @@ func downloadForBytes(url string) ([]byte, error) {
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
return ioutil.ReadAll(resp.Body)
}
func saveFile(bytes []byte, path string) error {
return os.WriteFile(path, bytes, 0o644)
return ioutil.WriteFile(path, bytes, 0o644)
}

View File

@ -31,7 +31,6 @@ const (
Vless
Trojan
Hysteria
WireGuard
)
const (
@ -107,13 +106,12 @@ type ProxyAdapter interface {
ListenPacketOnStreamConn(c net.Conn, metadata *Metadata) (PacketConn, error)
// Unwrap extracts the proxy from a proxy-group. It returns nil when nothing to extract.
Unwrap(metadata *Metadata, touch bool) Proxy
Unwrap(metadata *Metadata) Proxy
}
type Group interface {
URLTest(ctx context.Context, url string) (mp map[string]uint16, err error)
GetProxies(touch bool) []Proxy
Touch()
}
type DelayHistory struct {
@ -166,8 +164,6 @@ func (at AdapterType) String() string {
return "Trojan"
case Hysteria:
return "Hysteria"
case WireGuard:
return "WireGuard"
case Relay:
return "Relay"
@ -202,7 +198,3 @@ type UDPPacket interface {
// LocalAddr returns the source IP/Port of packet
LocalAddr() net.Addr
}
type UDPPacketInAddr interface {
InAddr() net.Addr
}

View File

@ -16,7 +16,6 @@ const (
DNSNormal DNSMode = iota
DNSFakeIP
DNSMapping
DNSHosts
)
type DNSMode int
@ -65,63 +64,7 @@ func (e DNSMode) String() string {
return "fake-ip"
case DNSMapping:
return "redir-host"
case DNSHosts:
return "hosts"
default:
return "unknown"
}
}
type DNSPrefer int
const (
DualStack DNSPrefer = iota
IPv4Only
IPv6Only
IPv4Prefer
IPv6Prefer
)
var dnsPreferMap = map[string]DNSPrefer{
DualStack.String(): DualStack,
IPv4Only.String(): IPv4Only,
IPv6Only.String(): IPv6Only,
IPv4Prefer.String(): IPv4Prefer,
IPv6Prefer.String(): IPv6Prefer,
}
func (d DNSPrefer) String() string {
switch d {
case DualStack:
return "dual"
case IPv4Only:
return "ipv4"
case IPv6Only:
return "ipv6"
case IPv4Prefer:
return "ipv4-prefer"
case IPv6Prefer:
return "ipv6-prefer"
default:
return "dual"
}
}
func NewDNSPrefer(prefer string) DNSPrefer {
if p, ok := dnsPreferMap[prefer]; ok {
return p
} else {
return DualStack
}
}
type HTTPVersion string
const (
// HTTPVersion11 is HTTP/1.1.
HTTPVersion11 HTTPVersion = "http/1.1"
// HTTPVersion2 is HTTP/2.
HTTPVersion2 HTTPVersion = "h2"
// HTTPVersion3 is HTTP/3.
HTTPVersion3 HTTPVersion = "h3"
)

View File

@ -1,20 +0,0 @@
package constant
import (
"net/netip"
"github.com/Dreamacro/clash/transport/socks5"
)
const (
BpfFSPath = "/sys/fs/bpf/clash"
TcpAutoRedirPort = 't'<<8 | 'r'<<0
ClashTrafficMark = 'c'<<24 | 'l'<<16 | 't'<<8 | 'm'<<0
)
type EBpf interface {
Start() error
Close()
Lookup(srcAddrPort netip.AddrPort) (socks5.Addr, error)
}

View File

@ -1,15 +1,7 @@
package constant
import "net"
type Listener interface {
RawAddress() string
Address() string
Close() error
}
type AdvanceListener interface {
Close()
Config() string
HandleConn(conn net.Conn, in chan<- ConnContext)
}

View File

@ -6,12 +6,14 @@ import (
"net"
"net/netip"
"strconv"
"github.com/Dreamacro/clash/transport/socks5"
)
// Socks addr type
const (
AtypIPv4 = 1
AtypDomainName = 3
AtypIPv6 = 4
TCP NetWork = iota
UDP
ALLNet
@ -20,12 +22,8 @@ const (
HTTPS
SOCKS4
SOCKS5
SHADOWSOCKS
VMESS
REDIR
TPROXY
TCPTUN
UDPTUN
TUN
INNER
)
@ -57,18 +55,10 @@ func (t Type) String() string {
return "Socks4"
case SOCKS5:
return "Socks5"
case SHADOWSOCKS:
return "ShadowSocks"
case VMESS:
return "Vmess"
case REDIR:
return "Redir"
case TPROXY:
return "TProxy"
case TCPTUN:
return "TcpTun"
case UDPTUN:
return "UdpTun"
case TUN:
return "Tun"
case INNER:
@ -115,8 +105,7 @@ type Metadata struct {
DstIP netip.Addr `json:"destinationIP"`
SrcPort string `json:"sourcePort"`
DstPort string `json:"destinationPort"`
InIP netip.Addr `json:"inboundIP"`
InPort string `json:"inboundPort"`
AddrType int `json:"-"`
Host string `json:"host"`
DNSMode DNSMode `json:"dnsMode"`
Uid *int32 `json:"uid"`
@ -149,17 +138,6 @@ func (m *Metadata) SourceDetail() string {
}
}
func (m *Metadata) AddrType() int {
switch true {
case m.Host != "" || !m.DstIP.IsValid():
return socks5.AtypDomainName
case m.DstIP.Is4():
return socks5.AtypIPv4
default:
return socks5.AtypIPv6
}
}
func (m *Metadata) Resolved() bool {
return m.DstIP.IsValid()
}
@ -167,9 +145,14 @@ func (m *Metadata) Resolved() bool {
// Pure is used to solve unexpected behavior
// when dialing proxy connection in DNSMapping mode.
func (m *Metadata) Pure() *Metadata {
if (m.DNSMode == DNSMapping || m.DNSMode == DNSHosts) && m.DstIP.IsValid() {
if m.DNSMode == DNSMapping && m.DstIP.IsValid() {
copyM := *m
copyM.Host = ""
if copyM.DstIP.Is4() {
copyM.AddrType = AtypIPv4
} else {
copyM.AddrType = AtypIPv6
}
return &copyM
}

View File

@ -1,6 +1,7 @@
package constant
import (
"io/ioutil"
"os"
P "path"
"path/filepath"
@ -57,7 +58,7 @@ func (p *path) Resolve(path string) string {
}
func (p *path) MMDB() string {
files, err := os.ReadDir(p.homeDir)
files, err := ioutil.ReadDir(p.homeDir)
if err != nil {
return ""
}
@ -84,7 +85,7 @@ func (p *path) Cache() string {
}
func (p *path) GeoIP() string {
files, err := os.ReadDir(p.homeDir)
files, err := ioutil.ReadDir(p.homeDir)
if err != nil {
return ""
}
@ -103,7 +104,7 @@ func (p *path) GeoIP() string {
}
func (p *path) GeoSite() string {
files, err := os.ReadDir(p.homeDir)
files, err := ioutil.ReadDir(p.homeDir)
if err != nil {
return ""
}

View File

@ -68,7 +68,7 @@ type ProxyProvider interface {
Proxies() []C.Proxy
Touch()
HealthCheck()
Version() uint32
Version() uint
}
// Rule Type

View File

@ -13,14 +13,12 @@ const (
SrcIPSuffix
SrcPort
DstPort
InPort
Process
ProcessPath
RuleSet
Network
Uid
INTYPE
SubRules
MATCH
AND
OR
@ -53,8 +51,6 @@ func (rt RuleType) String() string {
return "SrcPort"
case DstPort:
return "DstPort"
case InPort:
return "InPort"
case Process:
return "Process"
case ProcessPath:
@ -69,8 +65,6 @@ func (rt RuleType) String() string {
return "Uid"
case INTYPE:
return "InType"
case SubRules:
return "SubRules"
case AND:
return "AND"
case OR:
@ -84,9 +78,11 @@ func (rt RuleType) String() string {
type Rule interface {
RuleType() RuleType
Match(metadata *Metadata) (bool, string)
Match(metadata *Metadata) bool
Adapter() string
Payload() string
ShouldResolveIP() bool
ShouldFindProcess() bool
RuleExtra() *RuleExtra
SetRuleExtra(re *RuleExtra)
}

View File

@ -1,9 +1,48 @@
package constant
import (
"net/netip"
"strings"
"github.com/Dreamacro/clash/component/geodata/router"
)
type RuleExtra struct {
Network NetWork
SourceIPs []*netip.Prefix
ProcessNames []string
}
func (re *RuleExtra) NotMatchNetwork(network NetWork) bool {
return re.Network != ALLNet && re.Network != network
}
func (re *RuleExtra) NotMatchSourceIP(srcIP netip.Addr) bool {
if re.SourceIPs == nil {
return false
}
for _, ips := range re.SourceIPs {
if ips.Contains(srcIP) {
return false
}
}
return true
}
func (re *RuleExtra) NotMatchProcessName(processName string) bool {
if re.ProcessNames == nil {
return false
}
for _, pn := range re.ProcessNames {
if strings.EqualFold(pn, processName) {
return false
}
}
return true
}
type RuleGeoSite interface {
GetDomainMatcher() *router.DomainMatcher
}

View File

@ -7,15 +7,13 @@ import (
)
var StackTypeMapping = map[string]TUNStack{
strings.ToLower(TunGvisor.String()): TunGvisor,
strings.ToLower(TunSystem.String()): TunSystem,
strings.ToLower(TunLWIP.String()): TunLWIP,
strings.ToUpper(TunGvisor.String()): TunGvisor,
strings.ToUpper(TunSystem.String()): TunSystem,
}
const (
TunGvisor TUNStack = iota
TunSystem
TunLWIP
)
type TUNStack int
@ -26,7 +24,7 @@ func (e *TUNStack) UnmarshalYAML(unmarshal func(any) error) error {
if err := unmarshal(&tp); err != nil {
return err
}
mode, exist := StackTypeMapping[strings.ToLower(tp)]
mode, exist := StackTypeMapping[strings.ToUpper(tp)]
if !exist {
return errors.New("invalid tun stack")
}
@ -43,7 +41,7 @@ func (e TUNStack) MarshalYAML() (any, error) {
func (e *TUNStack) UnmarshalJSON(data []byte) error {
var tp string
json.Unmarshal(data, &tp)
mode, exist := StackTypeMapping[strings.ToLower(tp)]
mode, exist := StackTypeMapping[strings.ToUpper(tp)]
if !exist {
return errors.New("invalid tun stack")
}
@ -62,8 +60,6 @@ func (e TUNStack) String() string {
return "gVisor"
case TunSystem:
return "System"
case TunLWIP:
return "LWIP"
default:
return "unknown"
}

View File

@ -3,9 +3,8 @@ package context
import (
"net"
N "github.com/Dreamacro/clash/common/net"
CN "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
@ -17,11 +16,10 @@ type ConnContext struct {
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
id, _ := uuid.NewV4()
return &ConnContext{
id: id,
metadata: metadata,
conn: N.NewBufferedConn(conn),
conn: CN.NewBufferedConn(conn),
}
}

View File

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
"go.uber.org/atomic"
"net"
"net/netip"
@ -78,7 +77,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
ch := make(chan result, 1)
go func() {
if strings.HasSuffix(c.Client.Net, "tls") {
conn = tls.Client(conn, tlsC.GetGlobalFingerprintTLCConfig(c.Client.TLSConfig))
conn = tls.Client(conn, c.Client.TLSConfig)
}
msg, _, err := c.Client.ExchangeWithConn(m, &D.Conn{

View File

@ -1,730 +1,108 @@
package dns
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"runtime"
"strconv"
"sync"
"time"
"github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/miekg/dns"
"github.com/Dreamacro/clash/component/resolver"
D "github.com/miekg/dns"
"golang.org/x/net/http2"
)
// Values to configure HTTP and HTTP/2 transport.
const (
// transportDefaultReadIdleTimeout is the default timeout for pinging
// idle connections in HTTP/2 transport.
transportDefaultReadIdleTimeout = 30 * time.Second
// transportDefaultIdleConnTimeout is the default timeout for idle
// connections in HTTP transport.
transportDefaultIdleConnTimeout = 5 * time.Minute
// dohMaxConnsPerHost controls the maximum number of connections for
// each host.
dohMaxConnsPerHost = 1
dialTimeout = 10 * time.Second
// dohMaxIdleConns controls the maximum number of connections being idle
// at the same time.
dohMaxIdleConns = 1
maxElapsedTime = time.Second * 30
// dotMimeType is the DoH mimetype that should be used.
dotMimeType = "application/dns-message"
)
var DefaultHTTPVersions = []C.HTTPVersion{C.HTTPVersion11, C.HTTPVersion2}
// dnsOverHTTPS is a struct that implements the Upstream interface for the
// DNS-over-HTTPS protocol.
type dnsOverHTTPS struct {
// The Client's Transport typically has internal state (cached TCP
// connections), so Clients should be reused instead of created as
// needed. Clients are safe for concurrent use by multiple goroutines.
client *http.Client
clientMu sync.Mutex
// quicConfig is the QUIC configuration that is used if HTTP/3 is enabled
// for this upstream.
quicConfig *quic.Config
quicConfigGuard sync.Mutex
url *url.URL
r *Resolver
httpVersions []C.HTTPVersion
proxyAdapter string
type dohClient struct {
url string
proxyAdapter string
transport *http.Transport
}
// type check
var _ dnsClient = (*dnsOverHTTPS)(nil)
// newDoH returns the DNS-over-HTTPS Upstream.
func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[string]string, proxyAdapter string) dnsClient {
u, _ := url.Parse(urlString)
httpVersions := DefaultHTTPVersions
if preferH3 {
httpVersions = append(httpVersions, C.HTTPVersion3)
}
if params["h3"] == "true" {
httpVersions = []C.HTTPVersion{C.HTTPVersion3}
}
doh := &dnsOverHTTPS{
url: u,
r: r,
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
},
httpVersions: httpVersions,
}
runtime.SetFinalizer(doh, (*dnsOverHTTPS).Close)
return doh
func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
return dc.ExchangeContext(context.Background(), m)
}
// Address implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Address() string { return p.url.String() }
func (p *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
// Quote from https://www.rfc-editor.org/rfc/rfc8484.html:
// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such
// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
// request.
id := m.Id
m.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
m.Id = id
if msg != nil {
msg.Id = id
}
}()
// Check if there was already an active client before sending the request.
// We'll only attempt to re-connect if there was one.
client, isCached, err := p.getClient()
func (dc *dohClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
// https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
// In order to maximize cache friendliness, SHOULD use a DNS ID of 0 in every DNS request.
newM := *m
newM.Id = 0
req, err := dc.newRequest(&newM)
if err != nil {
return nil, fmt.Errorf("failed to init http client: %w", err)
return nil, err
}
// Make the first attempt to send the DNS query.
msg, err = p.exchangeHTTPS(ctx, client, m)
// Make up to 2 attempts to re-create the HTTP client and send the request
// again. There are several cases (mostly, with QUIC) where this workaround
// is necessary to make HTTP client usable. We need to make 2 attempts in
// the case when the connection was closed (due to inactivity for example)
// AND the server refuses to open a 0-RTT connection.
for i := 0; isCached && p.shouldRetry(err) && i < 2; i++ {
client, err = p.resetClient(err)
if err != nil {
return nil, fmt.Errorf("failed to reset http client: %w", err)
}
msg, err = p.exchangeHTTPS(ctx, client, m)
req = req.WithContext(ctx)
msg, err = dc.doRequest(req)
if err == nil {
msg.Id = m.Id
}
return
}
// newRequest returns a new DoH request given a dns.Msg.
func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) {
buf, err := m.Pack()
if err != nil {
// If the request failed anyway, make sure we don't use this client.
_, resErr := p.resetClient(err)
return nil, fmt.Errorf("err:%v,resErr:%v", err, resErr)
return nil, err
}
req, err := http.NewRequest(http.MethodPost, dc.url, bytes.NewReader(buf))
if err != nil {
return req, err
}
req.Header.Set("content-type", dotMimeType)
req.Header.Set("accept", dotMimeType)
return req, nil
}
func (dc *dohClient) doRequest(req *http.Request) (*D.Msg, error) {
client := &http.Client{Transport: dc.transport}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
msg := &D.Msg{}
err = msg.Unpack(buf)
return msg, err
}
// Exchange implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
return p.ExchangeContext(context.Background(), m)
}
func newDoHClient(url string, r *Resolver, proxyAdapter string) *dohClient {
return &dohClient{
url: url,
proxyAdapter: proxyAdapter,
transport: &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// Close implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Close() (err error) {
p.clientMu.Lock()
defer p.clientMu.Unlock()
ip, err := resolver.ResolveIPWithResolver(host, r)
if err != nil {
return nil, err
}
runtime.SetFinalizer(p, nil)
if p.client == nil {
return nil
}
return p.closeClient(p.client)
}
// closeClient cleans up resources used by client if necessary. Note, that at
// this point it should only be done for HTTP/3 as it may leak due to keep-alive
// connections.
func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
if isHTTP3(client) {
return client.Transport.(io.Closer).Close()
}
return nil
}
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (p *dnsOverHTTPS) exchangeHTTPS(ctx context.Context, client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
resp, err = p.exchangeHTTPSClient(ctx, client, req)
return resp, err
}
// exchangeHTTPSClient sends the DNS query to a DoH resolver using the specified
// http.Client instance.
func (p *dnsOverHTTPS) exchangeHTTPSClient(
ctx context.Context,
client *http.Client,
req *dns.Msg,
) (resp *dns.Msg, err error) {
buf, err := req.Pack()
if err != nil {
return nil, fmt.Errorf("packing message: %w", err)
}
// It appears, that GET requests are more memory-efficient with Golang
// implementation of HTTP/2.
method := http.MethodGet
if isHTTP3(client) {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
p.url.RawQuery = fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf))
httpReq, err := http.NewRequest(method, p.url.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", p.url, err)
}
httpReq.Header.Set("Accept", "application/dns-message")
httpReq.Header.Set("User-Agent", "")
_ = httpReq.WithContext(ctx)
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("requesting %s: %w", p.url, err)
}
defer httpResp.Body.Close()
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", p.url, err)
}
if httpResp.StatusCode != http.StatusOK {
return nil,
fmt.Errorf(
"expected status %d, got %d from %s",
http.StatusOK,
httpResp.StatusCode,
p.url,
)
}
resp = &dns.Msg{}
err = resp.Unpack(body)
if err != nil {
return nil, fmt.Errorf(
"unpacking response from %s: body is %s: %w",
p.url,
body,
err,
)
}
if resp.Id != req.Id {
err = dns.ErrId
}
return resp, err
}
// shouldRetry checks what error we have received and returns true if we should
// re-create the HTTP client and retry the request.
func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
if err == nil {
return false
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// If this is a timeout error, trying to forcibly re-create the HTTP
// client instance. This is an attempt to fix an issue with DoH client
// stalling after a network change.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3217.
return true
}
if isQUICRetryError(err) {
return true
}
return false
}
// resetClient triggers re-creation of the *http.Client that is used by this
// upstream. This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err error) {
p.clientMu.Lock()
defer p.clientMu.Unlock()
if errors.Is(resetErr, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
oldClient := p.client
if oldClient != nil {
closeErr := p.closeClient(oldClient)
if closeErr != nil {
log.Warnln("warning: failed to close the old http client: %v", closeErr)
}
}
log.Debugln("re-creating the http client due to %v", resetErr)
p.client, err = p.createClient()
return p.client, err
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
return p.quicConfig
}
// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
}
// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
startTime := time.Now()
p.clientMu.Lock()
defer p.clientMu.Unlock()
if p.client != nil {
return p.client, true, nil
}
// Timeout can be exceeded while waiting for the lock. This happens quite
// often on mobile devices.
elapsed := time.Since(startTime)
if elapsed > maxElapsedTime {
return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed)
}
log.Debugln("creating a new http client")
p.client, err = p.createClient()
return p.client, false, err
}
// createClient creates a new *http.Client instance. The HTTP protocol version
// will depend on whether HTTP3 is allowed and provided by this upstream. Note,
// that we'll attempt to establish a QUIC connection when creating the client in
// order to check whether HTTP3 is supported.
func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
transport, err := p.createTransport()
if err != nil {
return nil, fmt.Errorf("initializing http transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: DefaultTimeout,
Jar: nil,
}
p.client = client
return p.client, nil
}
// createTransport initializes an HTTP transport that will be used specifically
// for this DoH resolver. This HTTP transport ensures that the HTTP requests
// will be sent exactly to the IP address got from the bootstrap resolver. Note,
// that this function will first attempt to establish a QUIC connection (if
// HTTP3 is enabled in the upstream options). If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{
InsecureSkipVerify: false,
MinVersion: tls.VersionTLS12,
SessionTicketsDisabled: false,
})
var nextProtos []string
for _, v := range p.httpVersions {
nextProtos = append(nextProtos, string(v))
}
tlsConfig.NextProtos = nextProtos
dialContext := getDialHandler(p.r, p.proxyAdapter)
// First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this
// upstream.
transportH3, err := p.createTransportH3(tlsConfig, dialContext)
if err == nil {
log.Debugln("using HTTP/3 for this upstream: QUIC was faster")
return transportH3, nil
}
log.Debugln("using HTTP/2 for this upstream: %v", err)
if !p.supportsHTTP() {
return nil, errors.New("HTTP1/1 and HTTP2 are not supported by this upstream")
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true,
DialContext: dialContext,
IdleConnTimeout: transportDefaultIdleConnTimeout,
MaxConnsPerHost: dohMaxConnsPerHost,
MaxIdleConns: dohMaxIdleConns,
// Since we have a custom DialContext, we need to use this field to
// make golang http.Client attempt to use HTTP/2. Otherwise, it would
// only be used when negotiated on the TLS level.
ForceAttemptHTTP2: true,
}
// Explicitly configure transport to use HTTP/2.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/11.
var transportH2 *http2.Transport
transportH2, err = http2.ConfigureTransports(transport)
if err != nil {
return nil, err
}
// Enable HTTP/2 pings on idle connections.
transportH2.ReadIdleTimeout = transportDefaultReadIdleTimeout
return transport, nil
}
// http3Transport is a wrapper over *http3.RoundTripper that tries to optimize
// its behavior. The main thing that it does is trying to force use a single
// connection to a host instead of creating a new one all the time. It also
// helps mitigate race issues with quic-go.
type http3Transport struct {
baseTransport *http3.RoundTripper
closed bool
mu sync.RWMutex
}
// type check
var _ http.RoundTripper = (*http3Transport)(nil)
// RoundTrip implements the http.RoundTripper interface for *http3Transport.
func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
h.mu.RLock()
defer h.mu.RUnlock()
if h.closed {
return nil, net.ErrClosed
}
// Try to use cached connection to the target host if it's available.
resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true})
if errors.Is(err, http3.ErrNoCachedConn) {
// If there are no cached connection, trigger creating a new one.
resp, err = h.baseTransport.RoundTrip(req)
}
return resp, err
}
// type check
var _ io.Closer = (*http3Transport)(nil)
// Close implements the io.Closer interface for *http3Transport.
func (h *http3Transport) Close() (err error) {
h.mu.Lock()
defer h.mu.Unlock()
h.closed = true
return h.baseTransport.Close()
}
// createTransportH3 tries to create an HTTP/3 transport for this upstream.
// We should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or
// if it is too slow. In order to do that, this method will run two probes
// in parallel (one for TLS, the other one for QUIC) and if QUIC is faster it
// will create the *http3.RoundTripper instance.
func (doh *dnsOverHTTPS) createTransportH3(
tlsConfig *tls.Config,
dialContext dialHandler,
) (roundTripper http.RoundTripper, err error) {
if !doh.supportsH3() {
return nil, errors.New("HTTP3 support is not enabled")
}
addr, err := doh.probeH3(tlsConfig, dialContext)
if err != nil {
return nil, err
}
rt := &http3.RoundTripper{
Dial: func(
ctx context.Context,
// Ignore the address and always connect to the one that we got
// from the bootstrapper.
_ string,
tlsCfg *tls.Config,
cfg *quic.Config,
) (c quic.EarlyConnection, err error) {
return doh.dialQuic(ctx, addr, tlsCfg, cfg)
if proxyAdapter == "" {
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), port))
} else {
return dialContextExtra(ctx, proxyAdapter, "tcp", ip, port)
}
},
},
DisableCompression: true,
TLSClientConfig: tlsConfig,
QuicConfig: doh.getQUICConfig(),
}
return &http3Transport{baseTransport: rt}, nil
}
func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
portInt, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
udpAddr := net.UDPAddr{
IP: net.ParseIP(ip),
Port: portInt,
}
var conn net.PacketConn
if doh.proxyAdapter == "" {
conn, err = dialer.ListenPacket(ctx, "udp", "")
if err != nil {
return nil, err
}
} else {
if wrapConn, err := dialContextExtra(ctx, doh.proxyAdapter, "udp", udpAddr.AddrPort().Addr(), port); err == nil {
if pc, ok := wrapConn.(*wrapPacketConn); ok {
conn = pc
} else {
return nil, fmt.Errorf("conn isn't wrapPacketConn")
}
} else {
return nil, err
}
}
return quic.DialEarlyContext(ctx, conn, &udpAddr, doh.url.Host, tlsCfg, cfg)
}
// probeH3 runs a test to check whether QUIC is faster than TLS for this
// upstream. If the test is successful it will return the address that we
// should use to establish the QUIC connections.
func (p *dnsOverHTTPS) probeH3(
tlsConfig *tls.Config,
dialContext dialHandler,
) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there are v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", p.url.Host)
if err != nil {
return "", fmt.Errorf("failed to dial: %w", err)
}
// It's never actually used.
_ = rawConn.Close()
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return "", fmt.Errorf("not a UDP connection to %s", p.Address())
}
addr = udpConn.RemoteAddr().String()
// Avoid spending time on probing if this upstream only supports HTTP/3.
if p.supportsH3() && !p.supportsHTTP() {
return addr, nil
}
// Use a new *tls.Config with empty session cache for probe connections.
// Surprisingly, this is really important since otherwise it invalidates
// the existing cache.
// TODO(ameshkov): figure out why the sessions cache invalidates here.
probeTLSCfg := tlsConfig.Clone()
probeTLSCfg.ClientSessionCache = nil
// Do not expose probe connections to the callbacks that are passed to
// the bootstrap options to avoid side-effects.
// TODO(ameshkov): consider exposing, somehow mark that this is a probe.
probeTLSCfg.VerifyPeerCertificate = nil
probeTLSCfg.VerifyConnection = nil
// Run probeQUIC and probeTLS in parallel and see which one is faster.
chQuic := make(chan error, 1)
chTLS := make(chan error, 1)
go p.probeQUIC(addr, probeTLSCfg, chQuic)
go p.probeTLS(dialContext, probeTLSCfg, chTLS)
select {
case quicErr := <-chQuic:
if quicErr != nil {
// QUIC failed, return error since HTTP3 was not preferred.
return "", quicErr
}
// Return immediately, QUIC was faster.
return addr, quicErr
case tlsErr := <-chTLS:
if tlsErr != nil {
// Return immediately, TLS failed.
log.Debugln("probing TLS: %v", tlsErr)
return addr, nil
}
return "", errors.New("TLS was faster than QUIC, prefer it")
}
}
// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
timeout := DefaultTimeout
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
defer cancel()
conn, err := p.dialQuic(ctx, addr, tlsConfig, p.getQUICConfig())
if err != nil {
ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.Address(), err)
return
}
// Ignore the error since there's no way we can use it for anything useful.
_ = conn.CloseWithError(QUICCodeNoError, "")
ch <- nil
elapsed := time.Now().Sub(startTime)
log.Debugln("elapsed on establishing a QUIC connection: %s", elapsed)
}
// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
conn, err := p.tlsDial(dialContext, "tcp", tlsConfig)
if err != nil {
ch <- fmt.Errorf("opening TLS connection: %w", err)
return
}
// Ignore the error since there's no way we can use it for anything useful.
_ = conn.Close()
ch <- nil
elapsed := time.Now().Sub(startTime)
log.Debugln("elapsed on establishing a TLS connection: %s", elapsed)
}
// supportsH3 returns true if HTTP/3 is supported by this upstream.
func (p *dnsOverHTTPS) supportsH3() (ok bool) {
for _, v := range p.supportedHTTPVersions() {
if v == C.HTTPVersion3 {
return true
}
}
return false
}
// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream.
func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
for _, v := range p.supportedHTTPVersions() {
if v == C.HTTPVersion11 || v == C.HTTPVersion2 {
return true
}
}
return false
}
// supportedHTTPVersions returns the list of supported HTTP versions.
func (p *dnsOverHTTPS) supportedHTTPVersions() (v []C.HTTPVersion) {
v = p.httpVersions
if v == nil {
v = DefaultHTTPVersions
}
return v
}
// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
_, ok = client.Transport.(*http3Transport)
return ok
}
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func (doh *dnsOverHTTPS) tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed
// to the function.
rawConn, err := dialContext(context.Background(), network, doh.url.Host)
if err != nil {
return nil, err
}
// We want the timeout to cover the whole process: TCP connection and
// TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
panic(fmt.Errorf("cannot set deadline: %w", err))
}
err = conn.Handshake()
if err != nil {
defer conn.Close()
return nil, err
}
return conn, nil
}

View File

@ -1,354 +1,174 @@
package dns
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/resolver"
"github.com/lucas-clemente/quic-go"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"time"
"github.com/Dreamacro/clash/component/dialer"
tlsC "github.com/Dreamacro/clash/component/tls"
"github.com/lucas-clemente/quic-go"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
)
const NextProtoDQ = "doq"
const (
// QUICCodeNoError is used when the connection or stream needs to be closed,
// but there is no error to signal.
QUICCodeNoError = quic.ApplicationErrorCode(0)
// QUICCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
QUICCodeInternalError = quic.ApplicationErrorCode(1)
// QUICKeepAlivePeriod is the value that we pass to *quic.Config and that
// controls the period with with keep-alive frames are being sent to the
// connection. We set it to 20s as it would be in the quic-go@v0.27.1 with
// KeepAlive field set to true This value is specified in
// https://pkg.go.dev/github.com/lucas-clemente/quic-go/internal/protocol#MaxKeepAliveInterval.
//
// TODO(ameshkov): Consider making it configurable.
QUICKeepAlivePeriod = time.Second * 20
DefaultTimeout = time.Second * 5
)
type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error)
// dnsOverQUIC is a struct that implements the Upstream interface for the
// DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct {
// quicConfig is the QUIC configuration that is used for establishing
// connections to the upstream. This configuration includes the TokenStore
// that needs to be stored for the lifetime of dnsOverQUIC since we can
// re-create the connection.
quicConfig *quic.Config
quicConfigGuard sync.Mutex
// conn is the current active QUIC connection. It can be closed and
// re-opened when needed.
conn quic.Connection
connMu sync.RWMutex
// bytesPool is a *sync.Pool we use to store byte buffers in. These byte
// buffers are used to read responses from the upstream.
bytesPool *sync.Pool
bytesPoolGuard sync.Mutex
var bytesPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }}
type quicClient struct {
addr string
proxyAdapter string
r *Resolver
session quic.Connection
proxyAdapter string
sync.RWMutex // protects session and bytesPool
}
// type check
var _ dnsClient = (*dnsOverQUIC)(nil)
// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(resolver *Resolver, addr string, adapter string) (dnsClient, error) {
doq := &dnsOverQUIC{
func newDOQ(r *Resolver, addr, proxyAdapter string) *quicClient {
return &quicClient{
addr: addr,
proxyAdapter: adapter,
r: resolver,
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
},
r: r,
proxyAdapter: proxyAdapter,
}
runtime.SetFinalizer(doq, (*dnsOverQUIC).Close)
return doq, nil
}
// Address implements the Upstream interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.addr }
func (p *dnsOverQUIC) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
// When sending queries over a QUIC connection, the DNS Message ID MUST be
// set to zero.
id := m.Id
m.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies.
m.Id = id
if msg != nil {
msg.Id = id
}
}()
// Check if there was already an active conn before sending the request.
// We'll only attempt to re-connect if there was one.
hasConnection := p.hasConnection()
// Make the first attempt to send the DNS query.
msg, err = p.exchangeQUIC(ctx, m)
// Make up to 2 attempts to re-open the QUIC connection and send the request
// again. There are several cases where this workaround is necessary to
// make DoQ usable. We need to make 2 attempts in the case when the
// connection was closed (due to inactivity for example) AND the server
// refuses to open a 0-RTT connection.
for i := 0; hasConnection && p.shouldRetry(err) && i < 2; i++ {
log.Debugln("re-creating the QUIC connection and retrying due to %v", err)
// Close the active connection to make sure we'll try to re-connect.
p.closeConnWithError(err)
// Retry sending the request.
msg, err = p.exchangeQUIC(ctx, m)
}
func (dc *quicClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
return dc.ExchangeContext(context.Background(), m)
}
func (dc *quicClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
stream, err := dc.openStream(ctx)
if err != nil {
// If we're unable to exchange messages, make sure the connection is
// closed and signal about an internal error.
p.closeConnWithError(err)
return nil, fmt.Errorf("failed to open new stream to %s", dc.addr)
}
return msg, err
}
// Exchange implements the Upstream interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(m *D.Msg) (msg *D.Msg, err error) {
return p.ExchangeContext(context.Background(), m)
}
// Close implements the Upstream interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Close() (err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
runtime.SetFinalizer(p, nil)
if p.conn != nil {
err = p.conn.CloseWithError(QUICCodeNoError, "")
}
return err
}
// exchangeQUIC attempts to open a QUIC connection, send the DNS message
// through it and return the response it got from the server.
func (p *dnsOverQUIC) exchangeQUIC(ctx context.Context, msg *D.Msg) (resp *D.Msg, err error) {
var conn quic.Connection
conn, err = p.getConnection(true)
buf, err := m.Pack()
if err != nil {
return nil, err
}
var buf []byte
buf, err = msg.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack DNS message for DoQ: %w", err)
}
var stream quic.Stream
stream, err = p.openStream(ctx, conn)
_, err = stream.Write(buf)
if err != nil {
return nil, err
}
_, err = stream.Write(AddPrefix(buf))
if err != nil {
return nil, fmt.Errorf("failed to write to a QUIC stream: %w", err)
}
// The client MUST send the DNS query over the selected stream, and MUST
// indicate through the STREAM FIN mechanism that no further data will
// be sent on that stream. Note, that stream.Close() closes the
// write-direction of the stream, but does not prevent reading from it.
// be sent on that stream.
// stream.Close() -- closes the write-direction of the stream.
_ = stream.Close()
return p.readMsg(stream)
}
respBuf := bytesPool.Get().(*bytes.Buffer)
defer bytesPool.Put(respBuf)
defer respBuf.Reset()
// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
m = make([]byte, 2+len(b))
binary.BigEndian.PutUint16(m, uint16(len(b)))
copy(m[2:], b)
return m
}
// shouldRetry checks what error we received and decides whether it is required
// to re-open the connection and retry sending the request.
func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) {
return isQUICRetryError(err)
}
// getBytesPool returns (creates if needed) a pool we store byte buffers in.
func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
p.bytesPoolGuard.Lock()
defer p.bytesPoolGuard.Unlock()
if p.bytesPool == nil {
p.bytesPool = &sync.Pool{
New: func() interface{} {
b := make([]byte, MaxMsgSize)
return &b
},
}
n, err := respBuf.ReadFrom(stream)
if err != nil && n == 0 {
return nil, err
}
return p.bytesPool
reply := new(D.Msg)
err = reply.Unpack(respBuf.Bytes())
if err != nil {
return nil, err
}
return reply, nil
}
// getConnection opens or returns an existing quic.Connection. useCached
// argument controls whether we should try to use the existing cached
// connection. If it is false, we will forcibly create a new connection and
// close the existing one if needed.
func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
var conn quic.Connection
p.connMu.RLock()
conn = p.conn
if conn != nil && useCached {
p.connMu.RUnlock()
return conn, nil
func isActive(s quic.Connection) bool {
select {
case <-s.Context().Done():
return false
default:
return true
}
if conn != nil {
// we're recreating the connection, let's create a new one.
_ = conn.CloseWithError(QUICCodeNoError, "")
}
p.connMu.RUnlock()
}
p.connMu.Lock()
defer p.connMu.Unlock()
// getSession - opens or returns an existing quic.Connection
// useCached - if true and cached session exists, return it right away
// otherwise - forcibly creates a new session
func (dc *quicClient) getSession(ctx context.Context) (quic.Connection, error) {
var session quic.Connection
dc.RLock()
session = dc.session
if session != nil && isActive(session) {
dc.RUnlock()
return session, nil
}
if session != nil {
// we're recreating the session, let's create a new one
_ = session.CloseWithError(0, "")
}
dc.RUnlock()
dc.Lock()
defer dc.Unlock()
var err error
conn, err = p.openConnection()
session, err = dc.openSession(ctx)
if err != nil {
// This does not look too nice, but QUIC (or maybe quic-go)
// doesn't seem stable enough.
// Maybe retransmissions aren't fully implemented in quic-go?
// Anyways, the simple solution is to make a second try when
// it fails to open the QUIC session.
session, err = dc.openSession(ctx)
if err != nil {
return nil, err
}
}
dc.session = session
return session, nil
}
func (dc *quicClient) openSession(ctx context.Context) (quic.Connection, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: false,
NextProtos: []string{
NextProtoDQ,
},
SessionTicketsDisabled: false,
}
quicConfig := &quic.Config{
ConnectionIDLength: 12,
HandshakeIdleTimeout: time.Second * 8,
MaxIncomingStreams: 4,
MaxIdleTimeout: time.Second * 45,
}
log.Debugln("opening session to %s", dc.addr)
var (
udp net.PacketConn
err error
)
host, port, err := net.SplitHostPort(dc.addr)
if err != nil {
return nil, err
}
p.conn = conn
return conn, nil
}
// hasConnection returns true if there's an active QUIC connection.
func (p *dnsOverQUIC) hasConnection() (ok bool) {
p.connMu.Lock()
defer p.connMu.Unlock()
return p.conn != nil
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
return p.quicConfig
}
// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
}
// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(ctx context.Context, conn quic.Connection) (quic.Stream, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := conn.OpenStreamSync(ctx)
if err == nil {
return stream, nil
}
// We can get here if the old QUIC connection is not valid anymore. We
// should try to re-create the connection again in this case.
newConn, err := p.getConnection(false)
if err != nil {
return nil, err
}
// Open a new stream.
return newConn.OpenStreamSync(ctx)
}
// openConnection opens a new QUIC connection.
func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
tlsConfig := tlsC.GetGlobalFingerprintTLCConfig(
&tls.Config{
InsecureSkipVerify: false,
NextProtos: []string{
NextProtoDQ,
},
SessionTicketsDisabled: false,
})
// we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses).
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
rawConn, err := getDialHandler(doq.r, doq.proxyAdapter)(ctx, "udp", doq.addr)
if err != nil {
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
}
// It's never actually used
_ = rawConn.Close()
cancel()
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("failed to open connection to %s", doq.addr)
}
addr := udpConn.RemoteAddr().String()
ip, port, err := net.SplitHostPort(addr)
ip, err := resolver.ResolveIPv4WithResolver(host, dc.r)
if err != nil {
return nil, err
}
p, err := strconv.Atoi(port)
udpAddr := net.UDPAddr{IP: net.ParseIP(ip), Port: p}
var udp net.PacketConn
if doq.proxyAdapter == "" {
udpAddr := net.UDPAddr{IP: ip.AsSlice(), Port: p}
if dc.proxyAdapter == "" {
udp, err = dialer.ListenPacket(ctx, "udp", "")
if err != nil {
return nil, err
}
} else {
ipAddr, err := netip.ParseAddr(ip)
if err != nil {
return nil, err
}
conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port)
conn, err := dialContextExtra(ctx, dc.proxyAdapter, "udp", ip, port)
if err != nil {
return nil, err
}
@ -361,158 +181,20 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
udp = wrapConn
}
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
host, _, err := net.SplitHostPort(doq.addr)
session, err := quic.DialContext(ctx, udp, &udpAddr, host, tlsConfig, quicConfig)
if err != nil {
return nil, fmt.Errorf("failed to open QUIC session: %w", err)
}
return session, nil
}
func (dc *quicClient) openStream(ctx context.Context) (quic.Stream, error) {
session, err := dc.getSession(ctx)
if err != nil {
return nil, err
}
conn, err = quic.DialContext(ctx, udp, &udpAddr, host, tlsConfig, doq.getQUICConfig())
if err != nil {
return nil, fmt.Errorf("opening quic connection to %s: %w", doq.addr, err)
}
return conn, nil
}
// closeConnWithError closes the active connection with error to make sure that
// new queries were processed in another connection. We can do that in the case
// of a fatal error.
func (p *dnsOverQUIC) closeConnWithError(err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
if p.conn == nil {
// Do nothing, there's no active conn anyways.
return
}
code := QUICCodeNoError
if err != nil {
code = QUICCodeInternalError
}
if errors.Is(err, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
err = p.conn.CloseWithError(code, "")
if err != nil {
log.Errorln("failed to close the conn: %v", err)
}
p.conn = nil
}
// readMsg reads the incoming DNS message from the QUIC stream.
func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *D.Msg, err error) {
pool := p.getBytesPool()
bufPtr := pool.Get().(*[]byte)
defer pool.Put(bufPtr)
respBuf := *bufPtr
n, err := stream.Read(respBuf)
if err != nil && n == 0 {
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err)
}
// All DNS messages (queries and responses) sent over DoQ connections MUST
// be encoded as a 2-octet length field followed by the message content as
// specified in [RFC1035].
// IMPORTANT: Note, that we ignore this prefix here as this implementation
// does not support receiving multiple messages over a single connection.
m = new(D.Msg)
err = m.Unpack(respBuf[2:])
if err != nil {
return nil, fmt.Errorf("unpacking response from %s: %w", p.Address(), err)
}
return m, nil
}
// newQUICTokenStore creates a new quic.TokenStore that is necessary to have
// in order to benefit from 0-RTT.
func newQUICTokenStore() (s quic.TokenStore) {
// You can read more on address validation here:
// https://datatracker.ietf.org/doc/html/rfc9000#section-8.1
// Setting maxOrigins to 1 and tokensPerOrigin to 10 assuming that this is
// more than enough for the way we use it (one connection per upstream).
return quic.NewLRUTokenStore(1, 10)
}
// isQUICRetryError checks the error and determines whether it may signal that
// we should re-create the QUIC connection. This requirement is caused by
// quic-go issues, see the comments inside this function.
// TODO(ameshkov): re-test when updating quic-go.
func isQUICRetryError(err error) (ok bool) {
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
// This error is often returned when the server has been restarted,
// and we try to use the same connection on the client-side. It seems,
// that the old connections aren't closed immediately on the server-side
// and that's why one can run into this.
// In addition to that, quic-go HTTP3 client implementation does not
// clean up dead connections (this one is specific to DoH3 upstream):
// https://github.com/lucas-clemente/quic-go/issues/765
return true
}
var qIdleErr *quic.IdleTimeoutError
if errors.As(err, &qIdleErr) {
// This error means that the connection was closed due to being idle.
// In this case we should forcibly re-create the QUIC connection.
// Reproducing is rather simple, stop the server and wait for 30 seconds
// then try to send another request via the same upstream.
return true
}
var resetErr *quic.StatelessResetError
if errors.As(err, &resetErr) {
// A stateless reset is sent when a server receives a QUIC packet that
// it doesn't know how to decrypt. For instance, it may happen when
// the server was recently rebooted. We should reconnect and try again
// in this case.
return true
}
var qTransportError *quic.TransportError
if errors.As(err, &qTransportError) && qTransportError.ErrorCode == quic.NoError {
// A transport error with the NO_ERROR error code could be sent by the
// server when it considers that it's time to close the connection.
// For example, Google DNS eventually closes an active connection with
// the NO_ERROR code and "Connection max age expired" message:
// https://github.com/AdguardTeam/dnsproxy/issues/283
return true
}
if errors.Is(err, quic.Err0RTTRejected) {
// This error happens when we try to establish a 0-RTT connection with
// a token the server is no more aware of. This can be reproduced by
// restarting the QUIC server (it will clear its tokens cache). The
// next connection attempt will return this error until the client's
// tokens cache is purged.
return true
}
return false
}
func getDialHandler(r *Resolver, proxyAdapter string) dialHandler {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip, err := r.ResolveIP(host)
if err != nil {
return nil, err
}
if len(proxyAdapter) == 0 {
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port), dialer.WithDirect())
} else {
return dialContextExtra(ctx, proxyAdapter, network, ip.Unmap(), port, dialer.WithDirect())
}
}
// open a new stream
return session.OpenStreamSync(ctx)
}

Some files were not shown because too many files have changed in this diff Show More