[client] Refactor/iface pkg (#2646)

Refactor the flat code structure
This commit is contained in:
Zoltan Papp
2024-10-02 18:24:22 +02:00
committed by GitHub
parent 7e5d3bdfe2
commit fd67892cb4
105 changed files with 505 additions and 438 deletions

142
client/iface/bind/bind.go Normal file
View File

@ -0,0 +1,142 @@
package bind
import (
"fmt"
"net"
"runtime"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
muUDPMux sync.Mutex
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

View File

@ -0,0 +1,440 @@
package bind
import (
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/pion/ice/v3"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
)
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
mu sync.Mutex
// for UDP connection listen at unspecified address
localAddrsForUnspecified []net.Addr
}
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
// Required for gathering local addresses
// in case a un UDPConn is passed which does not
// bind to a specific local address.
Net transport.Net
InterfaceFilter func(interfaceName string) bool
}
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := n.Interfaces()
if err != nil {
return ips, err
}
var IPv4Requested, IPv6Requested bool
for _, typ := range networkTypes {
if typ.IsIPv4() {
IPv4Requested = true
}
if typ.IsIPv6() {
IPv6Requested = true
}
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch addr := addr.(type) {
case *net.IPNet:
ip = addr.IP
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}
if ipv4 := ip.To4(); ipv4 == nil {
if !IPv6Requested {
continue
} else if !isSupportedIPv6(ip) {
continue
}
} else if !IPv4Requested {
continue
}
if ipFilter != nil && !ipFilter(ip) {
continue
}
ips = append(ips, ip)
}
}
return ips, nil
}
// The conditions of invalidation written below are defined in
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
func isSupportedIPv6(ip net.IP) bool {
if len(ip) != net.IPv6len ||
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() {
return false
}
return true
}
func isZeros(ip net.IP) bool {
for i := 0; i < len(ip); i++ {
if ip[i] != 0 {
return false
}
}
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String())
}
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
delete(m.addressMap, addr)
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.connsIPv4 {
_ = c.Close()
}
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
_ = m.params.UDPConn.Close()
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buf []byte
}
func newBufferHolder(size int) *bufferHolder {
return &bufferHolder{
buf: make([]byte, size),
}
}

View File

@ -0,0 +1,369 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
*/
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
)
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
// ReadFromConn reads from the m.params.UDPConn provided upon the creation. It expects STUN packets only, however, will
// just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, 1500)
for {
select {
case <-ctx.Done():
log.Debugf("stopped reading from the UDPConn due to finished context")
return
default:
n, a, err := m.params.UDPConn.ReadFrom(buf)
if err != nil {
log.Errorf("error while reading packet: %s", err)
continue
}
msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...),
}
err = msg.Decode()
if err != nil {
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
continue
}
err = m.HandleSTUNMessage(msg, a)
if err != nil {
log.Errorf("error while handling STUn message: %s", err)
}
}
}
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
}
// GetSharedConn returns the shared udp conn
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
return m.params.UDPConn
}
// GetListenAddresses returns the listen addr of this UDP
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
// Not implemented yet.
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
return nil, fmt.Errorf("not implemented yet")
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
// All other STUN packets will be forwarded to the UDPMux
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
return nil
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
if !ok {
return fmt.Errorf("no XOR address mapping")
}
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err != nil {
return err
}
m.xorMappedMap[stunAddr.String()] = mappedAddr
mappedAddr.SetAddr(&addr)
return nil
}
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// if we already have a mapping for this STUN server (address already received)
// and if it is not too old we return it without making a new request to STUN server
if ok {
if mappedAddr.expired() {
mappedAddr.closeWaiters()
delete(m.xorMappedMap, serverAddr.String())
ok = false
} else if mappedAddr.pending() {
ok = false
}
}
m.mu.Unlock()
if ok {
return mappedAddr.addr, nil
}
// otherwise, make a STUN request to discover the address
// or wait for already sent request to complete
waitAddrReceived, err := m.sendSTUN(serverAddr)
if err != nil {
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
}
// block until response was handled by the connWorker routine and XORMappedAddress was updated
select {
case <-waitAddrReceived:
// when channel closed, addr was obtained
var addr *stun.XORMappedAddress
m.mu.Lock()
// A very odd case that mappedAddr is nil.
// Can happen when the deadline property is larger than params.XORMappedAddrCacheTTL.
// Or when we don't receive a response to our m.sendSTUN request (the response is handled asynchronously) and
// the XORMapped expires meanwhile triggering a closure of the waitAddrReceived channel.
// We protect the code from panic here.
if mappedAddr, ok := m.xorMappedMap[serverAddr.String()]; ok {
addr = mappedAddr.addr
}
m.mu.Unlock()
if addr == nil {
return nil, fmt.Errorf("no XOR address mapping")
}
return addr, nil
case <-time.After(deadline):
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
}
}
// sendSTUN sends a STUN request via UDP conn.
//
// The returned channel is closed when the STUN response has been received.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) sendSTUN(serverAddr net.Addr) (chan struct{}, error) {
m.mu.Lock()
defer m.mu.Unlock()
// if record present in the map, we already sent a STUN request,
// just wait when waitAddrReceived will be closed
addrMap, ok := m.xorMappedMap[serverAddr.String()]
if !ok {
addrMap = &xorMapped{
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
waitAddrReceived: make(chan struct{}),
}
m.xorMappedMap[serverAddr.String()] = addrMap
}
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
if err != nil {
return nil, err
}
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
return nil, err
}
return addrMap.waitAddrReceived, nil
}
type xorMapped struct {
addr *stun.XORMappedAddress
waitAddrReceived chan struct{}
expiresAt time.Time
}
func (a *xorMapped) closeWaiters() {
select {
case <-a.waitAddrReceived:
// notify was close, ok, that means we received duplicate response
// just exit
break
default:
// notify that twe have a new addr
close(a.waitAddrReceived)
}
}
func (a *xorMapped) pending() bool {
return a.addr == nil
}
func (a *xorMapped) expired() bool {
return a.expiresAt.Before(time.Now())
}
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
a.addr = addr
a.closeWaiters()
}

View File

@ -0,0 +1,233 @@
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
*/
import (
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v3/packetio"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
}
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
}
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
}
return dataLen, rAddr, nil
}
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, rAddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
close(c.closedChan)
})
return err
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
}
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
}
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
}
return nil
}
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
}
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
}
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
}
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
}
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
}
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil
}

View File

@ -0,0 +1,5 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@ -0,0 +1,220 @@
//go:build (linux && !android) || freebsd
package configurer
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type KernelConfigurer struct {
deviceName string
}
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
return &KernelConfigurer{
deviceName: deviceName,
}
}
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port)
}
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String())
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName)
}
return nil
}
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) //nolint:gocritic
break
}
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf("remove allowed IP %s on interface %s: %w", allowedIP, c.deviceName, err)
}
return nil
}
func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("get device %s: %w", ifaceName, err)
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, ErrPeerNotFound
}
func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
return wg.ConfigureDevice(c.deviceName, config)
}
func (c *KernelConfigurer) Close() {
}
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
}
return WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}, nil
}

View File

@ -0,0 +1,6 @@
//go:build linux || windows || freebsd
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0"

View File

@ -0,0 +1,6 @@
//go:build darwin
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100"

View File

@ -0,0 +1,26 @@
//go:build !windows
package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err)
return nil, err
}
return listener, nil
}

View File

@ -0,0 +1,11 @@
package configurer
import (
"net"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
return ipc.UAPIListen(deviceName)
}

View File

@ -0,0 +1,369 @@
package configurer
import (
"encoding/hex"
"fmt"
"net"
"os"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
}
wgCfg.startUAPI()
return wgCfg
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipc, "\n")
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{},
}
foundPeer := false
removedAllowedIP := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, reset the flag.
if strings.HasPrefix(line, "public_key=") && foundPeer {
foundPeer = false
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
// If we're within the details of the found peer and find the specific allowed IP, skip this line
if foundPeer && line == "allowed_ip="+ip {
removedAllowedIP = true
continue
}
// Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peer.AllowedIPs = append(peer.AllowedIPs, *ipNet)
}
}
if !removedAllowedIP {
return ErrAllowedIPNotFound
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error
t.uapiListener, err = openUAPI(t.deviceName)
if err != nil {
log.Errorf("failed to open uapi listener: %v", err)
return
}
go func(uapi net.Listener) {
for {
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Tracef("%s", uapiErr)
return
}
go func() {
t.device.IpcHandle(uapiConn)
}()
}
}(t.uapiListener)
}
func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil {
err := t.uapiListener.Close()
if err != nil {
log.Errorf("failed to close uapi listener: %v", err)
}
}
if runtime.GOOS == "linux" {
sockPath := "/var/run/wireguard/" + t.deviceName + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil {
_ = os.Remove(sockPath)
}
}
}
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)
}
stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
}
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return nil, fmt.Errorf("parse key: %w", err)
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer {
break
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
}
}
}
// todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
}
func toWgUserspaceString(wgCfg wgtypes.Config) string {
var sb strings.Builder
if wgCfg.PrivateKey != nil {
hexKey := hex.EncodeToString(wgCfg.PrivateKey[:])
sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey))
}
if wgCfg.ListenPort != nil {
sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort))
}
if wgCfg.ReplacePeers {
sb.WriteString("replace_peers=true\n")
}
if wgCfg.FirewallMark != nil {
sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark))
}
for _, p := range wgCfg.Peers {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
}
return sb.String()
}
func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0
}

View File

@ -0,0 +1,104 @@
package configurer
import (
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var ipcFixture = `
private_key=e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a
listen_port=12912
public_key=b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33
preshared_key=188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52
allowed_ip=192.168.4.4/32
endpoint=[abcd:23::33%2]:51820
public_key=58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376
tx_bytes=38333
rx_bytes=2224
allowed_ip=192.168.4.6/32
persistent_keepalive_interval=111
endpoint=182.122.22.19:3233
public_key=662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58
endpoint=5.152.198.39:51820
allowed_ip=192.168.4.10/32
allowed_ip=192.168.4.11/32
tx_bytes=1212111
rx_bytes=1929999999
protocol_version=1
errno=0
`
func Test_findPeerInfo(t *testing.T) {
tests := []struct {
name string
peerKey string
searchKeys []string
want map[string]string
wantErr bool
}{
{
name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
},
wantErr: false,
},
{
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
"rx_bytes": "2224",
},
wantErr: false,
},
{
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "1212111",
"rx_bytes": "1929999999",
},
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := hex.DecodeString(tt.peerKey)
require.NoError(t, err)
key, err := wgtypes.NewKey(res)
require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
})
}
}

View File

@ -0,0 +1,9 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

18
client/iface/device.go Normal file
View File

@ -0,0 +1,18 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@ -0,0 +1,8 @@
package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error
ProtectSocket(fd int32) bool
}

View File

@ -0,0 +1,29 @@
package device
import (
"fmt"
"net"
)
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
}
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@ -0,0 +1,6 @@
package device
type MobileIFaceArguments struct {
TunAdapter TunAdapter // only for Android
TunFd int // only for iOS
}

View File

@ -0,0 +1,140 @@
//go:build android
package device
import (
"strings"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
tunAdapter: tunAdapter,
}
}
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
log.Info("create tun interface")
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return nil, err
}
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to create Android interface: %s", err)
return nil, err
}
t.name = name
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
func (t *WGTunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
t.device = nil
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *WGTunDevice) Device() *device.Device {
return t.device
}
func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}
func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}
func searchDomainsToString(searchDomains []string) string {
return strings.Join(searchDomains, ";")
}

View File

@ -0,0 +1,141 @@
//go:build !ios
package device
import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type TunDevice struct {
name string
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}
func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
t.device = nil
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
}
// dummy ipv6 so routing works
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
if out, err := routeCmd.CombinedOutput(); err != nil {
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
}
return nil
}

View File

@ -0,0 +1,100 @@
package device
import (
"net"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// FilteredDevice to override Read or Write of packets
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceFilter constructor function
func newDeviceFilter(device tun.Device) *FilteredDevice {
return &FilteredDevice{
Device: device,
}
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
return n, nil
}
// Write wraps write method with filtering feature
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
}
// SetFilter sets packet filter to device
func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
}

View File

@ -0,0 +1,223 @@
package device
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
mocks "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestDeviceWrapperRead(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
t.Run("read ICMP", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{192, 168, 0, 1},
DstIP: net.IP{100, 200, 0, 1},
}
icmpLayer := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
Id: 1,
Seq: 1,
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
icmpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
sizes[0] = len(bufs[0])
return 1, nil
})
wrapped := newDeviceFilter(tun)
bufs := [][]byte{{}}
sizes := []int{0}
offset := 0
n, err := wrapped.Read(bufs, sizes, offset)
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if n != 1 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("write TCP", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 9},
DstIP: net.IP{100, 200, 0, 10},
}
// create TCP layer packet
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(34423),
DstPort: layers.TCPPort(8080),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs := [][]byte{buffer.Bytes()}
mockBufs[0] = buffer.Bytes()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
wrapped := newDeviceFilter(tun)
bufs := [][]byte{buffer.Bytes()}
n, err := wrapped.Write(bufs, 0)
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if n != 1 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("drop write UDP package", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 11},
DstIP: net.IP{100, 200, 0, 20},
}
// create TCP layer packet
tcpLayer := &layers.UDP{
SrcPort: layers.UDPPort(27278),
DstPort: layers.UDPPort(53),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs := [][]byte{}
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()}
n, err := wrapped.Write(bufs, 0)
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if n != 0 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("drop read UDP package", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 11},
DstIP: net.IP{100, 200, 0, 20},
}
// create TCP layer packet
tcpLayer := &layers.UDP{
SrcPort: layers.UDPPort(19243),
DstPort: layers.UDPPort(1024),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
sizes[0] = len(bufs[0])
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{{}}
sizes := []int{0}
offset := 0
n, err := wrapped.Read(bufs, sizes, offset)
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
return
}
})
}

View File

@ -0,0 +1,134 @@
//go:build ios
// +build ios
package device
import (
"os"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type TunDevice struct {
name string
address WGAddress
port int
key string
iceBind *bind.ICEBind
tunFd int
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
iceBind: bind.NewICEBind(transportNet, filterFn),
tunFd: tunFd,
}
}
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
log.Errorf("Unable to set tun fd as non blocking: %v", err)
_ = unix.Close(dupTunFd)
return nil, err
}
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)
return nil, err
}
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
t.device = nil
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@ -0,0 +1,163 @@
//go:build (linux && !android) || freebsd
package device
import (
"context"
"fmt"
"net"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address WGAddress
wgPort int
key string
mtu int
ctx context.Context
ctxCancel context.CancelFunc
transportNet transport.Net
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
ctxCancel: cancel,
name: name,
address: address,
wgPort: wgPort,
key: key,
mtu: mtu,
transportNet: transportNet,
}
}
func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name)
if err := link.recreate(); err != nil {
return nil, fmt.Errorf("recreate: %w", err)
}
t.link = link
if err := t.assignAddr(); err != nil {
return nil, fmt.Errorf("assign addr: %w", err)
}
// TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(t.mtu); err != nil {
return nil, fmt.Errorf("set mtu: %w", err)
}
configurer := configurer.NewKernelConfigurer(t.name)
if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
}
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
if t.link == nil {
return nil, fmt.Errorf("device is not ready yet")
}
log.Debugf("bringing up interface: %s", t.name)
if err := t.link.up(); err != nil {
log.Errorf("error bringing up interface: %s", t.name)
return nil, err
}
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter())
if err != nil {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux
log.Debugf("device is ready to use: %s", t.name)
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *TunKernelDevice) Close() error {
if t.link == nil {
return nil
}
t.ctxCancel()
var closErr error
if err := t.link.Close(); err != nil {
log.Debugf("failed to close link: %s", err)
closErr = err
}
if t.udpMux != nil {
if err := t.udpMux.Close(); err != nil {
log.Debugf("failed to close udp mux: %s", err)
closErr = err
}
if err := t.udpMuxConn.Close(); err != nil {
log.Debugf("failed to close udp mux connection: %s", err)
closErr = err
}
}
return closErr
}
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunKernelDevice) DeviceName() string {
return t.name
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
// assignAddr Adds IP address to the tunnel interface
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}

View File

@ -0,0 +1,120 @@
//go:build !android
// +build !android
package device
import (
"fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
)
type TunNetstackDevice struct {
name string
address WGAddress
port int
key string
mtu int
listenAddress string
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("netstack device is ready to use")
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@ -0,0 +1,145 @@
//go:build (linux && !android) || freebsd
package device
import (
"fmt"
"os"
"runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type USPDevice struct {
name string
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
return &USPDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn)}
}
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *USPDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
func (t *USPDevice) DeviceName() string {
return t.name
}
func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)
}
func checkUser() {
if runtime.GOOS == "freebsd" {
euid := os.Geteuid()
if euid != 0 {
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
}
}
}

View File

@ -0,0 +1,172 @@
package device
import (
"fmt"
"net/netip"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}
func getGUID() (windows.GUID, error) {
guidString := defaultWindowsGUIDSTring
if CustomWindowsGUIDString != "" {
guidString = CustomWindowsGUIDString
}
return windows.GUIDFromString(guidString)
}
func (t *TunDevice) Create() (WGConfigurer, error) {
guid, err := getGUID()
if err != nil {
log.Errorf("failed to get GUID: %s", err)
return nil, err
}
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
nbiface, err := luid.IPInterface(windows.AF_INET)
if err != nil {
t.device.Close()
return nil, fmt.Errorf("got error when getting ip interface %s", err)
}
nbiface.NLMTU = uint32(t.mtu)
err = nbiface.Set()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err)
}
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
t.device = nil
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
guid, err := luid.GUID()
if err != nil {
return "", err
}
return guid.String(), nil
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
}

View File

@ -0,0 +1,20 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@ -0,0 +1,8 @@
//go:build (!linux && !freebsd) || android
package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
}

View File

@ -0,0 +1,18 @@
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@ -0,0 +1,360 @@
//go:build linux && !android
// Package iface provides wireguard network interface creation and management
package device
import (
"bufio"
"errors"
"fmt"
"io"
"io/fs"
"math"
"os"
"path/filepath"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// Holds logic to check existence of kernel modules used by wireguard interfaces
// Copied from https://github.com/paultag/go-modprobe and
// https://github.com/pmorjan/kmod
type status int
const (
defaultModuleDir = "/lib/modules"
unknown status = iota
unloaded
unloading
loading
live
inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
)
type module struct {
name string
path string
}
var (
// ErrModuleNotFound is the error resulting if a module can't be found.
ErrModuleNotFound = errors.New("module not found")
moduleLibDir = defaultModuleDir
// get the root directory for the kernel modules. If this line panics,
// it's because getModuleRoot has failed to get the uname of the running
// kernel (likely a non-POSIX system, but maybe a broken kernel?)
moduleRoot = getModuleRoot()
)
// Get the module root (/lib/modules/$(uname -r)/)
func getModuleRoot() string {
uname := unix.Utsname{}
if err := unix.Uname(&uname); err != nil {
panic(err)
}
i := 0
for ; uname.Release[i] != 0; i++ {
}
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true
}
log.Infof("couldn't access device /dev/net/tun, go error %v, "+
"will attempt to load tun module, if running on container add flag --cap-add=NET_ADMIN", err)
tunLoaded, err := tryToLoadModule("tun")
if err != nil {
log.Errorf("unable to find or load tun module, got error: %v", err)
}
return tunLoaded
}
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
}
func canCreateFakeWireGuardInterface() bool {
link := newWGLink("mustnotexist")
// We willingly try to create a device with an invalid
// MTU here as the validation of the MTU will be performed after
// the validation of the link kind and hence allows us to check
// for the existence of the wireguard module without actually
// creating a link.
//
// As a side-effect, this will also let the kernel lazy-load
// the wireguard module.
link.attrs.MTU = math.MaxInt
err := netlink.LinkAdd(link)
return errors.Is(err, syscall.EINVAL)
}
func tryToLoadModule(moduleName string) (bool, error) {
if isModuleEnabled(moduleName) {
return true, nil
}
modulePath, err := getModulePath(moduleName)
if err != nil {
return false, fmt.Errorf("couldn't find module path for %s, error: %v", moduleName, err)
}
if modulePath == "" {
return false, nil
}
log.Infof("trying to load %s module", moduleName)
err = loadModuleWithDependencies(moduleName, modulePath)
if err != nil {
return false, fmt.Errorf("couldn't load %s module, error: %v", moduleName, err)
}
return true, nil
}
func isModuleEnabled(name string) bool {
builtin, builtinErr := isBuiltinModule(name)
state, statusErr := moduleStatus(name)
return (builtinErr == nil && builtin) || (statusErr == nil && state >= loading)
}
func getModulePath(name string) (string, error) {
var foundPath string
skipRemainingDirs := false
err := filepath.WalkDir(
moduleRoot,
func(path string, info fs.DirEntry, err error) error {
if skipRemainingDirs {
return fs.SkipDir
}
if err != nil {
// skip broken files
return nil //nolint:nilerr
}
if !info.Type().IsRegular() {
return nil
}
nameFromPath := pathToName(path)
if nameFromPath == name {
foundPath = path
skipRemainingDirs = true
}
return nil
})
if err != nil {
return "", err
}
return foundPath, nil
}
func pathToName(s string) string {
s = filepath.Base(s)
for ext := filepath.Ext(s); ext != ""; ext = filepath.Ext(s) {
s = strings.TrimSuffix(s, ext)
}
return cleanName(s)
}
func cleanName(s string) string {
return strings.ReplaceAll(strings.TrimSpace(s), "-", "_")
}
func isBuiltinModule(name string) (bool, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.builtin"))
if err != nil {
return false, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.builtin file, %v", err)
}
}()
var found bool
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if pathToName(line) == name {
found = true
break
}
}
if err := scanner.Err(); err != nil {
return false, err
}
return found, nil
}
// /proc/modules
//
// name | memory size | reference count | references | state: <Live|Loading|Unloading>
// macvlan 28672 1 macvtap, Live 0x0000000000000000
func moduleStatus(name string) (status, error) {
state := unknown
f, err := os.Open("/proc/modules")
if err != nil {
return state, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing /proc/modules file, %v", err)
}
}()
state = unloaded
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if fields[0] == name {
if fields[2] != "0" {
state = inuse
break
}
switch fields[4] {
case "Live":
state = live
case "Loading":
state = loading
case "Unloading":
state = unloading
}
break
}
}
if err := scanner.Err(); err != nil {
return state, err
}
return state, nil
}
func loadModuleWithDependencies(name, path string) error {
deps, err := getModuleDependencies(name)
if err != nil {
return fmt.Errorf("couldn't load list of module %s dependencies", name)
}
for _, dep := range deps {
err = loadModule(dep.name, dep.path)
if err != nil {
return fmt.Errorf("couldn't load dependency module %s for %s", dep.name, name)
}
}
return loadModule(name, path)
}
func loadModule(name, path string) error {
state, err := moduleStatus(name)
if err != nil {
return err
}
if state >= loading {
return nil
}
f, err := os.Open(path)
if err != nil {
return err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing %s file, %v", path, err)
}
}()
// first try finit_module(2), then init_module(2)
err = unix.FinitModule(int(f.Fd()), "", 0)
if errors.Is(err, unix.ENOSYS) {
buf, err := io.ReadAll(f)
if err != nil {
return err
}
return unix.InitModule(buf, "")
}
return err
}
// getModuleDependencies returns a module dependencies
func getModuleDependencies(name string) ([]module, error) {
f, err := os.Open(filepath.Join(moduleRoot, "/modules.dep"))
if err != nil {
return nil, err
}
defer func() {
err := f.Close()
if err != nil {
log.Errorf("failed closing modules.dep file, %v", err)
}
}()
var deps []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if pathToName(strings.TrimSuffix(fields[0], ":")) == name {
deps = fields
break
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(deps) == 0 {
return nil, ErrModuleNotFound
}
deps[0] = strings.TrimSuffix(deps[0], ":")
var modules []module
for _, v := range deps {
if pathToName(v) != name {
modules = append(modules, module{
name: pathToName(v),
path: filepath.Join(moduleRoot, v),
})
}
}
return modules, nil
}

View File

@ -0,0 +1,225 @@
//go:build linux && !android
package device
import (
"bufio"
"bytes"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func TestGetModuleDependencies(t *testing.T) {
testCases := []struct {
name string
module string
expected []module
}{
{
name: "Get Single Dependency",
module: "bar",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
},
},
{
name: "Get Multiple Dependencies",
module: "baz",
expected: []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
},
},
{
name: "Get No Dependencies",
module: "foo",
expected: []module{},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
modules, err := getModuleDependencies(testCase.module)
require.NoError(t, err)
expected := testCase.expected
for i := range expected {
expected[i].path = moduleRoot + "/" + expected[i].path
}
require.ElementsMatchf(t, modules, expected, "returned modules should match")
})
}
}
func TestIsBuiltinModule(t *testing.T) {
testCases := []struct {
name string
module string
expected bool
}{
{
name: "Built In Should Return True",
module: "foo_bi",
expected: true,
},
{
name: "Not Built In Should Return False",
module: "not_built_in",
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
isBuiltIn, err := isBuiltinModule(testCase.module)
require.NoError(t, err)
require.Equal(t, testCase.expected, isBuiltIn)
})
}
}
func TestModuleStatus(t *testing.T) {
random, err := getRandomLoadedModule(t)
if err != nil {
t.Fatal("should be able to get random module")
}
testCases := []struct {
name string
module string
shouldBeLoaded bool
}{
{
name: "Should Return Module Loading Or Greater Status",
module: random,
shouldBeLoaded: true,
},
{
name: "Should Return Module Unloaded Or Lower Status",
module: "not_loaded_module",
shouldBeLoaded: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
defer resetGlobals()
_, _ = createFiles(t)
state, err := moduleStatus(testCase.module)
require.NoError(t, err)
if testCase.shouldBeLoaded {
require.GreaterOrEqual(t, loading, state, "moduleStatus for %s should return state loading", testCase.module)
} else {
require.Less(t, state, loading, "module should return state unloading or lower")
}
})
}
}
func resetGlobals() {
moduleLibDir = defaultModuleDir
moduleRoot = getModuleRoot()
}
func createFiles(t *testing.T) (string, []module) {
t.Helper()
writeFile := func(path, text string) {
if err := os.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err)
}
}
var u unix.Utsname
if err := unix.Uname(&u); err != nil {
t.Fatal(err)
}
moduleLibDir = t.TempDir()
moduleRoot = getModuleRoot()
if err := os.Mkdir(moduleRoot, 0755); err != nil {
t.Fatal(err)
}
text := "kernel/a/foo.ko:\n"
text += "kernel/a/bar.ko: kernel/a/foo.ko\n"
text += "kernel/a/baz.ko: kernel/a/bar.ko kernel/a/foo.ko\n"
writeFile(filepath.Join(moduleRoot, "/modules.dep"), text)
text = "kernel/a/foo_bi.ko\n"
text += "kernel/a/bar-bi.ko.gz\n"
writeFile(filepath.Join(moduleRoot, "/modules.builtin"), text)
modules := []module{
{name: "foo", path: "kernel/a/foo.ko"},
{name: "bar", path: "kernel/a/bar.ko"},
{name: "baz", path: "kernel/a/baz.ko"},
}
return moduleLibDir, modules
}
func getRandomLoadedModule(t *testing.T) (string, error) {
t.Helper()
f, err := os.Open("/proc/modules")
if err != nil {
return "", err
}
defer func() {
err := f.Close()
if err != nil {
t.Logf("failed closing /proc/modules file, %v", err)
}
}()
lines, err := lineCounter(f)
if err != nil {
return "", err
}
counter := 1
midLine := lines / 2
modName := ""
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if counter == midLine {
if fields[4] == "Unloading" {
continue
}
modName = fields[0]
break
}
counter++
}
if scanner.Err() != nil {
return "", scanner.Err()
}
return modName, nil
}
func lineCounter(r io.Reader) (int, error) {
buf := make([]byte, 32*1024)
count := 0
lineSep := []byte{'\n'}
for {
c, err := r.Read(buf)
count += bytes.Count(buf[:c], lineSep)
switch {
case err == io.EOF:
return count, nil
case err != nil:
return count, err
}
}
}

View File

@ -0,0 +1,81 @@
package device
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
)
type wgLink struct {
name string
link *freebsd.Link
}
func newWGLink(name string) *wgLink {
link := freebsd.NewLink(name)
return &wgLink{
name: name,
link: link,
}
}
// Type returns the interface type
func (l *wgLink) Type() string {
return "wireguard"
}
// Close deletes the link interface
func (l *wgLink) Close() error {
return l.link.Del()
}
func (l *wgLink) recreate() error {
if err := l.link.Recreate(); err != nil {
return fmt.Errorf("recreate: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := l.link.SetMTU(mtu); err != nil {
return fmt.Errorf("set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := l.link.Up(); err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)
}
ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
err = link.AssignAddr(ip, mask)
if err != nil {
return fmt.Errorf("assign addr: %w", err)
}
err = link.Up()
if err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}

View File

@ -0,0 +1,133 @@
//go:build linux && !android
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
type wgLink struct {
attrs *netlink.LinkAttrs
}
func newWGLink(name string) *wgLink {
attrs := netlink.NewLinkAttrs()
attrs.Name = name
return &wgLink{
attrs: &attrs,
}
}
// Attrs returns the Wireguard's default attributes
func (l *wgLink) Attrs() *netlink.LinkAttrs {
return l.attrs
}
// Type returns the interface type
func (l *wgLink) Type() string {
return "wireguard"
}
// Close deletes the link interface
func (l *wgLink) Close() error {
return netlink.LinkDel(l)
}
func (l *wgLink) recreate() error {
name := l.attrs.Name
// check if interface exists
link, err := netlink.LinkByName(name)
if err != nil {
switch err.(type) {
case netlink.LinkNotFoundError:
break
default:
return fmt.Errorf("link by name: %w", err)
}
}
// remove if interface exists
if link != nil {
err = netlink.LinkDel(l)
if err != nil {
return err
}
}
log.Debugf("adding device: %s", name)
err = netlink.LinkAdd(l)
if os.IsExist(err) {
log.Infof("interface %s already exists. Will reuse.", name)
} else if err != nil {
return fmt.Errorf("link add: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := netlink.LinkSetMTU(l, mtu); err != nil {
log.Errorf("error setting MTU on interface: %s", l.attrs.Name)
return fmt.Errorf("link set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := netlink.LinkSetUp(l); err != nil {
log.Errorf("error bringing up interface: %s", l.attrs.Name)
return fmt.Errorf("link setup: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {
return fmt.Errorf("list addr: %w", err)
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(l, &addr)
if err != nil {
return fmt.Errorf("del addr: %w", err)
}
}
}
name := l.attrs.Name
addrStr := address.String()
log.Debugf("adding address %s to interface: %s", addrStr, name)
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
}
err = netlink.AddrAdd(l, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", name, addrStr)
} else if err != nil {
return fmt.Errorf("add addr: %w", err)
}
// On linux, the link must be brought up
if err := netlink.LinkSetUp(l); err != nil {
return fmt.Errorf("link setup: %w", err)
}
return nil
}

View File

@ -0,0 +1,15 @@
package device
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@ -0,0 +1,4 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@ -0,0 +1,16 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@ -0,0 +1,8 @@
package freebsd
import "errors"
var (
ErrDoesNotExist = errors.New("does not exist")
ErrNameDoesNotMatch = errors.New("name does not match")
)

View File

@ -0,0 +1,108 @@
package freebsd
import (
"bufio"
"fmt"
"strconv"
"strings"
)
type iface struct {
Name string
MTU int
Group string
IPAddrs []string
}
func parseError(output []byte) error {
// TODO: implement without allocations
lines := string(output)
if strings.Contains(lines, "does not exist") {
return ErrDoesNotExist
}
return nil
}
func parseIfconfigOutput(output []byte) (*iface, error) {
// TODO: implement without allocations
lines := string(output)
scanner := bufio.NewScanner(strings.NewReader(lines))
var name, mtu, group string
var ips []string
for scanner.Scan() {
line := scanner.Text()
// If line contains ": flags", it's a line with interface information
if strings.Contains(line, ": flags") {
parts := strings.Fields(line)
if len(parts) < 4 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
name = strings.TrimSuffix(parts[0], ":")
if strings.Contains(line, "mtu") {
mtuIndex := 0
for i, part := range parts {
if part == "mtu" {
mtuIndex = i
break
}
}
mtu = parts[mtuIndex+1]
}
}
// If line contains "groups:", it's a line with interface group
if strings.Contains(line, "groups:") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
group = parts[1]
}
// If line contains "inet ", it's a line with IP address
if strings.Contains(line, "inet ") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
ips = append(ips, parts[1])
}
}
if name == "" {
return nil, fmt.Errorf("interface name not found in ifconfig output")
}
mtuInt, err := strconv.Atoi(mtu)
if err != nil {
return nil, fmt.Errorf("failed to parse MTU: %w", err)
}
return &iface{
Name: name,
MTU: mtuInt,
Group: group,
IPAddrs: ips,
}, nil
}
func parseIFName(output []byte) (string, error) {
// TODO: implement without allocations
lines := strings.Split(string(output), "\n")
if len(lines) == 0 || lines[0] == "" {
return "", fmt.Errorf("no output returned")
}
fields := strings.Fields(lines[0])
if len(fields) > 1 {
return "", fmt.Errorf("invalid output")
}
return fields[0], nil
}

View File

@ -0,0 +1,76 @@
package freebsd
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseIfconfigOutput(t *testing.T) {
testOutput := `wg1: flags=8080<NOARP,MULTICAST> metric 0 mtu 1420
options=80000<LINKSTATE>
groups: wg
nd6 options=109<PERFORMNUD,IFDISABLED,NO_DAD>`
expected := &iface{
Name: "wg1",
MTU: 1420,
Group: "wg",
}
result, err := parseIfconfigOutput(([]byte)(testOutput))
if err != nil {
t.Errorf("Error parsing ifconfig output: %v", err)
return
}
assert.Equal(t, expected.Name, result.Name, "Name should match")
assert.Equal(t, expected.MTU, result.MTU, "MTU should match")
assert.Equal(t, expected.Group, result.Group, "Group should match")
}
func TestParseIFName(t *testing.T) {
tests := []struct {
name string
output string
expected string
expectedErr error
}{
{
name: "ValidOutput",
output: "eth0\n",
expected: "eth0",
},
{
name: "ValidOutputOneLine",
output: "eth0",
expected: "eth0",
},
{
name: "EmptyOutput",
output: "",
expectedErr: fmt.Errorf("no output returned"),
},
{
name: "InvalidOutput",
output: "This is an invalid output\n",
expectedErr: fmt.Errorf("invalid output"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := parseIFName(([]byte)(test.output))
assert.Equal(t, test.expected, result, "Interface names should match")
if test.expectedErr != nil {
assert.NotNil(t, err, "Error should not be nil")
assert.EqualError(t, err, test.expectedErr.Error(), "Error messages should match")
} else {
assert.Nil(t, err, "Error should be nil")
}
})
}
}

View File

@ -0,0 +1,239 @@
package freebsd
import (
"bytes"
"errors"
"fmt"
"os/exec"
"strconv"
log "github.com/sirupsen/logrus"
)
const wgIFGroup = "wg"
// Link represents a network interface.
type Link struct {
name string
}
func NewLink(name string) *Link {
return &Link{
name: name,
}
}
// LinkByName retrieves a network interface by its name.
func LinkByName(name string) (*Link, error) {
out, err := exec.Command("ifconfig", name).CombinedOutput()
if err != nil {
if pErr := parseError(out); pErr != nil {
return nil, pErr
}
log.Debugf("ifconfig out: %s", out)
return nil, fmt.Errorf("command run: %w", err)
}
i, err := parseIfconfigOutput(out)
if err != nil {
return nil, fmt.Errorf("parse ifconfig output: %w", err)
}
if i.Name != name {
return nil, ErrNameDoesNotMatch
}
return &Link{name: i.Name}, nil
}
// Recreate - create new interface, remove current before create if it exists
func (l *Link) Recreate() error {
ok, err := l.isExist()
if err != nil {
return fmt.Errorf("is exist: %w", err)
}
if ok {
if err := l.del(l.name); err != nil {
return fmt.Errorf("del: %w", err)
}
}
return l.Add()
}
// Add creates a new network interface.
func (l *Link) Add() error {
parsedName, err := l.create(wgIFGroup)
if err != nil {
return fmt.Errorf("create link: %w", err)
}
if parsedName == l.name {
return nil
}
parsedName, err = l.rename(parsedName, l.name)
if err != nil {
errDel := l.del(parsedName)
if errDel != nil {
return fmt.Errorf("del on rename link: %w: %w", err, errDel)
}
return fmt.Errorf("rename link: %w", err)
}
return nil
}
// Del removes an existing network interface.
func (l *Link) Del() error {
return l.del(l.name)
}
// SetMTU sets the MTU of the network interface.
func (l *Link) SetMTU(mtu int) error {
return l.setMTU(mtu)
}
// AssignAddr assigns an IP address and netmask to the network interface.
func (l *Link) AssignAddr(ip, netmask string) error {
return l.setAddr(ip, netmask)
}
func (l *Link) Up() error {
return l.up(l.name)
}
func (l *Link) Down() error {
return l.down(l.name)
}
func (l *Link) isExist() (bool, error) {
_, err := LinkByName(l.name)
if errors.Is(err, ErrDoesNotExist) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("link by name: %w", err)
}
return true, nil
}
func (l *Link) create(groupName string) (string, error) {
cmd := exec.Command("ifconfig", groupName, "create")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("create %s interface: %w", groupName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse interface name: %w", err)
}
return interfaceName, nil
}
func (l *Link) rename(oldName, newName string) (string, error) {
cmd := exec.Command("ifconfig", oldName, "name", newName)
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("change name %q -> %q: %w", oldName, newName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse new name: %w", err)
}
return interfaceName, nil
}
func (l *Link) del(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "destroy")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("destroy %s interface: %w", name, err)
}
return nil
}
func (l *Link) setMTU(mtu int) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "mtu", strconv.Itoa(mtu))
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface mtu: %w", err)
}
return nil
}
func (l *Link) setAddr(ip, netmask string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "inet", ip, "netmask", netmask)
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface addr: %w", err)
}
return nil
}
func (l *Link) up(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "up")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("up %s interface: %w", name, err)
}
return nil
}
func (l *Link) down(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "down")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("down %s interface: %w", name, err)
}
return nil
}

207
client/iface/iface.go Normal file
View File

@ -0,0 +1,207 @@
package iface
import (
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
DefaultMTU = 1280
DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
// WGIface represents an interface instance
type WGIface struct {
tun WGTunDevice
userspaceBind bool
mu sync.Mutex
configurer device.WGConfigurer
filter device.PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
}
// Name returns the interface name
func (w *WGIface) Name() string {
return w.tun.DeviceName()
}
// Address returns the interface address
func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
// ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName()
intf, err := net.InterfaceByName(name)
if err != nil {
log.Warnf("Failed to get interface by name %s: %v", name, err)
intf = &net.Interface{
Name: name,
}
}
return intf
}
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Up()
}
// UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}
return w.tun.UpdateAddr(addr)
}
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey)
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
}
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
}
err = w.waitUntilRemoved()
if err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
}
log.Infof("interface %s successfully removed", w.Name())
}
return nil
}
// SetFilter sets packet filters for the userspace implementation
func (w *WGIface) SetFilter(filter device.PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun.FilteredDevice() == nil {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter)
return nil
}
// GetFilter returns packet filter used by interface if it uses userspace device implementation
func (w *WGIface) GetFilter() device.PacketFilter {
w.mu.Lock()
defer w.mu.Unlock()
return w.filter
}
// GetDevice to interact with raw device (with filtering)
func (w *WGIface) GetDevice() *device.FilteredDevice {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.FilteredDevice()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)
defer timeout.Stop()
for {
iface, err := net.InterfaceByName(w.Name())
if err != nil {
if _, ok := err.(*net.OpError); ok {
log.Infof("interface %s has been removed", w.Name())
return nil
}
log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
} else if iface == nil {
log.Infof("interface %s has been removed", w.Name())
return nil
}
select {
case <-timeout.C:
return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
default:
time.Sleep(100 * time.Millisecond)
}
}
}

View File

@ -0,0 +1,43 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@ -0,0 +1,19 @@
//go:build (!android && !darwin) || ios
package iface
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create()
if err != nil {
return err
}
w.configurer = cfgr
return nil
}

View File

@ -0,0 +1,67 @@
//go:build !ios
package iface
import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
backOff := &backoff.ExponentialBackOff{
InitialInterval: 20 * time.Millisecond,
MaxElapsedTime: 500 * time.Millisecond,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
cfgr, err := w.tun.Create()
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
return backoff.Retry(operation, backOff)
}

View File

@ -0,0 +1,17 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

View File

@ -0,0 +1,22 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

View File

@ -0,0 +1,9 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

View File

@ -0,0 +1,32 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

31
client/iface/iface_ios.go Normal file
View File

@ -0,0 +1,31 @@
//go:build ios
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

105
client/iface/iface_moc.go Normal file
View File

@ -0,0 +1,105 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}

529
client/iface/iface_test.go Normal file
View File

@ -0,0 +1,529 @@
package iface
import (
"fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
)
// keep darwin compatibility
const (
WgIntNumber = 2000
)
var (
key string
peerPubKey string
)
func init() {
log.SetLevel(log.DebugLevel)
privateKey, _ := wgtypes.GeneratePrivateKey()
key = privateKey.String()
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
peerPubKey = peerPrivateKey.PublicKey().String()
}
func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
wgPort := 33100
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
addrs, err := getIfaceAddrs(ifaceName)
if err != nil {
t.Error(err)
}
assert.Equal(t, addr, addrs[0].String())
//update WireGuard address
addr = "100.64.0.2/8"
err = iface.UpdateAddr(addr)
if err != nil {
t.Fatal(err)
}
addrs, err = getIfaceAddrs(ifaceName)
if err != nil {
t.Error(err)
}
var found bool
for _, a := range addrs {
prefix, err := netip.ParsePrefix(a.String())
assert.NoError(t, err)
if prefix.Addr().Is4() {
found = true
assert.Equal(t, addr, prefix.String())
}
}
if !found {
t.Fatal("v4 address not found")
}
}
func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
ief, err := net.InterfaceByName(ifaceName)
if err != nil {
return nil, err
}
addrs, err := ief.Addrs()
if err != nil {
return nil, err
}
return addrs, nil
}
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
wg, err := wgctrl.New()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
}
func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
wg, err := wgctrl.New()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
err = iface.Close()
if err != nil {
t.Fatal(err)
}
}
func TestRecreation(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("down-%d", i), func(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
for {
_, err = net.InterfaceByName(ifaceName)
if err != nil {
t.Logf("interface %s not found: err: %s", ifaceName, err)
break
}
t.Logf("interface %s found", ifaceName)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
wg, err := wgctrl.New()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
for {
_, err = net.InterfaceByName(ifaceName)
if err == nil {
t.Logf("interface %s found", ifaceName)
break
}
t.Logf("interface %s not found: err: %s", ifaceName, err)
}
start := time.Now()
err = iface.Close()
t.Logf("down time: %s", time.Since(start))
if err != nil {
t.Fatal(err)
}
})
}
}
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
wgPort := 33100
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
wg, err := wgctrl.New()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
t.Fatal(err)
}
if wgDevice.PrivateKey.String() != key {
t.Fatalf("Private keys don't match after configure: %s != %s", key, wgDevice.PrivateKey.String())
}
}
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := "10.99.99.10/32"
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
if err != nil {
t.Fatal(err)
}
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil)
if err != nil {
t.Fatal(err)
}
peer, err := getPeer(ifaceName, peerPubKey)
if err != nil {
t.Fatal(err)
}
if peer.PersistentKeepaliveInterval != keepAlive {
t.Fatal("configured peer with mismatched keepalive interval value")
}
if peer.Endpoint.String() != endpoint.String() {
t.Fatal("configured peer with mismatched endpoint")
}
var foundAllowedIP bool
for _, aip := range peer.AllowedIPs {
if aip.String() == allowedIP {
foundAllowedIP = true
break
}
}
if !foundAllowedIP {
t.Fatal("configured peer with mismatched Allowed IPs")
}
}
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface.Close()
if err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := "10.99.99.14/32"
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface.RemovePeer(peerPubKey)
if err != nil {
t.Fatal(err)
}
_, err = getPeer(ifaceName, peerPubKey)
if err.Error() != "peer not found" {
t.Fatal(err)
}
}
func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
peer1wgIP := "10.99.99.17/30"
peer1Key, _ := wgtypes.GeneratePrivateKey()
peer1wgPort := 33100
peer2ifaceName := "utun500"
peer2wgIP := "10.99.99.18/30"
peer2Key, _ := wgtypes.GeneratePrivateKey()
peer2wgPort := 33200
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface1.Create()
if err != nil {
t.Fatal(err)
}
_, err = iface1.Up()
if err != nil {
t.Fatal(err)
}
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort))
if err != nil {
t.Fatal(err)
}
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
err = iface2.Create()
if err != nil {
t.Fatal(err)
}
_, err = iface2.Up()
if err != nil {
t.Fatal(err)
}
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort))
if err != nil {
t.Fatal(err)
}
defer func() {
err = iface1.Close()
if err != nil {
t.Error(err)
}
err = iface2.Close()
if err != nil {
t.Error(err)
}
}()
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil)
if err != nil {
t.Fatal(err)
}
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil)
if err != nil {
t.Fatal(err)
}
// todo: investigate why in some tests execution we need 30s
timeout := 30 * time.Second
timeoutChannel := time.After(timeout)
for {
select {
case <-timeoutChannel:
t.Fatalf("waiting for peer handshake timeout after %s", timeout.String())
default:
}
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String())
if gpErr != nil {
t.Fatal(gpErr)
}
if !peer.LastHandshakeTime.IsZero() {
t.Log("peers successfully handshake")
break
}
}
}
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}

View File

@ -0,0 +1,49 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"runtime"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
}

View File

@ -0,0 +1,41 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@ -0,0 +1,34 @@
//go:build !windows
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@ -0,0 +1,33 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}

View File

@ -0,0 +1,7 @@
## Mocks
To generate (or refresh) mocks from iface package interfaces please install [mockgen](https://github.com/golang/mock).
Run this command to update PacketFilter mock:
```bash
mockgen -destination iface/mocks/filter.go -package mocks github.com/netbirdio/netbird/iface PacketFilter
```

View File

@ -0,0 +1,103 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

152
client/iface/mocks/tun.go Normal file
View File

@ -0,0 +1,152 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: golang.zx2c4.com/wireguard/tun (interfaces: Device)
// Package mocks is a generated GoMock package.
package mocks
import (
os "os"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
tun "golang.zx2c4.com/wireguard/tun"
)
// MockDevice is a mock of Device interface.
type MockDevice struct {
ctrl *gomock.Controller
recorder *MockDeviceMockRecorder
}
// MockDeviceMockRecorder is the mock recorder for MockDevice.
type MockDeviceMockRecorder struct {
mock *MockDevice
}
// NewMockDevice creates a new mock instance.
func NewMockDevice(ctrl *gomock.Controller) *MockDevice {
mock := &MockDevice{ctrl: ctrl}
mock.recorder = &MockDeviceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDevice) EXPECT() *MockDeviceMockRecorder {
return m.recorder
}
// BatchSize mocks base method.
func (m *MockDevice) BatchSize() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchSize")
ret0, _ := ret[0].(int)
return ret0
}
// BatchSize indicates an expected call of BatchSize.
func (mr *MockDeviceMockRecorder) BatchSize() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchSize", reflect.TypeOf((*MockDevice)(nil).BatchSize))
}
// Close mocks base method.
func (m *MockDevice) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockDeviceMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDevice)(nil).Close))
}
// Events mocks base method.
func (m *MockDevice) Events() <-chan tun.Event {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Events")
ret0, _ := ret[0].(<-chan tun.Event)
return ret0
}
// Events indicates an expected call of Events.
func (mr *MockDeviceMockRecorder) Events() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Events", reflect.TypeOf((*MockDevice)(nil).Events))
}
// File mocks base method.
func (m *MockDevice) File() *os.File {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "File")
ret0, _ := ret[0].(*os.File)
return ret0
}
// File indicates an expected call of File.
func (mr *MockDeviceMockRecorder) File() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "File", reflect.TypeOf((*MockDevice)(nil).File))
}
// MTU mocks base method.
func (m *MockDevice) MTU() (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MTU")
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MTU indicates an expected call of MTU.
func (mr *MockDeviceMockRecorder) MTU() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MTU", reflect.TypeOf((*MockDevice)(nil).MTU))
}
// Name mocks base method.
func (m *MockDevice) Name() (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Name indicates an expected call of Name.
func (mr *MockDeviceMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockDevice)(nil).Name))
}
// Read mocks base method.
func (m *MockDevice) Read(arg0 [][]byte, arg1 []int, arg2 int) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0, arg1, arg2)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockDeviceMockRecorder) Read(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockDevice)(nil).Read), arg0, arg1, arg2)
}
// Write mocks base method.
func (m *MockDevice) Write(arg0 [][]byte, arg1 int) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0, arg1)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write.
func (mr *MockDeviceMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDevice)(nil).Write), arg0, arg1)
}

View File

@ -0,0 +1,32 @@
package netstack
import (
"context"
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}
type NSDialer struct {
net *netstack.Net
}
func NewNSDialer(net *netstack.Net) *NSDialer {
return &NSDialer{
net: net,
}
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugf("dialing %s %s", network, addr)
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Debugf("failed to deal connection: %s", err)
}
return conn, err
}

View File

@ -0,0 +1,33 @@
package netstack
import (
"fmt"
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
// IsEnabled todo: move these function to cmd layer
func IsEnabled() bool {
return os.Getenv("NB_USE_NETSTACK_MODE") == "true"
}
func ListenAddr() string {
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
port, err := strconv.Atoi(sPort)
if err != nil {
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)
return listenAddr(DefaultSocks5Port)
}
if port < 1 || port > 65535 {
log.Warnf("invalid socks5 listener port, it should be in the range 1-65535, falling back to default: %d", DefaultSocks5Port)
return listenAddr(DefaultSocks5Port)
}
return listenAddr(port)
}
func listenAddr(port int) string {
return fmt.Sprintf("0.0.0.0:%d", port)
}

View File

@ -0,0 +1,65 @@
package netstack
import (
"net"
"github.com/things-go/go-socks5"
log "github.com/sirupsen/logrus"
)
const (
DefaultSocks5Port = 1080
)
// Proxy todo close server
type Proxy struct {
server *socks5.Server
listener net.Listener
closed bool
}
func NewSocks5(dialer Dialer) (*Proxy, error) {
server := socks5.NewServer(
socks5.WithDial(dialer.Dial),
)
return &Proxy{
server: server,
}, nil
}
func (s *Proxy) ListenAndServe(addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Errorf("failed to create listener for socks5 proxy: %s", err)
return err
}
s.listener = listener
for {
conn, err := listener.Accept()
if err != nil {
if s.closed {
return nil
}
return err
}
go func() {
if err := s.server.ServeConn(conn); err != nil {
log.Errorf("failed to serve a connection: %s", err)
}
}()
}
}
func (s *Proxy) Close() error {
if s.listener == nil {
return nil
}
s.closed = true
return s.listener.Close()
}

View File

@ -0,0 +1,74 @@
package netstack
import (
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type NetStackTun struct { //nolint:revive
address string
mtu int
listenAddress string
proxy *Proxy
tundev tun.Device
}
func NewNetStackTun(listenAddress string, address string, mtu int) *NetStackTun {
return &NetStackTun{
address: address,
mtu: mtu,
listenAddress: listenAddress,
}
}
func (t *NetStackTun) Create() (tun.Device, error) {
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(t.address)},
[]netip.Addr{},
t.mtu)
if err != nil {
return nil, err
}
t.tundev = nsTunDev
dialer := NewNSDialer(tunNet)
t.proxy, err = NewSocks5(dialer)
if err != nil {
_ = t.tundev.Close()
return nil, err
}
go func() {
err := t.proxy.ListenAndServe(t.listenAddress)
if err != nil {
log.Errorf("error in socks5 proxy serving: %s", err)
}
}()
return nsTunDev, nil
}
func (t *NetStackTun) Close() error {
var err error
if t.proxy != nil {
pErr := t.proxy.Close()
if pErr != nil {
log.Errorf("failed to close socks5 proxy: %s", pErr)
err = pErr
}
}
if t.tundev != nil {
dErr := t.tundev.Close()
if dErr != nil {
log.Errorf("failed to close netstack tun device: %s", dErr)
err = dErr
}
}
return err
}