Chore: mitm proxy with authenticate
This commit is contained in:
parent
30025c0241
commit
22458ad0be
@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
|
|||||||
request.RequestURI = ""
|
request.RequestURI = ""
|
||||||
|
|
||||||
if isUpgradeRequest(request) {
|
if isUpgradeRequest(request) {
|
||||||
if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil {
|
if resp = HandleUpgrade(conn, nil, request, in); resp == nil {
|
||||||
return // hijack connection
|
return // hijack connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -18,58 +16,43 @@ func isUpgradeRequest(req *http.Request) bool {
|
|||||||
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
|
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
|
func HandleUpgrade(localConn net.Conn, serverConn *N.BufferedConn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
|
||||||
removeProxyHeaders(request.Header)
|
removeProxyHeaders(request.Header)
|
||||||
RemoveExtraHTTPHostPort(request)
|
RemoveExtraHTTPHostPort(request)
|
||||||
|
|
||||||
address := request.Host
|
if serverConn == nil {
|
||||||
if _, _, err := net.SplitHostPort(address); err != nil {
|
address := request.Host
|
||||||
port := "80"
|
if _, _, err := net.SplitHostPort(address); err != nil {
|
||||||
if request.TLS != nil {
|
port := "80"
|
||||||
port = "443"
|
if request.TLS != nil {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
address = net.JoinHostPort(address, port)
|
||||||
}
|
}
|
||||||
address = net.JoinHostPort(address, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
dstAddr := socks5.ParseAddr(address)
|
dstAddr := socks5.ParseAddr(address)
|
||||||
if dstAddr == nil {
|
if dstAddr == nil {
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
left, right := net.Pipe()
|
|
||||||
|
|
||||||
in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right)
|
|
||||||
|
|
||||||
var remoteServer *N.BufferedConn
|
|
||||||
if request.TLS != nil {
|
|
||||||
tlsConn := tls.Client(left, &tls.Config{
|
|
||||||
ServerName: request.URL.Hostname(),
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
|
|
||||||
defer cancel()
|
|
||||||
if tlsConn.HandshakeContext(ctx) != nil {
|
|
||||||
_ = localConn.Close()
|
|
||||||
_ = left.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteServer = N.NewBufferedConn(tlsConn)
|
left, right := net.Pipe()
|
||||||
} else {
|
|
||||||
remoteServer = N.NewBufferedConn(left)
|
in <- inbound.NewHTTP(dstAddr, localConn.RemoteAddr(), right)
|
||||||
|
|
||||||
|
serverConn = N.NewBufferedConn(left)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = serverConn.Close()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
err := request.Write(serverConn)
|
||||||
_ = remoteServer.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := request.Write(remoteServer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = localConn.Close()
|
_ = localConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err = http.ReadResponse(remoteServer.Reader(), request)
|
resp, err = http.ReadResponse(serverConn.Reader(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = localConn.Close()
|
_ = localConn.Close()
|
||||||
return
|
return
|
||||||
@ -88,7 +71,7 @@ func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, i
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
N.Relay(remoteServer, localConn) // blocking here
|
N.Relay(serverConn, localConn) // blocking here
|
||||||
_ = localConn.Close()
|
_ = localConn.Close()
|
||||||
resp = nil
|
resp = nil
|
||||||
}
|
}
|
||||||
|
@ -3,53 +3,53 @@ package mitm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/adapter/inbound"
|
"github.com/Dreamacro/clash/adapter/inbound"
|
||||||
|
N "github.com/Dreamacro/clash/common/net"
|
||||||
C "github.com/Dreamacro/clash/constant"
|
C "github.com/Dreamacro/clash/constant"
|
||||||
"github.com/Dreamacro/clash/transport/socks5"
|
"github.com/Dreamacro/clash/transport/socks5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client {
|
func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr net.Addr, in chan<- C.ConnContext) (*N.BufferedConn, error) {
|
||||||
return &http.Client{
|
if serverConn != nil {
|
||||||
Transport: &http.Transport{
|
return serverConn, nil
|
||||||
// excepted HTTP/2
|
|
||||||
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
|
|
||||||
// only needed 1 connection
|
|
||||||
MaxIdleConns: 1,
|
|
||||||
MaxIdleConnsPerHost: 1,
|
|
||||||
MaxConnsPerHost: 1,
|
|
||||||
IdleConnTimeout: 60 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) {
|
|
||||||
return nil, ErrCertUnsupported
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
|
|
||||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
||||||
return nil, errors.New("unsupported network " + network)
|
|
||||||
}
|
|
||||||
|
|
||||||
dstAddr := socks5.ParseAddr(address)
|
|
||||||
if dstAddr == nil {
|
|
||||||
return nil, socks5.ErrAddressNotSupported
|
|
||||||
}
|
|
||||||
|
|
||||||
left, right := net.Pipe()
|
|
||||||
|
|
||||||
in <- inbound.NewMitm(dstAddr, source, userAgent, right)
|
|
||||||
|
|
||||||
return left, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
Timeout: 120 * time.Second,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
address := request.Host
|
||||||
|
if _, _, err := net.SplitHostPort(address); err != nil {
|
||||||
|
port := "80"
|
||||||
|
if request.TLS != nil {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
address = net.JoinHostPort(address, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
dstAddr := socks5.ParseAddr(address)
|
||||||
|
if dstAddr == nil {
|
||||||
|
return nil, socks5.ErrAddressNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
|
||||||
|
in <- inbound.NewMitm(dstAddr, srcAddr, request.Header.Get("User-Agent"), right)
|
||||||
|
|
||||||
|
if request.TLS != nil {
|
||||||
|
tlsConn := tls.Client(left, &tls.Config{
|
||||||
|
ServerName: request.TLS.ServerName,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConn = N.NewBufferedConn(tlsConn)
|
||||||
|
} else {
|
||||||
|
serverConn = N.NewBufferedConn(left)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverConn, nil
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -19,27 +18,31 @@ import (
|
|||||||
H "github.com/Dreamacro/clash/listener/http"
|
H "github.com/Dreamacro/clash/listener/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
|
func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
|
||||||
var (
|
var (
|
||||||
source net.Addr
|
clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr()
|
||||||
client *http.Client
|
sourceAddr net.Addr
|
||||||
|
serverConn *N.BufferedConn
|
||||||
)
|
)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if client != nil {
|
if serverConn != nil {
|
||||||
client.CloseIdleConnections()
|
_ = serverConn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
startOver:
|
startOver:
|
||||||
var conn *N.BufferedConn
|
var conn *N.BufferedConn
|
||||||
if bufConn, ok := c.(*N.BufferedConn); ok {
|
if bufConn, ok := clientConn.(*N.BufferedConn); ok {
|
||||||
conn = bufConn
|
conn = bufConn
|
||||||
} else {
|
} else {
|
||||||
conn = N.NewBufferedConn(c)
|
conn = N.NewBufferedConn(clientConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
trusted := cache == nil // disable authenticate if cache is nil
|
trusted := cache == nil // disable authenticate if cache is nil
|
||||||
|
if !trusted {
|
||||||
|
trusted = clientIP.IsLoopback()
|
||||||
|
}
|
||||||
|
|
||||||
readLoop:
|
readLoop:
|
||||||
for {
|
for {
|
||||||
@ -57,8 +60,8 @@ readLoop:
|
|||||||
|
|
||||||
session := newSession(conn, request, response)
|
session := newSession(conn, request, response)
|
||||||
|
|
||||||
source = parseSourceAddress(session.request, c.RemoteAddr(), source)
|
sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr)
|
||||||
session.request.RemoteAddr = source.String()
|
session.request.RemoteAddr = sourceAddr.String()
|
||||||
|
|
||||||
if !trusted {
|
if !trusted {
|
||||||
session.response = H.Authenticate(session.request, cache)
|
session.response = H.Authenticate(session.request, cache)
|
||||||
@ -95,15 +98,15 @@ readLoop:
|
|||||||
break // close connection
|
break // close connection
|
||||||
}
|
}
|
||||||
|
|
||||||
c = tlsConn
|
clientConn = tlsConn
|
||||||
} else {
|
} else {
|
||||||
c = conn
|
clientConn = conn
|
||||||
}
|
}
|
||||||
|
|
||||||
goto startOver
|
goto startOver
|
||||||
}
|
}
|
||||||
|
|
||||||
prepareRequest(c, session.request)
|
prepareRequest(clientConn, session.request)
|
||||||
|
|
||||||
// hijack api
|
// hijack api
|
||||||
if session.request.URL.Hostname() == opt.ApiHost {
|
if session.request.URL.Hostname() == opt.ApiHost {
|
||||||
@ -115,17 +118,22 @@ readLoop:
|
|||||||
|
|
||||||
// forward websocket
|
// forward websocket
|
||||||
if isWebsocketRequest(request) {
|
if isWebsocketRequest(request) {
|
||||||
|
serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
session.request.RequestURI = ""
|
session.request.RequestURI = ""
|
||||||
if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil {
|
if session.response = H.HandleUpgrade(conn, serverConn, request, in); session.response == nil {
|
||||||
return // hijack connection
|
return // hijack connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
H.RemoveHopByHopHeaders(session.request.Header)
|
if session.response == nil {
|
||||||
H.RemoveExtraHTTPHostPort(session.request)
|
H.RemoveHopByHopHeaders(session.request.Header)
|
||||||
|
H.RemoveExtraHTTPHostPort(session.request)
|
||||||
|
|
||||||
// hijack custom request and write back custom response if necessary
|
// hijack custom request and write back custom response if necessary
|
||||||
if opt.Handler != nil && session.response == nil {
|
|
||||||
newReq, newRes := opt.Handler.HandleRequest(session)
|
newReq, newRes := opt.Handler.HandleRequest(session)
|
||||||
if newReq != nil {
|
if newReq != nil {
|
||||||
session.request = newReq
|
session.request = newReq
|
||||||
@ -139,26 +147,26 @@ readLoop:
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if session.response == nil {
|
|
||||||
session.request.RequestURI = ""
|
session.request.RequestURI = ""
|
||||||
|
|
||||||
if session.request.URL.Host == "" {
|
if session.request.URL.Host == "" {
|
||||||
session.response = session.NewErrorResponse(ErrInvalidURL)
|
session.response = session.NewErrorResponse(ErrInvalidURL)
|
||||||
} else {
|
} else {
|
||||||
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
|
serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// send the request to remote server
|
// send the request to remote server
|
||||||
session.response, err = client.Do(session.request)
|
err = session.request.Write(serverConn)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(opt, session, err)
|
break
|
||||||
session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err))
|
}
|
||||||
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
|
|
||||||
_ = writeResponse(session, false)
|
session.response, err = http.ReadResponse(serverConn.Reader(), request)
|
||||||
break
|
if err != nil {
|
||||||
}
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,11 +182,9 @@ readLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func writeResponseWithHandler(session *Session, opt *Option) error {
|
func writeResponseWithHandler(session *Session, opt *Option) error {
|
||||||
if opt.Handler != nil {
|
res := opt.Handler.HandleResponse(session)
|
||||||
res := opt.Handler.HandleResponse(session)
|
if res != nil {
|
||||||
if res != nil {
|
session.response = res
|
||||||
session.response = res
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeResponse(session, true)
|
return writeResponse(session, true)
|
||||||
@ -220,10 +226,8 @@ func handleApiRequest(session *Session, opt *Option) error {
|
|||||||
</body></html>
|
</body></html>
|
||||||
`
|
`
|
||||||
|
|
||||||
if opt.Handler != nil {
|
if opt.Handler.HandleApiRequest(session) {
|
||||||
if opt.Handler.HandleApiRequest(session) {
|
return nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
b = fmt.Sprintf(b, session.request.URL.Path)
|
b = fmt.Sprintf(b, session.request.URL.Path)
|
||||||
@ -243,9 +247,7 @@ func handleError(opt *Option, session *Session, err error) {
|
|||||||
_ = session.response.Body.Close()
|
_ = session.response.Body.Close()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if opt.Handler != nil {
|
opt.Handler.HandleError(session, err)
|
||||||
opt.Handler.HandleError(session, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRequest(conn net.Conn, request *http.Request) {
|
func prepareRequest(conn net.Conn, request *http.Request) {
|
||||||
@ -297,10 +299,6 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client {
|
func isWebsocketRequest(req *http.Request) bool {
|
||||||
if cli != nil {
|
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
|
||||||
return cli
|
|
||||||
}
|
|
||||||
|
|
||||||
return newClient(source, req.Header.Get("User-Agent"), in)
|
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ func (l *Listener) Close() error {
|
|||||||
|
|
||||||
// New the MITM proxy actually is a type of HTTP proxy
|
// New the MITM proxy actually is a type of HTTP proxy
|
||||||
func New(option *Option, in chan<- C.ConnContext) (*Listener, error) {
|
func New(option *Option, in chan<- C.ConnContext) (*Listener, error) {
|
||||||
return NewWithAuthenticate(option, in, false)
|
return NewWithAuthenticate(option, in, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) {
|
func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) {
|
||||||
@ -65,7 +65,7 @@ func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate b
|
|||||||
|
|
||||||
var c *cache.Cache[string, bool]
|
var c *cache.Cache[string, bool]
|
||||||
if authenticate {
|
if authenticate {
|
||||||
c = cache.New[string, bool](time.Second * 30)
|
c = cache.New[string, bool](time.Second * 90)
|
||||||
}
|
}
|
||||||
|
|
||||||
hl := &Listener{
|
hl := &Listener{
|
||||||
|
@ -15,15 +15,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrCertUnsupported = errors.New("tls: client cert unsupported")
|
|
||||||
ErrInvalidResponse = errors.New("invalid response")
|
ErrInvalidResponse = errors.New("invalid response")
|
||||||
ErrInvalidURL = errors.New("invalid URL")
|
ErrInvalidURL = errors.New("invalid URL")
|
||||||
)
|
)
|
||||||
|
|
||||||
func isWebsocketRequest(req *http.Request) bool {
|
|
||||||
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
|
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
|
||||||
if body == nil {
|
if body == nil {
|
||||||
body = &bytes.Buffer{}
|
body = &bytes.Buffer{}
|
||||||
|
Reference in New Issue
Block a user