chore: add WaitReadFrom support in ssr
This commit is contained in:
@ -7,6 +7,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/transport/shadowsocks/shadowaead"
|
||||
"github.com/Dreamacro/clash/transport/shadowsocks/shadowstream"
|
||||
)
|
||||
@ -21,7 +22,7 @@ type StreamConnCipher interface {
|
||||
}
|
||||
|
||||
type PacketConnCipher interface {
|
||||
PacketConn(net.PacketConn) net.PacketConn
|
||||
PacketConn(N.EnhancePacketConn) N.EnhancePacketConn
|
||||
}
|
||||
|
||||
// ErrCipherNotSupported occurs when a cipher is not supported (likely because of security concerns).
|
||||
@ -128,7 +129,7 @@ type AeadCipher struct {
|
||||
}
|
||||
|
||||
func (aead *AeadCipher) StreamConn(c net.Conn) net.Conn { return shadowaead.NewConn(c, aead) }
|
||||
func (aead *AeadCipher) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
func (aead *AeadCipher) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
|
||||
return shadowaead.NewPacketConn(c, aead)
|
||||
}
|
||||
|
||||
@ -139,7 +140,7 @@ type StreamCipher struct {
|
||||
}
|
||||
|
||||
func (ciph *StreamCipher) StreamConn(c net.Conn) net.Conn { return shadowstream.NewConn(c, ciph) }
|
||||
func (ciph *StreamCipher) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
func (ciph *StreamCipher) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
|
||||
return shadowstream.NewPacketConn(c, ciph)
|
||||
}
|
||||
|
||||
@ -147,8 +148,8 @@ func (ciph *StreamCipher) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
|
||||
type dummy struct{}
|
||||
|
||||
func (dummy) StreamConn(c net.Conn) net.Conn { return c }
|
||||
func (dummy) PacketConn(c net.PacketConn) net.PacketConn { return c }
|
||||
func (dummy) StreamConn(c net.Conn) net.Conn { return c }
|
||||
func (dummy) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return c }
|
||||
|
||||
// key-derivation function from original Shadowsocks
|
||||
func Kdf(password string, keyLen int) []byte {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
)
|
||||
|
||||
@ -57,15 +58,15 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
N.EnhancePacketConn
|
||||
Cipher
|
||||
}
|
||||
|
||||
const maxPacketSize = 64 * 1024
|
||||
|
||||
// NewPacketConn wraps a net.PacketConn with cipher
|
||||
func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
|
||||
return &PacketConn{PacketConn: c, Cipher: ciph}
|
||||
// NewPacketConn wraps an N.EnhancePacketConn with cipher
|
||||
func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn {
|
||||
return &PacketConn{EnhancePacketConn: c, Cipher: ciph}
|
||||
}
|
||||
|
||||
// WriteTo encrypts b and write to addr using the embedded PacketConn.
|
||||
@ -76,13 +77,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buf, addr)
|
||||
_, err = c.EnhancePacketConn.WriteTo(buf, addr)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
// ReadFrom reads from the embedded PacketConn and decrypts into b.
|
||||
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(b)
|
||||
n, addr, err := c.EnhancePacketConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
@ -93,3 +94,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
copy(b, bb)
|
||||
return len(bb), addr, err
|
||||
}
|
||||
|
||||
func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
|
||||
data, put, addr, err = c.EnhancePacketConn.WaitReadFrom()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data, err = Unpack(data[c.Cipher.SaltSize():], data, c)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
data = nil
|
||||
put = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
)
|
||||
|
||||
@ -43,13 +44,13 @@ func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) {
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
N.EnhancePacketConn
|
||||
Cipher
|
||||
}
|
||||
|
||||
// NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption.
|
||||
func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
|
||||
return &PacketConn{PacketConn: c, Cipher: ciph}
|
||||
// NewPacketConn wraps an N.EnhancePacketConn with stream cipher encryption/decryption.
|
||||
func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn {
|
||||
return &PacketConn{EnhancePacketConn: c, Cipher: ciph}
|
||||
}
|
||||
|
||||
const maxPacketSize = 64 * 1024
|
||||
@ -61,12 +62,12 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buf, addr)
|
||||
_, err = c.EnhancePacketConn.WriteTo(buf, addr)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(b)
|
||||
n, addr, err := c.EnhancePacketConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
@ -77,3 +78,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
copy(b, bb)
|
||||
return len(bb), addr, err
|
||||
}
|
||||
|
||||
func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
|
||||
data, put, addr, err = c.EnhancePacketConn.WaitReadFrom()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data, err = Unpack(data[c.IVSize():], data, c)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
data = nil
|
||||
put = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/ssr/tools"
|
||||
@ -82,13 +83,13 @@ func (a *authAES128) StreamConn(c net.Conn, iv []byte) net.Conn {
|
||||
return &Conn{Conn: c, Protocol: p}
|
||||
}
|
||||
|
||||
func (a *authAES128) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
func (a *authAES128) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
|
||||
p := &authAES128{
|
||||
Base: a.Base,
|
||||
authAES128Function: a.authAES128Function,
|
||||
userData: a.userData,
|
||||
}
|
||||
return &PacketConn{PacketConn: c, Protocol: p}
|
||||
return &PacketConn{EnhancePacketConn: c, Protocol: p}
|
||||
}
|
||||
|
||||
func (a *authAES128) Decode(dst, src *bytes.Buffer) error {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/shadowsocks/core"
|
||||
@ -83,13 +84,13 @@ func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
|
||||
return &Conn{Conn: c, Protocol: p}
|
||||
}
|
||||
|
||||
func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
func (a *authChainA) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
|
||||
p := &authChainA{
|
||||
Base: a.Base,
|
||||
salt: a.salt,
|
||||
userData: a.userData,
|
||||
}
|
||||
return &PacketConn{PacketConn: c, Protocol: p}
|
||||
return &PacketConn{EnhancePacketConn: c, Protocol: p}
|
||||
}
|
||||
|
||||
func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"hash/crc32"
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/transport/ssr/tools"
|
||||
|
||||
@ -35,7 +36,7 @@ func (a *authSHA1V4) StreamConn(c net.Conn, iv []byte) net.Conn {
|
||||
return &Conn{Conn: c, Protocol: p}
|
||||
}
|
||||
|
||||
func (a *authSHA1V4) PacketConn(c net.PacketConn) net.PacketConn {
|
||||
func (a *authSHA1V4) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
|
||||
return c
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,8 @@ package protocol
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
)
|
||||
|
||||
type origin struct{}
|
||||
@ -13,7 +15,7 @@ func newOrigin(b *Base) Protocol { return &origin{} }
|
||||
|
||||
func (o *origin) StreamConn(c net.Conn, iv []byte) net.Conn { return c }
|
||||
|
||||
func (o *origin) PacketConn(c net.PacketConn) net.PacketConn { return c }
|
||||
func (o *origin) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn { return c }
|
||||
|
||||
func (o *origin) Decode(dst, src *bytes.Buffer) error {
|
||||
dst.ReadFrom(src)
|
||||
|
@ -3,11 +3,12 @@ package protocol
|
||||
import (
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
N.EnhancePacketConn
|
||||
Protocol
|
||||
}
|
||||
|
||||
@ -18,12 +19,12 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = c.PacketConn.WriteTo(buf.Bytes(), addr)
|
||||
_, err = c.EnhancePacketConn.WriteTo(buf.Bytes(), addr)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.PacketConn.ReadFrom(b)
|
||||
n, addr, err := c.EnhancePacketConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
@ -34,3 +35,20 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
copy(b, decoded)
|
||||
return len(decoded), addr, nil
|
||||
}
|
||||
|
||||
func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
|
||||
data, put, addr, err = c.EnhancePacketConn.WaitReadFrom()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data, err = c.DecodePacket(data)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
data = nil
|
||||
put = nil
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
|
||||
"github.com/zhangyunhao116/fastrand"
|
||||
)
|
||||
|
||||
@ -22,7 +24,7 @@ var (
|
||||
|
||||
type Protocol interface {
|
||||
StreamConn(net.Conn, []byte) net.Conn
|
||||
PacketConn(net.PacketConn) net.PacketConn
|
||||
PacketConn(N.EnhancePacketConn) N.EnhancePacketConn
|
||||
Decode(dst, src *bytes.Buffer) error
|
||||
Encode(buf *bytes.Buffer, b []byte) error
|
||||
DecodePacket([]byte) ([]byte, error)
|
||||
|
@ -370,7 +370,9 @@ func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, er
|
||||
|
||||
_, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF
|
||||
if err != nil {
|
||||
put()
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
length := binary.BigEndian.Uint16(data)
|
||||
@ -379,11 +381,15 @@ func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, er
|
||||
data = data[:length]
|
||||
_, err = io.ReadFull(pc.Conn, data)
|
||||
if err != nil {
|
||||
put()
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
} else {
|
||||
put()
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
return nil, nil, addr, nil
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,6 @@ func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net
|
||||
return
|
||||
}
|
||||
data = packet.DATA
|
||||
put = N.NilPut
|
||||
addr = packet.ADDR.UDPAddr()
|
||||
} else {
|
||||
err = net.ErrClosed
|
||||
|
Reference in New Issue
Block a user