refactor: Implement extended IO

This commit is contained in:
H1JK
2023-01-16 09:42:03 +08:00
parent 8fa66c13a9
commit d1565bb46f
7 changed files with 219 additions and 39 deletions

View File

@ -1,7 +1,6 @@
package vless
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@ -9,12 +8,16 @@ import (
"net"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
xtls "github.com/xtls/go"
"google.golang.org/protobuf/proto"
)
type Conn struct {
net.Conn
network.ExtendedConn
dst *DstAddr
id *uuid.UUID
addons *Addons
@ -23,57 +26,82 @@ type Conn struct {
func (vc *Conn) Read(b []byte) (int, error) {
if vc.received {
return vc.Conn.Read(b)
return vc.ExtendedConn.Read(b)
}
if err := vc.recvResponse(); err != nil {
return 0, err
}
vc.received = true
return vc.Conn.Read(b)
return vc.ExtendedConn.Read(b)
}
func (vc *Conn) sendRequest() error {
buf := &bytes.Buffer{}
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
if vc.received {
return vc.ExtendedConn.ReadBuffer(buffer)
}
buf.WriteByte(Version) // protocol version
buf.Write(vc.id.Bytes()) // 16 bytes of uuid
if err := vc.recvResponse(); err != nil {
return err
}
vc.received = true
return vc.ExtendedConn.ReadBuffer(buffer)
}
func (vc *Conn) sendRequest() (err error) {
requestLen := 1 // protocol version
requestLen += 16 // UUID
requestLen += 1 // addons length
var addonsBytes []byte
if vc.addons != nil {
bytes, err := proto.Marshal(vc.addons)
addonsBytes, err = proto.Marshal(vc.addons)
if err != nil {
return err
}
buf.WriteByte(byte(len(bytes)))
buf.Write(bytes)
} else {
buf.WriteByte(0) // addon data length. 0 means no addon data
}
requestLen += len(addonsBytes)
requestLen += 1 // command
if !vc.dst.Mux {
requestLen += 2 // port
requestLen += 1 // addr type
requestLen += len(vc.dst.Addr)
}
_buffer := buf.StackNewSize(requestLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version), // protocol version
common.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
buffer.WriteByte(byte(len(addonsBytes))),
common.Error(buffer.Write(addonsBytes)),
)
if vc.dst.Mux {
buf.WriteByte(CommandMux)
common.Must(buffer.WriteByte(CommandMux))
} else {
if vc.dst.UDP {
buf.WriteByte(CommandUDP)
common.Must(buffer.WriteByte(CommandUDP))
} else {
buf.WriteByte(CommandTCP)
common.Must(buffer.WriteByte(CommandTCP))
}
// Port AddrType Addr
binary.Write(buf, binary.BigEndian, vc.dst.Port)
buf.WriteByte(vc.dst.AddrType)
buf.Write(vc.dst.Addr)
binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
common.Must(
buffer.WriteByte(vc.dst.AddrType),
common.Error(buffer.Write(vc.dst.Addr)),
)
}
_, err := vc.Conn.Write(buf.Bytes())
return err
_, err = vc.ExtendedConn.Write(buffer.Bytes())
return
}
func (vc *Conn) recvResponse() error {
var err error
buf := make([]byte, 1)
_, err = io.ReadFull(vc.Conn, buf)
var buf [1]byte
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil {
return err
}
@ -82,25 +110,32 @@ func (vc *Conn) recvResponse() error {
return errors.New("unexpected response version")
}
_, err = io.ReadFull(vc.Conn, buf)
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil {
return err
}
length := int64(buf[0])
if length != 0 { // addon data length > 0
io.CopyN(io.Discard, vc.Conn, length) // just discard
io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard
}
return nil
}
func (vc *Conn) Upstream() any {
if wrapper, ok := vc.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return vc.ExtendedConn
}
// newConn return a Conn instance
func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
c := &Conn{
Conn: conn,
id: client.uuid,
dst: dst,
ExtendedConn: bufio.NewExtendedConn(conn),
id: client.uuid,
dst: dst,
}
if !dst.UDP && client.Addons != nil {

View File

@ -5,9 +5,11 @@ import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
@ -15,15 +17,24 @@ import (
"strings"
"sync"
"time"
_ "unsafe"
"github.com/gorilla/websocket"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
)
//go:linkname maskBytes github.com/gorilla/websocket.maskBytes
func maskBytes(key [4]byte, pos int, b []byte) int
type websocketConn struct {
conn *websocket.Conn
reader io.Reader
remoteAddr net.Addr
rawWriter network.ExtendedWriter
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
rMux sync.Mutex
wMux sync.Mutex
@ -31,6 +42,7 @@ type websocketConn struct {
type websocketWithEarlyDataConn struct {
net.Conn
wsWriter network.ExtendedWriter
underlay net.Conn
closed bool
dialed chan bool
@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
var payloadBitLength int
dataLen := buffer.Len()
data := buffer.Bytes()
if dataLen < 126 {
payloadBitLength = 1
} else if dataLen < 65536 {
payloadBitLength = 3
} else {
payloadBitLength = 9
}
var headerLen int
headerLen += 1 // FIN / RSV / OPCODE
headerLen += payloadBitLength
headerLen += 4 // MASK KEY
header := buffer.ExtendHeader(headerLen)
header[0] = websocket.BinaryMessage | 1<<7
header[1] = 1 << 7
if dataLen < 126 {
header[1] |= byte(dataLen)
} else if dataLen < 65536 {
header[1] |= 126
binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
} else {
header[1] |= 127
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
}
maskKey := rand.Uint32()
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
wsc.wMux.Lock()
defer wsc.wMux.Unlock()
return wsc.rawWriter.WriteBuffer(buffer)
}
func (wsc *websocketConn) FrontHeadroom() int {
return 14
}
func (wsc *websocketConn) Upstream() any {
return wsc.conn.UnderlyingConn()
}
func (wsc *websocketConn) Close() error {
var errors []string
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
}
wsedc.dialed <- true
wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn)
if earlyDataBuf.Len() != 0 {
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
}
@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
if wsedc.closed {
return io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(buffer.Bytes()); err != nil {
return err
}
return nil
}
return wsedc.wsWriter.WriteBuffer(buffer)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
return wsedc.Conn.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) Upstream() any {
return wsedc.Conn
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf
return &websocketConn{
conn: wsConn,
rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()),
remoteAddr: conn.RemoteAddr(),
}, nil
}