From f750bc96cbf15cabe61368c0cfd007d3995e9cbc Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 15 Jun 2022 04:29:19 +0800 Subject: [PATCH] Chore: code style --- adapter/outbound/direct.go | 2 +- adapter/outbound/mitm.go | 1 + adapter/outbound/util.go | 1 + adapter/outbound/vmess.go | 4 +- adapter/outboundgroup/util.go | 1 + common/net/relay.go | 10 +--- .../{tproxy_iptables.go => iptables.go} | 0 listener/tproxy/packet.go | 6 +- listener/tproxy/tcp.go | 2 +- listener/tproxy/udp.go | 2 +- .../tun/ipstack/gvisor/adapter/adapter.go | 8 --- listener/tun/ipstack/gvisor/handler.go | 37 ++++-------- listener/tun/ipstack/gvisor/tcp.go | 6 -- listener/tun/ipstack/gvisor/udp.go | 9 +-- listener/tun/ipstack/system/mars/nat/nat.go | 7 +-- listener/tun/ipstack/system/mars/nat/table.go | 35 +++++++++++- listener/tun/ipstack/system/mars/nat/tcp.go | 28 +++++---- listener/tun/ipstack/system/mars/nat/udp.go | 47 +++++---------- listener/tun/ipstack/system/stack.go | 57 ++++++++----------- listener/tun/ipstack/system/udp.go | 17 ++++-- tunnel/connection.go | 12 ++-- 21 files changed, 135 insertions(+), 157 deletions(-) rename listener/tproxy/{tproxy_iptables.go => iptables.go} (100%) diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 2fcd6c6e..f70563a7 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -23,7 +23,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ... tcpKeepAlive(c) - if !metadata.DstIP.IsValid() && c.RemoteAddr() != nil { + if !metadata.Resolved() && c.RemoteAddr() != nil { if h, _, err := net.SplitHostPort(c.RemoteAddr().String()); err == nil { metadata.DstIP = netip.MustParseAddr(h) } diff --git a/adapter/outbound/mitm.go b/adapter/outbound/mitm.go index 95499ab6..fe6aef3d 100644 --- a/adapter/outbound/mitm.go +++ b/adapter/outbound/mitm.go @@ -24,6 +24,7 @@ func (m *Mitm) DialContext(_ context.Context, metadata *C.Metadata, _ ...dialer. _ = c.SetKeepAlive(true) _ = c.SetKeepAlivePeriod(60 * time.Second) + _ = c.SetLinger(0) metadata.Type = C.MITM diff --git a/adapter/outbound/util.go b/adapter/outbound/util.go index 29d6ac08..2a99adb9 100644 --- a/adapter/outbound/util.go +++ b/adapter/outbound/util.go @@ -15,6 +15,7 @@ func tcpKeepAlive(c net.Conn) { if tcp, ok := c.(*net.TCPConn); ok { _ = tcp.SetKeepAlive(true) _ = tcp.SetKeepAlivePeriod(30 * time.Second) + _ = tcp.SetLinger(0) } } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 8f63533e..47035a82 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -116,7 +116,9 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { wsOpts.TLSConfig.ServerName = host } } else { - wsOpts.Headers.Set("Host", convert.RandHost()) + if wsOpts.Headers.Get("Host") == "" { + wsOpts.Headers.Set("Host", convert.RandHost()) + } convert.SetUserAgent(wsOpts.Headers) } c, err = vmess.StreamWebsocketConn(c, wsOpts) diff --git a/adapter/outboundgroup/util.go b/adapter/outboundgroup/util.go index 46b86826..1c03c39d 100644 --- a/adapter/outboundgroup/util.go +++ b/adapter/outboundgroup/util.go @@ -48,5 +48,6 @@ func tcpKeepAlive(c net.Conn) { if tcp, ok := c.(*net.TCPConn); ok { _ = tcp.SetKeepAlive(true) _ = tcp.SetKeepAlivePeriod(30 * time.Second) + _ = tcp.SetLinger(0) } } diff --git a/common/net/relay.go b/common/net/relay.go index 6035a412..beaae77b 100644 --- a/common/net/relay.go +++ b/common/net/relay.go @@ -4,8 +4,6 @@ import ( "io" "net" "time" - - "github.com/Dreamacro/clash/common/pool" ) // Relay copies between left and right bidirectionally. @@ -16,18 +14,14 @@ func Relay(leftConn, rightConn net.Conn) { tcpKeepAlive(rightConn) go func() { - buf := pool.Get(pool.RelayBufferSize) // Wrapping to avoid using *net.TCPConn.(ReadFrom) // See also https://github.com/Dreamacro/clash/pull/1209 - _, err := io.CopyBuffer(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}, buf) - _ = pool.Put(buf) + _, err := io.Copy(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}) _ = leftConn.SetReadDeadline(time.Now()) ch <- err }() - buf := pool.Get(pool.RelayBufferSize) - _, _ = io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf) - _ = pool.Put(buf) + _, _ = io.Copy(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}) _ = rightConn.SetReadDeadline(time.Now()) <-ch } diff --git a/listener/tproxy/tproxy_iptables.go b/listener/tproxy/iptables.go similarity index 100% rename from listener/tproxy/tproxy_iptables.go rename to listener/tproxy/iptables.go diff --git a/listener/tproxy/packet.go b/listener/tproxy/packet.go index 9299df9d..54f29f89 100644 --- a/listener/tproxy/packet.go +++ b/listener/tproxy/packet.go @@ -24,15 +24,15 @@ func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { return } n, err = tc.Write(b) - tc.Close() + _ = tc.Close() return } // LocalAddr returns the source IP/Port of UDP Packet func (c *packet) LocalAddr() net.Addr { - return &net.UDPAddr{IP: c.lAddr.Addr().AsSlice(), Port: int(c.lAddr.Port()), Zone: c.lAddr.Addr().Zone()} + return net.UDPAddrFromAddrPort(c.lAddr) } func (c *packet) Drop() { - pool.Put(c.buf) + _ = pool.Put(c.buf) } diff --git a/listener/tproxy/tcp.go b/listener/tproxy/tcp.go index 1a09f366..6d6a5bee 100644 --- a/listener/tproxy/tcp.go +++ b/listener/tproxy/tcp.go @@ -32,7 +32,7 @@ func (l *Listener) Close() error { func (l *Listener) handleTProxy(conn net.Conn, in chan<- C.ConnContext) { target := socks5.ParseAddrToSocksAddr(conn.LocalAddr()) - conn.(*net.TCPConn).SetKeepAlive(true) + _ = conn.(*net.TCPConn).SetKeepAlive(true) in <- inbound.NewSocket(target, conn, C.TPROXY) } diff --git a/listener/tproxy/udp.go b/listener/tproxy/udp.go index 60783563..cf8bbd92 100644 --- a/listener/tproxy/udp.go +++ b/listener/tproxy/udp.go @@ -61,7 +61,7 @@ func NewUDP(addr string, in chan<- *inbound.PacketAdapter) (*UDPListener, error) buf := pool.Get(pool.UDPBufferSize) n, oobn, _, lAddr, err := c.ReadMsgUDPAddrPort(buf, oob) if err != nil { - pool.Put(buf) + _ = pool.Put(buf) if rl.closed { break } diff --git a/listener/tun/ipstack/gvisor/adapter/adapter.go b/listener/tun/ipstack/gvisor/adapter/adapter.go index 9a5649ef..e4e42965 100644 --- a/listener/tun/ipstack/gvisor/adapter/adapter.go +++ b/listener/tun/ipstack/gvisor/adapter/adapter.go @@ -2,23 +2,15 @@ package adapter import ( "net" - - "gvisor.dev/gvisor/pkg/tcpip/stack" ) // TCPConn implements the net.Conn interface. type TCPConn interface { net.Conn - - // ID returns the transport endpoint id of TCPConn. - ID() *stack.TransportEndpointID } // UDPConn implements net.Conn and net.PacketConn. type UDPConn interface { net.Conn net.PacketConn - - // ID returns the transport endpoint id of UDPConn. - ID() *stack.TransportEndpointID } diff --git a/listener/tun/ipstack/gvisor/handler.go b/listener/tun/ipstack/gvisor/handler.go index 476374bd..862d775a 100644 --- a/listener/tun/ipstack/gvisor/handler.go +++ b/listener/tun/ipstack/gvisor/handler.go @@ -7,7 +7,6 @@ import ( "time" "github.com/Dreamacro/clash/adapter/inbound" - "github.com/Dreamacro/clash/common/nnip" "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" D "github.com/Dreamacro/clash/listener/tun/ipstack/commons" @@ -27,15 +26,7 @@ type gvHandler struct { } func (gh *gvHandler) HandleTCP(tunConn adapter.TCPConn) { - id := tunConn.ID() - - rAddr := &net.UDPAddr{ - IP: net.IP(id.LocalAddress), - Port: int(id.LocalPort), - Zone: "", - } - - rAddrPort := netip.AddrPortFrom(nnip.IpToAddr(rAddr.IP), id.LocalPort) + rAddrPort := tunConn.LocalAddr().(*net.TCPAddr).AddrPort() if D.ShouldHijackDns(gh.dnsHijack, rAddrPort, "tcp") { go func() { @@ -43,8 +34,8 @@ func (gh *gvHandler) HandleTCP(tunConn adapter.TCPConn) { buf := pool.Get(pool.UDPBufferSize) defer func() { - _ = pool.Put(buf) _ = tunConn.Close() + _ = pool.Put(buf) }() for { @@ -78,26 +69,18 @@ func (gh *gvHandler) HandleTCP(tunConn adapter.TCPConn) { return } - gh.tcpIn <- inbound.NewSocket(socks5.ParseAddrToSocksAddr(rAddr), tunConn, C.TUN) + gh.tcpIn <- inbound.NewSocket(socks5.AddrFromStdAddrPort(rAddrPort), tunConn, C.TUN) } func (gh *gvHandler) HandleUDP(tunConn adapter.UDPConn) { - id := tunConn.ID() - - rAddr := &net.UDPAddr{ - IP: net.IP(id.LocalAddress), - Port: int(id.LocalPort), - Zone: "", - } - - rAddrPort := netip.AddrPortFrom(nnip.IpToAddr(rAddr.IP), id.LocalPort) + rAddrPort := tunConn.LocalAddr().(*net.UDPAddr).AddrPort() if rAddrPort.Addr() == gh.gateway { _ = tunConn.Close() return } - target := socks5.ParseAddrToSocksAddr(rAddr) + target := socks5.AddrFromStdAddrPort(rAddrPort) go func() { for { @@ -109,22 +92,20 @@ func (gh *gvHandler) HandleUDP(tunConn adapter.UDPConn) { break } - payload := buf[:n] - if D.ShouldHijackDns(gh.dnsHijack, rAddrPort, "udp") { go func() { defer func() { _ = pool.Put(buf) }() - msg, err1 := D.RelayDnsPacket(payload) + msg, err1 := D.RelayDnsPacket(buf[:n]) if err1 != nil { return } _, _ = tunConn.WriteTo(msg, addr) - log.Debugln("[TUN] hijack dns udp: %s", rAddr.String()) + log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String()) }() continue @@ -133,12 +114,14 @@ func (gh *gvHandler) HandleUDP(tunConn adapter.UDPConn) { gvPacket := &packet{ pc: tunConn, rAddr: addr, - payload: payload, + payload: buf, + offset: n, } select { case gh.udpIn <- inbound.NewPacket(target, gvPacket, C.TUN): default: + gvPacket.Drop() } } }() diff --git a/listener/tun/ipstack/gvisor/tcp.go b/listener/tun/ipstack/gvisor/tcp.go index 61f5d90e..0c893ce4 100644 --- a/listener/tun/ipstack/gvisor/tcp.go +++ b/listener/tun/ipstack/gvisor/tcp.go @@ -70,7 +70,6 @@ func withTCPHandler(handle adapter.TCPHandleFunc) option.Option { conn := &tcpConn{ TCPConn: gonet.NewTCPConn(&wq, ep), - id: id, } handle(conn) }) @@ -113,9 +112,4 @@ func setSocketOptions(s *stack.Stack, ep tcpip.Endpoint) tcpip.Error { type tcpConn struct { *gonet.TCPConn - id stack.TransportEndpointID -} - -func (c *tcpConn) ID() *stack.TransportEndpointID { - return &c.id } diff --git a/listener/tun/ipstack/gvisor/udp.go b/listener/tun/ipstack/gvisor/udp.go index 502e3a9c..b516068c 100644 --- a/listener/tun/ipstack/gvisor/udp.go +++ b/listener/tun/ipstack/gvisor/udp.go @@ -29,7 +29,6 @@ func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { conn := &udpConn{ UDPConn: gonet.NewUDPConn(s, &wq, ep), - id: id, } handle(conn) }) @@ -40,21 +39,17 @@ func withUDPHandler(handle adapter.UDPHandleFunc) option.Option { type udpConn struct { *gonet.UDPConn - id stack.TransportEndpointID -} - -func (c *udpConn) ID() *stack.TransportEndpointID { - return &c.id } type packet struct { pc adapter.UDPConn rAddr net.Addr payload []byte + offset int } func (c *packet) Data() []byte { - return c.payload + return c.payload[:c.offset] } // WriteBack write UDP packet with source(ip, port) = `addr` diff --git a/listener/tun/ipstack/system/mars/nat/nat.go b/listener/tun/ipstack/system/mars/nat/nat.go index 9f6f57d2..b10c4070 100644 --- a/listener/tun/ipstack/system/mars/nat/nat.go +++ b/listener/tun/ipstack/system/mars/nat/nat.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/tcpip" ) @@ -22,7 +21,7 @@ func Start(device io.ReadWriter, gateway, portal, broadcast netip.Addr) (*TCP, * tab := newTable() udp := &UDP{ device: device, - buf: [pool.UDPBufferSize]byte{}, + buf: [0xffff]byte{}, } tcp := &TCP{ listener: listener, @@ -38,7 +37,7 @@ func Start(device io.ReadWriter, gateway, portal, broadcast netip.Addr) (*TCP, * _ = udp.Close() }() - buf := make([]byte, pool.RelayBufferSize) + buf := make([]byte, 0xffff) for { n, err := device.Read(buf) @@ -152,7 +151,7 @@ func Start(device io.ReadWriter, gateway, portal, broadcast netip.Addr) (*TCP, * continue } - udp.handleUDPPacket(ip, u) + go udp.handleUDPPacket(ip, u) case tcpip.ICMP: i := tcpip.ICMPPacket(ip.Payload()) diff --git a/listener/tun/ipstack/system/mars/nat/table.go b/listener/tun/ipstack/system/mars/nat/table.go index 38b7d6c6..c22be06b 100644 --- a/listener/tun/ipstack/system/mars/nat/table.go +++ b/listener/tun/ipstack/system/mars/nat/table.go @@ -2,8 +2,11 @@ package nat import ( "net/netip" + "sync" "github.com/Dreamacro/clash/common/generics/list" + + "golang.org/x/exp/maps" ) const ( @@ -27,6 +30,7 @@ type table struct { tuples map[tuple]*list.Element[*binding] ports [portLength]*list.Element[*binding] available *list.List[*binding] + mux sync.Mutex } func (t *table) tupleOf(port uint16) tuple { @@ -43,10 +47,13 @@ func (t *table) tupleOf(port uint16) tuple { } func (t *table) portOf(tuple tuple) uint16 { + t.mux.Lock() elm := t.tuples[tuple] if elm == nil { + t.mux.Unlock() return 0 } + t.mux.Unlock() t.available.MoveToFront(elm) @@ -54,18 +61,40 @@ func (t *table) portOf(tuple tuple) uint16 { } func (t *table) newConn(tuple tuple) uint16 { - elm := t.available.Back() + t.mux.Lock() + elm := t.availableConn() b := elm.Value - - delete(t.tuples, b.tuple) t.tuples[tuple] = elm b.tuple = tuple + t.mux.Unlock() t.available.MoveToFront(elm) return portBegin + b.offset } +func (t *table) availableConn() *list.Element[*binding] { + elm := t.available.Back() + offset := elm.Value.offset + _, ok := t.tuples[t.ports[offset].Value.tuple] + if !ok { + if offset != 0 && offset%portLength == 0 { // resize + tuples := make(map[tuple]*list.Element[*binding], portLength) + maps.Copy(tuples, t.tuples) + t.tuples = tuples + } + return elm + } + t.available.MoveToFront(elm) + return t.availableConn() +} + +func (t *table) closeConn(tuple tuple) { + t.mux.Lock() + delete(t.tuples, tuple) + t.mux.Unlock() +} + func newTable() *table { result := &table{ tuples: make(map[tuple]*list.Element[*binding], portLength), diff --git a/listener/tun/ipstack/system/mars/nat/tcp.go b/listener/tun/ipstack/system/mars/nat/tcp.go index cc0abe7d..2ad9c025 100644 --- a/listener/tun/ipstack/system/mars/nat/tcp.go +++ b/listener/tun/ipstack/system/mars/nat/tcp.go @@ -16,6 +16,8 @@ type conn struct { net.Conn tuple tuple + + close func() } func (t *TCP) Accept() (net.Conn, error) { @@ -24,9 +26,9 @@ func (t *TCP) Accept() (net.Conn, error) { return nil, err } - addr := c.RemoteAddr().(*net.TCPAddr) - tup := t.table.tupleOf(uint16(addr.Port)) - if !addr.IP.Equal(t.portal.AsSlice()) || tup == zeroTuple { + addr := c.RemoteAddr().(*net.TCPAddr).AddrPort() + tup := t.table.tupleOf(addr.Port()) + if addr.Addr() != t.portal || tup == zeroTuple { _ = c.Close() return nil, net.InvalidAddrError("unknown remote addr") @@ -34,9 +36,14 @@ func (t *TCP) Accept() (net.Conn, error) { addition(c) + _ = c.SetLinger(0) + return &conn{ Conn: c, tuple: tup, + close: func() { + t.table.closeConn(tup) + }, }, nil } @@ -52,16 +59,15 @@ func (t *TCP) SetDeadline(time time.Time) error { return t.listener.SetDeadline(time) } +func (c *conn) Close() error { + c.close() + return c.Conn.Close() +} + func (c *conn) LocalAddr() net.Addr { - return &net.TCPAddr{ - IP: c.tuple.SourceAddr.Addr().AsSlice(), - Port: int(c.tuple.SourceAddr.Port()), - } + return net.TCPAddrFromAddrPort(c.tuple.SourceAddr) } func (c *conn) RemoteAddr() net.Addr { - return &net.TCPAddr{ - IP: c.tuple.DestinationAddr.Addr().AsSlice(), - Port: int(c.tuple.DestinationAddr.Port()), - } + return net.TCPAddrFromAddrPort(c.tuple.DestinationAddr) } diff --git a/listener/tun/ipstack/system/mars/nat/udp.go b/listener/tun/ipstack/system/mars/nat/udp.go index 6cc7faee..4489ddaf 100644 --- a/listener/tun/ipstack/system/mars/nat/udp.go +++ b/listener/tun/ipstack/system/mars/nat/udp.go @@ -7,8 +7,6 @@ import ( "net/netip" "sync" - "github.com/Dreamacro/clash/common/nnip" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/listener/tun/ipstack/system/mars/tcpip" ) @@ -16,8 +14,8 @@ type call struct { cond *sync.Cond buf []byte n int - source net.Addr - destination net.Addr + source netip.AddrPort + destination netip.AddrPort } type UDP struct { @@ -26,10 +24,10 @@ type UDP struct { queueLock sync.Mutex queue []*call bufLock sync.Mutex - buf [pool.UDPBufferSize]byte + buf [0xffff]byte } -func (u *UDP) ReadFrom(buf []byte) (int, net.Addr, net.Addr, error) { +func (u *UDP) ReadFrom(buf []byte) (int, netip.AddrPort, netip.AddrPort, error) { u.queueLock.Lock() defer u.queueLock.Unlock() @@ -38,8 +36,8 @@ func (u *UDP) ReadFrom(buf []byte) (int, net.Addr, net.Addr, error) { cond: sync.NewCond(&u.queueLock), buf: buf, n: -1, - source: nil, - destination: nil, + source: netip.AddrPort{}, + destination: netip.AddrPort{}, } u.queue = append(u.queue, c) @@ -51,10 +49,10 @@ func (u *UDP) ReadFrom(buf []byte) (int, net.Addr, net.Addr, error) { } } - return -1, nil, nil, net.ErrClosed + return -1, netip.AddrPort{}, netip.AddrPort{}, net.ErrClosed } -func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) { +func (u *UDP) WriteTo(buf []byte, local netip.AddrPort, remote netip.AddrPort) (int, error) { if u.closed { return 0, net.ErrClosed } @@ -66,16 +64,7 @@ func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) return 0, net.InvalidAddrError("invalid ip version") } - srcAddr, srcOk := local.(*net.UDPAddr) - dstAddr, dstOk := remote.(*net.UDPAddr) - if !srcOk || !dstOk { - return 0, net.InvalidAddrError("invalid addr") - } - - srcAddrPort := netip.AddrPortFrom(nnip.IpToAddr(srcAddr.IP), uint16(srcAddr.Port)) - dstAddrPort := netip.AddrPortFrom(nnip.IpToAddr(dstAddr.IP), uint16(dstAddr.Port)) - - if !srcAddrPort.Addr().Is4() || !dstAddrPort.Addr().Is4() { + if !local.Addr().Is4() || !remote.Addr().Is4() { return 0, net.InvalidAddrError("invalid ip version") } @@ -89,13 +78,13 @@ func (u *UDP) WriteTo(buf []byte, local net.Addr, remote net.Addr) (int, error) ip.SetFragmentOffset(0) ip.SetTimeToLive(64) ip.SetProtocol(tcpip.UDP) - ip.SetSourceIP(srcAddrPort.Addr()) - ip.SetDestinationIP(dstAddrPort.Addr()) + ip.SetSourceIP(local.Addr()) + ip.SetDestinationIP(remote.Addr()) udp := tcpip.UDPPacket(ip.Payload()) udp.SetLength(tcpip.UDPHeaderSize + uint16(len(buf))) - udp.SetSourcePort(srcAddrPort.Port()) - udp.SetDestinationPort(dstAddrPort.Port()) + udp.SetSourcePort(local.Port()) + udp.SetDestinationPort(remote.Port()) copy(udp.Payload(), buf) ip.ResetChecksum() @@ -131,14 +120,8 @@ func (u *UDP) handleUDPPacket(ip tcpip.IP, pkt tcpip.UDPPacket) { u.queueLock.Unlock() if c != nil { - c.source = &net.UDPAddr{ - IP: ip.SourceIP().AsSlice(), - Port: int(pkt.SourcePort()), - } - c.destination = &net.UDPAddr{ - IP: ip.DestinationIP().AsSlice(), - Port: int(pkt.DestinationPort()), - } + c.source = netip.AddrPortFrom(ip.SourceIP(), pkt.SourcePort()) + c.destination = netip.AddrPortFrom(ip.DestinationIP(), pkt.DestinationPort()) c.n = copy(c.buf, pkt.Payload()) c.cond.Signal() } diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index fd6b7cde..095be906 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -81,26 +81,23 @@ func New(device device.Device, dnsHijack []C.DNSUrl, tunAddress netip.Prefix, tc continue } - lAddr := conn.LocalAddr().(*net.TCPAddr) - rAddr := conn.RemoteAddr().(*net.TCPAddr) + lAddr := conn.LocalAddr().(*net.TCPAddr).AddrPort() + rAddr := conn.RemoteAddr().(*net.TCPAddr).AddrPort() - lAddrPort := netip.AddrPortFrom(nnip.IpToAddr(lAddr.IP), uint16(lAddr.Port)) - rAddrPort := netip.AddrPortFrom(nnip.IpToAddr(rAddr.IP), uint16(rAddr.Port)) - - if rAddrPort.Addr().IsLoopback() { + if rAddr.Addr().IsLoopback() { _ = conn.Close() continue } - if D.ShouldHijackDns(dnsAddr, rAddrPort, "tcp") { + if D.ShouldHijackDns(dnsAddr, rAddr, "tcp") { go func() { - log.Debugln("[TUN] hijack dns tcp: %s", rAddrPort.String()) + log.Debugln("[TUN] hijack dns tcp: %s", rAddr.String()) buf := pool.Get(pool.UDPBufferSize) defer func() { - _ = pool.Put(buf) _ = conn.Close() + _ = pool.Put(buf) }() for { @@ -137,10 +134,10 @@ func New(device device.Device, dnsHijack []C.DNSUrl, tunAddress netip.Prefix, tc metadata := &C.Metadata{ NetWork: C.TCP, Type: C.TUN, - SrcIP: lAddrPort.Addr(), - DstIP: rAddrPort.Addr(), - SrcPort: strconv.Itoa(lAddr.Port), - DstPort: strconv.Itoa(rAddr.Port), + SrcIP: lAddr.Addr(), + DstIP: rAddr.Addr(), + SrcPort: strconv.FormatUint(uint64(lAddr.Port()), 10), + DstPort: strconv.FormatUint(uint64(rAddr.Port()), 10), AddrType: C.AtypIPv4, Host: "", } @@ -159,56 +156,50 @@ func New(device device.Device, dnsHijack []C.DNSUrl, tunAddress netip.Prefix, tc for !ipStack.closed { buf := pool.Get(pool.UDPBufferSize) - n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) + n, lAddr, rAddr, err := stack.UDP().ReadFrom(buf) if err != nil { _ = pool.Put(buf) break } - raw := buf[:n] - lAddr := lRAddr.(*net.UDPAddr) - rAddr := rRAddr.(*net.UDPAddr) - - rAddrPort := netip.AddrPortFrom(nnip.IpToAddr(rAddr.IP), uint16(rAddr.Port)) - - if rAddrPort.Addr().IsLoopback() || rAddrPort.Addr() == gateway { + if rAddr.Addr().IsLoopback() || rAddr.Addr() == gateway { _ = pool.Put(buf) continue } - if D.ShouldHijackDns(dnsAddr, rAddrPort, "udp") { + if D.ShouldHijackDns(dnsAddr, rAddr, "udp") { go func() { - msg, err := D.RelayDnsPacket(raw) - if err != nil { + defer func() { _ = pool.Put(buf) + }() + + msg, err := D.RelayDnsPacket(buf[:n]) + if err != nil { return } _, _ = stack.UDP().WriteTo(msg, rAddr, lAddr) - _ = pool.Put(buf) - - log.Debugln("[TUN] hijack dns udp: %s", rAddrPort.String()) + log.Debugln("[TUN] hijack dns udp: %s", rAddr.String()) }() continue } pkt := &packet{ - local: lAddr, - data: raw, + local: lAddr, + data: buf, + offset: n, writeBack: func(b []byte, addr net.Addr) (int, error) { return stack.UDP().WriteTo(b, rAddr, lAddr) }, - drop: func() { - _ = pool.Put(buf) - }, } select { - case udpIn <- inbound.NewPacket(socks5.ParseAddrToSocksAddr(rAddr), pkt, C.TUN): + case udpIn <- inbound.NewPacket(socks5.AddrFromStdAddrPort(rAddr), pkt, C.TUN): default: + pkt.Drop() } } diff --git a/listener/tun/ipstack/system/udp.go b/listener/tun/ipstack/system/udp.go index cb2761e8..a4d416e9 100644 --- a/listener/tun/ipstack/system/udp.go +++ b/listener/tun/ipstack/system/udp.go @@ -1,16 +1,21 @@ package system -import "net" +import ( + "net" + "net/netip" + + "github.com/Dreamacro/clash/common/pool" +) type packet struct { - local *net.UDPAddr + local netip.AddrPort data []byte + offset int writeBack func(b []byte, addr net.Addr) (int, error) - drop func() } func (pkt *packet) Data() []byte { - return pkt.data + return pkt.data[:pkt.offset] } func (pkt *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { @@ -18,9 +23,9 @@ func (pkt *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { } func (pkt *packet) Drop() { - pkt.drop() + _ = pool.Put(pkt.data) } func (pkt *packet) LocalAddr() net.Addr { - return pkt.local + return net.UDPAddrFromAddrPort(pkt.local) } diff --git a/tunnel/connection.go b/tunnel/connection.go index 0384e805..7ab6c4d1 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -32,19 +32,21 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return err } // reset timeout - pc.SetReadDeadline(time.Now().Add(udpTimeout)) + _ = pc.SetReadDeadline(time.Now().Add(udpTimeout)) return nil } func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr net.Addr) { buf := pool.Get(pool.UDPBufferSize) - defer pool.Put(buf) - defer natTable.Delete(key) - defer pc.Close() + defer func() { + _ = pc.Close() + natTable.Delete(key) + _ = pool.Put(buf) + }() for { - pc.SetReadDeadline(time.Now().Add(udpTimeout)) + _ = pc.SetReadDeadline(time.Now().Add(udpTimeout)) n, from, err := pc.ReadFrom(buf) if err != nil { return