Fix lint and test issues

This commit is contained in:
Viktor Liu 2024-12-31 14:19:15 +01:00
parent 9feaa8d767
commit fb1a10755a
11 changed files with 64 additions and 52 deletions

View File

@ -1,7 +1,7 @@
package common
import (
device2 "golang.zx2c4.com/wireguard/device"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@ -11,6 +11,6 @@ import (
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
GetWGDevice() *device2.Device
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@ -75,7 +75,7 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.established.Store(true)
t.connections[key] = conn
t.logger.Trace("New UDP connection: %s", conn)
t.logger.Trace("New UDP connection: %v", conn)
}
t.mutex.Unlock()
@ -127,7 +127,7 @@ func (t *UDPTracker) cleanup() {
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("UDP connection timed out: %s", conn)
t.logger.Trace("UDP connection timed out: %v", conn)
}
}
}

View File

@ -55,7 +55,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %w", err)
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
_, bits := iface.Address().Network.Mask.Size()
@ -68,7 +68,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %w", err)
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
@ -79,11 +79,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %w", err)
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %w", err)
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@ -132,7 +132,7 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() error {
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
@ -141,6 +141,4 @@ func (f *Forwarder) Stop() error {
f.stack.Close()
f.stack.Wait()
return nil
}

View File

@ -131,7 +131,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
}
f.udpForwarder.RLock()
pConn, exists := f.udpForwarder.conns[id]
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id)
@ -159,7 +159,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn = &udpPacketConn{
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,

View File

@ -42,27 +42,6 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC",
}
func FromLogrusLevel(level log.Level) Level {
switch level {
case log.TraceLevel:
return LevelTrace
case log.DebugLevel:
return LevelDebug
case log.InfoLevel:
return LevelInfo
case log.WarnLevel:
return LevelWarn
case log.ErrorLevel:
return LevelError
case log.FatalLevel:
return LevelFatal
case log.PanicLevel:
return LevelPanic
default:
return LevelInfo
}
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
@ -128,7 +107,7 @@ func (l *Logger) log(level Level, format string, args ...interface{}) {
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
l.buffer.Write(*bufp)
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
@ -184,7 +163,7 @@ func (l *Logger) worker() {
}
// Write batch
l.output.Write(buf[:n])
_, _ = l.output.Write(buf[:n])
}
}
}

View File

@ -83,11 +83,3 @@ func (r *ringBuffer) Read(p []byte) (n int, err error) {
return n, nil
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -555,10 +555,6 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true
}
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:

View File

@ -10,6 +10,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
@ -22,6 +23,14 @@ type IFaceMock struct {
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
return nil
}
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
return nil
}
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented")

View File

@ -1,6 +1,8 @@
package iface
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
@ -13,4 +15,5 @@ type WGTunDevice interface {
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
}

View File

@ -4,6 +4,7 @@ import (
"net"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
@ -29,6 +30,7 @@ type MockWGIface struct {
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
return m.GetProxyFunc()
}

View File

@ -8,6 +8,8 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
}
// GetDevice mocks base method.
func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDevice")
ret0, _ := ret[0].(*device.FilteredDevice)
return ret0
}
// GetDevice indicates an expected call of GetDevice.
func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice))
}
// GetWGDevice mocks base method.
func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWGDevice")
ret0, _ := ret[0].(*wgdevice.Device)
return ret0
}
// GetWGDevice indicates an expected call of GetWGDevice.
func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice))
}