chore: migrate from gorilla/websocket to gobwas/ws
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package vmess
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
@ -14,27 +15,24 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/common/buf"
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
tlsC "github.com/Dreamacro/clash/component/tls"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/zhangyunhao116/fastrand"
|
||||
)
|
||||
|
||||
type websocketConn struct {
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
remoteAddr net.Addr
|
||||
net.Conn
|
||||
state ws.State
|
||||
reader *wsutil.Reader
|
||||
controlHandler wsutil.FrameHandlerFunc
|
||||
|
||||
rawWriter N.ExtendedWriter
|
||||
|
||||
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
|
||||
rMux sync.Mutex
|
||||
wMux sync.Mutex
|
||||
}
|
||||
|
||||
type websocketWithEarlyDataConn struct {
|
||||
@ -61,32 +59,48 @@ type WebsocketConfig struct {
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read()
|
||||
func (wsc *websocketConn) Read(b []byte) (int, error) {
|
||||
wsc.rMux.Lock()
|
||||
defer wsc.rMux.Unlock()
|
||||
// modify from gobwas/ws/wsutil.readData
|
||||
func (wsc *websocketConn) Read(b []byte) (n int, err error) {
|
||||
var header ws.Header
|
||||
for {
|
||||
reader, err := wsc.getReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
n, err = wsc.reader.Read(b)
|
||||
// in gobwas/ws: "The error is io.EOF only if all of message bytes were read."
|
||||
// but maybe next frame still have data, so drop it
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
nBytes, err := reader.Read(b)
|
||||
if err == io.EOF {
|
||||
wsc.reader = nil
|
||||
if !errors.Is(err, wsutil.ErrNoFrameAdvance) {
|
||||
return
|
||||
}
|
||||
header, err = wsc.reader.NextFrame()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if header.OpCode.IsControl() {
|
||||
err = wsc.controlHandler(header, wsc.reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if header.OpCode&(ws.OpBinary|ws.OpText) == 0 {
|
||||
err = wsc.reader.Discard()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nBytes, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (wsc *websocketConn) Write(b []byte) (int, error) {
|
||||
wsc.wMux.Lock()
|
||||
defer wsc.wMux.Unlock()
|
||||
if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
|
||||
return 0, err
|
||||
func (wsc *websocketConn) Write(b []byte) (n int, err error) {
|
||||
err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return len(b), nil
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
@ -108,7 +122,7 @@ func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
|
||||
header := buffer.ExtendHeader(headerLen)
|
||||
_ = header[2] // bounds check hint to compiler
|
||||
header[0] = websocket.BinaryMessage | 1<<7
|
||||
header[0] = byte(ws.OpBinary) | 0x80
|
||||
header[1] = 1 << 7
|
||||
|
||||
if dataLen < 126 {
|
||||
@ -121,12 +135,12 @@ func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
|
||||
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
|
||||
}
|
||||
|
||||
maskKey := fastrand.Uint32()
|
||||
binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey)
|
||||
N.MaskWebSocket(maskKey, data)
|
||||
if wsc.state.ClientSide() {
|
||||
maskKey := fastrand.Uint32()
|
||||
binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey)
|
||||
N.MaskWebSocket(maskKey, data)
|
||||
}
|
||||
|
||||
wsc.wMux.Lock()
|
||||
defer wsc.wMux.Unlock()
|
||||
return wsc.rawWriter.WriteBuffer(buffer)
|
||||
}
|
||||
|
||||
@ -135,59 +149,16 @@ func (wsc *websocketConn) FrontHeadroom() int {
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) Upstream() any {
|
||||
return wsc.conn.UnderlyingConn()
|
||||
return wsc.Conn
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) Close() error {
|
||||
var e []string
|
||||
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
|
||||
e = append(e, err.Error())
|
||||
}
|
||||
if err := wsc.conn.Close(); err != nil {
|
||||
e = append(e, err.Error())
|
||||
}
|
||||
if len(e) > 0 {
|
||||
return fmt.Errorf("failed to close connection: %s", strings.Join(e, ","))
|
||||
}
|
||||
_ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
|
||||
_ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, ""))
|
||||
_ = wsc.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) getReader() (io.Reader, error) {
|
||||
if wsc.reader != nil {
|
||||
return wsc.reader, nil
|
||||
}
|
||||
|
||||
_, reader, err := wsc.conn.NextReader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wsc.reader = reader
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) LocalAddr() net.Addr {
|
||||
return wsc.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) RemoteAddr() net.Addr {
|
||||
return wsc.remoteAddr
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) SetDeadline(t time.Time) error {
|
||||
if err := wsc.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return wsc.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) SetReadDeadline(t time.Time) error {
|
||||
return wsc.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (wsc *websocketConn) SetWriteDeadline(t time.Time) error {
|
||||
return wsc.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
|
||||
base64DataBuf := &bytes.Buffer{}
|
||||
base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
|
||||
@ -341,29 +312,25 @@ func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Co
|
||||
}
|
||||
|
||||
func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
dialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return conn, nil
|
||||
},
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 4 * 1024,
|
||||
HandshakeTimeout: time.Second * 8,
|
||||
TLSConfig: c.TLSConfig,
|
||||
}
|
||||
|
||||
scheme := "ws"
|
||||
if c.TLS {
|
||||
scheme = "wss"
|
||||
dialer.TLSClientConfig = c.TLSConfig
|
||||
if len(c.ClientFingerprint) != 0 {
|
||||
if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists {
|
||||
dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (net.Conn, error) {
|
||||
utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint)
|
||||
utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint)
|
||||
|
||||
if err := utlsConn.(*tlsC.UConn).WebsocketHandshake(); err != nil {
|
||||
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
|
||||
}
|
||||
return utlsConn, nil
|
||||
if err := utlsConn.(*tlsC.UConn).BuildWebsocketHandshakeState(); err != nil {
|
||||
return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
|
||||
}
|
||||
|
||||
dialer.TLSClient = func(conn net.Conn, hostname string) net.Conn {
|
||||
return utlsConn
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -381,38 +348,47 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig,
|
||||
RawQuery: u.RawQuery,
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers := http.Header{"User-Agent": []string{"Go-http-client/1.1"}} // match golang's net/http
|
||||
if c.Headers != nil {
|
||||
for k := range c.Headers {
|
||||
headers.Add(k, c.Headers.Get(k))
|
||||
cHeaders := c.Headers
|
||||
// gobwas/ws send "Host" directly in Upgrade() by `httpWriteHeader(bw, headerHost, u.Host)`
|
||||
// if headers has "Host" will send repeatedly
|
||||
if host := cHeaders.Get("Host"); host != "" {
|
||||
cHeaders.Del("Host")
|
||||
uri.Host = host
|
||||
}
|
||||
for k := range cHeaders {
|
||||
headers.Add(k, cHeaders.Get(k))
|
||||
}
|
||||
}
|
||||
|
||||
if earlyData != nil {
|
||||
earlyDataString := earlyData.String()
|
||||
if c.EarlyDataHeaderName == "" {
|
||||
uri.Path += earlyData.String()
|
||||
uri.Path += earlyDataString
|
||||
} else {
|
||||
headers.Set(c.EarlyDataHeaderName, earlyData.String())
|
||||
// gobwas/ws will check server's response "Sec-Websocket-Protocol" so must add Protocols to ws.Dialer
|
||||
// if not will cause ws.ErrHandshakeBadSubProtocol
|
||||
if c.EarlyDataHeaderName == "Sec-WebSocket-Protocol" {
|
||||
// gobwas/ws will set "Sec-Websocket-Protocol" according dialer.Protocols
|
||||
// to avoid send repeatedly don't set it to headers
|
||||
dialer.Protocols = []string{earlyDataString}
|
||||
} else {
|
||||
headers.Set(c.EarlyDataHeaderName, earlyDataString)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
wsConn, resp, err := dialer.DialContext(ctx, uri.String(), headers)
|
||||
dialer.Header = ws.HandshakeHeaderHTTP(headers)
|
||||
|
||||
conn, reader, _, err := dialer.Dial(ctx, uri.String())
|
||||
if err != nil {
|
||||
reason := err
|
||||
if resp != nil {
|
||||
reason = errors.New(resp.Status)
|
||||
}
|
||||
return nil, fmt.Errorf("dial %s error: %w", uri.Host, reason)
|
||||
return nil, fmt.Errorf("dial %s error: %w", uri.Host, err)
|
||||
}
|
||||
|
||||
conn = &websocketConn{
|
||||
conn: wsConn,
|
||||
rawWriter: N.NewExtendedWriter(wsConn.UnderlyingConn()),
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
}
|
||||
conn = newWebsocketConn(conn, reader, ws.StateClientSide)
|
||||
// websocketConn can't correct handle ReadDeadline
|
||||
// gorilla/websocket will cache the os.ErrDeadlineExceeded from conn.Read()
|
||||
// it will cause read fail and event panic in *websocket.Conn.NextReader()
|
||||
// so call N.NewDeadlineConn to add a safe wrapper
|
||||
return N.NewDeadlineConn(conn), nil
|
||||
}
|
||||
@ -436,3 +412,68 @@ func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig)
|
||||
|
||||
return streamWebsocketConn(ctx, conn, c, nil)
|
||||
}
|
||||
|
||||
func newWebsocketConn(conn net.Conn, br *bufio.Reader, state ws.State) *websocketConn {
|
||||
controlHandler := wsutil.ControlFrameHandler(conn, state)
|
||||
var reader io.Reader
|
||||
if br != nil && br.Buffered() > 0 {
|
||||
reader = br
|
||||
} else {
|
||||
reader = conn
|
||||
}
|
||||
return &websocketConn{
|
||||
Conn: conn,
|
||||
state: state,
|
||||
reader: &wsutil.Reader{
|
||||
Source: reader,
|
||||
State: state,
|
||||
SkipHeaderCheck: true,
|
||||
CheckUTF8: false,
|
||||
OnIntermediate: controlHandler,
|
||||
},
|
||||
controlHandler: controlHandler,
|
||||
rawWriter: N.NewExtendedWriter(conn),
|
||||
}
|
||||
}
|
||||
|
||||
var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "")
|
||||
|
||||
func decodeEd(s string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(replacer.Replace(s))
|
||||
}
|
||||
|
||||
func decodeXray0rtt(requestHeader http.Header) ([]byte, http.Header) {
|
||||
var edBuf []byte
|
||||
responseHeader := http.Header{}
|
||||
// read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws
|
||||
if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
|
||||
if buf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode
|
||||
edBuf = buf
|
||||
}
|
||||
}
|
||||
return edBuf, responseHeader
|
||||
}
|
||||
|
||||
func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
|
||||
edBuf, responseHeader := decodeXray0rtt(r.Header)
|
||||
wsConn, rw, _, err := ws.HTTPUpgrader{
|
||||
Header: responseHeader,
|
||||
}.Upgrade(r, w)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := newWebsocketConn(wsConn, rw.Reader, ws.StateServerSide)
|
||||
if len(edBuf) > 0 {
|
||||
return &websocketWithReaderConn{conn, io.MultiReader(bytes.NewReader(edBuf), conn)}, nil
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type websocketWithReaderConn struct {
|
||||
*websocketConn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (ws *websocketWithReaderConn) Read(b []byte) (n int, err error) {
|
||||
return ws.reader.Read(b)
|
||||
}
|
||||
|
Reference in New Issue
Block a user