refactor: Implement extended IO

This commit is contained in:
H1JK
2023-01-16 09:42:03 +08:00
parent 8fa66c13a9
commit d1565bb46f
7 changed files with 219 additions and 39 deletions

View File

@ -1,14 +1,16 @@
package tunnel
import (
"context"
"errors"
"net"
"net/netip"
"time"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/sagernet/sing/common/bufio"
)
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
@ -60,5 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
}
func handleSocket(ctx C.ConnContext, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound)
bufio.CopyConn(context.TODO(), ctx.Conn(), outbound)
}

View File

@ -7,6 +7,9 @@ import (
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
"go.uber.org/atomic"
)
@ -29,7 +32,9 @@ type trackerInfo struct {
type tcpTracker struct {
C.Conn `json:"-"`
*trackerInfo
manager *Manager
manager *Manager
extendedReader network.ExtendedReader
extendedWriter network.ExtendedWriter
}
func (tt *tcpTracker) ID() string {
@ -44,6 +49,14 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
return n, err
}
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedReader.ReadBuffer(buffer)
download := int64(buffer.Len())
tt.manager.PushDownloaded(download)
tt.DownloadTotal.Add(download)
return
}
func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b)
upload := int64(n)
@ -52,11 +65,26 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
return n, err
}
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedWriter.WriteBuffer(buffer)
var upload int64
if err != nil {
upload = int64(buffer.Len())
}
tt.manager.PushUploaded(upload)
tt.UploadTotal.Add(upload)
return
}
func (tt *tcpTracker) Close() error {
tt.manager.Leave(tt)
return tt.Conn.Close()
}
func (tt *tcpTracker) Upstream() any {
return tt.Conn
}
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4()
if conn != nil {
@ -79,6 +107,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
UploadTotal: atomic.NewInt64(0),
DownloadTotal: atomic.NewInt64(0),
},
extendedReader: bufio.NewExtendedReader(conn),
extendedWriter: bufio.NewExtendedWriter(conn),
}
if rule != nil {