Compare commits

..

79 Commits

Author SHA1 Message Date
6c83ff3496 Chore: update dependencies 2021-03-10 21:13:23 +08:00
f7f97ef625 Fix: some HTTP proxy request broken 2021-03-10 16:23:55 +08:00
5acdd72a1d Fix: remove host if host is ip string 2021-03-10 12:49:30 +08:00
f53686103d Chore: reset udp timeout after sending each packet (#1260) 2021-02-26 10:40:55 +08:00
f63c9eb22f Chore: update staticcheck command on actions 2021-02-21 19:37:37 +08:00
a37243cf30 Fix: store cache correctly 2021-02-21 01:07:22 +08:00
b3c1b4a840 Chore: update dependencies 2021-02-19 20:35:10 +08:00
14bbf6eedc Feature: support store group selected node to cache (enable by default) 2021-02-18 23:41:50 +08:00
aa81193d5b Feature: add darwin arm64 to Makefile (Apple Silicon) (#1234) 2021-02-18 18:15:09 +08:00
9eb98e399d Improve: refactor ssr and fix #995 (#1189)
Co-authored-by: goomada <madao@DESKTOP-IOEBS0C.localdomain>
2021-02-15 14:32:03 +08:00
d48cfecf60 Chore: API support patch ipv6 config (#1217) 2021-02-05 16:43:42 +08:00
6036fb63ba Chore: avoid provider unnecessary write file operations (#1210) 2021-02-02 17:52:46 +08:00
cd48f69b1f Fix: wrap net.Conn to avoid using *net.TCPConn.(ReadFrom) (#1209) 2021-02-01 20:06:45 +08:00
fcc594ae26 Chore: use jsdelivr CDN for Country.mmdb (#1057) 2021-01-30 00:40:35 +08:00
f4de055aa1 Refactor: make inbound request contextual 2021-01-23 14:58:09 +08:00
35925cb3da Chore: standardized Dockerfile label (#1191)
Signed-off-by: Junjie Yuan <yuan@junjie.pro>
2021-01-20 16:08:24 +08:00
ff430df845 Fix: connectivity of ssr auth_chain_(ab) protocol (#1180) 2021-01-13 23:35:41 +08:00
e4cdea2111 chore: use singleDo to get interface info 2021-01-13 17:30:54 +08:00
b6ee47a541 Fix: get general should return correct result (#1172) 2021-01-07 13:59:39 +08:00
b25009cde7 Fix: unnecessary write operation on provider (#1170) 2021-01-06 14:20:15 +08:00
6fedd7ec84 Fix: dns client should not bind local address 2021-01-04 00:51:53 +08:00
9619c3fb20 Fix: support unspecified UDP bind address (#1159) 2020-12-31 18:58:03 +08:00
02d029dd2d Fix: close http Response body on provider (#1154) 2020-12-29 11:28:22 +08:00
09c28e0355 Fix: fallback bind fn should not bind global unicast 2020-12-28 22:24:58 +08:00
3600077f3b Chore: update dependencies 2020-12-27 18:59:59 +08:00
de7656a787 Chore: update premium README 2020-12-27 00:14:24 +08:00
5dfe7f8561 Fix: handle keep alive on http connect proxy 2020-12-24 14:55:11 +08:00
ed27898a33 Fix: snell should support the config without obfs 2020-12-24 13:47:56 +08:00
532396d25c Fix: PROCESS-NAME rule for UDP sessions on Windows (#1140) 2020-12-22 15:13:44 +08:00
4b1b494164 Chore: move find process name to a single part 2020-12-17 22:17:27 +08:00
0d33dc3eb9 Chore: health checks return immediately if completed (#1097) 2020-11-24 22:52:23 +08:00
994cbff215 Fix: should not log rule when rule = nil 2020-11-22 23:38:12 +08:00
bea2ee8bf2 Chore: log rule msg on dial error 2020-11-22 19:12:36 +08:00
1e5593f1a9 Chore: update dependencies 2020-11-20 20:36:20 +08:00
34febc4579 Chore: more detailed error when dial failed 2020-11-20 00:27:37 +08:00
97581148b5 Fix: static check 2020-11-19 00:56:36 +08:00
0402878daa Feature: add lazy for proxy group and provider 2020-11-19 00:53:22 +08:00
4735f61fd1 Feature: add disable-udp option for all proxy group 2020-11-13 21:48:52 +08:00
16ae107e70 Chore: push image to github docker registry 2020-11-10 15:19:12 +08:00
83efe2ae57 Feature: add TCP TPROXY support (#1049) 2020-11-09 10:46:10 +08:00
87e4d94290 Fix: tunnel manager & tracker race condition (#1048) 2020-10-29 17:51:14 +08:00
b98e9ea202 Improve: #1038 and #1041 2020-10-29 00:32:31 +08:00
9a62b1081d Feature: support round-robin strategy for load-balance group (#1044) 2020-10-28 22:35:02 +08:00
2cd1b890ce Fix: tunnel UDP race condition (#1043) 2020-10-28 21:26:50 +08:00
ba060bd0ee Fix: should not bind interface on local address 2020-10-25 20:31:01 +08:00
b1795b1e3d Fix: stale typo 2020-10-25 11:53:03 +08:00
76c9820065 Fix: undefined variable 2020-10-23 17:49:34 +08:00
2db4ce57ef Chore: make stale time into 60 days 2020-10-23 00:30:17 +08:00
50b3d497f6 Feature: use native syscall to bind interface on Linux and macOS 2020-10-22 22:32:03 +08:00
2321e9139d Chore: deprecated eapache/channels 2020-10-20 17:44:39 +08:00
baabf21340 Chore: update github workflow 2020-10-17 13:46:05 +08:00
d3bb4c65a8 Fix: missing fake-ip record should return error 2020-10-17 12:52:43 +08:00
8c3e2a7559 Chore: fix typo (#1017) 2020-10-14 19:56:02 +08:00
bc52f8e4fd Chore: return empty record in SVCB/HTTPSSVC on fake-ip mode 2020-10-13 00:15:49 +08:00
d3b14c325f Fix: the priority of fake-ip-filter 2020-10-09 00:04:24 +08:00
4859b158b4 Chore: make builds reproducible (#1006) 2020-10-08 17:54:38 +08:00
d65b51c62b Feature: http support custom sni 2020-10-02 11:34:40 +08:00
a6444bb449 Feature: support domain in fallback filter (#964) 2020-09-28 22:17:10 +08:00
e09931dcf7 Chore: remove broken test temporarily 2020-09-26 20:36:52 +08:00
5bd189f2d0 Feature: support VMess HTTP/2 transport (#903) 2020-09-26 20:33:57 +08:00
8766287e72 Chore: sync necessary changes from premium 2020-09-21 22:22:07 +08:00
10f9571c9e Fix: pool gc test 2020-09-21 00:44:47 +08:00
96a8259c42 Feature: support snell v2 (#952)
Co-authored-by: Dreamacro <8615343+Dreamacro@users.noreply.github.com>
2020-09-21 00:33:13 +08:00
68dd0622b8 Chore: code style 2020-09-20 15:53:27 +08:00
558ac6b965 Chore: split enhanced mode instance (#936)
Co-authored-by: Dreamacro <305009791@qq.com>
2020-09-17 10:48:42 +08:00
e773f95f21 Fix: PROCESS-NAME on FreeBSD 11.x (#947) 2020-09-07 17:43:34 +08:00
314ce1c249 Feature: vmess network http support TLS (https) 2020-09-04 21:27:19 +08:00
13275b1aa6 Chore: use only one goroutine to handle statistic (#940) 2020-09-03 10:30:18 +08:00
02d9169b5d Fix: potential PCB buffer overflow on bsd systems (#941) 2020-09-03 10:27:20 +08:00
7631bcc99e Improve: use atomic for connection statistic (#938) 2020-09-02 16:34:12 +08:00
a32ee13fc9 Feature: reuse dns resolver cache when hot reload 2020-08-31 00:32:18 +08:00
b8ed738238 Chore: update actions version 2020-08-30 23:06:21 +08:00
687c2a21cf Fix: vmess UDP option should be effect 2020-08-30 22:49:55 +08:00
ad18064e6b Chore: code style (#933) 2020-08-30 19:53:00 +08:00
c9735ef75b Fix: static check 2020-08-25 22:36:38 +08:00
b70882f01a Chore: add static check 2020-08-25 22:32:23 +08:00
5805334ccd Chore: pass staticcheck 2020-08-25 22:19:59 +08:00
c1b4382fe8 Feature: add Windows ARM32 build (#902)
Co-authored-by: MarksonHon <50002150+MarksonHon@users.noreply.github.com>
2020-08-16 13:50:56 +08:00
008743f20b Chore: update dependencies 2020-08-16 11:32:51 +08:00
134 changed files with 4678 additions and 3272 deletions

View File

@ -17,37 +17,60 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up QEMU
uses: docker/setup-qemu-action@v1
with:
platforms: all
- name: Set up docker buildx - name: Set up docker buildx
id: buildx id: buildx
uses: crazy-max/ghaction-docker-buildx@v2 uses: docker/setup-buildx-action@v1
with: with:
buildx-version: latest version: latest
skip-cache: false
qemu-version: latest
- name: Docker login - name: Login to DockerHub
env: uses: docker/login-action@v1
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} with:
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} username: ${{ secrets.DOCKER_USERNAME }}
run: | password: ${{ secrets.DOCKER_PASSWORD }}
echo "${DOCKER_PASSWORD}" | docker login --username "${DOCKER_USERNAME}" --password-stdin
- name: Docker buildx image and push on dev branch - name: Login to Github Package
uses: docker/login-action@v1
with:
registry: ghcr.io
username: Dreamacro
password: ${{ secrets.PACKAGE_TOKEN }}
- name: Build dev branch and push
if: github.ref == 'refs/heads/dev' if: github.ref == 'refs/heads/dev'
run: | uses: docker/build-push-action@v2
docker buildx build --output "type=image,push=true" --platform=linux/amd64,linux/arm/v7,linux/arm64 --tag dreamacro/clash:dev . with:
context: .
platforms: linux/amd64,linux/arm/v7,linux/arm64
push: true
tags: 'dreamacro/clash:dev,ghcr.io/dreamacro/clash:dev'
- name: Replace tag without `v` - name: Get all docker tags
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
uses: actions/github-script@v1 uses: actions/github-script@v3
id: version id: tags
with: with:
script: | script: |
return context.payload.ref.replace(/\/?refs\/tags\/v/, '') const ref = `${context.payload.ref.replace(/\/?refs\/tags\//, '')}`
const tags = [
'dreamacro/clash:latest',
`dreamacro/clash:${ref}`,
'ghcr.io/dreamacro/clash:latest',
`ghcr.io/dreamacro/clash:${ref}`
]
return tags.join(',')
result-encoding: string result-encoding: string
- name: Docker buildx image and push on release - name: Build release and push
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
run: | uses: docker/build-push-action@v2
docker buildx build --output "type=image,push=true" --platform=linux/amd64,linux/arm/v7,linux/arm64 --tag dreamacro/clash:${{steps.version.outputs.result}} . with:
docker buildx build --output "type=image,push=true" --platform=linux/amd64,linux/arm/v7,linux/arm64 --tag dreamacro/clash:latest . context: .
platforms: linux/amd64,linux/arm/v7,linux/arm64
push: true
tags: ${{steps.tags.outputs.result}}

View File

@ -9,7 +9,7 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.15.x go-version: 1.16
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -22,9 +22,12 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Get dependencies and run test - name: Get dependencies, run test and static check
run: | run: |
go test ./... go test ./...
go vet ./...
go install honnef.co/go/tools/cmd/staticcheck@latest
staticcheck -- $(go list ./...)
- name: Build - name: Build
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')

View File

@ -11,9 +11,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@v1 - uses: actions/stale@v3
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue is stale because it has been open 120 days with no activity. Remove stale label or comment or this will be closed in 5 days' stale-issue-message: 'This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 5 days'
days-before-stale: 120 days-before-stale: 60
days-before-close: 5 days-before-close: 5

View File

@ -10,6 +10,7 @@ RUN go mod download && \
mv ./bin/clash-docker /clash mv ./bin/clash-docker /clash
FROM alpine:latest FROM alpine:latest
LABEL org.opencontainers.image.source="https://github.com/Dreamacro/clash"
RUN apk add --no-cache ca-certificates RUN apk add --no-cache ca-certificates
COPY --from=builder /Country.mmdb /root/.config/clash/ COPY --from=builder /Country.mmdb /root/.config/clash/

View File

@ -4,10 +4,11 @@ VERSION=$(shell git describe --tags || echo "unknown version")
BUILDTIME=$(shell date -u) BUILDTIME=$(shell date -u)
GOBUILD=CGO_ENABLED=0 go build -trimpath -ldflags '-X "github.com/Dreamacro/clash/constant.Version=$(VERSION)" \ GOBUILD=CGO_ENABLED=0 go build -trimpath -ldflags '-X "github.com/Dreamacro/clash/constant.Version=$(VERSION)" \
-X "github.com/Dreamacro/clash/constant.BuildTime=$(BUILDTIME)" \ -X "github.com/Dreamacro/clash/constant.BuildTime=$(BUILDTIME)" \
-w -s' -w -s -buildid='
PLATFORM_LIST = \ PLATFORM_LIST = \
darwin-amd64 \ darwin-amd64 \
darwin-arm64 \
linux-386 \ linux-386 \
linux-amd64 \ linux-amd64 \
linux-armv5 \ linux-armv5 \
@ -25,7 +26,8 @@ PLATFORM_LIST = \
WINDOWS_ARCH_LIST = \ WINDOWS_ARCH_LIST = \
windows-386 \ windows-386 \
windows-amd64 windows-amd64 \
windows-arm32v7
all: linux-amd64 darwin-amd64 windows-amd64 # Most used all: linux-amd64 darwin-amd64 windows-amd64 # Most used
@ -35,6 +37,9 @@ docker:
darwin-amd64: darwin-amd64:
GOARCH=amd64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=amd64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-arm64:
GOARCH=arm64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-386: linux-386:
GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
@ -83,6 +88,9 @@ windows-386:
windows-amd64: windows-amd64:
GOARCH=amd64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe GOARCH=amd64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-arm32v7:
GOARCH=arm GOOS=windows GOARM=7 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
gz_releases=$(addsuffix .gz, $(PLATFORM_LIST)) gz_releases=$(addsuffix .gz, $(PLATFORM_LIST))
zip_releases=$(addsuffix .zip, $(WINDOWS_ARCH_LIST)) zip_releases=$(addsuffix .zip, $(WINDOWS_ARCH_LIST))

View File

@ -28,9 +28,18 @@
- Netfilter TCP redirecting. Deploy Clash on your Internet gateway with `iptables`. - Netfilter TCP redirecting. Deploy Clash on your Internet gateway with `iptables`.
- Comprehensive HTTP RESTful API controller - Comprehensive HTTP RESTful API controller
## Premium Features
- TUN mode on macOS, Linux and Windows. [Doc](https://github.com/Dreamacro/clash/wiki/premium-core-features#tun-device)
- Match your tunnel by [Script](https://github.com/Dreamacro/clash/wiki/premium-core-features#script)
- [Rule Provider](https://github.com/Dreamacro/clash/wiki/premium-core-features#rule-providers)
## Getting Started ## Getting Started
Documentations are now moved to [GitHub Wiki](https://github.com/Dreamacro/clash/wiki). Documentations are now moved to [GitHub Wiki](https://github.com/Dreamacro/clash/wiki).
## Premium Release
[Release](https://github.com/Dreamacro/clash/releases/tag/premium)
## Credits ## Credits
* [riobard/go-shadowsocks2](https://github.com/riobard/go-shadowsocks2) * [riobard/go-shadowsocks2](https://github.com/riobard/go-shadowsocks2)

View File

@ -6,33 +6,18 @@ import (
"strings" "strings"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// HTTPAdapter is a adapter for HTTP connection // NewHTTP recieve normal http request and return HTTPContext
type HTTPAdapter struct { func NewHTTP(request *http.Request, conn net.Conn) *context.HTTPContext {
net.Conn
metadata *C.Metadata
R *http.Request
}
// Metadata return destination metadata
func (h *HTTPAdapter) Metadata() *C.Metadata {
return h.metadata
}
// NewHTTP is HTTPAdapter generator
func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTP metadata.Type = C.HTTP
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port
} }
return &HTTPAdapter{ return context.NewHTTPContext(conn, request, metadata)
metadata: metadata,
R: request,
Conn: conn,
}
} }
// RemoveHopByHopHeaders remove hop-by-hop header // RemoveHopByHopHeaders remove hop-by-hop header
@ -58,3 +43,19 @@ func RemoveHopByHopHeaders(header http.Header) {
header.Del(strings.TrimSpace(h)) header.Del(strings.TrimSpace(h))
} }
} }
// RemoveExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
func RemoveExtraHTTPHostPort(req *http.Request) {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
host = pHost
}
req.Host = host
req.URL.Host = host
}

View File

@ -5,18 +5,16 @@ import (
"net/http" "net/http"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// NewHTTPS is HTTPAdapter generator // NewHTTPS recieve CONNECT request and return ConnContext
func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPCONNECT metadata.Type = C.HTTPCONNECT
if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil { if ip, port, err := parseAddr(conn.RemoteAddr().String()); err == nil {
metadata.SrcIP = ip metadata.SrcIP = ip
metadata.SrcPort = port metadata.SrcPort = port
} }
return &SocketAdapter{ return context.NewConnContext(conn, metadata)
metadata: metadata,
Conn: conn,
}
} }

View File

@ -5,21 +5,11 @@ import (
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
) )
// SocketAdapter is a adapter for socks and redir connection // NewSocket recieve TCP inbound and return ConnContext
type SocketAdapter struct { func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnContext {
net.Conn
metadata *C.Metadata
}
// Metadata return destination metadata
func (s *SocketAdapter) Metadata() *C.Metadata {
return s.metadata
}
// NewSocket is SocketAdapter generator
func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter {
metadata := parseSocksAddr(target) metadata := parseSocksAddr(target)
metadata.NetWork = C.TCP metadata.NetWork = C.TCP
metadata.Type = source metadata.Type = source
@ -28,8 +18,5 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *SocketAdapter
metadata.SrcPort = port metadata.SrcPort = port
} }
return &SocketAdapter{ return context.NewConnContext(conn, metadata)
Conn: conn,
metadata: metadata,
}
} }

View File

@ -6,15 +6,12 @@ import (
"errors" "errors"
"net" "net"
"net/http" "net/http"
"sync/atomic"
"time" "time"
"github.com/Dreamacro/clash/common/queue" "github.com/Dreamacro/clash/common/queue"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
)
var ( "go.uber.org/atomic"
defaultURLTestTimeout = time.Second * 5
) )
type Base struct { type Base struct {
@ -99,11 +96,11 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
type Proxy struct { type Proxy struct {
C.ProxyAdapter C.ProxyAdapter
history *queue.Queue history *queue.Queue
alive uint32 alive *atomic.Bool
} }
func (p *Proxy) Alive() bool { func (p *Proxy) Alive() bool {
return atomic.LoadUint32(&p.alive) > 0 return p.alive.Load()
} }
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
@ -115,7 +112,7 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata) conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil { if err != nil {
atomic.StoreUint32(&p.alive, 0) p.alive.Store(false)
} }
return conn, err return conn, err
} }
@ -132,7 +129,7 @@ func (p *Proxy) DelayHistory() []C.DelayHistory {
// LastDelay return last history record. if proxy is not alive, return the max value of uint16. // LastDelay return last history record. if proxy is not alive, return the max value of uint16.
func (p *Proxy) LastDelay() (delay uint16) { func (p *Proxy) LastDelay() (delay uint16) {
var max uint16 = 0xffff var max uint16 = 0xffff
if atomic.LoadUint32(&p.alive) == 0 { if !p.alive.Load() {
return max return max
} }
@ -163,11 +160,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
// URLTest get the delay for the specified URL // URLTest get the delay for the specified URL
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
defer func() { defer func() {
if err == nil { p.alive.Store(err == nil)
atomic.StoreUint32(&p.alive, 1)
} else {
atomic.StoreUint32(&p.alive, 0)
}
record := C.DelayHistory{Time: time.Now()} record := C.DelayHistory{Time: time.Now()}
if err == nil { if err == nil {
record.Delay = t record.Delay = t
@ -223,5 +216,5 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
} }
func NewProxy(adapter C.ProxyAdapter) *Proxy { func NewProxy(adapter C.ProxyAdapter) *Proxy {
return &Proxy{adapter, queue.New(10), 1} return &Proxy{adapter, queue.New(10), atomic.NewBool(true)}
} }

View File

@ -31,6 +31,7 @@ type HttpOption struct {
UserName string `proxy:"username,omitempty"` UserName string `proxy:"username,omitempty"`
Password string `proxy:"password,omitempty"` Password string `proxy:"password,omitempty"`
TLS bool `proxy:"tls,omitempty"` TLS bool `proxy:"tls,omitempty"`
SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
} }
@ -114,10 +115,14 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
func NewHttp(option HttpOption) *Http { func NewHttp(option HttpOption) *Http {
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
sni := option.Server
if option.SNI != "" {
sni = option.SNI
}
tlsConfig = &tls.Config{ tlsConfig = &tls.Config{
InsecureSkipVerify: option.SkipCertVerify, InsecureSkipVerify: option.SkipCertVerify,
ClientSessionCache: getClientSessionCache(), ClientSessionCache: getClientSessionCache(),
ServerName: option.Server, ServerName: sni,
} }
} }

View File

@ -11,11 +11,13 @@ func ParseProxy(mapping map[string]interface{}) (C.Proxy, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true})
proxyType, existType := mapping["type"].(string) proxyType, existType := mapping["type"].(string)
if !existType { if !existType {
return nil, fmt.Errorf("Missing type") return nil, fmt.Errorf("missing type")
} }
var proxy C.ProxyAdapter var (
err := fmt.Errorf("Cannot parse") proxy C.ProxyAdapter
err error
)
switch proxyType { switch proxyType {
case "ss": case "ss":
ssOption := &ShadowSocksOption{} ssOption := &ShadowSocksOption{}
@ -72,7 +74,7 @@ func ParseProxy(mapping map[string]interface{}) (C.Proxy, error) {
} }
proxy, err = NewTrojan(*trojanOption) proxy, err = NewTrojan(*trojanOption)
default: default:
return nil, fmt.Errorf("Unsupport proxy type: %s", proxyType) return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
} }
if err != nil { if err != nil {

View File

@ -40,7 +40,7 @@ type ShadowSocksOption struct {
} }
type simpleObfsOption struct { type simpleObfsOption struct {
Mode string `obfs:"mode"` Mode string `obfs:"mode,omitempty"`
Host string `obfs:"host,omitempty"` Host string `obfs:"host,omitempty"`
} }

View File

@ -12,12 +12,13 @@ import (
"github.com/Dreamacro/clash/component/ssr/protocol" "github.com/Dreamacro/clash/component/ssr/protocol"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/go-shadowsocks2/core" "github.com/Dreamacro/go-shadowsocks2/core"
"github.com/Dreamacro/go-shadowsocks2/shadowaead"
"github.com/Dreamacro/go-shadowsocks2/shadowstream" "github.com/Dreamacro/go-shadowsocks2/shadowstream"
) )
type ShadowSocksR struct { type ShadowSocksR struct {
*Base *Base
cipher *core.StreamCipher cipher core.Cipher
obfs obfs.Obfs obfs obfs.Obfs
protocol protocol.Protocol protocol protocol.Protocol
} }
@ -36,17 +37,22 @@ type ShadowSocksROption struct {
} }
func (ssr *ShadowSocksR) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (ssr *ShadowSocksR) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
c = obfs.NewConn(c, ssr.obfs) c = ssr.obfs.StreamConn(c)
c = ssr.cipher.StreamConn(c) c = ssr.cipher.StreamConn(c)
conn, ok := c.(*shadowstream.Conn) var (
if !ok { iv []byte
return nil, fmt.Errorf("invalid connection type") err error
} )
iv, err := conn.ObtainWriteIV() switch conn := c.(type) {
case *shadowstream.Conn:
iv, err = conn.ObtainWriteIV()
if err != nil { if err != nil {
return nil, err return nil, err
} }
c = protocol.NewConn(c, ssr.protocol, iv) case *shadowaead.Conn:
return nil, fmt.Errorf("invalid connection type")
}
c = ssr.protocol.StreamConn(c, iv)
_, err = c.Write(serializesSocksAddr(metadata)) _, err = c.Write(serializesSocksAddr(metadata))
return c, err return c, err
} }
@ -74,7 +80,7 @@ func (ssr *ShadowSocksR) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
} }
pc = ssr.cipher.PacketConn(pc) pc = ssr.cipher.PacketConn(pc)
pc = protocol.NewPacketConn(pc, ssr.protocol) pc = ssr.protocol.PacketConn(pc)
return newPacketConn(&ssPacketConn{PacketConn: pc, rAddr: addr}, ssr), nil return newPacketConn(&ssPacketConn{PacketConn: pc, rAddr: addr}, ssr), nil
} }
@ -90,19 +96,29 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
password := option.Password password := option.Password
coreCiph, err := core.PickCipher(cipher, nil, password) coreCiph, err := core.PickCipher(cipher, nil, password)
if err != nil { if err != nil {
return nil, fmt.Errorf("ssr %s initialize cipher error: %w", addr, err) return nil, fmt.Errorf("ssr %s initialize error: %w", addr, err)
} }
var (
ivSize int
key []byte
)
if option.Cipher == "dummy" {
ivSize = 0
key = core.Kdf(option.Password, 16)
} else {
ciph, ok := coreCiph.(*core.StreamCipher) ciph, ok := coreCiph.(*core.StreamCipher)
if !ok { if !ok {
return nil, fmt.Errorf("%s is not a supported stream cipher in ssr", cipher) return nil, fmt.Errorf("%s is not dummy or a supported stream cipher in ssr", cipher)
}
ivSize = ciph.IVSize()
key = ciph.Key
} }
obfs, err := obfs.PickObfs(option.Obfs, &obfs.Base{ obfs, obfsOverhead, err := obfs.PickObfs(option.Obfs, &obfs.Base{
IVSize: ciph.IVSize(),
Key: ciph.Key,
HeadLen: 30,
Host: option.Server, Host: option.Server,
Port: option.Port, Port: option.Port,
Key: key,
IVSize: ivSize,
Param: option.ObfsParam, Param: option.ObfsParam,
}) })
if err != nil { if err != nil {
@ -110,15 +126,13 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
} }
protocol, err := protocol.PickProtocol(option.Protocol, &protocol.Base{ protocol, err := protocol.PickProtocol(option.Protocol, &protocol.Base{
IV: nil, Key: key,
Key: ciph.Key, Overhead: obfsOverhead,
TCPMss: 1460,
Param: option.ProtocolParam, Param: option.ProtocolParam,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err) return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err)
} }
protocol.SetOverhead(obfs.GetObfsOverhead() + protocol.GetProtocolOverhead())
return &ShadowSocksR{ return &ShadowSocksR{
Base: &Base{ Base: &Base{
@ -127,7 +141,7 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
tp: C.ShadowsocksR, tp: C.ShadowsocksR,
udp: option.UDP, udp: option.UDP,
}, },
cipher: ciph, cipher: coreCiph,
obfs: obfs, obfs: obfs,
protocol: protocol, protocol: protocol,
}, nil }, nil

View File

@ -16,7 +16,9 @@ import (
type Snell struct { type Snell struct {
*Base *Base
psk []byte psk []byte
pool *snell.Pool
obfsOption *simpleObfsOption obfsOption *simpleObfsOption
version int
} }
type SnellOption struct { type SnellOption struct {
@ -24,24 +26,47 @@ type SnellOption struct {
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port"` Port int `proxy:"port"`
Psk string `proxy:"psk"` Psk string `proxy:"psk"`
Version int `proxy:"version,omitempty"`
ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"`
} }
func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { type streamOption struct {
switch s.obfsOption.Mode { psk []byte
case "tls": version int
c = obfs.NewTLSObfs(c, s.obfsOption.Host) addr string
case "http": obfsOption *simpleObfsOption
_, port, _ := net.SplitHostPort(s.addr)
c = obfs.NewHTTPObfs(c, s.obfsOption.Host, port)
} }
c = snell.StreamConn(c, s.psk)
func streamConn(c net.Conn, option streamOption) *snell.Snell {
switch option.obfsOption.Mode {
case "tls":
c = obfs.NewTLSObfs(c, option.obfsOption.Host)
case "http":
_, port, _ := net.SplitHostPort(option.addr)
c = obfs.NewHTTPObfs(c, option.obfsOption.Host, port)
}
return snell.StreamConn(c, option.psk, option.version)
}
func (s *Snell) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})
port, _ := strconv.Atoi(metadata.DstPort) port, _ := strconv.Atoi(metadata.DstPort)
err := snell.WriteHeader(c, metadata.String(), uint(port)) err := snell.WriteHeader(c, metadata.String(), uint(port), s.version)
return c, err return c, err
} }
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
if s.version == snell.Version2 {
c, err := s.pool.Get()
if err != nil {
return nil, err
}
port, _ := strconv.Atoi(metadata.DstPort)
err = snell.WriteHeader(c, metadata.String(), uint(port), s.version)
return NewConn(c, s), err
}
c, err := dialer.DialContext(ctx, "tcp", s.addr) c, err := dialer.DialContext(ctx, "tcp", s.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err) return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
@ -62,11 +87,22 @@ func NewSnell(option SnellOption) (*Snell, error) {
return nil, fmt.Errorf("snell %s initialize obfs error: %w", addr, err) return nil, fmt.Errorf("snell %s initialize obfs error: %w", addr, err)
} }
if obfsOption.Mode != "tls" && obfsOption.Mode != "http" { switch obfsOption.Mode {
case "tls", "http", "":
break
default:
return nil, fmt.Errorf("snell %s obfs mode error: %s", addr, obfsOption.Mode) return nil, fmt.Errorf("snell %s obfs mode error: %s", addr, obfsOption.Mode)
} }
return &Snell{ // backward compatible
if option.Version == 0 {
option.Version = snell.DefaultSnellVersion
}
if option.Version != snell.Version1 && option.Version != snell.Version2 {
return nil, fmt.Errorf("snell version error: %d", option.Version)
}
s := &Snell{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
@ -74,5 +110,19 @@ func NewSnell(option SnellOption) (*Snell, error) {
}, },
psk: psk, psk: psk,
obfsOption: obfsOption, obfsOption: obfsOption,
}, nil version: option.Version,
}
if option.Version == snell.Version2 {
s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) {
c, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
tcpKeepAlive(c)
return streamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil
})
}
return s, nil
} }

View File

@ -122,7 +122,21 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
pc.Close() pc.Close()
}() }()
return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindAddr.UDPAddr(), tcpConn: c}, ss), nil // Support unspecified UDP bind address.
bindUDPAddr := bindAddr.UDPAddr()
if bindUDPAddr == nil {
err = errors.New("invalid UDP bind address")
return
} else if bindUDPAddr.IP.IsUnspecified() {
serverAddr, err := resolveUDPAddr("udp", ss.Addr())
if err != nil {
return nil, err
}
bindUDPAddr.IP = serverAddr.IP
}
return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil
} }
func NewSocks5(option Socks5Option) *Socks5 { func NewSocks5(option Socks5Option) *Socks5 {

View File

@ -32,6 +32,7 @@ type VmessOption struct {
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
Network string `proxy:"network,omitempty"` Network string `proxy:"network,omitempty"`
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
WSPath string `proxy:"ws-path,omitempty"` WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
@ -44,6 +45,11 @@ type HTTPOptions struct {
Headers map[string][]string `proxy:"headers,omitempty"` Headers map[string][]string `proxy:"headers,omitempty"`
} }
type HTTP2Options struct {
Host []string `proxy:"host,omitempty"`
Path string `proxy:"path,omitempty"`
}
func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error var err error
switch v.option.Network { switch v.option.Network {
@ -71,6 +77,25 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
} }
c, err = vmess.StreamWebsocketConn(c, wsOpts) c, err = vmess.StreamWebsocketConn(c, wsOpts)
case "http": case "http":
// readability first, so just copy default TLS logic
if v.option.TLS {
host, _, _ := net.SplitHostPort(v.addr)
tlsOpts := &vmess.TLSConfig{
Host: host,
SkipCertVerify: v.option.SkipCertVerify,
SessionCache: getClientSessionCache(),
}
if v.option.ServerName != "" {
tlsOpts.Host = v.option.ServerName
}
c, err = vmess.StreamTLSConn(c, tlsOpts)
if err != nil {
return nil, err
}
}
host, _, _ := net.SplitHostPort(v.addr) host, _, _ := net.SplitHostPort(v.addr)
httpOpts := &vmess.HTTPConfig{ httpOpts := &vmess.HTTPConfig{
Host: host, Host: host,
@ -80,6 +105,30 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
} }
c = vmess.StreamHTTPConn(c, httpOpts) c = vmess.StreamHTTPConn(c, httpOpts)
case "h2":
host, _, _ := net.SplitHostPort(v.addr)
tlsOpts := vmess.TLSConfig{
Host: host,
SkipCertVerify: v.option.SkipCertVerify,
SessionCache: getClientSessionCache(),
NextProtos: []string{"h2"},
}
if v.option.ServerName != "" {
tlsOpts.Host = v.option.ServerName
}
c, err = vmess.StreamTLSConn(c, &tlsOpts)
if err != nil {
return nil, err
}
h2Opts := &vmess.H2Config{
Hosts: v.option.HTTP2Opts.Host,
Path: v.option.HTTP2Opts.Path,
}
c, err = vmess.StreamH2Conn(c, h2Opts)
default: default:
// handle TLS // handle TLS
if v.option.TLS { if v.option.TLS {
@ -152,13 +201,16 @@ func NewVmess(option VmessOption) (*Vmess, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if option.Network == "h2" && !option.TLS {
return nil, fmt.Errorf("TLS must be true with h2 network")
}
return &Vmess{ return &Vmess{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess, tp: C.Vmess,
udp: true, udp: option.UDP,
}, },
client: client, client: client,
option: &option, option: &option,

View File

@ -11,10 +11,14 @@ const (
defaultGetProxiesDuration = time.Second * 5 defaultGetProxiesDuration = time.Second * 5
) )
func getProvidersProxies(providers []provider.ProxyProvider) []C.Proxy { func getProvidersProxies(providers []provider.ProxyProvider, touch bool) []C.Proxy {
proxies := []C.Proxy{} proxies := []C.Proxy{}
for _, provider := range providers { for _, provider := range providers {
if touch {
proxies = append(proxies, provider.ProxiesWithTouch()...)
} else {
proxies = append(proxies, provider.Proxies()...) proxies = append(proxies, provider.Proxies()...)
} }
}
return proxies return proxies
} }

View File

@ -12,17 +12,18 @@ import (
type Fallback struct { type Fallback struct {
*outbound.Base *outbound.Base
disableUDP bool
single *singledo.Single single *singledo.Single
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
func (f *Fallback) Now() string { func (f *Fallback) Now() string {
proxy := f.findAliveProxy() proxy := f.findAliveProxy(false)
return proxy.Name() return proxy.Name()
} }
func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxy := f.findAliveProxy() proxy := f.findAliveProxy(true)
c, err := proxy.DialContext(ctx, metadata) c, err := proxy.DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(f) c.AppendToChains(f)
@ -31,7 +32,7 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
} }
func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
proxy := f.findAliveProxy() proxy := f.findAliveProxy(true)
pc, err := proxy.DialUDP(metadata) pc, err := proxy.DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(f) pc.AppendToChains(f)
@ -40,13 +41,17 @@ func (f *Fallback) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
} }
func (f *Fallback) SupportUDP() bool { func (f *Fallback) SupportUDP() bool {
proxy := f.findAliveProxy() if f.disableUDP {
return false
}
proxy := f.findAliveProxy(false)
return proxy.SupportUDP() return proxy.SupportUDP()
} }
func (f *Fallback) MarshalJSON() ([]byte, error) { func (f *Fallback) MarshalJSON() ([]byte, error) {
var all []string var all []string
for _, proxy := range f.proxies() { for _, proxy := range f.proxies(false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
return json.Marshal(map[string]interface{}{ return json.Marshal(map[string]interface{}{
@ -57,33 +62,34 @@ func (f *Fallback) MarshalJSON() ([]byte, error) {
} }
func (f *Fallback) Unwrap(metadata *C.Metadata) C.Proxy { func (f *Fallback) Unwrap(metadata *C.Metadata) C.Proxy {
proxy := f.findAliveProxy() proxy := f.findAliveProxy(true)
return proxy return proxy
} }
func (f *Fallback) proxies() []C.Proxy { func (f *Fallback) proxies(touch bool) []C.Proxy {
elm, _, _ := f.single.Do(func() (interface{}, error) { elm, _, _ := f.single.Do(func() (interface{}, error) {
return getProvidersProxies(f.providers), nil return getProvidersProxies(f.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm.([]C.Proxy)
} }
func (f *Fallback) findAliveProxy() C.Proxy { func (f *Fallback) findAliveProxy(touch bool) C.Proxy {
proxies := f.proxies() proxies := f.proxies(touch)
for _, proxy := range proxies { for _, proxy := range proxies {
if proxy.Alive() { if proxy.Alive() {
return proxy return proxy
} }
} }
return f.proxies()[0] return proxies[0]
} }
func NewFallback(name string, providers []provider.ProxyProvider) *Fallback { func NewFallback(options *GroupCommonOption, providers []provider.ProxyProvider) *Fallback {
return &Fallback{ return &Fallback{
Base: outbound.NewBase(name, "", C.Fallback, false), Base: outbound.NewBase(options.Name, "", C.Fallback, false),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle(defaultGetProxiesDuration),
providers: providers, providers: providers,
disableUDP: options.DisableUDP,
} }
} }

View File

@ -3,6 +3,8 @@ package outboundgroup
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt"
"net" "net"
"github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outbound"
@ -14,11 +16,25 @@ import (
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy
type LoadBalance struct { type LoadBalance struct {
*outbound.Base *outbound.Base
disableUDP bool
single *singledo.Single single *singledo.Single
maxRetry int
providers []provider.ProxyProvider providers []provider.ProxyProvider
strategyFn strategyFn
}
var errStrategy = errors.New("unsupported strategy")
func parseStrategy(config map[string]interface{}) string {
if elm, ok := config["strategy"]; ok {
if strategy, ok := elm.(string); ok {
return strategy
}
}
return "consistent-hashing"
} }
func getKey(metadata *C.Metadata) string { func getKey(metadata *C.Metadata) string {
@ -78,14 +94,31 @@ func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (pc C.PacketConn, err error
} }
func (lb *LoadBalance) SupportUDP() bool { func (lb *LoadBalance) SupportUDP() bool {
return true return !lb.disableUDP
} }
func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { func strategyRoundRobin() strategyFn {
idx := 0
return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
length := len(proxies)
for i := 0; i < length; i++ {
idx = (idx + 1) % length
proxy := proxies[idx]
if proxy.Alive() {
return proxy
}
}
return proxies[0]
}
}
func strategyConsistentHashing() strategyFn {
maxRetry := 5
return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
key := uint64(murmur3.Sum32([]byte(getKey(metadata)))) key := uint64(murmur3.Sum32([]byte(getKey(metadata))))
proxies := lb.proxies()
buckets := int32(len(proxies)) buckets := int32(len(proxies))
for i := 0; i < lb.maxRetry; i, key = i+1, key+1 { for i := 0; i < maxRetry; i, key = i+1, key+1 {
idx := jumpHash(key, buckets) idx := jumpHash(key, buckets)
proxy := proxies[idx] proxy := proxies[idx]
if proxy.Alive() { if proxy.Alive() {
@ -95,10 +128,16 @@ func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
return proxies[0] return proxies[0]
} }
}
func (lb *LoadBalance) proxies() []C.Proxy { func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
proxies := lb.proxies(true)
return lb.strategyFn(proxies, metadata)
}
func (lb *LoadBalance) proxies(touch bool) []C.Proxy {
elm, _, _ := lb.single.Do(func() (interface{}, error) { elm, _, _ := lb.single.Do(func() (interface{}, error) {
return getProvidersProxies(lb.providers), nil return getProvidersProxies(lb.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm.([]C.Proxy)
@ -106,7 +145,7 @@ func (lb *LoadBalance) proxies() []C.Proxy {
func (lb *LoadBalance) MarshalJSON() ([]byte, error) { func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
var all []string var all []string
for _, proxy := range lb.proxies() { for _, proxy := range lb.proxies(false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
return json.Marshal(map[string]interface{}{ return json.Marshal(map[string]interface{}{
@ -115,11 +154,21 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
}) })
} }
func NewLoadBalance(name string, providers []provider.ProxyProvider) *LoadBalance { func NewLoadBalance(options *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
var strategyFn strategyFn
switch strategy {
case "consistent-hashing":
strategyFn = strategyConsistentHashing()
case "round-robin":
strategyFn = strategyRoundRobin()
default:
return nil, fmt.Errorf("%w: %s", errStrategy, strategy)
}
return &LoadBalance{ return &LoadBalance{
Base: outbound.NewBase(name, "", C.LoadBalance, false), Base: outbound.NewBase(options.Name, "", C.LoadBalance, false),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle(defaultGetProxiesDuration),
maxRetry: 3,
providers: providers, providers: providers,
} strategyFn: strategyFn,
disableUDP: options.DisableUDP,
}, nil
} }

View File

@ -12,7 +12,6 @@ import (
var ( var (
errFormat = errors.New("format error") errFormat = errors.New("format error")
errType = errors.New("unsupport type") errType = errors.New("unsupport type")
errMissUse = errors.New("`use` field should not be empty")
errMissProxy = errors.New("`use` or `proxies` missing") errMissProxy = errors.New("`use` or `proxies` missing")
errMissHealthCheck = errors.New("`url` or `interval` missing") errMissHealthCheck = errors.New("`url` or `interval` missing")
errDuplicateProvider = errors.New("`duplicate provider name") errDuplicateProvider = errors.New("`duplicate provider name")
@ -25,12 +24,16 @@ type GroupCommonOption struct {
Use []string `group:"use,omitempty"` Use []string `group:"use,omitempty"`
URL string `group:"url,omitempty"` URL string `group:"url,omitempty"`
Interval int `group:"interval,omitempty"` Interval int `group:"interval,omitempty"`
Lazy bool `group:"lazy,omitempty"`
DisableUDP bool `group:"disable-udp,omitempty"`
} }
func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy, providersMap map[string]provider.ProxyProvider) (C.ProxyAdapter, error) { func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy, providersMap map[string]provider.ProxyProvider) (C.ProxyAdapter, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "group", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "group", WeaklyTypedInput: true})
groupOption := &GroupCommonOption{} groupOption := &GroupCommonOption{
Lazy: true,
}
if err := decoder.Decode(config, groupOption); err != nil { if err := decoder.Decode(config, groupOption); err != nil {
return nil, errFormat return nil, errFormat
} }
@ -55,7 +58,7 @@ func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy,
// if Use not empty, drop health check options // if Use not empty, drop health check options
if len(groupOption.Use) != 0 { if len(groupOption.Use) != 0 {
hc := provider.NewHealthCheck(ps, "", 0) hc := provider.NewHealthCheck(ps, "", 0, true)
pd, err := provider.NewCompatibleProvider(groupName, ps, hc) pd, err := provider.NewCompatibleProvider(groupName, ps, hc)
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,9 +66,13 @@ func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy,
providers = append(providers, pd) providers = append(providers, pd)
} else { } else {
if _, ok := providersMap[groupName]; ok {
return nil, errDuplicateProvider
}
// select don't need health check // select don't need health check
if groupOption.Type == "select" || groupOption.Type == "relay" { if groupOption.Type == "select" || groupOption.Type == "relay" {
hc := provider.NewHealthCheck(ps, "", 0) hc := provider.NewHealthCheck(ps, "", 0, true)
pd, err := provider.NewCompatibleProvider(groupName, ps, hc) pd, err := provider.NewCompatibleProvider(groupName, ps, hc)
if err != nil { if err != nil {
return nil, err return nil, err
@ -78,7 +85,7 @@ func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy,
return nil, errMissHealthCheck return nil, errMissHealthCheck
} }
hc := provider.NewHealthCheck(ps, groupOption.URL, uint(groupOption.Interval)) hc := provider.NewHealthCheck(ps, groupOption.URL, uint(groupOption.Interval), groupOption.Lazy)
pd, err := provider.NewCompatibleProvider(groupName, ps, hc) pd, err := provider.NewCompatibleProvider(groupName, ps, hc)
if err != nil { if err != nil {
return nil, err return nil, err
@ -102,15 +109,16 @@ func ParseProxyGroup(config map[string]interface{}, proxyMap map[string]C.Proxy,
switch groupOption.Type { switch groupOption.Type {
case "url-test": case "url-test":
opts := parseURLTestOption(config) opts := parseURLTestOption(config)
group = NewURLTest(groupName, providers, opts...) group = NewURLTest(groupOption, providers, opts...)
case "select": case "select":
group = NewSelector(groupName, providers) group = NewSelector(groupOption, providers)
case "fallback": case "fallback":
group = NewFallback(groupName, providers) group = NewFallback(groupOption, providers)
case "load-balance": case "load-balance":
group = NewLoadBalance(groupName, providers) strategy := parseStrategy(config)
return NewLoadBalance(groupOption, providers, strategy)
case "relay": case "relay":
group = NewRelay(groupName, providers) group = NewRelay(groupOption, providers)
default: default:
return nil, fmt.Errorf("%w: %s", errType, groupOption.Type) return nil, fmt.Errorf("%w: %s", errType, groupOption.Type)
} }

View File

@ -20,9 +20,9 @@ type Relay struct {
} }
func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxies := r.proxies(metadata) proxies := r.proxies(metadata, true)
if len(proxies) == 0 { if len(proxies) == 0 {
return nil, errors.New("Proxy does not exist") return nil, errors.New("proxy does not exist")
} }
first := proxies[0] first := proxies[0]
last := proxies[len(proxies)-1] last := proxies[len(proxies)-1]
@ -58,7 +58,7 @@ func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
func (r *Relay) MarshalJSON() ([]byte, error) { func (r *Relay) MarshalJSON() ([]byte, error) {
var all []string var all []string
for _, proxy := range r.rawProxies() { for _, proxy := range r.rawProxies(false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
return json.Marshal(map[string]interface{}{ return json.Marshal(map[string]interface{}{
@ -67,16 +67,16 @@ func (r *Relay) MarshalJSON() ([]byte, error) {
}) })
} }
func (r *Relay) rawProxies() []C.Proxy { func (r *Relay) rawProxies(touch bool) []C.Proxy {
elm, _, _ := r.single.Do(func() (interface{}, error) { elm, _, _ := r.single.Do(func() (interface{}, error) {
return getProvidersProxies(r.providers), nil return getProvidersProxies(r.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm.([]C.Proxy)
} }
func (r *Relay) proxies(metadata *C.Metadata) []C.Proxy { func (r *Relay) proxies(metadata *C.Metadata, touch bool) []C.Proxy {
proxies := r.rawProxies() proxies := r.rawProxies(touch)
for n, proxy := range proxies { for n, proxy := range proxies {
subproxy := proxy.Unwrap(metadata) subproxy := proxy.Unwrap(metadata)
@ -89,9 +89,9 @@ func (r *Relay) proxies(metadata *C.Metadata) []C.Proxy {
return proxies return proxies
} }
func NewRelay(name string, providers []provider.ProxyProvider) *Relay { func NewRelay(options *GroupCommonOption, providers []provider.ProxyProvider) *Relay {
return &Relay{ return &Relay{
Base: outbound.NewBase(name, "", C.Relay, false), Base: outbound.NewBase(options.Name, "", C.Relay, false),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle(defaultGetProxiesDuration),
providers: providers, providers: providers,
} }

View File

@ -13,13 +13,14 @@ import (
type Selector struct { type Selector struct {
*outbound.Base *outbound.Base
disableUDP bool
single *singledo.Single single *singledo.Single
selected string selected string
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := s.selectedProxy().DialContext(ctx, metadata) c, err := s.selectedProxy(true).DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(s) c.AppendToChains(s)
} }
@ -27,7 +28,7 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
} }
func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, err := s.selectedProxy().DialUDP(metadata) pc, err := s.selectedProxy(true).DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(s) pc.AppendToChains(s)
} }
@ -35,12 +36,16 @@ func (s *Selector) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
} }
func (s *Selector) SupportUDP() bool { func (s *Selector) SupportUDP() bool {
return s.selectedProxy().SupportUDP() if s.disableUDP {
return false
}
return s.selectedProxy(false).SupportUDP()
} }
func (s *Selector) MarshalJSON() ([]byte, error) { func (s *Selector) MarshalJSON() ([]byte, error) {
var all []string var all []string
for _, proxy := range getProvidersProxies(s.providers) { for _, proxy := range getProvidersProxies(s.providers, false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
@ -52,11 +57,11 @@ func (s *Selector) MarshalJSON() ([]byte, error) {
} }
func (s *Selector) Now() string { func (s *Selector) Now() string {
return s.selectedProxy().Name() return s.selectedProxy(false).Name()
} }
func (s *Selector) Set(name string) error { func (s *Selector) Set(name string) error {
for _, proxy := range getProvidersProxies(s.providers) { for _, proxy := range getProvidersProxies(s.providers, false) {
if proxy.Name() == name { if proxy.Name() == name {
s.selected = name s.selected = name
s.single.Reset() s.single.Reset()
@ -64,16 +69,16 @@ func (s *Selector) Set(name string) error {
} }
} }
return errors.New("Proxy does not exist") return errors.New("proxy not exist")
} }
func (s *Selector) Unwrap(metadata *C.Metadata) C.Proxy { func (s *Selector) Unwrap(metadata *C.Metadata) C.Proxy {
return s.selectedProxy() return s.selectedProxy(true)
} }
func (s *Selector) selectedProxy() C.Proxy { func (s *Selector) selectedProxy(touch bool) C.Proxy {
elm, _, _ := s.single.Do(func() (interface{}, error) { elm, _, _ := s.single.Do(func() (interface{}, error) {
proxies := getProvidersProxies(s.providers) proxies := getProvidersProxies(s.providers, touch)
for _, proxy := range proxies { for _, proxy := range proxies {
if proxy.Name() == s.selected { if proxy.Name() == s.selected {
return proxy, nil return proxy, nil
@ -86,12 +91,13 @@ func (s *Selector) selectedProxy() C.Proxy {
return elm.(C.Proxy) return elm.(C.Proxy)
} }
func NewSelector(name string, providers []provider.ProxyProvider) *Selector { func NewSelector(options *GroupCommonOption, providers []provider.ProxyProvider) *Selector {
selected := providers[0].Proxies()[0].Name() selected := providers[0].Proxies()[0].Name()
return &Selector{ return &Selector{
Base: outbound.NewBase(name, "", C.Selector, false), Base: outbound.NewBase(options.Name, "", C.Selector, false),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle(defaultGetProxiesDuration),
providers: providers, providers: providers,
selected: selected, selected: selected,
disableUDP: options.DisableUDP,
} }
} }

View File

@ -22,6 +22,7 @@ func urlTestWithTolerance(tolerance uint16) urlTestOption {
type URLTest struct { type URLTest struct {
*outbound.Base *outbound.Base
tolerance uint16 tolerance uint16
disableUDP bool
fastNode C.Proxy fastNode C.Proxy
single *singledo.Single single *singledo.Single
fastSingle *singledo.Single fastSingle *singledo.Single
@ -29,11 +30,11 @@ type URLTest struct {
} }
func (u *URLTest) Now() string { func (u *URLTest) Now() string {
return u.fast().Name() return u.fast(false).Name()
} }
func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
c, err = u.fast().DialContext(ctx, metadata) c, err = u.fast(true).DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(u) c.AppendToChains(u)
} }
@ -41,7 +42,7 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co
} }
func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
pc, err := u.fast().DialUDP(metadata) pc, err := u.fast(true).DialUDP(metadata)
if err == nil { if err == nil {
pc.AppendToChains(u) pc.AppendToChains(u)
} }
@ -49,20 +50,20 @@ func (u *URLTest) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
} }
func (u *URLTest) Unwrap(metadata *C.Metadata) C.Proxy { func (u *URLTest) Unwrap(metadata *C.Metadata) C.Proxy {
return u.fast() return u.fast(true)
} }
func (u *URLTest) proxies() []C.Proxy { func (u *URLTest) proxies(touch bool) []C.Proxy {
elm, _, _ := u.single.Do(func() (interface{}, error) { elm, _, _ := u.single.Do(func() (interface{}, error) {
return getProvidersProxies(u.providers), nil return getProvidersProxies(u.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm.([]C.Proxy)
} }
func (u *URLTest) fast() C.Proxy { func (u *URLTest) fast(touch bool) C.Proxy {
elm, _, _ := u.fastSingle.Do(func() (interface{}, error) { elm, _, _ := u.fastSingle.Do(func() (interface{}, error) {
proxies := u.proxies() proxies := u.proxies(touch)
fast := proxies[0] fast := proxies[0]
min := fast.LastDelay() min := fast.LastDelay()
for _, proxy := range proxies[1:] { for _, proxy := range proxies[1:] {
@ -89,12 +90,16 @@ func (u *URLTest) fast() C.Proxy {
} }
func (u *URLTest) SupportUDP() bool { func (u *URLTest) SupportUDP() bool {
return u.fast().SupportUDP() if u.disableUDP {
return false
}
return u.fast(false).SupportUDP()
} }
func (u *URLTest) MarshalJSON() ([]byte, error) { func (u *URLTest) MarshalJSON() ([]byte, error) {
var all []string var all []string
for _, proxy := range u.proxies() { for _, proxy := range u.proxies(false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
return json.Marshal(map[string]interface{}{ return json.Marshal(map[string]interface{}{
@ -117,12 +122,13 @@ func parseURLTestOption(config map[string]interface{}) []urlTestOption {
return opts return opts
} }
func NewURLTest(name string, providers []provider.ProxyProvider, options ...urlTestOption) *URLTest { func NewURLTest(commonOptions *GroupCommonOption, providers []provider.ProxyProvider, options ...urlTestOption) *URLTest {
urlTest := &URLTest{ urlTest := &URLTest{
Base: outbound.NewBase(name, "", C.URLTest, false), Base: outbound.NewBase(commonOptions.Name, "", C.URLTest, false),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle(defaultGetProxiesDuration),
fastSingle: singledo.NewSingle(time.Second * 10), fastSingle: singledo.NewSingle(time.Second * 10),
providers: providers, providers: providers,
disableUDP: commonOptions.DisableUDP,
} }
for _, option := range options { for _, option := range options {

View File

@ -74,7 +74,7 @@ func (f *fetcher) Initial() (interface{}, error) {
} }
} }
if f.vehicle.Type() != File { if f.vehicle.Type() != File && !isLocal {
if err := safeWrite(f.vehicle.Path(), buf); err != nil { if err := safeWrite(f.vehicle.Path(), buf); err != nil {
return nil, err return nil, err
} }
@ -108,9 +108,11 @@ func (f *fetcher) Update() (interface{}, bool, error) {
return nil, false, err return nil, false, err
} }
if f.vehicle.Type() != File {
if err := safeWrite(f.vehicle.Path(), buf); err != nil { if err := safeWrite(f.vehicle.Path(), buf); err != nil {
return nil, false, err return nil, false, err
} }
}
f.updatedAt = &now f.updatedAt = &now
f.hash = hash f.hash = hash

View File

@ -2,9 +2,12 @@ package provider
import ( import (
"context" "context"
"sync"
"time" "time"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"go.uber.org/atomic"
) )
const ( const (
@ -20,6 +23,8 @@ type HealthCheck struct {
url string url string
proxies []C.Proxy proxies []C.Proxy
interval uint interval uint
lazy bool
lastTouch *atomic.Int64
done chan struct{} done chan struct{}
} }
@ -30,7 +35,10 @@ func (hc *HealthCheck) process() {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
now := time.Now().Unix()
if !hc.lazy || now-hc.lastTouch.Load() < int64(hc.interval) {
hc.check() hc.check()
}
case <-hc.done: case <-hc.done:
ticker.Stop() ticker.Stop()
return return
@ -46,13 +54,24 @@ func (hc *HealthCheck) auto() bool {
return hc.interval != 0 return hc.interval != 0
} }
func (hc *HealthCheck) check() { func (hc *HealthCheck) touch() {
ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout) hc.lastTouch.Store(time.Now().Unix())
for _, proxy := range hc.proxies {
go proxy.URLTest(ctx, hc.url)
} }
<-ctx.Done() func (hc *HealthCheck) check() {
ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout)
wg := &sync.WaitGroup{}
for _, proxy := range hc.proxies {
wg.Add(1)
go func(p C.Proxy) {
p.URLTest(ctx, hc.url)
wg.Done()
}(proxy)
}
wg.Wait()
cancel() cancel()
} }
@ -60,11 +79,13 @@ func (hc *HealthCheck) close() {
hc.done <- struct{}{} hc.done <- struct{}{}
} }
func NewHealthCheck(proxies []C.Proxy, url string, interval uint) *HealthCheck { func NewHealthCheck(proxies []C.Proxy, url string, interval uint, lazy bool) *HealthCheck {
return &HealthCheck{ return &HealthCheck{
proxies: proxies, proxies: proxies,
url: url, url: url,
interval: interval, interval: interval,
lazy: lazy,
lastTouch: atomic.NewInt64(0),
done: make(chan struct{}, 1), done: make(chan struct{}, 1),
} }
} }

View File

@ -17,6 +17,7 @@ type healthCheckSchema struct {
Enable bool `provider:"enable"` Enable bool `provider:"enable"`
URL string `provider:"url"` URL string `provider:"url"`
Interval int `provider:"interval"` Interval int `provider:"interval"`
Lazy bool `provider:"lazy,omitempty"`
} }
type proxyProviderSchema struct { type proxyProviderSchema struct {
@ -30,7 +31,11 @@ type proxyProviderSchema struct {
func ParseProxyProvider(name string, mapping map[string]interface{}) (ProxyProvider, error) { func ParseProxyProvider(name string, mapping map[string]interface{}) (ProxyProvider, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true})
schema := &proxyProviderSchema{} schema := &proxyProviderSchema{
HealthCheck: healthCheckSchema{
Lazy: true,
},
}
if err := decoder.Decode(mapping, schema); err != nil { if err := decoder.Decode(mapping, schema); err != nil {
return nil, err return nil, err
} }
@ -39,7 +44,7 @@ func ParseProxyProvider(name string, mapping map[string]interface{}) (ProxyProvi
if schema.HealthCheck.Enable { if schema.HealthCheck.Enable {
hcInterval = uint(schema.HealthCheck.Interval) hcInterval = uint(schema.HealthCheck.Interval)
} }
hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, hcInterval) hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, hcInterval, schema.HealthCheck.Lazy)
path := C.Path.Resolve(schema.Path) path := C.Path.Resolve(schema.Path)

View File

@ -50,6 +50,9 @@ type Provider interface {
type ProxyProvider interface { type ProxyProvider interface {
Provider Provider
Proxies() []C.Proxy Proxies() []C.Proxy
// ProxiesWithTouch is used to inform the provider that the proxy is actually being used while getting the list of proxies.
// Commonly used in Dial and DialUDP
ProxiesWithTouch() []C.Proxy
HealthCheck() HealthCheck()
} }
@ -112,6 +115,11 @@ func (pp *proxySetProvider) Proxies() []C.Proxy {
return pp.proxies return pp.proxies
} }
func (pp *proxySetProvider) ProxiesWithTouch() []C.Proxy {
pp.healthCheck.touch()
return pp.Proxies()
}
func proxiesParse(buf []byte) (interface{}, error) { func proxiesParse(buf []byte) (interface{}, error) {
schema := &ProxySchema{} schema := &ProxySchema{}
@ -120,20 +128,20 @@ func proxiesParse(buf []byte) (interface{}, error) {
} }
if schema.Proxies == nil { if schema.Proxies == nil {
return nil, errors.New("File must have a `proxies` field") return nil, errors.New("file must have a `proxies` field")
} }
proxies := []C.Proxy{} proxies := []C.Proxy{}
for idx, mapping := range schema.Proxies { for idx, mapping := range schema.Proxies {
proxy, err := outbound.ParseProxy(mapping) proxy, err := outbound.ParseProxy(mapping)
if err != nil { if err != nil {
return nil, fmt.Errorf("Proxy %d error: %w", idx, err) return nil, fmt.Errorf("proxy %d error: %w", idx, err)
} }
proxies = append(proxies, proxy) proxies = append(proxies, proxy)
} }
if len(proxies) == 0 { if len(proxies) == 0 {
return nil, errors.New("File doesn't have any valid proxy") return nil, errors.New("file doesn't have any valid proxy")
} }
return proxies, nil return proxies, nil
@ -223,6 +231,11 @@ func (cp *compatibleProvider) Proxies() []C.Proxy {
return cp.proxies return cp.proxies
} }
func (cp *compatibleProvider) ProxiesWithTouch() []C.Proxy {
cp.healthCheck.touch()
return cp.Proxies()
}
func stopCompatibleProvider(pd *CompatibleProvider) { func stopCompatibleProvider(pd *CompatibleProvider) {
pd.healthCheck.close() pd.healthCheck.close()
} }

View File

@ -107,6 +107,7 @@ func (h *HTTPVehicle) Read() ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close()
buf, err := ioutil.ReadAll(resp.Body) buf, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {

View File

@ -121,7 +121,7 @@ func (c *LruCache) Set(key interface{}, value interface{}) {
c.SetWithExpire(key, value, time.Unix(expires, 0)) c.SetWithExpire(key, value, time.Unix(expires, 0))
} }
// SetWithExpire stores the interface{} representation of a response for a given key and given exires. // SetWithExpire stores the interface{} representation of a response for a given key and given expires.
// The expires time will round to second. // The expires time will round to second.
func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) { func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) {
c.mu.Lock() c.mu.Lock()
@ -146,6 +146,23 @@ func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires tim
c.maybeDeleteOldest() c.maybeDeleteOldest()
} }
// CloneTo clone and overwrite elements to another LruCache
func (c *LruCache) CloneTo(n *LruCache) {
c.mu.Lock()
defer c.mu.Unlock()
n.mu.Lock()
defer n.mu.Unlock()
n.lru = list.New()
n.cache = make(map[interface{}]*list.Element)
for e := c.lru.Front(); e != nil; e = e.Next() {
elm := e.Value.(*entry)
n.cache[elm.key] = n.lru.PushBack(elm)
}
}
func (c *LruCache) get(key interface{}) *entry { func (c *LruCache) get(key interface{}) *entry {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -171,7 +188,7 @@ func (c *LruCache) get(key interface{}) *entry {
} }
// Delete removes the value associated with a key. // Delete removes the value associated with a key.
func (c *LruCache) Delete(key string) { func (c *LruCache) Delete(key interface{}) {
c.mu.Lock() c.mu.Lock()
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {

View File

@ -164,3 +164,21 @@ func TestStale(t *testing.T) {
assert.Equal(t, tenSecBefore, expires) assert.Equal(t, tenSecBefore, expires)
assert.Equal(t, true, exist) assert.Equal(t, true, exist)
} }
func TestCloneTo(t *testing.T) {
o := NewLRUCache(WithSize(10))
o.Set("1", 1)
o.Set("2", 2)
n := NewLRUCache(WithSize(2))
n.Set("3", 3)
n.Set("4", 4)
o.CloneTo(n)
assert.False(t, n.Exist("3"))
assert.True(t, n.Exist("1"))
n.Set("5", 5)
assert.False(t, n.Exist("1"))
}

11
common/net/io.go Normal file
View File

@ -0,0 +1,11 @@
package net
import "io"
type ReadOnlyReader struct {
io.Reader
}
type WriteOnlyWriter struct {
io.Writer
}

View File

@ -2,11 +2,11 @@ package observable
import ( import (
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/atomic"
) )
func iterator(item []interface{}) chan interface{} { func iterator(item []interface{}) chan interface{} {
@ -33,25 +33,25 @@ func TestObservable(t *testing.T) {
assert.Equal(t, count, 5) assert.Equal(t, count, 5)
} }
func TestObservable_MutilSubscribe(t *testing.T) { func TestObservable_MultiSubscribe(t *testing.T) {
iter := iterator([]interface{}{1, 2, 3, 4, 5}) iter := iterator([]interface{}{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable(iter)
ch1, _ := src.Subscribe() ch1, _ := src.Subscribe()
ch2, _ := src.Subscribe() ch2, _ := src.Subscribe()
var count int32 var count = atomic.NewInt32(0)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
waitCh := func(ch <-chan interface{}) { waitCh := func(ch <-chan interface{}) {
for range ch { for range ch {
atomic.AddInt32(&count, 1) count.Inc()
} }
wg.Done() wg.Done()
} }
go waitCh(ch1) go waitCh(ch1)
go waitCh(ch2) go waitCh(ch2)
wg.Wait() wg.Wait()
assert.Equal(t, int32(10), count) assert.Equal(t, int32(10), count.Load())
} }
func TestObservable_UnSubscribe(t *testing.T) { func TestObservable_UnSubscribe(t *testing.T) {
@ -113,3 +113,34 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) {
_, more := <-list[0] _, more := <-list[0]
assert.False(t, more) assert.False(t, more)
} }
func Benchmark_Observable_1000(b *testing.B) {
ch := make(chan interface{})
o := NewObservable(ch)
num := 1000
subs := []Subscription{}
for i := 0; i < num; i++ {
sub, _ := o.Subscribe()
subs = append(subs, sub)
}
wg := sync.WaitGroup{}
wg.Add(num)
b.ResetTimer()
for _, sub := range subs {
go func(s Subscription) {
for range s {
}
wg.Done()
}(sub)
}
for i := 0; i < b.N; i++ {
ch <- i
}
close(ch)
wg.Wait()
}

View File

@ -2,34 +2,32 @@ package observable
import ( import (
"sync" "sync"
"gopkg.in/eapache/channels.v1"
) )
type Subscription <-chan interface{} type Subscription <-chan interface{}
type Subscriber struct { type Subscriber struct {
buffer *channels.InfiniteChannel buffer chan interface{}
once sync.Once once sync.Once
} }
func (s *Subscriber) Emit(item interface{}) { func (s *Subscriber) Emit(item interface{}) {
s.buffer.In() <- item s.buffer <- item
} }
func (s *Subscriber) Out() Subscription { func (s *Subscriber) Out() Subscription {
return s.buffer.Out() return s.buffer
} }
func (s *Subscriber) Close() { func (s *Subscriber) Close() {
s.once.Do(func() { s.once.Do(func() {
s.buffer.Close() close(s.buffer)
}) })
} }
func newSubscriber() *Subscriber { func newSubscriber() *Subscriber {
sub := &Subscriber{ sub := &Subscriber{
buffer: channels.NewInfiniteChannel(), buffer: make(chan interface{}, 200),
} }
return sub return sub
} }

View File

@ -55,11 +55,13 @@ func (alloc *Allocator) Put(buf []byte) error {
if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits { if cap(buf) == 0 || cap(buf) > 65536 || cap(buf) != 1<<bits {
return errors.New("allocator Put() incorrect buffer size") return errors.New("allocator Put() incorrect buffer size")
} }
//lint:ignore SA6002 ignore temporarily
alloc.buffers[bits].Put(buf) alloc.buffers[bits].Put(buf)
return nil return nil
} }
// msb return the pos of most significiant bit // msb return the pos of most significant bit
func msb(size int) uint16 { func msb(size int) uint16 {
return uint16(bits.Len32(uint32(size)) - 1) return uint16(bits.Len32(uint32(size)) - 1)
} }

View File

@ -25,11 +25,11 @@ func TestAllocGet(t *testing.T) {
func TestAllocPut(t *testing.T) { func TestAllocPut(t *testing.T) {
alloc := NewAllocator() alloc := NewAllocator()
assert.NotNil(t, alloc.Put(nil), "put nil misbehavior") assert.NotNil(t, alloc.Put(nil), "put nil misbehavior")
assert.NotNil(t, alloc.Put(make([]byte, 3, 3)), "put elem:3 []bytes misbehavior") assert.NotNil(t, alloc.Put(make([]byte, 3)), "put elem:3 []bytes misbehavior")
assert.Nil(t, alloc.Put(make([]byte, 4, 4)), "put elem:4 []bytes misbehavior") assert.Nil(t, alloc.Put(make([]byte, 4)), "put elem:4 []bytes misbehavior")
assert.Nil(t, alloc.Put(make([]byte, 1023, 1024)), "put elem:1024 []bytes misbehavior") assert.Nil(t, alloc.Put(make([]byte, 1023, 1024)), "put elem:1024 []bytes misbehavior")
assert.Nil(t, alloc.Put(make([]byte, 65536, 65536)), "put elem:65536 []bytes misbehavior") assert.Nil(t, alloc.Put(make([]byte, 65536)), "put elem:65536 []bytes misbehavior")
assert.NotNil(t, alloc.Put(make([]byte, 65537, 65537)), "put elem:65537 []bytes misbehavior") assert.NotNil(t, alloc.Put(make([]byte, 65537)), "put elem:65537 []bytes misbehavior")
} }
func TestAllocPutThenGet(t *testing.T) { func TestAllocPutThenGet(t *testing.T) {

View File

@ -24,6 +24,8 @@ type Result struct {
Err error Err error
} }
// Do single.Do likes sync.singleFlight
//lint:ignore ST1008 it likes sync.singleFlight
func (s *Single) Do(fn func() (interface{}, error)) (v interface{}, err error, shared bool) { func (s *Single) Do(fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
s.mux.Lock() s.mux.Lock()
now := time.Now() now := time.Now()

View File

@ -2,17 +2,17 @@ package singledo
import ( import (
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/atomic"
) )
func TestBasic(t *testing.T) { func TestBasic(t *testing.T) {
single := NewSingle(time.Millisecond * 30) single := NewSingle(time.Millisecond * 30)
foo := 0 foo := 0
var shardCount int32 = 0 var shardCount = atomic.NewInt32(0)
call := func() (interface{}, error) { call := func() (interface{}, error) {
foo++ foo++
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 5)
@ -26,7 +26,7 @@ func TestBasic(t *testing.T) {
go func() { go func() {
_, _, shard := single.Do(call) _, _, shard := single.Do(call)
if shard { if shard {
atomic.AddInt32(&shardCount, 1) shardCount.Inc()
} }
wg.Done() wg.Done()
}() }()
@ -34,7 +34,7 @@ func TestBasic(t *testing.T) {
wg.Wait() wg.Wait()
assert.Equal(t, 1, foo) assert.Equal(t, 1, foo)
assert.Equal(t, int32(4), shardCount) assert.Equal(t, int32(4), shardCount.Load())
} }
func TestTimer(t *testing.T) { func TestTimer(t *testing.T) {

118
component/dialer/bind.go Normal file
View File

@ -0,0 +1,118 @@
package dialer
import (
"errors"
"net"
"time"
"github.com/Dreamacro/clash/common/singledo"
)
// In some OS, such as Windows, it takes a little longer to get interface information
var ifaceSingle = singledo.NewSingle(time.Second * 20)
var (
errPlatformNotSupport = errors.New("unsupport platform")
)
func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func fallbackBindToDialer(dialer *net.Dialer, network string, ip net.IP, name string) error {
if !ip.IsGlobalUnicast() {
return nil
}
iface, err, _ := ifaceSingle.Do(func() (interface{}, error) {
return net.InterfaceByName(name)
})
if err != nil {
return err
}
addrs, err := iface.(*net.Interface).Addrs()
if err != nil {
return err
}
switch network {
case "tcp", "tcp4", "tcp6":
if addr, err := lookupTCPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
} else {
return err
}
case "udp", "udp4", "udp6":
if addr, err := lookupUDPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
} else {
return err
}
}
return nil
}
func fallbackBindToListenConfig(name string) (string, error) {
iface, err, _ := ifaceSingle.Do(func() (interface{}, error) {
return net.InterfaceByName(name)
})
if err != nil {
return "", err
}
addrs, err := iface.(*net.Interface).Addrs()
if err != nil {
return "", err
}
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok || addr.IP.To4() == nil {
continue
}
return net.JoinHostPort(addr.IP.String(), "0"), nil
}
return "", ErrAddrNotFound
}

View File

@ -0,0 +1,53 @@
package dialer
import (
"net"
"syscall"
)
type controlFn = func(network, address string, c syscall.RawConn) error
func bindControl(ifaceIdx int) controlFn {
return func(network, address string, c syscall.RawConn) error {
ipStr, _, err := net.SplitHostPort(address)
if err == nil {
ip := net.ParseIP(ipStr)
if ip != nil && !ip.IsGlobalUnicast() {
return nil
}
}
return c.Control(func(fd uintptr) {
switch network {
case "tcp4", "udp4":
syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, ifaceIdx)
case "tcp6", "udp6":
syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, ifaceIdx)
}
})
}
}
func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error {
iface, err, _ := ifaceSingle.Do(func() (interface{}, error) {
return net.InterfaceByName(ifaceName)
})
if err != nil {
return err
}
dialer.Control = bindControl(iface.(*net.Interface).Index)
return nil
}
func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error {
iface, err, _ := ifaceSingle.Do(func() (interface{}, error) {
return net.InterfaceByName(ifaceName)
})
if err != nil {
return err
}
lc.Control = bindControl(iface.(*net.Interface).Index)
return nil
}

View File

@ -0,0 +1,36 @@
package dialer
import (
"net"
"syscall"
)
type controlFn = func(network, address string, c syscall.RawConn) error
func bindControl(ifaceName string) controlFn {
return func(network, address string, c syscall.RawConn) error {
ipStr, _, err := net.SplitHostPort(address)
if err == nil {
ip := net.ParseIP(ipStr)
if ip != nil && !ip.IsGlobalUnicast() {
return nil
}
}
return c.Control(func(fd uintptr) {
syscall.BindToDevice(int(fd), ifaceName)
})
}
}
func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error {
dialer.Control = bindControl(ifaceName)
return nil
}
func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error {
lc.Control = bindControl(ifaceName)
return nil
}

View File

@ -0,0 +1,13 @@
// +build !linux,!darwin
package dialer
import "net"
func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error {
return errPlatformNotSupport
}
func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error {
return errPlatformNotSupport
}

View File

@ -19,17 +19,6 @@ func Dialer() (*net.Dialer, error) {
return dialer, nil return dialer, nil
} }
func ListenConfig() (*net.ListenConfig, error) {
cfg := &net.ListenConfig{}
if ListenConfigHook != nil {
if err := ListenConfigHook(cfg); err != nil {
return nil, err
}
}
return cfg, nil
}
func Dial(network, address string) (net.Conn, error) { func Dial(network, address string) (net.Conn, error) {
return DialContext(context.Background(), network, address) return DialContext(context.Background(), network, address)
} }
@ -73,19 +62,16 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
func ListenPacket(network, address string) (net.PacketConn, error) { func ListenPacket(network, address string) (net.PacketConn, error) {
lc, err := ListenConfig() cfg := &net.ListenConfig{}
if ListenPacketHook != nil {
var err error
address, err = ListenPacketHook(cfg, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if ListenPacketHook != nil && address == "" { return cfg.ListenPacket(context.Background(), network, address)
ip, err := ListenPacketHook()
if err != nil {
return nil, err
}
address = net.JoinHostPort(ip.String(), "0")
}
return lc.ListenPacket(context.Background(), network, address)
} }
func dualStackDialContext(ctx context.Context, network, address string) (net.Conn, error) { func dualStackDialContext(ctx context.Context, network, address string) (net.Conn, error) {
@ -147,9 +133,7 @@ func dualStackDialContext(ctx context.Context, network, address string) (net.Con
go startRacer(ctx, network+"4", host, false) go startRacer(ctx, network+"4", host, false)
go startRacer(ctx, network+"6", host, true) go startRacer(ctx, network+"6", host, true)
for { for res := range results {
select {
case res := <-results:
if res.error == nil { if res.error == nil {
return res.Conn, nil return res.Conn, nil
} }
@ -170,5 +154,6 @@ func dualStackDialContext(ctx context.Context, network, address string) (net.Con
} }
} }
} }
}
return nil, errors.New("never touched")
} }

View File

@ -3,20 +3,15 @@ package dialer
import ( import (
"errors" "errors"
"net" "net"
"time"
"github.com/Dreamacro/clash/common/singledo"
) )
type DialerHookFunc = func(dialer *net.Dialer) error type DialerHookFunc = func(dialer *net.Dialer) error
type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error
type ListenConfigHookFunc = func(*net.ListenConfig) error type ListenPacketHookFunc = func(lc *net.ListenConfig, address string) (string, error)
type ListenPacketHookFunc = func() (net.IP, error)
var ( var (
DialerHook DialerHookFunc DialerHook DialerHookFunc
DialHook DialHookFunc DialHook DialHookFunc
ListenConfigHook ListenConfigHookFunc
ListenPacketHook ListenPacketHookFunc ListenPacketHook ListenPacketHookFunc
) )
@ -25,124 +20,24 @@ var (
ErrNetworkNotSupport = errors.New("network not support") ErrNetworkNotSupport = errors.New("network not support")
) )
func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func ListenPacketWithInterface(name string) ListenPacketHookFunc { func ListenPacketWithInterface(name string) ListenPacketHookFunc {
single := singledo.NewSingle(5 * time.Second) return func(lc *net.ListenConfig, address string) (string, error) {
err := bindIfaceToListenConfig(lc, name)
return func() (net.IP, error) { if err == errPlatformNotSupport {
elm, err, _ := single.Do(func() (interface{}, error) { address, err = fallbackBindToListenConfig(name)
iface, err := net.InterfaceByName(name)
if err != nil {
return nil, err
} }
addrs, err := iface.Addrs() return address, err
if err != nil {
return nil, err
}
return addrs, nil
})
if err != nil {
return nil, err
}
addrs := elm.([]net.Addr)
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok || addr.IP.To4() == nil {
continue
}
return addr.IP, nil
}
return nil, ErrAddrNotFound
} }
} }
func DialerWithInterface(name string) DialHookFunc { func DialerWithInterface(name string) DialHookFunc {
single := singledo.NewSingle(5 * time.Second)
return func(dialer *net.Dialer, network string, ip net.IP) error { return func(dialer *net.Dialer, network string, ip net.IP) error {
elm, err, _ := single.Do(func() (interface{}, error) { err := bindIfaceToDialer(dialer, name)
iface, err := net.InterfaceByName(name) if err == errPlatformNotSupport {
if err != nil { err = fallbackBindToDialer(dialer, network, ip, name)
return nil, err
} }
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
return addrs, nil
})
if err != nil {
return err
}
addrs := elm.([]net.Addr)
switch network {
case "tcp", "tcp4", "tcp6":
if addr, err := lookupTCPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
} else {
return err
}
case "udp", "udp4", "udp6":
if addr, err := lookupUDPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
} else {
return err return err
} }
} }
return nil
}
}

View File

@ -17,6 +17,7 @@ type Pool struct {
offset uint32 offset uint32
mux sync.Mutex mux sync.Mutex
host *trie.DomainTrie host *trie.DomainTrie
ipnet *net.IPNet
cache *cache.LruCache cache *cache.LruCache
} }
@ -89,6 +90,16 @@ func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway) return uintToIP(p.gateway)
} }
// IPNet return raw ipnet
func (p *Pool) IPNet() *net.IPNet {
return p.ipnet
}
// PatchFrom clone cache from old pool
func (p *Pool) PatchFrom(o *Pool) {
o.cache.CloneTo(p.cache)
}
func (p *Pool) get(host string) net.IP { func (p *Pool) get(host string) net.IP {
current := p.offset current := p.offset
for { for {
@ -116,7 +127,7 @@ func ipToUint(ip net.IP) uint32 {
} }
func uintToIP(v uint32) net.IP { func uintToIP(v uint32) net.IP {
return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) return net.IP{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
} }
// New return Pool instance // New return Pool instance
@ -136,6 +147,7 @@ func New(ipnet *net.IPNet, size int, host *trie.DomainTrie) (*Pool, error) {
max: max, max: max,
gateway: min - 1, gateway: min - 1,
host: host, host: host,
ipnet: ipnet,
cache: cache.NewLRUCache(cache.WithSize(size * 2)), cache: cache.NewLRUCache(cache.WithSize(size * 2)),
}, nil }, nil
} }

View File

@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn {
return item.(C.PacketConn) return item.(C.PacketConn)
} }
func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{}) item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.WaitGroup), loaded return item.(*sync.Cond), loaded
} }
func (t *Table) Delete(key string) { func (t *Table) Delete(key string) {

114
component/pool/pool.go Normal file
View File

@ -0,0 +1,114 @@
package pool
import (
"context"
"runtime"
"time"
)
type Factory = func(context.Context) (interface{}, error)
type entry struct {
elm interface{}
time time.Time
}
type Option func(*pool)
// WithEvict set the evict callback
func WithEvict(cb func(interface{})) Option {
return func(p *pool) {
p.evict = cb
}
}
// WithAge defined element max age (millisecond)
func WithAge(maxAge int64) Option {
return func(p *pool) {
p.maxAge = maxAge
}
}
// WithSize defined max size of Pool
func WithSize(maxSize int) Option {
return func(p *pool) {
p.ch = make(chan interface{}, maxSize)
}
}
// Pool is for GC, see New for detail
type Pool struct {
*pool
}
type pool struct {
ch chan interface{}
factory Factory
evict func(interface{})
maxAge int64
}
func (p *pool) GetContext(ctx context.Context) (interface{}, error) {
now := time.Now()
for {
select {
case item := <-p.ch:
elm := item.(*entry)
if p.maxAge != 0 && now.Sub(item.(*entry).time).Milliseconds() > p.maxAge {
if p.evict != nil {
p.evict(elm.elm)
}
continue
}
return elm.elm, nil
default:
return p.factory(ctx)
}
}
}
func (p *pool) Get() (interface{}, error) {
return p.GetContext(context.Background())
}
func (p *pool) Put(item interface{}) {
e := &entry{
elm: item,
time: time.Now(),
}
select {
case p.ch <- e:
return
default:
// pool is full
if p.evict != nil {
p.evict(item)
}
return
}
}
func recycle(p *Pool) {
for item := range p.pool.ch {
if p.pool.evict != nil {
p.pool.evict(item.(*entry).elm)
}
}
}
func New(factory Factory, options ...Option) *Pool {
p := &pool{
ch: make(chan interface{}, 10),
factory: factory,
}
for _, option := range options {
option(p)
}
P := &Pool{p}
runtime.SetFinalizer(P, recycle)
return P
}

View File

@ -0,0 +1,73 @@
package pool
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func lg() Factory {
initial := -1
return func(context.Context) (interface{}, error) {
initial++
return initial, nil
}
}
func TestPool_Basic(t *testing.T) {
g := lg()
pool := New(g)
elm, _ := pool.Get()
assert.Equal(t, 0, elm.(int))
pool.Put(elm)
elm, _ = pool.Get()
assert.Equal(t, 0, elm.(int))
elm, _ = pool.Get()
assert.Equal(t, 1, elm.(int))
}
func TestPool_MaxSize(t *testing.T) {
g := lg()
size := 5
pool := New(g, WithSize(size))
items := []interface{}{}
for i := 0; i < size; i++ {
item, _ := pool.Get()
items = append(items, item)
}
extra, _ := pool.Get()
assert.Equal(t, size, extra.(int))
for _, item := range items {
pool.Put(item)
}
pool.Put(extra)
for _, item := range items {
elm, _ := pool.Get()
assert.Equal(t, item.(int), elm.(int))
}
}
func TestPool_MaxAge(t *testing.T) {
g := lg()
pool := New(g, WithAge(20))
elm, _ := pool.Get()
pool.Put(elm)
elm, _ = pool.Get()
assert.Equal(t, 0, elm.(int))
pool.Put(elm)
time.Sleep(time.Millisecond * 22)
elm, _ = pool.Get()
assert.Equal(t, 1, elm.(int))
}

View File

@ -0,0 +1,21 @@
package process
import (
"errors"
"net"
)
var (
ErrInvalidNetwork = errors.New("invalid network")
ErrPlatformNotSupport = errors.New("not support on this platform")
ErrNotFound = errors.New("process not found")
)
const (
TCP = "tcp"
UDP = "udp"
)
func FindProcessName(network string, srcIP net.IP, srcPort int) (string, error) {
return findProcessName(network, srcIP, srcPort)
}

View File

@ -1,109 +1,26 @@
package rules package process
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"net" "net"
"path/filepath" "path/filepath"
"strconv"
"strings"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/Dreamacro/clash/common/cache"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
) )
// store process name for when dealing with multiple PROCESS-NAME rules
var processCache = cache.NewLRUCache(cache.WithAge(2), cache.WithSize(64))
type Process struct {
adapter string
process string
}
func (ps *Process) RuleType() C.RuleType {
return C.Process
}
func (ps *Process) Match(metadata *C.Metadata) bool {
key := fmt.Sprintf("%s:%s:%s", metadata.NetWork.String(), metadata.SrcIP.String(), metadata.SrcPort)
cached, hit := processCache.Get(key)
if !hit {
name, err := getExecPathFromAddress(metadata)
if err != nil {
log.Debugln("[%s] getExecPathFromAddress error: %s", C.Process.String(), err.Error())
}
processCache.Set(key, name)
cached = name
}
return strings.EqualFold(cached.(string), ps.process)
}
func (p *Process) Adapter() string {
return p.adapter
}
func (p *Process) Payload() string {
return p.process
}
func (p *Process) ShouldResolveIP() bool {
return false
}
func NewProcess(process string, adapter string) (*Process, error) {
return &Process{
adapter: adapter,
process: process,
}, nil
}
const ( const (
procpidpathinfo = 0xb procpidpathinfo = 0xb
procpidpathinfosize = 1024 procpidpathinfosize = 1024
proccallnumpidinfo = 0x2 proccallnumpidinfo = 0x2
) )
func getExecPathFromPID(pid uint32) (string, error) { func findProcessName(network string, ip net.IP, port int) (string, error) {
buf := make([]byte, procpidpathinfosize)
_, _, errno := syscall.Syscall6(
syscall.SYS_PROC_INFO,
proccallnumpidinfo,
uintptr(pid),
procpidpathinfo,
0,
uintptr(unsafe.Pointer(&buf[0])),
procpidpathinfosize)
if errno != 0 {
return "", errno
}
firstZero := bytes.IndexByte(buf, 0)
if firstZero <= 0 {
return "", nil
}
return filepath.Base(string(buf[:firstZero])), nil
}
func getExecPathFromAddress(metadata *C.Metadata) (string, error) {
ip := metadata.SrcIP
port, err := strconv.Atoi(metadata.SrcPort)
if err != nil {
return "", err
}
var spath string var spath string
switch metadata.NetWork { switch network {
case C.TCP: case TCP:
spath = "net.inet.tcp.pcblist_n" spath = "net.inet.tcp.pcblist_n"
case C.UDP: case UDP:
spath = "net.inet.udp.pcblist_n" spath = "net.inet.udp.pcblist_n"
default: default:
return "", ErrInvalidNetwork return "", ErrInvalidNetwork
@ -123,12 +40,12 @@ func getExecPathFromAddress(metadata *C.Metadata) (string, error) {
// rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) + // rup8(sizeof(xinpcb_n)) + rup8(sizeof(xsocket_n)) +
// 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n)) // 2 * rup8(sizeof(xsockbuf_n)) + rup8(sizeof(xsockstat_n))
itemSize := 384 itemSize := 384
if metadata.NetWork == C.TCP { if network == TCP {
// rup8(sizeof(xtcpcb_n)) // rup8(sizeof(xtcpcb_n))
itemSize += 208 itemSize += 208
} }
// skip the first and last xinpgen(24 bytes) block // skip the first xinpgen(24 bytes) block
for i := 24; i < len(buf)-24; i += itemSize { for i := 24; i+itemSize <= len(buf); i += itemSize {
// offset of xinpcb_n and xsocket_n // offset of xinpcb_n and xsocket_n
inp, so := i, i+104 inp, so := i, i+104
@ -161,7 +78,28 @@ func getExecPathFromAddress(metadata *C.Metadata) (string, error) {
return getExecPathFromPID(pid) return getExecPathFromPID(pid)
} }
return "", errors.New("process not found") return "", ErrNotFound
}
func getExecPathFromPID(pid uint32) (string, error) {
buf := make([]byte, procpidpathinfosize)
_, _, errno := syscall.Syscall6(
syscall.SYS_PROC_INFO,
proccallnumpidinfo,
uintptr(pid),
procpidpathinfo,
0,
uintptr(unsafe.Pointer(&buf[0])),
procpidpathinfosize)
if errno != 0 {
return "", errno
}
firstZero := bytes.IndexByte(buf, 0)
if firstZero <= 0 {
return "", nil
}
return filepath.Base(string(buf[:firstZero])), nil
} }
func readNativeUint32(b []byte) uint32 { func readNativeUint32(b []byte) uint32 {

View File

@ -0,0 +1,228 @@
package process
import (
"encoding/binary"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"unsafe"
"github.com/Dreamacro/clash/log"
)
// store process name for when dealing with multiple PROCESS-NAME rules
var (
defaultSearcher *searcher
once sync.Once
)
func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
once.Do(func() {
if err := initSearcher(); err != nil {
log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error())
log.Warnln("All PROCESS-NAME rules will be skipped")
return
}
})
var spath string
isTCP := network == TCP
switch network {
case TCP:
spath = "net.inet.tcp.pcblist"
case UDP:
spath = "net.inet.udp.pcblist"
default:
return "", ErrInvalidNetwork
}
value, err := syscall.Sysctl(spath)
if err != nil {
return "", err
}
buf := []byte(value)
pid, err := defaultSearcher.Search(buf, ip, uint16(srcPort), isTCP)
if err != nil {
return "", err
}
return getExecPathFromPID(pid)
}
func getExecPathFromPID(pid uint32) (string, error) {
buf := make([]byte, 2048)
size := uint64(len(buf))
// CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, pid
mib := [4]uint32{1, 14, 12, pid}
_, _, errno := syscall.Syscall6(
syscall.SYS___SYSCTL,
uintptr(unsafe.Pointer(&mib[0])),
uintptr(len(mib)),
uintptr(unsafe.Pointer(&buf[0])),
uintptr(unsafe.Pointer(&size)),
0,
0)
if errno != 0 || size == 0 {
return "", errno
}
return filepath.Base(string(buf[:size-1])), nil
}
func readNativeUint32(b []byte) uint32 {
return *(*uint32)(unsafe.Pointer(&b[0]))
}
type searcher struct {
// sizeof(struct xinpgen)
headSize int
// sizeof(struct xtcpcb)
tcpItemSize int
// sizeof(struct xinpcb)
udpItemSize int
udpInpOffset int
port int
ip int
vflag int
socket int
// sizeof(struct xfile)
fileItemSize int
data int
pid int
}
func (s *searcher) Search(buf []byte, ip net.IP, port uint16, isTCP bool) (uint32, error) {
var itemSize int
var inpOffset int
if isTCP {
// struct xtcpcb
itemSize = s.tcpItemSize
inpOffset = 8
} else {
// struct xinpcb
itemSize = s.udpItemSize
inpOffset = s.udpInpOffset
}
isIPv4 := ip.To4() != nil
// skip the first xinpgen block
for i := s.headSize; i+itemSize <= len(buf); i += itemSize {
inp := i + inpOffset
srcPort := binary.BigEndian.Uint16(buf[inp+s.port : inp+s.port+2])
if port != srcPort {
continue
}
// xinpcb.inp_vflag
flag := buf[inp+s.vflag]
var srcIP net.IP
switch {
case flag&0x1 > 0 && isIPv4:
// ipv4
srcIP = net.IP(buf[inp+s.ip : inp+s.ip+4])
case flag&0x2 > 0 && !isIPv4:
// ipv6
srcIP = net.IP(buf[inp+s.ip-12 : inp+s.ip+4])
default:
continue
}
if !ip.Equal(srcIP) {
continue
}
// xsocket.xso_so, interpreted as big endian anyway since it's only used for comparison
socket := binary.BigEndian.Uint64(buf[inp+s.socket : inp+s.socket+8])
return s.searchSocketPid(socket)
}
return 0, ErrNotFound
}
func (s *searcher) searchSocketPid(socket uint64) (uint32, error) {
value, err := syscall.Sysctl("kern.file")
if err != nil {
return 0, err
}
buf := []byte(value)
// struct xfile
itemSize := s.fileItemSize
for i := 0; i+itemSize <= len(buf); i += itemSize {
// xfile.xf_data
data := binary.BigEndian.Uint64(buf[i+s.data : i+s.data+8])
if data == socket {
// xfile.xf_pid
pid := readNativeUint32(buf[i+s.pid : i+s.pid+4])
return pid, nil
}
}
return 0, ErrNotFound
}
func newSearcher(major int) *searcher {
var s *searcher = nil
switch major {
case 11:
s = &searcher{
headSize: 32,
tcpItemSize: 1304,
udpItemSize: 632,
port: 198,
ip: 228,
vflag: 116,
socket: 88,
fileItemSize: 80,
data: 56,
pid: 8,
udpInpOffset: 8,
}
case 12:
s = &searcher{
headSize: 64,
tcpItemSize: 744,
udpItemSize: 400,
port: 254,
ip: 284,
vflag: 392,
socket: 16,
fileItemSize: 128,
data: 56,
pid: 8,
}
}
return s
}
func initSearcher() error {
osRelease, err := syscall.Sysctl("kern.osrelease")
if err != nil {
return err
}
dot := strings.Index(osRelease, ".")
if dot != -1 {
osRelease = osRelease[:dot]
}
major, err := strconv.Atoi(osRelease)
if err != nil {
return err
}
defaultSearcher = newSearcher(major)
if defaultSearcher == nil {
return fmt.Errorf("unsupported freebsd version %d", major)
}
return nil
}

View File

@ -1,4 +1,4 @@
package rules package process
import ( import (
"bytes" "bytes"
@ -9,15 +9,10 @@ import (
"net" "net"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
) )
// from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62 // from https://github.com/vishvananda/netlink/blob/bca67dfc8220b44ef582c9da4e9172bf1c9ec973/nl/nl_linux.go#L52-L62
@ -30,7 +25,7 @@ func init() {
} }
} }
type SocketResolver func(metadata *C.Metadata) (inode, uid int, err error) type SocketResolver func(network string, ip net.IP, srcPort int) (inode, uid int, err error)
type ProcessNameResolver func(inode, uid int) (name string, err error) type ProcessNameResolver func(inode, uid int) (name string, err error)
// export for android // export for android
@ -39,51 +34,6 @@ var (
DefaultProcessNameResolver ProcessNameResolver = resolveProcessNameByProcSearch DefaultProcessNameResolver ProcessNameResolver = resolveProcessNameByProcSearch
) )
type Process struct {
adapter string
process string
}
func (p *Process) RuleType() C.RuleType {
return C.Process
}
func (p *Process) Match(metadata *C.Metadata) bool {
key := fmt.Sprintf("%s:%s:%s", metadata.NetWork.String(), metadata.SrcIP.String(), metadata.SrcPort)
cached, hit := processCache.Get(key)
if !hit {
processName, err := resolveProcessName(metadata)
if err != nil {
log.Debugln("[%s] Resolve process of %s failure: %s", C.Process.String(), key, err.Error())
}
processCache.Set(key, processName)
cached = processName
}
return strings.EqualFold(cached.(string), p.process)
}
func (p *Process) Adapter() string {
return p.adapter
}
func (p *Process) Payload() string {
return p.process
}
func (p *Process) ShouldResolveIP() bool {
return false
}
func NewProcess(process string, adapter string) (*Process, error) {
return &Process{
adapter: adapter,
process: process,
}, nil
}
const ( const (
sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48 sizeOfSocketDiagRequest = syscall.SizeofNlMsghdr + 8 + 48
socketDiagByFamily = 20 socketDiagByFamily = 20
@ -92,10 +42,8 @@ const (
var nativeEndian binary.ByteOrder = binary.LittleEndian var nativeEndian binary.ByteOrder = binary.LittleEndian
var processCache = cache.NewLRUCache(cache.WithAge(2), cache.WithSize(64)) func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
inode, uid, err := DefaultSocketResolver(network, ip, srcPort)
func resolveProcessName(metadata *C.Metadata) (string, error) {
inode, uid, err := DefaultSocketResolver(metadata)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -103,31 +51,26 @@ func resolveProcessName(metadata *C.Metadata) (string, error) {
return DefaultProcessNameResolver(inode, uid) return DefaultProcessNameResolver(inode, uid)
} }
func resolveSocketByNetlink(metadata *C.Metadata) (int, int, error) { func resolveSocketByNetlink(network string, ip net.IP, srcPort int) (int, int, error) {
var family byte var family byte
var protocol byte var protocol byte
switch metadata.NetWork { switch network {
case C.TCP: case TCP:
protocol = syscall.IPPROTO_TCP protocol = syscall.IPPROTO_TCP
case C.UDP: case UDP:
protocol = syscall.IPPROTO_UDP protocol = syscall.IPPROTO_UDP
default: default:
return 0, 0, ErrInvalidNetwork return 0, 0, ErrInvalidNetwork
} }
if metadata.SrcIP.To4() != nil { if ip.To4() != nil {
family = syscall.AF_INET family = syscall.AF_INET
} else { } else {
family = syscall.AF_INET6 family = syscall.AF_INET6
} }
srcPort, err := strconv.Atoi(metadata.SrcPort) req := packSocketDiagRequest(family, protocol, ip, uint16(srcPort))
if err != nil {
return 0, 0, err
}
req := packSocketDiagRequest(family, protocol, metadata.SrcIP, uint16(srcPort))
socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG) socket, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_INET_DIAG)
if err != nil { if err != nil {
@ -257,7 +200,7 @@ func resolveProcessNameByProcSearch(inode, uid int) (string, error) {
continue continue
} }
if bytes.Compare(buffer[:n], socket) == 0 { if bytes.Equal(buffer[:n], socket) {
cmdline, err := ioutil.ReadFile(path.Join(processPath, "cmdline")) cmdline, err := ioutil.ReadFile(path.Join(processPath, "cmdline"))
if err != nil { if err != nil {
return "", err return "", err

View File

@ -0,0 +1,10 @@
// +build !darwin,!linux,!windows
// +build !freebsd !amd64
package process
import "net"
func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
return "", ErrPlatformNotSupport
}

View File

@ -1,18 +1,13 @@
package rules package process
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"path/filepath" "path/filepath"
"strconv"
"strings"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/Dreamacro/clash/common/cache"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -27,10 +22,6 @@ const (
) )
var ( var (
processCache = cache.NewLRUCache(cache.WithAge(2), cache.WithSize(64))
errNotFound = errors.New("process not found")
matchMeta = func(p *Process, m *C.Metadata) bool { return false }
getExTcpTable uintptr getExTcpTable uintptr
getExUdpTable uintptr getExUdpTable uintptr
queryProcName uintptr queryProcName uintptr
@ -67,47 +58,7 @@ func initWin32API() error {
return nil return nil
} }
type Process struct { func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
adapter string
process string
}
func (p *Process) RuleType() C.RuleType {
return C.Process
}
func (p *Process) Adapter() string {
return p.adapter
}
func (p *Process) Payload() string {
return p.process
}
func (p *Process) ShouldResolveIP() bool {
return false
}
func match(p *Process, metadata *C.Metadata) bool {
key := fmt.Sprintf("%s:%s:%s", metadata.NetWork.String(), metadata.SrcIP.String(), metadata.SrcPort)
cached, hit := processCache.Get(key)
if !hit {
processName, err := resolveProcessName(metadata)
if err != nil {
log.Debugln("[%s] Resolve process of %s failed: %s", C.Process.String(), key, err.Error())
}
processCache.Set(key, processName)
cached = processName
}
return strings.EqualFold(cached.(string), p.process)
}
func (p *Process) Match(metadata *C.Metadata) bool {
return matchMeta(p, metadata)
}
func NewProcess(process string, adapter string) (*Process, error) {
once.Do(func() { once.Do(func() {
err := initWin32API() err := initWin32API()
if err != nil { if err != nil {
@ -115,16 +66,7 @@ func NewProcess(process string, adapter string) (*Process, error) {
log.Warnln("All PROCESS-NAMES rules will be skiped") log.Warnln("All PROCESS-NAMES rules will be skiped")
return return
} }
matchMeta = match
}) })
return &Process{
adapter: adapter,
process: process,
}, nil
}
func resolveProcessName(metadata *C.Metadata) (string, error) {
ip := metadata.SrcIP
family := windows.AF_INET family := windows.AF_INET
if ip.To4() == nil { if ip.To4() == nil {
family = windows.AF_INET6 family = windows.AF_INET6
@ -132,28 +74,23 @@ func resolveProcessName(metadata *C.Metadata) (string, error) {
var class int var class int
var fn uintptr var fn uintptr
switch metadata.NetWork { switch network {
case C.TCP: case TCP:
fn = getExTcpTable fn = getExTcpTable
class = tcpTablePidConn class = tcpTablePidConn
case C.UDP: case UDP:
fn = getExUdpTable fn = getExUdpTable
class = udpTablePid class = udpTablePid
default: default:
return "", ErrInvalidNetwork return "", ErrInvalidNetwork
} }
srcPort, err := strconv.Atoi(metadata.SrcPort)
if err != nil {
return "", err
}
buf, err := getTransportTable(fn, family, class) buf, err := getTransportTable(fn, family, class)
if err != nil { if err != nil {
return "", err return "", err
} }
s := newSearcher(family == windows.AF_INET, metadata.NetWork == C.TCP) s := newSearcher(family == windows.AF_INET, network == TCP)
pid, err := s.Search(buf, ip, uint16(srcPort)) pid, err := s.Search(buf, ip, uint16(srcPort))
if err != nil { if err != nil {
@ -196,14 +133,15 @@ func (s *searcher) Search(b []byte, ip net.IP, port uint16) (uint32, error) {
} }
srcIP := net.IP(row[s.ip : s.ip+s.ipSize]) srcIP := net.IP(row[s.ip : s.ip+s.ipSize])
if !ip.Equal(srcIP) { // windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
if !ip.Equal(srcIP) && (!srcIP.IsUnspecified() || s.tcpState != -1) {
continue continue
} }
pid := readNativeUint32(row[s.pid : s.pid+4]) pid := readNativeUint32(row[s.pid : s.pid+4])
return pid, nil return pid, nil
} }
return 0, errNotFound return 0, ErrNotFound
} }
func newSearcher(isV4, isTCP bool) *searcher { func newSearcher(isV4, isTCP bool) *searcher {

View File

@ -0,0 +1,101 @@
package cachefile
import (
"bytes"
"encoding/gob"
"io/ioutil"
"os"
"sync"
"github.com/Dreamacro/clash/component/profile"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
var (
initOnce sync.Once
fileMode os.FileMode = 0666
defaultCache *CacheFile
)
type cache struct {
Selected map[string]string
}
// CacheFile store and update the cache file
type CacheFile struct {
path string
model *cache
buf *bytes.Buffer
mux sync.Mutex
}
func (c *CacheFile) SetSelected(group, selected string) {
if !profile.StoreSelected.Load() {
return
}
c.mux.Lock()
defer c.mux.Unlock()
model := c.element()
model.Selected[group] = selected
c.buf.Reset()
if err := gob.NewEncoder(c.buf).Encode(model); err != nil {
log.Warnln("[CacheFile] encode gob failed: %s", err.Error())
return
}
if err := ioutil.WriteFile(c.path, c.buf.Bytes(), fileMode); err != nil {
log.Warnln("[CacheFile] write cache to %s failed: %s", c.path, err.Error())
return
}
}
func (c *CacheFile) SelectedMap() map[string]string {
if !profile.StoreSelected.Load() {
return nil
}
c.mux.Lock()
defer c.mux.Unlock()
model := c.element()
mapping := map[string]string{}
for k, v := range model.Selected {
mapping[k] = v
}
return mapping
}
func (c *CacheFile) element() *cache {
if c.model != nil {
return c.model
}
model := &cache{
Selected: map[string]string{},
}
if buf, err := ioutil.ReadFile(c.path); err == nil {
bufReader := bytes.NewBuffer(buf)
gob.NewDecoder(bufReader).Decode(model)
}
c.model = model
return c.model
}
// Cache return singleton of CacheFile
func Cache() *CacheFile {
initOnce.Do(func() {
defaultCache = &CacheFile{
path: C.Path.Cache(),
buf: &bytes.Buffer{},
}
})
return defaultCache
}

View File

@ -0,0 +1,10 @@
package profile
import (
"go.uber.org/atomic"
)
var (
// StoreSelected is a global switch for storing selected proxy to cache
StoreSelected = atomic.NewBool(true)
)

View File

@ -0,0 +1,55 @@
package resolver
import (
"net"
)
var DefaultHostMapper Enhancer
type Enhancer interface {
FakeIPEnabled() bool
MappingEnabled() bool
IsFakeIP(net.IP) bool
IsExistFakeIP(net.IP) bool
FindHostByIP(net.IP) (string, bool)
}
func FakeIPEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FakeIPEnabled()
}
return false
}
func MappingEnabled() bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.MappingEnabled()
}
return false
}
func IsFakeIP(ip net.IP) bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.IsFakeIP(ip)
}
return false
}
func IsExistFakeIP(ip net.IP) bool {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.IsExistFakeIP(ip)
}
return false
}
func FindHostByIP(ip net.IP) (string, bool) {
if mapper := DefaultHostMapper; mapper != nil {
return mapper.FindHostByIP(ip)
}
return "", false
}

View File

@ -1,21 +1,54 @@
package snell package snell
import ( import (
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"github.com/Dreamacro/go-shadowsocks2/shadowaead"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
"golang.org/x/crypto/chacha20poly1305"
) )
type snellCipher struct { type snellCipher struct {
psk []byte psk []byte
keySize int
makeAEAD func(key []byte) (cipher.AEAD, error) makeAEAD func(key []byte) (cipher.AEAD, error)
} }
func (sc *snellCipher) KeySize() int { return 32 } func (sc *snellCipher) KeySize() int { return sc.keySize }
func (sc *snellCipher) SaltSize() int { return 16 } func (sc *snellCipher) SaltSize() int { return 16 }
func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) { func (sc *snellCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(argon2.IDKey(sc.psk, salt, 3, 8, 1, uint32(sc.KeySize()))) return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
} }
func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) { func (sc *snellCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
return sc.makeAEAD(argon2.IDKey(sc.psk, salt, 3, 8, 1, uint32(sc.KeySize()))) return sc.makeAEAD(snellKDF(sc.psk, salt, sc.KeySize()))
}
func snellKDF(psk, salt []byte, keySize int) []byte {
// snell use a special kdf function
return argon2.IDKey(psk, salt, 3, 8, 1, 32)[:keySize]
}
func aesGCM(key []byte) (cipher.AEAD, error) {
blk, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(blk)
}
func NewAES128GCM(psk []byte) shadowaead.Cipher {
return &snellCipher{
psk: psk,
keySize: 16,
makeAEAD: aesGCM,
}
}
func NewChacha20Poly1305(psk []byte) shadowaead.Cipher {
return &snellCipher{
psk: psk,
keySize: 32,
makeAEAD: chacha20poly1305.New,
}
} }

80
component/snell/pool.go Normal file
View File

@ -0,0 +1,80 @@
package snell
import (
"context"
"net"
"github.com/Dreamacro/clash/component/pool"
"github.com/Dreamacro/go-shadowsocks2/shadowaead"
)
type Pool struct {
pool *pool.Pool
}
func (p *Pool) Get() (net.Conn, error) {
return p.GetContext(context.Background())
}
func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
elm, err := p.pool.GetContext(ctx)
if err != nil {
return nil, err
}
return &PoolConn{elm.(*Snell), p}, nil
}
func (p *Pool) Put(conn net.Conn) {
if err := HalfClose(conn); err != nil {
conn.Close()
return
}
p.pool.Put(conn)
}
type PoolConn struct {
*Snell
pool *Pool
}
func (pc *PoolConn) Read(b []byte) (int, error) {
// save old status of reply (it mutable by Read)
reply := pc.Snell.reply
n, err := pc.Snell.Read(b)
if err == shadowaead.ErrZeroChunk {
// if reply is false, it should be client halfclose.
// ignore error and read data again.
if !reply {
pc.Snell.reply = false
return pc.Snell.Read(b)
}
}
return n, err
}
func (pc *PoolConn) Write(b []byte) (int, error) {
return pc.Snell.Write(b)
}
func (pc *PoolConn) Close() error {
pc.pool.Put(pc.Snell)
return nil
}
func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
p := pool.New(
func(ctx context.Context) (interface{}, error) {
return factory(ctx)
},
pool.WithAge(15000),
pool.WithSize(10),
pool.WithEvict(func(item interface{}) {
item.(*Snell).Close()
}),
)
return &Pool{p}
}

View File

@ -10,14 +10,21 @@ import (
"sync" "sync"
"github.com/Dreamacro/go-shadowsocks2/shadowaead" "github.com/Dreamacro/go-shadowsocks2/shadowaead"
"golang.org/x/crypto/chacha20poly1305" )
const (
Version1 = 1
Version2 = 2
DefaultSnellVersion = Version1
) )
const ( const (
CommandPing byte = 0 CommandPing byte = 0
CommandConnect byte = 1 CommandConnect byte = 1
CommandConnectV2 byte = 5
CommandTunnel byte = 0 CommandTunnel byte = 0
CommandPong byte = 1
CommandError byte = 2 CommandError byte = 2
Version byte = 1 Version byte = 1
@ -25,6 +32,7 @@ const (
var ( var (
bufferPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }} bufferPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }}
endSignal = []byte{}
) )
type Snell struct { type Snell struct {
@ -46,7 +54,7 @@ func (s *Snell) Read(b []byte) (int, error) {
if s.buffer[0] == CommandTunnel { if s.buffer[0] == CommandTunnel {
return s.Conn.Read(b) return s.Conn.Read(b)
} else if s.buffer[0] != CommandError { } else if s.buffer[0] != CommandError {
return 0, errors.New("Command not support") return 0, errors.New("command not support")
} }
// CommandError // CommandError
@ -70,12 +78,16 @@ func (s *Snell) Read(b []byte) (int, error) {
return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg)) return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
} }
func WriteHeader(conn net.Conn, host string, port uint) error { func WriteHeader(conn net.Conn, host string, port uint, version int) error {
buf := bufferPool.Get().(*bytes.Buffer) buf := bufferPool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
defer bufferPool.Put(buf) defer bufferPool.Put(buf)
buf.WriteByte(Version) buf.WriteByte(Version)
if version == Version2 {
buf.WriteByte(CommandConnectV2)
} else {
buf.WriteByte(CommandConnect) buf.WriteByte(CommandConnect)
}
// clientID length & id // clientID length & id
buf.WriteByte(0) buf.WriteByte(0)
@ -92,7 +104,24 @@ func WriteHeader(conn net.Conn, host string, port uint) error {
return nil return nil
} }
func StreamConn(conn net.Conn, psk []byte) net.Conn { // HalfClose works only on version2
cipher := &snellCipher{psk, chacha20poly1305.New} func HalfClose(conn net.Conn) error {
if _, err := conn.Write(endSignal); err != nil {
return err
}
if s, ok := conn.(*Snell); ok {
s.reply = false
}
return nil
}
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
var cipher shadowaead.Cipher
if version == Version2 {
cipher = NewAES128GCM(psk)
} else {
cipher = NewChacha20Poly1305(psk)
}
return &Snell{Conn: shadowaead.NewConn(conn, cipher)} return &Snell{Conn: shadowaead.NewConn(conn, cipher)}
} }

View File

@ -1,11 +1,9 @@
package obfs package obfs
// Base information for obfs
type Base struct { type Base struct {
IVSize int
Key []byte
HeadLen int
Host string Host string
Port int Port int
Key []byte
IVSize int
Param string Param string
} }

View File

@ -1,7 +1,7 @@
package obfs package obfs
func init() { func init() {
register("http_post", newHTTPPost) register("http_post", newHTTPPost, 0)
} }
func newHTTPPost(b *Base) Obfs { func newHTTPPost(b *Base) Obfs {

View File

@ -3,151 +3,157 @@ package obfs
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt"
"io" "io"
"math/rand" "math/rand"
"net"
"strconv"
"strings" "strings"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools"
) )
func init() {
register("http_simple", newHTTPSimple, 0)
}
type httpObfs struct { type httpObfs struct {
*Base *Base
firstRequest bool
firstResponse bool
post bool post bool
} }
func init() {
register("http_simple", newHTTPSimple)
}
func newHTTPSimple(b *Base) Obfs { func newHTTPSimple(b *Base) Obfs {
return &httpObfs{Base: b} return &httpObfs{Base: b}
} }
func (h *httpObfs) initForConn() Obfs { type httpConn struct {
return &httpObfs{ net.Conn
Base: h.Base, *httpObfs
firstRequest: true, hasSentHeader bool
firstResponse: true, hasRecvHeader bool
post: h.post, buf []byte
}
} }
func (h *httpObfs) GetObfsOverhead() int { func (h *httpObfs) StreamConn(c net.Conn) net.Conn {
return 0 return &httpConn{Conn: c, httpObfs: h}
} }
func (h *httpObfs) Decode(b []byte) ([]byte, bool, error) { func (c *httpConn) Read(b []byte) (int, error) {
if h.firstResponse { if c.buf != nil {
idx := bytes.Index(b, []byte("\r\n\r\n")) n := copy(b, c.buf)
if idx == -1 { if n == len(c.buf) {
return nil, false, io.EOF c.buf = nil
}
h.firstResponse = false
return b[idx+4:], false, nil
}
return b, false, nil
}
func (h *httpObfs) Encode(b []byte) ([]byte, error) {
if h.firstRequest {
bSize := len(b)
var headData []byte
if headSize := h.IVSize + h.HeadLen; bSize-headSize > 64 {
headData = make([]byte, headSize+rand.Intn(64))
} else { } else {
headData = make([]byte, bSize) c.buf = c.buf[n:]
}
copy(headData, b[:len(headData)])
host := h.Host
var customHead string
if len(h.Param) > 0 {
customHeads := strings.Split(h.Param, "#")
if len(customHeads) > 2 {
customHeads = customHeads[:2]
}
customHosts := h.Param
if len(customHeads) > 1 {
customHosts = customHeads[0]
customHead = customHeads[1]
}
hosts := strings.Split(customHosts, ",")
if len(hosts) > 0 {
host = strings.TrimSpace(hosts[rand.Intn(len(hosts))])
} }
return n, nil
} }
method := "GET /" if c.hasRecvHeader {
if h.post { return c.Conn.Read(b)
method = "POST /"
} }
requestPathIndex := rand.Intn(len(requestPath)/2) * 2
httpBuf := fmt.Sprintf("%s%s%s%s HTTP/1.1\r\nHost: %s:%d\r\n", buf := pool.Get(pool.RelayBufferSize)
method, defer pool.Put(buf)
requestPath[requestPathIndex], n, err := c.Conn.Read(buf)
data2URLEncode(headData), if err != nil {
requestPath[requestPathIndex+1], return 0, err
host, h.Port) }
if len(customHead) > 0 { pos := bytes.Index(buf[:n], []byte("\r\n\r\n"))
httpBuf = httpBuf + strings.Replace(customHead, "\\n", "\r\n", -1) + "\r\n\r\n" if pos == -1 {
return 0, io.EOF
}
c.hasRecvHeader = true
dataLength := n - pos - 4
n = copy(b, buf[4+pos:n])
if dataLength > n {
c.buf = append(c.buf, buf[4+pos+n:4+pos+dataLength]...)
}
return n, nil
}
func (c *httpConn) Write(b []byte) (int, error) {
if c.hasSentHeader {
return c.Conn.Write(b)
}
// 30: head length
headLength := c.IVSize + 30
bLength := len(b)
headDataLength := bLength
if bLength-headLength > 64 {
headDataLength = headLength + rand.Intn(65)
}
headData := b[:headDataLength]
b = b[headDataLength:]
var body string
host := c.Host
if len(c.Param) > 0 {
pos := strings.Index(c.Param, "#")
if pos != -1 {
body = strings.ReplaceAll(c.Param[pos+1:], "\n", "\r\n")
body = strings.ReplaceAll(body, "\\n", "\r\n")
host = c.Param[:pos]
} else { } else {
var contentType string host = c.Param
if h.post {
contentType = "Content-Type: multipart/form-data; boundary=" + boundary() + "\r\n"
} }
httpBuf = httpBuf + "User-agent: " + requestUserAgent[rand.Intn(len(requestUserAgent))] + "\r\n" +
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
"Accept-Language: en-US,en;q=0.8\r\n" +
"Accept-Encoding: gzip, deflate\r\n" +
contentType +
"DNT: 1\r\n" +
"Connection: keep-alive\r\n" +
"\r\n"
} }
hosts := strings.Split(host, ",")
host = hosts[rand.Intn(len(hosts))]
var encoded []byte buf := tools.BufPool.Get().(*bytes.Buffer)
if len(headData) < bSize { defer tools.BufPool.Put(buf)
encoded = make([]byte, len(httpBuf)+(bSize-len(headData))) defer buf.Reset()
copy(encoded, []byte(httpBuf)) if c.post {
copy(encoded[len(httpBuf):], b[len(headData):]) buf.WriteString("POST /")
} else { } else {
encoded = []byte(httpBuf) buf.WriteString("GET /")
} }
h.firstRequest = false packURLEncodedHeadData(buf, headData)
return encoded, nil buf.WriteString(" HTTP/1.1\r\nHost: " + host)
if c.Port != 80 {
buf.WriteString(":" + strconv.Itoa(c.Port))
}
buf.WriteString("\r\n")
if len(body) > 0 {
buf.WriteString(body + "\r\n\r\n")
} else {
buf.WriteString("User-Agent: ")
buf.WriteString(userAgent[rand.Intn(len(userAgent))])
buf.WriteString("\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\n")
if c.post {
packBoundary(buf)
}
buf.WriteString("DNT: 1\r\nConnection: keep-alive\r\n\r\n")
}
buf.Write(b)
_, err := c.Conn.Write(buf.Bytes())
if err != nil {
return 0, nil
}
c.hasSentHeader = true
return bLength, nil
} }
return b, nil func packURLEncodedHeadData(buf *bytes.Buffer, data []byte) {
dataLength := len(data)
for i := 0; i < dataLength; i++ {
buf.WriteRune('%')
buf.WriteString(hex.EncodeToString(data[i : i+1]))
}
} }
func data2URLEncode(data []byte) (ret string) { func packBoundary(buf *bytes.Buffer) {
for i := 0; i < len(data); i++ { buf.WriteString("Content-Type: multipart/form-data; boundary=")
ret = fmt.Sprintf("%s%%%s", ret, hex.EncodeToString([]byte{data[i]}))
}
return
}
func boundary() (ret string) {
set := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" set := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
ret = fmt.Sprintf("%s%c", ret, set[rand.Intn(len(set))]) buf.WriteByte(set[rand.Intn(62)])
} }
return buf.WriteString("\r\n")
} }
var ( var userAgent = []string{
requestPath = []string{
"", "",
"login.php?redir=", "",
"register.php?code=", "",
"?keyword=", "",
"search?src=typd&q=", "&lang=en",
"s?ie=utf-8&f=8&rsv_bp=1&rsv_idx=1&ch=&bar=&wd=", "&rn=",
"post.php?id=", "&goto=view.php",
}
requestUserAgent = []string{
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.162 Safari/537.36", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.162 Safari/537.36",
"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/45.0.2454.85 Safari/537.36", "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/45.0.2454.85 Safari/537.36",
"Mozilla/5.0 (Linux; Android 7.0; Moto C Build/NRD90M.059) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Mobile Safari/537.36", "Mozilla/5.0 (Linux; Android 7.0; Moto C Build/NRD90M.059) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Mobile Safari/537.36",
@ -399,4 +405,3 @@ var (
"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.67 Safari/537.36", "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.67 Safari/537.36",
"Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/68.0.3440.106 Safari/537.36", "Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/68.0.3440.106 Safari/537.36",
} }
)

View File

@ -3,7 +3,7 @@ package obfs
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings" "net"
) )
var ( var (
@ -12,26 +12,31 @@ var (
errTLS12TicketAuthHMACError = errors.New("tls1.2_ticket_auth hmac verifying failed") errTLS12TicketAuthHMACError = errors.New("tls1.2_ticket_auth hmac verifying failed")
) )
// Obfs provides methods for decoding and encoding type authData struct {
clientID [32]byte
}
type Obfs interface { type Obfs interface {
initForConn() Obfs StreamConn(net.Conn) net.Conn
GetObfsOverhead() int
Decode(b []byte) ([]byte, bool, error)
Encode(b []byte) ([]byte, error)
} }
type obfsCreator func(b *Base) Obfs type obfsCreator func(b *Base) Obfs
var obfsList = make(map[string]obfsCreator) var obfsList = make(map[string]struct {
overhead int
new obfsCreator
})
func register(name string, c obfsCreator) { func register(name string, c obfsCreator, o int) {
obfsList[name] = c obfsList[name] = struct {
overhead int
new obfsCreator
}{overhead: o, new: c}
} }
// PickObfs returns an obfs of the given name func PickObfs(name string, b *Base) (Obfs, int, error) {
func PickObfs(name string, b *Base) (Obfs, error) { if choice, ok := obfsList[name]; ok {
if obfsCreator, ok := obfsList[strings.ToLower(name)]; ok { return choice.new(b), choice.overhead, nil
return obfsCreator(b), nil
} }
return nil, fmt.Errorf("Obfs %s not supported", name) return nil, 0, fmt.Errorf("Obfs %s not supported", name)
} }

View File

@ -1,25 +1,15 @@
package obfs package obfs
import "net"
type plain struct{} type plain struct{}
func init() { func init() {
register("plain", newPlain) register("plain", newPlain, 0)
} }
func newPlain(b *Base) Obfs { func newPlain(b *Base) Obfs {
return &plain{} return &plain{}
} }
func (p *plain) initForConn() Obfs { return &plain{} } func (p *plain) StreamConn(c net.Conn) net.Conn { return c }
func (p *plain) GetObfsOverhead() int {
return 0
}
func (p *plain) Encode(b []byte) ([]byte, error) {
return b, nil
}
func (p *plain) Decode(b []byte) ([]byte, bool, error) {
return b, false, nil
}

View File

@ -4,72 +4,68 @@ import (
"encoding/binary" "encoding/binary"
"hash/crc32" "hash/crc32"
"math/rand" "math/rand"
"net"
"github.com/Dreamacro/clash/common/pool"
) )
func init() {
register("random_head", newRandomHead, 0)
}
type randomHead struct { type randomHead struct {
*Base *Base
firstRequest bool
firstResponse bool
headerSent bool
buffer []byte
}
func init() {
register("random_head", newRandomHead)
} }
func newRandomHead(b *Base) Obfs { func newRandomHead(b *Base) Obfs {
return &randomHead{Base: b} return &randomHead{Base: b}
} }
func (r *randomHead) initForConn() Obfs { type randomHeadConn struct {
return &randomHead{ net.Conn
Base: r.Base, *randomHead
firstRequest: true, hasSentHeader bool
firstResponse: true, rawTransSent bool
} rawTransRecv bool
buf []byte
} }
func (r *randomHead) GetObfsOverhead() int { func (r *randomHead) StreamConn(c net.Conn) net.Conn {
return 0 return &randomHeadConn{Conn: c, randomHead: r}
} }
func (r *randomHead) Encode(b []byte) (encoded []byte, err error) { func (c *randomHeadConn) Read(b []byte) (int, error) {
if !r.firstRequest { if c.rawTransRecv {
return b, nil return c.Conn.Read(b)
}
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)
c.Conn.Read(buf)
c.rawTransRecv = true
c.Write(nil)
return 0, nil
} }
bSize := len(b) func (c *randomHeadConn) Write(b []byte) (int, error) {
if r.headerSent { if c.rawTransSent {
if bSize > 0 { return c.Conn.Write(b)
d := make([]byte, len(r.buffer)+bSize)
copy(d, r.buffer)
copy(d[len(r.buffer):], b)
r.buffer = d
} else {
encoded = r.buffer
r.buffer = nil
r.firstRequest = false
} }
} else { c.buf = append(c.buf, b...)
size := rand.Intn(96) + 8 if !c.hasSentHeader {
encoded = make([]byte, size) c.hasSentHeader = true
rand.Read(encoded) dataLength := rand.Intn(96) + 4
crc := (0xFFFFFFFF - crc32.ChecksumIEEE(encoded[:size-4])) & 0xFFFFFFFF buf := pool.Get(dataLength + 4)
binary.LittleEndian.PutUint32(encoded[size-4:], crc) defer pool.Put(buf)
rand.Read(buf[:dataLength])
d := make([]byte, bSize) binary.LittleEndian.PutUint32(buf[dataLength:], 0xffffffff-crc32.ChecksumIEEE(buf[:dataLength]))
copy(d, b) _, err := c.Conn.Write(buf)
r.buffer = d return len(b), err
} }
r.headerSent = true if c.rawTransRecv {
return encoded, nil _, err := c.Conn.Write(c.buf)
c.buf = nil
c.rawTransSent = true
return len(b), err
} }
return len(b), nil
func (r *randomHead) Decode(b []byte) ([]byte, bool, error) {
if r.firstResponse {
r.firstResponse = false
return b, true, nil
}
return b, false, nil
} }

View File

@ -1,72 +0,0 @@
package obfs
import (
"net"
"github.com/Dreamacro/clash/common/pool"
)
// NewConn wraps a stream-oriented net.Conn with obfs decoding/encoding
func NewConn(c net.Conn, o Obfs) net.Conn {
return &Conn{Conn: c, Obfs: o.initForConn()}
}
// Conn represents an obfs connection
type Conn struct {
net.Conn
Obfs
buf []byte
offset int
}
func (c *Conn) Read(b []byte) (int, error) {
if c.buf != nil {
n := copy(b, c.buf[c.offset:])
c.offset += n
if c.offset == len(c.buf) {
pool.Put(c.buf)
c.buf = nil
}
return n, nil
}
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)
n, err := c.Conn.Read(buf)
if err != nil {
return 0, err
}
decoded, sendback, err := c.Decode(buf[:n])
// decoded may be part of buf
decodedData := pool.Get(len(decoded))
copy(decodedData, decoded)
if err != nil {
pool.Put(decodedData)
return 0, err
}
if sendback {
c.Write(nil)
pool.Put(decodedData)
return 0, nil
}
n = copy(b, decodedData)
if len(decodedData) > len(b) {
c.buf = decodedData
c.offset = n
} else {
pool.Put(decodedData)
}
return n, err
}
func (c *Conn) Write(b []byte) (int, error) {
encoded, err := c.Encode(b)
if err != nil {
return 0, err
}
_, err = c.Conn.Write(encoded)
if err != nil {
return 0, err
}
return len(b), nil
}

View File

@ -0,0 +1,231 @@
package obfs
import (
"bytes"
"crypto/hmac"
"encoding/binary"
"math/rand"
"net"
"strings"
"time"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools"
)
func init() {
register("tls1.2_ticket_auth", newTLS12Ticket, 5)
register("tls1.2_ticket_fastauth", newTLS12Ticket, 5)
}
type tls12Ticket struct {
*Base
*authData
}
func newTLS12Ticket(b *Base) Obfs {
r := &tls12Ticket{Base: b, authData: &authData{}}
rand.Read(r.clientID[:])
return r
}
type tls12TicketConn struct {
net.Conn
*tls12Ticket
handshakeStatus int
decoded bytes.Buffer
underDecoded bytes.Buffer
sendBuf bytes.Buffer
}
func (t *tls12Ticket) StreamConn(c net.Conn) net.Conn {
return &tls12TicketConn{Conn: c, tls12Ticket: t}
}
func (c *tls12TicketConn) Read(b []byte) (int, error) {
if c.decoded.Len() > 0 {
return c.decoded.Read(b)
}
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)
n, err := c.Conn.Read(buf)
if err != nil {
return 0, err
}
if c.handshakeStatus == 8 {
c.underDecoded.Write(buf[:n])
for c.underDecoded.Len() > 5 {
if !bytes.Equal(c.underDecoded.Bytes()[:3], []byte{0x17, 3, 3}) {
c.underDecoded.Reset()
return 0, errTLS12TicketAuthIncorrectMagicNumber
}
size := int(binary.BigEndian.Uint16(c.underDecoded.Bytes()[3:5]))
if c.underDecoded.Len() < 5+size {
break
}
c.underDecoded.Next(5)
c.decoded.Write(c.underDecoded.Next(size))
}
n, _ = c.decoded.Read(b)
return n, nil
}
if n < 11+32+1+32 {
return 0, errTLS12TicketAuthTooShortData
}
if !hmac.Equal(buf[33:43], c.hmacSHA1(buf[11:33])[:10]) || !hmac.Equal(buf[n-10:n], c.hmacSHA1(buf[:n-10])[:10]) {
return 0, errTLS12TicketAuthHMACError
}
c.Write(nil)
return 0, nil
}
func (c *tls12TicketConn) Write(b []byte) (int, error) {
length := len(b)
if c.handshakeStatus == 8 {
buf := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(buf)
defer buf.Reset()
for len(b) > 2048 {
size := rand.Intn(4096) + 100
if len(b) < size {
size = len(b)
}
packData(buf, b[:size])
b = b[size:]
}
if len(b) > 0 {
packData(buf, b)
}
_, err := c.Conn.Write(buf.Bytes())
if err != nil {
return 0, err
}
return length, nil
}
if len(b) > 0 {
packData(&c.sendBuf, b)
}
if c.handshakeStatus == 0 {
c.handshakeStatus = 1
data := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(data)
defer data.Reset()
data.Write([]byte{3, 3})
c.packAuthData(data)
data.WriteByte(0x20)
data.Write(c.clientID[:])
data.Write([]byte{0x00, 0x1c, 0xc0, 0x2b, 0xc0, 0x2f, 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, 0xc0, 0x0a, 0xc0, 0x14, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x9c, 0x00, 0x35, 0x00, 0x2f, 0x00, 0x0a})
data.Write([]byte{0x1, 0x0})
ext := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(ext)
defer ext.Reset()
host := c.getHost()
ext.Write([]byte{0xff, 0x01, 0x00, 0x01, 0x00})
packSNIData(ext, host)
ext.Write([]byte{0, 0x17, 0, 0})
c.packTicketBuf(ext, host)
ext.Write([]byte{0x00, 0x0d, 0x00, 0x16, 0x00, 0x14, 0x06, 0x01, 0x06, 0x03, 0x05, 0x01, 0x05, 0x03, 0x04, 0x01, 0x04, 0x03, 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03})
ext.Write([]byte{0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00})
ext.Write([]byte{0x00, 0x12, 0x00, 0x00})
ext.Write([]byte{0x75, 0x50, 0x00, 0x00})
ext.Write([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00})
ext.Write([]byte{0x00, 0x0a, 0x00, 0x06, 0x00, 0x04, 0x00, 0x17, 0x00, 0x18})
binary.Write(data, binary.BigEndian, uint16(ext.Len()))
data.ReadFrom(ext)
ret := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(ret)
defer ret.Reset()
ret.Write([]byte{0x16, 3, 1})
binary.Write(ret, binary.BigEndian, uint16(data.Len()+4))
ret.Write([]byte{1, 0})
binary.Write(ret, binary.BigEndian, uint16(data.Len()))
ret.ReadFrom(data)
_, err := c.Conn.Write(ret.Bytes())
if err != nil {
return 0, err
}
return length, nil
} else if c.handshakeStatus == 1 && len(b) == 0 {
buf := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(buf)
defer buf.Reset()
buf.Write([]byte{0x14, 3, 3, 0, 1, 1, 0x16, 3, 3, 0, 0x20})
tools.AppendRandBytes(buf, 22)
buf.Write(c.hmacSHA1(buf.Bytes())[:10])
buf.ReadFrom(&c.sendBuf)
c.handshakeStatus = 8
_, err := c.Conn.Write(buf.Bytes())
return 0, err
}
return length, nil
}
func packData(buf *bytes.Buffer, data []byte) {
buf.Write([]byte{0x17, 3, 3})
binary.Write(buf, binary.BigEndian, uint16(len(data)))
buf.Write(data)
}
func (t *tls12Ticket) packAuthData(buf *bytes.Buffer) {
binary.Write(buf, binary.BigEndian, uint32(time.Now().Unix()))
tools.AppendRandBytes(buf, 18)
buf.Write(t.hmacSHA1(buf.Bytes()[buf.Len()-22:])[:10])
}
func packSNIData(buf *bytes.Buffer, u string) {
len := uint16(len(u))
buf.Write([]byte{0, 0})
binary.Write(buf, binary.BigEndian, len+5)
binary.Write(buf, binary.BigEndian, len+3)
buf.WriteByte(0)
binary.Write(buf, binary.BigEndian, len)
buf.WriteString(u)
}
func (c *tls12TicketConn) packTicketBuf(buf *bytes.Buffer, u string) {
length := 16 * (rand.Intn(17) + 8)
buf.Write([]byte{0, 0x23})
binary.Write(buf, binary.BigEndian, uint16(length))
tools.AppendRandBytes(buf, length)
}
func (t *tls12Ticket) hmacSHA1(data []byte) []byte {
key := pool.Get(len(t.Key) + 32)
defer pool.Put(key)
copy(key, t.Key)
copy(key[len(t.Key):], t.clientID[:])
sha1Data := tools.HmacSHA1(key, data)
return sha1Data[:10]
}
func (t *tls12Ticket) getHost() string {
host := t.Param
if len(host) == 0 {
host = t.Host
}
if len(host) > 0 && host[len(host)-1] >= '0' && host[len(host)-1] <= '9' {
host = ""
}
hosts := strings.Split(host, ",")
host = hosts[rand.Intn(len(hosts))]
return host
}

View File

@ -1,291 +0,0 @@
package obfs
import (
"bytes"
"crypto/hmac"
"encoding/binary"
"fmt"
"io"
"log"
"math/rand"
"strings"
"time"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools"
)
type tlsAuthData struct {
localClientID [32]byte
}
type tls12Ticket struct {
*Base
*tlsAuthData
handshakeStatus int
sendSaver bytes.Buffer
recvBuffer bytes.Buffer
buffer bytes.Buffer
}
func init() {
register("tls1.2_ticket_auth", newTLS12Ticket)
register("tls1.2_ticket_fastauth", newTLS12Ticket)
}
func newTLS12Ticket(b *Base) Obfs {
return &tls12Ticket{
Base: b,
}
}
func (t *tls12Ticket) initForConn() Obfs {
r := &tls12Ticket{
Base: t.Base,
tlsAuthData: &tlsAuthData{},
}
rand.Read(r.localClientID[:])
return r
}
func (t *tls12Ticket) GetObfsOverhead() int {
return 5
}
func (t *tls12Ticket) Decode(b []byte) ([]byte, bool, error) {
if t.handshakeStatus == -1 {
return b, false, nil
}
t.buffer.Reset()
if t.handshakeStatus == 8 {
t.recvBuffer.Write(b)
for t.recvBuffer.Len() > 5 {
var h [5]byte
t.recvBuffer.Read(h[:])
if !bytes.Equal(h[:3], []byte{0x17, 0x3, 0x3}) {
log.Println("incorrect magic number", h[:3], ", 0x170303 is expected")
return nil, false, errTLS12TicketAuthIncorrectMagicNumber
}
size := int(binary.BigEndian.Uint16(h[3:5]))
if t.recvBuffer.Len() < size {
// 不够读,下回再读吧
unread := t.recvBuffer.Bytes()
t.recvBuffer.Reset()
t.recvBuffer.Write(h[:])
t.recvBuffer.Write(unread)
break
}
d := pool.Get(size)
t.recvBuffer.Read(d)
t.buffer.Write(d)
pool.Put(d)
}
return t.buffer.Bytes(), false, nil
}
if len(b) < 11+32+1+32 {
return nil, false, errTLS12TicketAuthTooShortData
}
hash := t.hmacSHA1(b[11 : 11+22])
if !hmac.Equal(b[33:33+tools.HmacSHA1Len], hash) {
return nil, false, errTLS12TicketAuthHMACError
}
return nil, true, nil
}
func (t *tls12Ticket) Encode(b []byte) ([]byte, error) {
t.buffer.Reset()
switch t.handshakeStatus {
case 8:
if len(b) < 1024 {
d := []byte{0x17, 0x3, 0x3, 0, 0}
binary.BigEndian.PutUint16(d[3:5], uint16(len(b)&0xFFFF))
t.buffer.Write(d)
t.buffer.Write(b)
return t.buffer.Bytes(), nil
}
start := 0
var l int
for len(b)-start > 2048 {
l = rand.Intn(4096) + 100
if l > len(b)-start {
l = len(b) - start
}
packData(&t.buffer, b[start:start+l])
start += l
}
if len(b)-start > 0 {
l = len(b) - start
packData(&t.buffer, b[start:start+l])
}
return t.buffer.Bytes(), nil
case 1:
if len(b) > 0 {
if len(b) < 1024 {
packData(&t.sendSaver, b)
} else {
start := 0
var l int
for len(b)-start > 2048 {
l = rand.Intn(4096) + 100
if l > len(b)-start {
l = len(b) - start
}
packData(&t.buffer, b[start:start+l])
start += l
}
if len(b)-start > 0 {
l = len(b) - start
packData(&t.buffer, b[start:start+l])
}
io.Copy(&t.sendSaver, &t.buffer)
}
return []byte{}, nil
}
hmacData := make([]byte, 43)
handshakeFinish := []byte("\x14\x03\x03\x00\x01\x01\x16\x03\x03\x00\x20")
copy(hmacData, handshakeFinish)
rand.Read(hmacData[11:33])
h := t.hmacSHA1(hmacData[:33])
copy(hmacData[33:], h)
t.buffer.Write(hmacData)
io.Copy(&t.buffer, &t.sendSaver)
t.handshakeStatus = 8
return t.buffer.Bytes(), nil
case 0:
tlsData0 := []byte("\x00\x1c\xc0\x2b\xc0\x2f\xcc\xa9\xcc\xa8\xcc\x14\xcc\x13\xc0\x0a\xc0\x14\xc0\x09\xc0\x13\x00\x9c\x00\x35\x00\x2f\x00\x0a\x01\x00")
tlsData1 := []byte("\xff\x01\x00\x01\x00")
tlsData2 := []byte("\x00\x17\x00\x00\x00\x23\x00\xd0")
// tlsData3 := []byte("\x00\x0d\x00\x16\x00\x14\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x03\x01\x03\x03\x02\x01\x02\x03\x00\x05\x00\x05\x01\x00\x00\x00\x00\x00\x12\x00\x00\x75\x50\x00\x00\x00\x0b\x00\x02\x01\x00\x00\x0a\x00\x06\x00\x04\x00\x17\x00\x18\x00\x15\x00\x66\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
tlsData3 := []byte("\x00\x0d\x00\x16\x00\x14\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x03\x01\x03\x03\x02\x01\x02\x03\x00\x05\x00\x05\x01\x00\x00\x00\x00\x00\x12\x00\x00\x75\x50\x00\x00\x00\x0b\x00\x02\x01\x00\x00\x0a\x00\x06\x00\x04\x00\x17\x00\x18")
var tlsData [2048]byte
tlsDataLen := 0
copy(tlsData[0:], tlsData1)
tlsDataLen += len(tlsData1)
sni := t.sni(t.getHost())
copy(tlsData[tlsDataLen:], sni)
tlsDataLen += len(sni)
copy(tlsData[tlsDataLen:], tlsData2)
tlsDataLen += len(tlsData2)
ticketLen := rand.Intn(164)*2 + 64
tlsData[tlsDataLen-1] = uint8(ticketLen & 0xff)
tlsData[tlsDataLen-2] = uint8(ticketLen >> 8)
//ticketLen := 208
rand.Read(tlsData[tlsDataLen : tlsDataLen+ticketLen])
tlsDataLen += ticketLen
copy(tlsData[tlsDataLen:], tlsData3)
tlsDataLen += len(tlsData3)
length := 11 + 32 + 1 + 32 + len(tlsData0) + 2 + tlsDataLen
encodedData := make([]byte, length)
pdata := length - tlsDataLen
l := tlsDataLen
copy(encodedData[pdata:], tlsData[:tlsDataLen])
encodedData[pdata-1] = uint8(tlsDataLen)
encodedData[pdata-2] = uint8(tlsDataLen >> 8)
pdata -= 2
l += 2
copy(encodedData[pdata-len(tlsData0):], tlsData0)
pdata -= len(tlsData0)
l += len(tlsData0)
copy(encodedData[pdata-32:], t.localClientID[:])
pdata -= 32
l += 32
encodedData[pdata-1] = 0x20
pdata--
l++
copy(encodedData[pdata-32:], t.packAuthData())
pdata -= 32
l += 32
encodedData[pdata-1] = 0x3
encodedData[pdata-2] = 0x3 // tls version
pdata -= 2
l += 2
encodedData[pdata-1] = uint8(l)
encodedData[pdata-2] = uint8(l >> 8)
encodedData[pdata-3] = 0
encodedData[pdata-4] = 1
pdata -= 4
l += 4
encodedData[pdata-1] = uint8(l)
encodedData[pdata-2] = uint8(l >> 8)
pdata -= 2
l += 2
encodedData[pdata-1] = 0x1
encodedData[pdata-2] = 0x3 // tls version
pdata -= 2
l += 2
encodedData[pdata-1] = 0x16 // tls handshake
pdata--
l++
packData(&t.sendSaver, b)
t.handshakeStatus = 1
return encodedData, nil
default:
return nil, fmt.Errorf("unexpected handshake status: %d", t.handshakeStatus)
}
}
func (t *tls12Ticket) hmacSHA1(data []byte) []byte {
key := make([]byte, len(t.Key)+32)
copy(key, t.Key)
copy(key[len(t.Key):], t.localClientID[:])
sha1Data := tools.HmacSHA1(key, data)
return sha1Data[:tools.HmacSHA1Len]
}
func (t *tls12Ticket) sni(u string) []byte {
bURL := []byte(u)
length := len(bURL)
ret := make([]byte, length+9)
copy(ret[9:9+length], bURL)
binary.BigEndian.PutUint16(ret[7:], uint16(length&0xFFFF))
length += 3
binary.BigEndian.PutUint16(ret[4:], uint16(length&0xFFFF))
length += 2
binary.BigEndian.PutUint16(ret[2:], uint16(length&0xFFFF))
return ret
}
func (t *tls12Ticket) getHost() string {
host := t.Host
if len(t.Param) > 0 {
hosts := strings.Split(t.Param, ",")
if len(hosts) > 0 {
host = hosts[rand.Intn(len(hosts))]
host = strings.TrimSpace(host)
}
}
if len(host) > 0 && host[len(host)-1] >= byte('0') && host[len(host)-1] <= byte('9') && len(t.Param) == 0 {
host = ""
}
return host
}
func (t *tls12Ticket) packAuthData() (ret []byte) {
retSize := 32
ret = make([]byte, retSize)
now := time.Now().Unix()
binary.BigEndian.PutUint32(ret[:4], uint32(now))
rand.Read(ret[4 : 4+18])
hash := t.hmacSHA1(ret[:retSize-tools.HmacSHA1Len])
copy(ret[retSize-tools.HmacSHA1Len:], hash)
return
}
func packData(buffer *bytes.Buffer, suffix []byte) {
d := []byte{0x17, 0x3, 0x3, 0, 0}
binary.BigEndian.PutUint16(d[3:5], uint16(len(suffix)&0xFFFF))
buffer.Write(d)
buffer.Write(suffix)
return
}

View File

@ -1,310 +1,18 @@
package protocol package protocol
import ( import "github.com/Dreamacro/clash/component/ssr/tools"
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/binary"
"math/rand"
"strconv"
"strings"
"time"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools"
"github.com/Dreamacro/go-shadowsocks2/core"
)
type authAES128 struct {
*Base
*recvInfo
*authData
hasSentHeader bool
packID uint32
userKey []byte
uid [4]byte
salt string
hmac hmacMethod
hashDigest hashDigestMethod
}
func init() { func init() {
register("auth_aes128_md5", newAuthAES128MD5) register("auth_aes128_md5", newAuthAES128MD5, 9)
} }
func newAuthAES128MD5(b *Base) Protocol { func newAuthAES128MD5(b *Base) Protocol {
return &authAES128{ a := &authAES128{
Base: b, Base: b,
authData: &authData{}, authData: &authData{},
salt: "auth_aes128_md5", authAES128Function: &authAES128Function{salt: "auth_aes128_md5", hmac: tools.HmacMD5, hashDigest: tools.MD5Sum},
hmac: tools.HmacMD5, userData: &userData{},
hashDigest: tools.MD5Sum,
} }
} a.initUserData()
return a
func (a *authAES128) initForConn(iv []byte) Protocol {
return &authAES128{
Base: &Base{
IV: iv,
Key: a.Key,
TCPMss: a.TCPMss,
Overhead: a.Overhead,
Param: a.Param,
},
recvInfo: &recvInfo{recvID: 1, buffer: new(bytes.Buffer)},
authData: a.authData,
packID: 1,
salt: a.salt,
hmac: a.hmac,
hashDigest: a.hashDigest,
}
}
func (a *authAES128) GetProtocolOverhead() int {
return 9
}
func (a *authAES128) SetOverhead(overhead int) {
a.Overhead = overhead
}
func (a *authAES128) Decode(b []byte) ([]byte, int, error) {
a.buffer.Reset()
bSize := len(b)
readSize := 0
key := pool.Get(len(a.userKey) + 4)
defer pool.Put(key)
copy(key, a.userKey)
for bSize > 4 {
binary.LittleEndian.PutUint32(key[len(key)-4:], a.recvID)
h := a.hmac(key, b[:2])
if !bytes.Equal(h[:2], b[2:4]) {
return nil, 0, errAuthAES128IncorrectMAC
}
length := int(binary.LittleEndian.Uint16(b[:2]))
if length >= 8192 || length < 8 {
return nil, 0, errAuthAES128DataLengthError
}
if length > bSize {
break
}
h = a.hmac(key, b[:length-4])
if !bytes.Equal(h[:4], b[length-4:length]) {
return nil, 0, errAuthAES128IncorrectChecksum
}
a.recvID++
pos := int(b[4])
if pos < 255 {
pos += 4
} else {
pos = int(binary.LittleEndian.Uint16(b[5:7])) + 4
}
if pos > length-4 {
return nil, 0, errAuthAES128PositionTooLarge
}
a.buffer.Write(b[pos : length-4])
b = b[length:]
bSize -= length
readSize += length
}
return a.buffer.Bytes(), readSize, nil
}
func (a *authAES128) Encode(b []byte) ([]byte, error) {
a.buffer.Reset()
bSize := len(b)
offset := 0
if bSize > 0 && !a.hasSentHeader {
authSize := bSize
if authSize > 1200 {
authSize = 1200
}
a.hasSentHeader = true
a.buffer.Write(a.packAuthData(b[:authSize]))
bSize -= authSize
offset += authSize
}
const blockSize = 4096
for bSize > blockSize {
packSize, randSize := a.packedDataSize(b[offset : offset+blockSize])
pack := pool.Get(packSize)
a.packData(b[offset:offset+blockSize], pack, randSize)
a.buffer.Write(pack)
pool.Put(pack)
bSize -= blockSize
offset += blockSize
}
if bSize > 0 {
packSize, randSize := a.packedDataSize(b[offset:])
pack := pool.Get(packSize)
a.packData(b[offset:], pack, randSize)
a.buffer.Write(pack)
pool.Put(pack)
}
return a.buffer.Bytes(), nil
}
func (a *authAES128) DecodePacket(b []byte) ([]byte, int, error) {
bSize := len(b)
h := a.hmac(a.Key, b[:bSize-4])
if !bytes.Equal(h[:4], b[bSize-4:]) {
return nil, 0, errAuthAES128IncorrectMAC
}
return b[:bSize-4], bSize - 4, nil
}
func (a *authAES128) EncodePacket(b []byte) ([]byte, error) {
a.initUserKeyAndID()
var buf bytes.Buffer
buf.Write(b)
buf.Write(a.uid[:])
h := a.hmac(a.userKey, buf.Bytes())
buf.Write(h[:4])
return buf.Bytes(), nil
}
func (a *authAES128) initUserKeyAndID() {
if a.userKey == nil {
params := strings.Split(a.Param, ":")
if len(params) >= 2 {
if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
binary.LittleEndian.PutUint32(a.uid[:], uint32(userID))
a.userKey = a.hashDigest([]byte(params[1]))
}
}
if a.userKey == nil {
rand.Read(a.uid[:])
a.userKey = make([]byte, len(a.Key))
copy(a.userKey, a.Key)
}
}
}
func (a *authAES128) packedDataSize(data []byte) (packSize, randSize int) {
dataSize := len(data)
randSize = 1
if dataSize <= 1200 {
if a.packID > 4 {
randSize += rand.Intn(32)
} else {
if dataSize > 900 {
randSize += rand.Intn(128)
} else {
randSize += rand.Intn(512)
}
}
}
packSize = randSize + dataSize + 8
return
}
func (a *authAES128) packData(data, ret []byte, randSize int) {
dataSize := len(data)
retSize := len(ret)
// 0~1, ret_size
binary.LittleEndian.PutUint16(ret[0:], uint16(retSize&0xFFFF))
// 2~3, hmac
key := pool.Get(len(a.userKey) + 4)
defer pool.Put(key)
copy(key, a.userKey)
binary.LittleEndian.PutUint32(key[len(key)-4:], a.packID)
h := a.hmac(key, ret[:2])
copy(ret[2:4], h[:2])
// 4~rand_size+4, rand number
rand.Read(ret[4 : 4+randSize])
// 4, rand_size
if randSize < 128 {
ret[4] = byte(randSize & 0xFF)
} else {
// 4, magic number 0xFF
ret[4] = 0xFF
// 5~6, rand_size
binary.LittleEndian.PutUint16(ret[5:], uint16(randSize&0xFFFF))
}
// rand_size+4~ret_size-4, data
if dataSize > 0 {
copy(ret[randSize+4:], data)
}
a.packID++
h = a.hmac(key, ret[:retSize-4])
copy(ret[retSize-4:], h[:4])
}
func (a *authAES128) packAuthData(data []byte) (ret []byte) {
dataSize := len(data)
var randSize int
if dataSize > 400 {
randSize = rand.Intn(512)
} else {
randSize = rand.Intn(1024)
}
dataOffset := randSize + 16 + 4 + 4 + 7
retSize := dataOffset + dataSize + 4
ret = make([]byte, retSize)
encrypt := make([]byte, 24)
key := make([]byte, len(a.IV)+len(a.Key))
copy(key, a.IV)
copy(key[len(a.IV):], a.Key)
rand.Read(ret[dataOffset-randSize:])
a.mutex.Lock()
defer a.mutex.Unlock()
a.connectionID++
if a.connectionID > 0xFF000000 {
a.clientID = nil
}
if len(a.clientID) == 0 {
a.clientID = make([]byte, 8)
rand.Read(a.clientID)
b := make([]byte, 4)
rand.Read(b)
a.connectionID = binary.LittleEndian.Uint32(b) & 0xFFFFFF
}
copy(encrypt[4:], a.clientID)
binary.LittleEndian.PutUint32(encrypt[8:], a.connectionID)
now := time.Now().Unix()
binary.LittleEndian.PutUint32(encrypt[:4], uint32(now))
binary.LittleEndian.PutUint16(encrypt[12:], uint16(retSize&0xFFFF))
binary.LittleEndian.PutUint16(encrypt[14:], uint16(randSize&0xFFFF))
a.initUserKeyAndID()
aesCipherKey := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+a.salt, 16)
block, err := aes.NewCipher(aesCipherKey)
if err != nil {
return nil
}
encryptData := make([]byte, 16)
iv := make([]byte, aes.BlockSize)
cbc := cipher.NewCBCEncrypter(block, iv)
cbc.CryptBlocks(encryptData, encrypt[:16])
copy(encrypt[:4], a.uid[:])
copy(encrypt[4:4+16], encryptData)
h := a.hmac(key, encrypt[:20])
copy(encrypt[20:], h[:4])
rand.Read(ret[:1])
h = a.hmac(key, ret[:1])
copy(ret[1:], h[:7-1])
copy(ret[7:], encrypt)
copy(ret[dataOffset:], data)
h = a.hmac(a.userKey, ret[:retSize-4])
copy(ret[retSize-4:], h[:4])
return
} }

View File

@ -2,21 +2,274 @@ package protocol
import ( import (
"bytes" "bytes"
"encoding/binary"
"math"
"math/rand"
"net"
"strconv"
"strings"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools" "github.com/Dreamacro/clash/component/ssr/tools"
"github.com/Dreamacro/clash/log"
) )
type hmacMethod func(key, data []byte) []byte
type hashDigestMethod func([]byte) []byte
func init() { func init() {
register("auth_aes128_sha1", newAuthAES128SHA1) register("auth_aes128_sha1", newAuthAES128SHA1, 9)
}
type authAES128Function struct {
salt string
hmac hmacMethod
hashDigest hashDigestMethod
}
type authAES128 struct {
*Base
*authData
*authAES128Function
*userData
iv []byte
hasSentHeader bool
rawTrans bool
packID uint32
recvID uint32
} }
func newAuthAES128SHA1(b *Base) Protocol { func newAuthAES128SHA1(b *Base) Protocol {
return &authAES128{ a := &authAES128{
Base: b, Base: b,
recvInfo: &recvInfo{buffer: new(bytes.Buffer)},
authData: &authData{}, authData: &authData{},
salt: "auth_aes128_sha1", authAES128Function: &authAES128Function{salt: "auth_aes128_sha1", hmac: tools.HmacSHA1, hashDigest: tools.SHA1Sum},
hmac: tools.HmacSHA1, userData: &userData{},
hashDigest: tools.SHA1Sum, }
a.initUserData()
return a
}
func (a *authAES128) initUserData() {
params := strings.Split(a.Param, ":")
if len(params) > 1 {
if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
a.userKey = a.hashDigest([]byte(params[1]))
} else {
log.Warnln("Wrong protocol-param for %s, only digits are expected before ':'", a.salt)
} }
} }
if len(a.userKey) == 0 {
a.userKey = a.Key
rand.Read(a.userID[:])
}
}
func (a *authAES128) StreamConn(c net.Conn, iv []byte) net.Conn {
p := &authAES128{
Base: a.Base,
authData: a.next(),
authAES128Function: a.authAES128Function,
userData: a.userData,
packID: 1,
recvID: 1,
}
p.iv = iv
return &Conn{Conn: c, Protocol: p}
}
func (a *authAES128) PacketConn(c net.PacketConn) net.PacketConn {
p := &authAES128{
Base: a.Base,
authAES128Function: a.authAES128Function,
userData: a.userData,
}
return &PacketConn{PacketConn: c, Protocol: p}
}
func (a *authAES128) Decode(dst, src *bytes.Buffer) error {
if a.rawTrans {
dst.ReadFrom(src)
return nil
}
for src.Len() > 4 {
macKey := pool.Get(len(a.userKey) + 4)
defer pool.Put(macKey)
copy(macKey, a.userKey)
binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
if !bytes.Equal(a.hmac(macKey, src.Bytes()[:2])[:2], src.Bytes()[2:4]) {
src.Reset()
return errAuthAES128MACError
}
length := int(binary.LittleEndian.Uint16(src.Bytes()[:2]))
if length >= 8192 || length < 7 {
a.rawTrans = true
src.Reset()
return errAuthAES128LengthError
}
if length > src.Len() {
break
}
if !bytes.Equal(a.hmac(macKey, src.Bytes()[:length-4])[:4], src.Bytes()[length-4:length]) {
a.rawTrans = true
src.Reset()
return errAuthAES128ChksumError
}
a.recvID++
pos := int(src.Bytes()[4])
if pos < 255 {
pos += 4
} else {
pos = int(binary.LittleEndian.Uint16(src.Bytes()[5:7])) + 4
}
dst.Write(src.Bytes()[pos : length-4])
src.Next(length)
}
return nil
}
func (a *authAES128) Encode(buf *bytes.Buffer, b []byte) error {
fullDataLength := len(b)
if !a.hasSentHeader {
dataLength := getDataLength(b)
a.packAuthData(buf, b[:dataLength])
b = b[dataLength:]
a.hasSentHeader = true
}
for len(b) > 8100 {
a.packData(buf, b[:8100], fullDataLength)
b = b[8100:]
}
if len(b) > 0 {
a.packData(buf, b, fullDataLength)
}
return nil
}
func (a *authAES128) DecodePacket(b []byte) ([]byte, error) {
if !bytes.Equal(a.hmac(a.Key, b[:len(b)-4])[:4], b[len(b)-4:]) {
return nil, errAuthAES128ChksumError
}
return b[:len(b)-4], nil
}
func (a *authAES128) EncodePacket(buf *bytes.Buffer, b []byte) error {
buf.Write(b)
buf.Write(a.userID[:])
buf.Write(a.hmac(a.userKey, buf.Bytes())[:4])
return nil
}
func (a *authAES128) packData(poolBuf *bytes.Buffer, data []byte, fullDataLength int) {
dataLength := len(data)
randDataLength := a.getRandDataLengthForPackData(dataLength, fullDataLength)
/*
2: uint16 LittleEndian packedDataLength
2: hmac of packedDataLength
3: maxRandDataLengthPrefix (min:1)
4: hmac of packedData except the last 4 bytes
*/
packedDataLength := 2 + 2 + 3 + randDataLength + dataLength + 4
if randDataLength < 128 {
packedDataLength -= 2
}
macKey := pool.Get(len(a.userKey) + 4)
defer pool.Put(macKey)
copy(macKey, a.userKey)
binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
a.packID++
binary.Write(poolBuf, binary.LittleEndian, uint16(packedDataLength))
poolBuf.Write(a.hmac(macKey, poolBuf.Bytes()[poolBuf.Len()-2:])[:2])
a.packRandData(poolBuf, randDataLength)
poolBuf.Write(data)
poolBuf.Write(a.hmac(macKey, poolBuf.Bytes()[poolBuf.Len()-packedDataLength+4:])[:4])
}
func trapezoidRandom(max int, d float64) int {
base := rand.Float64()
if d-0 > 1e-6 {
a := 1 - d
base = (math.Sqrt(a*a+4*d*base) - a) / (2 * d)
}
return int(base * float64(max))
}
func (a *authAES128) getRandDataLengthForPackData(dataLength, fullDataLength int) int {
if fullDataLength >= 32*1024-a.Overhead {
return 0
}
// 1460: tcp_mss
revLength := 1460 - dataLength - 9
if revLength == 0 {
return 0
}
if revLength < 0 {
if revLength > -1460 {
return trapezoidRandom(revLength+1460, -0.3)
}
return rand.Intn(32)
}
if dataLength > 900 {
return rand.Intn(revLength)
}
return trapezoidRandom(revLength, -0.3)
}
func (a *authAES128) packAuthData(poolBuf *bytes.Buffer, data []byte) {
if len(data) == 0 {
return
}
dataLength := len(data)
randDataLength := a.getRandDataLengthForPackAuthData(dataLength)
/*
7: checkHead(1) and hmac of checkHead(6)
4: userID
16: encrypted data of authdata(12), uint16 BigEndian packedDataLength(2) and uint16 BigEndian randDataLength(2)
4: hmac of userID and encrypted data
4: hmac of packedAuthData except the last 4 bytes
*/
packedAuthDataLength := 7 + 4 + 16 + 4 + randDataLength + dataLength + 4
macKey := pool.Get(len(a.iv) + len(a.Key))
defer pool.Put(macKey)
copy(macKey, a.iv)
copy(macKey[len(a.iv):], a.Key)
poolBuf.WriteByte(byte(rand.Intn(256)))
poolBuf.Write(a.hmac(macKey, poolBuf.Bytes())[:6])
poolBuf.Write(a.userID[:])
err := a.authData.putEncryptedData(poolBuf, a.userKey, [2]int{packedAuthDataLength, randDataLength}, a.salt)
if err != nil {
poolBuf.Reset()
return
}
poolBuf.Write(a.hmac(macKey, poolBuf.Bytes()[7:])[:4])
tools.AppendRandBytes(poolBuf, randDataLength)
poolBuf.Write(data)
poolBuf.Write(a.hmac(a.userKey, poolBuf.Bytes())[:4])
}
func (a *authAES128) getRandDataLengthForPackAuthData(size int) int {
if size > 400 {
return rand.Intn(512)
}
return rand.Intn(1024)
}
func (a *authAES128) packRandData(poolBuf *bytes.Buffer, size int) {
if size < 128 {
poolBuf.WriteByte(byte(size + 1))
tools.AppendRandBytes(poolBuf, size)
return
}
poolBuf.WriteByte(255)
binary.Write(poolBuf, binary.LittleEndian, uint16(size+3))
tools.AppendRandBytes(poolBuf, size)
}

View File

@ -2,427 +2,308 @@ package protocol
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand"
"crypto/rc4" "crypto/rc4"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"math/rand" "net"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools" "github.com/Dreamacro/clash/component/ssr/tools"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/go-shadowsocks2/core" "github.com/Dreamacro/go-shadowsocks2/core"
) )
type authChain struct { func init() {
*Base register("auth_chain_a", newAuthChainA, 4)
*recvInfo
*authData
randomClient shift128PlusContext
randomServer shift128PlusContext
enc cipher.Stream
dec cipher.Stream
headerSent bool
lastClientHash []byte
lastServerHash []byte
userKey []byte
uid [4]byte
salt string
hmac hmacMethod
hashDigest hashDigestMethod
rnd rndMethod
dataSizeList []int
dataSizeList2 []int
chunkID uint32
} }
func init() { type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
register("auth_chain_a", newAuthChainA)
type authChainA struct {
*Base
*authData
*userData
iv []byte
salt string
hasSentHeader bool
rawTrans bool
lastClientHash []byte
lastServerHash []byte
encrypter cipher.Stream
decrypter cipher.Stream
randomClient tools.XorShift128Plus
randomServer tools.XorShift128Plus
randDataLength randDataLengthMethod
packID uint32
recvID uint32
} }
func newAuthChainA(b *Base) Protocol { func newAuthChainA(b *Base) Protocol {
return &authChain{ a := &authChainA{
Base: b, Base: b,
authData: &authData{}, authData: &authData{},
userData: &userData{},
salt: "auth_chain_a", salt: "auth_chain_a",
hmac: tools.HmacMD5, }
hashDigest: tools.SHA1Sum, a.initUserData()
rnd: authChainAGetRandLen, return a
}
func (a *authChainA) initUserData() {
params := strings.Split(a.Param, ":")
if len(params) > 1 {
if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
a.userKey = []byte(params[1])
} else {
log.Warnln("Wrong protocol-param for %s, only digits are expected before ':'", a.salt)
}
}
if len(a.userKey) == 0 {
a.userKey = a.Key
rand.Read(a.userID[:])
} }
} }
func (a *authChain) initForConn(iv []byte) Protocol { func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
r := &authChain{ p := &authChainA{
Base: &Base{ Base: a.Base,
IV: iv, authData: a.next(),
Key: a.Key, userData: a.userData,
TCPMss: a.TCPMss,
Overhead: a.Overhead,
Param: a.Param,
},
recvInfo: &recvInfo{recvID: 1, buffer: new(bytes.Buffer)},
authData: a.authData,
salt: a.salt, salt: a.salt,
hmac: a.hmac, packID: 1,
hashDigest: a.hashDigest, recvID: 1,
rnd: a.rnd,
} }
if r.salt == "auth_chain_b" { p.iv = iv
initDataSize(r) p.randDataLength = p.getRandLength
} return &Conn{Conn: c, Protocol: p}
return r
} }
func (a *authChain) GetProtocolOverhead() int { func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
return 4 p := &authChainA{
Base: a.Base,
salt: a.salt,
userData: a.userData,
}
return &PacketConn{PacketConn: c, Protocol: p}
} }
func (a *authChain) SetOverhead(overhead int) { func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
a.Overhead = overhead if a.rawTrans {
dst.ReadFrom(src)
return nil
} }
for src.Len() > 4 {
macKey := pool.Get(len(a.userKey) + 4)
defer pool.Put(macKey)
copy(macKey, a.userKey)
binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
length := dataLength + randDataLength
func (a *authChain) Decode(b []byte) ([]byte, int, error) {
a.buffer.Reset()
key := pool.Get(len(a.userKey) + 4)
defer pool.Put(key)
readSize := 0
copy(key, a.userKey)
for len(b) > 4 {
binary.LittleEndian.PutUint32(key[len(a.userKey):], a.recvID)
dataLen := (int)((uint(b[1]^a.lastServerHash[15]) << 8) + uint(b[0]^a.lastServerHash[14]))
randLen := a.getServerRandLen(dataLen, a.Overhead)
length := randLen + dataLen
if length >= 4096 { if length >= 4096 {
return nil, 0, errAuthChainDataLengthError a.rawTrans = true
src.Reset()
return errAuthChainLengthError
} }
length += 4
if length > len(b) { if 4+length > src.Len() {
break break
} }
hash := a.hmac(key, b[:length-2]) serverHash := tools.HmacMD5(macKey, src.Bytes()[:length+2])
if !bytes.Equal(hash[:2], b[length-2:length]) { if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
return nil, 0, errAuthChainHMACError a.rawTrans = true
src.Reset()
return errAuthChainChksumError
} }
var dataPos int a.lastServerHash = serverHash
if dataLen > 0 && randLen > 0 {
dataPos = 2 + getRandStartPos(&a.randomServer, randLen) pos := 2
} else { if dataLength > 0 && randDataLength > 0 {
dataPos = 2 pos += getRandStartPos(randDataLength, &a.randomServer)
} }
d := pool.Get(dataLen) wantedData := src.Bytes()[pos : pos+dataLength]
a.dec.XORKeyStream(d, b[dataPos:dataPos+dataLen]) a.decrypter.XORKeyStream(wantedData, wantedData)
a.buffer.Write(d)
pool.Put(d)
if a.recvID == 1 { if a.recvID == 1 {
a.TCPMss = int(binary.LittleEndian.Uint16(a.buffer.Next(2))) dst.Write(wantedData[2:])
} else {
dst.Write(wantedData)
} }
a.lastServerHash = hash
a.recvID++ a.recvID++
b = b[length:] src.Next(length + 4)
readSize += length
} }
return a.buffer.Bytes(), readSize, nil return nil
} }
func (a *authChain) Encode(b []byte) ([]byte, error) { func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
a.buffer.Reset() if !a.hasSentHeader {
bSize := len(b) dataLength := getDataLength(b)
offset := 0 a.packAuthData(buf, b[:dataLength])
if bSize > 0 && !a.headerSent { b = b[dataLength:]
headSize := 1200 a.hasSentHeader = true
if headSize > bSize {
headSize = bSize
} }
a.buffer.Write(a.packAuthData(b[:headSize])) for len(b) > 2800 {
offset += headSize a.packData(buf, b[:2800])
bSize -= headSize b = b[2800:]
a.headerSent = true
} }
var unitSize = a.TCPMss - a.Overhead if len(b) > 0 {
for bSize > unitSize { a.packData(buf, b)
dataLen, randLength := a.packedDataLen(b[offset : offset+unitSize])
d := pool.Get(dataLen)
a.packData(d, b[offset:offset+unitSize], randLength)
a.buffer.Write(d)
pool.Put(d)
bSize -= unitSize
offset += unitSize
} }
if bSize > 0 { return nil
dataLen, randLength := a.packedDataLen(b[offset:])
d := pool.Get(dataLen)
a.packData(d, b[offset:], randLength)
a.buffer.Write(d)
pool.Put(d)
}
return a.buffer.Bytes(), nil
} }
func (a *authChain) DecodePacket(b []byte) ([]byte, int, error) { func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
bSize := len(b) if len(b) < 9 {
if bSize < 9 { return nil, errAuthChainLengthError
return nil, 0, errAuthChainDataLengthError
} }
h := a.hmac(a.userKey, b[:bSize-1]) if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
if h[0] != b[bSize-1] { return nil, errAuthChainChksumError
return nil, 0, errAuthChainHMACError
} }
hash := a.hmac(a.Key, b[bSize-8:bSize-1]) md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
cipherKey := a.getRC4CipherKey(hash)
dec, _ := rc4.NewCipher(cipherKey) randDataLength := udpGetRandLength(md5Data, &a.randomServer)
randLength := udpGetRandLen(&a.randomServer, hash)
bSize -= 8 + randLength key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
dec.XORKeyStream(b, b[:bSize]) rc4Cipher, err := rc4.NewCipher(key)
return b, bSize, nil if err != nil {
return nil, err
}
wantedData := b[:len(b)-8-randDataLength]
rc4Cipher.XORKeyStream(wantedData, wantedData)
return wantedData, nil
} }
func (a *authChain) EncodePacket(b []byte) ([]byte, error) { func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
a.initUserKeyAndID()
authData := pool.Get(3) authData := pool.Get(3)
defer pool.Put(authData) defer pool.Put(authData)
rand.Read(authData) rand.Read(authData)
hash := a.hmac(a.Key, authData)
uid := pool.Get(4)
defer pool.Put(uid)
for i := 0; i < 4; i++ {
uid[i] = a.uid[i] ^ hash[i]
}
cipherKey := a.getRC4CipherKey(hash) md5Data := tools.HmacMD5(a.Key, authData)
enc, _ := rc4.NewCipher(cipherKey)
var buf bytes.Buffer
enc.XORKeyStream(b, b)
buf.Write(b)
randLength := udpGetRandLen(&a.randomClient, hash) randDataLength := udpGetRandLength(md5Data, &a.randomClient)
randBytes := pool.Get(randLength)
defer pool.Put(randBytes)
buf.Write(randBytes)
buf.Write(authData) key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
buf.Write(uid) rc4Cipher, err := rc4.NewCipher(key)
h := a.hmac(a.userKey, buf.Bytes())
buf.Write(h[:1])
return buf.Bytes(), nil
}
func (a *authChain) getRC4CipherKey(hash []byte) []byte {
base64UserKey := base64.StdEncoding.EncodeToString(a.userKey)
return a.calcRC4CipherKey(hash, base64UserKey)
}
func (a *authChain) calcRC4CipherKey(hash []byte, base64UserKey string) []byte {
password := pool.Get(len(base64UserKey) + base64.StdEncoding.EncodedLen(16))
defer pool.Put(password)
copy(password, base64UserKey)
base64.StdEncoding.Encode(password[len(base64UserKey):], hash[:16])
return core.Kdf(string(password), 16)
}
func (a *authChain) initUserKeyAndID() {
if a.userKey == nil {
params := strings.Split(a.Param, ":")
if len(params) >= 2 {
if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
binary.LittleEndian.PutUint32(a.uid[:], uint32(userID))
a.userKey = []byte(params[1])[:len(a.userKey)]
}
}
if a.userKey == nil {
rand.Read(a.uid[:])
a.userKey = make([]byte, len(a.Key))
copy(a.userKey, a.Key)
}
}
}
func (a *authChain) getClientRandLen(dataLength int, overhead int) int {
return a.rnd(dataLength, &a.randomClient, a.lastClientHash, a.dataSizeList, a.dataSizeList2, overhead)
}
func (a *authChain) getServerRandLen(dataLength int, overhead int) int {
return a.rnd(dataLength, &a.randomServer, a.lastServerHash, a.dataSizeList, a.dataSizeList2, overhead)
}
func (a *authChain) packedDataLen(data []byte) (chunkLength, randLength int) {
dataLength := len(data)
randLength = a.getClientRandLen(dataLength, a.Overhead)
chunkLength = randLength + dataLength + 2 + 2
return
}
func (a *authChain) packData(outData []byte, data []byte, randLength int) {
dataLength := len(data)
outLength := randLength + dataLength + 2
outData[0] = byte(dataLength) ^ a.lastClientHash[14]
outData[1] = byte(dataLength>>8) ^ a.lastClientHash[15]
{
if dataLength > 0 {
randPart1Length := getRandStartPos(&a.randomClient, randLength)
rand.Read(outData[2 : 2+randPart1Length])
a.enc.XORKeyStream(outData[2+randPart1Length:], data)
rand.Read(outData[2+randPart1Length+dataLength : outLength])
} else {
rand.Read(outData[2 : 2+randLength])
}
}
userKeyLen := uint8(len(a.userKey))
key := pool.Get(int(userKeyLen + 4))
defer pool.Put(key)
copy(key, a.userKey)
a.chunkID++
binary.LittleEndian.PutUint32(key[userKeyLen:], a.chunkID)
a.lastClientHash = a.hmac(key, outData[:outLength])
copy(outData[outLength:], a.lastClientHash[:2])
return
}
const authHeadLength = 4 + 8 + 4 + 16 + 4
func (a *authChain) packAuthData(data []byte) (outData []byte) {
outData = make([]byte, authHeadLength, authHeadLength+1500)
a.mutex.Lock()
defer a.mutex.Unlock()
a.connectionID++
if a.connectionID > 0xFF000000 {
rand.Read(a.clientID)
b := make([]byte, 4)
rand.Read(b)
a.connectionID = binary.LittleEndian.Uint32(b) & 0xFFFFFF
}
var key = make([]byte, len(a.IV)+len(a.Key))
copy(key, a.IV)
copy(key[len(a.IV):], a.Key)
encrypt := make([]byte, 20)
t := time.Now().Unix()
binary.LittleEndian.PutUint32(encrypt[:4], uint32(t))
copy(encrypt[4:8], a.clientID)
binary.LittleEndian.PutUint32(encrypt[8:], a.connectionID)
binary.LittleEndian.PutUint16(encrypt[12:], uint16(a.Overhead))
binary.LittleEndian.PutUint16(encrypt[14:], 0)
// first 12 bytes
{
rand.Read(outData[:4])
a.lastClientHash = a.hmac(key, outData[:4])
copy(outData[4:], a.lastClientHash[:8])
}
var base64UserKey string
// uid & 16 bytes auth data
{
a.initUserKeyAndID()
uid := make([]byte, 4)
for i := 0; i < 4; i++ {
uid[i] = a.uid[i] ^ a.lastClientHash[8+i]
}
base64UserKey = base64.StdEncoding.EncodeToString(a.userKey)
aesCipherKey := core.Kdf(base64UserKey+a.salt, 16)
block, err := aes.NewCipher(aesCipherKey)
if err != nil { if err != nil {
return err
}
rc4Cipher.XORKeyStream(b, b)
buf.Write(b)
tools.AppendRandBytes(buf, randDataLength)
buf.Write(authData)
binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
return nil
}
func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
/*
dataLength := len(data)
12: checkHead(4) and hmac of checkHead(8)
4: uint32 LittleEndian uid (uid = userID ^ last client hash)
16: encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
4: last server hash(4)
packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
*/
macKey := pool.Get(len(a.iv) + len(a.Key))
defer pool.Put(macKey)
copy(macKey, a.iv)
copy(macKey[len(a.iv):], a.Key)
// check head
tools.AppendRandBytes(poolBuf, 4)
a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes())
a.initRC4Cipher()
poolBuf.Write(a.lastClientHash[:8])
// uid
binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
// encrypted data
err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
if err != nil {
poolBuf.Reset()
return return
} }
encryptData := make([]byte, 16) // last server hash
iv := make([]byte, aes.BlockSize) a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
cbc := cipher.NewCBCEncrypter(block, iv) poolBuf.Write(a.lastServerHash[:4])
cbc.CryptBlocks(encryptData, encrypt[:16]) // packed data
copy(encrypt[:4], uid[:]) a.packData(poolBuf, data)
copy(encrypt[4:4+16], encryptData)
}
// final HMAC
{
a.lastServerHash = a.hmac(a.userKey, encrypt[:20])
copy(outData[12:], encrypt)
copy(outData[12+20:], a.lastServerHash[:4])
} }
// init cipher func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
cipherKey := a.calcRC4CipherKey(a.lastClientHash, base64UserKey) a.encrypter.XORKeyStream(data, data)
a.enc, _ = rc4.NewCipher(cipherKey)
a.dec, _ = rc4.NewCipher(cipherKey)
// data macKey := pool.Get(len(a.userKey) + 4)
chunkLength, randLength := a.packedDataLen(data) defer pool.Put(macKey)
if chunkLength <= 1500 { copy(macKey, a.userKey)
outData = outData[:authHeadLength+chunkLength] binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
} else { a.packID++
newOutData := make([]byte, authHeadLength+chunkLength)
copy(newOutData, outData[:authHeadLength]) length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
outData = newOutData
originalLength := poolBuf.Len()
binary.Write(poolBuf, binary.LittleEndian, length)
a.putMixedRandDataAndData(poolBuf, data)
a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes()[originalLength:])
poolBuf.Write(a.lastClientHash[:2])
} }
a.packData(outData[authHeadLength:], data, randLength)
func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
if len(data) == 0 {
tools.AppendRandBytes(poolBuf, randDataLength)
return return
} }
if randDataLength > 0 {
func getRandStartPos(random *shift128PlusContext, randLength int) int { startPos := getRandStartPos(randDataLength, &a.randomClient)
if randLength > 0 { tools.AppendRandBytes(poolBuf, startPos)
return int(random.Next() % 8589934609 % uint64(randLength)) poolBuf.Write(data)
tools.AppendRandBytes(poolBuf, randDataLength-startPos)
return
} }
return 0 poolBuf.Write(data)
} }
func authChainAGetRandLen(dataLength int, random *shift128PlusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int { func getRandStartPos(length int, random *tools.XorShift128Plus) int {
if dataLength > 1440 { if length == 0 {
return 0 return 0
} }
random.InitFromBinDatalen(lastHash[:16], dataLength) return int(random.Next()%8589934609) % length
if dataLength > 1300 { }
func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
if length > 1440 {
return 0
}
random.InitFromBinAndLength(lastHash, length)
if length > 1300 {
return int(random.Next() % 31) return int(random.Next() % 31)
} }
if dataLength > 900 { if length > 900 {
return int(random.Next() % 127) return int(random.Next() % 127)
} }
if dataLength > 400 { if length > 400 {
return int(random.Next() % 521) return int(random.Next() % 521)
} }
return int(random.Next() % 1021) return int(random.Next() % 1021)
} }
func udpGetRandLen(random *shift128PlusContext, lastHash []byte) int { func (a *authChainA) initRC4Cipher() {
random.InitFromBin(lastHash[:16]) key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
a.encrypter, _ = rc4.NewCipher(key)
a.decrypter, _ = rc4.NewCipher(key)
}
func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
random.InitFromBin(lastHash)
return int(random.Next() % 127) return int(random.Next() % 127)
} }
type shift128PlusContext struct {
v [2]uint64
}
func (ctx *shift128PlusContext) InitFromBin(bin []byte) {
var fillBin [16]byte
copy(fillBin[:], bin)
ctx.v[0] = binary.LittleEndian.Uint64(fillBin[:8])
ctx.v[1] = binary.LittleEndian.Uint64(fillBin[8:])
}
func (ctx *shift128PlusContext) InitFromBinDatalen(bin []byte, datalen int) {
var fillBin [16]byte
copy(fillBin[:], bin)
binary.LittleEndian.PutUint16(fillBin[:2], uint16(datalen))
ctx.v[0] = binary.LittleEndian.Uint64(fillBin[:8])
ctx.v[1] = binary.LittleEndian.Uint64(fillBin[8:])
for i := 0; i < 4; i++ {
ctx.Next()
}
}
func (ctx *shift128PlusContext) Next() uint64 {
x := ctx.v[0]
y := ctx.v[1]
ctx.v[0] = y
x ^= x << 23
x ^= y ^ (x >> 17) ^ (y >> 26)
ctx.v[1] = x
return x + y
}

View File

@ -1,71 +1,96 @@
package protocol package protocol
import ( import (
"net"
"sort" "sort"
"github.com/Dreamacro/clash/component/ssr/tools" "github.com/Dreamacro/clash/component/ssr/tools"
) )
func init() { func init() {
register("auth_chain_b", newAuthChainB) register("auth_chain_b", newAuthChainB, 4)
}
type authChainB struct {
*authChainA
dataSizeList []int
dataSizeList2 []int
} }
func newAuthChainB(b *Base) Protocol { func newAuthChainB(b *Base) Protocol {
return &authChain{ a := &authChainB{
authChainA: &authChainA{
Base: b, Base: b,
authData: &authData{}, authData: &authData{},
userData: &userData{},
salt: "auth_chain_b", salt: "auth_chain_b",
hmac: tools.HmacMD5, },
hashDigest: tools.SHA1Sum,
rnd: authChainBGetRandLen,
} }
a.initUserData()
return a
} }
func initDataSize(r *authChain) { func (a *authChainB) StreamConn(c net.Conn, iv []byte) net.Conn {
random := &r.randomServer p := &authChainB{
random.InitFromBin(r.Key) authChainA: &authChainA{
len := random.Next()%8 + 4 Base: a.Base,
r.dataSizeList = make([]int, len) authData: a.next(),
for i := 0; i < int(len); i++ { userData: a.userData,
r.dataSizeList[i] = int(random.Next() % 2340 % 2040 % 1440) salt: a.salt,
packID: 1,
recvID: 1,
},
} }
sort.Ints(r.dataSizeList) p.iv = iv
p.randDataLength = p.getRandLength
len = random.Next()%16 + 8 p.initDataSize()
r.dataSizeList2 = make([]int, len) return &Conn{Conn: c, Protocol: p}
for i := 0; i < int(len); i++ {
r.dataSizeList2[i] = int(random.Next() % 2340 % 2040 % 1440)
}
sort.Ints(r.dataSizeList2)
} }
func authChainBGetRandLen(dataLength int, random *shift128PlusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int { func (a *authChainB) initDataSize() {
if dataLength > 1440 { a.dataSizeList = a.dataSizeList[:0]
a.dataSizeList2 = a.dataSizeList2[:0]
a.randomServer.InitFromBin(a.Key)
length := a.randomServer.Next()%8 + 4
for ; length > 0; length-- {
a.dataSizeList = append(a.dataSizeList, int(a.randomServer.Next()%2340%2040%1440))
}
sort.Ints(a.dataSizeList)
length = a.randomServer.Next()%16 + 8
for ; length > 0; length-- {
a.dataSizeList2 = append(a.dataSizeList2, int(a.randomServer.Next()%2340%2040%1440))
}
sort.Ints(a.dataSizeList2)
}
func (a *authChainB) getRandLength(length int, lashHash []byte, random *tools.XorShift128Plus) int {
if length >= 1440 {
return 0 return 0
} }
random.InitFromBinDatalen(lastHash[:16], dataLength) random.InitFromBinAndLength(lashHash, length)
pos := sort.Search(len(dataSizeList), func(i int) bool { return dataSizeList[i] > dataLength+overhead }) pos := sort.Search(len(a.dataSizeList), func(i int) bool { return a.dataSizeList[i] >= length+a.Overhead })
finalPos := uint64(pos) + random.Next()%uint64(len(dataSizeList)) finalPos := pos + int(random.Next()%uint64(len(a.dataSizeList)))
if finalPos < uint64(len(dataSizeList)) { if finalPos < len(a.dataSizeList) {
return dataSizeList[finalPos] - dataLength - overhead return a.dataSizeList[finalPos] - length - a.Overhead
} }
pos = sort.Search(len(dataSizeList2), func(i int) bool { return dataSizeList2[i] > dataLength+overhead }) pos = sort.Search(len(a.dataSizeList2), func(i int) bool { return a.dataSizeList2[i] >= length+a.Overhead })
finalPos = uint64(pos) + random.Next()%uint64(len(dataSizeList2)) finalPos = pos + int(random.Next()%uint64(len(a.dataSizeList2)))
if finalPos < uint64(len(dataSizeList2)) { if finalPos < len(a.dataSizeList2) {
return dataSizeList2[finalPos] - dataLength - overhead return a.dataSizeList2[finalPos] - length - a.Overhead
} }
if finalPos < uint64(pos+len(dataSizeList2)-1) { if finalPos < pos+len(a.dataSizeList2)-1 {
return 0 return 0
} }
if length > 1300 {
if dataLength > 1300 {
return int(random.Next() % 31) return int(random.Next() % 31)
} }
if dataLength > 900 { if length > 900 {
return int(random.Next() % 127) return int(random.Next() % 127)
} }
if dataLength > 400 { if length > 400 {
return int(random.Next() % 521) return int(random.Next() % 521)
} }
return int(random.Next() % 1021) return int(random.Next() % 1021)

View File

@ -6,248 +6,177 @@ import (
"hash/adler32" "hash/adler32"
"hash/crc32" "hash/crc32"
"math/rand" "math/rand"
"time" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools" "github.com/Dreamacro/clash/component/ssr/tools"
) )
func init() {
register("auth_sha1_v4", newAuthSHA1V4, 7)
}
type authSHA1V4 struct { type authSHA1V4 struct {
*Base *Base
*authData *authData
headerSent bool iv []byte
buffer bytes.Buffer hasSentHeader bool
} rawTrans bool
func init() {
register("auth_sha1_v4", newAuthSHA1V4)
} }
func newAuthSHA1V4(b *Base) Protocol { func newAuthSHA1V4(b *Base) Protocol {
return &authSHA1V4{Base: b, authData: &authData{}} return &authSHA1V4{Base: b, authData: &authData{}}
} }
func (a *authSHA1V4) initForConn(iv []byte) Protocol { func (a *authSHA1V4) StreamConn(c net.Conn, iv []byte) net.Conn {
return &authSHA1V4{ p := &authSHA1V4{Base: a.Base, authData: a.next()}
Base: &Base{ p.iv = iv
IV: iv, return &Conn{Conn: c, Protocol: p}
Key: a.Key,
TCPMss: a.TCPMss,
Overhead: a.Overhead,
Param: a.Param,
},
authData: a.authData,
}
} }
func (a *authSHA1V4) GetProtocolOverhead() int { func (a *authSHA1V4) PacketConn(c net.PacketConn) net.PacketConn {
return 7 return c
} }
func (a *authSHA1V4) SetOverhead(overhead int) { func (a *authSHA1V4) Decode(dst, src *bytes.Buffer) error {
a.Overhead = overhead if a.rawTrans {
dst.ReadFrom(src)
return nil
}
for src.Len() > 4 {
if uint16(crc32.ChecksumIEEE(src.Bytes()[:2])&0xffff) != binary.LittleEndian.Uint16(src.Bytes()[2:4]) {
src.Reset()
return errAuthSHA1V4CRC32Error
} }
func (a *authSHA1V4) Decode(b []byte) ([]byte, int, error) { length := int(binary.BigEndian.Uint16(src.Bytes()[:2]))
a.buffer.Reset() if length >= 8192 || length < 7 {
bSize := len(b) a.rawTrans = true
originalSize := bSize src.Reset()
for bSize > 4 { return errAuthSHA1V4LengthError
crc := crc32.ChecksumIEEE(b[:2]) & 0xFFFF
if binary.LittleEndian.Uint16(b[2:4]) != uint16(crc) {
return nil, 0, errAuthSHA1v4CRC32Error
} }
length := int(binary.BigEndian.Uint16(b[:2])) if length > src.Len() {
if length >= 8192 || length < 8 {
return nil, 0, errAuthSHA1v4DataLengthError
}
if length > bSize {
break break
} }
if adler32.Checksum(b[:length-4]) == binary.LittleEndian.Uint32(b[length-4:]) { if adler32.Checksum(src.Bytes()[:length-4]) != binary.LittleEndian.Uint32(src.Bytes()[length-4:length]) {
pos := int(b[4]) a.rawTrans = true
if pos != 0xFF { src.Reset()
return errAuthSHA1V4Adler32Error
}
pos := int(src.Bytes()[4])
if pos < 255 {
pos += 4 pos += 4
} else { } else {
pos = int(binary.BigEndian.Uint16(b[5:5+2])) + 4 pos = int(binary.BigEndian.Uint16(src.Bytes()[5:7])) + 4
} }
retSize := length - pos - 4 dst.Write(src.Bytes()[pos : length-4])
a.buffer.Write(b[pos : pos+retSize]) src.Next(length)
bSize -= length
b = b[length:]
} else {
return nil, 0, errAuthSHA1v4IncorrectChecksum
} }
} return nil
return a.buffer.Bytes(), originalSize - bSize, nil
} }
func (a *authSHA1V4) Encode(b []byte) ([]byte, error) { func (a *authSHA1V4) Encode(buf *bytes.Buffer, b []byte) error {
a.buffer.Reset() if !a.hasSentHeader {
bSize := len(b) dataLength := getDataLength(b)
offset := 0
if !a.headerSent && bSize > 0 { a.packAuthData(buf, b[:dataLength])
headSize := getHeadSize(b, 30) b = b[dataLength:]
if headSize > bSize {
headSize = bSize a.hasSentHeader = true
} }
a.buffer.Write(a.packAuthData(b[:headSize])) for len(b) > 8100 {
offset += headSize a.packData(buf, b[:8100])
bSize -= headSize b = b[8100:]
a.headerSent = true
} }
const blockSize = 4096 if len(b) > 0 {
for bSize > blockSize { a.packData(buf, b)
packSize, randSize := a.packedDataSize(b[offset : offset+blockSize])
pack := pool.Get(packSize)
a.packData(b[offset:offset+blockSize], pack, randSize)
a.buffer.Write(pack)
pool.Put(pack)
offset += blockSize
bSize -= blockSize
}
if bSize > 0 {
packSize, randSize := a.packedDataSize(b[offset:])
pack := pool.Get(packSize)
a.packData(b[offset:], pack, randSize)
a.buffer.Write(pack)
pool.Put(pack)
}
return a.buffer.Bytes(), nil
} }
func (a *authSHA1V4) DecodePacket(b []byte) ([]byte, int, error) { return nil
return b, len(b), nil
} }
func (a *authSHA1V4) EncodePacket(b []byte) ([]byte, error) { func (a *authSHA1V4) DecodePacket(b []byte) ([]byte, error) { return b, nil }
return b, nil
func (a *authSHA1V4) EncodePacket(buf *bytes.Buffer, b []byte) error {
buf.Write(b)
return nil
} }
func (a *authSHA1V4) packedDataSize(data []byte) (packSize, randSize int) { func (a *authSHA1V4) packData(poolBuf *bytes.Buffer, data []byte) {
dataSize := len(data) dataLength := len(data)
randSize = 1 randDataLength := a.getRandDataLength(dataLength)
if dataSize <= 1300 { /*
if dataSize > 400 { 2: uint16 BigEndian packedDataLength
randSize += rand.Intn(128) 2: uint16 LittleEndian crc32Data & 0xffff
} else { 3: maxRandDataLengthPrefix (min:1)
randSize += rand.Intn(1024) 4: adler32Data
} */
} packedDataLength := 2 + 2 + 3 + randDataLength + dataLength + 4
packSize = randSize + dataSize + 8 if randDataLength < 128 {
return packedDataLength -= 2
} }
func (a *authSHA1V4) packData(data, ret []byte, randSize int) { binary.Write(poolBuf, binary.BigEndian, uint16(packedDataLength))
dataSize := len(data) binary.Write(poolBuf, binary.LittleEndian, uint16(crc32.ChecksumIEEE(poolBuf.Bytes()[poolBuf.Len()-2:])&0xffff))
retSize := len(ret) a.packRandData(poolBuf, randDataLength)
// 0~1, ret size poolBuf.Write(data)
binary.BigEndian.PutUint16(ret[:2], uint16(retSize&0xFFFF)) binary.Write(poolBuf, binary.LittleEndian, adler32.Checksum(poolBuf.Bytes()[poolBuf.Len()-packedDataLength+4:]))
// 2~3, crc of ret size
crc := crc32.ChecksumIEEE(ret[:2]) & 0xFFFF
binary.LittleEndian.PutUint16(ret[2:4], uint16(crc))
// 4, rand size
if randSize < 128 {
ret[4] = uint8(randSize & 0xFF)
} else {
ret[4] = uint8(0xFF)
binary.BigEndian.PutUint16(ret[5:7], uint16(randSize&0xFFFF))
}
// (rand size+4)~(ret size-4), data
if dataSize > 0 {
copy(ret[randSize+4:], data)
}
// (ret size-4)~end, adler32 of full data
adler := adler32.Checksum(ret[:retSize-4])
binary.LittleEndian.PutUint32(ret[retSize-4:], adler)
} }
func (a *authSHA1V4) packAuthData(data []byte) (ret []byte) { func (a *authSHA1V4) packAuthData(poolBuf *bytes.Buffer, data []byte) {
dataSize := len(data) dataLength := len(data)
randSize := 1 randDataLength := a.getRandDataLength(12 + dataLength)
if dataSize <= 1300 { /*
if dataSize > 400 { 2: uint16 BigEndian packedAuthDataLength
randSize += rand.Intn(128) 4: uint32 LittleEndian crc32Data
} else { 3: maxRandDataLengthPrefix (min: 1)
randSize += rand.Intn(1024) 12: authDataLength
10: hmacSHA1DataLength
*/
packedAuthDataLength := 2 + 4 + 3 + randDataLength + 12 + dataLength + 10
if randDataLength < 128 {
packedAuthDataLength -= 2
} }
}
dataOffset := randSize + 4 + 2
retSize := dataOffset + dataSize + 12 + tools.HmacSHA1Len
ret = make([]byte, retSize)
a.mutex.Lock()
defer a.mutex.Unlock()
a.connectionID++
if a.connectionID > 0xFF000000 {
a.clientID = nil
}
if len(a.clientID) == 0 {
a.clientID = make([]byte, 8)
rand.Read(a.clientID)
b := make([]byte, 4)
rand.Read(b)
a.connectionID = binary.LittleEndian.Uint32(b) & 0xFFFFFF
}
// 0~1, ret size
binary.BigEndian.PutUint16(ret[:2], uint16(retSize&0xFFFF))
// 2~6, crc of (ret size+salt+key)
salt := []byte("auth_sha1_v4") salt := []byte("auth_sha1_v4")
crcData := make([]byte, len(salt)+len(a.Key)+2) crcData := pool.Get(len(salt) + len(a.Key) + 2)
copy(crcData[:2], ret[:2]) defer pool.Put(crcData)
binary.BigEndian.PutUint16(crcData, uint16(packedAuthDataLength))
copy(crcData[2:], salt) copy(crcData[2:], salt)
copy(crcData[2+len(salt):], a.Key) copy(crcData[2+len(salt):], a.Key)
crc := crc32.ChecksumIEEE(crcData) & 0xFFFFFFFF
// 2~6, crc of (ret size+salt+key)
binary.LittleEndian.PutUint32(ret[2:], crc)
// 6~(rand size+6), rand numbers
rand.Read(ret[dataOffset-randSize : dataOffset])
// 6, rand size
if randSize < 128 {
ret[6] = byte(randSize & 0xFF)
} else {
// 6, magic number 0xFF
ret[6] = 0xFF
// 7~8, rand size
binary.BigEndian.PutUint16(ret[7:9], uint16(randSize&0xFFFF))
}
// rand size+6~(rand size+10), time stamp
now := time.Now().Unix()
binary.LittleEndian.PutUint32(ret[dataOffset:dataOffset+4], uint32(now))
// rand size+10~(rand size+14), client ID
copy(ret[dataOffset+4:dataOffset+4+4], a.clientID[:4])
// rand size+14~(rand size+18), connection ID
binary.LittleEndian.PutUint32(ret[dataOffset+8:dataOffset+8+4], a.connectionID)
// rand size+18~(rand size+18)+data length, data
copy(ret[dataOffset+12:], data)
key := make([]byte, len(a.IV)+len(a.Key)) key := pool.Get(len(a.iv) + len(a.Key))
copy(key, a.IV) defer pool.Put(key)
copy(key[len(a.IV):], a.Key) copy(key, a.iv)
copy(key[len(a.iv):], a.Key)
h := tools.HmacSHA1(key, ret[:retSize-tools.HmacSHA1Len]) poolBuf.Write(crcData[:2])
// (ret size-10)~(ret size)/(rand size)+18+data length~end, hmac binary.Write(poolBuf, binary.LittleEndian, crc32.ChecksumIEEE(crcData))
copy(ret[retSize-tools.HmacSHA1Len:], h[:tools.HmacSHA1Len]) a.packRandData(poolBuf, randDataLength)
return ret a.putAuthData(poolBuf)
poolBuf.Write(data)
poolBuf.Write(tools.HmacSHA1(key, poolBuf.Bytes()[poolBuf.Len()-packedAuthDataLength+10:])[:10])
} }
func getHeadSize(data []byte, defaultValue int) int { func (a *authSHA1V4) packRandData(poolBuf *bytes.Buffer, size int) {
if data == nil || len(data) < 2 { if size < 128 {
return defaultValue poolBuf.WriteByte(byte(size + 1))
tools.AppendRandBytes(poolBuf, size)
return
} }
headType := data[0] & 0x07 poolBuf.WriteByte(255)
switch headType { binary.Write(poolBuf, binary.BigEndian, uint16(size+3))
case 1: tools.AppendRandBytes(poolBuf, size)
// IPv4 1+4+2
return 7
case 4:
// IPv6 1+16+2
return 19
case 3:
// domain name, variant length
return 4 + int(data[1])
} }
return defaultValue func (a *authSHA1V4) getRandDataLength(size int) int {
if size > 1200 {
return 0
}
if size > 400 {
return rand.Intn(256)
}
return rand.Intn(512)
} }

View File

@ -1,10 +1,77 @@
package protocol package protocol
// Base information for protocol import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/binary"
"math/rand"
"sync"
"time"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/go-shadowsocks2/core"
)
type Base struct { type Base struct {
IV []byte
Key []byte Key []byte
TCPMss int
Overhead int Overhead int
Param string Param string
} }
type userData struct {
userKey []byte
userID [4]byte
}
type authData struct {
clientID [4]byte
connectionID uint32
mutex sync.Mutex
}
func (a *authData) next() *authData {
r := &authData{}
a.mutex.Lock()
defer a.mutex.Unlock()
if a.connectionID > 0xff000000 || a.connectionID == 0 {
rand.Read(a.clientID[:])
a.connectionID = rand.Uint32() & 0xffffff
}
a.connectionID++
copy(r.clientID[:], a.clientID[:])
r.connectionID = a.connectionID
return r
}
func (a *authData) putAuthData(buf *bytes.Buffer) {
binary.Write(buf, binary.LittleEndian, uint32(time.Now().Unix()))
buf.Write(a.clientID[:])
binary.Write(buf, binary.LittleEndian, a.connectionID)
}
func (a *authData) putEncryptedData(b *bytes.Buffer, userKey []byte, paddings [2]int, salt string) error {
encrypt := pool.Get(16)
defer pool.Put(encrypt)
binary.LittleEndian.PutUint32(encrypt, uint32(time.Now().Unix()))
copy(encrypt[4:], a.clientID[:])
binary.LittleEndian.PutUint32(encrypt[8:], a.connectionID)
binary.LittleEndian.PutUint16(encrypt[12:], uint16(paddings[0]))
binary.LittleEndian.PutUint16(encrypt[14:], uint16(paddings[1]))
cipherKey := core.Kdf(base64.StdEncoding.EncodeToString(userKey)+salt, 16)
block, err := aes.NewCipher(cipherKey)
if err != nil {
log.Warnln("New cipher error: %s", err.Error())
return err
}
iv := bytes.Repeat([]byte{0}, 16)
cbcCipher := cipher.NewCBCEncrypter(block, iv)
cbcCipher.CryptBlocks(encrypt, encrypt)
b.Write(encrypt)
return nil
}

View File

@ -1,36 +1,33 @@
package protocol package protocol
type origin struct{ *Base } import (
"bytes"
"net"
)
func init() { type origin struct{}
register("origin", newOrigin)
func init() { register("origin", newOrigin, 0) }
func newOrigin(b *Base) Protocol { return &origin{} }
func (o *origin) StreamConn(c net.Conn, iv []byte) net.Conn { return c }
func (o *origin) PacketConn(c net.PacketConn) net.PacketConn { return c }
func (o *origin) Decode(dst, src *bytes.Buffer) error {
dst.ReadFrom(src)
return nil
} }
func newOrigin(b *Base) Protocol { func (o *origin) Encode(buf *bytes.Buffer, b []byte) error {
return &origin{} buf.Write(b)
return nil
} }
func (o *origin) initForConn(iv []byte) Protocol { return &origin{} } func (o *origin) DecodePacket(b []byte) ([]byte, error) { return b, nil }
func (o *origin) GetProtocolOverhead() int { func (o *origin) EncodePacket(buf *bytes.Buffer, b []byte) error {
return 0 buf.Write(b)
} return nil
func (o *origin) SetOverhead(overhead int) {
}
func (o *origin) Decode(b []byte) ([]byte, int, error) {
return b, len(b), nil
}
func (o *origin) Encode(b []byte) ([]byte, error) {
return b, nil
}
func (o *origin) DecodePacket(b []byte) ([]byte, int, error) {
return b, len(b), nil
}
func (o *origin) EncodePacket(b []byte) ([]byte, error) {
return b, nil
} }

View File

@ -1,30 +1,26 @@
package protocol package protocol
import ( import (
"bytes"
"net" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/ssr/tools"
) )
// NewPacketConn returns a net.NewPacketConn with protocol decoding/encoding
func NewPacketConn(pc net.PacketConn, p Protocol) net.PacketConn {
return &PacketConn{PacketConn: pc, Protocol: p.initForConn(nil)}
}
// PacketConn represents a protocol packet connection
type PacketConn struct { type PacketConn struct {
net.PacketConn net.PacketConn
Protocol Protocol
} }
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
buf := pool.Get(pool.RelayBufferSize) buf := tools.BufPool.Get().(*bytes.Buffer)
defer pool.Put(buf) defer tools.BufPool.Put(buf)
buf, err := c.EncodePacket(b) defer buf.Reset()
err := c.EncodePacket(buf, b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
_, err = c.PacketConn.WriteTo(buf, addr) _, err = c.PacketConn.WriteTo(buf.Bytes(), addr)
return len(b), err return len(b), err
} }
@ -33,10 +29,10 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if err != nil { if err != nil {
return n, addr, err return n, addr, err
} }
bb, length, err := c.DecodePacket(b[:n]) decoded, err := c.DecodePacket(b[:n])
if err != nil { if err != nil {
return n, addr, err return n, addr, err
} }
copy(b, bb) copy(b, decoded)
return length, addr, err return len(decoded), addr, nil
} }

View File

@ -4,60 +4,73 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"strings" "math/rand"
"sync" "net"
) )
var ( var (
errAuthAES128IncorrectMAC = errors.New("auth_aes128_* post decrypt incorrect mac") errAuthSHA1V4CRC32Error = errors.New("auth_sha1_v4 decode data wrong crc32")
errAuthAES128DataLengthError = errors.New("auth_aes128_* post decrypt length mismatch") errAuthSHA1V4LengthError = errors.New("auth_sha1_v4 decode data wrong length")
errAuthAES128IncorrectChecksum = errors.New("auth_aes128_* post decrypt incorrect checksum") errAuthSHA1V4Adler32Error = errors.New("auth_sha1_v4 decode data wrong adler32")
errAuthAES128PositionTooLarge = errors.New("auth_aes128_* post decrypt position is too large") errAuthAES128MACError = errors.New("auth_aes128 decode data wrong mac")
errAuthSHA1v4CRC32Error = errors.New("auth_sha1_v4 post decrypt data crc32 error") errAuthAES128LengthError = errors.New("auth_aes128 decode data wrong length")
errAuthSHA1v4DataLengthError = errors.New("auth_sha1_v4 post decrypt data length error") errAuthAES128ChksumError = errors.New("auth_aes128 decode data wrong checksum")
errAuthSHA1v4IncorrectChecksum = errors.New("auth_sha1_v4 post decrypt incorrect checksum") errAuthChainLengthError = errors.New("auth_chain decode data wrong length")
errAuthChainDataLengthError = errors.New("auth_chain_* post decrypt length mismatch") errAuthChainChksumError = errors.New("auth_chain decode data wrong checksum")
errAuthChainHMACError = errors.New("auth_chain_* post decrypt hmac error")
) )
type authData struct {
clientID []byte
connectionID uint32
mutex sync.Mutex
}
type recvInfo struct {
recvID uint32
buffer *bytes.Buffer
}
type hmacMethod func(key []byte, data []byte) []byte
type hashDigestMethod func(data []byte) []byte
type rndMethod func(dataSize int, random *shift128PlusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int
// Protocol provides methods for decoding, encoding and iv setting
type Protocol interface { type Protocol interface {
initForConn(iv []byte) Protocol StreamConn(net.Conn, []byte) net.Conn
GetProtocolOverhead() int PacketConn(net.PacketConn) net.PacketConn
SetOverhead(int) Decode(dst, src *bytes.Buffer) error
Decode([]byte) ([]byte, int, error) Encode(buf *bytes.Buffer, b []byte) error
Encode([]byte) ([]byte, error) DecodePacket([]byte) ([]byte, error)
DecodePacket([]byte) ([]byte, int, error) EncodePacket(buf *bytes.Buffer, b []byte) error
EncodePacket([]byte) ([]byte, error)
} }
type protocolCreator func(b *Base) Protocol type protocolCreator func(b *Base) Protocol
var protocolList = make(map[string]protocolCreator) var protocolList = make(map[string]struct {
overhead int
new protocolCreator
})
func register(name string, c protocolCreator) { func register(name string, c protocolCreator, o int) {
protocolList[name] = c protocolList[name] = struct {
overhead int
new protocolCreator
}{overhead: o, new: c}
} }
// PickProtocol returns a protocol of the given name
func PickProtocol(name string, b *Base) (Protocol, error) { func PickProtocol(name string, b *Base) (Protocol, error) {
if protocolCreator, ok := protocolList[strings.ToLower(name)]; ok { if choice, ok := protocolList[name]; ok {
return protocolCreator(b), nil b.Overhead += choice.overhead
return choice.new(b), nil
} }
return nil, fmt.Errorf("Protocol %s not supported", name) return nil, fmt.Errorf("protocol %s not supported", name)
}
func getHeadSize(b []byte, defaultValue int) int {
if len(b) < 2 {
return defaultValue
}
headType := b[0] & 7
switch headType {
case 1:
return 7
case 4:
return 19
case 3:
return 4 + int(b[1])
}
return defaultValue
}
func getDataLength(b []byte) int {
bLength := len(b)
dataLength := getHeadSize(b, 30) + rand.Intn(32)
if bLength < dataLength {
return bLength
}
return dataLength
} }

View File

@ -5,31 +5,21 @@ import (
"net" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/ssr/tools"
) )
// NewConn wraps a stream-oriented net.Conn with protocol decoding/encoding
func NewConn(c net.Conn, p Protocol, iv []byte) net.Conn {
return &Conn{Conn: c, Protocol: p.initForConn(iv)}
}
// Conn represents a protocol connection
type Conn struct { type Conn struct {
net.Conn net.Conn
Protocol Protocol
buf []byte decoded bytes.Buffer
offset int
underDecoded bytes.Buffer underDecoded bytes.Buffer
} }
func (c *Conn) Read(b []byte) (int, error) { func (c *Conn) Read(b []byte) (int, error) {
if c.buf != nil { if c.decoded.Len() > 0 {
n := copy(b, c.buf[c.offset:]) return c.decoded.Read(b)
c.offset += n
if c.offset == len(c.buf) {
c.buf = nil
}
return n, nil
} }
buf := pool.Get(pool.RelayBufferSize) buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf) defer pool.Put(buf)
n, err := c.Conn.Read(buf) n, err := c.Conn.Read(buf)
@ -37,32 +27,26 @@ func (c *Conn) Read(b []byte) (int, error) {
return 0, err return 0, err
} }
c.underDecoded.Write(buf[:n]) c.underDecoded.Write(buf[:n])
underDecoded := c.underDecoded.Bytes() err = c.Decode(&c.decoded, &c.underDecoded)
decoded, length, err := c.Decode(underDecoded)
if err != nil { if err != nil {
c.underDecoded.Reset() return 0, err
return 0, nil
}
if length == 0 {
return 0, nil
}
c.underDecoded.Next(length)
n = copy(b, decoded)
if len(decoded) > len(b) {
c.buf = decoded
c.offset = n
} }
n, _ = c.decoded.Read(b)
return n, nil return n, nil
} }
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(b []byte) (int, error) {
encoded, err := c.Encode(b) bLength := len(b)
buf := tools.BufPool.Get().(*bytes.Buffer)
defer tools.BufPool.Put(buf)
defer buf.Reset()
err := c.Encode(buf, b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
_, err = c.Conn.Write(encoded) _, err = c.Conn.Write(buf.Bytes())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return len(b), nil return bLength, nil
} }

View File

@ -0,0 +1,18 @@
package tools
import (
"bytes"
"math/rand"
"sync"
"github.com/Dreamacro/clash/common/pool"
)
var BufPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }}
func AppendRandBytes(b *bytes.Buffer, length int) {
randBytes := pool.Get(length)
defer pool.Put(randBytes)
rand.Read(randBytes)
b.Write(randBytes)
}

View File

@ -11,13 +11,13 @@ const HmacSHA1Len = 10
func HmacMD5(key, data []byte) []byte { func HmacMD5(key, data []byte) []byte {
hmacMD5 := hmac.New(md5.New, key) hmacMD5 := hmac.New(md5.New, key)
hmacMD5.Write(data) hmacMD5.Write(data)
return hmacMD5.Sum(nil)[:16] return hmacMD5.Sum(nil)
} }
func HmacSHA1(key, data []byte) []byte { func HmacSHA1(key, data []byte) []byte {
hmacSHA1 := hmac.New(sha1.New, key) hmacSHA1 := hmac.New(sha1.New, key)
hmacSHA1.Write(data) hmacSHA1.Write(data)
return hmacSHA1.Sum(nil)[:20] return hmacSHA1.Sum(nil)
} }
func MD5Sum(b []byte) []byte { func MD5Sum(b []byte) []byte {

View File

@ -0,0 +1,57 @@
package tools
import (
"encoding/binary"
"github.com/Dreamacro/clash/common/pool"
)
// XorShift128Plus - a pseudorandom number generator
type XorShift128Plus struct {
s [2]uint64
}
func (r *XorShift128Plus) Next() uint64 {
x := r.s[0]
y := r.s[1]
r.s[0] = y
x ^= x << 23
x ^= y ^ (x >> 17) ^ (y >> 26)
r.s[1] = x
return x + y
}
func (r *XorShift128Plus) InitFromBin(bin []byte) {
var full []byte
if len(bin) < 16 {
full := pool.Get(16)[:0]
defer pool.Put(full)
full = append(full, bin...)
for len(full) < 16 {
full = append(full, 0)
}
} else {
full = bin
}
r.s[0] = binary.LittleEndian.Uint64(full[:8])
r.s[1] = binary.LittleEndian.Uint64(full[8:16])
}
func (r *XorShift128Plus) InitFromBinAndLength(bin []byte, length int) {
var full []byte
if len(bin) < 16 {
full := pool.Get(16)[:0]
defer pool.Put(full)
full = append(full, bin...)
for len(full) < 16 {
full = append(full, 0)
}
}
full = bin
binary.LittleEndian.PutUint16(full, uint16(length))
r.s[0] = binary.LittleEndian.Uint64(full[:8])
r.s[1] = binary.LittleEndian.Uint64(full[8:16])
for i := 0; i < 4; i++ {
r.Next()
}
}

View File

@ -86,7 +86,7 @@ func (r *aeadReader) Read(b []byte) (int, error) {
size := int(binary.BigEndian.Uint16(r.sizeBuf)) size := int(binary.BigEndian.Uint16(r.sizeBuf))
if size > maxSize { if size > maxSize {
return 0, errors.New("Buffer is larger than standard") return 0, errors.New("buffer is larger than standard")
} }
buf := pool.Get(size) buf := pool.Get(size)

View File

@ -47,7 +47,7 @@ func (cr *chunkReader) Read(b []byte) (int, error) {
size := int(binary.BigEndian.Uint16(cr.sizeBuf)) size := int(binary.BigEndian.Uint16(cr.sizeBuf))
if size > maxSize { if size > maxSize {
return 0, errors.New("Buffer is larger than standard") return 0, errors.New("buffer is larger than standard")
} }
if len(b) >= size { if len(b) >= size {

112
component/vmess/h2.go Normal file
View File

@ -0,0 +1,112 @@
package vmess
import (
"io"
"math/rand"
"net"
"net/http"
"net/url"
"golang.org/x/net/http2"
)
type h2Conn struct {
net.Conn
*http2.ClientConn
pwriter *io.PipeWriter
res *http.Response
cfg *H2Config
}
type H2Config struct {
Hosts []string
Path string
}
func (hc *h2Conn) establishConn() error {
preader, pwriter := io.Pipe()
host := hc.cfg.Hosts[rand.Intn(len(hc.cfg.Hosts))]
path := hc.cfg.Path
// TODO: connect use VMess Host instead of H2 Host
req := http.Request{
Method: "PUT",
Host: host,
URL: &url.URL{
Scheme: "https",
Host: host,
Path: path,
},
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Body: preader,
Header: map[string][]string{
"Accept-Encoding": {"identity"},
},
}
// it will be close at : `func (hc *h2Conn) Close() error`
res, err := hc.ClientConn.RoundTrip(&req)
if err != nil {
return err
}
hc.pwriter = pwriter
hc.res = res
return nil
}
// Read implements net.Conn.Read()
func (hc *h2Conn) Read(b []byte) (int, error) {
if hc.res != nil && !hc.res.Close {
n, err := hc.res.Body.Read(b)
return n, err
}
if err := hc.establishConn(); err != nil {
return 0, err
}
return hc.res.Body.Read(b)
}
// Write implements io.Writer.
func (hc *h2Conn) Write(b []byte) (int, error) {
if hc.pwriter != nil {
return hc.pwriter.Write(b)
}
if err := hc.establishConn(); err != nil {
return 0, err
}
return hc.pwriter.Write(b)
}
func (hc *h2Conn) Close() error {
if err := hc.pwriter.Close(); err != nil {
return err
}
if err := hc.ClientConn.Shutdown(hc.res.Request.Context()); err != nil {
return err
}
if err := hc.Conn.Close(); err != nil {
return err
}
return nil
}
func StreamH2Conn(conn net.Conn, cfg *H2Config) (net.Conn, error) {
transport := &http2.Transport{}
cconn, err := transport.NewClientConn(conn)
if err != nil {
return nil, err
}
return &h2Conn{
Conn: conn,
ClientConn: cconn,
cfg: cfg,
}, nil
}

View File

@ -9,6 +9,7 @@ type TLSConfig struct {
Host string Host string
SkipCertVerify bool SkipCertVerify bool
SessionCache tls.ClientSessionCache SessionCache tls.ClientSessionCache
NextProtos []string
} }
func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) { func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) {
@ -16,6 +17,7 @@ func StreamTLSConn(conn net.Conn, cfg *TLSConfig) (net.Conn, error) {
ServerName: cfg.Host, ServerName: cfg.Host,
InsecureSkipVerify: cfg.SkipCertVerify, InsecureSkipVerify: cfg.SkipCertVerify,
ClientSessionCache: cfg.SessionCache, ClientSessionCache: cfg.SessionCache,
NextProtos: cfg.NextProtos,
} }
tlsConn := tls.Client(conn, tlsConfig) tlsConn := tls.Client(conn, tlsConfig)

View File

@ -1,12 +1,10 @@
package vmess package vmess
import ( import (
"crypto/tls"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"runtime" "runtime"
"sync"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
) )
@ -37,11 +35,6 @@ var CipherMapping = map[string]byte{
"chacha20-poly1305": SecurityCHACHA20POLY1305, "chacha20-poly1305": SecurityCHACHA20POLY1305,
} }
var (
clientSessionCache tls.ClientSessionCache
once sync.Once
)
// Command types // Command types
const ( const (
CommandTCP byte = 1 CommandTCP byte = 1
@ -106,7 +99,7 @@ func NewClient(config Config) (*Client, error) {
security = SecurityAES128GCM security = SecurityAES128GCM
} }
default: default:
return nil, fmt.Errorf("Unknown security type: %s", config.Security) return nil, fmt.Errorf("unknown security type: %s", config.Security)
} }
return &Client{ return &Client{

View File

@ -73,7 +73,7 @@ func (wsc *websocketConn) Close() error {
errors = append(errors, err.Error()) errors = append(errors, err.Error())
} }
if len(errors) > 0 { if len(errors) > 0 {
return fmt.Errorf("Failed to close connection: %s", strings.Join(errors, ",")) return fmt.Errorf("failed to close connection: %s", strings.Join(errors, ","))
} }
return nil return nil
} }
@ -159,7 +159,7 @@ func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
if resp != nil { if resp != nil {
reason = resp.Status reason = resp.Status
} }
return nil, fmt.Errorf("Dial %s error: %s", uri.Host, reason) return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason)
} }
return &websocketConn{ return &websocketConn{

View File

@ -30,7 +30,7 @@ type General struct {
Mode T.TunnelMode `json:"mode"` Mode T.TunnelMode `json:"mode"`
LogLevel log.LogLevel `json:"log-level"` LogLevel log.LogLevel `json:"log-level"`
IPv6 bool `json:"ipv6"` IPv6 bool `json:"ipv6"`
Interface string `json:"interface-name"` Interface string `json:"-"`
} }
// Inbound // Inbound
@ -38,6 +38,7 @@ type Inbound struct {
Port int `json:"port"` Port int `json:"port"`
SocksPort int `json:"socks-port"` SocksPort int `json:"socks-port"`
RedirPort int `json:"redir-port"` RedirPort int `json:"redir-port"`
TProxyPort int `json:"tproxy-port"`
MixedPort int `json:"mixed-port"` MixedPort int `json:"mixed-port"`
Authentication []string `json:"authentication"` Authentication []string `json:"authentication"`
AllowLan bool `json:"allow-lan"` AllowLan bool `json:"allow-lan"`
@ -69,6 +70,12 @@ type DNS struct {
type FallbackFilter struct { type FallbackFilter struct {
GeoIP bool `yaml:"geoip"` GeoIP bool `yaml:"geoip"`
IPCIDR []*net.IPNet `yaml:"ipcidr"` IPCIDR []*net.IPNet `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
}
// Profile config
type Profile struct {
StoreSelected bool `yaml:"store-selected"`
} }
// Experimental config // Experimental config
@ -80,6 +87,7 @@ type Config struct {
DNS *DNS DNS *DNS
Experimental *Experimental Experimental *Experimental
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
Profile *Profile
Rules []C.Rule Rules []C.Rule
Users []auth.AuthUser Users []auth.AuthUser
Proxies map[string]C.Proxy Proxies map[string]C.Proxy
@ -103,12 +111,14 @@ type RawDNS struct {
type RawFallbackFilter struct { type RawFallbackFilter struct {
GeoIP bool `yaml:"geoip"` GeoIP bool `yaml:"geoip"`
IPCIDR []string `yaml:"ipcidr"` IPCIDR []string `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
} }
type RawConfig struct { type RawConfig struct {
Port int `yaml:"port"` Port int `yaml:"port"`
SocksPort int `yaml:"socks-port"` SocksPort int `yaml:"socks-port"`
RedirPort int `yaml:"redir-port"` RedirPort int `yaml:"redir-port"`
TProxyPort int `yaml:"tproxy-port"`
MixedPort int `yaml:"mixed-port"` MixedPort int `yaml:"mixed-port"`
Authentication []string `yaml:"authentication"` Authentication []string `yaml:"authentication"`
AllowLan bool `yaml:"allow-lan"` AllowLan bool `yaml:"allow-lan"`
@ -125,6 +135,7 @@ type RawConfig struct {
Hosts map[string]string `yaml:"hosts"` Hosts map[string]string `yaml:"hosts"`
DNS RawDNS `yaml:"dns"` DNS RawDNS `yaml:"dns"`
Experimental Experimental `yaml:"experimental"` Experimental Experimental `yaml:"experimental"`
Profile Profile `yaml:"profile"`
Proxy []map[string]interface{} `yaml:"proxies"` Proxy []map[string]interface{} `yaml:"proxies"`
ProxyGroup []map[string]interface{} `yaml:"proxy-groups"` ProxyGroup []map[string]interface{} `yaml:"proxy-groups"`
Rule []string `yaml:"rules"` Rule []string `yaml:"rules"`
@ -141,7 +152,7 @@ func Parse(buf []byte) (*Config, error) {
} }
func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
// config with some default value // config with default value
rawCfg := &RawConfig{ rawCfg := &RawConfig{
AllowLan: false, AllowLan: false,
BindAddress: "*", BindAddress: "*",
@ -165,6 +176,9 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
"8.8.8.8", "8.8.8.8",
}, },
}, },
Profile: Profile{
StoreSelected: true,
},
} }
if err := yaml.Unmarshal(buf, &rawCfg); err != nil { if err := yaml.Unmarshal(buf, &rawCfg); err != nil {
@ -178,6 +192,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
config := &Config{} config := &Config{}
config.Experimental = &rawCfg.Experimental config.Experimental = &rawCfg.Experimental
config.Profile = &rawCfg.Profile
general, err := parseGeneral(rawCfg) general, err := parseGeneral(rawCfg)
if err != nil { if err != nil {
@ -232,6 +247,7 @@ func parseGeneral(cfg *RawConfig) (*General, error) {
Port: cfg.Port, Port: cfg.Port,
SocksPort: cfg.SocksPort, SocksPort: cfg.SocksPort,
RedirPort: cfg.RedirPort, RedirPort: cfg.RedirPort,
TProxyPort: cfg.TProxyPort,
MixedPort: cfg.MixedPort, MixedPort: cfg.MixedPort,
AllowLan: cfg.AllowLan, AllowLan: cfg.AllowLan,
BindAddress: cfg.BindAddress, BindAddress: cfg.BindAddress,
@ -264,11 +280,11 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
for idx, mapping := range proxiesConfig { for idx, mapping := range proxiesConfig {
proxy, err := outbound.ParseProxy(mapping) proxy, err := outbound.ParseProxy(mapping)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("Proxy %d: %w", idx, err) return nil, nil, fmt.Errorf("proxy %d: %w", idx, err)
} }
if _, exist := proxies[proxy.Name()]; exist { if _, exist := proxies[proxy.Name()]; exist {
return nil, nil, fmt.Errorf("Proxy %s is the duplicate name", proxy.Name()) return nil, nil, fmt.Errorf("proxy %s is the duplicate name", proxy.Name())
} }
proxies[proxy.Name()] = proxy proxies[proxy.Name()] = proxy
proxyList = append(proxyList, proxy.Name()) proxyList = append(proxyList, proxy.Name())
@ -278,7 +294,7 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
for idx, mapping := range groupsConfig { for idx, mapping := range groupsConfig {
groupName, existName := mapping["name"].(string) groupName, existName := mapping["name"].(string)
if !existName { if !existName {
return nil, nil, fmt.Errorf("ProxyGroup %d: missing name", idx) return nil, nil, fmt.Errorf("proxy group %d: missing name", idx)
} }
proxyList = append(proxyList, groupName) proxyList = append(proxyList, groupName)
} }
@ -313,12 +329,12 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
for idx, mapping := range groupsConfig { for idx, mapping := range groupsConfig {
group, err := outboundgroup.ParseProxyGroup(mapping, proxies, providersMap) group, err := outboundgroup.ParseProxyGroup(mapping, proxies, providersMap)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("ProxyGroup[%d]: %w", idx, err) return nil, nil, fmt.Errorf("proxy group[%d]: %w", idx, err)
} }
groupName := group.Name() groupName := group.Name()
if _, exist := proxies[groupName]; exist { if _, exist := proxies[groupName]; exist {
return nil, nil, fmt.Errorf("ProxyGroup %s: the duplicate name", groupName) return nil, nil, fmt.Errorf("proxy group %s: the duplicate name", groupName)
} }
proxies[groupName] = outbound.NewProxy(group) proxies[groupName] = outbound.NewProxy(group)
@ -340,11 +356,16 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
for _, v := range proxyList { for _, v := range proxyList {
ps = append(ps, proxies[v]) ps = append(ps, proxies[v])
} }
hc := provider.NewHealthCheck(ps, "", 0) hc := provider.NewHealthCheck(ps, "", 0, true)
pd, _ := provider.NewCompatibleProvider(provider.ReservedName, ps, hc) pd, _ := provider.NewCompatibleProvider(provider.ReservedName, ps, hc)
providersMap[provider.ReservedName] = pd providersMap[provider.ReservedName] = pd
global := outboundgroup.NewSelector("GLOBAL", []provider.ProxyProvider{pd}) global := outboundgroup.NewSelector(
&outboundgroup.GroupCommonOption{
Name: "GLOBAL",
},
[]provider.ProxyProvider{pd},
)
proxies["GLOBAL"] = outbound.NewProxy(global) proxies["GLOBAL"] = outbound.NewProxy(global)
return proxies, providersMap, nil return proxies, providersMap, nil
} }
@ -373,11 +394,11 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) {
target = rule[2] target = rule[2]
params = rule[3:] params = rule[3:]
default: default:
return nil, fmt.Errorf("Rules[%d] [%s] error: format invalid", idx, line) return nil, fmt.Errorf("rules[%d] [%s] error: format invalid", idx, line)
} }
if _, ok := proxies[target]; !ok { if _, ok := proxies[target]; !ok {
return nil, fmt.Errorf("Rules[%d] [%s] error: proxy [%s] not found", idx, line, target) return nil, fmt.Errorf("rules[%d] [%s] error: proxy [%s] not found", idx, line, target)
} }
rule = trimArr(rule) rule = trimArr(rule)
@ -385,11 +406,7 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) {
parsed, parseErr := R.ParseRule(rule[0], payload, target, params) parsed, parseErr := R.ParseRule(rule[0], payload, target, params)
if parseErr != nil { if parseErr != nil {
if parseErr == R.ErrPlatformNotSupport { return nil, fmt.Errorf("rules[%d] [%s] error: %s", idx, line, parseErr.Error())
log.Warnln("Rules[%d] [%s] don't support current OS, skip", idx, line)
continue
}
return nil, fmt.Errorf("Rules[%d] [%s] error: %s", idx, line, parseErr.Error())
} }
rules = append(rules, parsed) rules = append(rules, parsed)
@ -403,7 +420,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie, error) {
// add default hosts // add default hosts
if err := tree.Insert("localhost", net.IP{127, 0, 0, 1}); err != nil { if err := tree.Insert("localhost", net.IP{127, 0, 0, 1}); err != nil {
println(err.Error()) log.Errorln("insert localhost to host error: %s", err.Error())
} }
if len(cfg.Hosts) != 0 { if len(cfg.Hosts) != 0 {
@ -499,7 +516,7 @@ func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) {
func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) { func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
if cfg.Enable && len(cfg.NameServer) == 0 { if cfg.Enable && len(cfg.NameServer) == 0 {
return nil, fmt.Errorf("If DNS configuration is turned on, NameServer cannot be empty") return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty")
} }
dnsCfg := &DNS{ dnsCfg := &DNS{
@ -561,6 +578,7 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil { if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
dnsCfg.FallbackFilter.IPCIDR = fallbackip dnsCfg.FallbackFilter.IPCIDR = fallbackip
} }
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
if cfg.UseHosts { if cfg.UseHosts {
dnsCfg.Hosts = hosts dnsCfg.Hosts = hosts

View File

@ -12,7 +12,7 @@ import (
) )
func downloadMMDB(path string) (err error) { func downloadMMDB(path string) (err error) {
resp, err := http.Get("https://github.com/Dreamacro/maxmind-geoip/releases/latest/download/Country.mmdb") resp, err := http.Get("https://cdn.jsdelivr.net/gh/Dreamacro/maxmind-geoip@release/Country.mmdb")
if err != nil { if err != nil {
return return
} }
@ -32,18 +32,18 @@ func initMMDB() error {
if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) { if _, err := os.Stat(C.Path.MMDB()); os.IsNotExist(err) {
log.Infoln("Can't find MMDB, start download") log.Infoln("Can't find MMDB, start download")
if err := downloadMMDB(C.Path.MMDB()); err != nil { if err := downloadMMDB(C.Path.MMDB()); err != nil {
return fmt.Errorf("Can't download MMDB: %s", err.Error()) return fmt.Errorf("can't download MMDB: %s", err.Error())
} }
} }
if !mmdb.Verify() { if !mmdb.Verify() {
log.Warnln("MMDB invalid, remove and download") log.Warnln("MMDB invalid, remove and download")
if err := os.Remove(C.Path.MMDB()); err != nil { if err := os.Remove(C.Path.MMDB()); err != nil {
return fmt.Errorf("Can't remove invalid MMDB: %s", err.Error()) return fmt.Errorf("can't remove invalid MMDB: %s", err.Error())
} }
if err := downloadMMDB(C.Path.MMDB()); err != nil { if err := downloadMMDB(C.Path.MMDB()); err != nil {
return fmt.Errorf("Can't download MMDB: %s", err.Error()) return fmt.Errorf("can't download MMDB: %s", err.Error())
} }
} }
@ -55,7 +55,7 @@ func Init(dir string) error {
// initial homedir // initial homedir
if _, err := os.Stat(dir); os.IsNotExist(err) { if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0777); err != nil { if err := os.MkdirAll(dir, 0777); err != nil {
return fmt.Errorf("Can't create config directory %s: %s", dir, err.Error()) return fmt.Errorf("can't create config directory %s: %s", dir, err.Error())
} }
} }
@ -64,7 +64,7 @@ func Init(dir string) error {
log.Infoln("Can't find config, create a initial config file") log.Infoln("Can't find config, create a initial config file")
f, err := os.OpenFile(C.Path.Config(), os.O_CREATE|os.O_WRONLY, 0644) f, err := os.OpenFile(C.Path.Config(), os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return fmt.Errorf("Can't create file %s: %s", C.Path.Config(), err.Error()) return fmt.Errorf("can't create file %s: %s", C.Path.Config(), err.Error())
} }
f.Write([]byte(`port: 7890`)) f.Write([]byte(`port: 7890`))
f.Close() f.Close()
@ -72,7 +72,7 @@ func Init(dir string) error {
// initial mmdb // initial mmdb
if err := initMMDB(); err != nil { if err := initMMDB(); err != nil {
return fmt.Errorf("Can't initial MMDB: %w", err) return fmt.Errorf("can't initial MMDB: %w", err)
} }
return nil return nil
} }

View File

@ -15,15 +15,6 @@ func trimArr(arr []string) (r []string) {
return return
} }
func or(pointers ...*int) *int {
for _, p := range pointers {
if p != nil {
return p
}
}
return pointers[len(pointers)-1]
}
// Check if ProxyGroups form DAG(Directed Acyclic Graph), and sort all ProxyGroups by dependency order. // Check if ProxyGroups form DAG(Directed Acyclic Graph), and sort all ProxyGroups by dependency order.
// Meanwhile, record the original index in the config file. // Meanwhile, record the original index in the config file.
// If loop is detected, return an error with location of loop. // If loop is detected, return an error with location of loop.
@ -32,7 +23,7 @@ func proxyGroupsDagSort(groupsConfig []map[string]interface{}) error {
indegree int indegree int
// topological order // topological order
topo int topo int
// the origional data in `groupsConfig` // the original data in `groupsConfig`
data map[string]interface{} data map[string]interface{}
// `outdegree` and `from` are used in loop locating // `outdegree` and `from` are used in loop locating
outdegree int outdegree int
@ -74,7 +65,7 @@ func proxyGroupsDagSort(groupsConfig []map[string]interface{}) error {
index := 0 index := 0
queue := make([]string, 0) queue := make([]string, 0)
for name, node := range graph { for name, node := range graph {
// in the begning, put nodes that have `node.indegree == 0` into queue. // in the beginning, put nodes that have `node.indegree == 0` into queue.
if node.indegree == 0 { if node.indegree == 0 {
queue = append(queue, name) queue = append(queue, name)
} }
@ -153,5 +144,5 @@ func proxyGroupsDagSort(groupsConfig []map[string]interface{}) error {
loopElements = append(loopElements, name) loopElements = append(loopElements, name)
delete(graph, name) delete(graph, name)
} }
return fmt.Errorf("Loop is detected in ProxyGroup, please check following ProxyGroups: %v", loopElements) return fmt.Errorf("loop is detected in ProxyGroup, please check following ProxyGroups: %v", loopElements)
} }

View File

@ -27,11 +27,6 @@ const (
LoadBalance LoadBalance
) )
type ServerAdapter interface {
net.Conn
Metadata() *Metadata
}
type Connection interface { type Connection interface {
Chains() Chain Chains() Chain
AppendToChains(adapter ProxyAdapter) AppendToChains(adapter ProxyAdapter)
@ -50,6 +45,15 @@ func (c Chain) String() string {
} }
} }
func (c Chain) Last() string {
switch len(c) {
case 0:
return ""
default:
return c[0]
}
}
type Conn interface { type Conn interface {
net.Conn net.Conn
Connection Connection
@ -137,7 +141,7 @@ type UDPPacket interface {
// WriteBack writes the payload with source IP/Port equals addr // WriteBack writes the payload with source IP/Port equals addr
// - variable source IP/Port is important to STUN // - variable source IP/Port is important to STUN
// - if addr is not provided, WriteBack will wirte out UDP packet with SourceIP/Prot equals to origional Target, // - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target,
// this is important when using Fake-IP. // this is important when using Fake-IP.
WriteBack(b []byte, addr net.Addr) (n int, err error) WriteBack(b []byte, addr net.Addr) (n int, err error)

23
constant/context.go Normal file
View File

@ -0,0 +1,23 @@
package constant
import (
"net"
"github.com/gofrs/uuid"
)
type PlainContext interface {
ID() uuid.UUID
}
type ConnContext interface {
PlainContext
Metadata() *Metadata
Conn() net.Conn
}
type PacketConnContext interface {
PlainContext
Metadata() *Metadata
PacketConn() net.PacketConn
}

View File

@ -19,6 +19,7 @@ const (
HTTPCONNECT HTTPCONNECT
SOCKS SOCKS
REDIR REDIR
TPROXY
) )
type NetWork int type NetWork int
@ -46,6 +47,8 @@ func (t Type) String() string {
return "Socks5" return "Socks5"
case REDIR: case REDIR:
return "Redir" return "Redir"
case TPROXY:
return "TProxy"
default: default:
return "Unknown" return "Unknown"
} }

View File

@ -56,3 +56,7 @@ func (p *path) Resolve(path string) string {
func (p *path) MMDB() string { func (p *path) MMDB() string {
return P.Join(p.homeDir, "Country.mmdb") return P.Join(p.homeDir, "Country.mmdb")
} }
func (p *path) Cache() string {
return P.Join(p.homeDir, ".cache")
}

39
context/conn.go Normal file
View File

@ -0,0 +1,39 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type ConnContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
}
func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
id, _ := uuid.NewV4()
return &ConnContext{
id: id,
metadata: metadata,
conn: conn,
}
}
// ID implement C.ConnContext ID
func (c *ConnContext) ID() uuid.UUID {
return c.id
}
// Metadata implement C.ConnContext Metadata
func (c *ConnContext) Metadata() *C.Metadata {
return c.metadata
}
// Conn implement C.ConnContext Conn
func (c *ConnContext) Conn() net.Conn {
return c.conn
}

41
context/dns.go Normal file
View File

@ -0,0 +1,41 @@
package context
import (
"github.com/gofrs/uuid"
"github.com/miekg/dns"
)
const (
DNSTypeHost = "host"
DNSTypeFakeIP = "fakeip"
DNSTypeRaw = "raw"
)
type DNSContext struct {
id uuid.UUID
msg *dns.Msg
tp string
}
func NewDNSContext(msg *dns.Msg) *DNSContext {
id, _ := uuid.NewV4()
return &DNSContext{
id: id,
msg: msg,
}
}
// ID implement C.PlainContext ID
func (c *DNSContext) ID() uuid.UUID {
return c.id
}
// SetType set type of response
func (c *DNSContext) SetType(tp string) {
c.tp = tp
}
// Type return type of response
func (c *DNSContext) Type() string {
return c.tp
}

47
context/http.go Normal file
View File

@ -0,0 +1,47 @@
package context
import (
"net"
"net/http"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type HTTPContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
req *http.Request
}
func NewHTTPContext(conn net.Conn, req *http.Request, metadata *C.Metadata) *HTTPContext {
id, _ := uuid.NewV4()
return &HTTPContext{
id: id,
metadata: metadata,
conn: conn,
req: req,
}
}
// ID implement C.ConnContext ID
func (hc *HTTPContext) ID() uuid.UUID {
return hc.id
}
// Metadata implement C.ConnContext Metadata
func (hc *HTTPContext) Metadata() *C.Metadata {
return hc.metadata
}
// Conn implement C.ConnContext Conn
func (hc *HTTPContext) Conn() net.Conn {
return hc.conn
}
// Request return the http request struct
func (hc *HTTPContext) Request() *http.Request {
return hc.req
}

43
context/packetconn.go Normal file
View File

@ -0,0 +1,43 @@
package context
import (
"net"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
)
type PacketConnContext struct {
id uuid.UUID
metadata *C.Metadata
packetConn net.PacketConn
}
func NewPacketConnContext(metadata *C.Metadata) *PacketConnContext {
id, _ := uuid.NewV4()
return &PacketConnContext{
id: id,
metadata: metadata,
}
}
// ID implement C.PacketConnContext ID
func (pc *PacketConnContext) ID() uuid.UUID {
return pc.id
}
// Metadata implement C.PacketConnContext Metadata
func (pc *PacketConnContext) Metadata() *C.Metadata {
return pc.metadata
}
// PacketConn implement C.PacketConnContext PacketConn
func (pc *PacketConnContext) PacketConn() net.PacketConn {
return pc.packetConn
}
// InjectPacketConn injectPacketConn manually
func (pc *PacketConnContext) InjectPacketConn(pconn C.PacketConn) {
pc.packetConn = pconn
}

View File

@ -39,7 +39,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
return nil, err return nil, err
} }
if dialer.DialHook != nil { if ip != nil && ip.IsGlobalUnicast() && dialer.DialHook != nil {
network := "udp" network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") { if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp" network = "tcp"

Some files were not shown because too many files have changed in this diff Show More