Fix(socks5): fully udp associate support (#233)
This commit is contained in:
98
component/nat-table/nat.go
Normal file
98
component/nat-table/nat.go
Normal file
@ -0,0 +1,98 @@
|
||||
package nat
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
*table
|
||||
}
|
||||
|
||||
type table struct {
|
||||
mapping sync.Map
|
||||
janitor *janitor
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
type element struct {
|
||||
Expired time.Time
|
||||
RemoteAddr net.Addr
|
||||
RemoteConn net.PacketConn
|
||||
}
|
||||
|
||||
func (t *table) Set(key net.Addr, rConn net.PacketConn, rAddr net.Addr) {
|
||||
// set conn read timeout
|
||||
rConn.SetReadDeadline(time.Now().Add(t.timeout))
|
||||
t.mapping.Store(key, &element{
|
||||
RemoteAddr: rAddr,
|
||||
RemoteConn: rConn,
|
||||
Expired: time.Now().Add(t.timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *table) Get(key net.Addr) (rConn net.PacketConn, rAddr net.Addr) {
|
||||
item, exist := t.mapping.Load(key)
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
elm := item.(*element)
|
||||
// expired
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
t.mapping.Delete(key)
|
||||
elm.RemoteConn.Close()
|
||||
return
|
||||
}
|
||||
// reset expired time
|
||||
elm.Expired = time.Now().Add(t.timeout)
|
||||
return elm.RemoteConn, elm.RemoteAddr
|
||||
}
|
||||
|
||||
func (t *table) cleanup() {
|
||||
t.mapping.Range(func(k, v interface{}) bool {
|
||||
key := k.(net.Addr)
|
||||
elm := v.(*element)
|
||||
if time.Since(elm.Expired) > 0 {
|
||||
t.mapping.Delete(key)
|
||||
elm.RemoteConn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
type janitor struct {
|
||||
interval time.Duration
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func (j *janitor) process(t *table) {
|
||||
ticker := time.NewTicker(j.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.cleanup()
|
||||
case <-j.stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stopJanitor(t *Table) {
|
||||
t.janitor.stop <- struct{}{}
|
||||
}
|
||||
|
||||
// New return *Cache
|
||||
func New(interval time.Duration) *Table {
|
||||
j := &janitor{
|
||||
interval: interval,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
t := &table{janitor: j, timeout: interval}
|
||||
go j.process(t)
|
||||
T := &Table{t}
|
||||
runtime.SetFinalizer(T, stopJanitor)
|
||||
return T
|
||||
}
|
@ -41,7 +41,25 @@ const MaxAddrLen = 1 + 1 + 255 + 2
|
||||
const MaxAuthLen = 255
|
||||
|
||||
// Addr represents a SOCKS address as defined in RFC 1928 section 5.
|
||||
type Addr = []byte
|
||||
type Addr []byte
|
||||
|
||||
func (a Addr) String() string {
|
||||
var host, port string
|
||||
|
||||
switch a[0] {
|
||||
case AtypDomainName:
|
||||
host = string(a[2 : 2+int(a[1])])
|
||||
port = strconv.Itoa((int(a[2+int(a[1])]) << 8) | int(a[2+int(a[1])+1]))
|
||||
case AtypIPv4:
|
||||
host = net.IP(a[1 : 1+net.IPv4len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
|
||||
case AtypIPv6:
|
||||
host = net.IP(a[1 : 1+net.IPv6len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
|
||||
}
|
||||
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// SOCKS errors as defined in RFC 1928 section 6.
|
||||
const (
|
||||
@ -138,23 +156,33 @@ func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr,
|
||||
return
|
||||
}
|
||||
|
||||
if buf[1] != CmdConnect && buf[1] != CmdUDPAssociate {
|
||||
err = ErrCommandNotSupported
|
||||
return
|
||||
}
|
||||
|
||||
command = buf[1]
|
||||
addr, err = readAddr(rw, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write VER REP RSV ATYP BND.ADDR BND.PORT
|
||||
_, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0})
|
||||
|
||||
switch command {
|
||||
case CmdConnect, CmdUDPAssociate:
|
||||
// Acquire server listened address info
|
||||
localAddr := ParseAddr(rw.LocalAddr().String())
|
||||
if localAddr == nil {
|
||||
err = ErrAddressNotSupported
|
||||
} else {
|
||||
// write VER REP RSV ATYP BND.ADDR BND.PORT
|
||||
_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
|
||||
}
|
||||
case CmdBind:
|
||||
fallthrough
|
||||
default:
|
||||
err = ErrCommandNotSupported
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
|
||||
func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) error {
|
||||
func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
|
||||
buf := make([]byte, MaxAddrLen)
|
||||
var err error
|
||||
|
||||
@ -165,16 +193,16 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) e
|
||||
_, err = rw.Write([]byte{5, 1, 0})
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// VER, METHOD
|
||||
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buf[0] != 5 {
|
||||
return errors.New("SOCKS version error")
|
||||
return nil, errors.New("SOCKS version error")
|
||||
}
|
||||
|
||||
if buf[1] == 2 {
|
||||
@ -187,30 +215,31 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, cammand Command, user *User) e
|
||||
authMsg.WriteString(user.Password)
|
||||
|
||||
if _, err := rw.Write(authMsg.Bytes()); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buf[1] != 0 {
|
||||
return errors.New("rejected username/password")
|
||||
return nil, errors.New("rejected username/password")
|
||||
}
|
||||
} else if buf[1] != 0 {
|
||||
return errors.New("SOCKS need auth")
|
||||
return nil, errors.New("SOCKS need auth")
|
||||
}
|
||||
|
||||
// VER, CMD, RSV, ADDR
|
||||
if _, err := rw.Write(bytes.Join([][]byte{{5, cammand, 0}, addr}, []byte(""))); err != nil {
|
||||
return err
|
||||
if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(rw, buf[:10]); err != nil {
|
||||
return err
|
||||
// VER, REP, RSV
|
||||
if _, err := io.ReadFull(rw, buf[:3]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
return readAddr(rw, buf)
|
||||
}
|
||||
|
||||
func readAddr(r io.Reader, b []byte) (Addr, error) {
|
||||
@ -307,3 +336,39 @@ func ParseAddr(s string) Addr {
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
|
||||
if len(packet) < 5 {
|
||||
err = errors.New("insufficient length of packet")
|
||||
return
|
||||
}
|
||||
|
||||
// packet[0] and packet[1] are reserved
|
||||
if !bytes.Equal(packet[:2], []byte{0, 0}) {
|
||||
err = errors.New("reserved fields should be zero")
|
||||
return
|
||||
}
|
||||
|
||||
if packet[2] != 0 /* fragments */ {
|
||||
err = errors.New("discarding fragmented payload")
|
||||
return
|
||||
}
|
||||
|
||||
addr = SplitAddr(packet[3:])
|
||||
if addr == nil {
|
||||
err = errors.New("failed to read UDP header")
|
||||
}
|
||||
|
||||
payload = bytes.Join([][]byte{packet[3+len(addr):]}, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
func EncodeUDPPacket(addr string, payload []byte) (packet []byte, err error) {
|
||||
rAddr := ParseAddr(addr)
|
||||
if rAddr == nil {
|
||||
err = errors.New("cannot parse addr")
|
||||
return
|
||||
}
|
||||
packet = bytes.Join([][]byte{{0, 0, 0}, rAddr, payload}, []byte{})
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user