mirror of
https://github.com/KusakabeShi/EtherGuard-VPN.git
synced 2024-11-22 15:23:08 +01:00
602 lines
15 KiB
Go
602 lines
15 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package conn
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"unsafe"
|
|
|
|
"golang.org/x/sys/windows"
|
|
|
|
"github.com/KusakabeSi/EtherGuard-VPN/conn/winrio"
|
|
)
|
|
|
|
const (
|
|
packetsPerRing = 1024
|
|
bytesPerPacket = 2048 - 32
|
|
receiveSpins = 15
|
|
)
|
|
|
|
type ringPacket struct {
|
|
addr WinRingEndpoint
|
|
data [bytesPerPacket]byte
|
|
}
|
|
|
|
type ringBuffer struct {
|
|
packets uintptr
|
|
head, tail uint32
|
|
id winrio.BufferId
|
|
iocp windows.Handle
|
|
isFull bool
|
|
cq winrio.Cq
|
|
mu sync.Mutex
|
|
overlapped windows.Overlapped
|
|
}
|
|
|
|
func (rb *ringBuffer) Push() *ringPacket {
|
|
for rb.isFull {
|
|
panic("ring is full")
|
|
}
|
|
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
|
rb.tail += 1
|
|
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
|
|
rb.isFull = true
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (rb *ringBuffer) Return(count uint32) {
|
|
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
|
|
return
|
|
}
|
|
rb.head += count
|
|
rb.isFull = false
|
|
}
|
|
|
|
type afWinRingBind struct {
|
|
sock windows.Handle
|
|
rx, tx ringBuffer
|
|
rq winrio.Rq
|
|
mu sync.Mutex
|
|
blackhole bool
|
|
}
|
|
|
|
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
|
|
type WinRingBind struct {
|
|
v4, v6 afWinRingBind
|
|
mu sync.RWMutex
|
|
isOpen uint32
|
|
}
|
|
|
|
func NewDefaultBind() Bind { return NewWinRingBind() }
|
|
|
|
func NewWinRingBind() Bind {
|
|
if !winrio.Initialize() {
|
|
return NewStdNetBind()
|
|
}
|
|
return new(WinRingBind)
|
|
}
|
|
|
|
type WinRingEndpoint struct {
|
|
family uint16
|
|
data [30]byte
|
|
}
|
|
|
|
var _ Bind = (*WinRingBind)(nil)
|
|
var _ Endpoint = (*WinRingEndpoint)(nil)
|
|
|
|
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
host, port, err := net.SplitHostPort(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
host16, err := windows.UTF16PtrFromString(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
port16, err := windows.UTF16PtrFromString(port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hints := windows.AddrinfoW{
|
|
Flags: windows.AI_NUMERICHOST,
|
|
Family: windows.AF_UNSPEC,
|
|
Socktype: windows.SOCK_DGRAM,
|
|
Protocol: windows.IPPROTO_UDP,
|
|
}
|
|
var addrinfo *windows.AddrinfoW
|
|
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer windows.FreeAddrInfoW(addrinfo)
|
|
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
|
return nil, windows.ERROR_INVALID_ADDRESS
|
|
}
|
|
var src []byte
|
|
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
|
unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen))
|
|
copy(dst[:], src)
|
|
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
|
}
|
|
|
|
func (*WinRingEndpoint) ClearSrc() {}
|
|
|
|
func (e *WinRingEndpoint) DstIP() net.IP {
|
|
switch e.family {
|
|
case windows.AF_INET:
|
|
return append([]byte{}, e.data[2:6]...)
|
|
case windows.AF_INET6:
|
|
return append([]byte{}, e.data[6:22]...)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *WinRingEndpoint) SrcIP() net.IP {
|
|
return nil // not supported
|
|
}
|
|
|
|
func (e *WinRingEndpoint) DstToBytes() []byte {
|
|
switch e.family {
|
|
case windows.AF_INET:
|
|
b := make([]byte, 0, 6)
|
|
b = append(b, e.data[2:6]...)
|
|
b = append(b, e.data[1], e.data[0])
|
|
return b
|
|
case windows.AF_INET6:
|
|
b := make([]byte, 0, 18)
|
|
b = append(b, e.data[6:22]...)
|
|
b = append(b, e.data[1], e.data[0])
|
|
return b
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *WinRingEndpoint) DstToString() string {
|
|
switch e.family {
|
|
case windows.AF_INET:
|
|
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
|
return addr.String()
|
|
case windows.AF_INET6:
|
|
var zone string
|
|
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
|
zone = strconv.FormatUint(uint64(scope), 10)
|
|
}
|
|
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
|
return addr.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (e *WinRingEndpoint) SrcToString() string {
|
|
return ""
|
|
}
|
|
|
|
func (ring *ringBuffer) CloseAndZero() {
|
|
if ring.cq != 0 {
|
|
winrio.CloseCompletionQueue(ring.cq)
|
|
ring.cq = 0
|
|
}
|
|
if ring.iocp != 0 {
|
|
windows.CloseHandle(ring.iocp)
|
|
ring.iocp = 0
|
|
}
|
|
if ring.id != 0 {
|
|
winrio.DeregisterBuffer(ring.id)
|
|
ring.id = 0
|
|
}
|
|
if ring.packets != 0 {
|
|
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
|
ring.packets = 0
|
|
}
|
|
ring.head = 0
|
|
ring.tail = 0
|
|
ring.isFull = false
|
|
}
|
|
|
|
func (bind *afWinRingBind) CloseAndZero() {
|
|
bind.rx.CloseAndZero()
|
|
bind.tx.CloseAndZero()
|
|
if bind.sock != 0 {
|
|
windows.CloseHandle(bind.sock)
|
|
bind.sock = 0
|
|
}
|
|
bind.blackhole = false
|
|
}
|
|
|
|
func (bind *WinRingBind) closeAndZero() {
|
|
atomic.StoreUint32(&bind.isOpen, 0)
|
|
bind.v4.CloseAndZero()
|
|
bind.v6.CloseAndZero()
|
|
}
|
|
|
|
func (ring *ringBuffer) Open() error {
|
|
var err error
|
|
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
|
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
|
|
var err error
|
|
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = bind.rx.Open()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = bind.tx.Open()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = windows.Bind(bind.sock, sa)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sa, err = windows.Getsockname(bind.sock)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sa, nil
|
|
}
|
|
|
|
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
defer func() {
|
|
if err != nil {
|
|
bind.closeAndZero()
|
|
}
|
|
}()
|
|
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
|
return nil, 0, ErrBindAlreadyOpen
|
|
}
|
|
var sa windows.Sockaddr
|
|
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
|
for i := 0; i < packetsPerRing; i++ {
|
|
err = bind.v4.InsertReceiveRequest()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
err = bind.v6.InsertReceiveRequest()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
}
|
|
atomic.StoreUint32(&bind.isOpen, 1)
|
|
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
|
}
|
|
|
|
func (bind *WinRingBind) Close() error {
|
|
bind.mu.RLock()
|
|
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
|
bind.mu.RUnlock()
|
|
return nil
|
|
}
|
|
atomic.StoreUint32(&bind.isOpen, 2)
|
|
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
|
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
|
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
|
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
|
|
bind.mu.RUnlock()
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
bind.closeAndZero()
|
|
return nil
|
|
}
|
|
|
|
func (bind *WinRingBind) SetMark(mark uint32) error {
|
|
return nil
|
|
}
|
|
|
|
func (bind *afWinRingBind) InsertReceiveRequest() error {
|
|
packet := bind.rx.Push()
|
|
dataBuffer := &winrio.Buffer{
|
|
Id: bind.rx.id,
|
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
|
|
Length: uint32(len(packet.data)),
|
|
}
|
|
addressBuffer := &winrio.Buffer{
|
|
Id: bind.rx.id,
|
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
|
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
}
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
|
}
|
|
|
|
//go:linkname procyield runtime.procyield
|
|
func procyield(cycles uint32)
|
|
|
|
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return 0, nil, net.ErrClosed
|
|
}
|
|
bind.rx.mu.Lock()
|
|
defer bind.rx.mu.Unlock()
|
|
|
|
var err error
|
|
var count uint32
|
|
var results [1]winrio.Result
|
|
retry:
|
|
count = 0
|
|
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
|
if tries > 0 {
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return 0, nil, net.ErrClosed
|
|
}
|
|
procyield(1)
|
|
}
|
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
|
}
|
|
if count == 0 {
|
|
err = winrio.Notify(bind.rx.cq)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
var bytes uint32
|
|
var key uintptr
|
|
var overlapped *windows.Overlapped
|
|
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return 0, nil, net.ErrClosed
|
|
}
|
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
|
if count == 0 {
|
|
return 0, nil, io.ErrNoProgress
|
|
|
|
}
|
|
}
|
|
bind.rx.Return(1)
|
|
err = bind.InsertReceiveRequest()
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
|
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
|
// attacker bandwidth, just like the rest of the receive path.
|
|
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return 0, nil, net.ErrClosed
|
|
}
|
|
goto retry
|
|
}
|
|
if results[0].Status != 0 {
|
|
return 0, nil, windows.Errno(results[0].Status)
|
|
}
|
|
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
|
ep := packet.addr
|
|
n := copy(buf, packet.data[:results[0].BytesTransferred])
|
|
return n, &ep, nil
|
|
}
|
|
|
|
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
|
bind.mu.RLock()
|
|
defer bind.mu.RUnlock()
|
|
return bind.v4.Receive(buf, &bind.isOpen)
|
|
}
|
|
|
|
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
|
bind.mu.RLock()
|
|
defer bind.mu.RUnlock()
|
|
return bind.v6.Receive(buf, &bind.isOpen)
|
|
}
|
|
|
|
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return net.ErrClosed
|
|
}
|
|
if len(buf) > bytesPerPacket {
|
|
return io.ErrShortBuffer
|
|
}
|
|
bind.tx.mu.Lock()
|
|
defer bind.tx.mu.Unlock()
|
|
var results [packetsPerRing]winrio.Result
|
|
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
|
|
if count == 0 && bind.tx.isFull {
|
|
err := winrio.Notify(bind.tx.cq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var bytes uint32
|
|
var key uintptr
|
|
var overlapped *windows.Overlapped
|
|
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if atomic.LoadUint32(isOpen) != 1 {
|
|
return net.ErrClosed
|
|
}
|
|
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
|
if count == 0 {
|
|
return io.ErrNoProgress
|
|
}
|
|
}
|
|
if count > 0 {
|
|
bind.tx.Return(count)
|
|
}
|
|
packet := bind.tx.Push()
|
|
packet.addr = *nend
|
|
copy(packet.data[:], buf)
|
|
dataBuffer := &winrio.Buffer{
|
|
Id: bind.tx.id,
|
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
|
|
Length: uint32(len(buf)),
|
|
}
|
|
addressBuffer := &winrio.Buffer{
|
|
Id: bind.tx.id,
|
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
|
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
}
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
|
}
|
|
|
|
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
|
nend, ok := endpoint.(*WinRingEndpoint)
|
|
if !ok {
|
|
return ErrWrongEndpointType
|
|
}
|
|
bind.mu.RLock()
|
|
defer bind.mu.RUnlock()
|
|
switch nend.family {
|
|
case windows.AF_INET:
|
|
if bind.v4.blackhole {
|
|
return nil
|
|
}
|
|
return bind.v4.Send(buf, nend, &bind.isOpen)
|
|
case windows.AF_INET6:
|
|
if bind.v6.blackhole {
|
|
return nil
|
|
}
|
|
return bind.v6.Send(buf, nend, &bind.isOpen)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
sysconn, err := bind.ipv4.SyscallConn()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err2 := sysconn.Control(func(fd uintptr) {
|
|
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
|
|
})
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bind.blackhole4 = blackhole
|
|
return nil
|
|
}
|
|
|
|
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
|
bind.mu.Lock()
|
|
defer bind.mu.Unlock()
|
|
sysconn, err := bind.ipv6.SyscallConn()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err2 := sysconn.Control(func(fd uintptr) {
|
|
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
|
|
})
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bind.blackhole6 = blackhole
|
|
return nil
|
|
}
|
|
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
|
bind.mu.RLock()
|
|
defer bind.mu.RUnlock()
|
|
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
|
return net.ErrClosed
|
|
}
|
|
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bind.v4.blackhole = blackhole
|
|
return nil
|
|
}
|
|
|
|
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
|
bind.mu.RLock()
|
|
defer bind.mu.RUnlock()
|
|
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
|
return net.ErrClosed
|
|
}
|
|
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bind.v6.blackhole = blackhole
|
|
return nil
|
|
}
|
|
|
|
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
|
|
const IP_UNICAST_IF = 31
|
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
|
var bytes [4]byte
|
|
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
|
|
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
|
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
|
|
const IPV6_UNICAST_IF = 31
|
|
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
|
}
|
|
|
|
// unsafeSlice updates the slice slicePtr to be a slice
|
|
// referencing the provided data with its length & capacity set to
|
|
// lenCap.
|
|
//
|
|
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
|
|
// update callers to use unsafe.Slice instead of this.
|
|
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
|
|
type sliceHeader struct {
|
|
Data unsafe.Pointer
|
|
Len int
|
|
Cap int
|
|
}
|
|
h := (*sliceHeader)(slicePtr)
|
|
h.Data = data
|
|
h.Len = lenCap
|
|
h.Cap = lenCap
|
|
}
|