mirror of
https://github.com/KusakabeShi/EtherGuard-VPN.git
synced 2024-12-28 00:18:48 +01:00
Merge branch 'source-caching'
This commit is contained in:
commit
b5ae42349c
115
src/conn.go
115
src/conn.go
@ -2,10 +2,35 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
||||
*/
|
||||
type Bind interface {
|
||||
SetMark(value uint32) error
|
||||
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
||||
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
||||
Send(buff []byte, end Endpoint) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
/* An Endpoint maintains the source/destination caching for a peer
|
||||
*
|
||||
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||
* src : the local address from which datagrams originate going to the peer
|
||||
*/
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() net.IP
|
||||
SrcIP() net.IP
|
||||
}
|
||||
|
||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
|
||||
// ensure that the host is an IP address
|
||||
@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
return addr, err
|
||||
}
|
||||
|
||||
func updateUDPConn(device *Device) error {
|
||||
/* Must hold device and net lock
|
||||
*/
|
||||
func unsafeCloseUDPListener(device *Device) error {
|
||||
var err error
|
||||
netc := &device.net
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
netc.bind = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// must inform all listeners
|
||||
func UpdateUDPListener(device *Device) error {
|
||||
device.mutex.Lock()
|
||||
defer device.mutex.Unlock()
|
||||
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
defer netc.mutex.Unlock()
|
||||
|
||||
// close existing connection
|
||||
// close existing sockets
|
||||
|
||||
if netc.conn != nil {
|
||||
netc.conn.Close()
|
||||
netc.conn = nil
|
||||
|
||||
// We need for that fd to be closed in all other go routines, which
|
||||
// means we have to wait. TODO: find less horrible way of doing this.
|
||||
time.Sleep(time.Second / 2)
|
||||
if err := unsafeCloseUDPListener(device); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new connection
|
||||
// assumption: netc.update WaitGroup should be exactly 1
|
||||
|
||||
// open new sockets
|
||||
|
||||
if device.tun.isUp.Get() {
|
||||
|
||||
// listen on new address
|
||||
device.log.Debug.Println("UDP bind updating")
|
||||
|
||||
conn, err := net.ListenUDP("udp", netc.addr)
|
||||
// bind to new port
|
||||
|
||||
var err error
|
||||
netc.bind, netc.port, err = CreateBind(netc.port)
|
||||
if err != nil {
|
||||
netc.bind = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// set mark
|
||||
|
||||
err = netc.bind.SetMark(netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
// clear cached source addresses
|
||||
|
||||
err = setMark(netc.conn, netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
for _, peer := range device.peers {
|
||||
peer.mutex.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.mutex.Unlock()
|
||||
}
|
||||
|
||||
// retrieve port (may have been chosen by kernel)
|
||||
// decrease waitgroup to 0
|
||||
|
||||
addr := conn.LocalAddr()
|
||||
netc.conn = conn
|
||||
netc.addr, _ = net.ResolveUDPAddr(
|
||||
addr.Network(),
|
||||
addr.String(),
|
||||
)
|
||||
go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
|
||||
go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
|
||||
|
||||
// notify goroutines
|
||||
|
||||
signalSend(device.signal.newUDPConn)
|
||||
device.log.Debug.Println("UDP bind has been updated")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeUDPConn(device *Device) {
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
if netc.conn != nil {
|
||||
netc.conn.Close()
|
||||
}
|
||||
netc.mutex.Unlock()
|
||||
signalSend(device.signal.newUDPConn)
|
||||
func CloseUDPListener(device *Device) error {
|
||||
device.mutex.Lock()
|
||||
device.net.mutex.Lock()
|
||||
err := unsafeCloseUDPListener(device)
|
||||
device.net.mutex.Unlock()
|
||||
device.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
|
@ -6,6 +6,126 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func setMark(conn *net.UDPConn, value uint32) error {
|
||||
/* This code is meant to be a temporary solution
|
||||
* on platforms for which the sticky socket / source caching behavior
|
||||
* has not yet been implemented.
|
||||
*
|
||||
* See conn_linux.go for an implementation on the linux platform.
|
||||
*/
|
||||
|
||||
type NativeBind struct {
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
}
|
||||
|
||||
type NativeEndpoint net.UDPAddr
|
||||
|
||||
var _ Bind = (*NativeBind)(nil)
|
||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||
|
||||
func CreateEndpoint(s string) (Endpoint, error) {
|
||||
addr, err := parseEndpoint(s)
|
||||
return (*NativeEndpoint)(addr), err
|
||||
}
|
||||
|
||||
func (_ *NativeEndpoint) ClearSrc() {}
|
||||
|
||||
func (e *NativeEndpoint) DstIP() net.IP {
|
||||
return (*net.UDPAddr)(e).IP
|
||||
}
|
||||
|
||||
func (e *NativeEndpoint) SrcIP() net.IP {
|
||||
return nil // not supported
|
||||
}
|
||||
|
||||
func (e *NativeEndpoint) DstToBytes() []byte {
|
||||
addr := (*net.UDPAddr)(e)
|
||||
out := addr.IP
|
||||
out = append(out, byte(addr.Port&0xff))
|
||||
out = append(out, byte((addr.Port>>8)&0xff))
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *NativeEndpoint) DstToString() string {
|
||||
return (*net.UDPAddr)(e).String()
|
||||
}
|
||||
|
||||
func (e *NativeEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
|
||||
// listen
|
||||
|
||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// retrieve port
|
||||
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
laddr.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn, uaddr.Port, nil
|
||||
}
|
||||
|
||||
func CreateBind(uport uint16) (Bind, uint16, error) {
|
||||
var err error
|
||||
var bind NativeBind
|
||||
|
||||
port := int(uport)
|
||||
|
||||
bind.ipv4, port, err = listenNet("udp4", port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
bind.ipv6, port, err = listenNet("udp6", port)
|
||||
if err != nil {
|
||||
bind.ipv4.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return &bind, uint16(port), nil
|
||||
}
|
||||
|
||||
func (bind *NativeBind) Close() error {
|
||||
err1 := bind.ipv4.Close()
|
||||
err2 := bind.ipv6.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
|
||||
return n, (*NativeEndpoint)(endpoint), err
|
||||
}
|
||||
|
||||
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
|
||||
return n, (*NativeEndpoint)(endpoint), err
|
||||
}
|
||||
|
||||
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
var err error
|
||||
nend := endpoint.(*NativeEndpoint)
|
||||
if nend.IP.To16() != nil {
|
||||
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||
} else {
|
||||
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (bind *NativeBind) SetMark(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
@ -15,20 +16,230 @@ import (
|
||||
)
|
||||
|
||||
/* Supports source address caching
|
||||
*
|
||||
* It is important that the endpoint is only updated after the packet content has been authenticated.
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
type Endpoint struct {
|
||||
// source (selected based on dst type)
|
||||
// (could use RawSockaddrAny and unsafe)
|
||||
srcIPv6 unix.RawSockaddrInet6
|
||||
srcIPv4 unix.RawSockaddrInet4
|
||||
srcIf4 int32
|
||||
type NativeEndpoint struct {
|
||||
src unix.RawSockaddrInet6
|
||||
dst unix.RawSockaddrInet6
|
||||
}
|
||||
|
||||
dst unix.RawSockaddrAny
|
||||
type NativeBind struct {
|
||||
sock4 int
|
||||
sock6 int
|
||||
}
|
||||
|
||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||
var _ Bind = NativeBind{}
|
||||
|
||||
type IPv4Source struct {
|
||||
src unix.RawSockaddrInet4
|
||||
Ifindex int32
|
||||
}
|
||||
|
||||
func htons(val uint16) uint16 {
|
||||
var out [unsafe.Sizeof(val)]byte
|
||||
binary.BigEndian.PutUint16(out[:], val)
|
||||
return *((*uint16)(unsafe.Pointer(&out[0])))
|
||||
}
|
||||
|
||||
func ntohs(val uint16) uint16 {
|
||||
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
|
||||
return binary.BigEndian.Uint16((*tmp)[:])
|
||||
}
|
||||
|
||||
func CreateEndpoint(s string) (Endpoint, error) {
|
||||
var end NativeEndpoint
|
||||
addr, err := parseEndpoint(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipv4 := addr.IP.To4()
|
||||
if ipv4 != nil {
|
||||
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
||||
dst.Family = unix.AF_INET
|
||||
dst.Port = htons(uint16(addr.Port))
|
||||
dst.Zero = [8]byte{}
|
||||
copy(dst.Addr[:], ipv4)
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
ipv6 := addr.IP.To16()
|
||||
if ipv6 != nil {
|
||||
zone, err := zoneToUint32(addr.Zone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst := &end.dst
|
||||
dst.Family = unix.AF_INET6
|
||||
dst.Port = htons(uint16(addr.Port))
|
||||
dst.Flowinfo = 0
|
||||
dst.Scope_id = zone
|
||||
copy(dst.Addr[:], ipv6[:])
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("Failed to recognize IP address format")
|
||||
}
|
||||
|
||||
func CreateBind(port uint16) (Bind, uint16, error) {
|
||||
var err error
|
||||
var bind NativeBind
|
||||
|
||||
bind.sock6, port, err = create6(port)
|
||||
if err != nil {
|
||||
return nil, port, err
|
||||
}
|
||||
|
||||
bind.sock4, port, err = create4(port)
|
||||
if err != nil {
|
||||
unix.Close(bind.sock6)
|
||||
}
|
||||
return bind, port, err
|
||||
}
|
||||
|
||||
func (bind NativeBind) SetMark(value uint32) error {
|
||||
err := unix.SetsockoptInt(
|
||||
bind.sock6,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.SetsockoptInt(
|
||||
bind.sock4,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
}
|
||||
|
||||
func closeUnblock(fd int) error {
|
||||
// shutdown to unblock readers
|
||||
unix.Shutdown(fd, unix.SHUT_RD)
|
||||
return unix.Close(fd)
|
||||
}
|
||||
|
||||
func (bind NativeBind) Close() error {
|
||||
err1 := closeUnblock(bind.sock6)
|
||||
err2 := closeUnblock(bind.sock4)
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||
var end NativeEndpoint
|
||||
n, err := receive6(
|
||||
bind.sock6,
|
||||
buff,
|
||||
&end,
|
||||
)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||
var end NativeEndpoint
|
||||
n, err := receive4(
|
||||
bind.sock4,
|
||||
buff,
|
||||
&end,
|
||||
)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
|
||||
nend := end.(*NativeEndpoint)
|
||||
switch nend.dst.Family {
|
||||
case unix.AF_INET6:
|
||||
return send6(bind.sock6, nend, buff)
|
||||
case unix.AF_INET:
|
||||
return send4(bind.sock4, nend, buff)
|
||||
default:
|
||||
return errors.New("Unknown address family of destination")
|
||||
}
|
||||
}
|
||||
|
||||
func sockaddrToString(addr unix.RawSockaddrInet6) string {
|
||||
var udpAddr net.UDPAddr
|
||||
|
||||
switch addr.Family {
|
||||
case unix.AF_INET6:
|
||||
udpAddr.Port = int(ntohs(addr.Port))
|
||||
udpAddr.IP = addr.Addr[:]
|
||||
return udpAddr.String()
|
||||
|
||||
case unix.AF_INET:
|
||||
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
|
||||
udpAddr.Port = int(ntohs(ptr.Port))
|
||||
udpAddr.IP = net.IPv4(
|
||||
ptr.Addr[0],
|
||||
ptr.Addr[1],
|
||||
ptr.Addr[2],
|
||||
ptr.Addr[3],
|
||||
)
|
||||
return udpAddr.String()
|
||||
|
||||
default:
|
||||
return "<unknown address family>"
|
||||
}
|
||||
}
|
||||
|
||||
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
|
||||
switch addr.Family {
|
||||
case unix.AF_INET6:
|
||||
return addr.Addr[:]
|
||||
case unix.AF_INET:
|
||||
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
|
||||
return net.IPv4(
|
||||
ptr.Addr[0],
|
||||
ptr.Addr[1],
|
||||
ptr.Addr[2],
|
||||
ptr.Addr[3],
|
||||
)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||
return rawAddrToIP(end.src)
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) DstIP() net.IP {
|
||||
return rawAddrToIP(end.dst)
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) DstToBytes() []byte {
|
||||
ptr := unsafe.Pointer(&end.src)
|
||||
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
|
||||
return arr[:]
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) SrcToString() string {
|
||||
return sockaddrToString(end.src)
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) DstToString() string {
|
||||
return sockaddrToString(end.dst)
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) ClearDst() {
|
||||
end.dst = unix.RawSockaddrInet6{}
|
||||
}
|
||||
|
||||
func (end *NativeEndpoint) ClearSrc() {
|
||||
end.src = unix.RawSockaddrInet6{}
|
||||
}
|
||||
|
||||
func zoneToUint32(zone string) (uint32, error) {
|
||||
@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
|
||||
return uint32(n), err
|
||||
}
|
||||
|
||||
func (end *Endpoint) ClearSrc() {
|
||||
end.srcIf4 = 0
|
||||
end.srcIPv4 = unix.RawSockaddrInet4{}
|
||||
end.srcIPv6 = unix.RawSockaddrInet6{}
|
||||
}
|
||||
func create4(port uint16) (int, uint16, error) {
|
||||
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
|
||||
func (end *Endpoint) Set(s string) error {
|
||||
addr, err := parseEndpoint(s)
|
||||
if err != nil {
|
||||
return err
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
ipv6 := addr.IP.To16()
|
||||
if ipv6 != nil {
|
||||
zone, err := zoneToUint32(addr.Zone)
|
||||
if err != nil {
|
||||
addr := unix.SockaddrInet4{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
if err := func() error {
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_REUSEADDR,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
|
||||
ptr.Family = unix.AF_INET6
|
||||
ptr.Port = uint16(addr.Port)
|
||||
ptr.Flowinfo = 0
|
||||
ptr.Scope_id = zone
|
||||
copy(ptr.Addr[:], ipv6[:])
|
||||
end.ClearSrc()
|
||||
return nil
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IP,
|
||||
unix.IP_PKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
ipv4 := addr.IP.To4()
|
||||
if ipv4 != nil {
|
||||
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
||||
ptr.Family = unix.AF_INET
|
||||
ptr.Port = uint16(addr.Port)
|
||||
ptr.Zero = [8]byte{}
|
||||
copy(ptr.Addr[:], ipv4)
|
||||
end.ClearSrc()
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("Failed to recognize IP address format")
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
var iovec unix.Iovec
|
||||
func create6(port uint16) (int, uint16, error) {
|
||||
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET6,
|
||||
unix.SOCK_DGRAM,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
addr := unix.SockaddrInet6{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
if err := func() error {
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_REUSEADDR,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_RECVPKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_V6ONLY,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
||||
|
||||
// construct message header
|
||||
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||
iovec.SetLen(len(buff))
|
||||
|
||||
@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IPV6,
|
||||
Type: unix.IPV6_PKTINFO,
|
||||
Len: unix.SizeofInet6Pktinfo,
|
||||
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet6Pktinfo{
|
||||
Addr: end.srcIPv6.Addr,
|
||||
Ifindex: end.srcIPv6.Scope_id,
|
||||
Addr: end.src.Addr,
|
||||
Ifindex: end.src.Scope_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_SENDMSG,
|
||||
sock,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msghdr)),
|
||||
0,
|
||||
)
|
||||
|
||||
if errno == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear src and retry
|
||||
|
||||
if errno == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||
_, _, errno = unix.Syscall(
|
||||
unix.SYS_SENDMSG,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msghdr)),
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
return errno
|
||||
}
|
||||
|
||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
var iovec unix.Iovec
|
||||
func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
||||
|
||||
// construct message header
|
||||
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||
iovec.SetLen(len(buff))
|
||||
|
||||
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||
|
||||
cmsg := struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IP,
|
||||
Type: unix.IP_PKTINFO,
|
||||
Len: unix.SizeofInet6Pktinfo,
|
||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet4Pktinfo{
|
||||
Spec_dst: end.srcIPv4.Addr,
|
||||
Ifindex: end.srcIf4,
|
||||
Spec_dst: src4.src.Addr,
|
||||
Ifindex: src4.Ifindex,
|
||||
},
|
||||
}
|
||||
|
||||
@ -156,51 +451,44 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
||||
Name: (*byte)(unsafe.Pointer(&end.dst)),
|
||||
Namelen: unix.SizeofSockaddrInet4,
|
||||
Control: (*byte)(unsafe.Pointer(&cmsg)),
|
||||
Flags: 0,
|
||||
}
|
||||
|
||||
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||
|
||||
// sendmsg(sock, &msghdr, 0)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_SENDMSG,
|
||||
sock,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msghdr)),
|
||||
0,
|
||||
)
|
||||
|
||||
// clear source and try again
|
||||
|
||||
if errno == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||
_, _, errno = unix.Syscall(
|
||||
unix.SYS_SENDMSG,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msghdr)),
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
// errno = 0 is still an error instance
|
||||
|
||||
if errno == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errno
|
||||
}
|
||||
|
||||
func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
|
||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||
|
||||
// extract underlying file descriptor
|
||||
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sock := file.Fd()
|
||||
|
||||
// send depending on address family of dst
|
||||
|
||||
family := *((*uint16)(unsafe.Pointer(&end.dst)))
|
||||
if family == unix.AF_INET {
|
||||
return send4(sock, end, buff)
|
||||
} else if family == unix.AF_INET6 {
|
||||
return send6(sock, end, buff)
|
||||
}
|
||||
return errors.New("Unknown address family of source")
|
||||
}
|
||||
|
||||
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
|
||||
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return err, nil, nil
|
||||
}
|
||||
// contruct message header
|
||||
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||
@ -208,60 +496,87 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo // big enough
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
}
|
||||
|
||||
var msghdr unix.Msghdr
|
||||
msghdr.Iov = &iovec
|
||||
msghdr.Iovlen = 1
|
||||
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
|
||||
msghdr.Namelen = unix.SizeofSockaddrInet4
|
||||
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
|
||||
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||
|
||||
// recvmsg(sock, &mskhdr, 0)
|
||||
|
||||
size, _, errno := unix.Syscall(
|
||||
unix.SYS_RECVMSG,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msghdr)),
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||
src4.src.Family = unix.AF_INET
|
||||
src4.src.Addr = cmsg.pktinfo.Spec_dst
|
||||
src4.Ifindex = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return int(size), nil
|
||||
}
|
||||
|
||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||
|
||||
// contruct message header
|
||||
|
||||
var iovec unix.Iovec
|
||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||
iovec.SetLen(len(buff))
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo
|
||||
}
|
||||
|
||||
var msg unix.Msghdr
|
||||
msg.Iov = &iovec
|
||||
msg.Iovlen = 1
|
||||
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
|
||||
msg.Namelen = uint32(unix.SizeofSockaddrAny)
|
||||
msg.Namelen = uint32(unix.SizeofSockaddrInet6)
|
||||
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
|
||||
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
// recvmsg(sock, &mskhdr, 0)
|
||||
|
||||
size, _, errno := unix.Syscall(
|
||||
unix.SYS_RECVMSG,
|
||||
file.Fd(),
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msg)),
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return errno, nil, nil
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||
|
||||
end.src.Family = unix.AF_INET6
|
||||
end.src.Addr = cmsg.pktinfo.Addr
|
||||
end.src.Scope_id = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
|
||||
println(info)
|
||||
|
||||
}
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func setMark(conn *net.UDPConn, value uint32) error {
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := conn.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.SetsockoptInt(
|
||||
int(file.Fd()),
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
return int(size), nil
|
||||
}
|
||||
|
@ -5,10 +5,8 @@ import (
|
||||
"crypto/rand"
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type CookieChecker struct {
|
||||
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||
}
|
||||
|
||||
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
||||
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
||||
st.mutex.RLock()
|
||||
defer st.mutex.RUnlock()
|
||||
|
||||
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src.IP)
|
||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
||||
func (st *CookieChecker) CreateReply(
|
||||
msg []byte,
|
||||
recv uint32,
|
||||
src *net.UDPAddr,
|
||||
src []byte,
|
||||
) (*MessageCookieReply, error) {
|
||||
|
||||
st.mutex.RLock()
|
||||
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
|
||||
var cookie [blake2s.Size128]byte
|
||||
func() {
|
||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||
mac.Write(src.IP)
|
||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
||||
mac.Write(src)
|
||||
mac.Sum(cookie[:0])
|
||||
}()
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
|
||||
|
||||
// check mac1
|
||||
|
||||
src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
|
||||
src := []byte{192, 168, 13, 37, 10, 10, 10}
|
||||
|
||||
checkMAC1 := func(msg []byte) {
|
||||
generator.AddMacs(msg)
|
||||
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
|
||||
|
||||
msg[5] ^= 0x20
|
||||
|
||||
srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
|
||||
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
|
||||
if checker.CheckMAC2(msg, srcBad1) {
|
||||
t.Fatal("MAC2 generation/verification failed")
|
||||
}
|
||||
|
||||
srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
|
||||
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
|
||||
if checker.CheckMAC2(msg, srcBad2) {
|
||||
t.Fatal("MAC2 generation/verification failed")
|
||||
}
|
||||
|
@ -2,29 +2,25 @@ package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
/* Daemonizes the process on linux
|
||||
*
|
||||
* This is done by spawning and releasing a copy with the --foreground flag
|
||||
*
|
||||
* TODO: Use env variable to spawn in background
|
||||
*/
|
||||
func Daemonize(attr *os.ProcAttr) error {
|
||||
// I would like to use os.Executable,
|
||||
// however this means dropping support for Go <1.8
|
||||
path, err := exec.LookPath(os.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func Daemonize() error {
|
||||
argv := []string{os.Args[0], "--foreground"}
|
||||
argv = append(argv, os.Args[1:]...)
|
||||
attr := &os.ProcAttr{
|
||||
Dir: ".",
|
||||
Env: os.Environ(),
|
||||
Files: []*os.File{
|
||||
os.Stdin,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
}
|
||||
process, err := os.StartProcess(
|
||||
argv[0],
|
||||
path,
|
||||
argv,
|
||||
attr,
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -9,8 +8,9 @@ import (
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
log *Logger // collection of loggers for levels
|
||||
idCounter uint // for assigning debug ids to peers
|
||||
closed AtomicBool // device is closed? (acting as guard)
|
||||
log *Logger // collection of loggers for levels
|
||||
idCounter uint // for assigning debug ids to peers
|
||||
fwMark uint32
|
||||
tun struct {
|
||||
device TUNDevice
|
||||
@ -22,9 +22,9 @@ type Device struct {
|
||||
}
|
||||
net struct {
|
||||
mutex sync.RWMutex
|
||||
addr *net.UDPAddr // UDP source address
|
||||
conn *net.UDPConn // UDP "connection"
|
||||
fwmark uint32
|
||||
bind Bind // bind interface
|
||||
port uint16 // listening port
|
||||
fwmark uint32 // mark value (0 = disabled)
|
||||
}
|
||||
mutex sync.RWMutex
|
||||
privateKey NoisePrivateKey
|
||||
@ -37,8 +37,7 @@ type Device struct {
|
||||
handshake chan QueueHandshakeElement
|
||||
}
|
||||
signal struct {
|
||||
stop chan struct{} // halts all go routines
|
||||
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
|
||||
stop chan struct{}
|
||||
}
|
||||
underLoadUntil atomic.Value
|
||||
ratelimiter Ratelimiter
|
||||
@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||
device.pool.messageBuffers.Put(msg)
|
||||
}
|
||||
|
||||
func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||
device := new(Device)
|
||||
|
||||
device.mutex.Lock()
|
||||
defer device.mutex.Unlock()
|
||||
|
||||
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
|
||||
device.log = logger
|
||||
device.peers = make(map[NoisePublicKey]*Peer)
|
||||
device.tun.device = tun
|
||||
|
||||
device.indices.Init()
|
||||
device.ratelimiter.Init()
|
||||
|
||||
device.routingTable.Reset()
|
||||
device.underLoadUntil.Store(time.Time{})
|
||||
|
||||
// setup pools
|
||||
// setup buffer pool
|
||||
|
||||
device.pool.messageBuffers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
// prepare signals
|
||||
|
||||
device.signal.stop = make(chan struct{})
|
||||
device.signal.newUDPConn = make(chan struct{}, 1)
|
||||
|
||||
// prepare net
|
||||
|
||||
device.net.port = 0
|
||||
device.net.bind = nil
|
||||
|
||||
// start workers
|
||||
|
||||
@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
go device.RoutineDecryption()
|
||||
go device.RoutineHandshake()
|
||||
}
|
||||
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineTUNEventReader()
|
||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineReceiveIncomming()
|
||||
|
||||
return device
|
||||
}
|
||||
|
||||
@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
|
||||
}
|
||||
|
||||
func (device *Device) Close() {
|
||||
if device.closed.Swap(true) {
|
||||
return
|
||||
}
|
||||
device.log.Info.Println("Closing device")
|
||||
device.RemoveAllPeers()
|
||||
close(device.signal.stop)
|
||||
closeUDPConn(device)
|
||||
CloseUDPListener(device)
|
||||
device.tun.device.Close()
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -15,6 +16,10 @@ type DummyTUN struct {
|
||||
events chan TUNEvent
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *DummyTUN) Name() string {
|
||||
return tun.name
|
||||
}
|
||||
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tun, _ := CreateDummyTUN("dummy")
|
||||
device := NewDevice(tun, LogLevelError)
|
||||
logger := NewLogger(LogLevelError, "")
|
||||
device := NewDevice(tun, logger)
|
||||
device.SetPrivateKey(sk)
|
||||
return device
|
||||
}
|
||||
|
134
src/main.go
134
src/main.go
@ -2,10 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||
)
|
||||
|
||||
func printUsage() {
|
||||
@ -43,28 +48,6 @@ func main() {
|
||||
interfaceName = os.Args[1]
|
||||
}
|
||||
|
||||
// daemonize the process
|
||||
|
||||
if !foreground {
|
||||
err := Daemonize()
|
||||
if err != nil {
|
||||
log.Println("Failed to daemonize:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// increase number of go workers (for Go <1.5)
|
||||
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// open TUN device
|
||||
|
||||
tun, err := CreateTUN(interfaceName)
|
||||
if err != nil {
|
||||
log.Println("Failed to create tun device:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// get log level (default: info)
|
||||
|
||||
logLevel := func() int {
|
||||
@ -79,25 +62,103 @@ func main() {
|
||||
return LogLevelInfo
|
||||
}()
|
||||
|
||||
logger := NewLogger(
|
||||
logLevel,
|
||||
fmt.Sprintf("(%s) ", interfaceName),
|
||||
)
|
||||
|
||||
logger.Debug.Println("Debug log enabled")
|
||||
|
||||
// open TUN device (or use supplied fd)
|
||||
|
||||
tun, err := func() (TUNDevice, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return CreateTUN(interfaceName)
|
||||
}
|
||||
|
||||
// construct tun device from supplied fd
|
||||
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return CreateTUNFromFile(interfaceName, file)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
logger.Error.Println("Failed to create TUN device:", err)
|
||||
os.Exit(ExitSetupFailed)
|
||||
}
|
||||
|
||||
// open UAPI file (or use supplied fd)
|
||||
|
||||
fileUAPI, err := func() (*os.File, error) {
|
||||
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
||||
if uapiFdStr == "" {
|
||||
return UAPIOpen(interfaceName)
|
||||
}
|
||||
|
||||
// use supplied fd
|
||||
|
||||
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(fd), ""), nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
logger.Error.Println("UAPI listen error:", err)
|
||||
os.Exit(ExitSetupFailed)
|
||||
return
|
||||
}
|
||||
// daemonize the process
|
||||
|
||||
if !foreground {
|
||||
env := os.Environ()
|
||||
env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
|
||||
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
|
||||
attr := &os.ProcAttr{
|
||||
Files: []*os.File{
|
||||
nil, // stdin
|
||||
nil, // stdout
|
||||
nil, // stderr
|
||||
tun.File(),
|
||||
fileUAPI,
|
||||
},
|
||||
Dir: ".",
|
||||
Env: env,
|
||||
}
|
||||
err = Daemonize(attr)
|
||||
if err != nil {
|
||||
logger.Error.Println("Failed to daemonize:", err)
|
||||
os.Exit(ExitSetupFailed)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// increase number of go workers (for Go <1.5)
|
||||
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// create wireguard device
|
||||
|
||||
device := NewDevice(tun, logLevel)
|
||||
device := NewDevice(tun, logger)
|
||||
|
||||
logInfo := device.log.Info
|
||||
logError := device.log.Error
|
||||
logInfo.Println("Starting device")
|
||||
logger.Info.Println("Device started")
|
||||
|
||||
// start configuration lister
|
||||
|
||||
uapi, err := NewUAPIListener(interfaceName)
|
||||
if err != nil {
|
||||
logError.Fatal("UAPI listen error:", err)
|
||||
}
|
||||
// start uapi listener
|
||||
|
||||
errs := make(chan error)
|
||||
term := make(chan os.Signal)
|
||||
wait := device.WaitChannel()
|
||||
|
||||
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := uapi.Accept()
|
||||
@ -109,7 +170,7 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
logInfo.Println("UAPI listener started")
|
||||
logger.Info.Println("UAPI listener started")
|
||||
|
||||
// wait for program to terminate
|
||||
|
||||
@ -122,9 +183,10 @@ func main() {
|
||||
case <-errs:
|
||||
}
|
||||
|
||||
// clean up UAPI bind
|
||||
// clean up
|
||||
|
||||
uapi.Close()
|
||||
device.Close()
|
||||
|
||||
logInfo.Println("Closing")
|
||||
logger.Info.Println("Shutting down")
|
||||
}
|
||||
|
@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.flag) == AtomicTrue
|
||||
}
|
||||
|
||||
func (a *AtomicBool) Swap(val bool) bool {
|
||||
flag := AtomicFalse
|
||||
if val {
|
||||
flag = AtomicTrue
|
||||
}
|
||||
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
|
||||
}
|
||||
|
||||
func (a *AtomicBool) Set(val bool) {
|
||||
flag := AtomicFalse
|
||||
if val {
|
||||
|
@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
|
||||
var err error
|
||||
var out []byte
|
||||
var nonce [12]byte
|
||||
out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
|
||||
out = key1.send.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
|
||||
assertNil(t, err)
|
||||
assertEqual(t, out, testMsg)
|
||||
}()
|
||||
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
|
||||
var err error
|
||||
var out []byte
|
||||
var nonce [12]byte
|
||||
out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
|
||||
out = key2.send.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
|
||||
assertNil(t, err)
|
||||
assertEqual(t, out, testMsg)
|
||||
}()
|
||||
|
29
src/peer.go
29
src/peer.go
@ -4,7 +4,6 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -16,7 +15,7 @@ type Peer struct {
|
||||
keyPairs KeyPairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
endpoint *net.UDPAddr
|
||||
endpoint Endpoint
|
||||
stats struct {
|
||||
txBytes uint64 // bytes send to peer (endpoint)
|
||||
rxBytes uint64 // bytes received from peer
|
||||
@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
|
||||
peer.endpoint = nil
|
||||
|
||||
// prepare queuing
|
||||
|
||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||
@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
peer.device.net.mutex.RLock()
|
||||
defer peer.device.net.mutex.RUnlock()
|
||||
peer.mutex.RLock()
|
||||
defer peer.mutex.RUnlock()
|
||||
if peer.endpoint == nil {
|
||||
return errors.New("No known endpoint for peer")
|
||||
}
|
||||
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||
}
|
||||
|
||||
/* Returns a short string identification for logging
|
||||
*/
|
||||
func (peer *Peer) String() string {
|
||||
if peer.endpoint == nil {
|
||||
return fmt.Sprintf(
|
||||
"peer(%d unknown %s)",
|
||||
peer.id,
|
||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"peer(%d %s %s)",
|
||||
peer.id,
|
||||
peer.endpoint.String(),
|
||||
peer.endpoint.DstToString(),
|
||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||
)
|
||||
}
|
||||
|
271
src/receive.go
271
src/receive.go
@ -13,19 +13,20 @@ import (
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
msgType uint32
|
||||
packet []byte
|
||||
buffer *[MaxMessageSize]byte
|
||||
source *net.UDPAddr
|
||||
msgType uint32
|
||||
packet []byte
|
||||
endpoint Endpoint
|
||||
buffer *[MaxMessageSize]byte
|
||||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
dropped int32
|
||||
mutex sync.Mutex
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
keyPair *KeyPair
|
||||
dropped int32
|
||||
mutex sync.Mutex
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
keyPair *KeyPair
|
||||
endpoint Endpoint
|
||||
}
|
||||
|
||||
func (elem *QueueInboundElement) Drop() {
|
||||
@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReceiveIncomming() {
|
||||
func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
|
||||
|
||||
logDebug := device.log.Debug
|
||||
logDebug.Println("Routine, receive incomming, started")
|
||||
logDebug.Println("Routine, receive incomming, IP version:", IP)
|
||||
|
||||
for {
|
||||
|
||||
// wait for new conn
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
logDebug.Println("Waiting for udp socket")
|
||||
buffer := device.GetMessageBuffer()
|
||||
|
||||
select {
|
||||
case <-device.signal.stop:
|
||||
return
|
||||
var (
|
||||
err error
|
||||
size int
|
||||
endpoint Endpoint
|
||||
)
|
||||
|
||||
case <-device.signal.newUDPConn:
|
||||
for {
|
||||
|
||||
// fetch connection
|
||||
// read next datagram
|
||||
|
||||
device.net.mutex.RLock()
|
||||
conn := device.net.conn
|
||||
device.net.mutex.RUnlock()
|
||||
if conn == nil {
|
||||
switch IP {
|
||||
case ipv4.Version:
|
||||
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
|
||||
case ipv6.Version:
|
||||
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
logDebug.Println("Listening for inbound packets")
|
||||
// check size of packet
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
packet := buffer[:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
buffer := device.GetMessageBuffer()
|
||||
var okay bool
|
||||
|
||||
for {
|
||||
switch msgType {
|
||||
|
||||
// read next datagram
|
||||
// check if transport
|
||||
|
||||
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
||||
case MessageTransportType:
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
// check size
|
||||
|
||||
if size < MinMessageSize {
|
||||
if len(packet) < MessageTransportType {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
// lookup key pair
|
||||
|
||||
packet := buffer[:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
var okay bool
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportType {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indices.Lookup(receiver)
|
||||
keyPair := value.keyPair
|
||||
if keyPair == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check key-pair expiry
|
||||
|
||||
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
|
||||
peer := value.peer
|
||||
elem := &QueueInboundElement{
|
||||
packet: packet,
|
||||
buffer: buffer,
|
||||
keyPair: keyPair,
|
||||
dropped: AtomicFalse,
|
||||
}
|
||||
elem.mutex.Lock()
|
||||
|
||||
// add to decryption queues
|
||||
|
||||
device.addToDecryptionQueue(device.queue.decryption, elem)
|
||||
device.addToInboundQueue(peer.queue.inbound, elem)
|
||||
buffer = device.GetMessageBuffer()
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indices.Lookup(receiver)
|
||||
keyPair := value.keyPair
|
||||
if keyPair == nil {
|
||||
continue
|
||||
|
||||
// otherwise it is a handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
|
||||
case MessageResponseType:
|
||||
okay = len(packet) == MessageResponseSize
|
||||
|
||||
case MessageCookieReplyType:
|
||||
okay = len(packet) == MessageCookieReplySize
|
||||
}
|
||||
|
||||
if okay {
|
||||
device.addToHandshakeQueue(
|
||||
device.queue.handshake,
|
||||
QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
source: raddr,
|
||||
},
|
||||
)
|
||||
buffer = device.GetMessageBuffer()
|
||||
// check key-pair expiry
|
||||
|
||||
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
|
||||
peer := value.peer
|
||||
elem := &QueueInboundElement{
|
||||
packet: packet,
|
||||
buffer: buffer,
|
||||
keyPair: keyPair,
|
||||
dropped: AtomicFalse,
|
||||
endpoint: endpoint,
|
||||
}
|
||||
elem.mutex.Lock()
|
||||
|
||||
// add to decryption queues
|
||||
|
||||
device.addToDecryptionQueue(device.queue.decryption, elem)
|
||||
device.addToInboundQueue(peer.queue.inbound, elem)
|
||||
buffer = device.GetMessageBuffer()
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
|
||||
case MessageResponseType:
|
||||
okay = len(packet) == MessageResponseSize
|
||||
|
||||
case MessageCookieReplyType:
|
||||
okay = len(packet) == MessageCookieReplySize
|
||||
}
|
||||
|
||||
if okay {
|
||||
device.addToHandshakeQueue(
|
||||
device.queue.handshake,
|
||||
QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
},
|
||||
)
|
||||
buffer = device.GetMessageBuffer()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// unmarshal packet
|
||||
|
||||
logDebug.Println("Process cookie reply from:", elem.source.String())
|
||||
|
||||
var reply MessageCookieReply
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
|
||||
return
|
||||
}
|
||||
|
||||
// endpoints destination address is the source of the datagram
|
||||
|
||||
srcBytes := elem.endpoint.DstToBytes()
|
||||
|
||||
if device.IsUnderLoad() {
|
||||
if !device.mac.CheckMAC2(elem.packet, elem.source) {
|
||||
|
||||
// verify MAC2 field
|
||||
|
||||
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
|
||||
|
||||
// construct cookie reply
|
||||
|
||||
logDebug.Println("Sending cookie reply to:", elem.source.String())
|
||||
logDebug.Println(
|
||||
"Sending cookie reply to:",
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
|
||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
|
||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8])
|
||||
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
|
||||
if err != nil {
|
||||
logError.Println("Failed to create cookie reply:", err)
|
||||
return
|
||||
@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
writer := bytes.NewBuffer(temp[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
_, err = device.net.conn.WriteToUDP(
|
||||
writer.Bytes(),
|
||||
elem.source,
|
||||
)
|
||||
device.net.bind.Send(writer.Bytes(), elem.endpoint)
|
||||
if err != nil {
|
||||
logDebug.Println("Failed to send cookie reply:", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !device.ratelimiter.Allow(elem.source.IP) {
|
||||
// check ratelimiter
|
||||
|
||||
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
|
||||
if peer == nil {
|
||||
logInfo.Println(
|
||||
"Recieved invalid initiation message from",
|
||||
elem.source.IP.String(),
|
||||
elem.source.Port,
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
|
||||
peer.TimerAnyAuthenticatedPacketReceived()
|
||||
|
||||
// update endpoint
|
||||
// TODO: Discover destination address also, only update on change
|
||||
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = elem.source
|
||||
peer.endpoint = elem.endpoint
|
||||
peer.mutex.Unlock()
|
||||
|
||||
// create response
|
||||
@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// send response
|
||||
|
||||
_, err = peer.SendBuffer(packet)
|
||||
err = peer.SendBuffer(packet)
|
||||
if err == nil {
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
} else {
|
||||
logError.Println("Failed to send response to:", peer.String(), err)
|
||||
}
|
||||
|
||||
case MessageResponseType:
|
||||
@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
|
||||
if peer == nil {
|
||||
logInfo.Println(
|
||||
"Recieved invalid response message from",
|
||||
elem.source.IP.String(),
|
||||
elem.source.Port,
|
||||
elem.endpoint.DstToString(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = elem.endpoint
|
||||
peer.mutex.Unlock()
|
||||
|
||||
logDebug.Println("Received handshake initation from", peer)
|
||||
|
||||
peer.TimerEphemeralKeyCreated()
|
||||
@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
}
|
||||
kp.mutex.Unlock()
|
||||
|
||||
// update endpoint
|
||||
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = elem.endpoint
|
||||
peer.mutex.Unlock()
|
||||
|
||||
// check for keep-alive
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.routingTable.LookupIPv4(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
logInfo.Println(
|
||||
"IPv4 packet with unallowed source address from",
|
||||
peer.String(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.routingTable.LookupIPv6(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
logInfo.Println(
|
||||
"IPv6 packet with unallowed source address from",
|
||||
peer.String(),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
continue
|
||||
}
|
||||
|
||||
// write to tun
|
||||
// write to tun device
|
||||
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
_, err := device.tun.device.Write(elem.packet)
|
||||
|
23
src/send.go
23
src/send.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
@ -105,26 +104,6 @@ func addToEncryptionQueue(
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
||||
peer.device.net.mutex.RLock()
|
||||
defer peer.device.net.mutex.RUnlock()
|
||||
|
||||
peer.mutex.RLock()
|
||||
defer peer.mutex.RUnlock()
|
||||
|
||||
endpoint := peer.endpoint
|
||||
if endpoint == nil {
|
||||
return 0, errors.New("No known endpoint for peer")
|
||||
}
|
||||
|
||||
conn := peer.device.net.conn
|
||||
if conn == nil {
|
||||
return 0, errors.New("No UDP socket for device")
|
||||
}
|
||||
|
||||
return conn.WriteToUDP(buffer, endpoint)
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
* into nonce queue for peer
|
||||
*
|
||||
@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||
// send message and return buffer to pool
|
||||
|
||||
length := uint64(len(elem.packet))
|
||||
_, err := peer.SendBuffer(elem.packet)
|
||||
err := peer.SendBuffer(elem.packet)
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
if err != nil {
|
||||
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
|
||||
|
@ -20,6 +20,14 @@
|
||||
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
|
||||
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
|
||||
# details on how this is accomplished.
|
||||
|
||||
# This code is ported to the WireGuard-Go directly from the kernel project.
|
||||
#
|
||||
# Please ensure that you have installed the newest version of the WireGuard
|
||||
# tools from the WireGuard project and before running these tests as:
|
||||
#
|
||||
# ./netns.sh <path to wireguard-go>
|
||||
|
||||
set -e
|
||||
|
||||
exec 3>&1
|
||||
@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
|
||||
netns0="wg-test-$$-0"
|
||||
netns1="wg-test-$$-1"
|
||||
netns2="wg-test-$$-2"
|
||||
program="../wireguard-go"
|
||||
export LOG_LEVEL="error"
|
||||
program=$1
|
||||
export LOG_LEVEL="info"
|
||||
|
||||
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
|
||||
pp() { pretty "" "$*"; "$@"; }
|
||||
@ -72,13 +80,11 @@ pp ip netns add $netns2
|
||||
ip0 link set up dev lo
|
||||
|
||||
# ip0 link add dev wg1 type wireguard
|
||||
n0 $program -f wg1 &
|
||||
sleep 1
|
||||
n0 $program wg1
|
||||
ip0 link set wg1 netns $netns1
|
||||
|
||||
# ip0 link add dev wg1 type wireguard
|
||||
n0 $program -f wg2 &
|
||||
sleep 1
|
||||
n0 $program wg2
|
||||
ip0 link set wg2 netns $netns2
|
||||
|
||||
key1="$(pp wg genkey)"
|
||||
@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
|
||||
ip0 -4 addr add 127.212.121.99/8 dev lo
|
||||
n0 wg set wg1 listen-port 9999
|
||||
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
|
||||
n1 ping6 -W 1 -c 1 fd00::20000
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
|
||||
n1 ping6 -W 1 -c 1 fd00::2
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
|
||||
|
||||
# Test using IPv6 that roaming works
|
||||
n1 wg set wg1 listen-port 9998
|
||||
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
|
||||
n1 ping -W 1 -c 1 192.168.241.2
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
|
||||
|
||||
# Test that crypto-RP filter works
|
||||
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
|
||||
@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
|
||||
! read -r -N 1 -t 1 out <&4
|
||||
kill $nmap_pid
|
||||
n0 wg set wg1 peer "$more_specific_key" remove
|
||||
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
|
||||
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
|
||||
|
||||
ip1 link del wg1
|
||||
ip2 link del wg2
|
||||
@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
|
||||
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
|
||||
n1 ping -W 1 -c 1 192.168.241.2
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
|
||||
pp sleep 3
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
@ -289,7 +295,7 @@ ip2 link del wg2
|
||||
# ip1 link add dev wg1 type wireguard
|
||||
# ip2 link add dev wg1 type wireguard
|
||||
n1 $program wg1
|
||||
n2 $program wg1
|
||||
n2 $program wg2
|
||||
|
||||
configure_peers
|
||||
|
||||
@ -336,17 +342,83 @@ waitiface $netns1 veth1
|
||||
waitiface $netns2 veth2
|
||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
|
||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
|
||||
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
|
||||
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
|
||||
n2 ping -W 1 -c 1 192.168.241.1
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
|
||||
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
|
||||
|
||||
ip1 link del veth1
|
||||
ip1 link del wg1
|
||||
ip2 link del wg2
|
||||
|
||||
# Test that Netlink/IPC is working properly by doing things that usually cause split responses
|
||||
|
||||
n0 $program wg0
|
||||
sleep 5
|
||||
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
|
||||
for a in {1..255}; do
|
||||
for b in {0..255}; do
|
||||
config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
|
||||
done
|
||||
done
|
||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||
i=0
|
||||
for ip in $(n0 wg show wg0 allowed-ips); do
|
||||
((++i))
|
||||
done
|
||||
((i == 255*256*2+1))
|
||||
ip0 link del wg0
|
||||
|
||||
n0 $program wg0
|
||||
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
|
||||
for a in {1..40}; do
|
||||
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
|
||||
for b in {1..52}; do
|
||||
config+=( "AllowedIPs=$a.$b.0.0/16" )
|
||||
done
|
||||
done
|
||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||
i=0
|
||||
while read -r line; do
|
||||
j=0
|
||||
for ip in $line; do
|
||||
((++j))
|
||||
done
|
||||
((j == 53))
|
||||
((++i))
|
||||
done < <(n0 wg show wg0 allowed-ips)
|
||||
((i == 40))
|
||||
ip0 link del wg0
|
||||
|
||||
n0 $program wg0
|
||||
config=( )
|
||||
for i in {1..29}; do
|
||||
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
|
||||
done
|
||||
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
|
||||
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
|
||||
n0 wg showconf wg0 > /dev/null
|
||||
ip0 link del wg0
|
||||
|
||||
! n0 wg show doesnotexist || false
|
||||
|
||||
declare -A objects
|
||||
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
|
||||
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
|
||||
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
|
||||
done < /dev/kmsg
|
||||
alldeleted=1
|
||||
for object in "${!objects[@]}"; do
|
||||
if [[ ${objects["$object"]} != *createddestroyed ]]; then
|
||||
echo "Error: $object: merely ${objects["$object"]}" >&3
|
||||
alldeleted=0
|
||||
fi
|
||||
done
|
||||
[[ $alldeleted -eq 1 ]]
|
||||
pretty "" "Objects that were created were also destroyed."
|
||||
|
@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||
break AttemptHandshakes
|
||||
}
|
||||
|
||||
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
|
||||
|
||||
// marshal and send
|
||||
// marshal handshake message
|
||||
|
||||
writer := bytes.NewBuffer(temp[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
peer.mac.AddMacs(packet)
|
||||
|
||||
_, err = peer.SendBuffer(packet)
|
||||
if err != nil {
|
||||
// send to endpoint
|
||||
|
||||
err = peer.SendBuffer(packet)
|
||||
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
|
||||
timeout := time.NewTimer(RekeyTimeout + jitter)
|
||||
if err == nil {
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
logDebug.Println(
|
||||
"Handshake initiation attempt",
|
||||
attempts, "sent to", peer.String(),
|
||||
)
|
||||
} else {
|
||||
logError.Println(
|
||||
"Failed to send handshake initiation message to",
|
||||
peer.String(), ":", err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
|
||||
// set handshake timeout
|
||||
|
||||
timeout := time.NewTimer(RekeyTimeout + jitter)
|
||||
logDebug.Println(
|
||||
"Handshake initiation attempt",
|
||||
attempts, "sent to", peer.String(),
|
||||
)
|
||||
|
||||
// wait for handshake or timeout
|
||||
|
||||
select {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
@ -15,6 +16,7 @@ const (
|
||||
)
|
||||
|
||||
type TUNDevice interface {
|
||||
File() *os.File // returns the file descriptor of the device
|
||||
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
|
||||
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
|
||||
MTU() (int, error) // returns the MTU of the device
|
||||
@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
|
||||
if !device.tun.isUp.Get() {
|
||||
logInfo.Println("Interface set up")
|
||||
device.tun.isUp.Set(true)
|
||||
updateUDPConn(device)
|
||||
UpdateUDPListener(device)
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
|
||||
if device.tun.isUp.Get() {
|
||||
logInfo.Println("Interface set down")
|
||||
device.tun.isUp.Set(false)
|
||||
closeUDPConn(device)
|
||||
CloseUDPListener(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,6 +56,10 @@ type NativeTun struct {
|
||||
events chan TUNEvent // device related events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
return tun.fd
|
||||
}
|
||||
|
||||
func (tun *NativeTun) RoutineNetlinkListener() {
|
||||
sock := int(C.bind_rtmgrp())
|
||||
if sock < 0 {
|
||||
@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
|
||||
val := binary.LittleEndian.Uint32(ifr[16:20])
|
||||
if val >= (1 << 31) {
|
||||
return int(val-(1<<31)) - (1 << 31), nil
|
||||
return int(toInt32(val)), nil
|
||||
}
|
||||
return int(val), nil
|
||||
}
|
||||
@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
|
||||
device := &NativeTun{
|
||||
fd: fd,
|
||||
name: name,
|
||||
events: make(chan TUNEvent, 5),
|
||||
errors: make(chan error, 5),
|
||||
}
|
||||
|
||||
// start event listener
|
||||
|
||||
var err error
|
||||
device.index, err = getIFIndex(device.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go device.RoutineNetlinkListener()
|
||||
|
||||
// set default MTU
|
||||
|
||||
return device, device.setMTU(DefaultMTU)
|
||||
}
|
||||
|
||||
func CreateTUN(name string) (TUNDevice, error) {
|
||||
|
||||
// open clone device
|
||||
|
83
src/uapi.go
83
src/uapi.go
@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
send("private_key=" + device.privateKey.ToHex())
|
||||
}
|
||||
|
||||
if device.net.addr != nil {
|
||||
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
|
||||
if device.net.port != 0 {
|
||||
send(fmt.Sprintf("listen_port=%d", device.net.port))
|
||||
}
|
||||
|
||||
if device.net.fwmark != 0 {
|
||||
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
|
||||
}
|
||||
@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
||||
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
||||
if peer.endpoint != nil {
|
||||
send("endpoint=" + peer.endpoint.String())
|
||||
send("endpoint=" + peer.endpoint.DstToString())
|
||||
}
|
||||
|
||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
||||
@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
case "listen_port":
|
||||
port, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
logError.Println("Failed to parse listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
device.net.mutex.Lock()
|
||||
device.net.addr = addr
|
||||
device.net.mutex.Unlock()
|
||||
|
||||
err = updateUDPConn(device)
|
||||
if err != nil {
|
||||
device.net.port = uint16(port)
|
||||
if err := UpdateUDPListener(device); err != nil {
|
||||
logError.Println("Failed to set listen_port:", err)
|
||||
return &IPCError{Code: ipcErrorPortInUse}
|
||||
}
|
||||
|
||||
// TODO: Clear source address of all peers
|
||||
|
||||
case "fwmark":
|
||||
fwmark, err := strconv.ParseUint(value, 10, 32)
|
||||
|
||||
// parse fwmark field
|
||||
|
||||
fwmark, err := func() (uint32, error) {
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
mark, err := strconv.ParseUint(value, 10, 32)
|
||||
return uint32(mark), err
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
logError.Println("Invalid fwmark", err)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
|
||||
device.net.mutex.Lock()
|
||||
if fwmark > 0 || device.net.fwmark > 0 {
|
||||
device.net.fwmark = uint32(fwmark)
|
||||
err := setMark(
|
||||
device.net.conn,
|
||||
device.net.fwmark,
|
||||
)
|
||||
if err != nil {
|
||||
logError.Println("Failed to set fwmark:", err)
|
||||
device.net.mutex.Unlock()
|
||||
return &IPCError{Code: ipcErrorIO}
|
||||
}
|
||||
|
||||
// TODO: Clear source address of all peers
|
||||
}
|
||||
device.net.fwmark = uint32(fwmark)
|
||||
device.net.mutex.Unlock()
|
||||
|
||||
case "public_key":
|
||||
|
||||
// switch to peer configuration
|
||||
|
||||
deviceConfig = false
|
||||
|
||||
case "replace_peers":
|
||||
@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
device.mutex.RLock()
|
||||
if device.publicKey.Equals(pubKey) {
|
||||
|
||||
// create dummy instance
|
||||
// create dummy instance (not added to device)
|
||||
|
||||
peer = &Peer{}
|
||||
dummy = true
|
||||
@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
}
|
||||
|
||||
case "remove":
|
||||
|
||||
// remove currently selected peer from device
|
||||
|
||||
if value != "true" {
|
||||
logError.Println("Failed to set remove, invalid value:", value)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
dummy = true
|
||||
|
||||
case "preshared_key":
|
||||
|
||||
// update PSK
|
||||
|
||||
peer.mutex.Lock()
|
||||
err := peer.handshake.presharedKey.FromHex(value)
|
||||
peer.mutex.Unlock()
|
||||
@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
}
|
||||
|
||||
case "endpoint":
|
||||
addr, err := parseEndpoint(value)
|
||||
|
||||
// set endpoint destination
|
||||
|
||||
err := func() error {
|
||||
peer.mutex.Lock()
|
||||
defer peer.mutex.Unlock()
|
||||
endpoint, err := CreateEndpoint(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
peer.endpoint = endpoint
|
||||
signalSend(peer.signal.handshakeReset)
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
logError.Println("Failed to set endpoint:", value)
|
||||
return &IPCError{Code: ipcErrorInvalid}
|
||||
}
|
||||
peer.mutex.Lock()
|
||||
peer.endpoint = addr
|
||||
peer.mutex.Unlock()
|
||||
signalSend(peer.signal.handshakeReset)
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
|
||||
|
@ -10,12 +10,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ipcErrorIO = -int64(unix.EIO)
|
||||
ipcErrorProtocol = -int64(unix.EPROTO)
|
||||
ipcErrorInvalid = -int64(unix.EINVAL)
|
||||
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
socketDirectory = "/var/run/wireguard"
|
||||
socketName = "%s.sock"
|
||||
ipcErrorIO = -int64(unix.EIO)
|
||||
ipcErrorProtocol = -int64(unix.EPROTO)
|
||||
ipcErrorInvalid = -int64(unix.EINVAL)
|
||||
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
socketDirectory = "/var/run/wireguard"
|
||||
socketName = "%s.sock"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func connectUnixSocket(path string) (net.Listener, error) {
|
||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
|
||||
// attempt inital connection
|
||||
// wrap file in listener
|
||||
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// check if active
|
||||
|
||||
_, err = net.Dial("unix", path)
|
||||
if err == nil {
|
||||
return nil, errors.New("Unix socket in use")
|
||||
}
|
||||
|
||||
// attempt cleanup
|
||||
|
||||
err = os.Remove(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.Listen("unix", path)
|
||||
}
|
||||
|
||||
func NewUAPIListener(name string) (net.Listener, error) {
|
||||
|
||||
// check if path exist
|
||||
|
||||
err := os.MkdirAll(socketDirectory, 077)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// open UNIX socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
|
||||
listener, err := connectUnixSocket(socketPath)
|
||||
listener, err := net.FileListener(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
||||
|
||||
// watch for deletion of socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
|
||||
uapi.inotifyFd, err = unix.InotifyInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
||||
go func(l *UAPIListener) {
|
||||
var buff [4096]byte
|
||||
for {
|
||||
unix.Read(uapi.inotifyFd, buff[:])
|
||||
// start with lstat to avoid race condition
|
||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
unix.Read(uapi.inotifyFd, buff[:])
|
||||
}
|
||||
}(uapi)
|
||||
|
||||
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
|
||||
// check if path exist
|
||||
|
||||
err := os.MkdirAll(socketDirectory, 0600)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// open UNIX socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := func() (*net.UnixListener, error) {
|
||||
|
||||
// initial connection attempt
|
||||
|
||||
listener, err := net.ListenUnix("unix", addr)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// check if socket already active
|
||||
|
||||
_, err = net.Dial("unix", socketPath)
|
||||
if err == nil {
|
||||
return nil, errors.New("unix socket in use")
|
||||
}
|
||||
|
||||
// cleanup & attempt again
|
||||
|
||||
err = os.Remove(socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUnix("unix", addr)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listener.File()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user