Chore: use protobytes replace most of bytes.Buffer

This commit is contained in:
Dreamacro
2023-04-17 14:08:39 +08:00
parent df61a586c9
commit b7aade5e11
15 changed files with 244 additions and 236 deletions

View File

@ -1,14 +1,15 @@
package socks4
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"strconv"
"github.com/Dreamacro/clash/component/auth"
"github.com/Dreamacro/protobytes"
)
const Version = 0x04
@ -46,21 +47,17 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
return
}
if req[0] != Version {
r := protobytes.BytesReader(req[:])
if r.ReadUint8() != Version {
err = errVersionMismatched
return
}
if command = req[1]; command != CmdConnect {
if command = r.ReadUint8(); command != CmdConnect {
err = errCommandNotSupported
return
}
var (
dstIP = req[4:8] // [4]byte
dstPort = req[2:4] // [2]byte
)
var (
host string
port string
@ -71,7 +68,9 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
return
}
if isReservedIP(dstIP) {
dstAddr := r.ReadIPv4()
dstPort := r.ReadUint16be()
if isReservedIP(dstAddr) {
var target []byte
if target, err = readUntilNull(rw); err != nil {
return
@ -79,11 +78,11 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
host = string(target)
}
port = strconv.Itoa(int(binary.BigEndian.Uint16(dstPort)))
port = strconv.Itoa(int(dstPort))
if host != "" {
addr = net.JoinHostPort(host, port)
} else {
addr = net.JoinHostPort(net.IP(dstIP).String(), port)
addr = net.JoinHostPort(dstAddr.String(), port)
}
// SOCKS4 only support USERID auth.
@ -94,13 +93,13 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
err = ErrRequestIdentdMismatched
}
var reply [8]byte
reply[0] = 0x00 // reply code
reply[1] = code // result code
copy(reply[4:8], dstIP)
copy(reply[2:4], dstPort)
reply := protobytes.BytesWriter(make([]byte, 0, 8))
reply.PutUint8(0) // reply code
reply.PutUint8(code) // result code
reply.PutSlice(dstAddr.AsSlice())
reply.PutUint16be(dstPort)
_, wErr := rw.Write(reply[:])
_, wErr := rw.Write(reply.Bytes())
if err == nil {
err = wErr
}
@ -118,26 +117,24 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
return err
}
ip := net.ParseIP(host)
if ip == nil /* HOST */ {
ip = net.IPv4(0, 0, 0, 1).To4()
} else if ip.To4() == nil /* IPv6 */ {
ip, err := netip.ParseAddr(host)
if err != nil { // Host
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
} else if ip.Is6() { // IPv6
return errIPv6NotSupported
}
dstIP := ip.To4()
req := protobytes.BytesWriter{}
req.PutUint8(Version)
req.PutUint8(command)
req.PutUint16be(uint16(port))
req.Write(ip.AsSlice())
req.PutString(userID)
req.PutUint8(0) /* NULL */
req := &bytes.Buffer{}
req.WriteByte(Version)
req.WriteByte(command)
binary.Write(req, binary.BigEndian, uint16(port))
req.Write(dstIP)
req.WriteString(userID)
req.WriteByte(0) /* NULL */
if isReservedIP(dstIP) /* SOCKS4A */ {
req.WriteString(host)
req.WriteByte(0) /* NULL */
if isReservedIP(ip) /* SOCKS4A */ {
req.PutString(host)
req.PutUint8(0) /* NULL */
}
if _, err = rw.Write(req.Bytes()); err != nil {
@ -174,17 +171,17 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
// Internet Assigned Numbers Authority -- such an address is inadmissible
// as a destination IP address and thus should never occur if the client
// can resolve the domain name.)
func isReservedIP(ip net.IP) bool {
subnet := net.IPNet{
IP: net.IPv4zero,
Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00),
}
func isReservedIP(ip netip.Addr) bool {
subnet := netip.PrefixFrom(
netip.AddrFrom4([4]byte{0, 0, 0, 0}),
24,
)
return !ip.IsUnspecified() && subnet.Contains(ip)
}
func readUntilNull(r io.Reader) ([]byte, error) {
buf := &bytes.Buffer{}
buf := protobytes.BytesWriter{}
var data [1]byte
for {
@ -194,6 +191,6 @@ func readUntilNull(r io.Reader) ([]byte, error) {
if data[0] == 0 {
return buf.Bytes(), nil
}
buf.WriteByte(data[0])
buf.PutUint8(data[0])
}
}