chore: embed hysteria, clean irrelevant codes, code from https://github.com/HyNetwork/hysteria
This commit is contained in:
100
transport/hysteria/acl/engine.go
Normal file
100
transport/hysteria/acl/engine.go
Normal file
@ -0,0 +1,100 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"github.com/Dreamacro/clash/transport/hysteria/utils"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"net"
|
||||
)
|
||||
|
||||
const entryCacheSize = 1024
|
||||
|
||||
type Engine struct {
|
||||
DefaultAction Action
|
||||
Entries []Entry
|
||||
Cache *lru.ARCCache
|
||||
ResolveIPAddr func(string) (*net.IPAddr, error)
|
||||
GeoIPReader *geoip2.Reader
|
||||
}
|
||||
|
||||
type cacheKey struct {
|
||||
Host string
|
||||
Port uint16
|
||||
IsUDP bool
|
||||
}
|
||||
|
||||
type cacheValue struct {
|
||||
Action Action
|
||||
Arg string
|
||||
}
|
||||
|
||||
// action, arg, isDomain, resolvedIP, error
|
||||
func (e *Engine) ResolveAndMatch(host string, port uint16, isUDP bool) (Action, string, bool, *net.IPAddr, error) {
|
||||
ip, zone := utils.ParseIPZone(host)
|
||||
if ip == nil {
|
||||
// Domain
|
||||
ipAddr, err := e.ResolveIPAddr(host)
|
||||
if v, ok := e.Cache.Get(cacheKey{host, port, isUDP}); ok {
|
||||
// Cache hit
|
||||
ce := v.(cacheValue)
|
||||
return ce.Action, ce.Arg, true, ipAddr, err
|
||||
}
|
||||
for _, entry := range e.Entries {
|
||||
mReq := MatchRequest{
|
||||
Domain: host,
|
||||
Port: port,
|
||||
DB: e.GeoIPReader,
|
||||
}
|
||||
if ipAddr != nil {
|
||||
mReq.IP = ipAddr.IP
|
||||
}
|
||||
if isUDP {
|
||||
mReq.Protocol = ProtocolUDP
|
||||
} else {
|
||||
mReq.Protocol = ProtocolTCP
|
||||
}
|
||||
if entry.Match(mReq) {
|
||||
e.Cache.Add(cacheKey{host, port, isUDP},
|
||||
cacheValue{entry.Action, entry.ActionArg})
|
||||
return entry.Action, entry.ActionArg, true, ipAddr, err
|
||||
}
|
||||
}
|
||||
e.Cache.Add(cacheKey{host, port, isUDP}, cacheValue{e.DefaultAction, ""})
|
||||
return e.DefaultAction, "", true, ipAddr, err
|
||||
} else {
|
||||
// IP
|
||||
if v, ok := e.Cache.Get(cacheKey{ip.String(), port, isUDP}); ok {
|
||||
// Cache hit
|
||||
ce := v.(cacheValue)
|
||||
return ce.Action, ce.Arg, false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
for _, entry := range e.Entries {
|
||||
mReq := MatchRequest{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
DB: e.GeoIPReader,
|
||||
}
|
||||
if isUDP {
|
||||
mReq.Protocol = ProtocolUDP
|
||||
} else {
|
||||
mReq.Protocol = ProtocolTCP
|
||||
}
|
||||
if entry.Match(mReq) {
|
||||
e.Cache.Add(cacheKey{ip.String(), port, isUDP},
|
||||
cacheValue{entry.Action, entry.ActionArg})
|
||||
return entry.Action, entry.ActionArg, false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
e.Cache.Add(cacheKey{ip.String(), port, isUDP}, cacheValue{e.DefaultAction, ""})
|
||||
return e.DefaultAction, "", false, &net.IPAddr{
|
||||
IP: ip,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
}
|
||||
}
|
154
transport/hysteria/acl/engine_test.go
Normal file
154
transport/hysteria/acl/engine_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEngine_ResolveAndMatch(t *testing.T) {
|
||||
cache, _ := lru.NewARC(16)
|
||||
e := &Engine{
|
||||
DefaultAction: ActionDirect,
|
||||
Entries: []Entry{
|
||||
{
|
||||
Action: ActionProxy,
|
||||
ActionArg: "",
|
||||
Matcher: &domainMatcher{
|
||||
matcherBase: matcherBase{
|
||||
Protocol: ProtocolTCP,
|
||||
Port: 443,
|
||||
},
|
||||
Domain: "google.com",
|
||||
Suffix: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionHijack,
|
||||
ActionArg: "good.org",
|
||||
Matcher: &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "evil.corp",
|
||||
Suffix: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionProxy,
|
||||
ActionArg: "",
|
||||
Matcher: &netMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Net: &net.IPNet{
|
||||
IP: net.ParseIP("10.0.0.0"),
|
||||
Mask: net.CIDRMask(8, 32),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: ActionBlock,
|
||||
ActionArg: "",
|
||||
Matcher: &allMatcher{},
|
||||
},
|
||||
},
|
||||
Cache: cache,
|
||||
ResolveIPAddr: func(s string) (*net.IPAddr, error) {
|
||||
if strings.Contains(s, "evil.corp") {
|
||||
return nil, errors.New("resolve error")
|
||||
}
|
||||
return net.ResolveIPAddr("ip", s)
|
||||
},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
port uint16
|
||||
isUDP bool
|
||||
wantAction Action
|
||||
wantArg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "domain proxy",
|
||||
host: "google.com",
|
||||
port: 443,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain block",
|
||||
host: "google.com",
|
||||
port: 80,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain suffix 1",
|
||||
host: "evil.corp",
|
||||
port: 8899,
|
||||
isUDP: true,
|
||||
wantAction: ActionHijack,
|
||||
wantArg: "good.org",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "domain suffix 2",
|
||||
host: "notevil.corp",
|
||||
port: 22,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "domain suffix 3",
|
||||
host: "im.real.evil.corp",
|
||||
port: 443,
|
||||
isUDP: true,
|
||||
wantAction: ActionHijack,
|
||||
wantArg: "good.org",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ip match",
|
||||
host: "10.2.3.4",
|
||||
port: 80,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "ip mismatch",
|
||||
host: "100.5.6.0",
|
||||
port: 1234,
|
||||
isUDP: false,
|
||||
wantAction: ActionBlock,
|
||||
wantArg: "",
|
||||
},
|
||||
{
|
||||
name: "domain proxy cache",
|
||||
host: "google.com",
|
||||
port: 443,
|
||||
isUDP: false,
|
||||
wantAction: ActionProxy,
|
||||
wantArg: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotAction != tt.wantAction {
|
||||
t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction)
|
||||
}
|
||||
if gotArg != tt.wantArg {
|
||||
t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
331
transport/hysteria/acl/entry.go
Normal file
331
transport/hysteria/acl/entry.go
Normal file
@ -0,0 +1,331 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Action byte
|
||||
type Protocol byte
|
||||
|
||||
const (
|
||||
ActionDirect = Action(iota)
|
||||
ActionProxy
|
||||
ActionBlock
|
||||
ActionHijack
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolAll = Protocol(iota)
|
||||
ProtocolTCP
|
||||
ProtocolUDP
|
||||
)
|
||||
|
||||
var protocolPortAliases = map[string]string{
|
||||
"echo": "*/7",
|
||||
"ftp-data": "*/20",
|
||||
"ftp": "*/21",
|
||||
"ssh": "*/22",
|
||||
"telnet": "*/23",
|
||||
"domain": "*/53",
|
||||
"dns": "*/53",
|
||||
"http": "*/80",
|
||||
"sftp": "*/115",
|
||||
"ntp": "*/123",
|
||||
"https": "*/443",
|
||||
"quic": "udp/443",
|
||||
"socks": "*/1080",
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
Action Action
|
||||
ActionArg string
|
||||
Matcher Matcher
|
||||
}
|
||||
|
||||
type MatchRequest struct {
|
||||
IP net.IP
|
||||
Domain string
|
||||
|
||||
Protocol Protocol
|
||||
Port uint16
|
||||
|
||||
DB *geoip2.Reader
|
||||
}
|
||||
|
||||
type Matcher interface {
|
||||
Match(MatchRequest) bool
|
||||
}
|
||||
|
||||
type matcherBase struct {
|
||||
Protocol Protocol
|
||||
Port uint16 // 0 for all ports
|
||||
}
|
||||
|
||||
func (m *matcherBase) MatchProtocolPort(p Protocol, port uint16) bool {
|
||||
return (m.Protocol == ProtocolAll || m.Protocol == p) && (m.Port == 0 || m.Port == port)
|
||||
}
|
||||
|
||||
func parseProtocolPort(s string) (Protocol, uint16, error) {
|
||||
if protocolPortAliases[s] != "" {
|
||||
s = protocolPortAliases[s]
|
||||
}
|
||||
if len(s) == 0 || s == "*" {
|
||||
return ProtocolAll, 0, nil
|
||||
}
|
||||
parts := strings.Split(s, "/")
|
||||
if len(parts) != 2 {
|
||||
return ProtocolAll, 0, errors.New("invalid protocol/port syntax")
|
||||
}
|
||||
protocol := ProtocolAll
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
protocol = ProtocolTCP
|
||||
case "udp":
|
||||
protocol = ProtocolUDP
|
||||
case "*":
|
||||
protocol = ProtocolAll
|
||||
default:
|
||||
return ProtocolAll, 0, errors.New("invalid protocol")
|
||||
}
|
||||
if parts[1] == "*" {
|
||||
return protocol, 0, nil
|
||||
}
|
||||
port, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return ProtocolAll, 0, errors.New("invalid port")
|
||||
}
|
||||
return protocol, uint16(port), nil
|
||||
}
|
||||
|
||||
type netMatcher struct {
|
||||
matcherBase
|
||||
Net *net.IPNet
|
||||
}
|
||||
|
||||
func (m *netMatcher) Match(r MatchRequest) bool {
|
||||
if r.IP == nil {
|
||||
return false
|
||||
}
|
||||
return m.Net.Contains(r.IP) && m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type domainMatcher struct {
|
||||
matcherBase
|
||||
Domain string
|
||||
Suffix bool
|
||||
}
|
||||
|
||||
func (m *domainMatcher) Match(r MatchRequest) bool {
|
||||
if len(r.Domain) == 0 {
|
||||
return false
|
||||
}
|
||||
domain := strings.ToLower(r.Domain)
|
||||
return (m.Domain == domain || (m.Suffix && strings.HasSuffix(domain, "."+m.Domain))) &&
|
||||
m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type countryMatcher struct {
|
||||
matcherBase
|
||||
Country string // ISO 3166-1 alpha-2 country code, upper case
|
||||
}
|
||||
|
||||
func (m *countryMatcher) Match(r MatchRequest) bool {
|
||||
if r.IP == nil || r.DB == nil {
|
||||
return false
|
||||
}
|
||||
c, err := r.DB.Country(r.IP)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return c.Country.IsoCode == m.Country && m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
type allMatcher struct {
|
||||
matcherBase
|
||||
}
|
||||
|
||||
func (m *allMatcher) Match(r MatchRequest) bool {
|
||||
return m.MatchProtocolPort(r.Protocol, r.Port)
|
||||
}
|
||||
|
||||
func (e Entry) Match(r MatchRequest) bool {
|
||||
return e.Matcher.Match(r)
|
||||
}
|
||||
|
||||
func ParseEntry(s string) (Entry, error) {
|
||||
fields := strings.Fields(s)
|
||||
if len(fields) < 2 {
|
||||
return Entry{}, fmt.Errorf("expected at least 2 fields, got %d", len(fields))
|
||||
}
|
||||
e := Entry{}
|
||||
action := fields[0]
|
||||
conds := fields[1:]
|
||||
switch strings.ToLower(action) {
|
||||
case "direct":
|
||||
e.Action = ActionDirect
|
||||
case "proxy":
|
||||
e.Action = ActionProxy
|
||||
case "block":
|
||||
e.Action = ActionBlock
|
||||
case "hijack":
|
||||
if len(conds) < 2 {
|
||||
return Entry{}, fmt.Errorf("hijack requires at least 3 fields, got %d", len(fields))
|
||||
}
|
||||
e.Action = ActionHijack
|
||||
e.ActionArg = conds[len(conds)-1]
|
||||
conds = conds[:len(conds)-1]
|
||||
default:
|
||||
return Entry{}, fmt.Errorf("invalid action %s", fields[0])
|
||||
}
|
||||
m, err := condsToMatcher(conds)
|
||||
if err != nil {
|
||||
return Entry{}, err
|
||||
}
|
||||
e.Matcher = m
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func condsToMatcher(conds []string) (Matcher, error) {
|
||||
if len(conds) < 1 {
|
||||
return nil, errors.New("no condition specified")
|
||||
}
|
||||
typ, args := conds[0], conds[1:]
|
||||
switch strings.ToLower(typ) {
|
||||
case "domain":
|
||||
// domain <domain> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for domain: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &domainMatcher{
|
||||
matcherBase: mb,
|
||||
Domain: args[0],
|
||||
Suffix: false,
|
||||
}, nil
|
||||
case "domain-suffix":
|
||||
// domain-suffix <domain> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for domain-suffix: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &domainMatcher{
|
||||
matcherBase: mb,
|
||||
Domain: args[0],
|
||||
Suffix: true,
|
||||
}, nil
|
||||
case "cidr":
|
||||
// cidr <cidr> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for cidr: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
_, ipNet, err := net.ParseCIDR(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &netMatcher{
|
||||
matcherBase: mb,
|
||||
Net: ipNet,
|
||||
}, nil
|
||||
case "ip":
|
||||
// ip <ip> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for ip: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
ip := net.ParseIP(args[0])
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid ip: %s", args[0])
|
||||
}
|
||||
var ipNet *net.IPNet
|
||||
if ip.To4() != nil {
|
||||
ipNet = &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(32, 32),
|
||||
}
|
||||
} else {
|
||||
ipNet = &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(128, 128),
|
||||
}
|
||||
}
|
||||
return &netMatcher{
|
||||
matcherBase: mb,
|
||||
Net: ipNet,
|
||||
}, nil
|
||||
case "country":
|
||||
// country <country> <optional: protocol/port>
|
||||
if len(args) == 0 || len(args) > 2 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for country: %d, expected 1 or 2", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 2 {
|
||||
protocol, port, err := parseProtocolPort(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &countryMatcher{
|
||||
matcherBase: mb,
|
||||
Country: strings.ToUpper(args[0]),
|
||||
}, nil
|
||||
case "all":
|
||||
// all <optional: protocol/port>
|
||||
if len(args) > 1 {
|
||||
return nil, fmt.Errorf("invalid number of arguments for all: %d, expected 0 or 1", len(args))
|
||||
}
|
||||
mb := matcherBase{}
|
||||
if len(args) == 1 {
|
||||
protocol, port, err := parseProtocolPort(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mb.Protocol = protocol
|
||||
mb.Port = port
|
||||
}
|
||||
return &allMatcher{
|
||||
matcherBase: mb,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid condition type: %s", typ)
|
||||
}
|
||||
}
|
75
transport/hysteria/acl/entry_test.go
Normal file
75
transport/hysteria/acl/entry_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseEntry(t *testing.T) {
|
||||
_, ok3net, _ := net.ParseCIDR("8.8.8.0/24")
|
||||
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Entry
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "empty", args: args{""}, want: Entry{}, wantErr: true},
|
||||
{name: "ok 1", args: args{"direct domain-suffix google.com"},
|
||||
want: Entry{ActionDirect, "", &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "google.com",
|
||||
Suffix: true,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 2", args: args{"proxy domain shithole"},
|
||||
want: Entry{ActionProxy, "", &domainMatcher{
|
||||
matcherBase: matcherBase{},
|
||||
Domain: "shithole",
|
||||
Suffix: false,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 3", args: args{"block cidr 8.8.8.0/24 */53"},
|
||||
want: Entry{ActionBlock, "", &netMatcher{
|
||||
matcherBase: matcherBase{ProtocolAll, 53},
|
||||
Net: ok3net,
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "ok 4", args: args{"hijack all udp/* udpblackhole.net"},
|
||||
want: Entry{ActionHijack, "udpblackhole.net", &allMatcher{
|
||||
matcherBase: matcherBase{ProtocolUDP, 0},
|
||||
}},
|
||||
wantErr: false},
|
||||
{name: "err 1", args: args{"what the heck"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 2", args: args{"proxy sucks ass"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 3", args: args{"block ip 999.999.999.999"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 4", args: args{"hijack domain google.com"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
{name: "err 5", args: args{"hijack domain google.com bing.com 123"},
|
||||
want: Entry{},
|
||||
wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseEntry(tt.args.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEntry() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseEntry() got = %v, wantAction %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
145
transport/hysteria/congestion/brutal.go
Normal file
145
transport/hysteria/congestion/brutal.go
Normal file
@ -0,0 +1,145 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
initMaxDatagramSize = 1252
|
||||
|
||||
pktInfoSlotCount = 4
|
||||
minSampleCount = 50
|
||||
minAckRate = 0.8
|
||||
)
|
||||
|
||||
type BrutalSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
bps congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
pacer *pacer
|
||||
|
||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||
ackRate float64
|
||||
}
|
||||
|
||||
type pktInfo struct {
|
||||
Timestamp int64
|
||||
AckCount uint64
|
||||
LossCount uint64
|
||||
}
|
||||
|
||||
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
|
||||
bs := &BrutalSender{
|
||||
bps: bps,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
ackRate: 1,
|
||||
}
|
||||
bs.pacer = newPacer(func() congestion.ByteCount {
|
||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||
})
|
||||
return bs
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
||||
b.rttStats = rttStats
|
||||
}
|
||||
|
||||
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||
return b.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) HasPacingBudget() bool {
|
||||
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
|
||||
}
|
||||
|
||||
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < b.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
||||
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
|
||||
if rtt <= 0 {
|
||||
return 10240
|
||||
}
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) {
|
||||
b.pacer.SentPacket(sentTime, bytes)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount, eventTime time.Time) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].AckCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 1
|
||||
b.pktInfoSlots[slot].LossCount = 0
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount) {
|
||||
currentTimestamp := time.Now().Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].LossCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 0
|
||||
b.pktInfoSlots[slot].LossCount = 1
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||
b.maxDatagramSize = size
|
||||
b.pacer.SetMaxDatagramSize(size)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||
minTimestamp := currentTimestamp - pktInfoSlotCount
|
||||
var ackCount, lossCount uint64
|
||||
for _, info := range b.pktInfoSlots {
|
||||
if info.Timestamp < minTimestamp {
|
||||
continue
|
||||
}
|
||||
ackCount += info.AckCount
|
||||
lossCount += info.LossCount
|
||||
}
|
||||
if ackCount+lossCount < minSampleCount {
|
||||
b.ackRate = 1
|
||||
}
|
||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||
if rate < minAckRate {
|
||||
b.ackRate = minAckRate
|
||||
}
|
||||
b.ackRate = rate
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InSlowStart() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InRecovery() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) MaybeExitSlowStart() {}
|
||||
|
||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||
|
||||
func maxDuration(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
85
transport/hysteria/congestion/pacer.go
Normal file
85
transport/hysteria/congestion/pacer.go
Normal file
@ -0,0 +1,85 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstPackets = 10
|
||||
minPacingDelay = time.Millisecond
|
||||
)
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
lastSentTime time.Time
|
||||
getBandwidth func() congestion.ByteCount // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
|
||||
p := &pacer{
|
||||
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
getBandwidth: getBandwidth,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
return minByteCount(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||
return maxByteCount(
|
||||
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
||||
maxBurstPackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(maxDuration(
|
||||
minPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
||||
float64(p.getBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
|
||||
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
1
transport/hysteria/conns/faketcp/LICENSE
Normal file
1
transport/hysteria/conns/faketcp/LICENSE
Normal file
@ -0,0 +1 @@
|
||||
Grabbed from https://github.com/xtaci/tcpraw with modifications
|
102
transport/hysteria/conns/faketcp/obfs.go
Normal file
102
transport/hysteria/conns/faketcp/obfs.go
Normal file
@ -0,0 +1,102 @@
|
||||
package faketcp
|
||||
|
||||
import (
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/obfs"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const udpBufferSize = 65535
|
||||
|
||||
type ObfsFakeTCPConn struct {
|
||||
orig *TCPConn
|
||||
obfs obfs.Obfuscator
|
||||
closed bool
|
||||
readBuf []byte
|
||||
readMutex sync.Mutex
|
||||
writeBuf []byte
|
||||
writeMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewObfsFakeTCPConn(orig *TCPConn, obfs obfs.Obfuscator) *ObfsFakeTCPConn {
|
||||
return &ObfsFakeTCPConn{
|
||||
orig: orig,
|
||||
obfs: obfs,
|
||||
readBuf: make([]byte, udpBufferSize),
|
||||
writeBuf: make([]byte, udpBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
for {
|
||||
c.readMutex.Lock()
|
||||
if c.closed {
|
||||
log.Infoln("read faketcp obfs before")
|
||||
}
|
||||
n, addr, err := c.orig.ReadFrom(c.readBuf)
|
||||
if c.closed {
|
||||
log.Infoln("read faketcp obfs after")
|
||||
}
|
||||
if n <= 0 {
|
||||
c.readMutex.Unlock()
|
||||
return 0, addr, err
|
||||
}
|
||||
newN := c.obfs.Deobfuscate(c.readBuf[:n], p)
|
||||
c.readMutex.Unlock()
|
||||
if newN > 0 {
|
||||
// Valid packet
|
||||
return newN, addr, err
|
||||
} else if err != nil {
|
||||
// Not valid and orig.ReadFrom had some error
|
||||
return 0, addr, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.writeMutex.Lock()
|
||||
bn := c.obfs.Obfuscate(p, c.writeBuf)
|
||||
_, err = c.orig.WriteTo(c.writeBuf[:bn], addr)
|
||||
c.writeMutex.Unlock()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) Close() error {
|
||||
c.closed = true
|
||||
return c.orig.Close()
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) LocalAddr() net.Addr {
|
||||
return c.orig.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SetDeadline(t time.Time) error {
|
||||
return c.orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SetReadDeadline(t time.Time) error {
|
||||
return c.orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SetReadBuffer(bytes int) error {
|
||||
return c.orig.SetReadBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SetWriteBuffer(bytes int) error {
|
||||
return c.orig.SetWriteBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *ObfsFakeTCPConn) SyscallConn() (syscall.RawConn, error) {
|
||||
return c.orig.SyscallConn()
|
||||
}
|
616
transport/hysteria/conns/faketcp/tcp_linux.go
Normal file
616
transport/hysteria/conns/faketcp/tcp_linux.go
Normal file
@ -0,0 +1,616 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package faketcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
var (
|
||||
errOpNotImplemented = errors.New("operation not implemented")
|
||||
errTimeout = errors.New("timeout")
|
||||
expire = time.Minute
|
||||
)
|
||||
|
||||
// a message from NIC
|
||||
type message struct {
|
||||
bts []byte
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
// a tcp flow information of a connection pair
|
||||
type tcpFlow struct {
|
||||
conn *net.TCPConn // the related system TCP connection of this flow
|
||||
handle *net.IPConn // the handle to send packets
|
||||
seq uint32 // TCP sequence number
|
||||
ack uint32 // TCP acknowledge number
|
||||
networkLayer gopacket.SerializableLayer // network layer header for tx
|
||||
ts time.Time // last packet incoming time
|
||||
buf gopacket.SerializeBuffer // a buffer for write
|
||||
tcpHeader layers.TCP
|
||||
}
|
||||
|
||||
// TCPConn defines a TCP-packet oriented connection
|
||||
type TCPConn struct {
|
||||
die chan struct{}
|
||||
dieOnce sync.Once
|
||||
|
||||
// the main golang sockets
|
||||
tcpconn *net.TCPConn // from net.Dial
|
||||
listener *net.TCPListener // from net.Listen
|
||||
|
||||
// handles
|
||||
handles []*net.IPConn
|
||||
|
||||
// packets captured from all related NICs will be delivered to this channel
|
||||
chMessage chan message
|
||||
|
||||
// all TCP flows
|
||||
flowTable map[string]*tcpFlow
|
||||
flowsLock sync.Mutex
|
||||
|
||||
// iptables
|
||||
iptables *iptables.IPTables
|
||||
iprule []string
|
||||
|
||||
ip6tables *iptables.IPTables
|
||||
ip6rule []string
|
||||
|
||||
// deadlines
|
||||
readDeadline atomic.Value
|
||||
writeDeadline atomic.Value
|
||||
|
||||
// serialization
|
||||
opts gopacket.SerializeOptions
|
||||
}
|
||||
|
||||
// lockflow locks the flow table and apply function `f` to the entry, and create one if not exist
|
||||
func (conn *TCPConn) lockflow(addr net.Addr, f func(e *tcpFlow)) {
|
||||
key := addr.String()
|
||||
conn.flowsLock.Lock()
|
||||
e := conn.flowTable[key]
|
||||
if e == nil { // entry first visit
|
||||
e = new(tcpFlow)
|
||||
e.ts = time.Now()
|
||||
e.buf = gopacket.NewSerializeBuffer()
|
||||
}
|
||||
f(e)
|
||||
conn.flowTable[key] = e
|
||||
conn.flowsLock.Unlock()
|
||||
}
|
||||
|
||||
// clean expired flows
|
||||
func (conn *TCPConn) cleaner() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
select {
|
||||
case <-conn.die:
|
||||
return
|
||||
case <-ticker.C:
|
||||
conn.flowsLock.Lock()
|
||||
for k, v := range conn.flowTable {
|
||||
if time.Now().Sub(v.ts) > expire {
|
||||
if v.conn != nil {
|
||||
setTTL(v.conn, 64)
|
||||
v.conn.Close()
|
||||
}
|
||||
delete(conn.flowTable, k)
|
||||
}
|
||||
}
|
||||
conn.flowsLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// captureFlow capture every inbound packets based on rules of BPF
|
||||
func (conn *TCPConn) captureFlow(handle *net.IPConn, port int) {
|
||||
buf := make([]byte, 2048)
|
||||
opt := gopacket.DecodeOptions{NoCopy: true, Lazy: true}
|
||||
for {
|
||||
n, addr, err := handle.ReadFromIP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// try decoding TCP frame from buf[:n]
|
||||
packet := gopacket.NewPacket(buf[:n], layers.LayerTypeTCP, opt)
|
||||
transport := packet.TransportLayer()
|
||||
tcp, ok := transport.(*layers.TCP)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// port filtering
|
||||
if int(tcp.DstPort) != port {
|
||||
continue
|
||||
}
|
||||
|
||||
// address building
|
||||
var src net.TCPAddr
|
||||
src.IP = addr.IP
|
||||
src.Port = int(tcp.SrcPort)
|
||||
|
||||
var orphan bool
|
||||
// flow maintaince
|
||||
conn.lockflow(&src, func(e *tcpFlow) {
|
||||
if e.conn == nil { // make sure it's related to net.TCPConn
|
||||
orphan = true // mark as orphan if it's not related net.TCPConn
|
||||
}
|
||||
|
||||
// to keep track of TCP header related to this source
|
||||
e.ts = time.Now()
|
||||
if tcp.ACK {
|
||||
e.seq = tcp.Ack
|
||||
}
|
||||
if tcp.SYN {
|
||||
e.ack = tcp.Seq + 1
|
||||
}
|
||||
if tcp.PSH {
|
||||
if e.ack == tcp.Seq {
|
||||
e.ack = tcp.Seq + uint32(len(tcp.Payload))
|
||||
}
|
||||
}
|
||||
e.handle = handle
|
||||
})
|
||||
|
||||
// push data if it's not orphan
|
||||
if !orphan && tcp.PSH {
|
||||
payload := make([]byte, len(tcp.Payload))
|
||||
copy(payload, tcp.Payload)
|
||||
select {
|
||||
case conn.chMessage <- message{payload, &src}:
|
||||
case <-conn.die:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom implements the PacketConn ReadFrom method.
|
||||
func (conn *TCPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
var timer *time.Timer
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := conn.readDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer = time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
return 0, nil, errTimeout
|
||||
case <-conn.die:
|
||||
return 0, nil, io.EOF
|
||||
case packet := <-conn.chMessage:
|
||||
n = copy(p, packet.bts)
|
||||
return n, packet.addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo implements the PacketConn WriteTo method.
|
||||
func (conn *TCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := conn.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer := time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
return 0, errTimeout
|
||||
case <-conn.die:
|
||||
return 0, io.EOF
|
||||
default:
|
||||
raddr, err := net.ResolveTCPAddr("tcp", addr.String())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var lport int
|
||||
if conn.tcpconn != nil {
|
||||
lport = conn.tcpconn.LocalAddr().(*net.TCPAddr).Port
|
||||
} else {
|
||||
lport = conn.listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
conn.lockflow(addr, func(e *tcpFlow) {
|
||||
// if the flow doesn't have handle , assume this packet has lost, without notification
|
||||
if e.handle == nil {
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
// build tcp header with local and remote port
|
||||
e.tcpHeader.SrcPort = layers.TCPPort(lport)
|
||||
e.tcpHeader.DstPort = layers.TCPPort(raddr.Port)
|
||||
binary.Read(rand.Reader, binary.LittleEndian, &e.tcpHeader.Window)
|
||||
e.tcpHeader.Window |= 0x8000 // make sure it's larger than 32768
|
||||
e.tcpHeader.Ack = e.ack
|
||||
e.tcpHeader.Seq = e.seq
|
||||
e.tcpHeader.PSH = true
|
||||
e.tcpHeader.ACK = true
|
||||
|
||||
// build IP header with src & dst ip for TCP checksum
|
||||
if raddr.IP.To4() != nil {
|
||||
ip := &layers.IPv4{
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To4(),
|
||||
DstIP: raddr.IP.To4(),
|
||||
}
|
||||
e.tcpHeader.SetNetworkLayerForChecksum(ip)
|
||||
} else {
|
||||
ip := &layers.IPv6{
|
||||
NextHeader: layers.IPProtocolTCP,
|
||||
SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To16(),
|
||||
DstIP: raddr.IP.To16(),
|
||||
}
|
||||
e.tcpHeader.SetNetworkLayerForChecksum(ip)
|
||||
}
|
||||
|
||||
e.buf.Clear()
|
||||
gopacket.SerializeLayers(e.buf, conn.opts, &e.tcpHeader, gopacket.Payload(p))
|
||||
if conn.tcpconn != nil {
|
||||
_, err = e.handle.Write(e.buf.Bytes())
|
||||
} else {
|
||||
_, err = e.handle.WriteToIP(e.buf.Bytes(), &net.IPAddr{IP: raddr.IP})
|
||||
}
|
||||
// increase seq in flow
|
||||
e.seq += uint32(len(p))
|
||||
n = len(p)
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (conn *TCPConn) Close() error {
|
||||
var err error
|
||||
conn.dieOnce.Do(func() {
|
||||
// signal closing
|
||||
close(conn.die)
|
||||
|
||||
// close all established tcp connections
|
||||
if conn.tcpconn != nil { // client
|
||||
setTTL(conn.tcpconn, 64)
|
||||
err = conn.tcpconn.Close()
|
||||
} else if conn.listener != nil {
|
||||
err = conn.listener.Close() // server
|
||||
conn.flowsLock.Lock()
|
||||
for k, v := range conn.flowTable {
|
||||
if v.conn != nil {
|
||||
setTTL(v.conn, 64)
|
||||
v.conn.Close()
|
||||
}
|
||||
delete(conn.flowTable, k)
|
||||
}
|
||||
conn.flowsLock.Unlock()
|
||||
}
|
||||
|
||||
// close handles
|
||||
for k := range conn.handles {
|
||||
conn.handles[k].Close()
|
||||
}
|
||||
|
||||
// delete iptable
|
||||
if conn.iptables != nil {
|
||||
conn.iptables.Delete("filter", "OUTPUT", conn.iprule...)
|
||||
}
|
||||
if conn.ip6tables != nil {
|
||||
conn.ip6tables.Delete("filter", "OUTPUT", conn.ip6rule...)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (conn *TCPConn) LocalAddr() net.Addr {
|
||||
if conn.tcpconn != nil {
|
||||
return conn.tcpconn.LocalAddr()
|
||||
} else if conn.listener != nil {
|
||||
return conn.listener.Addr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline implements the Conn SetDeadline method.
|
||||
func (conn *TCPConn) SetDeadline(t time.Time) error {
|
||||
if err := conn.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.SetWriteDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the Conn SetReadDeadline method.
|
||||
func (conn *TCPConn) SetReadDeadline(t time.Time) error {
|
||||
conn.readDeadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the Conn SetWriteDeadline method.
|
||||
func (conn *TCPConn) SetWriteDeadline(t time.Time) error {
|
||||
conn.writeDeadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
|
||||
func (conn *TCPConn) SetDSCP(dscp int) error {
|
||||
for k := range conn.handles {
|
||||
if err := setDSCP(conn.handles[k], dscp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadBuffer sets the size of the operating system's receive buffer associated with the connection.
|
||||
func (conn *TCPConn) SetReadBuffer(bytes int) error {
|
||||
var err error
|
||||
for k := range conn.handles {
|
||||
if err := conn.handles[k].SetReadBuffer(bytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SetWriteBuffer sets the size of the operating system's transmit buffer associated with the connection.
|
||||
func (conn *TCPConn) SetWriteBuffer(bytes int) error {
|
||||
var err error
|
||||
for k := range conn.handles {
|
||||
if err := conn.handles[k].SetWriteBuffer(bytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *TCPConn) SyscallConn() (syscall.RawConn, error) {
|
||||
if len(conn.handles) == 0 {
|
||||
return nil, errors.New("no handles")
|
||||
// How is it possible?
|
||||
}
|
||||
return conn.handles[0].SyscallConn()
|
||||
}
|
||||
|
||||
// Dial connects to the remote TCP port,
|
||||
// and returns a single packet-oriented connection
|
||||
func Dial(network, address string) (*TCPConn, error) {
|
||||
// remote address resolve
|
||||
raddr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// AF_INET
|
||||
handle, err := net.DialIP("ip:tcp", nil, &net.IPAddr{IP: raddr.IP})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create an established tcp connection
|
||||
// will hack this tcp connection for packet transmission
|
||||
tcpconn, err := net.DialTCP(network, nil, raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// fields
|
||||
conn := new(TCPConn)
|
||||
conn.die = make(chan struct{})
|
||||
conn.flowTable = make(map[string]*tcpFlow)
|
||||
conn.tcpconn = tcpconn
|
||||
conn.chMessage = make(chan message)
|
||||
conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
|
||||
conn.handles = append(conn.handles, handle)
|
||||
conn.opts = gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
go conn.captureFlow(handle, tcpconn.LocalAddr().(*net.TCPAddr).Port)
|
||||
go conn.cleaner()
|
||||
|
||||
// iptables
|
||||
err = setTTL(tcpconn, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
|
||||
rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
|
||||
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
|
||||
if !exists {
|
||||
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
|
||||
conn.iprule = rule
|
||||
conn.iptables = ipt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
|
||||
rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
|
||||
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
|
||||
if !exists {
|
||||
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
|
||||
conn.ip6rule = rule
|
||||
conn.ip6tables = ipt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// discard everything
|
||||
go io.Copy(ioutil.Discard, tcpconn)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Listen acts like net.ListenTCP,
|
||||
// and returns a single packet-oriented connection
|
||||
func Listen(network, address string) (*TCPConn, error) {
|
||||
// fields
|
||||
conn := new(TCPConn)
|
||||
conn.flowTable = make(map[string]*tcpFlow)
|
||||
conn.die = make(chan struct{})
|
||||
conn.chMessage = make(chan message)
|
||||
conn.opts = gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
// resolve address
|
||||
laddr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// AF_INET
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if laddr.IP == nil || laddr.IP.IsUnspecified() { // if address is not specified, capture on all ifaces
|
||||
var lasterr error
|
||||
for _, iface := range ifaces {
|
||||
if addrs, err := iface.Addrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
if ipaddr, ok := addr.(*net.IPNet); ok {
|
||||
if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: ipaddr.IP}); err == nil {
|
||||
conn.handles = append(conn.handles, handle)
|
||||
go conn.captureFlow(handle, laddr.Port)
|
||||
} else {
|
||||
lasterr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(conn.handles) == 0 {
|
||||
return nil, lasterr
|
||||
}
|
||||
} else {
|
||||
if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: laddr.IP}); err == nil {
|
||||
conn.handles = append(conn.handles, handle)
|
||||
go conn.captureFlow(handle, laddr.Port)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// start listening
|
||||
l, err := net.ListenTCP(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.listener = l
|
||||
|
||||
// start cleaner
|
||||
go conn.cleaner()
|
||||
|
||||
// iptables drop packets marked with TTL = 1
|
||||
// TODO: what if iptables is not available, the next hop will send back ICMP Time Exceeded,
|
||||
// is this still an acceptable behavior?
|
||||
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
|
||||
rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
|
||||
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
|
||||
if !exists {
|
||||
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
|
||||
conn.iprule = rule
|
||||
conn.iptables = ipt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
|
||||
rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
|
||||
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
|
||||
if !exists {
|
||||
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
|
||||
conn.ip6rule = rule
|
||||
conn.ip6tables = ipt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// discard everything in original connection
|
||||
go func() {
|
||||
for {
|
||||
tcpconn, err := l.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// if we cannot set TTL = 1, the only thing reasonable is panic
|
||||
if err := setTTL(tcpconn, 1); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// record net.Conn
|
||||
conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
|
||||
|
||||
// discard everything
|
||||
go io.Copy(ioutil.Discard, tcpconn)
|
||||
}
|
||||
}()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// setTTL sets the Time-To-Live field on a given connection
|
||||
func setTTL(c *net.TCPConn, ttl int) error {
|
||||
raw, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr := c.LocalAddr().(*net.TCPAddr)
|
||||
|
||||
if addr.IP.To4() == nil {
|
||||
raw.Control(func(fd uintptr) {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, ttl)
|
||||
})
|
||||
} else {
|
||||
raw.Control(func(fd uintptr) {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, ttl)
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// setDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
|
||||
func setDSCP(c *net.IPConn, dscp int) error {
|
||||
raw, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addr := c.LocalAddr().(*net.IPAddr)
|
||||
|
||||
if addr.IP.To4() == nil {
|
||||
raw.Control(func(fd uintptr) {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, dscp)
|
||||
})
|
||||
} else {
|
||||
raw.Control(func(fd uintptr) {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, dscp<<2)
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
21
transport/hysteria/conns/faketcp/tcp_stub.go
Normal file
21
transport/hysteria/conns/faketcp/tcp_stub.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package faketcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type TCPConn struct{ *net.UDPConn }
|
||||
|
||||
// Dial connects to the remote TCP port,
|
||||
// and returns a single packet-oriented connection
|
||||
func Dial(network, address string) (*TCPConn, error) {
|
||||
return nil, errors.New("faketcp is not supported on this platform")
|
||||
}
|
||||
|
||||
func Listen(network, address string) (*TCPConn, error) {
|
||||
return nil, errors.New("faketcp is not supported on this platform")
|
||||
}
|
196
transport/hysteria/conns/faketcp/tcp_test.go
Normal file
196
transport/hysteria/conns/faketcp/tcp_test.go
Normal file
@ -0,0 +1,196 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package faketcp
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//const testPortStream = "127.0.0.1:3456"
|
||||
//const testPortPacket = "127.0.0.1:3457"
|
||||
|
||||
const testPortStream = "127.0.0.1:3456"
|
||||
const portServerPacket = "[::]:3457"
|
||||
const portRemotePacket = "127.0.0.1:3457"
|
||||
|
||||
func init() {
|
||||
startTCPServer()
|
||||
startTCPRawServer()
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||
}()
|
||||
}
|
||||
|
||||
func startTCPServer() net.Listener {
|
||||
l, err := net.Listen("tcp", testPortStream)
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer l.Close()
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
go handleRequest(conn)
|
||||
}
|
||||
}()
|
||||
return l
|
||||
}
|
||||
|
||||
func startTCPRawServer() *TCPConn {
|
||||
conn, err := Listen("tcp", portServerPacket)
|
||||
if err != nil {
|
||||
log.Panicln(err)
|
||||
}
|
||||
err = conn.SetReadBuffer(1024 * 1024)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
err = conn.SetWriteBuffer(1024 * 1024)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, addr, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Println("server readfrom:", err)
|
||||
return
|
||||
}
|
||||
//echo
|
||||
n, err = conn.WriteTo(buf[:n], addr)
|
||||
if err != nil {
|
||||
log.Println("server writeTo:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return conn
|
||||
}
|
||||
|
||||
func handleRequest(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
size, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Println("handleRequest:", err)
|
||||
return
|
||||
}
|
||||
data := buf[:size]
|
||||
conn.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTCPStream(t *testing.T) {
|
||||
conn, err := Dial("tcp", testPortStream)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", testPortStream)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err := conn.WriteTo([]byte("abc"), addr)
|
||||
if err != nil {
|
||||
t.Fatal(n, err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
if n, addr, err := conn.ReadFrom(buf); err != nil {
|
||||
t.Fatal(n, addr, err)
|
||||
} else {
|
||||
log.Println(string(buf[:n]), "from:", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialToTCPPacket(t *testing.T) {
|
||||
conn, err := Dial("tcp", portRemotePacket)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", portRemotePacket)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err := conn.WriteTo([]byte("abc"), addr)
|
||||
if err != nil {
|
||||
t.Fatal(n, err)
|
||||
}
|
||||
log.Println("written")
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
log.Println("readfrom buf")
|
||||
if n, addr, err := conn.ReadFrom(buf); err != nil {
|
||||
log.Println(err)
|
||||
t.Fatal(n, addr, err)
|
||||
} else {
|
||||
log.Println(string(buf[:n]), "from:", addr)
|
||||
}
|
||||
|
||||
log.Println("complete")
|
||||
}
|
||||
|
||||
func TestSettings(t *testing.T) {
|
||||
conn, err := Dial("tcp", portRemotePacket)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := conn.SetDSCP(46); err != nil {
|
||||
log.Fatal("SetDSCP:", err)
|
||||
}
|
||||
if err := conn.SetReadBuffer(4096); err != nil {
|
||||
log.Fatal("SetReaderBuffer:", err)
|
||||
}
|
||||
if err := conn.SetWriteBuffer(4096); err != nil {
|
||||
log.Fatal("SetWriteBuffer:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEcho(b *testing.B) {
|
||||
conn, err := Dial("tcp", portRemotePacket)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", portRemotePacket)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(buf)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
n, err := conn.WriteTo(buf, addr)
|
||||
if err != nil {
|
||||
b.Fatal(n, err)
|
||||
}
|
||||
|
||||
if n, addr, err := conn.ReadFrom(buf); err != nil {
|
||||
b.Fatal(n, addr, err)
|
||||
}
|
||||
}
|
||||
}
|
89
transport/hysteria/conns/udp/obfs.go
Normal file
89
transport/hysteria/conns/udp/obfs.go
Normal file
@ -0,0 +1,89 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/obfs"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const udpBufferSize = 65535
|
||||
|
||||
type ObfsUDPConn struct {
|
||||
orig net.PacketConn
|
||||
obfs obfs.Obfuscator
|
||||
readBuf []byte
|
||||
readMutex sync.Mutex
|
||||
writeBuf []byte
|
||||
writeMutex sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewObfsUDPConn(orig net.PacketConn, obfs obfs.Obfuscator) *ObfsUDPConn {
|
||||
return &ObfsUDPConn{
|
||||
orig: orig,
|
||||
obfs: obfs,
|
||||
readBuf: make([]byte, udpBufferSize),
|
||||
writeBuf: make([]byte, udpBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
for {
|
||||
c.readMutex.Lock()
|
||||
if c.closed {
|
||||
log.Infoln("read udp obfs before")
|
||||
}
|
||||
n, addr, err := c.orig.ReadFrom(c.readBuf)
|
||||
if c.closed {
|
||||
log.Infoln("read udp obfs after")
|
||||
}
|
||||
if n <= 0 {
|
||||
c.readMutex.Unlock()
|
||||
return 0, addr, err
|
||||
}
|
||||
newN := c.obfs.Deobfuscate(c.readBuf[:n], p)
|
||||
c.readMutex.Unlock()
|
||||
if newN > 0 {
|
||||
// Valid packet
|
||||
return newN, addr, err
|
||||
} else if err != nil {
|
||||
// Not valid and orig.ReadFrom had some error
|
||||
return 0, addr, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.writeMutex.Lock()
|
||||
bn := c.obfs.Obfuscate(p, c.writeBuf)
|
||||
_, err = c.orig.WriteTo(c.writeBuf[:bn], addr)
|
||||
c.writeMutex.Unlock()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) Close() error {
|
||||
c.closed = true
|
||||
return c.orig.Close()
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) LocalAddr() net.Addr {
|
||||
return c.orig.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) SetDeadline(t time.Time) error {
|
||||
return c.orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) SetReadDeadline(t time.Time) error {
|
||||
return c.orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsUDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.orig.SetWriteDeadline(t)
|
||||
}
|
105
transport/hysteria/conns/wechat/obfs.go
Normal file
105
transport/hysteria/conns/wechat/obfs.go
Normal file
@ -0,0 +1,105 @@
|
||||
package wechat
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/obfs"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const udpBufferSize = 65535
|
||||
|
||||
type ObfsWeChatUDPConn struct {
|
||||
orig net.PacketConn
|
||||
obfs obfs.Obfuscator
|
||||
closed bool
|
||||
readBuf []byte
|
||||
readMutex sync.Mutex
|
||||
writeBuf []byte
|
||||
writeMutex sync.Mutex
|
||||
sn uint32
|
||||
}
|
||||
|
||||
func NewObfsWeChatUDPConn(orig net.PacketConn, obfs obfs.Obfuscator) *ObfsWeChatUDPConn {
|
||||
log.Infoln("new wechat")
|
||||
return &ObfsWeChatUDPConn{
|
||||
orig: orig,
|
||||
obfs: obfs,
|
||||
readBuf: make([]byte, udpBufferSize),
|
||||
writeBuf: make([]byte, udpBufferSize),
|
||||
sn: rand.Uint32() & 0xFFFF,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
for {
|
||||
c.readMutex.Lock()
|
||||
if c.closed {
|
||||
log.Infoln("read wechat obfs before")
|
||||
}
|
||||
n, addr, err := c.orig.ReadFrom(c.readBuf)
|
||||
if c.closed {
|
||||
log.Infoln("read wechat obfs after")
|
||||
}
|
||||
if n <= 13 {
|
||||
c.readMutex.Unlock()
|
||||
return 0, addr, err
|
||||
}
|
||||
newN := c.obfs.Deobfuscate(c.readBuf[13:n], p)
|
||||
c.readMutex.Unlock()
|
||||
if newN > 0 {
|
||||
// Valid packet
|
||||
return newN, addr, err
|
||||
} else if err != nil {
|
||||
// Not valid and orig.ReadFrom had some error
|
||||
return 0, addr, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.writeMutex.Lock()
|
||||
c.writeBuf[0] = 0xa1
|
||||
c.writeBuf[1] = 0x08
|
||||
binary.BigEndian.PutUint32(c.writeBuf[2:], c.sn)
|
||||
c.sn++
|
||||
c.writeBuf[6] = 0x00
|
||||
c.writeBuf[7] = 0x10
|
||||
c.writeBuf[8] = 0x11
|
||||
c.writeBuf[9] = 0x18
|
||||
c.writeBuf[10] = 0x30
|
||||
c.writeBuf[11] = 0x22
|
||||
c.writeBuf[12] = 0x30
|
||||
bn := c.obfs.Obfuscate(p, c.writeBuf[13:])
|
||||
_, err = c.orig.WriteTo(c.writeBuf[:13+bn], addr)
|
||||
c.writeMutex.Unlock()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) Close() error {
|
||||
c.closed = true
|
||||
return c.orig.Close()
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) LocalAddr() net.Addr {
|
||||
return c.orig.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) SetDeadline(t time.Time) error {
|
||||
return c.orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) SetReadDeadline(t time.Time) error {
|
||||
return c.orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *ObfsWeChatUDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.orig.SetWriteDeadline(t)
|
||||
}
|
422
transport/hysteria/core/client.go
Normal file
422
transport/hysteria/core/client.go
Normal file
@ -0,0 +1,422 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/obfs"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/pmtud_fix"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/transport"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/utils"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lunixbochs/struc"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("closed")
|
||||
)
|
||||
|
||||
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
||||
|
||||
type Client struct {
|
||||
transport *transport.ClientTransport
|
||||
serverAddr string
|
||||
protocol string
|
||||
sendBPS, recvBPS uint64
|
||||
auth []byte
|
||||
congestionFactory CongestionFactory
|
||||
obfuscator obfs.Obfuscator
|
||||
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
|
||||
quicSession quic.Connection
|
||||
reconnectMutex sync.Mutex
|
||||
closed bool
|
||||
|
||||
udpSessionMutex sync.RWMutex
|
||||
udpSessionMap map[uint32]chan *udpMessage
|
||||
udpDefragger defragger
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
|
||||
obfuscator obfs.Obfuscator) (*Client, error) {
|
||||
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
|
||||
c := &Client{
|
||||
transport: transport,
|
||||
serverAddr: serverAddr,
|
||||
protocol: protocol,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
auth: auth,
|
||||
congestionFactory: congestionFactory,
|
||||
obfuscator: obfuscator,
|
||||
tlsConfig: tlsConfig,
|
||||
quicConfig: quicConfig,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) connectToServer(dialer transport.PacketDialer) error {
|
||||
qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, dialer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Control stream
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
|
||||
stream, err := qs.OpenStreamSync(ctx)
|
||||
ctxCancel()
|
||||
if err != nil {
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return err
|
||||
}
|
||||
ok, msg, err := c.handleControlStream(qs, stream)
|
||||
if err != nil {
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
|
||||
return fmt.Errorf("auth error: %s", msg)
|
||||
}
|
||||
// All good
|
||||
c.udpSessionMap = make(map[uint32]chan *udpMessage)
|
||||
go c.handleMessage(qs)
|
||||
c.quicSession = qs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) {
|
||||
// Send protocol version
|
||||
_, err := stream.Write([]byte{protocolVersion})
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
// Send client hello
|
||||
err = struc.Pack(stream, &clientHello{
|
||||
Rate: transmissionRate{
|
||||
SendBPS: c.sendBPS,
|
||||
RecvBPS: c.recvBPS,
|
||||
},
|
||||
Auth: c.auth,
|
||||
})
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
// Receive server hello
|
||||
var sh serverHello
|
||||
err = struc.Unpack(stream, &sh)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
// Set the congestion accordingly
|
||||
if sh.OK && c.congestionFactory != nil {
|
||||
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
|
||||
}
|
||||
return sh.OK, sh.Message, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleMessage(qs quic.Connection) {
|
||||
for {
|
||||
msg, err := qs.ReceiveMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var udpMsg udpMessage
|
||||
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
dfMsg := c.udpDefragger.Feed(udpMsg)
|
||||
if dfMsg == nil {
|
||||
continue
|
||||
}
|
||||
c.udpSessionMutex.RLock()
|
||||
ch, ok := c.udpSessionMap[dfMsg.SessionID]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- dfMsg:
|
||||
// OK
|
||||
default:
|
||||
// Silently drop the message when the channel is full
|
||||
}
|
||||
}
|
||||
c.udpSessionMutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) openStreamWithReconnect(dialer transport.PacketDialer) (quic.Connection, quic.Stream, error) {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
if c.closed {
|
||||
return nil, nil, ErrClosed
|
||||
}
|
||||
if c.quicSession == nil {
|
||||
if err := c.connectToServer(dialer); err != nil {
|
||||
// Still error, oops
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
stream, err := c.quicSession.OpenStream()
|
||||
if err == nil {
|
||||
// All good
|
||||
return c.quicSession, &wrappedQUICStream{stream}, nil
|
||||
}
|
||||
// Something is wrong
|
||||
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
||||
// Temporary error, just return
|
||||
return nil, nil, err
|
||||
}
|
||||
// Permanent error, need to reconnect
|
||||
if err := c.connectToServer(dialer); err != nil {
|
||||
// Still error, oops
|
||||
return nil, nil, err
|
||||
}
|
||||
// We are not going to try again even if it still fails the second time
|
||||
stream, err = c.quicSession.OpenStream()
|
||||
return c.quicSession, &wrappedQUICStream{stream}, err
|
||||
}
|
||||
|
||||
func (c *Client) DialTCP(addr string, dialer transport.PacketDialer) (net.Conn, error) {
|
||||
host, port, err := utils.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, stream, err := c.openStreamWithReconnect(dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: false,
|
||||
Host: host,
|
||||
Port: port,
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
var sr serverResponse
|
||||
err = struc.Unpack(stream, &sr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !sr.OK {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||
}
|
||||
return &quicConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: session.LocalAddr(),
|
||||
PseudoRemoteAddr: session.RemoteAddr(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialUDP(dialer transport.PacketDialer) (UDPConn, error) {
|
||||
session, stream, err := c.openStreamWithReconnect(dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
var sr serverResponse
|
||||
err = struc.Unpack(stream, &sr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !sr.OK {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||
}
|
||||
|
||||
// Create a session in the map
|
||||
c.udpSessionMutex.Lock()
|
||||
nCh := make(chan *udpMessage, 1024)
|
||||
// Store the current session map for CloseFunc below
|
||||
// to ensures that we are adding and removing sessions on the same map,
|
||||
// as reconnecting will reassign the map
|
||||
sessionMap := c.udpSessionMap
|
||||
sessionMap[sr.UDPSessionID] = nCh
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
pktConn := &quicPktConn{
|
||||
Session: session,
|
||||
Stream: stream,
|
||||
CloseFunc: func() {
|
||||
c.udpSessionMutex.Lock()
|
||||
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
|
||||
close(ch)
|
||||
delete(sessionMap, sr.UDPSessionID)
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
},
|
||||
UDPSessionID: sr.UDPSessionID,
|
||||
MsgCh: nCh,
|
||||
}
|
||||
go pktConn.Hold()
|
||||
return pktConn, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
|
||||
c.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
type quicConn struct {
|
||||
Orig quic.Stream
|
||||
PseudoLocalAddr net.Addr
|
||||
PseudoRemoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *quicConn) Read(b []byte) (n int, err error) {
|
||||
return w.Orig.Read(b)
|
||||
}
|
||||
|
||||
func (w *quicConn) Write(b []byte) (n int, err error) {
|
||||
return w.Orig.Write(b)
|
||||
}
|
||||
|
||||
func (w *quicConn) Close() error {
|
||||
return w.Orig.Close()
|
||||
}
|
||||
|
||||
func (w *quicConn) LocalAddr() net.Addr {
|
||||
return w.PseudoLocalAddr
|
||||
}
|
||||
|
||||
func (w *quicConn) RemoteAddr() net.Addr {
|
||||
return w.PseudoRemoteAddr
|
||||
}
|
||||
|
||||
func (w *quicConn) SetDeadline(t time.Time) error {
|
||||
return w.Orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (w *quicConn) SetReadDeadline(t time.Time) error {
|
||||
return w.Orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (w *quicConn) SetWriteDeadline(t time.Time) error {
|
||||
return w.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type UDPConn interface {
|
||||
ReadFrom() ([]byte, string, error)
|
||||
WriteTo([]byte, string) error
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
SetDeadline(t time.Time) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type quicPktConn struct {
|
||||
Session quic.Connection
|
||||
Stream quic.Stream
|
||||
CloseFunc func()
|
||||
UDPSessionID uint32
|
||||
MsgCh <-chan *udpMessage
|
||||
}
|
||||
|
||||
func (c *quicPktConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.Stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
|
||||
msg := <-c.MsgCh
|
||||
if msg == nil {
|
||||
// Closed
|
||||
return nil, "", ErrClosed
|
||||
}
|
||||
return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil
|
||||
}
|
||||
|
||||
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
|
||||
host, port, err := utils.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := udpMessage{
|
||||
SessionID: c.UDPSessionID,
|
||||
Host: host,
|
||||
Port: port,
|
||||
FragCount: 1,
|
||||
Data: p,
|
||||
}
|
||||
// try no frag first
|
||||
var msgBuf bytes.Buffer
|
||||
_ = struc.Pack(&msgBuf, &msg)
|
||||
err = c.Session.SendMessage(msgBuf.Bytes())
|
||||
if err != nil {
|
||||
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
|
||||
// need to frag
|
||||
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := fragUDPMessage(msg, int(errSize))
|
||||
for _, fragMsg := range fragMsgs {
|
||||
msgBuf.Reset()
|
||||
_ = struc.Pack(&msgBuf, &fragMsg)
|
||||
err = c.Session.SendMessage(msgBuf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
// some other error
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *quicPktConn) Close() error {
|
||||
c.CloseFunc()
|
||||
return c.Stream.Close()
|
||||
}
|
||||
|
||||
func (c *quicPktConn) LocalAddr() net.Addr {
|
||||
return c.Session.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *quicPktConn) SetDeadline(t time.Time) error {
|
||||
return c.Stream.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *quicPktConn) SetReadDeadline(t time.Time) error {
|
||||
return c.Stream.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *quicPktConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.Stream.SetWriteDeadline(t)
|
||||
}
|
67
transport/hysteria/core/frag.go
Normal file
67
transport/hysteria/core/frag.go
Normal file
@ -0,0 +1,67 @@
|
||||
package core
|
||||
|
||||
func fragUDPMessage(m udpMessage, maxSize int) []udpMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []udpMessage{m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
off := 0
|
||||
fragID := uint8(0)
|
||||
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||
var frags []udpMessage
|
||||
for off < len(fullPayload) {
|
||||
payloadSize := len(fullPayload) - off
|
||||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.DataLen = uint16(payloadSize)
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
frags = append(frags, frag)
|
||||
off += payloadSize
|
||||
fragID++
|
||||
}
|
||||
return frags
|
||||
}
|
||||
|
||||
type defragger struct {
|
||||
msgID uint16
|
||||
frags []*udpMessage
|
||||
count uint8
|
||||
}
|
||||
|
||||
func (d *defragger) Feed(m udpMessage) *udpMessage {
|
||||
if m.FragCount <= 1 {
|
||||
return &m
|
||||
}
|
||||
if m.FragID >= m.FragCount {
|
||||
// wtf is this?
|
||||
return nil
|
||||
}
|
||||
if m.MsgID != d.msgID {
|
||||
// new message, clear previous state
|
||||
d.msgID = m.MsgID
|
||||
d.frags = make([]*udpMessage, m.FragCount)
|
||||
d.count = 1
|
||||
d.frags[m.FragID] = &m
|
||||
} else if d.frags[m.FragID] == nil {
|
||||
d.frags[m.FragID] = &m
|
||||
d.count++
|
||||
if int(d.count) == len(d.frags) {
|
||||
// all fragments received, assemble
|
||||
var data []byte
|
||||
for _, frag := range d.frags {
|
||||
data = append(data, frag.Data...)
|
||||
}
|
||||
m.DataLen = uint16(len(data))
|
||||
m.Data = data
|
||||
m.FragID = 0
|
||||
m.FragCount = 1
|
||||
return &m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
346
transport/hysteria/core/frag_test.go
Normal file
346
transport/hysteria/core/frag_test.go
Normal file
@ -0,0 +1,346 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_fragUDPMessage(t *testing.T) {
|
||||
type args struct {
|
||||
m udpMessage
|
||||
maxSize int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []udpMessage
|
||||
}{
|
||||
{
|
||||
"no frag",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
100,
|
||||
},
|
||||
[]udpMessage{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"2 frags",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
22,
|
||||
},
|
||||
[]udpMessage{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 2,
|
||||
DataLen: 4,
|
||||
Data: []byte("hell"),
|
||||
},
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 1,
|
||||
FragCount: 2,
|
||||
DataLen: 1,
|
||||
Data: []byte("o"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"4 frags",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 20,
|
||||
Data: []byte("wow wow wow lol lmao"),
|
||||
},
|
||||
23,
|
||||
},
|
||||
[]udpMessage{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 4,
|
||||
DataLen: 5,
|
||||
Data: []byte("wow w"),
|
||||
},
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 1,
|
||||
FragCount: 4,
|
||||
DataLen: 5,
|
||||
Data: []byte("ow wo"),
|
||||
},
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 2,
|
||||
FragCount: 4,
|
||||
DataLen: 5,
|
||||
Data: []byte("w lol"),
|
||||
},
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 3,
|
||||
FragCount: 4,
|
||||
DataLen: 5,
|
||||
Data: []byte(" lmao"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := fragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("fragUDPMessage() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_defragger_Feed(t *testing.T) {
|
||||
d := &defragger{}
|
||||
type args struct {
|
||||
m udpMessage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *udpMessage
|
||||
}{
|
||||
{
|
||||
"no frag",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
&udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 123,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"frag 1 - 1/3",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 666,
|
||||
FragID: 0,
|
||||
FragCount: 3,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"frag 1 - 2/3",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 666,
|
||||
FragID: 1,
|
||||
FragCount: 3,
|
||||
DataLen: 8,
|
||||
Data: []byte(" shitty "),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"frag 1 - 3/3",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 666,
|
||||
FragID: 2,
|
||||
FragCount: 3,
|
||||
DataLen: 7,
|
||||
Data: []byte("world!!"),
|
||||
},
|
||||
},
|
||||
&udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 666,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 20,
|
||||
Data: []byte("hello shitty world!!"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"frag 2 - 1/2",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 777,
|
||||
FragID: 0,
|
||||
FragCount: 2,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"frag 3 - 2/2",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 778,
|
||||
FragID: 1,
|
||||
FragCount: 2,
|
||||
DataLen: 5,
|
||||
Data: []byte(" moto"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"frag 2 - 2/2",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 777,
|
||||
FragID: 1,
|
||||
FragCount: 2,
|
||||
DataLen: 5,
|
||||
Data: []byte(" moto"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"frag 2 - 1/2 re",
|
||||
args{
|
||||
udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 777,
|
||||
FragID: 0,
|
||||
FragCount: 2,
|
||||
DataLen: 5,
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
},
|
||||
&udpMessage{
|
||||
SessionID: 123,
|
||||
HostLen: 4,
|
||||
Host: "test",
|
||||
Port: 123,
|
||||
MsgID: 777,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
DataLen: 10,
|
||||
Data: []byte("hello moto"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := d.Feed(tt.args.m); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Feed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
76
transport/hysteria/core/protocol.go
Normal file
76
transport/hysteria/core/protocol.go
Normal file
@ -0,0 +1,76 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersion = uint8(3)
|
||||
protocolVersionV2 = uint8(2)
|
||||
protocolTimeout = 10 * time.Second
|
||||
|
||||
closeErrorCodeGeneric = 0
|
||||
closeErrorCodeProtocol = 1
|
||||
closeErrorCodeAuth = 2
|
||||
)
|
||||
|
||||
type transmissionRate struct {
|
||||
SendBPS uint64
|
||||
RecvBPS uint64
|
||||
}
|
||||
|
||||
type clientHello struct {
|
||||
Rate transmissionRate
|
||||
AuthLen uint16 `struc:"sizeof=Auth"`
|
||||
Auth []byte
|
||||
}
|
||||
|
||||
type serverHello struct {
|
||||
OK bool
|
||||
Rate transmissionRate
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
||||
|
||||
type clientRequest struct {
|
||||
UDP bool
|
||||
HostLen uint16 `struc:"sizeof=Host"`
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
type serverResponse struct {
|
||||
OK bool
|
||||
UDPSessionID uint32
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
||||
|
||||
type udpMessage struct {
|
||||
SessionID uint32
|
||||
HostLen uint16 `struc:"sizeof=Host"`
|
||||
Host string
|
||||
Port uint16
|
||||
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
|
||||
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
|
||||
FragCount uint8 // must be 1 when not fragmented
|
||||
DataLen uint16 `struc:"sizeof=Data"`
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (m udpMessage) HeaderSize() int {
|
||||
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
|
||||
}
|
||||
|
||||
func (m udpMessage) Size() int {
|
||||
return m.HeaderSize() + len(m.Data)
|
||||
}
|
||||
|
||||
type udpMessageV2 struct {
|
||||
SessionID uint32
|
||||
HostLen uint16 `struc:"sizeof=Host"`
|
||||
Host string
|
||||
Port uint16
|
||||
DataLen uint16 `struc:"sizeof=Data"`
|
||||
Data []byte
|
||||
}
|
54
transport/hysteria/core/stream.go
Normal file
54
transport/hysteria/core/stream.go
Normal file
@ -0,0 +1,54 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Handle stream close properly
|
||||
// Ref: https://github.com/libp2p/go-libp2p-quic-transport/blob/master/stream.go
|
||||
type wrappedQUICStream struct {
|
||||
Stream quic.Stream
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) StreamID() quic.StreamID {
|
||||
return s.Stream.StreamID()
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) Read(p []byte) (n int, err error) {
|
||||
return s.Stream.Read(p)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) CancelRead(code quic.StreamErrorCode) {
|
||||
s.Stream.CancelRead(code)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) SetReadDeadline(t time.Time) error {
|
||||
return s.Stream.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) Write(p []byte) (n int, err error) {
|
||||
return s.Stream.Write(p)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) Close() error {
|
||||
s.Stream.CancelRead(0)
|
||||
return s.Stream.Close()
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) CancelWrite(code quic.StreamErrorCode) {
|
||||
s.Stream.CancelWrite(code)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) Context() context.Context {
|
||||
return s.Stream.Context()
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) SetWriteDeadline(t time.Time) error {
|
||||
return s.Stream.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *wrappedQUICStream) SetDeadline(t time.Time) error {
|
||||
return s.Stream.SetDeadline(t)
|
||||
}
|
6
transport/hysteria/obfs/obfs.go
Normal file
6
transport/hysteria/obfs/obfs.go
Normal file
@ -0,0 +1,6 @@
|
||||
package obfs
|
||||
|
||||
type Obfuscator interface {
|
||||
Deobfuscate(in []byte, out []byte) int
|
||||
Obfuscate(in []byte, out []byte) int
|
||||
}
|
52
transport/hysteria/obfs/xplus.go
Normal file
52
transport/hysteria/obfs/xplus.go
Normal file
@ -0,0 +1,52 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// [salt][obfuscated payload]
|
||||
|
||||
const saltLen = 16
|
||||
|
||||
type XPlusObfuscator struct {
|
||||
Key []byte
|
||||
RandSrc *rand.Rand
|
||||
|
||||
lk sync.Mutex
|
||||
}
|
||||
|
||||
func NewXPlusObfuscator(key []byte) *XPlusObfuscator {
|
||||
return &XPlusObfuscator{
|
||||
Key: key,
|
||||
RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (x *XPlusObfuscator) Deobfuscate(in []byte, out []byte) int {
|
||||
pLen := len(in) - saltLen
|
||||
if pLen <= 0 || len(out) < pLen {
|
||||
// Invalid
|
||||
return 0
|
||||
}
|
||||
key := sha256.Sum256(append(x.Key, in[:saltLen]...))
|
||||
// Deobfuscate the payload
|
||||
for i, c := range in[saltLen:] {
|
||||
out[i] = c ^ key[i%sha256.Size]
|
||||
}
|
||||
return pLen
|
||||
}
|
||||
|
||||
func (x *XPlusObfuscator) Obfuscate(in []byte, out []byte) int {
|
||||
x.lk.Lock()
|
||||
_, _ = x.RandSrc.Read(out[:saltLen]) // salt
|
||||
x.lk.Unlock()
|
||||
// Obfuscate the payload
|
||||
key := sha256.Sum256(append(x.Key, out[:saltLen]...))
|
||||
for i, c := range in {
|
||||
out[i+saltLen] = c ^ key[i%sha256.Size]
|
||||
}
|
||||
return len(in) + saltLen
|
||||
}
|
31
transport/hysteria/obfs/xplus_test.go
Normal file
31
transport/hysteria/obfs/xplus_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package obfs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestXPlusObfuscator(t *testing.T) {
|
||||
x := NewXPlusObfuscator([]byte("Vaundy"))
|
||||
tests := []struct {
|
||||
name string
|
||||
p []byte
|
||||
}{
|
||||
{name: "1", p: []byte("HelloWorld")},
|
||||
{name: "2", p: []byte("Regret is just a horrible attempt at time travel that ends with you feeling like crap")},
|
||||
{name: "3", p: []byte("To be, or not to be, that is the question:\nWhether 'tis nobler in the mind to suffer\n" +
|
||||
"The slings and arrows of outrageous fortune,\nOr to take arms against a sea of troubles\n" +
|
||||
"And by opposing end them. To die—to sleep,\nNo more; and by a sleep to say we end")},
|
||||
{name: "empty", p: []byte("")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := make([]byte, 10240)
|
||||
n := x.Obfuscate(tt.p, buf)
|
||||
n2 := x.Deobfuscate(buf[:n], buf[n:])
|
||||
if !bytes.Equal(tt.p, buf[n:n+n2]) {
|
||||
t.Errorf("Inconsistent deobfuscate result: got %v, want %v", buf[n:n+n2], tt.p)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
8
transport/hysteria/pmtud_fix/avail.go
Normal file
8
transport/hysteria/pmtud_fix/avail.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build linux || windows
|
||||
// +build linux windows
|
||||
|
||||
package pmtud_fix
|
||||
|
||||
const (
|
||||
DisablePathMTUDiscovery = false
|
||||
)
|
8
transport/hysteria/pmtud_fix/unavail.go
Normal file
8
transport/hysteria/pmtud_fix/unavail.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build !linux && !windows
|
||||
// +build !linux,!windows
|
||||
|
||||
package pmtud_fix
|
||||
|
||||
const (
|
||||
DisablePathMTUDiscovery = true
|
||||
)
|
106
transport/hysteria/transport/client.go
Normal file
106
transport/hysteria/transport/client.go
Normal file
@ -0,0 +1,106 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/Dreamacro/clash/component/resolver"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/conns/faketcp"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/conns/udp"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/conns/wechat"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/obfs"
|
||||
"github.com/Dreamacro/clash/transport/hysteria/utils"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"net"
|
||||
)
|
||||
|
||||
type ClientTransport struct {
|
||||
Dialer *net.Dialer
|
||||
PrefEnabled bool
|
||||
PrefIPv6 bool
|
||||
PrefExclusive bool
|
||||
}
|
||||
|
||||
func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfs.Obfuscator, dialer PacketDialer) (net.PacketConn, error) {
|
||||
if len(proto) == 0 || proto == "udp" {
|
||||
conn, err := dialer.ListenPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if obfs != nil {
|
||||
oc := udp.NewObfsUDPConn(conn, obfs)
|
||||
return oc, nil
|
||||
} else {
|
||||
return conn, nil
|
||||
}
|
||||
} else if proto == "wechat-video" {
|
||||
conn, err := dialer.ListenPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if obfs != nil {
|
||||
oc := wechat.NewObfsWeChatUDPConn(conn, obfs)
|
||||
return oc, nil
|
||||
} else {
|
||||
return conn, nil
|
||||
}
|
||||
} else if proto == "faketcp" {
|
||||
var conn *faketcp.TCPConn
|
||||
conn, err := faketcp.Dial("tcp", server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if obfs != nil {
|
||||
oc := faketcp.NewObfsFakeTCPConn(conn, obfs)
|
||||
return oc, nil
|
||||
} else {
|
||||
return conn, nil
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
}
|
||||
|
||||
type PacketDialer interface {
|
||||
ListenPacket() (net.PacketConn, error)
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator, dialer PacketDialer) (quic.Connection, error) {
|
||||
ipStr, port, err := utils.SplitHostPort(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip, err := resolver.ResolveProxyServerHost(ipStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverUDPAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", ip, port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pktConn, err := ct.quicPacketConn(proto, server, obfs, dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qs, err := quic.DialContext(dialer.Context(), pktConn, serverUDPAddr, server, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
_ = pktConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return qs, nil
|
||||
}
|
||||
|
||||
func (ct *ClientTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
conn, err := ct.Dialer.Dial("tcp", raddr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(*net.TCPConn), nil
|
||||
}
|
||||
|
||||
func (ct *ClientTransport) ListenUDP() (*net.UDPConn, error) {
|
||||
return net.ListenUDP("udp", nil)
|
||||
}
|
42
transport/hysteria/utils/misc.go
Normal file
42
transport/hysteria/utils/misc.go
Normal file
@ -0,0 +1,42 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func SplitHostPort(hostport string) (string, uint16, error) {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
portUint, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return host, uint16(portUint), err
|
||||
}
|
||||
|
||||
func ParseIPZone(s string) (net.IP, string) {
|
||||
s, zone := splitHostZone(s)
|
||||
return net.ParseIP(s), zone
|
||||
}
|
||||
|
||||
func splitHostZone(s string) (host, zone string) {
|
||||
if i := last(s, '%'); i > 0 {
|
||||
host, zone = s[:i], s[i+1:]
|
||||
} else {
|
||||
host = s
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func last(s string, b byte) int {
|
||||
i := len(s)
|
||||
for i--; i >= 0; i-- {
|
||||
if s[i] == b {
|
||||
break
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
Reference in New Issue
Block a user