chore: embed hysteria, clean irrelevant codes, code from https://github.com/HyNetwork/hysteria

This commit is contained in:
Skyxim
2022-07-03 18:22:56 +08:00
parent 8ce9737f3d
commit 3cc1870aee
28 changed files with 3251 additions and 375 deletions

View File

@ -0,0 +1,100 @@
package acl
import (
"github.com/Dreamacro/clash/transport/hysteria/utils"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/geoip2-golang"
"net"
)
const entryCacheSize = 1024
type Engine struct {
DefaultAction Action
Entries []Entry
Cache *lru.ARCCache
ResolveIPAddr func(string) (*net.IPAddr, error)
GeoIPReader *geoip2.Reader
}
type cacheKey struct {
Host string
Port uint16
IsUDP bool
}
type cacheValue struct {
Action Action
Arg string
}
// action, arg, isDomain, resolvedIP, error
func (e *Engine) ResolveAndMatch(host string, port uint16, isUDP bool) (Action, string, bool, *net.IPAddr, error) {
ip, zone := utils.ParseIPZone(host)
if ip == nil {
// Domain
ipAddr, err := e.ResolveIPAddr(host)
if v, ok := e.Cache.Get(cacheKey{host, port, isUDP}); ok {
// Cache hit
ce := v.(cacheValue)
return ce.Action, ce.Arg, true, ipAddr, err
}
for _, entry := range e.Entries {
mReq := MatchRequest{
Domain: host,
Port: port,
DB: e.GeoIPReader,
}
if ipAddr != nil {
mReq.IP = ipAddr.IP
}
if isUDP {
mReq.Protocol = ProtocolUDP
} else {
mReq.Protocol = ProtocolTCP
}
if entry.Match(mReq) {
e.Cache.Add(cacheKey{host, port, isUDP},
cacheValue{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, true, ipAddr, err
}
}
e.Cache.Add(cacheKey{host, port, isUDP}, cacheValue{e.DefaultAction, ""})
return e.DefaultAction, "", true, ipAddr, err
} else {
// IP
if v, ok := e.Cache.Get(cacheKey{ip.String(), port, isUDP}); ok {
// Cache hit
ce := v.(cacheValue)
return ce.Action, ce.Arg, false, &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
for _, entry := range e.Entries {
mReq := MatchRequest{
IP: ip,
Port: port,
DB: e.GeoIPReader,
}
if isUDP {
mReq.Protocol = ProtocolUDP
} else {
mReq.Protocol = ProtocolTCP
}
if entry.Match(mReq) {
e.Cache.Add(cacheKey{ip.String(), port, isUDP},
cacheValue{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, false, &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
}
e.Cache.Add(cacheKey{ip.String(), port, isUDP}, cacheValue{e.DefaultAction, ""})
return e.DefaultAction, "", false, &net.IPAddr{
IP: ip,
Zone: zone,
}, nil
}
}

View File

@ -0,0 +1,154 @@
package acl
import (
"errors"
lru "github.com/hashicorp/golang-lru"
"net"
"strings"
"testing"
)
func TestEngine_ResolveAndMatch(t *testing.T) {
cache, _ := lru.NewARC(16)
e := &Engine{
DefaultAction: ActionDirect,
Entries: []Entry{
{
Action: ActionProxy,
ActionArg: "",
Matcher: &domainMatcher{
matcherBase: matcherBase{
Protocol: ProtocolTCP,
Port: 443,
},
Domain: "google.com",
Suffix: false,
},
},
{
Action: ActionHijack,
ActionArg: "good.org",
Matcher: &domainMatcher{
matcherBase: matcherBase{},
Domain: "evil.corp",
Suffix: true,
},
},
{
Action: ActionProxy,
ActionArg: "",
Matcher: &netMatcher{
matcherBase: matcherBase{},
Net: &net.IPNet{
IP: net.ParseIP("10.0.0.0"),
Mask: net.CIDRMask(8, 32),
},
},
},
{
Action: ActionBlock,
ActionArg: "",
Matcher: &allMatcher{},
},
},
Cache: cache,
ResolveIPAddr: func(s string) (*net.IPAddr, error) {
if strings.Contains(s, "evil.corp") {
return nil, errors.New("resolve error")
}
return net.ResolveIPAddr("ip", s)
},
}
tests := []struct {
name string
host string
port uint16
isUDP bool
wantAction Action
wantArg string
wantErr bool
}{
{
name: "domain proxy",
host: "google.com",
port: 443,
isUDP: false,
wantAction: ActionProxy,
wantArg: "",
},
{
name: "domain block",
host: "google.com",
port: 80,
isUDP: false,
wantAction: ActionBlock,
wantArg: "",
},
{
name: "domain suffix 1",
host: "evil.corp",
port: 8899,
isUDP: true,
wantAction: ActionHijack,
wantArg: "good.org",
wantErr: true,
},
{
name: "domain suffix 2",
host: "notevil.corp",
port: 22,
isUDP: false,
wantAction: ActionBlock,
wantArg: "",
wantErr: true,
},
{
name: "domain suffix 3",
host: "im.real.evil.corp",
port: 443,
isUDP: true,
wantAction: ActionHijack,
wantArg: "good.org",
wantErr: true,
},
{
name: "ip match",
host: "10.2.3.4",
port: 80,
isUDP: false,
wantAction: ActionProxy,
wantArg: "",
},
{
name: "ip mismatch",
host: "100.5.6.0",
port: 1234,
isUDP: false,
wantAction: ActionBlock,
wantArg: "",
},
{
name: "domain proxy cache",
host: "google.com",
port: 443,
isUDP: false,
wantAction: ActionProxy,
wantArg: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP)
if (err != nil) != tt.wantErr {
t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotAction != tt.wantAction {
t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction)
}
if gotArg != tt.wantArg {
t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg)
}
})
}
}

View File

@ -0,0 +1,331 @@
package acl
import (
"errors"
"fmt"
"github.com/oschwald/geoip2-golang"
"net"
"strconv"
"strings"
)
type Action byte
type Protocol byte
const (
ActionDirect = Action(iota)
ActionProxy
ActionBlock
ActionHijack
)
const (
ProtocolAll = Protocol(iota)
ProtocolTCP
ProtocolUDP
)
var protocolPortAliases = map[string]string{
"echo": "*/7",
"ftp-data": "*/20",
"ftp": "*/21",
"ssh": "*/22",
"telnet": "*/23",
"domain": "*/53",
"dns": "*/53",
"http": "*/80",
"sftp": "*/115",
"ntp": "*/123",
"https": "*/443",
"quic": "udp/443",
"socks": "*/1080",
}
type Entry struct {
Action Action
ActionArg string
Matcher Matcher
}
type MatchRequest struct {
IP net.IP
Domain string
Protocol Protocol
Port uint16
DB *geoip2.Reader
}
type Matcher interface {
Match(MatchRequest) bool
}
type matcherBase struct {
Protocol Protocol
Port uint16 // 0 for all ports
}
func (m *matcherBase) MatchProtocolPort(p Protocol, port uint16) bool {
return (m.Protocol == ProtocolAll || m.Protocol == p) && (m.Port == 0 || m.Port == port)
}
func parseProtocolPort(s string) (Protocol, uint16, error) {
if protocolPortAliases[s] != "" {
s = protocolPortAliases[s]
}
if len(s) == 0 || s == "*" {
return ProtocolAll, 0, nil
}
parts := strings.Split(s, "/")
if len(parts) != 2 {
return ProtocolAll, 0, errors.New("invalid protocol/port syntax")
}
protocol := ProtocolAll
switch parts[0] {
case "tcp":
protocol = ProtocolTCP
case "udp":
protocol = ProtocolUDP
case "*":
protocol = ProtocolAll
default:
return ProtocolAll, 0, errors.New("invalid protocol")
}
if parts[1] == "*" {
return protocol, 0, nil
}
port, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return ProtocolAll, 0, errors.New("invalid port")
}
return protocol, uint16(port), nil
}
type netMatcher struct {
matcherBase
Net *net.IPNet
}
func (m *netMatcher) Match(r MatchRequest) bool {
if r.IP == nil {
return false
}
return m.Net.Contains(r.IP) && m.MatchProtocolPort(r.Protocol, r.Port)
}
type domainMatcher struct {
matcherBase
Domain string
Suffix bool
}
func (m *domainMatcher) Match(r MatchRequest) bool {
if len(r.Domain) == 0 {
return false
}
domain := strings.ToLower(r.Domain)
return (m.Domain == domain || (m.Suffix && strings.HasSuffix(domain, "."+m.Domain))) &&
m.MatchProtocolPort(r.Protocol, r.Port)
}
type countryMatcher struct {
matcherBase
Country string // ISO 3166-1 alpha-2 country code, upper case
}
func (m *countryMatcher) Match(r MatchRequest) bool {
if r.IP == nil || r.DB == nil {
return false
}
c, err := r.DB.Country(r.IP)
if err != nil {
return false
}
return c.Country.IsoCode == m.Country && m.MatchProtocolPort(r.Protocol, r.Port)
}
type allMatcher struct {
matcherBase
}
func (m *allMatcher) Match(r MatchRequest) bool {
return m.MatchProtocolPort(r.Protocol, r.Port)
}
func (e Entry) Match(r MatchRequest) bool {
return e.Matcher.Match(r)
}
func ParseEntry(s string) (Entry, error) {
fields := strings.Fields(s)
if len(fields) < 2 {
return Entry{}, fmt.Errorf("expected at least 2 fields, got %d", len(fields))
}
e := Entry{}
action := fields[0]
conds := fields[1:]
switch strings.ToLower(action) {
case "direct":
e.Action = ActionDirect
case "proxy":
e.Action = ActionProxy
case "block":
e.Action = ActionBlock
case "hijack":
if len(conds) < 2 {
return Entry{}, fmt.Errorf("hijack requires at least 3 fields, got %d", len(fields))
}
e.Action = ActionHijack
e.ActionArg = conds[len(conds)-1]
conds = conds[:len(conds)-1]
default:
return Entry{}, fmt.Errorf("invalid action %s", fields[0])
}
m, err := condsToMatcher(conds)
if err != nil {
return Entry{}, err
}
e.Matcher = m
return e, nil
}
func condsToMatcher(conds []string) (Matcher, error) {
if len(conds) < 1 {
return nil, errors.New("no condition specified")
}
typ, args := conds[0], conds[1:]
switch strings.ToLower(typ) {
case "domain":
// domain <domain> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for domain: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &domainMatcher{
matcherBase: mb,
Domain: args[0],
Suffix: false,
}, nil
case "domain-suffix":
// domain-suffix <domain> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for domain-suffix: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &domainMatcher{
matcherBase: mb,
Domain: args[0],
Suffix: true,
}, nil
case "cidr":
// cidr <cidr> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for cidr: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
_, ipNet, err := net.ParseCIDR(args[0])
if err != nil {
return nil, err
}
return &netMatcher{
matcherBase: mb,
Net: ipNet,
}, nil
case "ip":
// ip <ip> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for ip: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
ip := net.ParseIP(args[0])
if ip == nil {
return nil, fmt.Errorf("invalid ip: %s", args[0])
}
var ipNet *net.IPNet
if ip.To4() != nil {
ipNet = &net.IPNet{
IP: ip,
Mask: net.CIDRMask(32, 32),
}
} else {
ipNet = &net.IPNet{
IP: ip,
Mask: net.CIDRMask(128, 128),
}
}
return &netMatcher{
matcherBase: mb,
Net: ipNet,
}, nil
case "country":
// country <country> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for country: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &countryMatcher{
matcherBase: mb,
Country: strings.ToUpper(args[0]),
}, nil
case "all":
// all <optional: protocol/port>
if len(args) > 1 {
return nil, fmt.Errorf("invalid number of arguments for all: %d, expected 0 or 1", len(args))
}
mb := matcherBase{}
if len(args) == 1 {
protocol, port, err := parseProtocolPort(args[0])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &allMatcher{
matcherBase: mb,
}, nil
default:
return nil, fmt.Errorf("invalid condition type: %s", typ)
}
}

View File

@ -0,0 +1,75 @@
package acl
import (
"net"
"reflect"
"testing"
)
func TestParseEntry(t *testing.T) {
_, ok3net, _ := net.ParseCIDR("8.8.8.0/24")
type args struct {
s string
}
tests := []struct {
name string
args args
want Entry
wantErr bool
}{
{name: "empty", args: args{""}, want: Entry{}, wantErr: true},
{name: "ok 1", args: args{"direct domain-suffix google.com"},
want: Entry{ActionDirect, "", &domainMatcher{
matcherBase: matcherBase{},
Domain: "google.com",
Suffix: true,
}},
wantErr: false},
{name: "ok 2", args: args{"proxy domain shithole"},
want: Entry{ActionProxy, "", &domainMatcher{
matcherBase: matcherBase{},
Domain: "shithole",
Suffix: false,
}},
wantErr: false},
{name: "ok 3", args: args{"block cidr 8.8.8.0/24 */53"},
want: Entry{ActionBlock, "", &netMatcher{
matcherBase: matcherBase{ProtocolAll, 53},
Net: ok3net,
}},
wantErr: false},
{name: "ok 4", args: args{"hijack all udp/* udpblackhole.net"},
want: Entry{ActionHijack, "udpblackhole.net", &allMatcher{
matcherBase: matcherBase{ProtocolUDP, 0},
}},
wantErr: false},
{name: "err 1", args: args{"what the heck"},
want: Entry{},
wantErr: true},
{name: "err 2", args: args{"proxy sucks ass"},
want: Entry{},
wantErr: true},
{name: "err 3", args: args{"block ip 999.999.999.999"},
want: Entry{},
wantErr: true},
{name: "err 4", args: args{"hijack domain google.com"},
want: Entry{},
wantErr: true},
{name: "err 5", args: args{"hijack domain google.com bing.com 123"},
want: Entry{},
wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseEntry(tt.args.s)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEntry() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseEntry() got = %v, wantAction %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,145 @@
package congestion
import (
"github.com/lucas-clemente/quic-go/congestion"
"time"
)
const (
initMaxDatagramSize = 1252
pktInfoSlotCount = 4
minSampleCount = 50
minAckRate = 0.8
)
type BrutalSender struct {
rttStats congestion.RTTStatsProvider
bps congestion.ByteCount
maxDatagramSize congestion.ByteCount
pacer *pacer
pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64
}
type pktInfo struct {
Timestamp int64
AckCount uint64
LossCount uint64
}
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
bs := &BrutalSender{
bps: bps,
maxDatagramSize: initMaxDatagramSize,
ackRate: 1,
}
bs.pacer = newPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
})
return bs
}
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
b.rttStats = rttStats
}
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
return b.pacer.TimeUntilSend()
}
func (b *BrutalSender) HasPacingBudget() bool {
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
}
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < b.GetCongestionWindow()
}
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
if rtt <= 0 {
return 10240
}
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
}
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool) {
b.pacer.SentPacket(sentTime, bytes)
}
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].AckCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 1
b.pktInfoSlots[slot].LossCount = 0
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount) {
currentTimestamp := time.Now().Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 0
b.pktInfoSlots[slot].LossCount = 1
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size)
}
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
minTimestamp := currentTimestamp - pktInfoSlotCount
var ackCount, lossCount uint64
for _, info := range b.pktInfoSlots {
if info.Timestamp < minTimestamp {
continue
}
ackCount += info.AckCount
lossCount += info.LossCount
}
if ackCount+lossCount < minSampleCount {
b.ackRate = 1
}
rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate {
b.ackRate = minAckRate
}
b.ackRate = rate
}
func (b *BrutalSender) InSlowStart() bool {
return false
}
func (b *BrutalSender) InRecovery() bool {
return false
}
func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func maxDuration(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}

View File

@ -0,0 +1,85 @@
package congestion
import (
"github.com/lucas-clemente/quic-go/congestion"
"math"
"time"
)
const (
maxBurstPackets = 10
minPacingDelay = time.Millisecond
)
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastSentTime time.Time
getBandwidth func() congestion.ByteCount // in bytes/s
}
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
p := &pacer{
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
maxDatagramSize: initMaxDatagramSize,
getBandwidth: getBandwidth,
}
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
return minByteCount(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() congestion.ByteCount {
return maxByteCount(
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
maxBurstPackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(maxDuration(
minPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
float64(p.getBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
p.maxDatagramSize = s
}
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return b
}
return a
}
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return a
}
return b
}

View File

@ -0,0 +1 @@
Grabbed from https://github.com/xtaci/tcpraw with modifications

View File

@ -0,0 +1,102 @@
package faketcp
import (
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"net"
"sync"
"syscall"
"time"
)
const udpBufferSize = 65535
type ObfsFakeTCPConn struct {
orig *TCPConn
obfs obfs.Obfuscator
closed bool
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewObfsFakeTCPConn(orig *TCPConn, obfs obfs.Obfuscator) *ObfsFakeTCPConn {
return &ObfsFakeTCPConn{
orig: orig,
obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
}
}
func (c *ObfsFakeTCPConn) ReadFrom(p []byte) (int, net.Addr, error) {
for {
c.readMutex.Lock()
if c.closed {
log.Infoln("read faketcp obfs before")
}
n, addr, err := c.orig.ReadFrom(c.readBuf)
if c.closed {
log.Infoln("read faketcp obfs after")
}
if n <= 0 {
c.readMutex.Unlock()
return 0, addr, err
}
newN := c.obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
if newN > 0 {
// Valid packet
return newN, addr, err
} else if err != nil {
// Not valid and orig.ReadFrom had some error
return 0, addr, err
}
}
}
func (c *ObfsFakeTCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
bn := c.obfs.Obfuscate(p, c.writeBuf)
_, err = c.orig.WriteTo(c.writeBuf[:bn], addr)
c.writeMutex.Unlock()
if err != nil {
return 0, err
} else {
return len(p), nil
}
}
func (c *ObfsFakeTCPConn) Close() error {
c.closed = true
return c.orig.Close()
}
func (c *ObfsFakeTCPConn) LocalAddr() net.Addr {
return c.orig.LocalAddr()
}
func (c *ObfsFakeTCPConn) SetDeadline(t time.Time) error {
return c.orig.SetDeadline(t)
}
func (c *ObfsFakeTCPConn) SetReadDeadline(t time.Time) error {
return c.orig.SetReadDeadline(t)
}
func (c *ObfsFakeTCPConn) SetWriteDeadline(t time.Time) error {
return c.orig.SetWriteDeadline(t)
}
func (c *ObfsFakeTCPConn) SetReadBuffer(bytes int) error {
return c.orig.SetReadBuffer(bytes)
}
func (c *ObfsFakeTCPConn) SetWriteBuffer(bytes int) error {
return c.orig.SetWriteBuffer(bytes)
}
func (c *ObfsFakeTCPConn) SyscallConn() (syscall.RawConn, error) {
return c.orig.SyscallConn()
}

View File

@ -0,0 +1,616 @@
//go:build linux
// +build linux
package faketcp
import (
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var (
errOpNotImplemented = errors.New("operation not implemented")
errTimeout = errors.New("timeout")
expire = time.Minute
)
// a message from NIC
type message struct {
bts []byte
addr net.Addr
}
// a tcp flow information of a connection pair
type tcpFlow struct {
conn *net.TCPConn // the related system TCP connection of this flow
handle *net.IPConn // the handle to send packets
seq uint32 // TCP sequence number
ack uint32 // TCP acknowledge number
networkLayer gopacket.SerializableLayer // network layer header for tx
ts time.Time // last packet incoming time
buf gopacket.SerializeBuffer // a buffer for write
tcpHeader layers.TCP
}
// TCPConn defines a TCP-packet oriented connection
type TCPConn struct {
die chan struct{}
dieOnce sync.Once
// the main golang sockets
tcpconn *net.TCPConn // from net.Dial
listener *net.TCPListener // from net.Listen
// handles
handles []*net.IPConn
// packets captured from all related NICs will be delivered to this channel
chMessage chan message
// all TCP flows
flowTable map[string]*tcpFlow
flowsLock sync.Mutex
// iptables
iptables *iptables.IPTables
iprule []string
ip6tables *iptables.IPTables
ip6rule []string
// deadlines
readDeadline atomic.Value
writeDeadline atomic.Value
// serialization
opts gopacket.SerializeOptions
}
// lockflow locks the flow table and apply function `f` to the entry, and create one if not exist
func (conn *TCPConn) lockflow(addr net.Addr, f func(e *tcpFlow)) {
key := addr.String()
conn.flowsLock.Lock()
e := conn.flowTable[key]
if e == nil { // entry first visit
e = new(tcpFlow)
e.ts = time.Now()
e.buf = gopacket.NewSerializeBuffer()
}
f(e)
conn.flowTable[key] = e
conn.flowsLock.Unlock()
}
// clean expired flows
func (conn *TCPConn) cleaner() {
ticker := time.NewTicker(time.Minute)
select {
case <-conn.die:
return
case <-ticker.C:
conn.flowsLock.Lock()
for k, v := range conn.flowTable {
if time.Now().Sub(v.ts) > expire {
if v.conn != nil {
setTTL(v.conn, 64)
v.conn.Close()
}
delete(conn.flowTable, k)
}
}
conn.flowsLock.Unlock()
}
}
// captureFlow capture every inbound packets based on rules of BPF
func (conn *TCPConn) captureFlow(handle *net.IPConn, port int) {
buf := make([]byte, 2048)
opt := gopacket.DecodeOptions{NoCopy: true, Lazy: true}
for {
n, addr, err := handle.ReadFromIP(buf)
if err != nil {
return
}
// try decoding TCP frame from buf[:n]
packet := gopacket.NewPacket(buf[:n], layers.LayerTypeTCP, opt)
transport := packet.TransportLayer()
tcp, ok := transport.(*layers.TCP)
if !ok {
continue
}
// port filtering
if int(tcp.DstPort) != port {
continue
}
// address building
var src net.TCPAddr
src.IP = addr.IP
src.Port = int(tcp.SrcPort)
var orphan bool
// flow maintaince
conn.lockflow(&src, func(e *tcpFlow) {
if e.conn == nil { // make sure it's related to net.TCPConn
orphan = true // mark as orphan if it's not related net.TCPConn
}
// to keep track of TCP header related to this source
e.ts = time.Now()
if tcp.ACK {
e.seq = tcp.Ack
}
if tcp.SYN {
e.ack = tcp.Seq + 1
}
if tcp.PSH {
if e.ack == tcp.Seq {
e.ack = tcp.Seq + uint32(len(tcp.Payload))
}
}
e.handle = handle
})
// push data if it's not orphan
if !orphan && tcp.PSH {
payload := make([]byte, len(tcp.Payload))
copy(payload, tcp.Payload)
select {
case conn.chMessage <- message{payload, &src}:
case <-conn.die:
return
}
}
}
}
// ReadFrom implements the PacketConn ReadFrom method.
func (conn *TCPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
var timer *time.Timer
var deadline <-chan time.Time
if d, ok := conn.readDeadline.Load().(time.Time); ok && !d.IsZero() {
timer = time.NewTimer(time.Until(d))
defer timer.Stop()
deadline = timer.C
}
select {
case <-deadline:
return 0, nil, errTimeout
case <-conn.die:
return 0, nil, io.EOF
case packet := <-conn.chMessage:
n = copy(p, packet.bts)
return n, packet.addr, nil
}
}
// WriteTo implements the PacketConn WriteTo method.
func (conn *TCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
var deadline <-chan time.Time
if d, ok := conn.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
timer := time.NewTimer(time.Until(d))
defer timer.Stop()
deadline = timer.C
}
select {
case <-deadline:
return 0, errTimeout
case <-conn.die:
return 0, io.EOF
default:
raddr, err := net.ResolveTCPAddr("tcp", addr.String())
if err != nil {
return 0, err
}
var lport int
if conn.tcpconn != nil {
lport = conn.tcpconn.LocalAddr().(*net.TCPAddr).Port
} else {
lport = conn.listener.Addr().(*net.TCPAddr).Port
}
conn.lockflow(addr, func(e *tcpFlow) {
// if the flow doesn't have handle , assume this packet has lost, without notification
if e.handle == nil {
n = len(p)
return
}
// build tcp header with local and remote port
e.tcpHeader.SrcPort = layers.TCPPort(lport)
e.tcpHeader.DstPort = layers.TCPPort(raddr.Port)
binary.Read(rand.Reader, binary.LittleEndian, &e.tcpHeader.Window)
e.tcpHeader.Window |= 0x8000 // make sure it's larger than 32768
e.tcpHeader.Ack = e.ack
e.tcpHeader.Seq = e.seq
e.tcpHeader.PSH = true
e.tcpHeader.ACK = true
// build IP header with src & dst ip for TCP checksum
if raddr.IP.To4() != nil {
ip := &layers.IPv4{
Protocol: layers.IPProtocolTCP,
SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To4(),
DstIP: raddr.IP.To4(),
}
e.tcpHeader.SetNetworkLayerForChecksum(ip)
} else {
ip := &layers.IPv6{
NextHeader: layers.IPProtocolTCP,
SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To16(),
DstIP: raddr.IP.To16(),
}
e.tcpHeader.SetNetworkLayerForChecksum(ip)
}
e.buf.Clear()
gopacket.SerializeLayers(e.buf, conn.opts, &e.tcpHeader, gopacket.Payload(p))
if conn.tcpconn != nil {
_, err = e.handle.Write(e.buf.Bytes())
} else {
_, err = e.handle.WriteToIP(e.buf.Bytes(), &net.IPAddr{IP: raddr.IP})
}
// increase seq in flow
e.seq += uint32(len(p))
n = len(p)
})
}
return
}
// Close closes the connection.
func (conn *TCPConn) Close() error {
var err error
conn.dieOnce.Do(func() {
// signal closing
close(conn.die)
// close all established tcp connections
if conn.tcpconn != nil { // client
setTTL(conn.tcpconn, 64)
err = conn.tcpconn.Close()
} else if conn.listener != nil {
err = conn.listener.Close() // server
conn.flowsLock.Lock()
for k, v := range conn.flowTable {
if v.conn != nil {
setTTL(v.conn, 64)
v.conn.Close()
}
delete(conn.flowTable, k)
}
conn.flowsLock.Unlock()
}
// close handles
for k := range conn.handles {
conn.handles[k].Close()
}
// delete iptable
if conn.iptables != nil {
conn.iptables.Delete("filter", "OUTPUT", conn.iprule...)
}
if conn.ip6tables != nil {
conn.ip6tables.Delete("filter", "OUTPUT", conn.ip6rule...)
}
})
return err
}
// LocalAddr returns the local network address.
func (conn *TCPConn) LocalAddr() net.Addr {
if conn.tcpconn != nil {
return conn.tcpconn.LocalAddr()
} else if conn.listener != nil {
return conn.listener.Addr()
}
return nil
}
// SetDeadline implements the Conn SetDeadline method.
func (conn *TCPConn) SetDeadline(t time.Time) error {
if err := conn.SetReadDeadline(t); err != nil {
return err
}
if err := conn.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
// SetReadDeadline implements the Conn SetReadDeadline method.
func (conn *TCPConn) SetReadDeadline(t time.Time) error {
conn.readDeadline.Store(t)
return nil
}
// SetWriteDeadline implements the Conn SetWriteDeadline method.
func (conn *TCPConn) SetWriteDeadline(t time.Time) error {
conn.writeDeadline.Store(t)
return nil
}
// SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
func (conn *TCPConn) SetDSCP(dscp int) error {
for k := range conn.handles {
if err := setDSCP(conn.handles[k], dscp); err != nil {
return err
}
}
return nil
}
// SetReadBuffer sets the size of the operating system's receive buffer associated with the connection.
func (conn *TCPConn) SetReadBuffer(bytes int) error {
var err error
for k := range conn.handles {
if err := conn.handles[k].SetReadBuffer(bytes); err != nil {
return err
}
}
return err
}
// SetWriteBuffer sets the size of the operating system's transmit buffer associated with the connection.
func (conn *TCPConn) SetWriteBuffer(bytes int) error {
var err error
for k := range conn.handles {
if err := conn.handles[k].SetWriteBuffer(bytes); err != nil {
return err
}
}
return err
}
func (conn *TCPConn) SyscallConn() (syscall.RawConn, error) {
if len(conn.handles) == 0 {
return nil, errors.New("no handles")
// How is it possible?
}
return conn.handles[0].SyscallConn()
}
// Dial connects to the remote TCP port,
// and returns a single packet-oriented connection
func Dial(network, address string) (*TCPConn, error) {
// remote address resolve
raddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
// AF_INET
handle, err := net.DialIP("ip:tcp", nil, &net.IPAddr{IP: raddr.IP})
if err != nil {
return nil, err
}
// create an established tcp connection
// will hack this tcp connection for packet transmission
tcpconn, err := net.DialTCP(network, nil, raddr)
if err != nil {
return nil, err
}
// fields
conn := new(TCPConn)
conn.die = make(chan struct{})
conn.flowTable = make(map[string]*tcpFlow)
conn.tcpconn = tcpconn
conn.chMessage = make(chan message)
conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
conn.handles = append(conn.handles, handle)
conn.opts = gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
go conn.captureFlow(handle, tcpconn.LocalAddr().(*net.TCPAddr).Port)
go conn.cleaner()
// iptables
err = setTTL(tcpconn, 1)
if err != nil {
return nil, err
}
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
if !exists {
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
conn.iprule = rule
conn.iptables = ipt
}
}
}
}
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
if !exists {
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
conn.ip6rule = rule
conn.ip6tables = ipt
}
}
}
}
// discard everything
go io.Copy(ioutil.Discard, tcpconn)
return conn, nil
}
// Listen acts like net.ListenTCP,
// and returns a single packet-oriented connection
func Listen(network, address string) (*TCPConn, error) {
// fields
conn := new(TCPConn)
conn.flowTable = make(map[string]*tcpFlow)
conn.die = make(chan struct{})
conn.chMessage = make(chan message)
conn.opts = gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
// resolve address
laddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
// AF_INET
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
if laddr.IP == nil || laddr.IP.IsUnspecified() { // if address is not specified, capture on all ifaces
var lasterr error
for _, iface := range ifaces {
if addrs, err := iface.Addrs(); err == nil {
for _, addr := range addrs {
if ipaddr, ok := addr.(*net.IPNet); ok {
if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: ipaddr.IP}); err == nil {
conn.handles = append(conn.handles, handle)
go conn.captureFlow(handle, laddr.Port)
} else {
lasterr = err
}
}
}
}
}
if len(conn.handles) == 0 {
return nil, lasterr
}
} else {
if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: laddr.IP}); err == nil {
conn.handles = append(conn.handles, handle)
go conn.captureFlow(handle, laddr.Port)
} else {
return nil, err
}
}
// start listening
l, err := net.ListenTCP(network, laddr)
if err != nil {
return nil, err
}
conn.listener = l
// start cleaner
go conn.cleaner()
// iptables drop packets marked with TTL = 1
// TODO: what if iptables is not available, the next hop will send back ICMP Time Exceeded,
// is this still an acceptable behavior?
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
if !exists {
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
conn.iprule = rule
conn.iptables = ipt
}
}
}
}
if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
if !exists {
if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
conn.ip6rule = rule
conn.ip6tables = ipt
}
}
}
}
// discard everything in original connection
go func() {
for {
tcpconn, err := l.AcceptTCP()
if err != nil {
return
}
// if we cannot set TTL = 1, the only thing reasonable is panic
if err := setTTL(tcpconn, 1); err != nil {
panic(err)
}
// record net.Conn
conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
// discard everything
go io.Copy(ioutil.Discard, tcpconn)
}
}()
return conn, nil
}
// setTTL sets the Time-To-Live field on a given connection
func setTTL(c *net.TCPConn, ttl int) error {
raw, err := c.SyscallConn()
if err != nil {
return err
}
addr := c.LocalAddr().(*net.TCPAddr)
if addr.IP.To4() == nil {
raw.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, ttl)
})
} else {
raw.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, ttl)
})
}
return err
}
// setDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
func setDSCP(c *net.IPConn, dscp int) error {
raw, err := c.SyscallConn()
if err != nil {
return err
}
addr := c.LocalAddr().(*net.IPAddr)
if addr.IP.To4() == nil {
raw.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, dscp)
})
} else {
raw.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, dscp<<2)
})
}
return err
}

View File

@ -0,0 +1,21 @@
//go:build !linux
// +build !linux
package faketcp
import (
"errors"
"net"
)
type TCPConn struct{ *net.UDPConn }
// Dial connects to the remote TCP port,
// and returns a single packet-oriented connection
func Dial(network, address string) (*TCPConn, error) {
return nil, errors.New("faketcp is not supported on this platform")
}
func Listen(network, address string) (*TCPConn, error) {
return nil, errors.New("faketcp is not supported on this platform")
}

View File

@ -0,0 +1,196 @@
//go:build linux
// +build linux
package faketcp
import (
"log"
"net"
"net/http"
_ "net/http/pprof"
"testing"
)
//const testPortStream = "127.0.0.1:3456"
//const testPortPacket = "127.0.0.1:3457"
const testPortStream = "127.0.0.1:3456"
const portServerPacket = "[::]:3457"
const portRemotePacket = "127.0.0.1:3457"
func init() {
startTCPServer()
startTCPRawServer()
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
}
func startTCPServer() net.Listener {
l, err := net.Listen("tcp", testPortStream)
if err != nil {
log.Panicln(err)
}
go func() {
defer l.Close()
for {
conn, err := l.Accept()
if err != nil {
log.Println(err)
return
}
go handleRequest(conn)
}
}()
return l
}
func startTCPRawServer() *TCPConn {
conn, err := Listen("tcp", portServerPacket)
if err != nil {
log.Panicln(err)
}
err = conn.SetReadBuffer(1024 * 1024)
if err != nil {
log.Println(err)
}
err = conn.SetWriteBuffer(1024 * 1024)
if err != nil {
log.Println(err)
}
go func() {
defer conn.Close()
buf := make([]byte, 1024)
for {
n, addr, err := conn.ReadFrom(buf)
if err != nil {
log.Println("server readfrom:", err)
return
}
//echo
n, err = conn.WriteTo(buf[:n], addr)
if err != nil {
log.Println("server writeTo:", err)
return
}
}
}()
return conn
}
func handleRequest(conn net.Conn) {
defer conn.Close()
for {
buf := make([]byte, 1024)
size, err := conn.Read(buf)
if err != nil {
log.Println("handleRequest:", err)
return
}
data := buf[:size]
conn.Write(data)
}
}
func TestDialTCPStream(t *testing.T) {
conn, err := Dial("tcp", testPortStream)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
addr, err := net.ResolveTCPAddr("tcp", testPortStream)
if err != nil {
t.Fatal(err)
}
n, err := conn.WriteTo([]byte("abc"), addr)
if err != nil {
t.Fatal(n, err)
}
buf := make([]byte, 1024)
if n, addr, err := conn.ReadFrom(buf); err != nil {
t.Fatal(n, addr, err)
} else {
log.Println(string(buf[:n]), "from:", addr)
}
}
func TestDialToTCPPacket(t *testing.T) {
conn, err := Dial("tcp", portRemotePacket)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
addr, err := net.ResolveTCPAddr("tcp", portRemotePacket)
if err != nil {
t.Fatal(err)
}
n, err := conn.WriteTo([]byte("abc"), addr)
if err != nil {
t.Fatal(n, err)
}
log.Println("written")
buf := make([]byte, 1024)
log.Println("readfrom buf")
if n, addr, err := conn.ReadFrom(buf); err != nil {
log.Println(err)
t.Fatal(n, addr, err)
} else {
log.Println(string(buf[:n]), "from:", addr)
}
log.Println("complete")
}
func TestSettings(t *testing.T) {
conn, err := Dial("tcp", portRemotePacket)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if err := conn.SetDSCP(46); err != nil {
log.Fatal("SetDSCP:", err)
}
if err := conn.SetReadBuffer(4096); err != nil {
log.Fatal("SetReaderBuffer:", err)
}
if err := conn.SetWriteBuffer(4096); err != nil {
log.Fatal("SetWriteBuffer:", err)
}
}
func BenchmarkEcho(b *testing.B) {
conn, err := Dial("tcp", portRemotePacket)
if err != nil {
b.Fatal(err)
}
defer conn.Close()
addr, err := net.ResolveTCPAddr("tcp", portRemotePacket)
if err != nil {
b.Fatal(err)
}
buf := make([]byte, 1024)
b.ReportAllocs()
b.SetBytes(int64(len(buf)))
for i := 0; i < b.N; i++ {
n, err := conn.WriteTo(buf, addr)
if err != nil {
b.Fatal(n, err)
}
if n, addr, err := conn.ReadFrom(buf); err != nil {
b.Fatal(n, addr, err)
}
}
}

View File

@ -0,0 +1,89 @@
package udp
import (
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"net"
"sync"
"time"
)
const udpBufferSize = 65535
type ObfsUDPConn struct {
orig net.PacketConn
obfs obfs.Obfuscator
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
closed bool
}
func NewObfsUDPConn(orig net.PacketConn, obfs obfs.Obfuscator) *ObfsUDPConn {
return &ObfsUDPConn{
orig: orig,
obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
}
}
func (c *ObfsUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
for {
c.readMutex.Lock()
if c.closed {
log.Infoln("read udp obfs before")
}
n, addr, err := c.orig.ReadFrom(c.readBuf)
if c.closed {
log.Infoln("read udp obfs after")
}
if n <= 0 {
c.readMutex.Unlock()
return 0, addr, err
}
newN := c.obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
if newN > 0 {
// Valid packet
return newN, addr, err
} else if err != nil {
// Not valid and orig.ReadFrom had some error
return 0, addr, err
}
}
}
func (c *ObfsUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
bn := c.obfs.Obfuscate(p, c.writeBuf)
_, err = c.orig.WriteTo(c.writeBuf[:bn], addr)
c.writeMutex.Unlock()
if err != nil {
return 0, err
} else {
return len(p), nil
}
}
func (c *ObfsUDPConn) Close() error {
c.closed = true
return c.orig.Close()
}
func (c *ObfsUDPConn) LocalAddr() net.Addr {
return c.orig.LocalAddr()
}
func (c *ObfsUDPConn) SetDeadline(t time.Time) error {
return c.orig.SetDeadline(t)
}
func (c *ObfsUDPConn) SetReadDeadline(t time.Time) error {
return c.orig.SetReadDeadline(t)
}
func (c *ObfsUDPConn) SetWriteDeadline(t time.Time) error {
return c.orig.SetWriteDeadline(t)
}

View File

@ -0,0 +1,105 @@
package wechat
import (
"encoding/binary"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"math/rand"
"net"
"sync"
"time"
)
const udpBufferSize = 65535
type ObfsWeChatUDPConn struct {
orig net.PacketConn
obfs obfs.Obfuscator
closed bool
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
sn uint32
}
func NewObfsWeChatUDPConn(orig net.PacketConn, obfs obfs.Obfuscator) *ObfsWeChatUDPConn {
log.Infoln("new wechat")
return &ObfsWeChatUDPConn{
orig: orig,
obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
sn: rand.Uint32() & 0xFFFF,
}
}
func (c *ObfsWeChatUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
for {
c.readMutex.Lock()
if c.closed {
log.Infoln("read wechat obfs before")
}
n, addr, err := c.orig.ReadFrom(c.readBuf)
if c.closed {
log.Infoln("read wechat obfs after")
}
if n <= 13 {
c.readMutex.Unlock()
return 0, addr, err
}
newN := c.obfs.Deobfuscate(c.readBuf[13:n], p)
c.readMutex.Unlock()
if newN > 0 {
// Valid packet
return newN, addr, err
} else if err != nil {
// Not valid and orig.ReadFrom had some error
return 0, addr, err
}
}
}
func (c *ObfsWeChatUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
c.writeBuf[0] = 0xa1
c.writeBuf[1] = 0x08
binary.BigEndian.PutUint32(c.writeBuf[2:], c.sn)
c.sn++
c.writeBuf[6] = 0x00
c.writeBuf[7] = 0x10
c.writeBuf[8] = 0x11
c.writeBuf[9] = 0x18
c.writeBuf[10] = 0x30
c.writeBuf[11] = 0x22
c.writeBuf[12] = 0x30
bn := c.obfs.Obfuscate(p, c.writeBuf[13:])
_, err = c.orig.WriteTo(c.writeBuf[:13+bn], addr)
c.writeMutex.Unlock()
if err != nil {
return 0, err
} else {
return len(p), nil
}
}
func (c *ObfsWeChatUDPConn) Close() error {
c.closed = true
return c.orig.Close()
}
func (c *ObfsWeChatUDPConn) LocalAddr() net.Addr {
return c.orig.LocalAddr()
}
func (c *ObfsWeChatUDPConn) SetDeadline(t time.Time) error {
return c.orig.SetDeadline(t)
}
func (c *ObfsWeChatUDPConn) SetReadDeadline(t time.Time) error {
return c.orig.SetReadDeadline(t)
}
func (c *ObfsWeChatUDPConn) SetWriteDeadline(t time.Time) error {
return c.orig.SetWriteDeadline(t)
}

View File

@ -0,0 +1,422 @@
package core
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"github.com/Dreamacro/clash/transport/hysteria/pmtud_fix"
"github.com/Dreamacro/clash/transport/hysteria/transport"
"github.com/Dreamacro/clash/transport/hysteria/utils"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lunixbochs/struc"
"math/rand"
"net"
"strconv"
"sync"
"time"
)
var (
ErrClosed = errors.New("closed")
)
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
type Client struct {
transport *transport.ClientTransport
serverAddr string
protocol string
sendBPS, recvBPS uint64
auth []byte
congestionFactory CongestionFactory
obfuscator obfs.Obfuscator
tlsConfig *tls.Config
quicConfig *quic.Config
quicSession quic.Connection
reconnectMutex sync.Mutex
closed bool
udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]chan *udpMessage
udpDefragger defragger
}
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator obfs.Obfuscator) (*Client, error) {
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
c := &Client{
transport: transport,
serverAddr: serverAddr,
protocol: protocol,
sendBPS: sendBPS,
recvBPS: recvBPS,
auth: auth,
congestionFactory: congestionFactory,
obfuscator: obfuscator,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
}
return c, nil
}
func (c *Client) connectToServer(dialer transport.PacketDialer) error {
qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, dialer)
if err != nil {
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
ok, msg, err := c.handleControlStream(qs, stream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
if !ok {
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
return fmt.Errorf("auth error: %s", msg)
}
// All good
c.udpSessionMap = make(map[uint32]chan *udpMessage)
go c.handleMessage(qs)
c.quicSession = qs
return nil
}
func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) {
// Send protocol version
_, err := stream.Write([]byte{protocolVersion})
if err != nil {
return false, "", err
}
// Send client hello
err = struc.Pack(stream, &clientHello{
Rate: transmissionRate{
SendBPS: c.sendBPS,
RecvBPS: c.recvBPS,
},
Auth: c.auth,
})
if err != nil {
return false, "", err
}
// Receive server hello
var sh serverHello
err = struc.Unpack(stream, &sh)
if err != nil {
return false, "", err
}
// Set the congestion accordingly
if sh.OK && c.congestionFactory != nil {
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
}
return sh.OK, sh.Message, nil
}
func (c *Client) handleMessage(qs quic.Connection) {
for {
msg, err := qs.ReceiveMessage()
if err != nil {
break
}
var udpMsg udpMessage
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil {
continue
}
dfMsg := c.udpDefragger.Feed(udpMsg)
if dfMsg == nil {
continue
}
c.udpSessionMutex.RLock()
ch, ok := c.udpSessionMap[dfMsg.SessionID]
if ok {
select {
case ch <- dfMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
c.udpSessionMutex.RUnlock()
}
}
func (c *Client) openStreamWithReconnect(dialer transport.PacketDialer) (quic.Connection, quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, nil, ErrClosed
}
if c.quicSession == nil {
if err := c.connectToServer(dialer); err != nil {
// Still error, oops
return nil, nil, err
}
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return c.quicSession, &wrappedQUICStream{stream}, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(dialer); err != nil {
// Still error, oops
return nil, nil, err
}
// We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream()
return c.quicSession, &wrappedQUICStream{stream}, err
}
func (c *Client) DialTCP(addr string, dialer transport.PacketDialer) (net.Conn, error) {
host, port, err := utils.SplitHostPort(addr)
if err != nil {
return nil, err
}
session, stream, err := c.openStreamWithReconnect(dialer)
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: false,
Host: host,
Port: port,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
return &quicConn{
Orig: stream,
PseudoLocalAddr: session.LocalAddr(),
PseudoRemoteAddr: session.RemoteAddr(),
}, nil
}
func (c *Client) DialUDP(dialer transport.PacketDialer) (UDPConn, error) {
session, stream, err := c.openStreamWithReconnect(dialer)
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: true,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
// Create a session in the map
c.udpSessionMutex.Lock()
nCh := make(chan *udpMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
sessionMap := c.udpSessionMap
sessionMap[sr.UDPSessionID] = nCh
c.udpSessionMutex.Unlock()
pktConn := &quicPktConn{
Session: session,
Stream: stream,
CloseFunc: func() {
c.udpSessionMutex.Lock()
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
close(ch)
delete(sessionMap, sr.UDPSessionID)
}
c.udpSessionMutex.Unlock()
},
UDPSessionID: sr.UDPSessionID,
MsgCh: nCh,
}
go pktConn.Hold()
return pktConn, nil
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
c.closed = true
return err
}
type quicConn struct {
Orig quic.Stream
PseudoLocalAddr net.Addr
PseudoRemoteAddr net.Addr
}
func (w *quicConn) Read(b []byte) (n int, err error) {
return w.Orig.Read(b)
}
func (w *quicConn) Write(b []byte) (n int, err error) {
return w.Orig.Write(b)
}
func (w *quicConn) Close() error {
return w.Orig.Close()
}
func (w *quicConn) LocalAddr() net.Addr {
return w.PseudoLocalAddr
}
func (w *quicConn) RemoteAddr() net.Addr {
return w.PseudoRemoteAddr
}
func (w *quicConn) SetDeadline(t time.Time) error {
return w.Orig.SetDeadline(t)
}
func (w *quicConn) SetReadDeadline(t time.Time) error {
return w.Orig.SetReadDeadline(t)
}
func (w *quicConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}
type UDPConn interface {
ReadFrom() ([]byte, string, error)
WriteTo([]byte, string) error
Close() error
LocalAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
type quicPktConn struct {
Session quic.Connection
Stream quic.Stream
CloseFunc func()
UDPSessionID uint32
MsgCh <-chan *udpMessage
}
func (c *quicPktConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.Stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
msg := <-c.MsgCh
if msg == nil {
// Closed
return nil, "", ErrClosed
}
return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil
}
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
host, port, err := utils.SplitHostPort(addr)
if err != nil {
return err
}
msg := udpMessage{
SessionID: c.UDPSessionID,
Host: host,
Port: port,
FragCount: 1,
Data: p,
}
// try no frag first
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &msg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
// need to frag
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := fragUDPMessage(msg, int(errSize))
for _, fragMsg := range fragMsgs {
msgBuf.Reset()
_ = struc.Pack(&msgBuf, &fragMsg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
return err
}
}
return nil
} else {
// some other error
return err
}
} else {
return nil
}
}
func (c *quicPktConn) Close() error {
c.CloseFunc()
return c.Stream.Close()
}
func (c *quicPktConn) LocalAddr() net.Addr {
return c.Session.LocalAddr()
}
func (c *quicPktConn) SetDeadline(t time.Time) error {
return c.Stream.SetDeadline(t)
}
func (c *quicPktConn) SetReadDeadline(t time.Time) error {
return c.Stream.SetReadDeadline(t)
}
func (c *quicPktConn) SetWriteDeadline(t time.Time) error {
return c.Stream.SetWriteDeadline(t)
}

View File

@ -0,0 +1,67 @@
package core
func fragUDPMessage(m udpMessage, maxSize int) []udpMessage {
if m.Size() <= maxSize {
return []udpMessage{m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
off := 0
fragID := uint8(0)
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
var frags []udpMessage
for off < len(fullPayload) {
payloadSize := len(fullPayload) - off
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := m
frag.FragID = fragID
frag.FragCount = fragCount
frag.DataLen = uint16(payloadSize)
frag.Data = fullPayload[off : off+payloadSize]
frags = append(frags, frag)
off += payloadSize
fragID++
}
return frags
}
type defragger struct {
msgID uint16
frags []*udpMessage
count uint8
}
func (d *defragger) Feed(m udpMessage) *udpMessage {
if m.FragCount <= 1 {
return &m
}
if m.FragID >= m.FragCount {
// wtf is this?
return nil
}
if m.MsgID != d.msgID {
// new message, clear previous state
d.msgID = m.MsgID
d.frags = make([]*udpMessage, m.FragCount)
d.count = 1
d.frags[m.FragID] = &m
} else if d.frags[m.FragID] == nil {
d.frags[m.FragID] = &m
d.count++
if int(d.count) == len(d.frags) {
// all fragments received, assemble
var data []byte
for _, frag := range d.frags {
data = append(data, frag.Data...)
}
m.DataLen = uint16(len(data))
m.Data = data
m.FragID = 0
m.FragCount = 1
return &m
}
}
return nil
}

View File

@ -0,0 +1,346 @@
package core
import (
"reflect"
"testing"
)
func Test_fragUDPMessage(t *testing.T) {
type args struct {
m udpMessage
maxSize int
}
tests := []struct {
name string
args args
want []udpMessage
}{
{
"no frag",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 5,
Data: []byte("hello"),
},
100,
},
[]udpMessage{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 5,
Data: []byte("hello"),
},
},
},
{
"2 frags",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 5,
Data: []byte("hello"),
},
22,
},
[]udpMessage{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 2,
DataLen: 4,
Data: []byte("hell"),
},
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 1,
FragCount: 2,
DataLen: 1,
Data: []byte("o"),
},
},
},
{
"4 frags",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 20,
Data: []byte("wow wow wow lol lmao"),
},
23,
},
[]udpMessage{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 4,
DataLen: 5,
Data: []byte("wow w"),
},
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 1,
FragCount: 4,
DataLen: 5,
Data: []byte("ow wo"),
},
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 2,
FragCount: 4,
DataLen: 5,
Data: []byte("w lol"),
},
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 3,
FragCount: 4,
DataLen: 5,
Data: []byte(" lmao"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := fragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
t.Errorf("fragUDPMessage() = %v, want %v", got, tt.want)
}
})
}
}
func Test_defragger_Feed(t *testing.T) {
d := &defragger{}
type args struct {
m udpMessage
}
tests := []struct {
name string
args args
want *udpMessage
}{
{
"no frag",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 5,
Data: []byte("hello"),
},
},
&udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 123,
FragID: 0,
FragCount: 1,
DataLen: 5,
Data: []byte("hello"),
},
},
{
"frag 1 - 1/3",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 666,
FragID: 0,
FragCount: 3,
DataLen: 5,
Data: []byte("hello"),
},
},
nil,
},
{
"frag 1 - 2/3",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 666,
FragID: 1,
FragCount: 3,
DataLen: 8,
Data: []byte(" shitty "),
},
},
nil,
},
{
"frag 1 - 3/3",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 666,
FragID: 2,
FragCount: 3,
DataLen: 7,
Data: []byte("world!!"),
},
},
&udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 666,
FragID: 0,
FragCount: 1,
DataLen: 20,
Data: []byte("hello shitty world!!"),
},
},
{
"frag 2 - 1/2",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 777,
FragID: 0,
FragCount: 2,
DataLen: 5,
Data: []byte("hello"),
},
},
nil,
},
{
"frag 3 - 2/2",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 778,
FragID: 1,
FragCount: 2,
DataLen: 5,
Data: []byte(" moto"),
},
},
nil,
},
{
"frag 2 - 2/2",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 777,
FragID: 1,
FragCount: 2,
DataLen: 5,
Data: []byte(" moto"),
},
},
nil,
},
{
"frag 2 - 1/2 re",
args{
udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 777,
FragID: 0,
FragCount: 2,
DataLen: 5,
Data: []byte("hello"),
},
},
&udpMessage{
SessionID: 123,
HostLen: 4,
Host: "test",
Port: 123,
MsgID: 777,
FragID: 0,
FragCount: 1,
DataLen: 10,
Data: []byte("hello moto"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := d.Feed(tt.args.m); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Feed() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,76 @@
package core
import (
"time"
)
const (
protocolVersion = uint8(3)
protocolVersionV2 = uint8(2)
protocolTimeout = 10 * time.Second
closeErrorCodeGeneric = 0
closeErrorCodeProtocol = 1
closeErrorCodeAuth = 2
)
type transmissionRate struct {
SendBPS uint64
RecvBPS uint64
}
type clientHello struct {
Rate transmissionRate
AuthLen uint16 `struc:"sizeof=Auth"`
Auth []byte
}
type serverHello struct {
OK bool
Rate transmissionRate
MessageLen uint16 `struc:"sizeof=Message"`
Message string
}
type clientRequest struct {
UDP bool
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
}
type serverResponse struct {
OK bool
UDPSessionID uint32
MessageLen uint16 `struc:"sizeof=Message"`
Message string
}
type udpMessage struct {
SessionID uint32
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
FragCount uint8 // must be 1 when not fragmented
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
}
func (m udpMessage) HeaderSize() int {
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
}
func (m udpMessage) Size() int {
return m.HeaderSize() + len(m.Data)
}
type udpMessageV2 struct {
SessionID uint32
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
}

View File

@ -0,0 +1,54 @@
package core
import (
"context"
"github.com/lucas-clemente/quic-go"
"time"
)
// Handle stream close properly
// Ref: https://github.com/libp2p/go-libp2p-quic-transport/blob/master/stream.go
type wrappedQUICStream struct {
Stream quic.Stream
}
func (s *wrappedQUICStream) StreamID() quic.StreamID {
return s.Stream.StreamID()
}
func (s *wrappedQUICStream) Read(p []byte) (n int, err error) {
return s.Stream.Read(p)
}
func (s *wrappedQUICStream) CancelRead(code quic.StreamErrorCode) {
s.Stream.CancelRead(code)
}
func (s *wrappedQUICStream) SetReadDeadline(t time.Time) error {
return s.Stream.SetReadDeadline(t)
}
func (s *wrappedQUICStream) Write(p []byte) (n int, err error) {
return s.Stream.Write(p)
}
func (s *wrappedQUICStream) Close() error {
s.Stream.CancelRead(0)
return s.Stream.Close()
}
func (s *wrappedQUICStream) CancelWrite(code quic.StreamErrorCode) {
s.Stream.CancelWrite(code)
}
func (s *wrappedQUICStream) Context() context.Context {
return s.Stream.Context()
}
func (s *wrappedQUICStream) SetWriteDeadline(t time.Time) error {
return s.Stream.SetWriteDeadline(t)
}
func (s *wrappedQUICStream) SetDeadline(t time.Time) error {
return s.Stream.SetDeadline(t)
}

View File

@ -0,0 +1,6 @@
package obfs
type Obfuscator interface {
Deobfuscate(in []byte, out []byte) int
Obfuscate(in []byte, out []byte) int
}

View File

@ -0,0 +1,52 @@
package obfs
import (
"crypto/sha256"
"math/rand"
"sync"
"time"
)
// [salt][obfuscated payload]
const saltLen = 16
type XPlusObfuscator struct {
Key []byte
RandSrc *rand.Rand
lk sync.Mutex
}
func NewXPlusObfuscator(key []byte) *XPlusObfuscator {
return &XPlusObfuscator{
Key: key,
RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (x *XPlusObfuscator) Deobfuscate(in []byte, out []byte) int {
pLen := len(in) - saltLen
if pLen <= 0 || len(out) < pLen {
// Invalid
return 0
}
key := sha256.Sum256(append(x.Key, in[:saltLen]...))
// Deobfuscate the payload
for i, c := range in[saltLen:] {
out[i] = c ^ key[i%sha256.Size]
}
return pLen
}
func (x *XPlusObfuscator) Obfuscate(in []byte, out []byte) int {
x.lk.Lock()
_, _ = x.RandSrc.Read(out[:saltLen]) // salt
x.lk.Unlock()
// Obfuscate the payload
key := sha256.Sum256(append(x.Key, out[:saltLen]...))
for i, c := range in {
out[i+saltLen] = c ^ key[i%sha256.Size]
}
return len(in) + saltLen
}

View File

@ -0,0 +1,31 @@
package obfs
import (
"bytes"
"testing"
)
func TestXPlusObfuscator(t *testing.T) {
x := NewXPlusObfuscator([]byte("Vaundy"))
tests := []struct {
name string
p []byte
}{
{name: "1", p: []byte("HelloWorld")},
{name: "2", p: []byte("Regret is just a horrible attempt at time travel that ends with you feeling like crap")},
{name: "3", p: []byte("To be, or not to be, that is the question:\nWhether 'tis nobler in the mind to suffer\n" +
"The slings and arrows of outrageous fortune,\nOr to take arms against a sea of troubles\n" +
"And by opposing end them. To die—to sleep,\nNo more; and by a sleep to say we end")},
{name: "empty", p: []byte("")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := make([]byte, 10240)
n := x.Obfuscate(tt.p, buf)
n2 := x.Deobfuscate(buf[:n], buf[n:])
if !bytes.Equal(tt.p, buf[n:n+n2]) {
t.Errorf("Inconsistent deobfuscate result: got %v, want %v", buf[n:n+n2], tt.p)
}
})
}
}

View File

@ -0,0 +1,8 @@
//go:build linux || windows
// +build linux windows
package pmtud_fix
const (
DisablePathMTUDiscovery = false
)

View File

@ -0,0 +1,8 @@
//go:build !linux && !windows
// +build !linux,!windows
package pmtud_fix
const (
DisablePathMTUDiscovery = true
)

View File

@ -0,0 +1,106 @@
package transport
import (
"context"
"crypto/tls"
"fmt"
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/transport/hysteria/conns/faketcp"
"github.com/Dreamacro/clash/transport/hysteria/conns/udp"
"github.com/Dreamacro/clash/transport/hysteria/conns/wechat"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"github.com/Dreamacro/clash/transport/hysteria/utils"
"github.com/lucas-clemente/quic-go"
"net"
)
type ClientTransport struct {
Dialer *net.Dialer
PrefEnabled bool
PrefIPv6 bool
PrefExclusive bool
}
func (ct *ClientTransport) quicPacketConn(proto string, server string, obfs obfs.Obfuscator, dialer PacketDialer) (net.PacketConn, error) {
if len(proto) == 0 || proto == "udp" {
conn, err := dialer.ListenPacket()
if err != nil {
return nil, err
}
if obfs != nil {
oc := udp.NewObfsUDPConn(conn, obfs)
return oc, nil
} else {
return conn, nil
}
} else if proto == "wechat-video" {
conn, err := dialer.ListenPacket()
if err != nil {
return nil, err
}
if obfs != nil {
oc := wechat.NewObfsWeChatUDPConn(conn, obfs)
return oc, nil
} else {
return conn, nil
}
} else if proto == "faketcp" {
var conn *faketcp.TCPConn
conn, err := faketcp.Dial("tcp", server)
if err != nil {
return nil, err
}
if obfs != nil {
oc := faketcp.NewObfsFakeTCPConn(conn, obfs)
return oc, nil
} else {
return conn, nil
}
} else {
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
}
type PacketDialer interface {
ListenPacket() (net.PacketConn, error)
Context() context.Context
}
func (ct *ClientTransport) QUICDial(proto string, server string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator, dialer PacketDialer) (quic.Connection, error) {
ipStr, port, err := utils.SplitHostPort(server)
if err != nil {
return nil, err
}
ip, err := resolver.ResolveProxyServerHost(ipStr)
if err != nil {
return nil, err
}
serverUDPAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
pktConn, err := ct.quicPacketConn(proto, server, obfs, dialer)
if err != nil {
return nil, err
}
qs, err := quic.DialContext(dialer.Context(), pktConn, serverUDPAddr, server, tlsConfig, quicConfig)
if err != nil {
_ = pktConn.Close()
return nil, err
}
return qs, nil
}
func (ct *ClientTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) {
conn, err := ct.Dialer.Dial("tcp", raddr.String())
if err != nil {
return nil, err
}
return conn.(*net.TCPConn), nil
}
func (ct *ClientTransport) ListenUDP() (*net.UDPConn, error) {
return net.ListenUDP("udp", nil)
}

View File

@ -0,0 +1,42 @@
package utils
import (
"net"
"strconv"
)
func SplitHostPort(hostport string) (string, uint16, error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
return "", 0, err
}
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return "", 0, err
}
return host, uint16(portUint), err
}
func ParseIPZone(s string) (net.IP, string) {
s, zone := splitHostZone(s)
return net.ParseIP(s), zone
}
func splitHostZone(s string) (host, zone string) {
if i := last(s, '%'); i > 0 {
host, zone = s[:i], s[i+1:]
} else {
host = s
}
return
}
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}