Refactor: use native Win32 API to detect interface changed on Windows

This commit is contained in:
yaling888 2022-05-28 09:50:09 +08:00
parent 67905bcf7e
commit 985dc99b5d
8 changed files with 262 additions and 76 deletions

View File

@ -18,10 +18,21 @@ var StackTypeMapping = map[string]TUNStack{
const (
TunGvisor TUNStack = iota
TunSystem
TunDisabled TUNState = iota
TunEnabled
TunPaused
)
type TUNStack int
type TUNState int
type TUNChangeCallback interface {
Pause()
Resume()
}
// UnmarshalYAML unserialize TUNStack with yaml
func (e *TUNStack) UnmarshalYAML(unmarshal func(any) error) error {
var tp string

View File

@ -385,7 +385,13 @@ func ReCreateTun(tunConf *config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *
return
}
tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn)
callback := &tunChangeCallback{
tunConf: *tunConf,
tcpIn: tcpIn,
udpIn: udpIn,
}
tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn, callback)
if err != nil {
return
}
@ -563,6 +569,24 @@ func hasTunConfigChange(tunConf *config.Tun, tunAddressPrefix *netip.Prefix) boo
return false
}
type tunChangeCallback struct {
tunConf config.Tun
tcpIn chan<- C.ConnContext
udpIn chan<- *inbound.PacketAdapter
}
func (t *tunChangeCallback) Pause() {
conf := t.tunConf
conf.Enable = false
ReCreateTun(&conf, t.tcpIn, t.udpIn)
}
func (t *tunChangeCallback) Resume() {
conf := t.tunConf
conf.Enable = true
ReCreateTun(&conf, t.tcpIn, t.udpIn)
}
func initCert() error {
if _, err := os.Stat(C.Path.RootCA()); os.IsNotExist(err) {
log.Infoln("Can't find mitm_ca.crt, start generate")

View File

@ -1,14 +1,13 @@
package commons
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/iface"
"github.com/Dreamacro/clash/log"
C "github.com/Dreamacro/clash/constant"
)
var (
@ -18,6 +17,10 @@ var (
monitorStarted = false
monitorStop = make(chan struct{}, 2)
monitorMux sync.Mutex
tunStatus = C.TunDisabled
tunChangeCallback C.TUNChangeCallback
errInterfaceNotFound = errors.New("default interface not found")
)
func ipv4MaskString(bits int) string {
@ -29,54 +32,6 @@ func ipv4MaskString(bits int) string {
return fmt.Sprintf("%d.%d.%d.%d", m[0], m[1], m[2], m[3])
}
func StartDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
if monitorStarted {
monitorMux.Unlock()
return
}
monitorStarted = true
monitorMux.Unlock()
select {
case <-monitorStop:
default:
}
t := time.NewTicker(monitorDuration)
defer t.Stop()
for {
select {
case <-t.C:
interfaceName, err := GetAutoDetectInterface()
if err != nil {
log.Warnln("[TUN] default interface monitor err: %v", err)
continue
}
old := dialer.DefaultInterface.Load()
if interfaceName == old {
continue
}
dialer.DefaultInterface.Store(interfaceName)
iface.FlushCache()
log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName)
case <-monitorStop:
break
}
}
}
func StopDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
defer monitorMux.Unlock()
if monitorStarted {
monitorStop <- struct{}{}
monitorStarted = false
}
func SetTunChangeCallback(callback C.TUNChangeCallback) {
tunChangeCallback = callback
}

View File

@ -6,22 +6,26 @@ import (
"net/netip"
"strings"
"syscall"
"time"
"github.com/Dreamacro/clash/common/cmd"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/iface"
"github.com/Dreamacro/clash/listener/tun/device"
"github.com/Dreamacro/clash/log"
"golang.org/x/net/route"
)
func GetAutoDetectInterface() (string, error) {
iface, err := defaultRouteInterface()
ifaceM, err := defaultRouteInterface()
if err != nil {
return "", err
}
return iface.Name, nil
return ifaceM.Name, nil
}
func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, forceMTU int, autoRoute bool) error {
func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, _ int, autoRoute bool) error {
if !addr.Addr().Is4() {
return fmt.Errorf("supported ipv4 only")
}
@ -96,17 +100,69 @@ func defaultRouteInterface() (*net.Interface, error) {
continue
}
iface, err1 := net.InterfaceByIndex(routeMessage.Index)
ifaceM, err1 := net.InterfaceByIndex(routeMessage.Index)
if err1 != nil {
continue
}
if strings.HasPrefix(iface.Name, "utun") {
if strings.HasPrefix(ifaceM.Name, "utun") {
continue
}
return iface, nil
return ifaceM, nil
}
return nil, fmt.Errorf("ambiguous gateway interfaces found")
}
func StartDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
if monitorStarted {
monitorMux.Unlock()
return
}
monitorStarted = true
monitorMux.Unlock()
select {
case <-monitorStop:
default:
}
t := time.NewTicker(monitorDuration)
defer t.Stop()
for {
select {
case <-t.C:
interfaceName, err := GetAutoDetectInterface()
if err != nil {
log.Warnln("[TUN] default interface monitor err: %v", err)
continue
}
old := dialer.DefaultInterface.Load()
if interfaceName == old {
continue
}
dialer.DefaultInterface.Store(interfaceName)
iface.FlushCache()
log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName)
case <-monitorStop:
break
}
}
}
func StopDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
defer monitorMux.Unlock()
if monitorStarted {
monitorStop <- struct{}{}
monitorStarted = false
}
}

View File

@ -3,16 +3,20 @@ package commons
import (
"fmt"
"net/netip"
"time"
"github.com/Dreamacro/clash/common/cmd"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/iface"
"github.com/Dreamacro/clash/listener/tun/device"
"github.com/Dreamacro/clash/log"
)
func GetAutoDetectInterface() (string, error) {
return cmd.ExecCmd("bash -c ip route show | grep 'default via' | awk -F ' ' 'NR==1{print $5}' | xargs echo -n")
}
func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, forceMTU int, autoRoute bool) error {
func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, _ int, autoRoute bool) error {
var (
interfaceName = dev.Name()
ip = addr.Masked().Addr().Next()
@ -51,3 +55,55 @@ func execRouterCmd(action, route string, interfaceName string, linkIP string) er
_, err := cmd.ExecCmd(cmdStr)
return err
}
func StartDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
if monitorStarted {
monitorMux.Unlock()
return
}
monitorStarted = true
monitorMux.Unlock()
select {
case <-monitorStop:
default:
}
t := time.NewTicker(monitorDuration)
defer t.Stop()
for {
select {
case <-t.C:
interfaceName, err := GetAutoDetectInterface()
if err != nil {
log.Warnln("[TUN] default interface monitor err: %v", err)
continue
}
old := dialer.DefaultInterface.Load()
if interfaceName == old {
continue
}
dialer.DefaultInterface.Store(interfaceName)
iface.FlushCache()
log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName)
case <-monitorStop:
break
}
}
}
func StopDefaultInterfaceChangeMonitor() {
monitorMux.Lock()
defer monitorMux.Unlock()
if monitorStarted {
monitorStop <- struct{}{}
monitorStarted = false
}
}

View File

@ -17,3 +17,7 @@ func GetAutoDetectInterface() (string, error) {
func ConfigInterfaceAddress(device.Device, netip.Prefix, int, bool) error {
return fmt.Errorf("unsupported on this OS: %s", runtime.GOOS)
}
func StartDefaultInterfaceChangeMonitor() {}
func StopDefaultInterfaceChangeMonitor() {}

View File

@ -1,12 +1,16 @@
package commons
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/Dreamacro/clash/common/nnip"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/iface"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/tun/device"
"github.com/Dreamacro/clash/listener/tun/device/tun"
"github.com/Dreamacro/clash/log"
@ -16,7 +20,11 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
var wintunInterfaceName string
var (
wintunInterfaceName string
unicastAddressChangeCallback *winipcfg.UnicastAddressChangeCallback
unicastAddressChangeLock sync.Mutex
)
func GetAutoDetectInterface() (string, error) {
ifname, err := getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET))
@ -220,15 +228,15 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
if err != nil {
return
}
for _, iface := range interfaces {
if iface.OperStatus == winipcfg.IfOperStatusUp {
for _, ifaceM := range interfaces {
if ifaceM.OperStatus == winipcfg.IfOperStatusUp {
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
for address := ifaceM.FirstUnicastAddress; address != nil; address = address.Next {
if ip := nnip.IpToAddr(address.Address.IP()); addrHash[ip] {
prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
log.Infoln("[TUN] cleaning up stale address %s from interface %s", prefix.String(), iface.FriendlyName())
_ = iface.LUID.DeleteIPAddress(prefix)
log.Infoln("[TUN] cleaning up stale address %s from interface %s", prefix.String(), ifaceM.FriendlyName())
_ = ifaceM.LUID.DeleteIPAddress(prefix)
}
}
}
@ -237,7 +245,7 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, error) {
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeGateways)
if err != nil {
return "", fmt.Errorf("get ethernet interface failure. %w", err)
return "", fmt.Errorf("get default interface failure. %w", err)
}
var destination netip.Prefix
@ -247,25 +255,96 @@ func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, erro
destination = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
for _, iface := range interfaces {
if iface.OperStatus != winipcfg.IfOperStatusUp {
for _, ifaceM := range interfaces {
if ifaceM.OperStatus != winipcfg.IfOperStatusUp {
continue
}
ifname := iface.FriendlyName()
ifname := ifaceM.FriendlyName()
if wintunInterfaceName == ifname {
continue
}
for gatewayAddress := iface.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next {
for gatewayAddress := ifaceM.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next {
nextHop := nnip.IpToAddr(gatewayAddress.Address.IP())
if _, err = iface.LUID.Route(destination, nextHop); err == nil {
if _, err = ifaceM.LUID.Route(destination, nextHop); err == nil {
return ifname, nil
}
}
}
return "", errors.New("ethernet interface not found")
return "", errInterfaceNotFound
}
func unicastAddressChange(_ winipcfg.MibNotificationType, unicastAddress *winipcfg.MibUnicastIPAddressRow) {
unicastAddressChangeLock.Lock()
defer unicastAddressChangeLock.Unlock()
interfaceName, err := GetAutoDetectInterface()
if err != nil {
if err == errInterfaceNotFound && tunStatus == C.TunEnabled {
log.Warnln("[TUN] lost the default interface, pause tun adapter")
tunStatus = C.TunPaused
tunChangeCallback.Pause()
}
return
}
ifaceM, err := net.InterfaceByIndex(int(unicastAddress.InterfaceIndex))
if err != nil {
log.Warnln("[TUN] default interface monitor err: %v", err)
return
}
newName := ifaceM.Name
if newName != interfaceName {
return
}
dialer.DefaultInterface.Store(interfaceName)
iface.FlushCache()
if tunStatus == C.TunPaused {
log.Warnln("[TUN] found interface %s(%s), resume tun adapter", interfaceName, unicastAddress.Address.Addr())
tunStatus = C.TunEnabled
tunChangeCallback.Resume()
return
}
log.Warnln("[TUN] default interface changed to %s(%s) by monitor", interfaceName, unicastAddress.Address.Addr())
}
func StartDefaultInterfaceChangeMonitor() {
if unicastAddressChangeCallback != nil {
return
}
var err error
unicastAddressChangeCallback, err = winipcfg.RegisterUnicastAddressChangeCallback(unicastAddressChange)
if err != nil {
log.Errorln("[TUN] register uni-cast address change callback failed: %v", err)
return
}
tunStatus = C.TunEnabled
log.Infoln("[TUN] register uni-cast address change callback")
}
func StopDefaultInterfaceChangeMonitor() {
if unicastAddressChangeCallback == nil || tunStatus == C.TunPaused {
return
}
_ = unicastAddressChangeCallback.Unregister()
unicastAddressChangeCallback = nil
tunChangeCallback = nil
tunStatus = C.TunDisabled
}

View File

@ -24,7 +24,7 @@ import (
)
// New TunAdapter
func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter, tunChangeCallback C.TUNChangeCallback) (ipstack.Stack, error) {
var (
tunAddress = netip.Prefix{}
devName = tunConf.Device
@ -98,6 +98,7 @@ func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.Con
}
if tunConf.AutoDetectInterface {
commons.SetTunChangeCallback(tunChangeCallback)
go commons.StartDefaultInterfaceChangeMonitor()
}