From 985dc99b5d1d641acb52080b95e1549e181a52e1 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Sat, 28 May 2022 09:50:09 +0800 Subject: [PATCH] Refactor: use native Win32 API to detect interface changed on Windows --- constant/tun.go | 11 ++ listener/listener.go | 26 ++++- listener/tun/ipstack/commons/router.go | 61 ++-------- listener/tun/ipstack/commons/router_darwin.go | 68 ++++++++++- listener/tun/ipstack/commons/router_linux.go | 58 +++++++++- listener/tun/ipstack/commons/router_others.go | 4 + .../tun/ipstack/commons/router_windows.go | 107 +++++++++++++++--- listener/tun/tun_adapter.go | 3 +- 8 files changed, 262 insertions(+), 76 deletions(-) diff --git a/constant/tun.go b/constant/tun.go index ddd65d71..408ada8e 100644 --- a/constant/tun.go +++ b/constant/tun.go @@ -18,10 +18,21 @@ var StackTypeMapping = map[string]TUNStack{ const ( TunGvisor TUNStack = iota TunSystem + + TunDisabled TUNState = iota + TunEnabled + TunPaused ) type TUNStack int +type TUNState int + +type TUNChangeCallback interface { + Pause() + Resume() +} + // UnmarshalYAML unserialize TUNStack with yaml func (e *TUNStack) UnmarshalYAML(unmarshal func(any) error) error { var tp string diff --git a/listener/listener.go b/listener/listener.go index c9913a22..bb6484c5 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -385,7 +385,13 @@ func ReCreateTun(tunConf *config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- * return } - tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn) + callback := &tunChangeCallback{ + tunConf: *tunConf, + tcpIn: tcpIn, + udpIn: udpIn, + } + + tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn, callback) if err != nil { return } @@ -563,6 +569,24 @@ func hasTunConfigChange(tunConf *config.Tun, tunAddressPrefix *netip.Prefix) boo return false } +type tunChangeCallback struct { + tunConf config.Tun + tcpIn chan<- C.ConnContext + udpIn chan<- *inbound.PacketAdapter +} + +func (t *tunChangeCallback) Pause() { + conf := t.tunConf + conf.Enable = false + ReCreateTun(&conf, t.tcpIn, t.udpIn) +} + +func (t *tunChangeCallback) Resume() { + conf := t.tunConf + conf.Enable = true + ReCreateTun(&conf, t.tcpIn, t.udpIn) +} + func initCert() error { if _, err := os.Stat(C.Path.RootCA()); os.IsNotExist(err) { log.Infoln("Can't find mitm_ca.crt, start generate") diff --git a/listener/tun/ipstack/commons/router.go b/listener/tun/ipstack/commons/router.go index d4ac5743..798e1dd6 100644 --- a/listener/tun/ipstack/commons/router.go +++ b/listener/tun/ipstack/commons/router.go @@ -1,14 +1,13 @@ package commons import ( + "errors" "fmt" "net" "sync" "time" - "github.com/Dreamacro/clash/component/dialer" - "github.com/Dreamacro/clash/component/iface" - "github.com/Dreamacro/clash/log" + C "github.com/Dreamacro/clash/constant" ) var ( @@ -18,6 +17,10 @@ var ( monitorStarted = false monitorStop = make(chan struct{}, 2) monitorMux sync.Mutex + + tunStatus = C.TunDisabled + tunChangeCallback C.TUNChangeCallback + errInterfaceNotFound = errors.New("default interface not found") ) func ipv4MaskString(bits int) string { @@ -29,54 +32,6 @@ func ipv4MaskString(bits int) string { return fmt.Sprintf("%d.%d.%d.%d", m[0], m[1], m[2], m[3]) } -func StartDefaultInterfaceChangeMonitor() { - monitorMux.Lock() - if monitorStarted { - monitorMux.Unlock() - return - } - monitorStarted = true - monitorMux.Unlock() - - select { - case <-monitorStop: - default: - } - - t := time.NewTicker(monitorDuration) - defer t.Stop() - - for { - select { - case <-t.C: - interfaceName, err := GetAutoDetectInterface() - if err != nil { - log.Warnln("[TUN] default interface monitor err: %v", err) - continue - } - - old := dialer.DefaultInterface.Load() - if interfaceName == old { - continue - } - - dialer.DefaultInterface.Store(interfaceName) - - iface.FlushCache() - - log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName) - case <-monitorStop: - break - } - } -} - -func StopDefaultInterfaceChangeMonitor() { - monitorMux.Lock() - defer monitorMux.Unlock() - - if monitorStarted { - monitorStop <- struct{}{} - monitorStarted = false - } +func SetTunChangeCallback(callback C.TUNChangeCallback) { + tunChangeCallback = callback } diff --git a/listener/tun/ipstack/commons/router_darwin.go b/listener/tun/ipstack/commons/router_darwin.go index a889aaf5..5998c881 100644 --- a/listener/tun/ipstack/commons/router_darwin.go +++ b/listener/tun/ipstack/commons/router_darwin.go @@ -6,22 +6,26 @@ import ( "net/netip" "strings" "syscall" + "time" "github.com/Dreamacro/clash/common/cmd" + "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/iface" "github.com/Dreamacro/clash/listener/tun/device" + "github.com/Dreamacro/clash/log" "golang.org/x/net/route" ) func GetAutoDetectInterface() (string, error) { - iface, err := defaultRouteInterface() + ifaceM, err := defaultRouteInterface() if err != nil { return "", err } - return iface.Name, nil + return ifaceM.Name, nil } -func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, forceMTU int, autoRoute bool) error { +func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, _ int, autoRoute bool) error { if !addr.Addr().Is4() { return fmt.Errorf("supported ipv4 only") } @@ -96,17 +100,69 @@ func defaultRouteInterface() (*net.Interface, error) { continue } - iface, err1 := net.InterfaceByIndex(routeMessage.Index) + ifaceM, err1 := net.InterfaceByIndex(routeMessage.Index) if err1 != nil { continue } - if strings.HasPrefix(iface.Name, "utun") { + if strings.HasPrefix(ifaceM.Name, "utun") { continue } - return iface, nil + return ifaceM, nil } return nil, fmt.Errorf("ambiguous gateway interfaces found") } + +func StartDefaultInterfaceChangeMonitor() { + monitorMux.Lock() + if monitorStarted { + monitorMux.Unlock() + return + } + monitorStarted = true + monitorMux.Unlock() + + select { + case <-monitorStop: + default: + } + + t := time.NewTicker(monitorDuration) + defer t.Stop() + + for { + select { + case <-t.C: + interfaceName, err := GetAutoDetectInterface() + if err != nil { + log.Warnln("[TUN] default interface monitor err: %v", err) + continue + } + + old := dialer.DefaultInterface.Load() + if interfaceName == old { + continue + } + + dialer.DefaultInterface.Store(interfaceName) + + iface.FlushCache() + + log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName) + case <-monitorStop: + break + } + } +} + +func StopDefaultInterfaceChangeMonitor() { + monitorMux.Lock() + defer monitorMux.Unlock() + + if monitorStarted { + monitorStop <- struct{}{} + monitorStarted = false + } +} diff --git a/listener/tun/ipstack/commons/router_linux.go b/listener/tun/ipstack/commons/router_linux.go index 06df068c..e03acf39 100644 --- a/listener/tun/ipstack/commons/router_linux.go +++ b/listener/tun/ipstack/commons/router_linux.go @@ -3,16 +3,20 @@ package commons import ( "fmt" "net/netip" + "time" "github.com/Dreamacro/clash/common/cmd" + "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/iface" "github.com/Dreamacro/clash/listener/tun/device" + "github.com/Dreamacro/clash/log" ) func GetAutoDetectInterface() (string, error) { return cmd.ExecCmd("bash -c ip route show | grep 'default via' | awk -F ' ' 'NR==1{print $5}' | xargs echo -n") } -func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, forceMTU int, autoRoute bool) error { +func ConfigInterfaceAddress(dev device.Device, addr netip.Prefix, _ int, autoRoute bool) error { var ( interfaceName = dev.Name() ip = addr.Masked().Addr().Next() @@ -51,3 +55,55 @@ func execRouterCmd(action, route string, interfaceName string, linkIP string) er _, err := cmd.ExecCmd(cmdStr) return err } + +func StartDefaultInterfaceChangeMonitor() { + monitorMux.Lock() + if monitorStarted { + monitorMux.Unlock() + return + } + monitorStarted = true + monitorMux.Unlock() + + select { + case <-monitorStop: + default: + } + + t := time.NewTicker(monitorDuration) + defer t.Stop() + + for { + select { + case <-t.C: + interfaceName, err := GetAutoDetectInterface() + if err != nil { + log.Warnln("[TUN] default interface monitor err: %v", err) + continue + } + + old := dialer.DefaultInterface.Load() + if interfaceName == old { + continue + } + + dialer.DefaultInterface.Store(interfaceName) + + iface.FlushCache() + + log.Warnln("[TUN] default interface changed by monitor, %s => %s", old, interfaceName) + case <-monitorStop: + break + } + } +} + +func StopDefaultInterfaceChangeMonitor() { + monitorMux.Lock() + defer monitorMux.Unlock() + + if monitorStarted { + monitorStop <- struct{}{} + monitorStarted = false + } +} diff --git a/listener/tun/ipstack/commons/router_others.go b/listener/tun/ipstack/commons/router_others.go index 6c8ea341..85733f4a 100644 --- a/listener/tun/ipstack/commons/router_others.go +++ b/listener/tun/ipstack/commons/router_others.go @@ -17,3 +17,7 @@ func GetAutoDetectInterface() (string, error) { func ConfigInterfaceAddress(device.Device, netip.Prefix, int, bool) error { return fmt.Errorf("unsupported on this OS: %s", runtime.GOOS) } + +func StartDefaultInterfaceChangeMonitor() {} + +func StopDefaultInterfaceChangeMonitor() {} diff --git a/listener/tun/ipstack/commons/router_windows.go b/listener/tun/ipstack/commons/router_windows.go index b29b20ab..76a5462a 100644 --- a/listener/tun/ipstack/commons/router_windows.go +++ b/listener/tun/ipstack/commons/router_windows.go @@ -1,12 +1,16 @@ package commons import ( - "errors" "fmt" + "net" "net/netip" + "sync" "time" "github.com/Dreamacro/clash/common/nnip" + "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/iface" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/listener/tun/device" "github.com/Dreamacro/clash/listener/tun/device/tun" "github.com/Dreamacro/clash/log" @@ -16,7 +20,11 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -var wintunInterfaceName string +var ( + wintunInterfaceName string + unicastAddressChangeCallback *winipcfg.UnicastAddressChangeCallback + unicastAddressChangeLock sync.Mutex +) func GetAutoDetectInterface() (string, error) { ifname, err := getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET)) @@ -220,15 +228,15 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add if err != nil { return } - for _, iface := range interfaces { - if iface.OperStatus == winipcfg.IfOperStatusUp { + for _, ifaceM := range interfaces { + if ifaceM.OperStatus == winipcfg.IfOperStatusUp { continue } - for address := iface.FirstUnicastAddress; address != nil; address = address.Next { + for address := ifaceM.FirstUnicastAddress; address != nil; address = address.Next { if ip := nnip.IpToAddr(address.Address.IP()); addrHash[ip] { prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) - log.Infoln("[TUN] cleaning up stale address %s from interface ā€˜%s’", prefix.String(), iface.FriendlyName()) - _ = iface.LUID.DeleteIPAddress(prefix) + log.Infoln("[TUN] cleaning up stale address %s from interface ā€˜%s’", prefix.String(), ifaceM.FriendlyName()) + _ = ifaceM.LUID.DeleteIPAddress(prefix) } } } @@ -237,7 +245,7 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, error) { interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeGateways) if err != nil { - return "", fmt.Errorf("get ethernet interface failure. %w", err) + return "", fmt.Errorf("get default interface failure. %w", err) } var destination netip.Prefix @@ -247,25 +255,96 @@ func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, erro destination = netip.PrefixFrom(netip.IPv6Unspecified(), 0) } - for _, iface := range interfaces { - if iface.OperStatus != winipcfg.IfOperStatusUp { + for _, ifaceM := range interfaces { + if ifaceM.OperStatus != winipcfg.IfOperStatusUp { continue } - ifname := iface.FriendlyName() + ifname := ifaceM.FriendlyName() if wintunInterfaceName == ifname { continue } - for gatewayAddress := iface.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next { + for gatewayAddress := ifaceM.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next { nextHop := nnip.IpToAddr(gatewayAddress.Address.IP()) - if _, err = iface.LUID.Route(destination, nextHop); err == nil { + if _, err = ifaceM.LUID.Route(destination, nextHop); err == nil { return ifname, nil } } } - return "", errors.New("ethernet interface not found") + return "", errInterfaceNotFound +} + +func unicastAddressChange(_ winipcfg.MibNotificationType, unicastAddress *winipcfg.MibUnicastIPAddressRow) { + unicastAddressChangeLock.Lock() + defer unicastAddressChangeLock.Unlock() + + interfaceName, err := GetAutoDetectInterface() + if err != nil { + if err == errInterfaceNotFound && tunStatus == C.TunEnabled { + log.Warnln("[TUN] lost the default interface, pause tun adapter") + + tunStatus = C.TunPaused + tunChangeCallback.Pause() + } + return + } + + ifaceM, err := net.InterfaceByIndex(int(unicastAddress.InterfaceIndex)) + if err != nil { + log.Warnln("[TUN] default interface monitor err: %v", err) + return + } + + newName := ifaceM.Name + + if newName != interfaceName { + return + } + + dialer.DefaultInterface.Store(interfaceName) + + iface.FlushCache() + + if tunStatus == C.TunPaused { + log.Warnln("[TUN] found interface %s(%s), resume tun adapter", interfaceName, unicastAddress.Address.Addr()) + + tunStatus = C.TunEnabled + tunChangeCallback.Resume() + return + } + + log.Warnln("[TUN] default interface changed to %s(%s) by monitor", interfaceName, unicastAddress.Address.Addr()) +} + +func StartDefaultInterfaceChangeMonitor() { + if unicastAddressChangeCallback != nil { + return + } + + var err error + unicastAddressChangeCallback, err = winipcfg.RegisterUnicastAddressChangeCallback(unicastAddressChange) + + if err != nil { + log.Errorln("[TUN] register uni-cast address change callback failed: %v", err) + return + } + + tunStatus = C.TunEnabled + + log.Infoln("[TUN] register uni-cast address change callback") +} + +func StopDefaultInterfaceChangeMonitor() { + if unicastAddressChangeCallback == nil || tunStatus == C.TunPaused { + return + } + + _ = unicastAddressChangeCallback.Unregister() + unicastAddressChangeCallback = nil + tunChangeCallback = nil + tunStatus = C.TunDisabled } diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index d9178f3e..a03a2e81 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -24,7 +24,7 @@ import ( ) // New TunAdapter -func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { +func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter, tunChangeCallback C.TUNChangeCallback) (ipstack.Stack, error) { var ( tunAddress = netip.Prefix{} devName = tunConf.Device @@ -98,6 +98,7 @@ func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.Con } if tunConf.AutoDetectInterface { + commons.SetTunChangeCallback(tunChangeCallback) go commons.StartDefaultInterfaceChangeMonitor() }