refactor: Implement extended IO

This commit is contained in:
H1JK
2023-01-16 09:42:03 +08:00
parent 8fa66c13a9
commit d1565bb46f
7 changed files with 219 additions and 39 deletions

View File

@ -5,9 +5,11 @@ import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
@ -15,15 +17,24 @@ import (
"strings"
"sync"
"time"
_ "unsafe"
"github.com/gorilla/websocket"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
)
//go:linkname maskBytes github.com/gorilla/websocket.maskBytes
func maskBytes(key [4]byte, pos int, b []byte) int
type websocketConn struct {
conn *websocket.Conn
reader io.Reader
remoteAddr net.Addr
rawWriter network.ExtendedWriter
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
rMux sync.Mutex
wMux sync.Mutex
@ -31,6 +42,7 @@ type websocketConn struct {
type websocketWithEarlyDataConn struct {
net.Conn
wsWriter network.ExtendedWriter
underlay net.Conn
closed bool
dialed chan bool
@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
var payloadBitLength int
dataLen := buffer.Len()
data := buffer.Bytes()
if dataLen < 126 {
payloadBitLength = 1
} else if dataLen < 65536 {
payloadBitLength = 3
} else {
payloadBitLength = 9
}
var headerLen int
headerLen += 1 // FIN / RSV / OPCODE
headerLen += payloadBitLength
headerLen += 4 // MASK KEY
header := buffer.ExtendHeader(headerLen)
header[0] = websocket.BinaryMessage | 1<<7
header[1] = 1 << 7
if dataLen < 126 {
header[1] |= byte(dataLen)
} else if dataLen < 65536 {
header[1] |= 126
binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
} else {
header[1] |= 127
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
}
maskKey := rand.Uint32()
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
wsc.wMux.Lock()
defer wsc.wMux.Unlock()
return wsc.rawWriter.WriteBuffer(buffer)
}
func (wsc *websocketConn) FrontHeadroom() int {
return 14
}
func (wsc *websocketConn) Upstream() any {
return wsc.conn.UnderlyingConn()
}
func (wsc *websocketConn) Close() error {
var errors []string
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
}
wsedc.dialed <- true
wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn)
if earlyDataBuf.Len() != 0 {
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
}
@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
if wsedc.closed {
return io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(buffer.Bytes()); err != nil {
return err
}
return nil
}
return wsedc.wsWriter.WriteBuffer(buffer)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
return wsedc.Conn.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) Upstream() any {
return wsedc.Conn
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf
return &websocketConn{
conn: wsConn,
rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()),
remoteAddr: conn.RemoteAddr(),
}, nil
}