diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2dbeb106a..664e8be18 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -44,4 +44,5 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) + diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 5f7d7b4a3..ba5f66746 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -134,7 +134,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) test_management: needs: [ build-cache ] @@ -194,7 +194,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: needs: [ build-cache ] @@ -254,7 +254,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... api_benchmark: needs: [ build-cache ] diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 3a3c47052..782e4c30a 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -65,7 +65,7 @@ jobs: - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d774f4538..2592ff840 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -19,8 +19,7 @@ const ( tableName = "filter" // rules chains contains the effective ACL rules - chainNameInputRules = "NETBIRD-ACL-INPUT" - chainNameOutputRules = "NETBIRD-ACL-OUTPUT" + chainNameInputRules = "NETBIRD-ACL-INPUT" ) type aclEntries map[string][][]string @@ -84,7 +83,6 @@ func (m *aclManager) AddPeerFiltering( protocol firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, - direction firewall.RuleDirection, action firewall.Action, ipsetName string, ) ([]firewall.Rule, error) { @@ -97,15 +95,10 @@ func (m *aclManager) AddPeerFiltering( sPortVal = strconv.Itoa(sPort.Values[0]) } - var chain string - if direction == firewall.RuleDirectionOUT { - chain = chainNameOutputRules - } else { - chain = chainNameInputRules - } + chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) - specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) + specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName) if ipsetName != "" { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if err := ipset.Add(ipsetName, ip.String()); err != nil { @@ -214,28 +207,7 @@ func (m *aclManager) Reset() error { // todo write less destructive cleanup mechanism func (m *aclManager) cleanChains() error { - ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) - if err != nil { - log.Debugf("failed to list chains: %s", err) - return err - } - if ok { - rules := m.entries["OUTPUT"] - for _, rule := range rules { - err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...) - if err != nil { - log.Errorf("failed to delete rule: %v, %s", rule, err) - } - } - - err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules) - if err != nil { - log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err) - return err - } - } - - ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules) + ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules) if err != nil { log.Debugf("failed to list chains: %s", err) return err @@ -295,12 +267,6 @@ func (m *aclManager) createDefaultChains() error { return err } - // chain netbird-acl-output-rules - if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil { - log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err) - return err - } - for chainName, rules := range m.entries { for _, rule := range rules { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { @@ -329,8 +295,6 @@ func (m *aclManager) createDefaultChains() error { // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. - -// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { established := getConntrackEstablished() @@ -390,30 +354,18 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule -func filterRuleSpecs( - ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, -) (specs []string) { +func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) { matchByIP := true // don't use IP matching if IP is ip 0.0.0.0 if ip.String() == "0.0.0.0" { matchByIP = false } - switch direction { - case firewall.RuleDirectionIN: - if matchByIP { - if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "src") - } else { - specs = append(specs, "-s", ip.String()) - } - } - case firewall.RuleDirectionOUT: - if matchByIP { - if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "dst") - } else { - specs = append(specs, "-d", ip.String()) - } + + if matchByIP { + if ipsetName != "" { + specs = append(specs, "-m", "set", "--set", ipsetName, "src") + } else { + specs = append(specs, "-s", ip.String()) } } if protocol != "all" { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 8f7084bca..679f288e3 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering( protocol firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, - direction firewall.RuleDirection, action firewall.Action, ipsetName string, - comment string, + _ string, ) ([]firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) + return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error { "all", nil, nil, - firewall.RuleDirectionIN, firewall.ActionAccept, "", "", diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ebdb83137..fe0bc86de 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -68,27 +68,13 @@ func TestIptablesManager(t *testing.T) { time.Sleep(time.Second) }() - var rule1 []fw.Rule - t.Run("add first rule", func(t *testing.T) { - ip := net.ParseIP("10.20.0.2") - port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") - require.NoError(t, err, "failed to add rule") - - for _, r := range rule1 { - checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...) - } - - }) - var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") port := &fw.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddPeerFiltering( - ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") + rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") for _, r := range rule2 { @@ -97,15 +83,6 @@ func TestIptablesManager(t *testing.T) { } }) - t.Run("delete first rule", func(t *testing.T) { - for _, r := range rule1 { - err := manager.DeletePeerRule(r) - require.NoError(t, err, "failed to delete rule") - - checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) - } - }) - t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { err := manager.DeletePeerRule(r) @@ -119,7 +96,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset(nil) @@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) { } func TestIptablesManagerIPSet(t *testing.T) { - ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) - require.NoError(t, err) - mock := &iFaceMock{ NameFunc: func() string { return "lo" @@ -167,33 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) { time.Sleep(time.Second) }() - var rule1 []fw.Rule - t.Run("add first rule with set", func(t *testing.T) { - ip := net.ParseIP("10.20.0.2") - port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddPeerFiltering( - ip, "tcp", nil, port, fw.RuleDirectionOUT, - fw.ActionAccept, "default", "accept HTTP traffic", - ) - require.NoError(t, err, "failed to add rule") - - for _, r := range rule1 { - checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...) - require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set") - require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set") - } - }) - var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") port := &fw.Port{ Values: []int{443}, } - rule2, err = manager.AddPeerFiltering( - ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, - "default", "accept HTTPS traffic from ports range", - ) + rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range") for _, r := range rule2 { require.NoError(t, err, "failed to add rule") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") @@ -201,15 +155,6 @@ func TestIptablesManagerIPSet(t *testing.T) { } }) - t.Run("delete first rule", func(t *testing.T) { - for _, r := range rule1 { - err := manager.DeletePeerRule(r) - require.NoError(t, err, "failed to delete rule") - - require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") - } - }) - t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { err := manager.DeletePeerRule(r) @@ -270,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) { start := time.Now() for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} - if i%2 == 0 { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") - } else { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") - } + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 247e55686..de25ff1f1 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -69,7 +69,6 @@ type Manager interface { proto Protocol, sPort *Port, dPort *Port, - direction RuleDirection, action Action, ipsetName string, comment string, diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 852cfec8d..8c1d89e68 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -22,8 +22,7 @@ import ( const ( // rules chains contains the effective ACL rules - chainNameInputRules = "netbird-acl-input-rules" - chainNameOutputRules = "netbird-acl-output-rules" + chainNameInputRules = "netbird-acl-input-rules" // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" @@ -45,9 +44,8 @@ type AclManager struct { wgIface iFaceMapper routingFwChainName string - workTable *nftables.Table - chainInputRules *nftables.Chain - chainOutputRules *nftables.Chain + workTable *nftables.Table + chainInputRules *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -89,7 +87,6 @@ func (m *AclManager) AddPeerFiltering( proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, - direction firewall.RuleDirection, action firewall.Action, ipsetName string, comment string, @@ -104,7 +101,7 @@ func (m *AclManager) AddPeerFiltering( } newRules := make([]firewall.Rule, 0, 2) - ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment) + ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment) if err != nil { return nil, err } @@ -214,38 +211,6 @@ func (m *AclManager) createDefaultAllowRules() error { Exprs: expIn, }) - expOut := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - _ = m.rConn.InsertRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainOutputRules, - Position: 0, - Exprs: expOut, - }) - if err := m.rConn.Flush(); err != nil { return fmt.Errorf(flushError, err) } @@ -264,15 +229,19 @@ func (m *AclManager) Flush() error { log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) } - if err := m.refreshRuleHandles(m.chainOutputRules); err != nil { - log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) - } - return nil } -func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) +func (m *AclManager) addIOFiltering( + ip net.IP, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, + ipset *nftables.Set, + comment string, +) (*Rule, error) { + ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ r.nftRule, @@ -310,9 +279,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f if !bytes.HasPrefix(anyIP, rawIP) { // source address position addrOffset := uint32(12) - if direction == firewall.RuleDirectionOUT { - addrOffset += 4 // is ipv4 address length - } expressions = append(expressions, &expr.Payload{ @@ -383,12 +349,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f userData := []byte(strings.Join([]string{ruleId, comment}, " ")) - var chain *nftables.Chain - if direction == firewall.RuleDirectionIN { - chain = m.chainInputRules - } else { - chain = m.chainOutputRules - } + chain := m.chainInputRules nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, @@ -419,15 +380,6 @@ func (m *AclManager) createDefaultChains() (err error) { } m.chainInputRules = chain - // chainNameOutputRules - chain = m.createChain(chainNameOutputRules) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err) - return err - } - m.chainOutputRules = chain - // netbird-acl-input-filter // type filter hook input priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) @@ -720,15 +672,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generatePeerRuleId( - ip net.IP, - sPort *firewall.Port, - dPort *firewall.Port, - direction firewall.RuleDirection, - action firewall.Action, - ipset *nftables.Set, -) string { - rulesetID := ":" + strconv.Itoa(int(direction)) + ":" +func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 76390d30a..4fe52bd53 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering( proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, - direction firewall.RuleDirection, action firewall.Action, ipsetName string, comment string, @@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering( return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) + return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment) } -func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 33fdc4b3d..9c9637282 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddPeerFiltering( - ip, - fw.ProtocolTCP, - nil, - &fw.Port{Values: []int{53}}, - fw.RuleDirectionIN, - fw.ActionDrop, - "", - "", - ) + rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "") require.NoError(t, err, "failed to add rule") err = manager.Flush() @@ -210,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { start := time.Now() for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} - if i%2 == 0 { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") - } else { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") - } + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") if i%100 == 0 { @@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { }) ip := net.ParseIP("100.96.0.1") - _, err = manager.AddPeerFiltering( - ip, - fw.ProtocolTCP, - nil, - &fw.Port{Values: []int{80}}, - fw.RuleDirectionIN, - fw.ActionAccept, - "", - "test rule", - ) + _, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule") require.NoError(t, err, "failed to add peer filtering rule") _, err = manager.AddRouteFiltering( diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index 3d199ce65..03a2573d6 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -16,7 +16,6 @@ type PeerRule struct { ipLayer gopacket.LayerType matchByIP bool protoLayer gopacket.LayerType - direction firewall.RuleDirection sPort uint16 dPort uint16 drop bool diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 94a2f45d2..99a3dcee0 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -45,7 +45,9 @@ type RuleSet map[string]PeerRule // Manager userspace firewall manager type Manager struct { - outgoingRules map[string]RuleSet + // outgoingRules is used for hooks only + outgoingRules map[string]RuleSet + // incomingRules is used for filtering and hooks incomingRules map[string]RuleSet routeRules map[string]RouteRule wgNetwork *net.IPNet @@ -297,9 +299,8 @@ func (m *Manager) AddPeerFiltering( proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, - direction firewall.RuleDirection, action firewall.Action, - ipsetName string, + _ string, comment string, ) ([]firewall.Rule, error) { r := PeerRule{ @@ -307,7 +308,6 @@ func (m *Manager) AddPeerFiltering( ip: ip, ipLayer: layers.LayerTypeIPv6, matchByIP: true, - direction: direction, drop: action == firewall.ActionDrop, comment: comment, } @@ -343,17 +343,10 @@ func (m *Manager) AddPeerFiltering( } m.mutex.Lock() - if direction == firewall.RuleDirectionIN { - if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(RuleSet) - } - m.incomingRules[r.ip.String()][r.id] = r - } else { - if _, ok := m.outgoingRules[r.ip.String()]; !ok { - m.outgoingRules[r.ip.String()] = make(RuleSet) - } - m.outgoingRules[r.ip.String()][r.id] = r + if _, ok := m.incomingRules[r.ip.String()]; !ok { + m.incomingRules[r.ip.String()] = make(RuleSet) } + m.incomingRules[r.ip.String()][r.id] = r m.mutex.Unlock() return []firewall.Rule{&r}, nil } @@ -416,19 +409,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } - if r.direction == firewall.RuleDirectionIN { - _, ok := m.incomingRules[r.ip.String()][r.id] - if !ok { - return fmt.Errorf("delete rule: no rule with such id: %v", r.id) - } - delete(m.incomingRules[r.ip.String()], r.id) - } else { - _, ok := m.outgoingRules[r.ip.String()][r.id] - if !ok { - return fmt.Errorf("delete rule: no rule with such id: %v", r.id) - } - delete(m.outgoingRules[r.ip.String()], r.id) + if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { + return fmt.Errorf("delete rule: no rule with such id: %v", r.id) } + delete(m.incomingRules[r.ip.String()], r.id) return nil } @@ -918,7 +902,6 @@ func (m *Manager) AddUDPPacketHook( protoLayer: layers.LayerTypeUDP, dPort: dPort, ipLayer: layers.LayerTypeIPv6, - direction: firewall.RuleDirectionOUT, comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), udpHook: hook, } @@ -929,7 +912,6 @@ func (m *Manager) AddUDPPacketHook( m.mutex.Lock() if in { - r.direction = firewall.RuleDirectionIN if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(map[string]PeerRule) } @@ -948,19 +930,22 @@ func (m *Manager) AddUDPPacketHook( // RemovePacketHook removes packet hook by given ID func (m *Manager) RemovePacketHook(hookID string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + for _, arr := range m.incomingRules { for _, r := range arr { if r.id == hookID { - rule := r - return m.DeletePeerRule(&rule) + delete(arr, r.id) + return nil } } } for _, arr := range m.outgoingRules { for _, r := range arr { if r.id == hookID { - rule := r - return m.DeletePeerRule(&rule) + delete(arr, r.id) + return nil } } } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 684057d24..92f72f839 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) { setupFunc: func(m *Manager) { // Single rule allowing all traffic _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, - fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + fw.ActionAccept, "", "allow all") require.NoError(b, err) }, desc: "Baseline: Single 'allow all' rule without connection tracking", @@ -117,7 +117,7 @@ func BenchmarkCoreFiltering(b *testing.B) { _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, &fw.Port{Values: []int{1024 + i}}, &fw.Port{Values: []int{80}}, - fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + fw.ActionAccept, "", "explicit return") require.NoError(b, err) } }, @@ -129,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) { setupFunc: func(m *Manager) { // Add some basic rules but rely on state for established connections _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, - fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + fw.ActionDrop, "", "default drop") require.NoError(b, err) }, desc: "Connection tracking with established connections", @@ -593,7 +593,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []int{80}}, nil, - fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + fw.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -684,7 +684,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []int{80}}, nil, - fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + fw.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -802,7 +802,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []int{80}}, nil, - fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + fw.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -889,7 +889,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []int{80}}, nil, - fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + fw.ActionAccept, "", "return traffic") require.NoError(b, err) } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index ecfc6bf96..6e3e96255 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -91,11 +91,10 @@ func TestManagerAddPeerFiltering(t *testing.T) { ip := net.ParseIP("192.168.1.1") proto := fw.ProtocolTCP port := &fw.Port{Values: []int{80}} - direction := fw.RuleDirectionOUT action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -126,37 +125,15 @@ func TestManagerDeleteRule(t *testing.T) { ip := net.ParseIP("192.168.1.1") proto := fw.ProtocolTCP port := &fw.Port{Values: []int{80}} - direction := fw.RuleDirectionOUT action := fw.ActionDrop - comment := "Test rule" + comment := "Test rule 2" - rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) + rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return } - ip = net.ParseIP("192.168.1.1") - proto = fw.ProtocolTCP - port = &fw.Port{Values: []int{80}} - direction = fw.RuleDirectionIN - action = fw.ActionDrop - comment = "Test rule 2" - - rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) - if err != nil { - t.Errorf("failed to add filtering: %v", err) - return - } - - for _, r := range rule { - err = m.DeletePeerRule(r) - if err != nil { - t.Errorf("failed to delete rule: %v", err) - return - } - } - for _, r := range rule2 { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { t.Errorf("rule2 is not in the incomingRules") @@ -246,10 +223,6 @@ func TestAddUDPPacketHook(t *testing.T) { t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) return } - if tt.expDir != addedRule.direction { - t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction) - return - } if addedRule.udpHook == nil { t.Errorf("expected udpHook to be set") return @@ -272,11 +245,10 @@ func TestManagerReset(t *testing.T) { ip := net.ParseIP("192.168.1.1") proto := fw.ProtocolTCP port := &fw.Port{Values: []int{80}} - direction := fw.RuleDirectionOUT action := fw.ActionDrop comment := "Test rule" - _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -319,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) { ip := net.ParseIP("0.0.0.0") proto := fw.ProtocolUDP - direction := fw.RuleDirectionIN action := fw.ActionAccept comment := "Test rule" - _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -523,11 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { start := time.Now() for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} - if i%2 == 0 { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") - } else { - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") - } + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") } diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5bb0905d2..0ade5d7ce 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { d.rollBack(newRulePairs) break } - if len(rules) > 0 { + if len(rulePair) > 0 { d.peerRulesPairs[pairID] = rulePair newRulePairs[pairID] = rulePair } @@ -288,6 +288,8 @@ func (d *DefaultManager) protoRuleToFirewallRule( case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") case mgmProto.RuleDirection_OUT: + // TODO: Remove this soon. Outbound rules are obsolete. + // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") @@ -308,25 +310,12 @@ func (d *DefaultManager) addInRules( ipsetName string, comment string, ) ([]firewall.Rule, error) { - var rules []firewall.Rule - rule, err := d.firewall.AddPeerFiltering( - ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) + rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment) if err != nil { - return nil, fmt.Errorf("failed to add firewall rule: %v", err) - } - rules = append(rules, rule...) - - if shouldSkipInvertedRule(protocol, port) { - return rules, nil + return nil, fmt.Errorf("add firewall rule: %w", err) } - rule, err = d.firewall.AddPeerFiltering( - ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) - if err != nil { - return nil, fmt.Errorf("failed to add firewall rule: %v", err) - } - - return append(rules, rule...), nil + return rule, nil } func (d *DefaultManager) addOutRules( @@ -337,25 +326,16 @@ func (d *DefaultManager) addOutRules( ipsetName string, comment string, ) ([]firewall.Rule, error) { - var rules []firewall.Rule - rule, err := d.firewall.AddPeerFiltering( - ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) - if err != nil { - return nil, fmt.Errorf("failed to add firewall rule: %v", err) - } - rules = append(rules, rule...) - if shouldSkipInvertedRule(protocol, port) { - return rules, nil + return nil, nil } - rule, err = d.firewall.AddPeerFiltering( - ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) + rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment) if err != nil { - return nil, fmt.Errorf("failed to add firewall rule: %v", err) + return nil, fmt.Errorf("add firewall rule: %w", err) } - return append(rules, rule...), nil + return rule, nil } // getPeerRuleID() returns unique ID for the rule based on its parameters. diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index d146fef1f..217dbce9f 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -120,8 +120,8 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRulesIsEmpty = false acl.ApplyFiltering(networkMap) - if len(acl.peerRulesPairs) != 2 { - t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) + if len(acl.peerRulesPairs) != 1 { + t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return } }) @@ -358,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.peerRulesPairs) != 4 { - t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) + if len(acl.peerRulesPairs) != 3 { + t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) return } } diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 7bd4aec64..297d50822 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -48,11 +48,17 @@ type restoreHostManager interface { func newHostManager(wgInterface string) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { - return nil, err + return nil, fmt.Errorf("get os dns manager type: %w", err) } log.Infof("System DNS manager discovered: %s", osManager) - return newHostManagerFromType(wgInterface, osManager) + mgr, err := newHostManagerFromType(wgInterface, osManager) + // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value + if err != nil { + return nil, fmt.Errorf("create host manager: %w", err) + } + + return mgr, nil } func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index bb097c4cb..1fe913fd9 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -12,6 +12,7 @@ import ( "github.com/mitchellh/hashstructure/v2" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -239,7 +240,10 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) - if s.disableSys { + // use noop host manager if requested or running in netstack mode. + // Netstack mode currently doesn't have a way to receive DNS requests. + // TODO: Use listener on localhost in netstack mode when running as root. + if s.disableSys || netstack.IsEnabled() { log.Info("system DNS is disabled, not setting up host manager") s.hostManager = &noopHostConfigurator{} return nil diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index e6dfd278e..968f2d398 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error { return nil } - dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err diff --git a/client/internal/engine.go b/client/internal/engine.go index 83b993eb1..62144efaf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -499,7 +499,6 @@ func (e *Engine) initFirewall() error { manager.ProtocolUDP, nil, &port, - manager.RuleDirectionIN, manager.ActionAccept, "", "", diff --git a/client/ssh/login.go b/client/ssh/login.go index 578f58775..d1d56ceb0 100644 --- a/client/ssh/login.go +++ b/client/ssh/login.go @@ -4,13 +4,27 @@ import ( "fmt" "net" "net/netip" + "os" "os/exec" "runtime" "github.com/netbirdio/netbird/util" ) +func isRoot() bool { + return os.Geteuid() == 0 +} + func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { + if !isRoot() { + shell := getUserShell(user) + if shell == "" { + shell = "/bin/sh" + } + + return shell, []string{"-l"}, nil + } + loginPath, err = exec.LookPath("login") if err != nil { return "", nil, err diff --git a/client/ssh/lookup.go b/client/ssh/lookup.go index 7acef8f0b..9a7f6ff2e 100644 --- a/client/ssh/lookup.go +++ b/client/ssh/lookup.go @@ -6,5 +6,9 @@ package ssh import "os/user" func userNameLookup(username string) (*user.User, error) { + if username == "" || (username == "root" && !isRoot()) { + return user.Current() + } + return user.Lookup(username) } diff --git a/client/ssh/lookup_darwin.go b/client/ssh/lookup_darwin.go index e6f3c3b93..913d049dc 100644 --- a/client/ssh/lookup_darwin.go +++ b/client/ssh/lookup_darwin.go @@ -12,6 +12,10 @@ import ( ) func userNameLookup(username string) (*user.User, error) { + if username == "" || (username == "root" && !isRoot()) { + return user.Current() + } + var userObject *user.User userObject, err := user.Lookup(username) if err != nil && err.Error() == user.UnknownUserError(username).Error() { diff --git a/client/system/info_android.go b/client/system/info_android.go index 7718da913..2d44a6f52 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -39,6 +39,9 @@ func GetInfo(ctx context.Context) *Info { WiretrusteeVersion: version.NetbirdVersion(), UIVersion: extractUIVersion(ctx), KernelVersion: kernelVersion, + SystemSerialNumber: serial(), + SystemProductName: productModel(), + SystemManufacturer: productManufacturer(), } return gio @@ -49,13 +52,42 @@ func checkFileAndProcess(paths []string) ([]File, error) { return []File{}, nil } +func serial() string { + // try to fetch serial ID using different properties + properties := []string{"ril.serialnumber", "ro.serialno", "ro.boot.serialno", "sys.serialnumber"} + var value string + + for _, property := range properties { + value = getprop(property) + if len(value) > 0 { + return value + } + } + + // unable to get serial ID, fallback to ANDROID_ID + return androidId() +} + +func androidId() string { + // this is a uniq id defined on first initialization, id will be a new one if user wipes his device + return run("/system/bin/settings", "get", "secure", "android_id") +} + +func productModel() string { + return getprop("ro.product.model") +} + +func productManufacturer() string { + return getprop("ro.product.manufacturer") +} + func uname() []string { res := run("/system/bin/uname", "-a") return strings.Split(res, " ") } func osVersion() string { - return run("/system/bin/getprop", "ro.build.version.release") + return getprop("ro.build.version.release") } func extractUIVersion(ctx context.Context) string { @@ -66,6 +98,10 @@ func extractUIVersion(ctx context.Context) string { return v } +func getprop(arg ...string) string { + return run("/system/bin/getprop", arg...) +} + func run(name string, arg ...string) string { cmd := exec.Command(name, arg...) cmd.Stdin = strings.NewReader("some") diff --git a/go.mod b/go.mod index be5906e70..6b7d479a8 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/quic-go/quic-go v0.48.2 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -156,11 +157,13 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect @@ -222,6 +225,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect diff --git a/go.sum b/go.sum index 723388ce7..d6e025d40 100644 --- a/go.sum +++ b/go.sum @@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= @@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/relay/client/client.go b/relay/client/client.go index db5252f50..3c23b70d2 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -10,6 +10,8 @@ import ( log "github.com/sirupsen/logrus" auth "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client/dialer" + "github.com/netbirdio/netbird/relay/client/dialer/quic" "github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" @@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) { msg.Free() default: msg.Free() - cc.log.Infof("message queue is full") - // todo consider to close the connection } } @@ -179,8 +179,7 @@ func (c *Client) Connect() error { return nil } - err := c.connect() - if err != nil { + if err := c.connect(); err != nil { return err } @@ -264,14 +263,14 @@ func (c *Client) Close() error { } func (c *Client) connect() error { - conn, err := ws.Dial(c.connectionURL) + rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + conn, err := rd.Dial() if err != nil { return err } c.relayConn = conn - err = c.handShake() - if err != nil { + if err = c.handShake(); err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) @@ -306,7 +305,7 @@ func (c *Client) handShake() error { return fmt.Errorf("validate version: %w", err) } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { c.log.Errorf("failed to determine message type: %s", err) return err @@ -317,7 +316,7 @@ func (c *Client) handShake() error { return fmt.Errorf("unexpected message type") } - addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) + addr, err := messages.UnmarshalAuthResponse(buf[:n]) if err != nil { return err } @@ -345,27 +344,30 @@ func (c *Client) readLoop(relayConn net.Conn) { c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { - c.log.Debugf("failed to read message from relay server: %s", errExit) + c.log.Errorf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() + c.bufPool.Put(bufPtr) break } - _, err := messages.ValidateVersion(buf[:n]) + buf = buf[:n] + + _, err := messages.ValidateVersion(buf) if err != nil { c.log.Errorf("failed to validate protocol version: %s", err) c.bufPool.Put(bufPtr) continue } - msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineServerMessageType(buf) if err != nil { c.log.Errorf("failed to determine message type: %s", err) c.bufPool.Put(bufPtr) continue } - if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) { + if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) { break } } diff --git a/relay/client/dialer/net/err.go b/relay/client/dialer/net/err.go new file mode 100644 index 000000000..fee844963 --- /dev/null +++ b/relay/client/dialer/net/err.go @@ -0,0 +1,7 @@ +package net + +import "errors" + +var ( + ErrClosedByServer = errors.New("closed by server") +) diff --git a/relay/client/dialer/quic/conn.go b/relay/client/dialer/quic/conn.go new file mode 100644 index 000000000..d64633c8c --- /dev/null +++ b/relay/client/dialer/quic/conn.go @@ -0,0 +1,97 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" + + netErr "github.com/netbirdio/netbird/relay/client/dialer/net" +) + +const ( + Network = "quic" +) + +type Addr struct { + addr string +} + +func (a Addr) Network() string { + return Network +} + +func (a Addr) String() string { + return a.addr +} + +type Conn struct { + session quic.Connection + ctx context.Context +} + +func NewConn(session quic.Connection) net.Conn { + return &Conn{ + session: session, + ctx: context.Background(), + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, c.remoteCloseErrHandling(err) + } + + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.session.SendDatagram(b) + if err != nil { + err = c.remoteCloseErrHandling(err) + log.Errorf("failed to write to QUIC stream: %v", err) + return 0, err + } + return len(b), nil +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.session.RemoteAddr() +} + +func (c *Conn) LocalAddr() net.Addr { + if c.session != nil { + return c.session.LocalAddr() + } + return Addr{addr: "unknown"} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("SetReadDeadline is not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return nil +} + +func (c *Conn) Close() error { + return c.session.CloseWithError(0, "normal closure") +} + +func (c *Conn) remoteCloseErrHandling(err error) error { + var appErr *quic.ApplicationError + if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 { + return netErr.ErrClosedByServer + } + return err +} diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go new file mode 100644 index 000000000..593d1334b --- /dev/null +++ b/relay/client/dialer/quic/quic.go @@ -0,0 +1,71 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" + + quictls "github.com/netbirdio/netbird/relay/tls" + nbnet "github.com/netbirdio/netbird/util/net" +) + +type Dialer struct { +} + +func (d Dialer) Protocol() string { + return Network +} + +func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { + quicURL, err := prepareURL(address) + if err != nil { + return nil, err + } + + quicConfig := &quic.Config{ + KeepAlivePeriod: 30 * time.Second, + MaxIdleTimeout: 4 * time.Minute, + EnableDatagrams: true, + } + + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + log.Errorf("failed to listen on UDP: %s", err) + return nil, err + } + + udpAddr, err := net.ResolveUDPAddr("udp", quicURL) + if err != nil { + log.Errorf("failed to resolve UDP address: %s", err) + return nil, err + } + + session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } + log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err) + return nil, err + } + + conn := NewConn(session) + return conn, nil +} + +func prepareURL(address string) (string, error) { + if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") { + return "", fmt.Errorf("unsupported scheme: %s", address) + } + + if strings.HasPrefix(address, "rels://") { + return address[7:], nil + } + return address[6:], nil +} diff --git a/relay/client/dialer/race_dialer.go b/relay/client/dialer/race_dialer.go new file mode 100644 index 000000000..11dba5799 --- /dev/null +++ b/relay/client/dialer/race_dialer.go @@ -0,0 +1,96 @@ +package dialer + +import ( + "context" + "errors" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + connectionTimeout = 30 * time.Second +) + +type DialeFn interface { + Dial(ctx context.Context, address string) (net.Conn, error) + Protocol() string +} + +type dialResult struct { + Conn net.Conn + Protocol string + Err error +} + +type RaceDial struct { + log *log.Entry + serverURL string + dialerFns []DialeFn +} + +func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { + return &RaceDial{ + log: log, + serverURL: serverURL, + dialerFns: dialerFns, + } +} + +func (r *RaceDial) Dial() (net.Conn, error) { + connChan := make(chan dialResult, len(r.dialerFns)) + winnerConn := make(chan net.Conn, 1) + abortCtx, abort := context.WithCancel(context.Background()) + defer abort() + + for _, dfn := range r.dialerFns { + go r.dial(dfn, abortCtx, connChan) + } + + go r.processResults(connChan, winnerConn, abort) + + conn, ok := <-winnerConn + if !ok { + return nil, errors.New("failed to dial to Relay server on any protocol") + } + return conn, nil +} + +func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { + ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) + defer cancel() + + r.log.Infof("dialing Relay server via %s", dfn.Protocol()) + conn, err := dfn.Dial(ctx, r.serverURL) + connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err} +} + +func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) { + var hasWinner bool + for i := 0; i < len(r.dialerFns); i++ { + dr := <-connChan + if dr.Err != nil { + if errors.Is(dr.Err, context.Canceled) { + r.log.Infof("connection attempt aborted via: %s", dr.Protocol) + } else { + r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err) + } + continue + } + + if hasWinner { + if cerr := dr.Conn.Close(); cerr != nil { + r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr) + } + continue + } + + r.log.Infof("successfully dialed via: %s", dr.Protocol) + + abort() + hasWinner = true + winnerConn <- dr.Conn + } + close(winnerConn) +} diff --git a/relay/client/dialer/race_dialer_test.go b/relay/client/dialer/race_dialer_test.go new file mode 100644 index 000000000..989abb0a6 --- /dev/null +++ b/relay/client/dialer/race_dialer_test.go @@ -0,0 +1,252 @@ +package dialer + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +type MockAddr struct { + network string +} + +func (m *MockAddr) Network() string { + return m.network +} + +func (m *MockAddr) String() string { + return "1.2.3.4" +} + +// MockDialer is a mock implementation of DialeFn +type MockDialer struct { + dialFunc func(ctx context.Context, address string) (net.Conn, error) + protocolStr string +} + +func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) { + return m.dialFunc(ctx, address) +} + +func (m *MockDialer) Protocol() string { + return m.protocolStr +} + +// MockConn implements net.Conn for testing +type MockConn struct { + remoteAddr net.Addr +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func (m *MockConn) Close() error { + return nil +} + +func (m *MockConn) LocalAddr() net.Addr { + return nil +} + +func (m *MockConn) RemoteAddr() net.Addr { + return m.remoteAddr +} + +func (m *MockConn) SetDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestRaceDialEmptyDialers(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + rd := NewRaceDial(logger, serverURL) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error with empty dialers, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection with empty dialers, got %v", conn) + } +} + +func TestRaceDialSingleSuccessfulDialer(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto := "test-protocol" + + mockConn := &MockConn{ + remoteAddr: &MockAddr{network: proto}, + } + + mockDialer := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return mockConn, nil + }, + protocolStr: proto, + } + + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn == nil { + t.Errorf("Expected non-nil connection") + } +} + +func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto2 := "protocol2" + + mockConn2 := &MockConn{ + remoteAddr: &MockAddr{network: proto2}, + } + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("first dialer failed") + }, + protocolStr: "proto1", + } + + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return mockConn2, nil + }, + protocolStr: "proto2", + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn.RemoteAddr().Network() != proto2 { + t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) + } +} + +func TestRaceDialTimeout(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + connectionTimeout = 3 * time.Second + mockDialer := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + protocolStr: "proto1", + } + + rd := NewRaceDial(logger, serverURL, mockDialer) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection, got %v", conn) + } +} + +func TestRaceDialAllDialersFail(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("first dialer failed") + }, + protocolStr: "protocol1", + } + + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + return nil, errors.New("second dialer failed") + }, + protocolStr: "protocol2", + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err == nil { + t.Errorf("Expected an error, got nil") + } + if conn != nil { + t.Errorf("Expected nil connection, got %v", conn) + } +} + +func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + serverURL := "test.server.com" + proto1 := "protocol1" + proto2 := "protocol2" + + mockConn1 := &MockConn{ + remoteAddr: &MockAddr{network: proto1}, + } + + mockConn2 := &MockConn{ + remoteAddr: &MockAddr{network: proto2}, + } + + mockDialer1 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + time.Sleep(1 * time.Second) + return mockConn1, nil + }, + protocolStr: proto1, + } + + mock2err := make(chan error) + mockDialer2 := &MockDialer{ + dialFunc: func(ctx context.Context, address string) (net.Conn, error) { + <-ctx.Done() + mock2err <- ctx.Err() + return mockConn2, ctx.Err() + }, + protocolStr: proto2, + } + + rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + conn, err := rd.Dial() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if conn == nil { + t.Errorf("Expected non-nil connection") + } + if conn != mockConn1 { + t.Errorf("Expected first connection, got %v", conn) + } + + select { + case <-time.After(3 * time.Second): + t.Errorf("Timed out waiting for second dialer to finish") + case err := <-mock2err: + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled error, got %v", err) + } + } +} diff --git a/relay/client/dialer/ws/addr.go b/relay/client/dialer/ws/addr.go index 43f5dd6af..11158cfbd 100644 --- a/relay/client/dialer/ws/addr.go +++ b/relay/client/dialer/ws/addr.go @@ -1,11 +1,15 @@ package ws +const ( + Network = "ws" +) + type WebsocketAddr struct { addr string } func (a WebsocketAddr) Network() string { - return "websocket" + return Network } func (a WebsocketAddr) String() string { diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go index e7f771b8d..74bcafd82 100644 --- a/relay/client/dialer/ws/conn.go +++ b/relay/client/dialer/ws/conn.go @@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { func (c *Conn) Read(b []byte) (n int, err error) { t, ioReader, err := c.Conn.Reader(c.ctx) if err != nil { + // todo use ErrClosedByServer return 0, err } diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index d9388aafd..df91a66d4 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -2,6 +2,7 @@ package ws import ( "context" + "errors" "fmt" "net" "net/http" @@ -15,7 +16,14 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func Dial(address string) (net.Conn, error) { +type Dialer struct { +} + +func (d Dialer) Protocol() string { + return "WS" +} + +func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { wsURL, err := prepareURL(address) if err != nil { return nil, err @@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) { } parsedURL.Path = ws.URLPath - wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) + wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts) if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) return nil, err } diff --git a/relay/messages/message.go b/relay/messages/message.go index 39ca0aa90..7794c57bc 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -23,20 +23,26 @@ const ( MsgTypeAuth = 6 MsgTypeAuthResponse = 7 - SizeOfVersionByte = 1 - SizeOfMsgType = 1 + // base size of the message + sizeOfVersionByte = 1 + sizeOfMsgType = 1 + sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType - SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType - - sizeOfMagicByte = 4 - - headerSizeTransport = IDSize + // auth message + sizeOfMagicByte = 4 + headerSizeAuth = sizeOfMagicByte + IDSize + offsetMagicByte = sizeOfProtoHeader + offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte + headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth + // hello message headerSizeHello = sizeOfMagicByte + IDSize headerSizeHelloResp = 0 - headerSizeAuth = sizeOfMagicByte + IDSize - headerSizeAuthResp = 0 + // transport + headerSizeTransport = IDSize + offsetTransportID = sizeOfProtoHeader + headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport ) var ( @@ -73,7 +79,7 @@ func (m MsgType) String() string { // ValidateVersion checks if the given version is supported by the protocol func ValidateVersion(msg []byte) (int, error) { - if len(msg) < SizeOfVersionByte { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } version := int(msg[0]) @@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) { // DetermineClientMessageType determines the message type from the first the message func DetermineClientMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHello, @@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { // DetermineServerMessageType determines the message type from the first the message func DetermineServerMessageType(msg []byte) (MsgType, error) { - if len(msg) < SizeOfMsgType { + if len(msg) < sizeOfProtoHeader { return 0, ErrInvalidMessageLength } - msgType := MsgType(msg[0]) + msgType := MsgType(msg[1]) switch msgType { case MsgTypeHelloResponse, @@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions)) + msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHello) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) msg = append(msg, peerID...) msg = append(msg, additions...) @@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { - if len(msg) < headerSizeHello { + if len(msg) < sizeOfProtoHeader+headerSizeHello { return nil, nil, ErrInvalidMessageLength } - if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil + return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil } // Deprecated: Use MarshalAuthResponse instead. @@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // servers. func MarshalHelloResponse(additionalData []byte) ([]byte, error) { - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) + msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeHelloResponse) @@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) { // Deprecated: Use UnmarshalAuthResponse instead. // UnmarshalHelloResponse extracts the additional data from the hello response message. func UnmarshalHelloResponse(msg []byte) ([]byte, error) { - if len(msg) < headerSizeHelloResp { + if len(msg) < sizeOfProtoHeader+headerSizeHelloResp { return nil, ErrInvalidMessageLength } return msg, nil @@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) + msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuth) - copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + copy(msg[sizeOfProtoHeader:], magicHeader) msg = append(msg, peerID...) msg = append(msg, authPayload...) @@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { // UnmarshalAuthMsg extracts peerID and the auth payload from the message func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { - if len(msg) < headerSizeAuth { + if len(msg) < headerTotalSizeAuth { return nil, nil, ErrInvalidMessageLength } - if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil + return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil } // MarshalAuthResponse creates a response message to the auth. @@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { // servers. func MarshalAuthResponse(address string) ([]byte, error) { ab := []byte(address) - msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) + msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuthResponse) @@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) { // UnmarshalAuthResponse it is a confirmation message to auth success func UnmarshalAuthResponse(msg []byte) (string, error) { - if len(msg) < headerSizeAuthResp+1 { + if len(msg) < sizeOfProtoHeader+1 { return "", ErrInvalidMessageLength } - return string(msg), nil + return string(msg[sizeOfProtoHeader:]), nil } // MarshalCloseMsg creates a close message. // The close message is used to close the connection gracefully between the client and the server. The server and the // client can send this message. After receiving this message, the server or client will close the connection. func MarshalCloseMsg() []byte { - msg := make([]byte, SizeOfProtoHeader) - - msg[0] = byte(CurrentProtocolVersion) - msg[1] = byte(MsgTypeClose) - - return msg + return []byte{ + byte(CurrentProtocolVersion), + byte(MsgTypeClose), + } } // MarshalTransportMsg creates a transport message. // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // destination peer hashed ID. -func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { +func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { if len(peerID) != IDSize { return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) } - msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload)) - + msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeTransport) - - copy(msg[SizeOfProtoHeader:], peerID) - + copy(msg[sizeOfProtoHeader:], peerID) msg = append(msg, payload...) return msg, nil @@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { // UnmarshalTransportMsg extracts the peerID and the payload from the transport message. func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { - if len(buf) < headerSizeTransport { + if len(buf) < headerTotalSizeTransport { return nil, nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], buf[headerSizeTransport:], nil + return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil } // UnmarshalTransportID extracts the peerID from the transport message. func UnmarshalTransportID(buf []byte) ([]byte, error) { - if len(buf) < headerSizeTransport { + if len(buf) < headerTotalSizeTransport { return nil, ErrInvalidMessageLength } - return buf[:headerSizeTransport], nil + return buf[offsetTransportID:headerTotalSizeTransport], nil } // UpdateTransportMsg updates the peerID in the transport message. // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // need to allocate a new byte slice. func UpdateTransportMsg(msg []byte, peerID []byte) error { - if len(msg) < len(peerID) { + if len(msg) < offsetTransportID+len(peerID) { return ErrInvalidMessageLength } - copy(msg, peerID) + copy(msg[offsetTransportID:], peerID) return nil } diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index 6e917da71..19bede07b 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -6,12 +6,21 @@ import ( func TestMarshalHelloMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") - bHello, err := MarshalHelloMsg(peerID, nil) + msg, err := MarshalHelloMsg(peerID, nil) if err != nil { t.Fatalf("error: %v", err) } - receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeHello { + t.Errorf("expected %d, got %d", MsgTypeHello, msgType) + } + + receivedPeerID, _, err := UnmarshalHelloMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") - bHello, err := MarshalAuthMsg(peerID, []byte{}) + msg, err := MarshalAuthMsg(peerID, []byte{}) if err != nil { t.Fatalf("error: %v", err) } - receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeAuth { + t.Errorf("expected %d, got %d", MsgTypeAuth, msgType) + } + + receivedPeerID, _, err := UnmarshalAuthMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) { } } +func TestMarshalAuthResponse(t *testing.T) { + address := "myaddress" + msg, err := MarshalAuthResponse(address) + if err != nil { + t.Fatalf("error: %v", err) + } + + msgType, err := DetermineServerMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeAuthResponse { + t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType) + } + + respAddr, err := UnmarshalAuthResponse(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + if respAddr != address { + t.Errorf("expected %s, got %s", address, respAddr) + } +} + func TestMarshalTransportMsg(t *testing.T) { peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") payload := []byte("payload") @@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("error: %v", err) } - id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:]) + msgType, err := DetermineClientMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeTransport { + t.Errorf("expected %d, got %d", MsgTypeTransport, msgType) + } + + uPeerID, err := UnmarshalTransportID(msg) + if err != nil { + t.Fatalf("failed to unmarshal transport id: %v", err) + } + + if string(uPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, uPeerID) + } + + id, respPayload, err := UnmarshalTransportMsg(msg) if err != nil { t.Fatalf("error: %v", err) } @@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) { t.Errorf("expected %s, got %s", payload, respPayload) } } + +func TestMarshalHealthcheck(t *testing.T) { + msg := MarshalHealthcheck() + + _, err := ValidateVersion(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + msgType, err := DetermineServerMessageType(msg) + if err != nil { + t.Fatalf("error: %v", err) + } + + if msgType != MsgTypeHealthCheck { + t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType) + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 0257300f8..babd6f955 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) } - _, err = messages.ValidateVersion(buf[:n]) + buf = buf[:n] + + _, err = messages.ValidateVersion(buf) if err != nil { return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) } - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + msgType, err := messages.DetermineClientMessageType(buf) if err != nil { return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) } @@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) { switch msgType { //nolint:staticcheck case messages.MsgTypeHello: - bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + bytePeerID, peerID, err = h.handleHelloMsg(buf) case messages.MsgTypeAuth: h.handshakeMethodAuth = true - bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + bytePeerID, peerID, err = h.handleAuthMsg(buf) default: return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go new file mode 100644 index 000000000..909ec1cc6 --- /dev/null +++ b/relay/server/listener/quic/conn.go @@ -0,0 +1,101 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/quic-go/quic-go" +) + +type Conn struct { + session quic.Connection + closed bool + closedMu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc +} + +func NewConn(session quic.Connection) *Conn { + ctx, cancel := context.WithCancel(context.Background()) + return &Conn{ + session: session, + ctx: ctx, + ctxCancel: cancel, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(c.ctx) + if err != nil { + return 0, c.remoteCloseErrHandling(err) + } + // Copy data to b, ensuring we don’t exceed the size of b + n = copy(b, dgram) + return n, nil +} + +func (c *Conn) Write(b []byte) (int, error) { + if err := c.session.SendDatagram(b); err != nil { + return 0, c.remoteCloseErrHandling(err) + } + return len(b), nil +} + +func (c *Conn) LocalAddr() net.Addr { + return c.session.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.session.RemoteAddr() +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return fmt.Errorf("SetDeadline is not implemented") +} + +func (c *Conn) Close() error { + c.closedMu.Lock() + if c.closed { + c.closedMu.Unlock() + return nil + } + c.closed = true + c.closedMu.Unlock() + + c.ctxCancel() // Cancel the context + + sessionErr := c.session.CloseWithError(0, "normal closure") + return sessionErr +} + +func (c *Conn) isClosed() bool { + c.closedMu.Lock() + defer c.closedMu.Unlock() + return c.closed +} + +func (c *Conn) remoteCloseErrHandling(err error) error { + if c.isClosed() { + return net.ErrClosed + } + + // Check if the connection was closed remotely + var appErr *quic.ApplicationError + if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 { + return net.ErrClosed + } + + return err +} diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go new file mode 100644 index 000000000..b6e01994f --- /dev/null +++ b/relay/server/listener/quic/listener.go @@ -0,0 +1,66 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + + "github.com/quic-go/quic-go" + log "github.com/sirupsen/logrus" +) + +type Listener struct { + // Address is the address to listen on + Address string + // TLSConfig is the TLS configuration for the server + TLSConfig *tls.Config + + listener *quic.Listener + acceptFn func(conn net.Conn) +} + +func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + l.acceptFn = acceptFn + + quicCfg := &quic.Config{ + EnableDatagrams: true, + } + listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg) + if err != nil { + return fmt.Errorf("failed to create QUIC listener: %v", err) + } + + l.listener = listener + log.Infof("QUIC server listening on address: %s", l.Address) + + for { + session, err := listener.Accept(context.Background()) + if err != nil { + if errors.Is(err, quic.ErrServerClosed) { + return nil + } + + log.Errorf("Failed to accept QUIC session: %v", err) + continue + } + + log.Infof("QUIC client connected from: %s", session.RemoteAddr()) + conn := NewConn(session) + l.acceptFn(conn) + } +} + +func (l *Listener) Shutdown(ctx context.Context) error { + if l.listener == nil { + return nil + } + + log.Infof("stopping QUIC listener") + if err := l.listener.Close(); err != nil { + return fmt.Errorf("listener shutdown failed: %v", err) + } + log.Infof("QUIC listener stopped") + return nil +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 5c62c0826..0eb244c77 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } + log.Infof("WS client connected from: %s", rAddr) + conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) } diff --git a/relay/server/peer.go b/relay/server/peer.go index f65fb786a..aa9790f63 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -84,7 +84,7 @@ func (p *Peer) Work() { return } - msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:]) + msgType, err := messages.DetermineClientMessageType(msg) if err != nil { p.log.Errorf("failed to determine message type: %s", err) return @@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send } func (p *Peer) handleTransportMsg(msg []byte) { - peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:]) + peerID, err := messages.UnmarshalTransportID(msg) if err != nil { p.log.Errorf("failed to unmarshal transport message: %s", err) return @@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB) + err = messages.UpdateTransportMsg(msg, p.idB) if err != nil { p.log.Errorf("failed to update transport message: %s", err) return diff --git a/relay/server/relay.go b/relay/server/relay.go index 6cd8506ae..a5e77bc61 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Shutdown(ctx context.Context) { log.Infof("close connection with all peers") r.closeMu.Lock() + defer r.closeMu.Unlock() + wg := sync.WaitGroup{} peers := r.store.Peers() for _, peer := range peers { @@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) { } wg.Wait() r.metricsCancel() - r.closeMu.Unlock() + r.closed = true } // InstanceURL returns the instance URL of the relay server diff --git a/relay/server/server.go b/relay/server/server.go index 0036e2390..cacc3dafb 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,13 +3,17 @@ package server import ( "context" "crypto/tls" + "sync" - log "github.com/sirupsen/logrus" + "github.com/hashicorp/go-multierror" "go.opentelemetry.io/otel/metric" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" + "github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/ws" + quictls "github.com/netbirdio/netbird/relay/tls" ) // ListenerConfig is the configuration for the listener. @@ -24,8 +28,8 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - relay *Relay - wSListener listener.Listener + relay *Relay + listeners []listener.Listener } // NewServer creates a new relay server instance. @@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV return nil, err } return &Server{ - relay: relay, + relay: relay, + listeners: make([]listener.Listener, 0, 2), }, nil } // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.wSListener = &ws.Listener{ + wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, } + r.listeners = append(r.listeners, wSListener) - wslErr := r.wSListener.Listen(r.relay.Accept) - if wslErr != nil { - log.Errorf("failed to bind ws server: %s", wslErr) + tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig) + if err != nil { + return err } - return wslErr + quicListener := &quic.Listener{ + Address: cfg.Address, + TLSConfig: tlsConfigQUIC, + } + + r.listeners = append(r.listeners, quicListener) + + errChan := make(chan error, len(r.listeners)) + wg := sync.WaitGroup{} + for _, l := range r.listeners { + wg.Add(1) + go func(listener listener.Listener) { + defer wg.Done() + errChan <- listener.Listen(r.relay.Accept) + }(l) + } + + wg.Wait() + close(errChan) + var multiErr *multierror.Error + for err := range errChan { + multiErr = multierror.Append(multiErr, err) + } + + return nberrors.FormatErrorOrNil(multiErr) } // Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context, // the connections will be forcefully closed. -func (r *Server) Shutdown(ctx context.Context) (err error) { - // stop service new connections - if r.wSListener != nil { - err = r.wSListener.Shutdown(ctx) - } - +func (r *Server) Shutdown(ctx context.Context) error { r.relay.Shutdown(ctx) - return + + var multiErr *multierror.Error + for _, l := range r.listeners { + if err := l.Shutdown(ctx); err != nil { + multiErr = multierror.Append(multiErr, err) + } + } + return nberrors.FormatErrorOrNil(multiErr) } // InstanceURL returns the instance URL of the relay server. diff --git a/relay/tls/alpn.go b/relay/tls/alpn.go new file mode 100644 index 000000000..29497d401 --- /dev/null +++ b/relay/tls/alpn.go @@ -0,0 +1,3 @@ +package tls + +const nbalpn = "nb-quic" diff --git a/relay/tls/client_dev.go b/relay/tls/client_dev.go new file mode 100644 index 000000000..f6b8290a0 --- /dev/null +++ b/relay/tls/client_dev.go @@ -0,0 +1,12 @@ +//go:build devcert + +package tls + +import "crypto/tls" + +func ClientQUICTLSConfig() *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // Debug mode allows insecure connections + NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN + } +} diff --git a/relay/tls/client_prod.go b/relay/tls/client_prod.go new file mode 100644 index 000000000..686093a37 --- /dev/null +++ b/relay/tls/client_prod.go @@ -0,0 +1,11 @@ +//go:build !devcert + +package tls + +import "crypto/tls" + +func ClientQUICTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: []string{nbalpn}, + } +} diff --git a/relay/tls/doc.go b/relay/tls/doc.go new file mode 100644 index 000000000..38b807f84 --- /dev/null +++ b/relay/tls/doc.go @@ -0,0 +1,36 @@ +// Package tls provides utilities for configuring and managing Transport Layer +// Security (TLS) in server and client environments, with a focus on QUIC +// protocol support and testing configurations. +// +// The package includes functions for cloning and customizing TLS +// configurations as well as generating self-signed certificates for +// development and testing purposes. +// +// Key Features: +// +// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored +// for QUIC protocol with specified or default settings. QUIC requires a +// specific TLS configuration with proper ALPN (Application-Layer Protocol +// Negotiation) support, making the TLS settings crucial for establishing +// secure connections. +// +// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable +// for QUIC protocol. The configuration differs between development +// (insecure testing) and production (strict verification). +// +// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for +// use in local development and testing scenarios. +// +// Usage: +// +// This package provides separate implementations for development and production +// environments. The development implementation (guarded by `//go:build devcert`) +// supports testing configurations with self-signed certificates and insecure +// client connections. The production implementation (guarded by `//go:build +// !devcert`) ensures that valid and secure TLS configurations are supplied +// and used. +// +// The QUIC protocol is highly reliant on properly configured TLS settings, +// and this package ensures that configurations meet the requirements for +// secure and efficient QUIC communication. +package tls diff --git a/relay/tls/server_dev.go b/relay/tls/server_dev.go new file mode 100644 index 000000000..1a01658fc --- /dev/null +++ b/relay/tls/server_dev.go @@ -0,0 +1,79 @@ +//go:build devcert + +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { + if originTLSCfg == nil { + log.Warnf("QUIC server will use self signed certificate for testing!") + return generateTestTLSConfig() + } + + cfg := originTLSCfg.Clone() + cfg.NextProtos = []string{nbalpn} + return cfg, nil +} + +// GenerateTestTLSConfig creates a self-signed certificate for testing +func generateTestTLSConfig() (*tls.Config, error) { + log.Infof("generating test TLS config") + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + // Create certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{nbalpn}, + }, nil +} diff --git a/relay/tls/server_prod.go b/relay/tls/server_prod.go new file mode 100644 index 000000000..9d1c47d88 --- /dev/null +++ b/relay/tls/server_prod.go @@ -0,0 +1,17 @@ +//go:build !devcert + +package tls + +import ( + "crypto/tls" + "fmt" +) + +func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) { + if originTLSCfg == nil { + return nil, fmt.Errorf("valid TLS config is required for QUIC listener") + } + cfg := originTLSCfg.Clone() + cfg.NextProtos = []string{nbalpn} + return cfg, nil +}