Refactor: refactor find process (#2781)
This commit is contained in:
@ -1,196 +1,206 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
tcpTableFunc = "GetExtendedTcpTable"
|
||||
tcpTablePidConn = 4
|
||||
udpTableFunc = "GetExtendedUdpTable"
|
||||
udpTablePid = 1
|
||||
queryProcNameFunc = "QueryFullProcessImageNameW"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
)
|
||||
|
||||
var (
|
||||
getExTCPTable uintptr
|
||||
getExUDPTable uintptr
|
||||
queryProcName uintptr
|
||||
modIphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
|
||||
once sync.Once
|
||||
procGetExtendedTcpTable = modIphlpapi.NewProc("GetExtendedTcpTable")
|
||||
procGetExtendedUdpTable = modIphlpapi.NewProc("GetExtendedUdpTable")
|
||||
)
|
||||
|
||||
func initWin32API() error {
|
||||
h, err := windows.LoadLibrary("iphlpapi.dll")
|
||||
if err != nil {
|
||||
return fmt.Errorf("LoadLibrary iphlpapi.dll failed: %s", err.Error())
|
||||
}
|
||||
|
||||
getExTCPTable, err = windows.GetProcAddress(h, tcpTableFunc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetProcAddress of %s failed: %s", tcpTableFunc, err.Error())
|
||||
}
|
||||
|
||||
getExUDPTable, err = windows.GetProcAddress(h, udpTableFunc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetProcAddress of %s failed: %s", udpTableFunc, err.Error())
|
||||
}
|
||||
|
||||
h, err = windows.LoadLibrary("kernel32.dll")
|
||||
if err != nil {
|
||||
return fmt.Errorf("LoadLibrary kernel32.dll failed: %s", err.Error())
|
||||
}
|
||||
|
||||
queryProcName, err = windows.GetProcAddress(h, queryProcNameFunc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetProcAddress of %s failed: %s", queryProcNameFunc, err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
|
||||
once.Do(func() {
|
||||
err := initWin32API()
|
||||
if err != nil {
|
||||
log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error())
|
||||
log.Warnln("All PROCESS-NAMES rules will be skipped")
|
||||
return
|
||||
}
|
||||
})
|
||||
family := windows.AF_INET
|
||||
if ip.To4() == nil {
|
||||
func findProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (string, error) {
|
||||
family := uint32(windows.AF_INET)
|
||||
if from.Addr().Is6() {
|
||||
family = windows.AF_INET6
|
||||
}
|
||||
|
||||
var class int
|
||||
var fn uintptr
|
||||
var protocol uint32
|
||||
switch network {
|
||||
case TCP:
|
||||
fn = getExTCPTable
|
||||
class = tcpTablePidConn
|
||||
protocol = windows.IPPROTO_TCP
|
||||
case UDP:
|
||||
fn = getExUDPTable
|
||||
class = udpTablePid
|
||||
protocol = windows.IPPROTO_UDP
|
||||
default:
|
||||
return "", ErrInvalidNetwork
|
||||
}
|
||||
|
||||
buf, err := getTransportTable(fn, family, class)
|
||||
pid, err := findPidByConnectionEndpoint(family, protocol, from, to)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s := newSearcher(family == windows.AF_INET, network == TCP)
|
||||
|
||||
pid, err := s.Search(buf, ip, uint16(srcPort))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getExecPathFromPID(pid)
|
||||
}
|
||||
|
||||
type searcher struct {
|
||||
itemSize int
|
||||
port int
|
||||
ip int
|
||||
ipSize int
|
||||
pid int
|
||||
tcpState int
|
||||
}
|
||||
func findPidByConnectionEndpoint(family uint32, protocol uint32, from netip.AddrPort, to netip.AddrPort) (uint32, error) {
|
||||
buf := pool.Get(8)
|
||||
defer pool.Put(buf)
|
||||
|
||||
func (s *searcher) Search(b []byte, ip net.IP, port uint16) (uint32, error) {
|
||||
n := int(readNativeUint32(b[:4]))
|
||||
itemSize := s.itemSize
|
||||
for i := 0; i < n; i++ {
|
||||
row := b[4+itemSize*i : 4+itemSize*(i+1)]
|
||||
bufSize := len(buf)
|
||||
|
||||
if s.tcpState >= 0 {
|
||||
tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
|
||||
// MIB_TCP_STATE_ESTAB, only check established connections for TCP
|
||||
if tcpState != 5 {
|
||||
continue
|
||||
loop:
|
||||
for {
|
||||
var ret uintptr
|
||||
|
||||
switch protocol {
|
||||
case windows.IPPROTO_TCP:
|
||||
ret, _, _ = procGetExtendedTcpTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
0,
|
||||
uintptr(family),
|
||||
4, // TCP_TABLE_OWNER_PID_CONNECTIONS
|
||||
0,
|
||||
)
|
||||
case windows.IPPROTO_UDP:
|
||||
ret, _, _ = procGetExtendedUdpTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bufSize)),
|
||||
0,
|
||||
uintptr(family),
|
||||
1, // UDP_TABLE_OWNER_PID
|
||||
0,
|
||||
)
|
||||
default:
|
||||
return 0, errors.New("unsupported network")
|
||||
}
|
||||
|
||||
switch ret {
|
||||
case 0:
|
||||
buf = buf[:bufSize]
|
||||
|
||||
break loop
|
||||
case uintptr(windows.ERROR_INSUFFICIENT_BUFFER):
|
||||
pool.Put(buf)
|
||||
buf = pool.Get(bufSize)
|
||||
|
||||
continue loop
|
||||
default:
|
||||
return 0, fmt.Errorf("syscall error: %d", ret)
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) < int(unsafe.Sizeof(uint32(0))) {
|
||||
return 0, fmt.Errorf("invalid table size: %d", len(buf))
|
||||
}
|
||||
|
||||
entriesSize := *(*uint32)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch protocol {
|
||||
case windows.IPPROTO_TCP:
|
||||
if family == windows.AF_INET {
|
||||
type MibTcpRowOwnerPid struct {
|
||||
State uint32
|
||||
LocalAddr [4]byte
|
||||
LocalPort uint32
|
||||
RemoteAddr [4]byte
|
||||
RemotePort uint32
|
||||
OwningPid uint32
|
||||
}
|
||||
|
||||
if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcpRowOwnerPid{})) {
|
||||
return 0, fmt.Errorf("invalid tables size: %d", len(buf))
|
||||
}
|
||||
|
||||
entries := unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
|
||||
for _, entry := range entries {
|
||||
localAddr := netip.AddrFrom4(entry.LocalAddr)
|
||||
localPort := windows.Ntohs(uint16(entry.LocalPort))
|
||||
remoteAddr := netip.AddrFrom4(entry.RemoteAddr)
|
||||
remotePort := windows.Ntohs(uint16(entry.RemotePort))
|
||||
|
||||
if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() {
|
||||
return entry.OwningPid, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
type MibTcp6RowOwnerPid struct {
|
||||
LocalAddr [16]byte
|
||||
LocalScopeID uint32
|
||||
LocalPort uint32
|
||||
RemoteAddr [16]byte
|
||||
RemoteScopeID uint32
|
||||
RemotePort uint32
|
||||
State uint32
|
||||
OwningPid uint32
|
||||
}
|
||||
|
||||
if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcp6RowOwnerPid{})) {
|
||||
return 0, fmt.Errorf("invalid tables size: %d", len(buf))
|
||||
}
|
||||
|
||||
entries := unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
|
||||
for _, entry := range entries {
|
||||
localAddr := netip.AddrFrom16(entry.LocalAddr)
|
||||
localPort := windows.Ntohs(uint16(entry.LocalPort))
|
||||
remoteAddr := netip.AddrFrom16(entry.RemoteAddr)
|
||||
remotePort := windows.Ntohs(uint16(entry.RemotePort))
|
||||
|
||||
if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() {
|
||||
return entry.OwningPid, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
case windows.IPPROTO_UDP:
|
||||
if family == windows.AF_INET {
|
||||
type MibUdpRowOwnerPid struct {
|
||||
LocalAddr [4]byte
|
||||
LocalPort uint32
|
||||
OwningPid uint32
|
||||
}
|
||||
|
||||
// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
|
||||
// this field can be illustrated as follows depends on different machine endianess:
|
||||
// little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB)
|
||||
// big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB)
|
||||
// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
|
||||
srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
|
||||
if srcPort != port {
|
||||
continue
|
||||
if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdpRowOwnerPid{})) {
|
||||
return 0, fmt.Errorf("invalid tables size: %d", len(buf))
|
||||
}
|
||||
|
||||
entries := unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
|
||||
for _, entry := range entries {
|
||||
localAddr := netip.AddrFrom4(entry.LocalAddr)
|
||||
localPort := windows.Ntohs(uint16(entry.LocalPort))
|
||||
|
||||
if localAddr == from.Addr() && localPort == from.Port() {
|
||||
return entry.OwningPid, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
type MibUdp6RowOwnerPid struct {
|
||||
LocalAddr [16]byte
|
||||
LocalScopeId uint32
|
||||
LocalPort uint32
|
||||
OwningPid uint32
|
||||
}
|
||||
|
||||
if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdp6RowOwnerPid{})) {
|
||||
return 0, fmt.Errorf("invalid tables size: %d", len(buf))
|
||||
}
|
||||
|
||||
entries := unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
|
||||
for _, entry := range entries {
|
||||
localAddr := netip.AddrFrom16(entry.LocalAddr)
|
||||
localPort := windows.Ntohs(uint16(entry.LocalPort))
|
||||
|
||||
if localAddr == from.Addr() && localPort == from.Port() {
|
||||
return entry.OwningPid, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
srcIP := net.IP(row[s.ip : s.ip+s.ipSize])
|
||||
// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
|
||||
if !ip.Equal(srcIP) && (!srcIP.IsUnspecified() || s.tcpState != -1) {
|
||||
continue
|
||||
}
|
||||
|
||||
pid := readNativeUint32(row[s.pid : s.pid+4])
|
||||
return pid, nil
|
||||
default:
|
||||
return 0, ErrInvalidNetwork
|
||||
}
|
||||
|
||||
return 0, ErrNotFound
|
||||
}
|
||||
|
||||
func newSearcher(isV4, isTCP bool) *searcher {
|
||||
var itemSize, port, ip, ipSize, pid int
|
||||
tcpState := -1
|
||||
switch {
|
||||
case isV4 && isTCP:
|
||||
// struct MIB_TCPROW_OWNER_PID
|
||||
itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
|
||||
case isV4 && !isTCP:
|
||||
// struct MIB_UDPROW_OWNER_PID
|
||||
itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
|
||||
case !isV4 && isTCP:
|
||||
// struct MIB_TCP6ROW_OWNER_PID
|
||||
itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
|
||||
case !isV4 && !isTCP:
|
||||
// struct MIB_UDP6ROW_OWNER_PID
|
||||
itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
|
||||
}
|
||||
|
||||
return &searcher{
|
||||
itemSize: itemSize,
|
||||
port: port,
|
||||
ip: ip,
|
||||
ipSize: ipSize,
|
||||
pid: pid,
|
||||
tcpState: tcpState,
|
||||
}
|
||||
}
|
||||
|
||||
func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
|
||||
for size, buf := uint32(8), make([]byte, 8); ; {
|
||||
ptr := unsafe.Pointer(&buf[0])
|
||||
err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
|
||||
|
||||
switch err {
|
||||
case 0:
|
||||
return buf, nil
|
||||
case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
|
||||
buf = make([]byte, size)
|
||||
default:
|
||||
return nil, fmt.Errorf("syscall error: %d", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readNativeUint32(b []byte) uint32 {
|
||||
return *(*uint32)(unsafe.Pointer(&b[0]))
|
||||
}
|
||||
|
||||
func getExecPathFromPID(pid uint32) (string, error) {
|
||||
// kernel process starts with a colon in order to distinguish with normal processes
|
||||
switch pid {
|
||||
@ -207,17 +217,13 @@ func getExecPathFromPID(pid uint32) (string, error) {
|
||||
}
|
||||
defer windows.CloseHandle(h)
|
||||
|
||||
buf := make([]uint16, syscall.MAX_LONG_PATH)
|
||||
buf := make([]uint16, windows.MAX_LONG_PATH)
|
||||
size := uint32(len(buf))
|
||||
r1, _, err := syscall.SyscallN(
|
||||
queryProcName,
|
||||
uintptr(h),
|
||||
uintptr(1),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&size)),
|
||||
)
|
||||
if r1 == 0 {
|
||||
|
||||
err = windows.QueryFullProcessImageName(h, 0, &buf[0], &size)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return syscall.UTF16ToString(buf[:size]), nil
|
||||
|
||||
return windows.UTF16ToString(buf[:size]), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user