Chore: use custom buffer pool for lwIP stack
This commit is contained in:
@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/Dreamacro/clash/adapter/inbound"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/component/resolver"
|
||||
"github.com/Dreamacro/clash/config"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
@ -83,7 +84,7 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, tunAddress string, tcpIn
|
||||
// TCP handler
|
||||
// maximum number of half-open tcp connection set to 1024
|
||||
// receive buffer size set to 20k
|
||||
tcpFwd := tcp.NewForwarder(ipstack, 20*1024, 1024, func(r *tcp.ForwarderRequest) {
|
||||
tcpFwd := tcp.NewForwarder(ipstack, pool.RelayBufferSize, 1024, func(r *tcp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
"github.com/yaling888/go-lwip"
|
||||
)
|
||||
|
||||
const defaultDnsReadTimeout = time.Second * 30
|
||||
const defaultDnsReadTimeout = time.Second * 8
|
||||
|
||||
func shouldHijackDns(dnsIP net.IP, targetIp net.IP, targetPort int) bool {
|
||||
if targetPort != 53 {
|
||||
@ -28,6 +28,10 @@ func hijackUDPDns(conn golwip.UDPConn, pkt []byte, addr *net.UDPAddr) {
|
||||
_ = conn.Close()
|
||||
}(conn)
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
answer, err := D.RelayDnsPacket(pkt)
|
||||
if err != nil {
|
||||
return
|
||||
@ -42,11 +46,11 @@ func hijackTCPDns(conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}(conn)
|
||||
|
||||
for {
|
||||
if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
if err := conn.SetDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
var length uint16
|
||||
if binary.Read(conn, binary.BigEndian, &length) != nil {
|
||||
return
|
||||
@ -68,7 +72,7 @@ func hijackTCPDns(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(rb); err != nil {
|
||||
if _, err = conn.Write(rb); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/Dreamacro/clash/adapter/inbound"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
"github.com/Dreamacro/clash/config"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/listener/tun/dev"
|
||||
@ -43,12 +44,19 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, tcpIn chan<- C.C
|
||||
|
||||
dnsIP := net.ParseIP(dnsHost)
|
||||
|
||||
// Register output function, write packets from lwip stack to tun device
|
||||
golwip.RegisterOutputFn(func(data []byte) (int, error) {
|
||||
return device.Write(data)
|
||||
})
|
||||
|
||||
// Set custom buffer pool
|
||||
golwip.SetPoolAllocator(&lwipPool{})
|
||||
|
||||
// Setup TCP/IP stack.
|
||||
lwipStack := golwip.NewLWIPStack(mtu)
|
||||
lwipStack, err := golwip.NewLWIPStack(mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adapter.lwipStack = lwipStack
|
||||
|
||||
golwip.RegisterDnsHandler(NewDnsHandler())
|
||||
@ -59,7 +67,7 @@ func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, tcpIn chan<- C.C
|
||||
go func(lwipStack golwip.LWIPStack, device dev.TunDevice, mtu int) {
|
||||
_, err := io.CopyBuffer(lwipStack.(io.Writer), device, make([]byte, mtu))
|
||||
if err != nil {
|
||||
log.Errorln("copying data failed: %v", err)
|
||||
log.Debugln("copying data failed: %v", err)
|
||||
}
|
||||
}(lwipStack, device, mtu)
|
||||
|
||||
@ -97,3 +105,13 @@ func (l *lwipAdapter) stopLocked() {
|
||||
l.lwipStack = nil
|
||||
l.device = nil
|
||||
}
|
||||
|
||||
type lwipPool struct{}
|
||||
|
||||
func (p lwipPool) Get(size int) []byte {
|
||||
return pool.Get(size)
|
||||
}
|
||||
|
||||
func (p lwipPool) Put(buf []byte) error {
|
||||
return pool.Put(buf)
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"github.com/kr328/tun2socket/redirect"
|
||||
)
|
||||
|
||||
const defaultDnsReadTimeout = time.Second * 30
|
||||
const defaultDnsReadTimeout = time.Second * 10
|
||||
|
||||
func shouldHijackDns(dnsAddr binding.Address, targetAddr binding.Address) bool {
|
||||
if targetAddr.Port != 53 {
|
||||
@ -41,11 +41,11 @@ func hijackTCPDns(conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}(conn)
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
var length uint16
|
||||
if binary.Read(conn, binary.BigEndian, &length) != nil {
|
||||
return
|
||||
@ -67,7 +67,7 @@ func hijackTCPDns(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(rb); err != nil {
|
||||
if _, err = conn.Write(rb); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user