Code: refresh code

This commit is contained in:
yaling888
2021-07-01 22:49:29 +08:00
parent 3ca5d17c40
commit d7732f6ebc
104 changed files with 11329 additions and 136 deletions

View File

@ -3,16 +3,20 @@ package proxy
import (
"fmt"
"net"
"runtime"
"strconv"
"sync"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/http"
"github.com/Dreamacro/clash/listener/mixed"
"github.com/Dreamacro/clash/listener/redir"
"github.com/Dreamacro/clash/listener/socks"
"github.com/Dreamacro/clash/listener/tproxy"
"github.com/Dreamacro/clash/listener/tun"
"github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/log"
)
@ -29,6 +33,7 @@ var (
tproxyUDPListener *tproxy.UDPListener
mixedListener *mixed.Listener
mixedUDPLister *socks.UDPListener
tunAdapter ipstack.TunAdapter
// lock for recreate function
socksMux sync.Mutex
@ -36,6 +41,7 @@ var (
redirMux sync.Mutex
tproxyMux sync.Mutex
mixedMux sync.Mutex
tunMux sync.Mutex
)
type Ports struct {
@ -58,6 +64,18 @@ func SetAllowLan(al bool) {
allowLan = al
}
func Tun() config.Tun {
if tunAdapter == nil {
return config.Tun{}
}
return config.Tun{
Enable: true,
Stack: tunAdapter.Stack(),
DNSListen: tunAdapter.DNSListen(),
AutoRoute: tunAdapter.AutoRoute(),
}
}
func SetBindAddress(host string) {
bindAddress = host
}
@ -275,6 +293,25 @@ func ReCreateMixed(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.P
return nil
}
func ReCreateTun(conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) error {
tunMux.Lock()
defer tunMux.Unlock()
if tunAdapter != nil {
tunAdapter.Close()
tunAdapter = nil
}
if !conf.Enable {
return nil
}
var err error
tunAdapter, err = tun.New(conf, tcpIn, udpIn)
return err
}
// GetPorts return the ports of proxy servers
func GetPorts() *Ports {
ports := &Ports{}
@ -330,3 +367,12 @@ func genAddr(host string, port int, allowLan bool) string {
return fmt.Sprintf("127.0.0.1:%d", port)
}
// CleanUp clean up something
func CleanUp() {
if runtime.GOOS == "windows" {
if tunAdapter != nil {
tunAdapter.Close()
}
}
}

View File

@ -0,0 +1,192 @@
package tproxy
import (
"errors"
"fmt"
"os/exec"
U "os/user"
"runtime"
"strings"
"github.com/Dreamacro/clash/log"
)
var (
interfaceName = ""
tproxyPort = 0
dnsPort = 0
)
const (
PROXY_FWMARK = "0x2d0"
PROXY_ROUTE_TABLE = "0x2d0"
USERNAME = "clash"
)
func SetTProxyLinuxIPTables(ifname string, tport int, dport int) error {
var err error
if _, err = execCmd("iptables -V"); err != nil {
return fmt.Errorf("current operations system [%s] are not support iptables or command iptables does not exist", runtime.GOOS)
}
//if _, err = execCmd("modprobe xt_TPROXY"); err != nil {
// return errors.New("xt_TPROXY module does not exist, please install it")
//}
user, err := U.Lookup(USERNAME)
if err != nil {
return fmt.Errorf("the user \" %s\" does not exist, please create it", USERNAME)
}
if ifname == "" {
return errors.New("interface name can not be empty")
}
ownerUid := user.Uid
interfaceName = ifname
tproxyPort = tport
dnsPort = dport
// add route
execCmd(fmt.Sprintf("ip -f inet rule add fwmark %s lookup %s", PROXY_FWMARK, PROXY_ROUTE_TABLE))
execCmd(fmt.Sprintf("ip -f inet route add local default dev %s table %s", interfaceName, PROXY_ROUTE_TABLE))
// set FORWARD
execCmd("sysctl -w net.ipv4.ip_forward=1")
execCmd(fmt.Sprintf("iptables -t filter -A FORWARD -o %s -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT", interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -A FORWARD -o %s -j ACCEPT", interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -A FORWARD -i %s ! -o %s -j ACCEPT", interfaceName, interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -A FORWARD -i %s -o %s -j ACCEPT", interfaceName, interfaceName))
// set clash divert
execCmd("iptables -t mangle -N clash_divert")
execCmd("iptables -t mangle -F clash_divert")
execCmd(fmt.Sprintf("iptables -t mangle -A clash_divert -j MARK --set-mark %s", PROXY_FWMARK))
execCmd("iptables -t mangle -A clash_divert -j ACCEPT")
// set pre routing
execCmd("iptables -t mangle -N clash_prerouting")
execCmd("iptables -t mangle -F clash_prerouting")
execCmd("iptables -t mangle -A clash_prerouting -s 172.17.0.0/16 -j RETURN")
execCmd("iptables -t mangle -A clash_prerouting -p udp --dport 53 -j ACCEPT")
execCmd("iptables -t mangle -A clash_prerouting -p tcp --dport 53 -j ACCEPT")
execCmd("iptables -t mangle -A clash_prerouting -m addrtype --dst-type LOCAL -j RETURN")
addLocalnetworkToChain("clash_prerouting")
execCmd("iptables -t mangle -A clash_prerouting -p tcp -m socket -j clash_divert")
execCmd("iptables -t mangle -A clash_prerouting -p udp -m socket -j clash_divert")
execCmd(fmt.Sprintf("iptables -t mangle -A clash_prerouting -p tcp -j TPROXY --on-port %d --tproxy-mark %s/%s", tproxyPort, PROXY_FWMARK, PROXY_FWMARK))
execCmd(fmt.Sprintf("iptables -t mangle -A clash_prerouting -p udp -j TPROXY --on-port %d --tproxy-mark %s/%s", tproxyPort, PROXY_FWMARK, PROXY_FWMARK))
execCmd("iptables -t mangle -A PREROUTING -j clash_prerouting")
execCmd(fmt.Sprintf("iptables -t nat -I PREROUTING ! -s 172.17.0.0/16 ! -d 127.0.0.0/8 -p tcp --dport 53 -j REDIRECT --to %d", dnsPort))
execCmd(fmt.Sprintf("iptables -t nat -I PREROUTING ! -s 172.17.0.0/16 ! -d 127.0.0.0/8 -p udp --dport 53 -j REDIRECT --to %d", dnsPort))
// set post routing
execCmd(fmt.Sprintf("iptables -t nat -A POSTROUTING -o %s -m addrtype ! --src-type LOCAL -j MASQUERADE", interfaceName))
// set output
execCmd("iptables -t mangle -N clash_output")
execCmd("iptables -t mangle -F clash_output")
execCmd(fmt.Sprintf("iptables -t mangle -A clash_output -m owner --uid-owner %s -j RETURN", ownerUid))
execCmd("iptables -t mangle -A clash_output -p udp -m multiport --dports 53,123,137 -j ACCEPT")
execCmd("iptables -t mangle -A clash_output -p tcp --dport 53 -j ACCEPT")
execCmd("iptables -t mangle -A clash_output -m addrtype --dst-type LOCAL -j RETURN")
execCmd("iptables -t mangle -A clash_output -m addrtype --dst-type BROADCAST -j RETURN")
addLocalnetworkToChain("clash_output")
execCmd(fmt.Sprintf("iptables -t mangle -A clash_output -p tcp -j MARK --set-mark %s", PROXY_FWMARK))
execCmd(fmt.Sprintf("iptables -t mangle -A clash_output -p udp -j MARK --set-mark %s", PROXY_FWMARK))
execCmd(fmt.Sprintf("iptables -t mangle -I OUTPUT -o %s -j clash_output", interfaceName))
// set dns output
execCmd("iptables -t nat -N clash_dns_output")
execCmd("iptables -t nat -F clash_dns_output")
execCmd(fmt.Sprintf("iptables -t nat -A clash_dns_output -m owner --uid-owner %s -j RETURN", ownerUid))
execCmd("iptables -t nat -A clash_dns_output -s 172.17.0.0/16 -j RETURN")
execCmd(fmt.Sprintf("iptables -t nat -A clash_dns_output -p udp -j REDIRECT --to-ports %d", dnsPort))
execCmd(fmt.Sprintf("iptables -t nat -A clash_dns_output -p tcp -j REDIRECT --to-ports %d", dnsPort))
execCmd("iptables -t nat -I OUTPUT -p tcp --dport 53 -j clash_dns_output")
execCmd("iptables -t nat -I OUTPUT -p udp --dport 53 -j clash_dns_output")
log.Infoln("[TProxy] Setting iptables completed")
return nil
}
func CleanUpTProxyLinuxIPTables() {
if interfaceName == "" || tproxyPort == 0 || dnsPort == 0 {
return
}
log.Warnln("Clean up tproxy linux iptables")
if _, err := execCmd("iptables -t mangle -L clash_divert"); err != nil {
return
}
// clean route
execCmd(fmt.Sprintf("ip -f inet rule del fwmark %s lookup %s", PROXY_FWMARK, PROXY_ROUTE_TABLE))
execCmd(fmt.Sprintf("ip -f inet route del local default dev %s table %s", interfaceName, PROXY_ROUTE_TABLE))
// clean FORWARD
//execCmd("sysctl -w net.ipv4.ip_forward=0")
execCmd(fmt.Sprintf("iptables -t filter -D FORWARD -i %s ! -o %s -j ACCEPT", interfaceName, interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -D FORWARD -i %s -o %s -j ACCEPT", interfaceName, interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -D FORWARD -o %s -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT", interfaceName))
execCmd(fmt.Sprintf("iptables -t filter -D FORWARD -o %s -j ACCEPT", interfaceName))
// clean PREROUTING
execCmd(fmt.Sprintf("iptables -t nat -D PREROUTING ! -s 172.17.0.0/16 ! -d 127.0.0.0/8 -p tcp --dport 53 -j REDIRECT --to %d", dnsPort))
execCmd(fmt.Sprintf("iptables -t nat -D PREROUTING ! -s 172.17.0.0/16 ! -d 127.0.0.0/8 -p udp --dport 53 -j REDIRECT --to %d", dnsPort))
execCmd("iptables -t mangle -D PREROUTING -j clash_prerouting")
// clean POSTROUTING
execCmd(fmt.Sprintf("iptables -t nat -D POSTROUTING -o %s -m addrtype ! --src-type LOCAL -j MASQUERADE", interfaceName))
// clean OUTPUT
execCmd(fmt.Sprintf("iptables -t mangle -D OUTPUT -o %s -j clash_output", interfaceName))
execCmd("iptables -t nat -D OUTPUT -p tcp --dport 53 -j clash_dns_output")
execCmd("iptables -t nat -D OUTPUT -p udp --dport 53 -j clash_dns_output")
// clean chain
execCmd("iptables -t mangle -F clash_prerouting")
execCmd("iptables -t mangle -X clash_prerouting")
execCmd("iptables -t mangle -F clash_divert")
execCmd("iptables -t mangle -X clash_divert")
execCmd("iptables -t mangle -F clash_output")
execCmd("iptables -t mangle -X clash_output")
execCmd("iptables -t nat -F clash_dns_output")
execCmd("iptables -t nat -X clash_dns_output")
}
func addLocalnetworkToChain(chain string) {
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 0.0.0.0/8 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 10.0.0.0/8 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 100.64.0.0/10 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 127.0.0.0/8 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 169.254.0.0/16 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 172.16.0.0/12 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 192.0.0.0/24 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 192.0.2.0/24 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 192.88.99.0/24 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 192.168.0.0/16 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 198.51.100.0/24 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 203.0.113.0/24 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 224.0.0.0/4 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 240.0.0.0/4 -j RETURN", chain))
execCmd(fmt.Sprintf("iptables -t mangle -A %s -d 255.255.255.255/32 -j RETURN", chain))
}
func execCmd(cmdstr string) (string, error) {
log.Debugln("[TProxy] %s", cmdstr)
args := strings.Split(cmdstr, " ")
cmd := exec.Command(args[0], args[1:]...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorln("[TProxy] error: %s, %s", err.Error(), string(out))
return "", err
}
return string(out), nil
}

66
listener/tun/dev/dev.go Normal file
View File

@ -0,0 +1,66 @@
package dev
import (
"os/exec"
"runtime"
"github.com/Dreamacro/clash/log"
)
// TunDevice is cross-platform tun interface
type TunDevice interface {
Name() string
URL() string
MTU() (int, error)
IsClose() bool
Close() error
Read(buff []byte) (int, error)
Write(buff []byte) (int, error)
}
func SetLinuxAutoRoute() {
log.Infoln("Tun adapter auto setting MacOS route")
addLinuxSystemRoute("1")
addLinuxSystemRoute("2/7")
addLinuxSystemRoute("4/6")
addLinuxSystemRoute("8/5")
addLinuxSystemRoute("16/4")
addLinuxSystemRoute("32/3")
addLinuxSystemRoute("64/2")
addLinuxSystemRoute("128.0/1")
addLinuxSystemRoute("198.18.0/16")
}
func RemoveLinuxAutoRoute() {
log.Infoln("Tun adapter removing MacOS route")
delLinuxSystemRoute("1")
delLinuxSystemRoute("2/7")
delLinuxSystemRoute("4/6")
delLinuxSystemRoute("8/5")
delLinuxSystemRoute("16/4")
delLinuxSystemRoute("32/3")
delLinuxSystemRoute("64/2")
delLinuxSystemRoute("128.0/1")
delLinuxSystemRoute("198.18.0/16")
}
func addLinuxSystemRoute(net string) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
return
}
cmd := exec.Command("route", "add", "-net", net, "198.18.0.1")
if err := cmd.Run(); err != nil {
log.Errorln("[MacOS auto route] Failed to add system route: %s, cmd: %s", err.Error(), cmd.String())
}
}
func delLinuxSystemRoute(net string) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
return
}
cmd := exec.Command("route", "delete", "-net", net, "198.18.0.1")
_ = cmd.Run()
//if err := cmd.Run(); err != nil {
// log.Errorln("[MacOS auto route]Failed to delete system route: %s, cmd: %s", err.Error(), cmd.String())
//}
}

View File

@ -0,0 +1,506 @@
// +build darwin
package dev
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"os/exec"
"sync"
"syscall"
"unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"github.com/Dreamacro/clash/common/pool"
)
const utunControlName = "com.apple.net.utun_control"
const _IOC_OUT = 0x40000000
const _IOC_IN = 0x80000000
const _IOC_INOUT = _IOC_IN | _IOC_OUT
// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
// https://github.com/apple/darwin-xnu/blob/master/bsd/sys/ioccom.h
// #define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) /* get id from name */ = 0xc0644e03
const _CTLIOCGINFO = _IOC_INOUT | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
// #define SIOCAIFADDR_IN6 _IOW('i', 26, struct in6_aliasreq) = 0x8080691a
//const _SIOCAIFADDR_IN6 = _IOC_IN | ((128 & 0x1fff) << 16) | uint32(byte('i'))<<8 | 26
// #define SIOCPROTOATTACH_IN6 _IOWR('i', 110, struct in6_aliasreq_64)
const _SIOCPROTOATTACH_IN6 = _IOC_INOUT | ((128 & 0x1fff) << 16) | uint32(byte('i'))<<8 | 110
// #define SIOCLL_START _IOWR('i', 130, struct in6_aliasreq)
const _SIOCLL_START = _IOC_INOUT | ((128 & 0x1fff) << 16) | uint32(byte('i'))<<8 | 130
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/netinet6/nd6.h#L469
const ND6_INFINITE_LIFETIME = 0xffffffff
// Following the wireguard-go solution:
// These unix.SYS_* constants were removed from golang.org/x/sys/unix
// so copy them here for now.
// See https://github.com/golang/go/issues/41868
const (
sys_IOCTL = 54
sys_CONNECT = 98
sys_GETSOCKOPT = 118
)
type tunDarwin struct {
//url string
name string
tunAddress string
autoRoute bool
tunFile *os.File
errors chan error
closed bool
stopOnce sync.Once
}
// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/net/if.h#L402-L563
//type ifreqAddr struct {
// Name [unix.IFNAMSIZ]byte
// Addr unix.RawSockaddrInet4
// Pad [8]byte
//}
var sockaddrCtlSize uintptr = 32
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
name := "utun"
// TODO: configure the MTU
mtu := 9000
ifIndex := -1
if name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
if err != nil || ifIndex < 0 {
return nil, fmt.Errorf("interface name must be utun[0-9]*")
}
}
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
var ctlInfo = &struct {
ctlID uint32
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], []byte(utunControlName))
_, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd),
uintptr(_CTLIOCGINFO),
uintptr(unsafe.Pointer(ctlInfo)),
)
if errno != 0 {
return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
}
sc := sockaddrCtl{
scLen: uint8(sockaddrCtlSize),
scFamily: unix.AF_SYSTEM,
ssSysaddr: 2,
scID: ctlInfo.ctlID,
scUnit: uint32(ifIndex) + 1,
}
scPointer := unsafe.Pointer(&sc)
_, _, errno = unix.RawSyscall(
sys_CONNECT,
uintptr(fd),
uintptr(scPointer),
uintptr(sockaddrCtlSize),
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
}
err = syscall.SetNonblock(fd, true)
if err != nil {
return nil, err
}
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu, tunAddress, autoRoute)
if err != nil {
return nil, err
}
if autoRoute {
SetLinuxAutoRoute()
}
return tun, nil
}
func CreateTUNFromFile(file *os.File, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
tun := &tunDarwin{
tunFile: file,
tunAddress: tunAddress,
autoRoute: autoRoute,
errors: make(chan error, 5),
}
name, err := tun.getName()
if err != nil {
tun.tunFile.Close()
return nil, err
}
tun.name = name
if err != nil {
tun.tunFile.Close()
return nil, err
}
if mtu > 0 {
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
}
// This address doesn't mean anything here. NIC just net an IP address to set route upon.
// TODO: maybe let user config it. And I'm doubt whether we really need it.
p2pAddress := net.ParseIP("198.18.0.1")
err = tun.setTunAddress(p2pAddress)
if err != nil {
tun.Close()
return nil, err
}
err = tun.attachLinkLocal()
if err != nil {
tun.Close()
return nil, err
}
return tun, nil
}
func (t *tunDarwin) Name() string {
return t.name
}
func (t *tunDarwin) URL() string {
return fmt.Sprintf("dev://%s", t.Name())
}
func (t *tunDarwin) MTU() (int, error) {
return t.getInterfaceMtu()
}
func (t *tunDarwin) Read(buff []byte) (int, error) {
select {
case err := <-t.errors:
return 0, err
default:
n, err := t.tunFile.Read(buff)
if n < 4 {
return 0, err
}
copy(buff[:], buff[4:])
return n - 4, err
}
}
func (t *tunDarwin) Write(buff []byte) (int, error) {
// reserve space for header
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf[:cap(buf)])
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
copy(buf[4:], buff)
if buf[4]>>4 == ipv6.Version {
buf[3] = unix.AF_INET6
} else {
buf[3] = unix.AF_INET
}
// write
return t.tunFile.Write(buf[:4+len(buff)])
}
func (t *tunDarwin) IsClose() bool {
return t.closed
}
func (t *tunDarwin) Close() error {
t.stopOnce.Do(func() {
if t.autoRoute {
RemoveLinuxAutoRoute()
}
t.closed = true
t.tunFile.Close()
})
return nil
}
func (t *tunDarwin) getInterfaceMtu() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
// do ioctl call
var ifr [64]byte
copy(ifr[:], t.name)
_, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", t.name)
}
return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
}
func (t *tunDarwin) getName() (string, error) {
var ifName struct {
name [16]byte
}
ifNameSize := uintptr(16)
var errno syscall.Errno
t.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall6(
sys_GETSOCKOPT,
fd,
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
})
if errno != 0 {
return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
}
t.name = string(ifName.name[:ifNameSize-1])
return t.name, nil
}
func (t *tunDarwin) setMTU(n int) error {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr [32]byte
copy(ifr[:], t.name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", t.name)
}
return nil
}
func (t *tunDarwin) operateOnFd(fn func(fd uintptr)) {
sysconn, err := t.tunFile.SyscallConn()
// TODO: consume the errors
if err != nil {
t.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
return
}
err = sysconn.Control(fn)
if err != nil {
t.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
}
}
func (t *tunDarwin) setTunAddress(addr net.IP) error {
var ifr [unix.IFNAMSIZ]byte
copy(ifr[:], t.name)
// set IPv4 address
fd4, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer syscall.Close(fd4)
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/sys/sockio.h#L107
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/net/if.h#L570-L575
// https://man.openbsd.org/netintro.4#SIOCAIFADDR
type aliasreq struct {
ifra_name [unix.IFNAMSIZ]byte
ifra_addr unix.RawSockaddrInet4
ifra_dstaddr unix.RawSockaddrInet4
ifra_mask unix.RawSockaddrInet4
}
var ip4 [4]byte
copy(ip4[:], addr.To4())
ip4mask := [4]byte{255, 255, 0, 0}
ifra4 := aliasreq{
ifra_name: ifr,
ifra_addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: ip4,
},
ifra_dstaddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: ip4,
},
ifra_mask: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: ip4mask,
},
}
if _, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd4),
uintptr(unix.SIOCAIFADDR),
uintptr(unsafe.Pointer(&ifra4)),
); errno != 0 {
return fmt.Errorf("failed to set ip address on %s: %v", t.name, errno)
}
return nil
}
func (t *tunDarwin) attachLinkLocal() error {
var ifr [unix.IFNAMSIZ]byte
copy(ifr[:], t.name)
// attach link-local address
fd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer syscall.Close(fd6)
// SIOCAIFADDR_IN6
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/netinet6/in6_var.h#L114-L119
// https://opensource.apple.com/source/network_cmds/network_cmds-543.260.3/
type in6_addrlifetime struct {
//ia6t_expire uint64
//ia6t_preferred uint64
//ia6t_vltime uint32
//ia6t_pltime uint32
}
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/netinet6/in6_var.h#L336-L343
// https://github.com/apple/darwin-xnu/blob/a449c6a3b8014d9406c2ddbdc81795da24aa7443/bsd/netinet6/in6.h#L174-L181
type in6_aliasreq struct {
ifra_name [unix.IFNAMSIZ]byte
ifra_addr unix.RawSockaddrInet6
ifra_dstaddr unix.RawSockaddrInet6
ifra_prefixmask unix.RawSockaddrInet6
ifra_flags int32
ifra_lifetime in6_addrlifetime
}
// Attach link-local address
ifra6 := in6_aliasreq{
ifra_name: ifr,
}
if _, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd6),
uintptr(_SIOCPROTOATTACH_IN6),
uintptr(unsafe.Pointer(&ifra6)),
); errno != 0 {
return fmt.Errorf("failed to attach link-local address on %s: SIOCPROTOATTACH_IN6 %v", t.name, errno)
}
if _, _, errno := unix.Syscall(
sys_IOCTL,
uintptr(fd6),
uintptr(_SIOCLL_START),
uintptr(unsafe.Pointer(&ifra6)),
); errno != 0 {
return fmt.Errorf("failed to set ipv6 address on %s: SIOCLL_START %v", t.name, errno)
}
return nil
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
cmd := exec.Command("bash", "-c", "netstat -rnf inet | grep 'default' | awk -F ' ' 'NR==1{print $6}' | xargs echo -n")
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return "", err
}
if out.Len() == 0 {
return "", errors.New("interface not found by default route")
}
return out.String(), nil
}

View File

@ -0,0 +1,254 @@
// +build linux android
package dev
import (
"bytes"
"errors"
"fmt"
"net/url"
"os"
"os/exec"
"strconv"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const (
cloneDevicePath = "/dev/net/tun"
ifReqSize = unix.IFNAMSIZ + 64
)
type tunLinux struct {
url string
name string
tunAddress string
autoRoute bool
tunFile *os.File
mtu int
closed bool
stopOnce sync.Once
}
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
deviceURL, _ := url.Parse("dev://clash0")
mtu, _ := strconv.ParseInt(deviceURL.Query().Get("mtu"), 0, 32)
t := &tunLinux{
url: deviceURL.String(),
mtu: int(mtu),
tunAddress: tunAddress,
autoRoute: autoRoute,
}
switch deviceURL.Scheme {
case "dev":
var err error
var dev TunDevice
dev, err = t.openDeviceByName(deviceURL.Host)
if err != nil {
return nil, err
}
if autoRoute {
SetLinuxAutoRoute()
}
return dev, nil
case "fd":
fd, err := strconv.ParseInt(deviceURL.Host, 10, 32)
if err != nil {
return nil, err
}
var dev TunDevice
dev, err = t.openDeviceByFd(int(fd))
if err != nil {
return nil, err
}
if autoRoute {
SetLinuxAutoRoute()
}
return dev, nil
}
return nil, fmt.Errorf("unsupported device type `%s`", deviceURL.Scheme)
}
func (t *tunLinux) Name() string {
return t.name
}
func (t *tunLinux) URL() string {
return t.url
}
func (t *tunLinux) Write(buff []byte) (int, error) {
return t.tunFile.Write(buff)
}
func (t *tunLinux) Read(buff []byte) (int, error) {
return t.tunFile.Read(buff)
}
func (t *tunLinux) IsClose() bool {
return t.closed
}
func (t *tunLinux) Close() error {
t.stopOnce.Do(func() {
if t.autoRoute {
RemoveLinuxAutoRoute()
}
t.closed = true
t.tunFile.Close()
})
return nil
}
func (t *tunLinux) MTU() (int, error) {
// Sometime, we can't read MTU by SIOCGIFMTU. Then we should return the preset MTU
if t.mtu > 0 {
return t.mtu, nil
}
mtu, err := t.getInterfaceMtu()
return int(mtu), err
}
func (t *tunLinux) openDeviceByName(name string) (TunDevice, error) {
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
if err != nil {
return nil, err
}
var ifr [ifReqSize]byte
var flags uint16 = unix.IFF_TUN | unix.IFF_NO_PI
nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ {
return nil, errors.New("interface name too long")
}
copy(ifr[:], nameBytes)
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(nfd),
uintptr(unix.TUNSETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return nil, errno
}
err = unix.SetNonblock(nfd, true)
if err != nil {
return nil, err
}
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
t.tunFile = os.NewFile(uintptr(nfd), cloneDevicePath)
t.name, err = t.getName()
if err != nil {
t.tunFile.Close()
return nil, err
}
return t, nil
}
func (t *tunLinux) openDeviceByFd(fd int) (TunDevice, error) {
var ifr struct {
name [16]byte
flags uint16
_ [22]byte
}
fd, err := syscall.Dup(fd)
if err != nil {
return nil, err
}
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNGETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
return nil, errno
}
if ifr.flags&syscall.IFF_TUN == 0 || ifr.flags&syscall.IFF_NO_PI == 0 {
return nil, errors.New("only tun device and no pi mode supported")
}
nullStr := ifr.name[:]
i := bytes.IndexByte(nullStr, 0)
if i != -1 {
nullStr = nullStr[:i]
}
t.name = string(nullStr)
t.tunFile = os.NewFile(uintptr(fd), "/dev/tun")
return t, nil
}
func (t *tunLinux) getInterfaceMtu() (uint32, error) {
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
if err != nil {
return 0, err
}
defer syscall.Close(fd)
var ifreq struct {
name [16]byte
mtu int32
_ [20]byte
}
copy(ifreq.name[:], t.name)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
if errno != 0 {
return 0, errno
}
return uint32(ifreq.mtu), nil
}
func (t *tunLinux) getName() (string, error) {
sysconn, err := t.tunFile.SyscallConn()
if err != nil {
return "", err
}
var ifr [ifReqSize]byte
var errno syscall.Errno
err = sysconn.Control(func(fd uintptr) {
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
fd,
uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
})
if err != nil {
return "", errors.New("failed to get name of TUN device: " + err.Error())
}
if errno != 0 {
return "", errors.New("failed to get name of TUN device: " + errno.Error())
}
nullStr := ifr[:]
i := bytes.IndexByte(nullStr, 0)
if i != -1 {
nullStr = nullStr[:i]
}
t.name = string(nullStr)
return t.name, nil
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
cmd := exec.Command("bash", "-c", "ip route show | grep 'default via' | awk -F ' ' 'NR==1{print $5}' | xargs echo -n")
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return "", err
}
return out.String(), nil
}

View File

@ -0,0 +1,17 @@
// +build !linux,!android,!darwin,!windows
package dev
import (
"errors"
"runtime"
)
func OpenTunDevice(tunAddress string, autoRute bool) (TunDevice, error) {
return nil, errors.New("Unsupported platform " + runtime.GOOS + "/" + runtime.GOARCH)
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
return "", nil
}

View File

@ -0,0 +1,552 @@
// +build windows
package dev
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"sort"
"sync"
"sync/atomic"
"time"
_ "unsafe"
"github.com/Dreamacro/clash/listener/tun/dev/winipcfg"
"github.com/Dreamacro/clash/listener/tun/dev/wintun"
"github.com/Dreamacro/clash/log"
"golang.org/x/sys/windows"
)
const (
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
spinloopRateThreshold = 800000000 / 8 // 800mbps
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
messageTransportHeaderSize = 0 // size of data preceding content in transport message
)
type rateJuggler struct {
current uint64
nextByteCount uint64
nextStartTime int64
changing int32
}
type tunWindows struct {
wt *wintun.Adapter
handle windows.Handle
closed bool
closing sync.RWMutex
forcedMTU int
rate rateJuggler
session wintun.Session
readWait windows.Handle
stopOnce sync.Once
url string
name string
tunAddress string
autoRoute bool
}
var WintunPool, _ = wintun.MakePool("Clash")
var WintunStaticRequestedGUID *windows.GUID
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
requestedGUID, err := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")
if err == nil {
WintunStaticRequestedGUID = &requestedGUID
log.Debugln("Generate GUID: %s", WintunStaticRequestedGUID.String())
} else {
log.Warnln("Error parese GUID from string: %v", err)
}
interfaceName := "Clash"
mtu := 9000
tun, err := CreateTUN(interfaceName, mtu, tunAddress, autoRoute)
if err != nil {
return nil, err
}
return tun, nil
}
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, tunAddress, autoRoute)
}
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
var err error
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.Delete(false)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
if rebootRequired {
log.Infoln("Windows indicated a reboot is required.")
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &tunWindows{
wt: wt,
handle: windows.InvalidHandle,
forcedMTU: forcedMTU,
tunAddress: tunAddress,
autoRoute: autoRoute,
}
// config tun ip
err = tun.configureInterface()
if err != nil {
tun.wt.Delete(false)
return nil, fmt.Errorf("Error configure interface: %w", err)
}
realInterfaceName, err2 := wt.Name()
if err2 == nil {
ifname = realInterfaceName
tun.name = realInterfaceName
}
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
tun.wt.Delete(false)
return nil, fmt.Errorf("Error starting session: %w", err)
}
tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *tunWindows) getName() (string, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.closed {
return "", os.ErrClosed
}
return tun.wt.Name()
}
func (tun *tunWindows) IsClose() bool {
return tun.closed
}
func (tun *tunWindows) Close() error {
tun.stopOnce.Do(func() {
//tun.closing.Lock()
//defer tun.closing.Unlock()
tun.closed = true
tun.session.End()
if tun.wt != nil {
forceCloseSessions := false
rebootRequired, err := tun.wt.Delete(forceCloseSessions)
if rebootRequired {
log.Infoln("Delete Wintun failure, Windows indicated a reboot is required.")
} else {
log.Infoln("Delete Wintun success.")
}
if err != nil {
log.Errorln("Close Wintun Sessions failure: %v", err)
}
}
})
return nil
}
func (tun *tunWindows) MTU() (int, error) {
return tun.forcedMTU, nil
}
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *tunWindows) ForceMTU(mtu int) {
tun.forcedMTU = mtu
}
func (tun *tunWindows) Read(buff []byte) (int, error) {
return tun.ReadO(buff, messageTransportHeaderSize)
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *tunWindows) ReadO(buff []byte, offset int) (int, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
retry:
if tun.closed {
return 0, os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
for {
if tun.closed {
return 0, os.ErrClosed
}
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
copy(buff[offset:], packet)
tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize))
return packetSize, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_INVALID_DATA:
return 0, errors.New("Send ring corrupt")
}
return 0, fmt.Errorf("Read failed: %w", err)
}
}
func (tun *tunWindows) Flush() error {
return nil
}
func (tun *tunWindows) Write(buff []byte) (int, error) {
return tun.WriteO(buff, messageTransportHeaderSize)
}
func (tun *tunWindows) WriteO(buff []byte, offset int) (int, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.closed {
return 0, os.ErrClosed
}
packetSize := len(buff) - offset
tun.rate.update(uint64(packetSize))
packet, err := tun.session.AllocateSendPacket(packetSize)
if err == nil {
copy(packet, buff[offset:])
tun.session.SendPacket(packet)
return packetSize, nil
}
switch err {
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_BUFFER_OVERFLOW:
return 0, nil // Dropping when ring is full.
}
return 0, fmt.Errorf("Write failed: %w", err)
}
// LUID returns Windows interface instance ID.
func (tun *tunWindows) LUID() uint64 {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.closed {
return 0
}
return tun.wt.LUID()
}
// RunningVersion returns the running version of the Wintun driver.
func (tun *tunWindows) RunningVersion() (version uint32, err error) {
return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
}
}
func (tun *tunWindows) Name() string {
return tun.name
}
func (t *tunWindows) URL() string {
return fmt.Sprintf("dev://%s", t.Name())
}
func (tun *tunWindows) configureInterface() error {
luid := winipcfg.LUID(tun.LUID())
mtu, err := tun.MTU()
if err != nil {
return errors.New("unable to get device mtu")
}
family := winipcfg.AddressFamily(windows.AF_INET)
familyV6 := winipcfg.AddressFamily(windows.AF_INET6)
tunAddress := winipcfg.ParseIPCidr("198.18.0.1/16")
addresses := []net.IPNet{tunAddress.IPNet()}
err = luid.FlushIPAddresses(familyV6)
if err != nil {
return err
}
err = luid.FlushDNS(family)
if err != nil {
return err
}
err = luid.FlushDNS(familyV6)
if err != nil {
return err
}
err = luid.FlushRoutes(familyV6)
if err != nil {
return err
}
err = luid.SetIPAddressesForFamily(family, addresses)
if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
cleanupAddressesOnDisconnectedInterfaces(family, addresses)
err = luid.SetIPAddressesForFamily(family, addresses)
}
if err != nil {
return err
}
foundDefault4 := false
foundDefault6 := false
if tun.autoRoute {
allowedIPs := []*winipcfg.IPCidr{
winipcfg.ParseIPCidr("1.0.0.0/8"),
winipcfg.ParseIPCidr("2.0.0.0/7"),
winipcfg.ParseIPCidr("4.0.0.0/6"),
winipcfg.ParseIPCidr("8.0.0.0/5"),
winipcfg.ParseIPCidr("16.0.0.0/4"),
winipcfg.ParseIPCidr("32.0.0.0/3"),
winipcfg.ParseIPCidr("64.0.0.0/2"),
winipcfg.ParseIPCidr("128.0.0.0/1"),
//winipcfg.ParseIPCidr("198.18.0.0/16"),
//winipcfg.ParseIPCidr("198.18.0.1/32"),
//winipcfg.ParseIPCidr("198.18.255.255/32"),
winipcfg.ParseIPCidr("224.0.0.0/4"),
winipcfg.ParseIPCidr("255.255.255.255/32"),
}
estimatedRouteCount := len(allowedIPs)
routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
var haveV4Address, haveV6Address bool = true, false
for _, allowedip := range allowedIPs {
allowedip.MaskSelf()
if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
continue
}
route := winipcfg.RouteData{
Destination: allowedip.IPNet(),
Metric: 0,
}
if allowedip.Bits() == 32 {
if allowedip.Cidr == 0 {
foundDefault4 = true
}
route.NextHop = net.IPv4zero
} else if allowedip.Bits() == 128 {
if allowedip.Cidr == 0 {
foundDefault6 = true
}
route.NextHop = net.IPv6zero
}
routes = append(routes, route)
}
deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
sort.Slice(routes, func(i, j int) bool {
if routes[i].Metric != routes[j].Metric {
return routes[i].Metric < routes[j].Metric
}
if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 {
return c < 0
}
return false
})
for i := 0; i < len(routes); i++ {
if i > 0 && routes[i].Metric == routes[i-1].Metric &&
bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
continue
}
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
if err != nil {
return err
}
}
ipif, err := luid.IPInterface(family)
if err != nil {
return err
}
ipif.NLMTU = uint32(mtu)
if family == windows.AF_INET {
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
} else if family == windows.AF_INET6 {
if foundDefault6 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
ipif.DadTransmits = 0
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
}
err = ipif.Set()
if err != nil {
return err
}
ipif6, err := luid.IPInterface(familyV6)
if err != nil {
return err
}
err = ipif6.Set()
if err != nil {
return err
}
return luid.SetDNS(family, []net.IP{net.ParseIP("198.18.0.2")}, nil)
}
func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
if len(addresses) == 0 {
return
}
includedInAddresses := func(a net.IPNet) bool {
// TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer!
for _, addr := range addresses {
ip := addr.IP
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
mA, _ := addr.Mask.Size()
mB, _ := a.Mask.Size()
if bytes.Equal(ip, a.IP) && mA == mB {
return true
}
}
return false
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
return
}
for _, iface := range interfaces {
if iface.OperStatus == winipcfg.IfOperStatusUp {
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
ip := address.Address.IP()
ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
if includedInAddresses(ipnet) {
log.Infoln("[Wintun] Cleaning up stale address %s from interface %s", ipnet.String(), iface.FriendlyName())
iface.LUID.DeleteIPAddress(ipnet)
}
}
}
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
ifname, err := getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET))
if err == nil {
return ifname, err
}
return getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET6))
}
func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, error) {
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeGateways)
if err != nil {
return "", fmt.Errorf("find ethernet interface failure. %w", err)
}
for _, iface := range interfaces {
if iface.OperStatus != winipcfg.IfOperStatusUp {
continue
}
ifname := iface.FriendlyName()
if ifname == "Clash" {
continue
}
for gatewayAddress := iface.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next {
nextHop := gatewayAddress.Address.IP()
var ipnet net.IPNet
if family == windows.AF_INET {
ipnet = net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
ipnet = net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
if _, err = iface.LUID.Route(ipnet, nextHop); err == nil {
return ifname, nil
}
}
}
return "", errors.New("ethernet interface not found")
}

View File

@ -0,0 +1,56 @@
// +build windows
package winipcfg
import (
"fmt"
"net"
"strconv"
"strings"
)
type IPCidr struct {
IP net.IP
Cidr uint8
}
func (r *IPCidr) String() string {
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
}
func (r *IPCidr) Bits() uint8 {
if r.IP.To4() != nil {
return 32
}
return 128
}
func (r *IPCidr) IPNet() net.IPNet {
return net.IPNet{
IP: r.IP,
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
}
}
func (r *IPCidr) MaskSelf() {
bits := int(r.Bits())
mask := net.CIDRMask(int(r.Cidr), bits)
for i := 0; i < bits/8; i++ {
r.IP[i] &= mask[i]
}
}
func ParseIPCidr(ipcidr string) *IPCidr {
s := strings.Split(ipcidr, "/")
if len(s) != 2 {
return nil
}
cidr, err := strconv.Atoi(s[1])
if err != nil {
return nil
}
return &IPCidr{
IP: net.ParseIP(s[0]),
Cidr: uint8(cidr),
}
}

View File

@ -0,0 +1,85 @@
// +build windows
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// InterfaceChangeCallback structure allows interface change callback handling.
type InterfaceChangeCallback struct {
cb func(notificationType MibNotificationType, iface *MibIPInterfaceRow)
wait sync.WaitGroup
}
var (
interfaceChangeAddRemoveMutex = sync.Mutex{}
interfaceChangeMutex = sync.Mutex{}
interfaceChangeCallbacks = make(map[*InterfaceChangeCallback]bool)
interfaceChangeHandle = windows.Handle(0)
)
// RegisterInterfaceChangeCallback registers a new InterfaceChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned InterfaceChangeCallback.Unregister method should be used
// to unregister.
func RegisterInterfaceChangeCallback(callback func(notificationType MibNotificationType, iface *MibIPInterfaceRow)) (*InterfaceChangeCallback, error) {
s := &InterfaceChangeCallback{cb: callback}
interfaceChangeAddRemoveMutex.Lock()
defer interfaceChangeAddRemoveMutex.Unlock()
interfaceChangeMutex.Lock()
defer interfaceChangeMutex.Unlock()
interfaceChangeCallbacks[s] = true
if interfaceChangeHandle == 0 {
err := notifyIPInterfaceChange(windows.AF_UNSPEC, windows.NewCallback(interfaceChanged), 0, false, &interfaceChangeHandle)
if err != nil {
delete(interfaceChangeCallbacks, s)
interfaceChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *InterfaceChangeCallback) Unregister() error {
interfaceChangeAddRemoveMutex.Lock()
defer interfaceChangeAddRemoveMutex.Unlock()
interfaceChangeMutex.Lock()
delete(interfaceChangeCallbacks, callback)
removeIt := len(interfaceChangeCallbacks) == 0 && interfaceChangeHandle != 0
interfaceChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(interfaceChangeHandle)
if err != nil {
return err
}
interfaceChangeHandle = 0
}
return nil
}
func interfaceChanged(callerContext uintptr, row *MibIPInterfaceRow, notificationType MibNotificationType) uintptr {
rowCopy := *row
interfaceChangeMutex.Lock()
for cb := range interfaceChangeCallbacks {
cb.wait.Add(1)
go func(cb *InterfaceChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
interfaceChangeMutex.Unlock()
return 0
}

View File

@ -0,0 +1,383 @@
// +build windows
package winipcfg
import (
"errors"
"net"
"strings"
"golang.org/x/sys/windows"
)
// LUID represents a network interface.
type LUID uint64
// IPInterface method retrieves IP information for the specified interface on the local computer.
func (luid LUID) IPInterface(family AddressFamily) (*MibIPInterfaceRow, error) {
row := &MibIPInterfaceRow{}
row.Init()
row.InterfaceLUID = luid
row.Family = family
err := row.get()
if err != nil {
return nil, err
}
return row, nil
}
// Interface method retrieves information for the specified adapter on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2
func (luid LUID) Interface() (*MibIfRow2, error) {
row := &MibIfRow2{}
row.InterfaceLUID = luid
err := row.get()
if err != nil {
return nil, err
}
return row, nil
}
// GUID method converts a locally unique identifier (LUID) for a network interface to a globally unique identifier (GUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceluidtoguid
func (luid LUID) GUID() (*windows.GUID, error) {
guid := &windows.GUID{}
err := convertInterfaceLUIDToGUID(&luid, guid)
if err != nil {
return nil, err
}
return guid, nil
}
// LUIDFromGUID function converts a globally unique identifier (GUID) for a network interface to the locally unique identifier (LUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceguidtoluid
func LUIDFromGUID(guid *windows.GUID) (LUID, error) {
var luid LUID
err := convertInterfaceGUIDToLUID(guid, &luid)
if err != nil {
return 0, err
}
return luid, nil
}
// LUIDFromIndex function converts a local index for a network interface to the locally unique identifier (LUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceindextoluid
func LUIDFromIndex(index uint32) (LUID, error) {
var luid LUID
err := convertInterfaceIndexToLUID(index, &luid)
if err != nil {
return 0, err
}
return luid, nil
}
// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry)
func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
row := &MibUnicastIPAddressRow{InterfaceLUID: luid}
err := row.Address.SetIP(ip, 0)
if err != nil {
return nil, err
}
err = row.get()
if err != nil {
return nil, err
}
return row, nil
}
// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
func (luid LUID) AddIPAddress(address net.IPNet) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
err := row.Address.SetIP(address.IP, 0)
if err != nil {
return err
}
ones, _ := address.Mask.Size()
row.OnLinkPrefixLength = uint8(ones)
return row.Create()
}
// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
for i := range addresses {
err := luid.AddIPAddress(addresses[i])
if err != nil {
return err
}
}
return nil
}
// SetIPAddresses method sets new unicast IP addresses to the interface.
func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
err := luid.FlushIPAddresses(windows.AF_UNSPEC)
if err != nil {
return err
}
return luid.AddIPAddresses(addresses)
}
// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
err := luid.FlushIPAddresses(family)
if err != nil {
return err
}
for i := range addresses {
asV4 := addresses[i].IP.To4()
if asV4 == nil && family == windows.AF_INET {
continue
} else if asV4 != nil && family == windows.AF_INET6 {
continue
}
err := luid.AddIPAddress(addresses[i])
if err != nil {
return err
}
}
return nil
}
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
func (luid LUID) DeleteIPAddress(address net.IPNet) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
err := row.Address.SetIP(address.IP, 0)
if err != nil {
return err
}
// Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry().
ones, _ := address.Mask.Size()
row.OnLinkPrefixLength = uint8(ones)
return row.Delete()
}
// FlushIPAddresses method deletes all interface's unicast IP addresses.
func (luid LUID) FlushIPAddresses(family AddressFamily) error {
var tab *mibUnicastIPAddressTable
err := getUnicastIPAddressTable(family, &tab)
if err != nil {
return err
}
t := tab.get()
for i := range t {
if t[i].InterfaceLUID == luid {
t[i].Delete()
}
}
tab.free()
return nil
}
// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2).
// NOTE: If the corresponding route isn't found, the method will return error.
func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return nil, err
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return nil, err
}
err = row.get()
if err != nil {
return nil, err
}
return row, nil
}
// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature.
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2)
func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return err
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return err
}
row.Metric = metric
return row.Create()
}
// AddRoutes method adds multiple routes to the interface.
func (luid LUID) AddRoutes(routesData []*RouteData) error {
for _, rd := range routesData {
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
if err != nil {
return err
}
}
return nil
}
// SetRoutes method sets (flush than add) multiple routes to the interface.
func (luid LUID) SetRoutes(routesData []*RouteData) error {
err := luid.FlushRoutes(windows.AF_UNSPEC)
if err != nil {
return err
}
return luid.AddRoutes(routesData)
}
// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface.
func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error {
err := luid.FlushRoutes(family)
if err != nil {
return err
}
for _, rd := range routesData {
asV4 := rd.Destination.IP.To4()
if asV4 == nil && family == windows.AF_INET {
continue
} else if asV4 != nil && family == windows.AF_INET6 {
continue
}
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
if err != nil {
return err
}
}
return nil
}
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return err
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return err
}
err = row.get()
if err != nil {
return err
}
return row.Delete()
}
// FlushRoutes method deletes all interface's routes.
// It continues on failures, and returns the last error afterwards.
func (luid LUID) FlushRoutes(family AddressFamily) error {
var tab *mibIPforwardTable2
err := getIPForwardTable2(family, &tab)
if err != nil {
return err
}
t := tab.get()
for i := range t {
if t[i].InterfaceLUID == luid {
err2 := t[i].Delete()
if err2 != nil {
err = err2
}
}
}
tab.free()
return err
}
// DNS method returns all DNS server addresses associated with the adapter.
func (luid LUID) DNS() ([]net.IP, error) {
addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
return nil, err
}
r := make([]net.IP, 0, len(addresses))
for _, addr := range addresses {
if addr.LUID == luid {
for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next {
if ip := dns.Address.IP(); ip != nil {
r = append(r, ip)
} else {
return nil, windows.ERROR_INVALID_PARAMETER
}
}
}
}
return r, nil
}
// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family.
func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error {
if family != windows.AF_INET && family != windows.AF_INET6 {
return windows.ERROR_PROTOCOL_UNREACHABLE
}
var filteredServers []string
for _, server := range servers {
if v4 := server.To4(); v4 != nil && family == windows.AF_INET {
filteredServers = append(filteredServers, v4.String())
} else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
filteredServers = append(filteredServers, v6.String())
}
}
servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ","))
if err != nil {
return err
}
domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ","))
if err != nil {
return err
}
guid, err := luid.GUID()
if err != nil {
return err
}
var maybeV6 uint64
if family == windows.AF_INET6 {
maybeV6 = disFlagsIPv6
}
// For >= Windows 10 1809
err = setInterfaceDnsSettings(*guid, &dnsInterfaceSettings{
Version: disVersion1,
Flags: disFlagsNameServer | disFlagsSearchList | maybeV6,
NameServer: servers16,
SearchList: domains16,
})
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
// For < Windows 10 1809
err = luid.fallbackSetDNSForFamily(family, servers)
if err != nil {
return err
}
if len(domains) > 0 {
return luid.fallbackSetDNSDomain(domains[0])
} else {
return luid.fallbackSetDNSDomain("")
}
}
// FlushDNS method clears all DNS servers associated with the adapter.
func (luid LUID) FlushDNS(family AddressFamily) error {
return luid.SetDNS(family, nil, nil)
}

View File

@ -0,0 +1,3 @@
package winipcfg
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zwinipcfg_windows.go winipcfg.go

View File

@ -0,0 +1,105 @@
// +build windows
package winipcfg
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"os/exec"
"path/filepath"
"strings"
"syscall"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
func runNetsh(cmds []string) error {
system32, err := windows.GetSystemDirectory()
if err != nil {
return err
}
cmd := exec.Command(filepath.Join(system32, "netsh.exe"))
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("runNetsh stdin pipe - %w", err)
}
go func() {
defer stdin.Close()
io.WriteString(stdin, strings.Join(append(cmds, "exit\r\n"), "\r\n"))
}()
output, err := cmd.CombinedOutput()
// Horrible kludges, sorry.
cleaned := bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'})
cleaned = bytes.ReplaceAll(cleaned, []byte("netsh>"), []byte{})
cleaned = bytes.ReplaceAll(cleaned, []byte("There are no Domain Name Servers (DNS) configured on this computer."), []byte{})
cleaned = bytes.TrimSpace(cleaned)
if len(cleaned) != 0 && err == nil {
return fmt.Errorf("netsh: %#q", string(cleaned))
} else if err != nil {
return fmt.Errorf("netsh: %v: %#q", err, string(cleaned))
}
return nil
}
const (
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
)
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error {
var templateFlush string
if family == windows.AF_INET {
templateFlush = netshCmdTemplateFlush4
} else if family == windows.AF_INET6 {
templateFlush = netshCmdTemplateFlush6
}
cmds := make([]string, 0, 1+len(dnses))
ipif, err := luid.IPInterface(family)
if err != nil {
return err
}
cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
for i := 0; i < len(dnses); i++ {
if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
} else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
}
}
return runNetsh(cmds)
}
func (luid LUID) fallbackSetDNSDomain(domain string) error {
guid, err := luid.GUID()
if err != nil {
return fmt.Errorf("Error converting luid to guid: %w", err)
}
key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE)
if err != nil {
return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err)
}
paths, _, err := key.GetStringsValue("IpConfig")
key.Close()
if err != nil {
return fmt.Errorf("Error reading IpConfig registry key: %w", err)
}
if len(paths) == 0 {
return errors.New("No TCP/IP interfaces found on adapter")
}
key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE)
if err != nil {
return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err)
}
err = key.SetStringValue("Domain", domain)
key.Close()
return err
}

View File

@ -0,0 +1,85 @@
// +build windows
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// RouteChangeCallback structure allows route change callback handling.
type RouteChangeCallback struct {
cb func(notificationType MibNotificationType, route *MibIPforwardRow2)
wait sync.WaitGroup
}
var (
routeChangeAddRemoveMutex = sync.Mutex{}
routeChangeMutex = sync.Mutex{}
routeChangeCallbacks = make(map[*RouteChangeCallback]bool)
routeChangeHandle = windows.Handle(0)
)
// RegisterRouteChangeCallback registers a new RouteChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned RouteChangeCallback.Unregister method should be used
// to unregister.
func RegisterRouteChangeCallback(callback func(notificationType MibNotificationType, route *MibIPforwardRow2)) (*RouteChangeCallback, error) {
s := &RouteChangeCallback{cb: callback}
routeChangeAddRemoveMutex.Lock()
defer routeChangeAddRemoveMutex.Unlock()
routeChangeMutex.Lock()
defer routeChangeMutex.Unlock()
routeChangeCallbacks[s] = true
if routeChangeHandle == 0 {
err := notifyRouteChange2(windows.AF_UNSPEC, windows.NewCallback(routeChanged), 0, false, &routeChangeHandle)
if err != nil {
delete(routeChangeCallbacks, s)
routeChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *RouteChangeCallback) Unregister() error {
routeChangeAddRemoveMutex.Lock()
defer routeChangeAddRemoveMutex.Unlock()
routeChangeMutex.Lock()
delete(routeChangeCallbacks, callback)
removeIt := len(routeChangeCallbacks) == 0 && routeChangeHandle != 0
routeChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(routeChangeHandle)
if err != nil {
return err
}
routeChangeHandle = 0
}
return nil
}
func routeChanged(callerContext uintptr, row *MibIPforwardRow2, notificationType MibNotificationType) uintptr {
rowCopy := *row
routeChangeMutex.Lock()
for cb := range routeChangeCallbacks {
cb.wait.Add(1)
go func(cb *RouteChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
routeChangeMutex.Unlock()
return 0
}

View File

@ -0,0 +1,993 @@
// +build windows
package winipcfg
import (
"net"
"unsafe"
"golang.org/x/sys/windows"
)
const (
anySize = 1
maxDNSSuffixStringLength = 256
maxDHCPv6DUIDLength = 130
ifMaxStringSize = 256
ifMaxPhysAddressLength = 32
)
// AddressFamily enumeration specifies protocol family and is one of the windows.AF_* constants.
type AddressFamily uint16
// IPAAFlags enumeration describes adapter addresses flags
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh
type IPAAFlags uint32
const (
IPAAFlagDdnsEnabled IPAAFlags = 1 << iota
IPAAFlagRegisterAdapterSuffix
IPAAFlagDhcpv4Enabled
IPAAFlagReceiveOnly
IPAAFlagNoMulticast
IPAAFlagIpv6OtherStatefulConfig
IPAAFlagNetbiosOverTcpipEnabled
IPAAFlagIpv4Enabled
IPAAFlagIpv6Enabled
IPAAFlagIpv6ManagedAddressConfigurationSupported
)
// IfOperStatus enumeration specifies the operational status of an interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-if_oper_status
type IfOperStatus uint32
const (
IfOperStatusUp IfOperStatus = iota + 1
IfOperStatusDown
IfOperStatusTesting
IfOperStatusUnknown
IfOperStatusDormant
IfOperStatusNotPresent
IfOperStatusLowerLayerDown
)
// IfType enumeration specifies interface type.
type IfType uint32
const (
IfTypeOther IfType = 1 // None of the below
IfTypeRegular1822 = 2
IfTypeHdh1822 = 3
IfTypeDdnX25 = 4
IfTypeRfc877X25 = 5
IfTypeEthernetCSMACD = 6
IfTypeISO88023CSMACD = 7
IfTypeISO88024Tokenbus = 8
IfTypeISO88025Tokenring = 9
IfTypeISO88026Man = 10
IfTypeStarlan = 11
IfTypeProteon10Mbit = 12
IfTypeProteon80Mbit = 13
IfTypeHyperchannel = 14
IfTypeFddi = 15
IfTypeLapB = 16
IfTypeSdlc = 17
IfTypeDs1 = 18 // DS1-MIB
IfTypeE1 = 19 // Obsolete; see DS1-MIB
IfTypeBasicISDN = 20
IfTypePrimaryISDN = 21
IfTypePropPoint2PointSerial = 22 // proprietary serial
IfTypePPP = 23
IfTypeSoftwareLoopback = 24
IfTypeEon = 25 // CLNP over IP
IfTypeEthernet3Mbit = 26
IfTypeNsip = 27 // XNS over IP
IfTypeSlip = 28 // Generic Slip
IfTypeUltra = 29 // ULTRA Technologies
IfTypeDs3 = 30 // DS3-MIB
IfTypeSip = 31 // SMDS, coffee
IfTypeFramerelay = 32 // DTE only
IfTypeRs232 = 33
IfTypePara = 34 // Parallel port
IfTypeArcnet = 35
IfTypeArcnetPlus = 36
IfTypeAtm = 37 // ATM cells
IfTypeMioX25 = 38
IfTypeSonet = 39 // SONET or SDH
IfTypeX25Ple = 40
IfTypeIso88022LLC = 41
IfTypeLocaltalk = 42
IfTypeSmdsDxi = 43
IfTypeFramerelayService = 44 // FRNETSERV-MIB
IfTypeV35 = 45
IfTypeHssi = 46
IfTypeHippi = 47
IfTypeModem = 48 // Generic Modem
IfTypeAal5 = 49 // AAL5 over ATM
IfTypeSonetPath = 50
IfTypeSonetVt = 51
IfTypeSmdsIcip = 52 // SMDS InterCarrier Interface
IfTypePropVirtual = 53 // Proprietary virtual/internal
IfTypePropMultiplexor = 54 // Proprietary multiplexing
IfTypeIEEE80212 = 55 // 100BaseVG
IfTypeFibrechannel = 56
IfTypeHippiinterface = 57
IfTypeFramerelayInterconnect = 58 // Obsolete, use 32 or 44
IfTypeAflane8023 = 59 // ATM Emulated LAN for 802.3
IfTypeAflane8025 = 60 // ATM Emulated LAN for 802.5
IfTypeCctemul = 61 // ATM Emulated circuit
IfTypeFastether = 62 // Fast Ethernet (100BaseT)
IfTypeISDN = 63 // ISDN and X.25
IfTypeV11 = 64 // CCITT V.11/X.21
IfTypeV36 = 65 // CCITT V.36
IfTypeG703_64k = 66 // CCITT G703 at 64Kbps
IfTypeG703_2mb = 67 // Obsolete; see DS1-MIB
IfTypeQllc = 68 // SNA QLLC
IfTypeFastetherFX = 69 // Fast Ethernet (100BaseFX)
IfTypeChannel = 70
IfTypeIEEE80211 = 71 // Radio spread spectrum
IfTypeIBM370parchan = 72 // IBM System 360/370 OEMI Channel
IfTypeEscon = 73 // IBM Enterprise Systems Connection
IfTypeDlsw = 74 // Data Link Switching
IfTypeISDNS = 75 // ISDN S/T interface
IfTypeISDNU = 76 // ISDN U interface
IfTypeLapD = 77 // Link Access Protocol D
IfTypeIpswitch = 78 // IP Switching Objects
IfTypeRsrb = 79 // Remote Source Route Bridging
IfTypeAtmLogical = 80 // ATM Logical Port
IfTypeDs0 = 81 // Digital Signal Level 0
IfTypeDs0Bundle = 82 // Group of ds0s on the same ds1
IfTypeBsc = 83 // Bisynchronous Protocol
IfTypeAsync = 84 // Asynchronous Protocol
IfTypeCnr = 85 // Combat Net Radio
IfTypeIso88025rDtr = 86 // ISO 802.5r DTR
IfTypeEplrs = 87 // Ext Pos Loc Report Sys
IfTypeArap = 88 // Appletalk Remote Access Protocol
IfTypePropCnls = 89 // Proprietary Connectionless Proto
IfTypeHostpad = 90 // CCITT-ITU X.29 PAD Protocol
IfTypeTermpad = 91 // CCITT-ITU X.3 PAD Facility
IfTypeFramerelayMpi = 92 // Multiproto Interconnect over FR
IfTypeX213 = 93 // CCITT-ITU X213
IfTypeAdsl = 94 // Asymmetric Digital Subscrbr Loop
IfTypeRadsl = 95 // Rate-Adapt Digital Subscrbr Loop
IfTypeSdsl = 96 // Symmetric Digital Subscriber Loop
IfTypeVdsl = 97 // Very H-Speed Digital Subscrb Loop
IfTypeIso88025Crfprint = 98 // ISO 802.5 CRFP
IfTypeMyrinet = 99 // Myricom Myrinet
IfTypeVoiceEm = 100 // Voice recEive and transMit
IfTypeVoiceFxo = 101 // Voice Foreign Exchange Office
IfTypeVoiceFxs = 102 // Voice Foreign Exchange Station
IfTypeVoiceEncap = 103 // Voice encapsulation
IfTypeVoiceOverip = 104 // Voice over IP encapsulation
IfTypeAtmDxi = 105 // ATM DXI
IfTypeAtmFuni = 106 // ATM FUNI
IfTypeAtmIma = 107 // ATM IMA
IfTypePPPmultilinkbundle = 108 // PPP Multilink Bundle
IfTypeIpoverCdlc = 109 // IBM ipOverCdlc
IfTypeIpoverClaw = 110 // IBM Common Link Access to Workstn
IfTypeStacktostack = 111 // IBM stackToStack
IfTypeVirtualipaddress = 112 // IBM VIPA
IfTypeMpc = 113 // IBM multi-proto channel support
IfTypeIpoverAtm = 114 // IBM ipOverAtm
IfTypeIso88025Fiber = 115 // ISO 802.5j Fiber Token Ring
IfTypeTdlc = 116 // IBM twinaxial data link control
IfTypeGigabitethernet = 117
IfTypeHdlc = 118
IfTypeLapF = 119
IfTypeV37 = 120
IfTypeX25Mlp = 121 // Multi-Link Protocol
IfTypeX25Huntgroup = 122 // X.25 Hunt Group
IfTypeTransphdlc = 123
IfTypeInterleave = 124 // Interleave channel
IfTypeFast = 125 // Fast channel
IfTypeIP = 126 // IP (for APPN HPR in IP networks)
IfTypeDocscableMaclayer = 127 // CATV Mac Layer
IfTypeDocscableDownstream = 128 // CATV Downstream interface
IfTypeDocscableUpstream = 129 // CATV Upstream interface
IfTypeA12mppswitch = 130 // Avalon Parallel Processor
IfTypeTunnel = 131 // Encapsulation interface
IfTypeCoffee = 132 // Coffee pot
IfTypeCes = 133 // Circuit Emulation Service
IfTypeAtmSubinterface = 134 // ATM Sub Interface
IfTypeL2Vlan = 135 // Layer 2 Virtual LAN using 802.1Q
IfTypeL3Ipvlan = 136 // Layer 3 Virtual LAN using IP
IfTypeL3Ipxvlan = 137 // Layer 3 Virtual LAN using IPX
IfTypeDigitalpowerline = 138 // IP over Power Lines
IfTypeMediamailoverip = 139 // Multimedia Mail over IP
IfTypeDtm = 140 // Dynamic syncronous Transfer Mode
IfTypeDcn = 141 // Data Communications Network
IfTypeIpforward = 142 // IP Forwarding Interface
IfTypeMsdsl = 143 // Multi-rate Symmetric DSL
IfTypeIEEE1394 = 144 // IEEE1394 High Perf Serial Bus
IfTypeIfGsn = 145
IfTypeDvbrccMaclayer = 146
IfTypeDvbrccDownstream = 147
IfTypeDvbrccUpstream = 148
IfTypeAtmVirtual = 149
IfTypeMplsTunnel = 150
IfTypeSrp = 151
IfTypeVoiceoveratm = 152
IfTypeVoiceoverframerelay = 153
IfTypeIdsl = 154
IfTypeCompositelink = 155
IfTypeSs7Siglink = 156
IfTypePropWirelessP2P = 157
IfTypeFrForward = 158
IfTypeRfc1483 = 159
IfTypeUsb = 160
IfTypeIEEE8023adLag = 161
IfTypeBgpPolicyAccounting = 162
IfTypeFrf16MfrBundle = 163
IfTypeH323Gatekeeper = 164
IfTypeH323Proxy = 165
IfTypeMpls = 166
IfTypeMfSiglink = 167
IfTypeHdsl2 = 168
IfTypeShdsl = 169
IfTypeDs1Fdl = 170
IfTypePos = 171
IfTypeDvbAsiIn = 172
IfTypeDvbAsiOut = 173
IfTypePlc = 174
IfTypeNfas = 175
IfTypeTr008 = 176
IfTypeGr303Rdt = 177
IfTypeGr303Idt = 178
IfTypeIsup = 179
IfTypePropDocsWirelessMaclayer = 180
IfTypePropDocsWirelessDownstream = 181
IfTypePropDocsWirelessUpstream = 182
IfTypeHiperlan2 = 183
IfTypePropBwaP2MP = 184
IfTypeSonetOverheadChannel = 185
IfTypeDigitalWrapperOverheadChannel = 186
IfTypeAal2 = 187
IfTypeRadioMac = 188
IfTypeAtmRadio = 189
IfTypeImt = 190
IfTypeMvl = 191
IfTypeReachDsl = 192
IfTypeFrDlciEndpt = 193
IfTypeAtmVciEndpt = 194
IfTypeOpticalChannel = 195
IfTypeOpticalTransport = 196
IfTypeIEEE80216Wman = 237
IfTypeWwanpp = 243 // WWAN devices based on GSM technology
IfTypeWwanpp2 = 244 // WWAN devices based on CDMA technology
IfTypeIEEE802154 = 259 // IEEE 802.15.4 WPAN interface
IfTypeXboxWireless = 281
)
// MibIfEntryLevel enumeration specifies level of interface information to retrieve in GetIfTable2Ex function call.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2ex
type MibIfEntryLevel uint32
const (
MibIfEntryNormal MibIfEntryLevel = 0
MibIfEntryNormalWithoutStatistics = 2
)
// NdisMedium enumeration type identifies the medium types that NDIS drivers support.
// https://docs.microsoft.com/en-us/windows-hardware/drivers/ddi/content/ntddndis/ne-ntddndis-_ndis_medium
type NdisMedium uint32
const (
NdisMedium802_3 NdisMedium = iota
NdisMedium802_5
NdisMediumFddi
NdisMediumWan
NdisMediumLocalTalk
NdisMediumDix // defined for convenience, not a real medium
NdisMediumArcnetRaw
NdisMediumArcnet878_2
NdisMediumAtm
NdisMediumWirelessWan
NdisMediumIrda
NdisMediumBpc
NdisMediumCoWan
NdisMedium1394
NdisMediumInfiniBand
NdisMediumTunnel
NdisMediumNative802_11
NdisMediumLoopback
NdisMediumWiMAX
NdisMediumIP
NdisMediumMax
)
// NdisPhysicalMedium describes NDIS physical medium type.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type NdisPhysicalMedium uint32
const (
NdisPhysicalMediumUnspecified NdisPhysicalMedium = iota
NdisPhysicalMediumWirelessLan
NdisPhysicalMediumCableModem
NdisPhysicalMediumPhoneLine
NdisPhysicalMediumPowerLine
NdisPhysicalMediumDSL // includes ADSL and UADSL (G.Lite)
NdisPhysicalMediumFibreChannel
NdisPhysicalMedium1394
NdisPhysicalMediumWirelessWan
NdisPhysicalMediumNative802_11
NdisPhysicalMediumBluetooth
NdisPhysicalMediumInfiniband
NdisPhysicalMediumWiMax
NdisPhysicalMediumUWB
NdisPhysicalMedium802_3
NdisPhysicalMedium802_5
NdisPhysicalMediumIrda
NdisPhysicalMediumWiredWAN
NdisPhysicalMediumWiredCoWan
NdisPhysicalMediumOther
NdisPhysicalMediumNative802_15_4
NdisPhysicalMediumMax
)
// NetIfAccessType enumeration type specifies the NDIS network interface access type.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_access_type
type NetIfAccessType uint32
const (
NetIfAccessLoopback NetIfAccessType = iota + 1
NetIfAccessBroadcast
NetIfAccessPointToPoint
NetIfAccessPointToMultiPoint
NetIfAccessMax
)
// NetIfAdminStatus enumeration type specifies the NDIS network interface administrative status, as described in RFC 2863.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-net_if_admin_status
type NetIfAdminStatus uint32
const (
NetIfAdminStatusUp NetIfAdminStatus = iota + 1
NetIfAdminStatusDown
NetIfAdminStatusTesting
)
// NetIfConnectionType enumeration type specifies the NDIS network interface connection type.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_connection_type
type NetIfConnectionType uint32
const (
NetIfConnectionDedicated NetIfConnectionType = iota + 1
NetIfConnectionPassive
NetIfConnectionDemand
NetIfConnectionMaximum
)
// NetIfDirectionType enumeration type specifies the NDIS network interface direction type.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-net_if_direction_type
type NetIfDirectionType uint32
const (
NetIfDirectionSendReceive NetIfDirectionType = iota
NetIfDirectionSendOnly
NetIfDirectionReceiveOnly
NetIfDirectionMaximum
)
// NetIfMediaConnectState enumeration type specifies the NDIS network interface connection state.
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-_net_if_media_connect_state
type NetIfMediaConnectState uint32
const (
MediaConnectStateUnknown NetIfMediaConnectState = iota
MediaConnectStateConnected
MediaConnectStateDisconnected
)
// DadState enumeration specifies information about the duplicate address detection (DAD) state for an IPv4 or IPv6 address.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_dad_state
type DadState uint32
const (
DadStateInvalid DadState = iota
DadStateTentative
DadStateDuplicate
DadStateDeprecated
DadStatePreferred
)
// PrefixOrigin enumeration specifies the origin of an IPv4 or IPv6 address prefix, and is used with the IP_ADAPTER_UNICAST_ADDRESS structure.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_prefix_origin
type PrefixOrigin uint32
const (
PrefixOriginOther PrefixOrigin = iota
PrefixOriginManual
PrefixOriginWellKnown
PrefixOriginDHCP
PrefixOriginRouterAdvertisement
PrefixOriginUnchanged = 1 << 4
)
// LinkLocalAddressBehavior enumeration type defines the link local address behavior.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-_nl_link_local_address_behavior
type LinkLocalAddressBehavior int32
const (
LinkLocalAddressAlwaysOff LinkLocalAddressBehavior = iota // Never use link locals.
LinkLocalAddressDelayed // Use link locals only if no other addresses. (default for IPv4). Legacy mapping: IPAutoconfigurationEnabled.
LinkLocalAddressAlwaysOn // Always use link locals (default for IPv6).
LinkLocalAddressUnchanged = -1
)
// OffloadRod enumeration specifies a set of flags that indicate the offload capabilities for an IP interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ns-nldef-_nl_interface_offload_rod
type OffloadRod uint8
const (
ChecksumSupported OffloadRod = 1 << iota
OptionsSupported
DatagramChecksumSupported
StreamChecksumSupported
StreamOptionsSupported
FastPathCompatible
LargeSendOffloadSupported
GiantSendOffloadSupported
)
// RouteOrigin enumeration type defines the origin of the IP route.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_route_origin
type RouteOrigin uint32
const (
RouteOriginManual RouteOrigin = iota
RouteOriginWellKnown
RouteOriginDHCP
RouteOriginRouterAdvertisement
RouteOrigin6to4
)
// RouteProtocol enumeration type defines the routing mechanism that an IP route was added with, as described in RFC 4292.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_route_protocol
type RouteProtocol uint32
const (
RouteProtocolOther RouteProtocol = iota + 1
RouteProtocolLocal
RouteProtocolNetMgmt
RouteProtocolIcmp
RouteProtocolEgp
RouteProtocolGgp
RouteProtocolHello
RouteProtocolRip
RouteProtocolIsIs
RouteProtocolEsIs
RouteProtocolCisco
RouteProtocolBbn
RouteProtocolOspf
RouteProtocolBgp
RouteProtocolIdpr
RouteProtocolEigrp
RouteProtocolDvmrp
RouteProtocolRpl
RouteProtocolDHCP
RouteProtocolNTAutostatic = 10002
RouteProtocolNTStatic = 10006
RouteProtocolNTStaticNonDOD = 10007
)
// RouterDiscoveryBehavior enumeration type defines the router discovery behavior, as described in RFC 2461.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-_nl_router_discovery_behavior
type RouterDiscoveryBehavior int32
const (
RouterDiscoveryDisabled RouterDiscoveryBehavior = iota
RouterDiscoveryEnabled
RouterDiscoveryDHCP
RouterDiscoveryUnchanged = -1
)
// SuffixOrigin enumeration specifies the origin of an IPv4 or IPv6 address suffix, and is used with the IP_ADAPTER_UNICAST_ADDRESS structure.
// https://docs.microsoft.com/en-us/windows/desktop/api/nldef/ne-nldef-nl_suffix_origin
type SuffixOrigin uint32
const (
SuffixOriginOther SuffixOrigin = iota
SuffixOriginManual
SuffixOriginWellKnown
SuffixOriginDHCP
SuffixOriginLinkLayerAddress
SuffixOriginRandom
SuffixOriginUnchanged = 1 << 4
)
// MibNotificationType enumeration defines the notification type passed to a callback function when a notification occurs.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ne-netioapi-_mib_notification_type
type MibNotificationType uint32
const (
MibParameterNotification MibNotificationType = iota // Parameter change
MibAddInstance // Addition
MibDeleteInstance // Deletion
MibInitialNotification // Initial notification
)
type ChangeCallback interface {
Unregister() error
}
// TunnelType enumeration type defines the encapsulation method used by a tunnel, as described by the Internet Assigned Names Authority (IANA).
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-tunnel_type
type TunnelType uint32
const (
TunnelTypeNone TunnelType = 0
TunnelTypeOther = 1
TunnelTypeDirect = 2
TunnelType6to4 = 11
TunnelTypeIsatap = 13
TunnelTypeTeredo = 14
TunnelTypeIPHTTPS = 15
)
// InterfaceAndOperStatusFlags enumeration type defines interface and operation flags
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type InterfaceAndOperStatusFlags uint8
const (
IAOSFHardwareInterface InterfaceAndOperStatusFlags = 1 << iota
IAOSFFilterInterface
IAOSFConnectorPresent
IAOSFNotAuthenticated
IAOSFNotMediaConnected
IAOSFPaused
IAOSFLowPower
IAOSFEndPointInterface
)
// GAAFlags enumeration defines flags used in GetAdaptersAddresses calls
// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
type GAAFlags uint32
const (
GAAFlagSkipUnicast GAAFlags = 1 << iota
GAAFlagSkipAnycast
GAAFlagSkipMulticast
GAAFlagSkipDNSServer
GAAFlagIncludePrefix
GAAFlagSkipFriendlyName
GAAFlagIncludeWinsInfo
GAAFlagIncludeGateways
GAAFlagIncludeAllInterfaces
GAAFlagIncludeAllCompartments
GAAFlagIncludeTunnelBindingOrder
GAAFlagSkipDNSInfo
GAAFlagDefault GAAFlags = 0
GAAFlagSkipAll = GAAFlagSkipUnicast | GAAFlagSkipAnycast | GAAFlagSkipMulticast | GAAFlagSkipDNSServer | GAAFlagSkipFriendlyName | GAAFlagSkipDNSInfo
GAAFlagIncludeAll = GAAFlagIncludePrefix | GAAFlagIncludeWinsInfo | GAAFlagIncludeGateways | GAAFlagIncludeAllInterfaces | GAAFlagIncludeAllCompartments | GAAFlagIncludeTunnelBindingOrder
)
// ScopeLevel enumeration is used with the IP_ADAPTER_ADDRESSES structure to identify scope levels for IPv6 addresses.
// https://docs.microsoft.com/en-us/windows/desktop/api/ws2def/ne-ws2def-scope_level
type ScopeLevel uint32
const (
ScopeLevelInterface ScopeLevel = 1
ScopeLevelLink = 2
ScopeLevelSubnet = 3
ScopeLevelAdmin = 4
ScopeLevelSite = 5
ScopeLevelOrganization = 8
ScopeLevelGlobal = 14
ScopeLevelCount = 16
)
// RouteData structure describes a route to add
type RouteData struct {
Destination net.IPNet
NextHop net.IP
Metric uint32
}
// IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix
type IPAdapterDNSSuffix struct {
Next *IPAdapterDNSSuffix
str [maxDNSSuffixStringLength]uint16
}
// String method returns the DNS suffix for this DNS suffix entry.
func (obj *IPAdapterDNSSuffix) String() string {
return windows.UTF16ToString(obj.str[:])
}
// AdapterName method returns the name of the adapter with which these addresses are associated.
// Unlike an adapter's friendly name, the adapter name returned by AdapterName is permanent and cannot be modified by the user.
func (addr *IPAdapterAddresses) AdapterName() string {
return windows.BytePtrToString(addr.adapterName)
}
// DNSSuffix method returns adapter DNS suffix associated with this adapter.
func (addr *IPAdapterAddresses) DNSSuffix() string {
if addr.dnsSuffix == nil {
return ""
}
return windows.UTF16PtrToString(addr.dnsSuffix)
}
// Description method returns description for the adapter.
func (addr *IPAdapterAddresses) Description() string {
if addr.description == nil {
return ""
}
return windows.UTF16PtrToString(addr.description)
}
// FriendlyName method returns a user-friendly name for the adapter. For example: "Local Area Connection 1."
// This name appears in contexts such as the ipconfig command line program and the Connection folder.
func (addr *IPAdapterAddresses) FriendlyName() string {
if addr.friendlyName == nil {
return ""
}
return windows.UTF16PtrToString(addr.friendlyName)
}
// PhysicalAddress method returns the Media Access Control (MAC) address for the adapter.
// For example, on an Ethernet network this member would specify the Ethernet hardware address.
func (addr *IPAdapterAddresses) PhysicalAddress() []byte {
return addr.physicalAddress[:addr.physicalAddressLength]
}
// DHCPv6ClientDUID method returns the DHCP unique identifier (DUID) for the DHCPv6 client.
// This information is only applicable to an IPv6 adapter address configured using DHCPv6.
func (addr *IPAdapterAddresses) DHCPv6ClientDUID() []byte {
return addr.dhcpv6ClientDUID[:addr.dhcpv6ClientDUIDLength]
}
// Init method initializes the members of an MIB_IPINTERFACE_ROW entry with default values.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeipinterfaceentry
func (row *MibIPInterfaceRow) Init() {
initializeIPInterfaceEntry(row)
}
// get method retrieves IP information for the specified interface on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipinterfaceentry
func (row *MibIPInterfaceRow) get() error {
if err := getIPInterfaceEntry(row); err != nil {
return err
}
// Patch that fixes SitePrefixLength issue
// https://stackoverflow.com/questions/54857292/setipinterfaceentry-returns-error-invalid-parameter?noredirect=1
switch row.Family {
case windows.AF_INET:
if row.SitePrefixLength > 32 {
row.SitePrefixLength = 0
}
case windows.AF_INET6:
if row.SitePrefixLength > 128 {
row.SitePrefixLength = 128
}
}
return nil
}
// Set method sets the properties of an IP interface on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setipinterfaceentry
func (row *MibIPInterfaceRow) Set() error {
return setIPInterfaceEntry(row)
}
// get method returns all table rows as a Go slice.
func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) {
unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
return
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable
func (tab *mibIPInterfaceTable) free() {
freeMibTable(unsafe.Pointer(tab))
}
// Alias method returns a string that contains the alias name of the network interface.
func (row *MibIfRow2) Alias() string {
return windows.UTF16ToString(row.alias[:])
}
// Description method returns a string that contains a description of the network interface.
func (row *MibIfRow2) Description() string {
return windows.UTF16ToString(row.description[:])
}
// PhysicalAddress method returns the physical hardware address of the adapter for this network interface.
func (row *MibIfRow2) PhysicalAddress() []byte {
return row.physicalAddress[:row.physicalAddressLength]
}
// PermanentPhysicalAddress method returns the permanent physical hardware address of the adapter for this network interface.
func (row *MibIfRow2) PermanentPhysicalAddress() []byte {
return row.permanentPhysicalAddress[:row.physicalAddressLength]
}
// get method retrieves information for the specified interface on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2
func (row *MibIfRow2) get() (ret error) {
return getIfEntry2(row)
}
// get method returns all table rows as a Go slice.
func (tab *mibIfTable2) get() (s []MibIfRow2) {
unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
return
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable
func (tab *mibIfTable2) free() {
freeMibTable(unsafe.Pointer(tab))
}
// RawSockaddrInet union contains an IPv4, an IPv6 address, or an address family.
// https://docs.microsoft.com/en-us/windows/desktop/api/ws2ipdef/ns-ws2ipdef-_sockaddr_inet
type RawSockaddrInet struct {
Family AddressFamily
data [26]byte
}
// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
if v4 := ip.To4(); v4 != nil {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
copy(addr4.Addr[:], v4)
addr4.Port = port
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
return nil
}
if v6 := ip.To16(); v6 != nil {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
addr6.Port = port
addr6.Flowinfo = 0
copy(addr6.Addr[:], v6)
addr6.Scope_id = 0
return nil
}
return windows.ERROR_INVALID_PARAMETER
}
// IP method returns IPv4 or IPv6 address.
// If the address is neither IPv4 not IPv6 nil is returned.
func (addr *RawSockaddrInet) IP() net.IP {
switch addr.Family {
case windows.AF_INET:
return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:]
case windows.AF_INET6:
return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:]
}
return nil
}
// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeunicastipaddressentry
func (row *MibUnicastIPAddressRow) Init() {
initializeUnicastIPAddressEntry(row)
}
// get method retrieves information for an existing unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry
func (row *MibUnicastIPAddressRow) get() error {
return getUnicastIPAddressEntry(row)
}
// Set method sets the properties of an existing unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setunicastipaddressentry
func (row *MibUnicastIPAddressRow) Set() error {
return setUnicastIPAddressEntry(row)
}
// Create method adds a new unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry
func (row *MibUnicastIPAddressRow) Create() error {
return createUnicastIPAddressEntry(row)
}
// Delete method deletes an existing unicast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry
func (row *MibUnicastIPAddressRow) Delete() error {
return deleteUnicastIPAddressEntry(row)
}
// get method returns all table rows as a Go slice.
func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) {
unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
return
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable
func (tab *mibUnicastIPAddressTable) free() {
freeMibTable(unsafe.Pointer(tab))
}
// get method retrieves information for an existing anycast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getanycastipaddressentry
func (row *MibAnycastIPAddressRow) get() error {
return getAnycastIPAddressEntry(row)
}
// Create method adds a new anycast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createanycastipaddressentry
func (row *MibAnycastIPAddressRow) Create() error {
return createAnycastIPAddressEntry(row)
}
// Delete method deletes an existing anycast IP address entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteanycastipaddressentry
func (row *MibAnycastIPAddressRow) Delete() error {
return deleteAnycastIPAddressEntry(row)
}
// get method returns all table rows as a Go slice.
func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) {
unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
return
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable
func (tab *mibAnycastIPAddressTable) free() {
freeMibTable(unsafe.Pointer(tab))
}
// IPAddressPrefix structure stores an IP address prefix.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix
type IPAddressPrefix struct {
Prefix RawSockaddrInet
PrefixLength uint8
_ [2]byte
}
// SetIPNet method sets IP address prefix using net.IPNet.
func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error {
err := prefix.Prefix.SetIP(net.IP, 0)
if err != nil {
return err
}
ones, _ := net.Mask.Size()
prefix.PrefixLength = uint8(ones)
return nil
}
// IPNet method returns IP address prefix as net.IPNet.
// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately.
func (prefix *IPAddressPrefix) IPNet() net.IPNet {
switch prefix.Prefix.Family {
case windows.AF_INET:
return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)}
case windows.AF_INET6:
return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)}
}
return net.IPNet{}
}
// MibIPforwardRow2 structure stores information about an IP route entry.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_row2
type MibIPforwardRow2 struct {
InterfaceLUID LUID
InterfaceIndex uint32
DestinationPrefix IPAddressPrefix
NextHop RawSockaddrInet
SitePrefixLength uint8
ValidLifetime uint32
PreferredLifetime uint32
Metric uint32
Protocol RouteProtocol
Loopback bool
AutoconfigureAddress bool
Publish bool
Immortal bool
Age uint32
Origin RouteOrigin
}
// Init method initializes a MIB_IPFORWARD_ROW2 structure with default values for an IP route entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-initializeipforwardentry
func (row *MibIPforwardRow2) Init() {
initializeIPForwardEntry(row)
}
// get method retrieves information for an IP route entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2
func (row *MibIPforwardRow2) get() error {
return getIPForwardEntry2(row)
}
// Set method sets the properties of an IP route entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-setipforwardentry2
func (row *MibIPforwardRow2) Set() error {
return setIPForwardEntry2(row)
}
// Create method creates a new IP route entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2
func (row *MibIPforwardRow2) Create() error {
return createIPForwardEntry2(row)
}
// Delete method deletes an IP route entry on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2
func (row *MibIPforwardRow2) Delete() error {
return deleteIPForwardEntry2(row)
}
// get method returns all table rows as a Go slice.
func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) {
unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
return
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-freemibtable
func (tab *mibIPforwardTable2) free() {
freeMibTable(unsafe.Pointer(tab))
}
//
// Undocumented DNS API
//
// dnsInterfaceSettings is mean to be used with setInterfaceDnsSettings
type dnsInterfaceSettings struct {
Version uint32
_ [4]byte
Flags uint64
Domain *uint16
NameServer *uint16
SearchList *uint16
RegistrationEnabled uint32
RegisterAdapterName uint32
EnableLLMNR uint32
QueryAdapterName uint32
ProfileNameServer *uint16
}
const (
disVersion1 = 1
disVersion2 = 2
disFlagsIPv6 = 0x1
disFlagsNameServer = 0x2
disFlagsSearchList = 0x4
disFlagsRegistrationEnabled = 0x8
disFlagsRegisterAdapterName = 0x10
disFlagsDomain = 0x20
disFlagsHostname = 0x40 // ??
disFlagsEnableLLMNR = 0x80
disFlagsQueryAdapterName = 0x100
disFlagsProfileNameServer = 0x200
disFlagsVersion2 = 0x400 // ?? - v2 only
disFlagsMoreFlags = 0x800 // ?? - v2 only
)
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@ -0,0 +1,227 @@
// +build 386 arm
package winipcfg
import (
"golang.org/x/sys/windows"
)
// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh
type IPAdapterWINSServerAddress struct {
Length uint32
_ uint32
Next *IPAdapterWINSServerAddress
Address windows.SocketAddress
_ [4]byte
}
// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh
type IPAdapterGatewayAddress struct {
Length uint32
_ uint32
Next *IPAdapterGatewayAddress
Address windows.SocketAddress
_ [4]byte
}
// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh
// This is a modified and extended version of windows.IpAdapterAddresses.
type IPAdapterAddresses struct {
Length uint32
IfIndex uint32
Next *IPAdapterAddresses
adapterName *byte
FirstUnicastAddress *windows.IpAdapterUnicastAddress
FirstAnycastAddress *windows.IpAdapterAnycastAddress
FirstMulticastAddress *windows.IpAdapterMulticastAddress
FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter
dnsSuffix *uint16
description *uint16
friendlyName *uint16
physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte
physicalAddressLength uint32
Flags IPAAFlags
MTU uint32
IfType IfType
OperStatus IfOperStatus
IPv6IfIndex uint32
ZoneIndices [16]uint32
FirstPrefix *windows.IpAdapterPrefix
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
FirstWINSServerAddress *IPAdapterWINSServerAddress
FirstGatewayAddress *IPAdapterGatewayAddress
Ipv4Metric uint32
Ipv6Metric uint32
LUID LUID
DHCPv4Server windows.SocketAddress
CompartmentID uint32
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TunnelType TunnelType
DHCPv6Server windows.SocketAddress
dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte
dhcpv6ClientDUIDLength uint32
DHCPv6IAID uint32
FirstDNSSuffix *IPAdapterDNSSuffix
_ [4]byte
}
// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row
type MibIPInterfaceRow struct {
Family AddressFamily
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
MaxReassemblySize uint32
InterfaceIdentifier uint64
MinRouterAdvertisementInterval uint32
MaxRouterAdvertisementInterval uint32
AdvertisingEnabled bool
ForwardingEnabled bool
WeakHostSend bool
WeakHostReceive bool
UseAutomaticMetric bool
UseNeighborUnreachabilityDetection bool
ManagedAddressConfigurationSupported bool
OtherStatefulConfigurationSupported bool
AdvertiseDefaultRoute bool
RouterDiscoveryBehavior RouterDiscoveryBehavior
DadTransmits uint32
BaseReachableTime uint32
RetransmitTime uint32
PathMTUDiscoveryTimeout uint32
LinkLocalAddressBehavior LinkLocalAddressBehavior
LinkLocalAddressTimeout uint32
ZoneIndices [ScopeLevelCount]uint32
SitePrefixLength uint32
Metric uint32
NLMTU uint32
Connected bool
SupportsWakeUpPatterns bool
SupportsNeighborDiscovery bool
SupportsRouterDiscovery bool
ReachableTime uint32
TransmitOffload OffloadRod
ReceiveOffload OffloadRod
DisableDefaultRoutes bool
}
// mibIPInterfaceTable structure contains a table of IP interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table
type mibIPInterfaceTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibIPInterfaceRow
}
// MibIfRow2 structure stores information about a particular interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type MibIfRow2 struct {
InterfaceLUID LUID
InterfaceIndex uint32
InterfaceGUID windows.GUID
alias [ifMaxStringSize + 1]uint16
description [ifMaxStringSize + 1]uint16
physicalAddressLength uint32
physicalAddress [ifMaxPhysAddressLength]byte
permanentPhysicalAddress [ifMaxPhysAddressLength]byte
MTU uint32
Type IfType
TunnelType TunnelType
MediaType NdisMedium
PhysicalMediumType NdisPhysicalMedium
AccessType NetIfAccessType
DirectionType NetIfDirectionType
InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags
OperStatus IfOperStatus
AdminStatus NetIfAdminStatus
MediaConnectState NetIfMediaConnectState
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
_ [4]byte
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
InOctets uint64
InUcastPkts uint64
InNUcastPkts uint64
InDiscards uint64
InErrors uint64
InUnknownProtos uint64
InUcastOctets uint64
InMulticastOctets uint64
InBroadcastOctets uint64
OutOctets uint64
OutUcastPkts uint64
OutNUcastPkts uint64
OutDiscards uint64
OutErrors uint64
OutUcastOctets uint64
OutMulticastOctets uint64
OutBroadcastOctets uint64
OutQLen uint64
}
// mibIfTable2 structure contains a table of logical and physical interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2
type mibIfTable2 struct {
numEntries uint32
_ [4]byte
table [anySize]MibIfRow2
}
// MibUnicastIPAddressRow structure stores information about a unicast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row
type MibUnicastIPAddressRow struct {
Address RawSockaddrInet
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
PrefixOrigin PrefixOrigin
SuffixOrigin SuffixOrigin
ValidLifetime uint32
PreferredLifetime uint32
OnLinkPrefixLength uint8
SkipAsSource bool
DadState DadState
ScopeID uint32
CreationTimeStamp int64
}
// mibUnicastIPAddressTable structure contains a table of unicast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table
type mibUnicastIPAddressTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibUnicastIPAddressRow
}
// MibAnycastIPAddressRow structure stores information about an anycast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row
type MibAnycastIPAddressRow struct {
Address RawSockaddrInet
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
ScopeID uint32
}
// mibAnycastIPAddressTable structure contains a table of anycast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table
type mibAnycastIPAddressTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibAnycastIPAddressRow
}
// mibIPforwardTable2 structure contains a table of IP route entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2
type mibIPforwardTable2 struct {
numEntries uint32
_ [4]byte
table [anySize]MibIPforwardRow2
}

View File

@ -0,0 +1,216 @@
// +build windows
// +build amd64 arm64
package winipcfg
import (
"golang.org/x/sys/windows"
)
// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh
type IPAdapterWINSServerAddress struct {
Length uint32
_ uint32
Next *IPAdapterWINSServerAddress
Address windows.SocketAddress
}
// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh
type IPAdapterGatewayAddress struct {
Length uint32
_ uint32
Next *IPAdapterGatewayAddress
Address windows.SocketAddress
}
// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh
// This is a modified and extended version of windows.IpAdapterAddresses.
type IPAdapterAddresses struct {
Length uint32
IfIndex uint32
Next *IPAdapterAddresses
adapterName *byte
FirstUnicastAddress *windows.IpAdapterUnicastAddress
FirstAnycastAddress *windows.IpAdapterAnycastAddress
FirstMulticastAddress *windows.IpAdapterMulticastAddress
FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter
dnsSuffix *uint16
description *uint16
friendlyName *uint16
physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte
physicalAddressLength uint32
Flags IPAAFlags
MTU uint32
IfType IfType
OperStatus IfOperStatus
IPv6IfIndex uint32
ZoneIndices [16]uint32
FirstPrefix *windows.IpAdapterPrefix
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
FirstWINSServerAddress *IPAdapterWINSServerAddress
FirstGatewayAddress *IPAdapterGatewayAddress
Ipv4Metric uint32
Ipv6Metric uint32
LUID LUID
DHCPv4Server windows.SocketAddress
CompartmentID uint32
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TunnelType TunnelType
DHCPv6Server windows.SocketAddress
dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte
dhcpv6ClientDUIDLength uint32
DHCPv6IAID uint32
FirstDNSSuffix *IPAdapterDNSSuffix
}
// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row
type MibIPInterfaceRow struct {
Family AddressFamily
InterfaceLUID LUID
InterfaceIndex uint32
MaxReassemblySize uint32
InterfaceIdentifier uint64
MinRouterAdvertisementInterval uint32
MaxRouterAdvertisementInterval uint32
AdvertisingEnabled bool
ForwardingEnabled bool
WeakHostSend bool
WeakHostReceive bool
UseAutomaticMetric bool
UseNeighborUnreachabilityDetection bool
ManagedAddressConfigurationSupported bool
OtherStatefulConfigurationSupported bool
AdvertiseDefaultRoute bool
RouterDiscoveryBehavior RouterDiscoveryBehavior
DadTransmits uint32
BaseReachableTime uint32
RetransmitTime uint32
PathMTUDiscoveryTimeout uint32
LinkLocalAddressBehavior LinkLocalAddressBehavior
LinkLocalAddressTimeout uint32
ZoneIndices [ScopeLevelCount]uint32
SitePrefixLength uint32
Metric uint32
NLMTU uint32
Connected bool
SupportsWakeUpPatterns bool
SupportsNeighborDiscovery bool
SupportsRouterDiscovery bool
ReachableTime uint32
TransmitOffload OffloadRod
ReceiveOffload OffloadRod
DisableDefaultRoutes bool
}
// mibIPInterfaceTable structure contains a table of IP interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table
type mibIPInterfaceTable struct {
numEntries uint32
table [anySize]MibIPInterfaceRow
}
// MibIfRow2 structure stores information about a particular interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type MibIfRow2 struct {
InterfaceLUID LUID
InterfaceIndex uint32
InterfaceGUID windows.GUID
alias [ifMaxStringSize + 1]uint16
description [ifMaxStringSize + 1]uint16
physicalAddressLength uint32
physicalAddress [ifMaxPhysAddressLength]byte
permanentPhysicalAddress [ifMaxPhysAddressLength]byte
MTU uint32
Type IfType
TunnelType TunnelType
MediaType NdisMedium
PhysicalMediumType NdisPhysicalMedium
AccessType NetIfAccessType
DirectionType NetIfDirectionType
InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags
OperStatus IfOperStatus
AdminStatus NetIfAdminStatus
MediaConnectState NetIfMediaConnectState
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
InOctets uint64
InUcastPkts uint64
InNUcastPkts uint64
InDiscards uint64
InErrors uint64
InUnknownProtos uint64
InUcastOctets uint64
InMulticastOctets uint64
InBroadcastOctets uint64
OutOctets uint64
OutUcastPkts uint64
OutNUcastPkts uint64
OutDiscards uint64
OutErrors uint64
OutUcastOctets uint64
OutMulticastOctets uint64
OutBroadcastOctets uint64
OutQLen uint64
}
// mibIfTable2 structure contains a table of logical and physical interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2
type mibIfTable2 struct {
numEntries uint32
table [anySize]MibIfRow2
}
// MibUnicastIPAddressRow structure stores information about a unicast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row
type MibUnicastIPAddressRow struct {
Address RawSockaddrInet
InterfaceLUID LUID
InterfaceIndex uint32
PrefixOrigin PrefixOrigin
SuffixOrigin SuffixOrigin
ValidLifetime uint32
PreferredLifetime uint32
OnLinkPrefixLength uint8
SkipAsSource bool
DadState DadState
ScopeID uint32
CreationTimeStamp int64
}
// mibUnicastIPAddressTable structure contains a table of unicast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table
type mibUnicastIPAddressTable struct {
numEntries uint32
table [anySize]MibUnicastIPAddressRow
}
// MibAnycastIPAddressRow structure stores information about an anycast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row
type MibAnycastIPAddressRow struct {
Address RawSockaddrInet
InterfaceLUID LUID
InterfaceIndex uint32
ScopeID uint32
}
// mibAnycastIPAddressTable structure contains a table of anycast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table
type mibAnycastIPAddressTable struct {
numEntries uint32
table [anySize]MibAnycastIPAddressRow
}
// mibIPforwardTable2 structure contains a table of IP route entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2
type mibIPforwardTable2 struct {
numEntries uint32
table [anySize]MibIPforwardRow2
}

View File

@ -0,0 +1,85 @@
// +build windows
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// UnicastAddressChangeCallback structure allows unicast address change callback handling.
type UnicastAddressChangeCallback struct {
cb func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow)
wait sync.WaitGroup
}
var (
unicastAddressChangeAddRemoveMutex = sync.Mutex{}
unicastAddressChangeMutex = sync.Mutex{}
unicastAddressChangeCallbacks = make(map[*UnicastAddressChangeCallback]bool)
unicastAddressChangeHandle = windows.Handle(0)
)
// RegisterUnicastAddressChangeCallback registers a new UnicastAddressChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned UnicastAddressChangeCallback.Unregister method should be used
// to unregister.
func RegisterUnicastAddressChangeCallback(callback func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow)) (*UnicastAddressChangeCallback, error) {
s := &UnicastAddressChangeCallback{cb: callback}
unicastAddressChangeAddRemoveMutex.Lock()
defer unicastAddressChangeAddRemoveMutex.Unlock()
unicastAddressChangeMutex.Lock()
defer unicastAddressChangeMutex.Unlock()
unicastAddressChangeCallbacks[s] = true
if unicastAddressChangeHandle == 0 {
err := notifyUnicastIPAddressChange(windows.AF_UNSPEC, windows.NewCallback(unicastAddressChanged), 0, false, &unicastAddressChangeHandle)
if err != nil {
delete(unicastAddressChangeCallbacks, s)
unicastAddressChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *UnicastAddressChangeCallback) Unregister() error {
unicastAddressChangeAddRemoveMutex.Lock()
defer unicastAddressChangeAddRemoveMutex.Unlock()
unicastAddressChangeMutex.Lock()
delete(unicastAddressChangeCallbacks, callback)
removeIt := len(unicastAddressChangeCallbacks) == 0 && unicastAddressChangeHandle != 0
unicastAddressChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(unicastAddressChangeHandle)
if err != nil {
return err
}
unicastAddressChangeHandle = 0
}
return nil
}
func unicastAddressChanged(callerContext uintptr, row *MibUnicastIPAddressRow, notificationType MibNotificationType) uintptr {
rowCopy := *row
unicastAddressChangeMutex.Lock()
for cb := range unicastAddressChangeCallbacks {
cb.wait.Add(1)
go func(cb *UnicastAddressChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
unicastAddressChangeMutex.Unlock()
return 0
}

View File

@ -0,0 +1,193 @@
// +build windows
package winipcfg
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
//
// Common functions
//
//sys freeMibTable(memory unsafe.Pointer) = iphlpapi.FreeMibTable
//
// Interface-related functions
//
//sys initializeIPInterfaceEntry(row *MibIPInterfaceRow) = iphlpapi.InitializeIpInterfaceEntry
//sys getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) = iphlpapi.GetIpInterfaceTable
//sys getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.GetIpInterfaceEntry
//sys setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.SetIpInterfaceEntry
//sys getIfEntry2(row *MibIfRow2) (ret error) = iphlpapi.GetIfEntry2
//sys getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) = iphlpapi.GetIfTable2Ex
//sys convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceGuidToLuid
//sys convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceIndexToLuid
// GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) {
var b []byte
size := uint32(15000)
for {
b = make([]byte, size)
err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size)
if err == nil {
break
}
if err != windows.ERROR_BUFFER_OVERFLOW || size <= uint32(len(b)) {
return nil, err
}
}
result := make([]*IPAdapterAddresses, 0, uintptr(size)/unsafe.Sizeof(IPAdapterAddresses{}))
for wtiaa := (*IPAdapterAddresses)(unsafe.Pointer(&b[0])); wtiaa != nil; wtiaa = wtiaa.Next {
result = append(result, wtiaa)
}
return result, nil
}
// GetIPInterfaceTable function retrieves the IP interface entries on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipinterfacetable
func GetIPInterfaceTable(family AddressFamily) ([]MibIPInterfaceRow, error) {
var tab *mibIPInterfaceTable
err := getIPInterfaceTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIPInterfaceRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
// GetIfTable2Ex function retrieves the MIB-II interface table.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getiftable2ex
func GetIfTable2Ex(level MibIfEntryLevel) ([]MibIfRow2, error) {
var tab *mibIfTable2
err := getIfTable2Ex(level, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIfRow2, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Unicast IP address-related functions
//
//sys getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) = iphlpapi.GetUnicastIpAddressTable
//sys initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) = iphlpapi.InitializeUnicastIpAddressEntry
//sys getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.GetUnicastIpAddressEntry
//sys setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.SetUnicastIpAddressEntry
//sys createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.CreateUnicastIpAddressEntry
//sys deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.DeleteUnicastIpAddressEntry
// GetUnicastIPAddressTable function retrieves the unicast IP address table on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddresstable
func GetUnicastIPAddressTable(family AddressFamily) ([]MibUnicastIPAddressRow, error) {
var tab *mibUnicastIPAddressTable
err := getUnicastIPAddressTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibUnicastIPAddressRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Anycast IP address-related functions
//
//sys getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) = iphlpapi.GetAnycastIpAddressTable
//sys getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.GetAnycastIpAddressEntry
//sys createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.CreateAnycastIpAddressEntry
//sys deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.DeleteAnycastIpAddressEntry
// GetAnycastIPAddressTable function retrieves the anycast IP address table on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getanycastipaddresstable
func GetAnycastIPAddressTable(family AddressFamily) ([]MibAnycastIPAddressRow, error) {
var tab *mibAnycastIPAddressTable
err := getAnycastIPAddressTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibAnycastIPAddressRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Routing-related functions
//
//sys getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) = iphlpapi.GetIpForwardTable2
//sys initializeIPForwardEntry(route *MibIPforwardRow2) = iphlpapi.InitializeIpForwardEntry
//sys getIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.GetIpForwardEntry2
//sys setIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.SetIpForwardEntry2
//sys createIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.CreateIpForwardEntry2
//sys deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.DeleteIpForwardEntry2
// GetIPForwardTable2 function retrieves the IP route entries on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardtable2
func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) {
var tab *mibIPforwardTable2
err := getIPForwardTable2(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIPforwardRow2, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Notifications-related functions
//
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyipinterfacechange
//sys notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyIpInterfaceChange
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyunicastipaddresschange
//sys notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyUnicastIpAddressChange
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyroutechange2
//sys notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyRouteChange2
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2
//sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2
//
// Undocumented DNS API
//
//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
// The GUID is passed by value, not by reference, which means different
// things on different calling conventions. On amd64, this means it's
// passed by reference anyway, while on arm, arm64, and 386, it's split
// into words.
func setInterfaceDnsSettings(guid windows.GUID, settings *dnsInterfaceSettings) error {
words := (*[4]uintptr)(unsafe.Pointer(&guid))
switch runtime.GOARCH {
case "amd64":
return setInterfaceDnsSettingsByPtr(&guid, settings)
case "arm64":
return setInterfaceDnsSettingsByQwords(words[0], words[1], settings)
case "arm", "386":
return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings)
default:
panic("unknown calling convention")
}
}

View File

@ -0,0 +1,350 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winipcfg
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procConvertInterfaceGuidToLuid = modiphlpapi.NewProc("ConvertInterfaceGuidToLuid")
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry")
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procCreateUnicastIpAddressEntry = modiphlpapi.NewProc("CreateUnicastIpAddressEntry")
procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry")
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procDeleteUnicastIpAddressEntry = modiphlpapi.NewProc("DeleteUnicastIpAddressEntry")
procFreeMibTable = modiphlpapi.NewProc("FreeMibTable")
procGetAnycastIpAddressEntry = modiphlpapi.NewProc("GetAnycastIpAddressEntry")
procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable")
procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2")
procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex")
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2")
procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry")
procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable")
procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry")
procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry")
procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry")
procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange")
procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings")
procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2")
procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry")
procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry")
)
func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) {
r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func freeMibTable(memory unsafe.Pointer) {
syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0)
return
}
func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIfEntry2(row *MibIfRow2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func initializeIPForwardEntry(route *MibIPforwardRow2) {
syscall.Syscall(procInitializeIpForwardEntry.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
return
}
func initializeIPInterfaceEntry(row *MibIPInterfaceRow) {
syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
return
}
func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) {
syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
return
}
func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@ -0,0 +1,49 @@
// +build !load_wintun_from_rsrc
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@ -0,0 +1,56 @@
// +build load_wintun_from_rsrc
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"github.com/Dreamacro/clash/listener/tun/dev/wintun/memmod"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := windows.LoadResourceData(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

View File

@ -0,0 +1,54 @@
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
return &lazyDLL{Name: name, onLoad: onLoad}
}
func (d *lazyDLL) NewProc(name string) *lazyProc {
return &lazyProc{dll: d, Name: name}
}
type lazyProc struct {
Name string
mu sync.Mutex
dll *lazyDLL
addr uintptr
}
func (p *lazyProc) Find() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.addr != 0 {
return nil
}
err := p.dll.Load()
if err != nil {
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
}
addr, err := p.nameToAddr()
if err != nil {
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
return nil
}
func (p *lazyProc) Addr() uintptr {
err := p.Find()
if err != nil {
panic(err)
}
return p.addr
}

View File

@ -0,0 +1,620 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type addressList struct {
next *addressList
address uintptr
}
func (head *addressList) free() {
for node := head; node != nil; node = node.next {
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
}
}
type Module struct {
headers *IMAGE_NT_HEADERS
codeBase uintptr
modules []windows.Handle
initialized bool
isDLL bool
isRelocated bool
nameExports map[string]uint16
entry uintptr
blockedMemory *addressList
}
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
return &module.headers.OptionalHeader.DataDirectory[idx]
}
func (module *Module) copySections(address uintptr, size uintptr, old_headers *IMAGE_NT_HEADERS) error {
sections := module.headers.Sections()
for i := range sections {
if sections[i].SizeOfRawData == 0 {
// Section doesn't contain data in the dll itself, but may define uninitialized data.
sectionSize := old_headers.OptionalHeader.SectionAlignment
if sectionSize == 0 {
continue
}
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sectionSize),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating section: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
var dst []byte
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
for j := range dst {
dst[j] = 0
}
continue
}
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
return errors.New("Incomplete section")
}
// Commit memory block and copy data from dll.
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sections[i].SizeOfRawData),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
memcpy(
module.codeBase+uintptr(sections[i].VirtualAddress),
address+uintptr(sections[i].PointerToRawData),
uintptr(sections[i].SizeOfRawData))
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
}
return nil
}
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
size := section.SizeOfRawData
if size != 0 {
return uintptr(size)
}
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
}
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
}
return 0
}
type sectionFinalizeData struct {
address uintptr
alignedAddress uintptr
size uintptr
characteristics uint32
last bool
}
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
if sectionData.size == 0 {
return nil
}
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
// Section is not needed any more and can safely be freed.
if sectionData.address == sectionData.alignedAddress &&
(sectionData.last ||
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
// Only allowed to decommit whole pages.
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
}
return nil
}
// determine protection flags based on characteristics
var ProtectionFlags = [8]uint32{
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
windows.PAGE_EXECUTE, // not writeable, not readable, executable
windows.PAGE_READONLY, // not writeable, readable, not executable
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
windows.PAGE_READWRITE, // writeable, readable, not executable
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
}
protect := ProtectionFlags[sectionData.characteristics>>29]
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
protect |= windows.PAGE_NOCACHE
}
// Change memory access flags.
var oldProtect uint32
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
if err != nil {
return fmt.Errorf("Error protecting memory page: %w", err)
}
return nil
}
func (module *Module) finalizeSections() error {
sections := module.headers.Sections()
imageOffset := module.headers.OptionalHeader.imageOffset()
sectionData := sectionFinalizeData{}
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionData.size = module.realSectionSize(&sections[0])
sectionData.characteristics = sections[0].Characteristics
// Loop through all sections and change access flags.
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionSize := module.realSectionSize(&sections[i])
// Combine access flags of all sections that share a page.
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
// Section shares page with previous.
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
} else {
sectionData.characteristics |= sections[i].Characteristics
}
sectionData.size = sectionAddress + sectionSize - sectionData.address
continue
}
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
sectionData.address = sectionAddress
sectionData.alignedAddress = alignedAddress
sectionData.size = sectionSize
sectionData.characteristics = sections[i].Characteristics
}
sectionData.last = true
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
return nil
}
func (module *Module) executeTLS() {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
if directory.VirtualAddress == 0 {
return
}
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
callback := tls.AddressOfCallbacks
if callback != 0 {
for {
f := *(*uintptr)(a2p(callback))
if f == 0 {
break
}
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
callback += unsafe.Sizeof(f)
}
}
}
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
if directory.Size == 0 {
return delta == 0, nil
}
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
var relInfos []uint16
unsafeSlice(
unsafe.Pointer(&relInfos),
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation.
relType := relInfo >> 12
// The lower 12 bits define the offset.
relOffset := uintptr(relInfo & 0xfff)
switch relType {
case IMAGE_REL_BASED_ABSOLUTE:
// Skip relocation.
case IMAGE_REL_BASED_LOW:
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
break
case IMAGE_REL_BASED_HIGH:
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
break
case IMAGE_REL_BASED_HIGHLOW:
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
case IMAGE_REL_BASED_DIR64:
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
case IMAGE_REL_BASED_THUMB_MOV32:
inst := *(*uint32)(a2p(dest + relOffset))
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f240 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
}
imm16 += uint32(delta) & 0xffff
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
if hiDelta != 0 {
inst = *(*uint32)(a2p(dest + relOffset + 4))
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f2c0 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
}
imm16 += hiDelta
if imm16 > 0xffff {
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
}
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
}
default:
return false, fmt.Errorf("Unsupported relocation: %v", relType)
}
}
// Advance to next relocation block.
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
}
return true, nil
}
func (module *Module) buildImportTable() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
if directory.Size == 0 {
return nil
}
module.modules = make([]windows.Handle, 0, 16)
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for importDesc.Name != 0 {
handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Error loading module: %w", err)
}
var thunkRef, funcRef *uintptr
if importDesc.OriginalFirstThunk() != 0 {
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
} else {
// No hint table.
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
}
for *thunkRef != 0 {
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
*funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef))
} else {
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
*funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0]))
}
if err != nil {
windows.FreeLibrary(handle)
return fmt.Errorf("Error getting function address: %w", err)
}
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
}
module.modules = append(module.modules, handle)
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
}
return nil
}
func (module *Module) buildNameExports() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
return errors.New("No functions exported")
}
if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name")
}
var nameRefs []uint32
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
var ordinals []uint16
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
module.nameExports = make(map[string]uint16)
for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
module.nameExports[nameArray] = ordinals[i]
}
return nil
}
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
size := uintptr(len(data))
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
}
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
}
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
}
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
}
if oldHeader.FileHeader.Machine != imageFileProcess {
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
}
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
return nil, errors.New("Unaligned section")
}
lastSectionEnd := uintptr(0)
sections := oldHeader.Sections()
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
for i := range sections {
var endOfSection uintptr
if sections[i].SizeOfRawData == 0 {
// Section without data in the DLL
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
} else {
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
}
if endOfSection > lastSectionEnd {
lastSectionEnd = endOfSection
}
}
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
return nil, errors.New("Section is not page-aligned")
}
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
defer func() {
if err != nil {
module.Free()
module = nil
}
}()
// Reserve memory for image of library.
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
// Try to allocate memory at arbitrary position.
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating code: %w", err)
return
}
}
err = module.check4GBBoundaries(alignedImageSize)
if err != nil {
err = fmt.Errorf("Error reallocating code: %w", err)
return
}
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
err = errors.New("Incomplete headers")
return
}
// Commit memory for headers.
headers, err := windows.VirtualAlloc(module.codeBase,
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating headers: %w", err)
return
}
// Copy PE header to code.
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
// Update position.
module.headers.OptionalHeader.ImageBase = module.codeBase
// Copy sections from DLL file block to new memory location.
err = module.copySections(addr, size, oldHeader)
if err != nil {
err = fmt.Errorf("Error copying sections: %w", err)
return
}
// Adjust base address of imported data.
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
if locationDelta != 0 {
module.isRelocated, err = module.performBaseRelocation(locationDelta)
if err != nil {
err = fmt.Errorf("Error relocating module: %w", err)
return
}
} else {
module.isRelocated = true
}
// Load required dlls and adjust function table of imports.
err = module.buildImportTable()
if err != nil {
err = fmt.Errorf("Error building import table: %w", err)
return
}
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
err = module.finalizeSections()
if err != nil {
err = fmt.Errorf("Error finalizing sections: %w", err)
return
}
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
// Get entry point of loaded module.
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
if module.isDLL {
// Notify library about attaching to process.
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
successful := r0 != 0
if !successful {
err = windows.ERROR_DLL_INIT_FAILED
return
}
module.initialized = true
}
}
module.buildNameExports()
return
}
// Free releases module resources and unloads it.
func (module *Module) Free() {
if module.initialized {
// Notify library about detaching from process.
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
module.initialized = false
}
if module.modules != nil {
// Free previously opened libraries.
for _, handle := range module.modules {
windows.FreeLibrary(handle)
}
module.modules = nil
}
if module.codeBase != 0 {
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
module.codeBase = 0
}
if module.blockedMemory != nil {
module.blockedMemory.free()
module.blockedMemory = nil
}
}
// ProcAddressByName returns function address by exported name.
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if module.nameExports == nil {
return 0, errors.New("No functions exported by name")
}
if idx, ok := module.nameExports[name]; ok {
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
return 0, errors.New("Function not found by name")
}
// ProcAddressByOrdinal returns function address by exported ordinal.
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if uint32(ordinal) < exports.Base {
return 0, errors.New("Ordinal number too low")
}
idx := ordinal - uint16(exports.Base)
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
func alignDown(value, alignment uintptr) uintptr {
return value & ^(alignment - 1)
}
func alignUp(value, alignment uintptr) uintptr {
return (value + alignment - 1) & ^(alignment - 1)
}
func a2p(addr uintptr) unsafe.Pointer {
return unsafe.Pointer(addr)
}
func memcpy(dst, src, size uintptr) {
var d, s []byte
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
copy(d, s)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@ -0,0 +1,16 @@
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return 0
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
return
}

View File

@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_I386

View File

@ -0,0 +1,36 @@
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"fmt"
"golang.org/x/sys/windows"
)
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
node := &addressList{
next: module.blockedMemory,
address: module.codeBase,
}
module.blockedMemory = node
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
}
return
}

View File

@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64

View File

@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT

View File

@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64

View File

@ -0,0 +1,339 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import "unsafe"
const (
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
IMAGE_OS2_SIGNATURE = 0x454E // NE
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
IMAGE_VXD_SIGNATURE = 0x454C // LE
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
)
// DOS .EXE header
type IMAGE_DOS_HEADER struct {
E_magic uint16 // Magic number
E_cblp uint16 // Bytes on last page of file
E_cp uint16 // Pages in file
E_crlc uint16 // Relocations
E_cparhdr uint16 // Size of header in paragraphs
E_minalloc uint16 // Minimum extra paragraphs needed
E_maxalloc uint16 // Maximum extra paragraphs needed
E_ss uint16 // Initial (relative) SS value
E_sp uint16 // Initial SP value
E_csum uint16 // Checksum
E_ip uint16 // Initial IP value
E_cs uint16 // Initial (relative) CS value
E_lfarlc uint16 // File address of relocation table
E_ovno uint16 // Overlay number
E_res [4]uint16 // Reserved words
E_oemid uint16 // OEM identifier (for e_oeminfo)
E_oeminfo uint16 // OEM information; e_oemid specific
E_res2 [10]uint16 // Reserved words
E_lfanew int32 // File address of new exe header
}
// File header format
type IMAGE_FILE_HEADER struct {
Machine uint16
NumberOfSections uint16
TimeDateStamp uint32
PointerToSymbolTable uint32
NumberOfSymbols uint32
SizeOfOptionalHeader uint16
Characteristics uint16
}
const (
IMAGE_SIZEOF_FILE_HEADER = 20
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
IMAGE_FILE_SYSTEM = 0x1000 // System File.
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
IMAGE_FILE_MACHINE_UNKNOWN = 0
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_AM33 = 0x01d3
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
IMAGE_FILE_MACHINE_CEF = 0x0CEF
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
IMAGE_FILE_MACHINE_CEE = 0xC0EE
)
// Directory format
type IMAGE_DATA_DIRECTORY struct {
VirtualAddress uint32
Size uint32
}
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
type IMAGE_NT_HEADERS struct {
Signature uint32
FileHeader IMAGE_FILE_HEADER
OptionalHeader IMAGE_OPTIONAL_HEADER
}
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
(uintptr)(unsafe.Pointer(ntheader)) +
unsafe.Offsetof(ntheader.OptionalHeader) +
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
}
const (
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
)
const IMAGE_SIZEOF_SHORT_NAME = 8
// Section header format
type IMAGE_SECTION_HEADER struct {
Name [IMAGE_SIZEOF_SHORT_NAME]byte
physicalAddressOrVirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLinenumbers uint32
NumberOfRelocations uint16
NumberOfLinenumbers uint16
Characteristics uint32
}
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
const (
// Section characteristics.
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
IMAGE_SCN_MEM_FARDATA = 0x00008000
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
IMAGE_SCN_MEM_16BIT = 0x00020000
IMAGE_SCN_MEM_LOCKED = 0x00040000
IMAGE_SCN_MEM_PRELOAD = 0x00080000
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
IMAGE_SCN_ALIGN_MASK = 0x00F00000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
// TLS Characteristic Flags
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
)
// Based relocation format
type IMAGE_BASE_RELOCATION struct {
VirtualAddress uint32
SizeOfBlock uint32
}
const (
IMAGE_REL_BASED_ABSOLUTE = 0
IMAGE_REL_BASED_HIGH = 1
IMAGE_REL_BASED_LOW = 2
IMAGE_REL_BASED_HIGHLOW = 3
IMAGE_REL_BASED_HIGHADJ = 4
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
IMAGE_REL_BASED_RESERVED = 6
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
IMAGE_REL_BASED_DIR64 = 10
IMAGE_REL_BASED_IA64_IMM64 = 9
IMAGE_REL_BASED_MIPS_JMPADDR = 5
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
IMAGE_REL_BASED_ARM_MOV32 = 5
IMAGE_REL_BASED_THUMB_MOV32 = 7
)
// Export Format
type IMAGE_EXPORT_DIRECTORY struct {
Characteristics uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
Name uint32
Base uint32
NumberOfFunctions uint32
NumberOfNames uint32
AddressOfFunctions uint32 // RVA from base of image
AddressOfNames uint32 // RVA from base of image
AddressOfNameOrdinals uint32 // RVA from base of image
}
type IMAGE_IMPORT_BY_NAME struct {
Hint uint16
Name [1]byte
}
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
return ordinal & 0xffff
}
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
}
// Thread Local Storage
type IMAGE_TLS_DIRECTORY struct {
StartAddressOfRawData uintptr
EndAddressOfRawData uintptr
AddressOfIndex uintptr // PDWORD
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
SizeOfZeroFill uint32
Characteristics uint32
}
type IMAGE_IMPORT_DESCRIPTOR struct {
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
TimeDateStamp uint32 // 0 if not bound,
// -1 if bound, and real date\time stamp
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
// O.W. date/time stamp of DLL bound to (Old BIND)
ForwarderChain uint32 // -1 if no forwarders
Name uint32
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
const (
DLL_PROCESS_ATTACH = 1
DLL_THREAD_ATTACH = 2
DLL_THREAD_DETACH = 3
DLL_PROCESS_DETACH = 0
)
type SYSTEM_INFO struct {
ProcessorArchitecture uint16
Reserved uint16
PageSize uint32
MinimumApplicationAddress uintptr
MaximumApplicationAddress uintptr
ActiveProcessorMask uintptr
NumberOfProcessors uint32
ProcessorType uint32
AllocationGranularity uint32
ProcessorLevel uint16
ProcessorRevision uint16
}

View File

@ -0,0 +1,45 @@
// +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
BaseOfData uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000

View File

@ -0,0 +1,44 @@
// +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000

View File

@ -0,0 +1,103 @@
package wintun
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type Session struct {
handle uintptr
}
const (
PacketSizeMax = 0xffff // Maximum packet size
RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB)
RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
)
// Packet with data
type Packet struct {
Next *Packet // Pointer to next packet in queue
Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
}
var (
procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket")
procWintunEndSession = modwintun.NewProc("WintunEndSession")
procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent")
procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket")
procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
procWintunSendPacket = modwintun.NewProc("WintunSendPacket")
procWintunStartSession = modwintun.NewProc("WintunStartSession")
)
func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
if r0 == 0 {
err = e1
} else {
session = Session{r0}
}
return
}
func (session Session) End() {
syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
session.handle = 0
}
func (session Session) ReadWaitEvent() (handle windows.Handle) {
r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
handle = windows.Handle(r0)
return
}
func (session Session) ReceivePacket() (packet []byte, err error) {
var packetSize uint32
r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
if r0 == 0 {
err = e1
return
}
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
return
}
func (session Session) ReleaseReceivePacket(packet []byte) {
syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
}
func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
if r0 == 0 {
err = e1
return
}
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
return
}
func (session Session) SendPacket(packet []byte) {
syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@ -0,0 +1,221 @@
package wintun
import (
"errors"
"runtime"
"syscall"
"unsafe"
"github.com/Dreamacro/clash/log"
"golang.org/x/sys/windows"
)
type loggerLevel int
const (
logInfo loggerLevel = iota
logWarn
logErr
)
const (
PoolNameMax = 256
AdapterNameMax = 128
)
type Pool [PoolNameMax]uint16
type Adapter struct {
handle uintptr
}
var (
modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver")
procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters")
procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter")
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
)
func setupLogger(dll *lazyDLL) {
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
var lv log.LogLevel
switch level {
case logInfo:
lv = log.INFO
case logWarn:
lv = log.WARNING
case logErr:
lv = log.ERROR
default:
lv = log.INFO
}
log.PrintLog(lv, "[Wintun] %s", windows.UTF16PtrToString(msg))
return 0
}), 0, 0)
}
func MakePool(poolName string) (pool *Pool, err error) {
poolName16, err := windows.UTF16FromString(poolName)
if err != nil {
return
}
if len(poolName16) > PoolNameMax {
err = errors.New("Pool name too long")
return
}
pool = &Pool{}
copy(pool[:], poolName16)
return
}
func (pool *Pool) String() string {
return windows.UTF16ToString(pool[:])
}
func freeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
}
// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
// released after use.
func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
ifname16, err := windows.UTF16PtrFromString(ifname)
if err != nil {
return nil, err
}
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{r0}
runtime.SetFinalizer(wintun, freeAdapter)
return
}
// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while
// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
// uses is completely undocumented, and so there could be minor interesting complications with its
// usage. This function returns the network adapter ID and a flag if reboot is required.
func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
var ifname16 *uint16
ifname16, err = windows.UTF16PtrFromString(ifname)
if err != nil {
return
}
var _p0 uint32
r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
rebootRequired = _p0 != 0
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{r0}
runtime.SetFinalizer(wintun, freeAdapter)
return
}
// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns
// a bool indicating whether a reboot is required.
func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) {
var _p0 uint32
if forceCloseSessions {
_p0 = 1
}
var _p1 uint32
r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
rebootRequired = _p1 != 0
if r1 == 0 {
err = e1
}
return
}
// DeleteMatchingAdapters deletes all Wintun adapters, which match
// given criteria, and returns which ones it deleted, whether a reboot
// is required after, and which errors occurred during the process.
func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
cb := func(handle uintptr, _ uintptr) int {
adapter := &Adapter{handle}
if !matches(adapter) {
return 1
}
rebootRequired2, err := adapter.Delete(forceCloseSessions)
if err != nil {
errors = append(errors, err)
return 1
}
rebootRequired = rebootRequired || rebootRequired2
return 1
}
r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
if r1 == 0 {
errors = append(errors, e1)
}
return
}
// Name returns the name of the Wintun adapter.
func (wintun *Adapter) Name() (ifname string, err error) {
var ifname16 [AdapterNameMax]uint16
r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 {
err = e1
return
}
ifname = windows.UTF16ToString(ifname16[:])
return
}
// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
// pools, also removes Wintun from the driver store, usually called by uninstallers.
func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
var _p0 uint32
r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
rebootRequired = _p0 != 0
if r1 == 0 {
err = e1
}
return
}
// SetName sets name of the Wintun adapter.
func (wintun *Adapter) SetName(ifname string) (err error) {
ifname16, err := windows.UTF16FromString(ifname)
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 {
err = e1
}
return
}
// RunningVersion returns the version of the running Wintun driver.
func RunningVersion() (version uint32, err error) {
r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
version = uint32(r0)
if version == 0 {
err = e1
}
return
}
// LUID returns the LUID of the adapter.
func (wintun *Adapter) LUID() (luid uint64) {
syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
return
}

View File

@ -0,0 +1,263 @@
package gvisor
import (
"encoding/binary"
"errors"
"fmt"
"net"
"strings"
"sync"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/listener/tun/dev"
"github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/socks5"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const nicID tcpip.NICID = 1
type gvisorAdapter struct {
device dev.TunDevice
ipstack *stack.Stack
dnsserver *DNSServer
udpIn chan<- *inbound.PacketAdapter
stackName string
autoRoute bool
linkCache *channel.Endpoint
wg sync.WaitGroup // wait for goroutines to stop
writeHandle *channel.NotificationHandle
}
// GvisorAdapter create GvisorAdapter
func NewAdapter(device dev.TunDevice, conf config.Tun, tunAddress string, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
ipstack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
})
adapter := &gvisorAdapter{
device: device,
ipstack: ipstack,
udpIn: udpIn,
stackName: conf.Stack,
autoRoute: conf.AutoRoute,
}
linkEP, err := adapter.AsLinkEndpoint()
if err != nil {
return nil, fmt.Errorf("unable to create virtual endpoint: %v", err)
}
if err := ipstack.CreateNIC(nicID, linkEP); err != nil {
return nil, fmt.Errorf("fail to create NIC in ipstack: %v", err)
}
ipstack.SetPromiscuousMode(nicID, true) // Accept all the traffice from this NIC
ipstack.SetSpoofing(nicID, true) // Otherwise our TCP connection can not find the route backward
// Add route for ipv4 & ipv6
// So FindRoute will return correct route to tun NIC
subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
ipstack.AddRoute(tcpip.Route{Destination: subnet, Gateway: "", NIC: nicID})
subnet, _ = tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 6)), tcpip.AddressMask(strings.Repeat("\x00", 6)))
ipstack.AddRoute(tcpip.Route{Destination: subnet, Gateway: "", NIC: nicID})
// TCP handler
// maximum number of half-open tcp connection set to 1024
// receive buffer size set to 20k
tcpFwd := tcp.NewForwarder(ipstack, 20*1024, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
log.Warnln("Can't create TCP Endpoint in ipstack: %v", err)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
// if the endpoint is not in connected state, conn.RemoteAddr() will return nil
// this protection may be not enough, but will help us debug the panic
if conn.RemoteAddr() == nil {
log.Warnln("TCP endpoint is not connected, current state: %v", tcp.EndpointState(ep.State()))
conn.Close()
return
}
target := getAddr(ep.Info().(*stack.TransportEndpointInfo).ID)
tcpIn <- inbound.NewSocket(target, conn, C.TUN)
})
ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
// UDP handler
ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, adapter.udpHandlePacket)
if resolver.DefaultResolver != nil {
err = adapter.ReCreateDNSServer(resolver.DefaultResolver.(*dns.Resolver), resolver.DefaultHostMapper.(*dns.ResolverEnhancer), conf.DNSListen)
if err != nil {
return nil, err
}
}
return adapter, nil
}
func (t *gvisorAdapter) Stack() string {
return t.stackName
}
func (t *gvisorAdapter) AutoRoute() bool {
return t.autoRoute
}
// Close close the TunAdapter
func (t *gvisorAdapter) Close() {
if t.dnsserver != nil {
t.dnsserver.Stop()
}
if t.ipstack != nil {
t.ipstack.Close()
}
if t.device != nil {
_ = t.device.Close()
}
}
func (t *gvisorAdapter) udpHandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
// ref: gvisor pkg/tcpip/transport/udp/endpoint.go HandlePacket
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
// Malformed packet.
t.ipstack.Stats().UDP.MalformedPacketsReceived.Increment()
return true
}
target := getAddr(id)
packet := &fakeConn{
id: id,
pkt: pkt,
s: t.ipstack,
payload: pkt.Data().AsRange().ToOwnedView(),
}
select {
case t.udpIn <- inbound.NewPacket(target, packet, C.TUN):
default:
}
return true
}
// Wait wait goroutines to exit
func (t *gvisorAdapter) Wait() {
t.wg.Wait()
}
func (t *gvisorAdapter) AsLinkEndpoint() (result stack.LinkEndpoint, err error) {
if t.linkCache != nil {
return t.linkCache, nil
}
mtu, err := t.device.MTU()
if err != nil {
return nil, errors.New("unable to get device mtu")
}
linkEP := channel.New(512, uint32(mtu), "")
// start Read loop. read ip packet from tun and write it to ipstack
t.wg.Add(1)
go func() {
for !t.device.IsClose() {
packet := make([]byte, mtu)
n, err := t.device.Read(packet)
if err != nil && !t.device.IsClose() {
log.Errorln("can not read from tun: %v", err)
}
var p tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
p = header.IPv6ProtocolNumber
}
if linkEP.IsAttached() {
linkEP.InjectInbound(p, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(packet[:n]).ToVectorisedView(),
}))
} else {
log.Debugln("received packet from tun when %s is not attached to any dispatcher.", t.device.Name())
}
}
t.wg.Done()
t.Close()
log.Debugln("%v stop read loop", t.device.Name())
}()
// start write notification
t.writeHandle = linkEP.AddNotify(t)
t.linkCache = linkEP
return t.linkCache, nil
}
// WriteNotify implements channel.Notification.WriteNotify.
func (t *gvisorAdapter) WriteNotify() {
packet, ok := t.linkCache.Read()
if ok {
var vv buffer.VectorisedView
// Append upper headers.
vv.AppendView(packet.Pkt.NetworkHeader().View())
vv.AppendView(packet.Pkt.TransportHeader().View())
// Append data payload.
vv.Append(packet.Pkt.Data().ExtractVV())
_, err := t.device.Write(vv.ToView())
if err != nil && !t.device.IsClose() {
log.Errorln("can not write to tun: %v", err)
}
}
}
func getAddr(id stack.TransportEndpointID) socks5.Addr {
ipv4 := id.LocalAddress.To4()
// get the big-endian binary represent of port
port := make([]byte, 2)
binary.BigEndian.PutUint16(port, id.LocalPort)
if ipv4 != "" {
addr := make([]byte, 1+net.IPv4len+2)
addr[0] = socks5.AtypIPv4
copy(addr[1:1+net.IPv4len], []byte(ipv4))
addr[1+net.IPv4len], addr[1+net.IPv4len+1] = port[0], port[1]
return addr
} else {
addr := make([]byte, 1+net.IPv6len+2)
addr[0] = socks5.AtypIPv6
copy(addr[1:1+net.IPv6len], []byte(id.LocalAddress))
addr[1+net.IPv6len], addr[1+net.IPv6len+1] = port[0], port[1]
return addr
}
}

View File

@ -0,0 +1,280 @@
package gvisor
import (
"fmt"
"net"
"github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
var (
ipv4Zero = tcpip.Address(net.IPv4zero.To4())
ipv6Zero = tcpip.Address(net.IPv6zero.To16())
)
// DNSServer is DNS Server listening on tun devcice
type DNSServer struct {
*dns.Server
resolver *dns.Resolver
stack *stack.Stack
tcpListener net.Listener
udpEndpoint *dnsEndpoint
udpEndpointID *stack.TransportEndpointID
tcpip.NICID
}
// dnsEndpoint is a TransportEndpoint that will register to stack
type dnsEndpoint struct {
stack.TransportEndpoint
stack *stack.Stack
uniqueID uint64
server *dns.Server
}
// Keep track of the source of DNS request
type dnsResponseWriter struct {
s *stack.Stack
pkt *stack.PacketBuffer // The request packet
id stack.TransportEndpointID
}
func (e *dnsEndpoint) UniqueID() uint64 {
return e.uniqueID
}
func (e *dnsEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return
}
// server DNS
var msg D.Msg
msg.Unpack(pkt.Data().AsRange().ToOwnedView())
writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id}
go e.server.ServeDNS(&writer, &msg)
}
func (e *dnsEndpoint) Close() {
}
func (e *dnsEndpoint) Wait() {
}
func (e *dnsEndpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) {
log.Warnln("DNS endpoint get a transport error: %v", transErr)
log.Debugln("DNS endpoint transport error packet : %v", pkt)
}
// Abort implements stack.TransportEndpoint.Abort.
func (e *dnsEndpoint) Abort() {
e.Close()
}
func (w *dnsResponseWriter) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(w.id.LocalAddress), Port: int(w.id.LocalPort)}
}
func (w *dnsResponseWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(w.id.RemoteAddress), Port: int(w.id.RemotePort)}
}
func (w *dnsResponseWriter) WriteMsg(msg *D.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}
_, err = w.Write(b)
return err
}
func (w *dnsResponseWriter) TsigStatus() error {
// Unsupported
return nil
}
func (w *dnsResponseWriter) TsigTimersOnly(bool) {
// Unsupported
}
func (w *dnsResponseWriter) Hijack() {
// Unsupported
}
func (w *dnsResponseWriter) Write(b []byte) (int, error) {
v := buffer.NewView(len(b))
copy(v, b)
data := v.ToVectorisedView()
// w.id.LocalAddress is the source ip of DNS response
r, _ := w.s.FindRoute(w.pkt.NICID, w.id.LocalAddress, w.id.RemoteAddress, w.pkt.NetworkProtocolNumber, false /* multicastLoop */)
return writeUDP(r, data, w.id.LocalPort, w.id.RemotePort)
}
func (w *dnsResponseWriter) Close() error {
return nil
}
// CreateDNSServer create a dns server on given netstack
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, ip net.IP, port int, nicID tcpip.NICID) (*DNSServer, error) {
var v4 bool
var err error
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
if ip.To4() != nil {
v4 = true
address.Addr = tcpip.Address(ip.To4())
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
s.AddAddress(nicID, ipv4.ProtocolNumber, address.Addr)
} else {
v4 = false
address.Addr = tcpip.Address(ip.To16())
s.AddAddress(nicID, ipv6.ProtocolNumber, address.Addr)
}
if address.Addr == ipv4Zero || address.Addr == ipv6Zero {
address.Addr = ""
}
handler := dns.NewHandler(resolver, mapper)
serverIn := &dns.Server{}
serverIn.SetHandler(handler)
// UDP DNS
id := &stack.TransportEndpointID{
LocalAddress: address.Addr,
LocalPort: uint16(port),
RemotePort: 0,
RemoteAddress: "",
}
// TransportEndpoint for DNS
endpoint := &dnsEndpoint{
stack: s,
uniqueID: s.UniqueID(),
server: serverIn,
}
if tcpiperr := s.RegisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*id,
endpoint,
ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect.
nicID); err != nil {
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
}
// TCP DNS
var tcpListener net.Listener
if v4 {
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
} else {
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
}
if err != nil {
return nil, fmt.Errorf("can not listen on tun: %v", err)
}
server := &DNSServer{
Server: serverIn,
resolver: resolver,
stack: s,
tcpListener: tcpListener,
udpEndpoint: endpoint,
udpEndpointID: id,
NICID: nicID,
}
server.SetHandler(handler)
server.Server.Server = &D.Server{Listener: tcpListener, Handler: server}
go func() {
server.ActivateAndServe()
}()
return server, err
}
// Stop stop the DNS Server on tun
func (s *DNSServer) Stop() {
// shutdown TCP DNS Server
s.Server.Shutdown()
// remove TCP endpoint from stack
if s.Listener != nil {
s.Listener.Close()
}
// remove udp endpoint from stack
s.stack.UnregisterTransportEndpoint(
[]tcpip.NetworkProtocolNumber{
ipv4.ProtocolNumber,
ipv6.ProtocolNumber,
},
udp.ProtocolNumber,
*s.udpEndpointID,
s.udpEndpoint,
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
s.NICID)
}
// DNSListen return the listening address of DNS Server
func (t *gvisorAdapter) DNSListen() string {
if t.dnsserver != nil {
id := t.dnsserver.udpEndpointID
return fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
}
return ""
}
// Stop stop the DNS Server on tun
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, addr string) error {
if addr == "" && t.dnsserver == nil {
return nil
}
if addr == t.DNSListen() && t.dnsserver != nil && t.dnsserver.resolver == resolver {
return nil
}
if t.dnsserver != nil {
t.dnsserver.Stop()
t.dnsserver = nil
log.Debugln("tun DNS server stoped")
}
var err error
_, port, err := net.SplitHostPort(addr)
if port == "0" || port == "" || err != nil {
return nil
}
if resolver == nil {
return fmt.Errorf("failed to create DNS server on tun: resolver not provided")
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
server, err := CreateDNSServer(t.ipstack, resolver, mapper, udpAddr.IP, udpAddr.Port, nicID)
if err != nil {
return err
}
t.dnsserver = server
log.Infoln("Tun DNS server listening at: %s, fake ip enabled: %v", addr, mapper.FakeIPEnabled())
return nil
}

View File

@ -0,0 +1,109 @@
package gvisor
import (
"fmt"
"net"
"github.com/Dreamacro/clash/component/resolver"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type fakeConn struct {
id stack.TransportEndpointID // The endpoint of incomming packet, it's remote address is the source address it sent from
pkt *stack.PacketBuffer // The original packet comming from tun
s *stack.Stack
payload []byte
fakeip *bool
}
func (c *fakeConn) Data() []byte {
return c.payload
}
func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) {
v := buffer.View(b)
data := v.ToVectorisedView()
var localAddress tcpip.Address
var localPort uint16
// if addr is not provided, write back use original dst Addr as src Addr
if c.FakeIP() || addr == nil {
localAddress = c.id.LocalAddress
localPort = c.id.LocalPort
} else {
udpaddr, _ := addr.(*net.UDPAddr)
localAddress = tcpip.Address(udpaddr.IP)
localPort = uint16(udpaddr.Port)
}
r, _ := c.s.FindRoute(c.pkt.NICID, localAddress, c.id.RemoteAddress, c.pkt.NetworkProtocolNumber, false /* multicastLoop */)
return writeUDP(r, data, localPort, c.id.RemotePort)
}
func (c *fakeConn) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(c.id.RemoteAddress), Port: int(c.id.RemotePort)}
}
func (c *fakeConn) Close() error {
return nil
}
func (c *fakeConn) Drop() {
}
func (c *fakeConn) FakeIP() bool {
if c.fakeip != nil {
return *c.fakeip
}
fakeip := resolver.IsFakeIP(net.IP(c.id.LocalAddress.To4()))
c.fakeip = &fakeip
return fakeip
}
func writeUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) (int, error) {
const protocol = udp.ProtocolNumber
// Allocate a buffer for the UDP header.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
Data: data,
})
// Initialize the header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
// Set the checksum field unless TX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
if r.RequiresTXTransportChecksum() {
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
ttl := r.DefaultTTL()
if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: protocol, TTL: ttl, TOS: 0 /* default */}, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return 0, fmt.Errorf("%v", err)
}
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
return data.Size(), nil
}

View File

@ -0,0 +1,9 @@
package ipstack
// TunAdapter hold the state of tun/tap interface
type TunAdapter interface {
Close()
Stack() string
DNSListen() string
AutoRoute() bool
}

View File

@ -0,0 +1,101 @@
package system
import (
"encoding/binary"
"io"
"net"
"time"
"github.com/Dreamacro/clash/component/resolver"
D "github.com/miekg/dns"
"github.com/kr328/tun2socket/binding"
"github.com/kr328/tun2socket/redirect"
)
const defaultDnsReadTimeout = time.Second * 30
func shouldHijackDns(dnsAddr binding.Address, targetAddr binding.Address) bool {
if targetAddr.Port != 53 {
return false
}
return dnsAddr.IP.Equal(net.IPv4zero) || dnsAddr.IP.Equal(targetAddr.IP)
}
func hijackUDPDns(pkt []byte, ep *binding.Endpoint, sender redirect.UDPSender) {
go func() {
answer, err := relayDnsPacket(pkt)
if err != nil {
return
}
_ = sender(answer, &binding.Endpoint{
Source: ep.Target,
Target: ep.Source,
})
}()
}
func hijackTCPDns(conn net.Conn) {
go func() {
defer conn.Close()
for {
if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
return
}
var length uint16
if binary.Read(conn, binary.BigEndian, &length) != nil {
return
}
data := make([]byte, length)
_, err := io.ReadFull(conn, data)
if err != nil {
return
}
rb, err := relayDnsPacket(data)
if err != nil {
continue
}
if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil {
return
}
if _, err := conn.Write(rb); err != nil {
return
}
}
}()
}
func relayDnsPacket(payload []byte) ([]byte, error) {
msg := &D.Msg{}
if err := msg.Unpack(payload); err != nil {
return nil, err
}
r, err := resolver.ServeMsg(msg)
if err != nil {
return nil, err
}
for _, ans := range r.Answer {
header := ans.Header()
if header.Class == D.ClassINET && (header.Rrtype == D.TypeA || header.Rrtype == D.TypeAAAA) {
header.Ttl = 1
}
}
r.SetRcode(msg, r.Rcode)
r.Compress = true
return r.Pack()
}

View File

@ -0,0 +1,21 @@
package system
import "github.com/Dreamacro/clash/log"
type logger struct{}
func (l *logger) D(format string, args ...interface{}) {
log.Debugln("[TUN] "+format, args...)
}
func (l *logger) I(format string, args ...interface{}) {
log.Infoln("[TUN] "+format, args...)
}
func (l *logger) W(format string, args ...interface{}) {
log.Warnln("[TUN] "+format, args...)
}
func (l *logger) E(format string, args ...interface{}) {
log.Errorln("[TUN] "+format, args...)
}

View File

@ -0,0 +1,37 @@
package system
import (
"net"
"strconv"
"github.com/kr328/tun2socket/binding"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
)
func handleTCP(conn net.Conn, endpoint *binding.Endpoint, tcpIn chan<- C.ConnContext) {
src := &net.TCPAddr{
IP: endpoint.Source.IP,
Port: int(endpoint.Source.Port),
Zone: "",
}
dst := &net.TCPAddr{
IP: endpoint.Target.IP,
Port: int(endpoint.Target.Port),
Zone: "",
}
metadata := &C.Metadata{
NetWork: C.TCP,
Type: C.TUN,
SrcIP: src.IP,
DstIP: dst.IP,
SrcPort: strconv.Itoa(src.Port),
DstPort: strconv.Itoa(dst.Port),
AddrType: C.AtypIPv4,
Host: "",
}
tcpIn <- context.NewConnContext(conn, metadata)
}

View File

@ -0,0 +1,125 @@
package system
import (
"net"
"strconv"
"sync"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/tun/dev"
"github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/log"
"github.com/kr328/tun2socket"
"github.com/kr328/tun2socket/binding"
"github.com/kr328/tun2socket/redirect"
)
type systemAdapter struct {
device dev.TunDevice
tun *tun2socket.Tun2Socket
lock sync.Mutex
stackName string
dnsListen string
autoRoute bool
}
func NewAdapter(device dev.TunDevice, conf config.Tun, mtu int, gateway, mirror string, onStop func(), tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
adapter := &systemAdapter{
device: device,
stackName: conf.Stack,
dnsListen: conf.DNSListen,
autoRoute: conf.AutoRoute,
}
adapter.lock.Lock()
defer adapter.lock.Unlock()
//adapter.stopLocked()
dnsHost, dnsPort, err := net.SplitHostPort(conf.DNSListen)
if err != nil {
return nil, err
}
dnsP, err := strconv.Atoi(dnsPort)
if err != nil {
return nil, err
}
dnsAddr := binding.Address{
IP: net.ParseIP(dnsHost),
Port: uint16(dnsP),
}
t := tun2socket.NewTun2Socket(device, mtu, net.ParseIP(gateway), net.ParseIP(mirror))
t.SetAllocator(allocUDP)
t.SetClosedHandler(onStop)
t.SetLogger(&logger{})
t.SetTCPHandler(func(conn net.Conn, endpoint *binding.Endpoint) {
if shouldHijackDns(dnsAddr, endpoint.Target) {
hijackTCPDns(conn)
if log.Level() == log.DEBUG {
log.Debugln("[TUN] hijack dns tcp: %s:%d", endpoint.Target.IP.String(), endpoint.Target.Port)
}
return
}
handleTCP(conn, endpoint, tcpIn)
})
t.SetUDPHandler(func(payload []byte, endpoint *binding.Endpoint, sender redirect.UDPSender) {
if shouldHijackDns(dnsAddr, endpoint.Target) {
hijackUDPDns(payload, endpoint, sender)
if log.Level() == log.DEBUG {
log.Debugln("[TUN] hijack dns udp: %s:%d", endpoint.Target.IP.String(), endpoint.Target.Port)
}
return
}
handleUDP(payload, endpoint, sender, udpIn)
})
t.Start()
adapter.tun = t
return adapter, nil
}
func (t *systemAdapter) Stack() string {
return t.stackName
}
func (t *systemAdapter) AutoRoute() bool {
return t.autoRoute
}
func (t *systemAdapter) DNSListen() string {
return t.dnsListen
}
func (t *systemAdapter) Close() {
t.lock.Lock()
defer t.lock.Unlock()
t.stopLocked()
}
func (t *systemAdapter) stopLocked() {
if t.tun != nil {
t.tun.Close()
}
if t.device != nil {
_ = t.device.Close()
}
t.tun = nil
t.device = nil
}

View File

@ -0,0 +1,74 @@
package system
import (
"io"
"net"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
"github.com/kr328/tun2socket/binding"
"github.com/kr328/tun2socket/redirect"
)
type udpPacket struct {
source binding.Address
data []byte
send redirect.UDPSender
}
func (u *udpPacket) Data() []byte {
return u.data
}
func (u *udpPacket) WriteBack(b []byte, addr net.Addr) (n int, err error) {
uAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, io.ErrClosedPipe
}
return len(b), u.send(b, &binding.Endpoint{
Source: binding.Address{IP: uAddr.IP, Port: uint16(uAddr.Port)},
Target: u.source,
})
}
func (u *udpPacket) Drop() {
recycleUDP(u.data)
}
func (u *udpPacket) LocalAddr() net.Addr {
return &net.UDPAddr{
IP: u.source.IP,
Port: int(u.source.Port),
Zone: "",
}
}
func handleUDP(payload []byte, endpoint *binding.Endpoint, sender redirect.UDPSender, udpIn chan<- *inbound.PacketAdapter) {
pkt := &udpPacket{
source: endpoint.Source,
data: payload,
send: sender,
}
rAddr := &net.UDPAddr{
IP: endpoint.Target.IP,
Port: int(endpoint.Target.Port),
Zone: "",
}
select {
case udpIn <- inbound.NewPacket(socks5.ParseAddrToSocksAddr(rAddr), pkt, C.TUN):
default:
}
}
func allocUDP(size int) []byte {
return pool.Get(size)
}
func recycleUDP(payload []byte) {
_ = pool.Put(payload)
}

View File

@ -0,0 +1,51 @@
package tun
import (
"errors"
"fmt"
"strings"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/tun/dev"
"github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/listener/tun/ipstack/gvisor"
"github.com/Dreamacro/clash/listener/tun/ipstack/system"
"github.com/Dreamacro/clash/log"
)
// New create TunAdapter
func New(conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
tunAddress := "198.18.0.1"
autoRoute := conf.AutoRoute
stack := conf.Stack
var tunAdapter ipstack.TunAdapter
device, err := dev.OpenTunDevice(tunAddress, autoRoute)
if err != nil {
return nil, fmt.Errorf("can't open tun: %v", err)
}
mtu, err := device.MTU()
if err != nil {
_ = device.Close()
return nil, errors.New("unable to get device mtu")
}
if strings.EqualFold(stack, "system") {
tunAdapter, err = system.NewAdapter(device, conf, mtu, tunAddress, tunAddress, func() {}, tcpIn, udpIn)
} else if strings.EqualFold(stack, "gvisor") {
tunAdapter, err = gvisor.NewAdapter(device, conf, tunAddress, tcpIn, udpIn)
} else {
err = fmt.Errorf("can not support tun ip stack: %s, only support \"system\" and \"gvisor\"", stack)
}
if err != nil {
_ = device.Close()
return nil, err
}
log.Infoln("Tun adapter listening at: %s(%s), mtu: %d, auto route: %v, ip stack: %s", device.Name(), tunAddress, mtu, autoRoute, stack)
return tunAdapter, nil
}