Merge branch 'main' into userspace-router

This commit is contained in:
Viktor Liu 2025-01-15 17:00:37 +01:00
commit ea6c947f5d
50 changed files with 1255 additions and 443 deletions

View File

@ -44,4 +44,5 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@ -134,7 +134,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management: test_management:
needs: [ build-cache ] needs: [ build-cache ]
@ -194,7 +194,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark: benchmark:
needs: [ build-cache ] needs: [ build-cache ]
@ -254,7 +254,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
api_benchmark: api_benchmark:
needs: [ build-cache ] needs: [ build-cache ]

View File

@ -65,7 +65,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@ -19,8 +19,7 @@ const (
tableName = "filter" tableName = "filter"
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string type aclEntries map[string][][]string
@ -84,7 +83,6 @@ func (m *aclManager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
@ -97,15 +95,10 @@ func (m *aclManager) AddPeerFiltering(
sPortVal = strconv.Itoa(sPort.Values[0]) sPortVal = strconv.Itoa(sPort.Values[0])
} }
var chain string chain := chainNameInputRules
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
@ -214,28 +207,7 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error { func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil { if err != nil {
log.Debugf("failed to list chains: %s", err) log.Debugf("failed to list chains: %s", err)
return err return err
@ -295,12 +267,6 @@ func (m *aclManager) createDefaultChains() error {
return err return err
} }
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@ -329,8 +295,6 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished() established := getConntrackEstablished()
@ -390,30 +354,18 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" { if ip.String() == "0.0.0.0" {
matchByIP = false matchByIP = false
} }
switch direction {
case firewall.RuleDirectionIN: if matchByIP {
if matchByIP { if ipsetName != "" {
if ipsetName != "" { specs = append(specs, "-m", "set", "--set", ipsetName, "src")
specs = append(specs, "-m", "set", "--set", ipsetName, "src") } else {
} else { specs = append(specs, "-s", ip.String())
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
} }
} }
if protocol != "all" { if protocol != "all" {

View File

@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, _ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
"all", "all",
nil, nil,
nil, nil,
firewall.RuleDirectionIN,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"", "",

View File

@ -68,27 +68,13 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@ -97,15 +83,6 @@ func TestIptablesManager(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -119,7 +96,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Reset(nil)
@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) {
} }
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
@ -167,33 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -201,15 +155,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -270,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -69,7 +69,6 @@ type Manager interface {
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
direction RuleDirection,
action Action, action Action,
ipsetName string, ipsetName string,
comment string, comment string,

View File

@ -22,8 +22,7 @@ import (
const ( const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules" chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
@ -45,9 +44,8 @@ type AclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@ -89,7 +87,6 @@ func (m *AclManager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -104,7 +101,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,38 +211,6 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn, Exprs: expIn,
}) })
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@ -264,15 +229,19 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@ -310,9 +279,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
if !bytes.HasPrefix(anyIP, rawIP) { if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
addrOffset := uint32(12) addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
@ -383,12 +349,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(strings.Join([]string{ruleId, comment}, " "))
var chain *nftables.Chain chain := m.chainInputRules
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
@ -419,15 +380,6 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
m.chainInputRules = chain m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@ -720,15 +672,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generatePeerRuleId( func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
ip net.IP, rulesetID := ":"
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil { if sPort != nil {
rulesetID += sPort.String() rulesetID += sPort.String()
} }

View File

@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@ -210,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering( _, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@ -16,7 +16,6 @@ type PeerRule struct {
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
direction firewall.RuleDirection
sPort uint16 sPort uint16
dPort uint16 dPort uint16
drop bool drop bool

View File

@ -45,7 +45,9 @@ type RuleSet map[string]PeerRule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules map[string]RuleSet // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
routeRules map[string]RouteRule routeRules map[string]RouteRule
wgNetwork *net.IPNet wgNetwork *net.IPNet
@ -297,9 +299,8 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, _ string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
r := PeerRule{ r := PeerRule{
@ -307,7 +308,6 @@ func (m *Manager) AddPeerFiltering(
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment, comment: comment,
} }
@ -343,17 +343,10 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()]; !ok {
if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(RuleSet)
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@ -416,19 +409,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if r.direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
_, ok := m.incomingRules[r.ip.String()][r.id] return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id)
return nil return nil
} }
@ -918,7 +902,6 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: dPort, dPort: dPort,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
@ -929,7 +912,6 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock() m.mutex.Lock()
if in { if in {
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule) m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
} }
@ -948,19 +930,22 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, arr := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }
for _, arr := range m.outgoingRules { for _, arr := range m.outgoingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }

View File

@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@ -117,7 +117,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}}, &fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") fw.ActionAccept, "", "explicit return")
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@ -129,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") fw.ActionDrop, "", "default drop")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@ -593,7 +593,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -684,7 +684,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -802,7 +802,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -889,7 +889,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }

View File

@ -91,11 +91,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -126,37 +125,15 @@ func TestManagerDeleteRule(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule 2"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
@ -246,10 +223,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return return
} }
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil { if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
@ -272,11 +245,10 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -319,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
direction := fw.RuleDirectionIN
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -523,11 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.rollBack(newRulePairs) d.rollBack(newRulePairs)
break break
} }
if len(rules) > 0 { if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair newRulePairs[pairID] = rulePair
} }
@ -288,6 +288,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@ -308,25 +310,12 @@ func (d *DefaultManager) addInRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( return rule, nil
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
@ -337,25 +326,16 @@ func (d *DefaultManager) addOutRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return nil, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
return append(rules, rule...), nil return rule, nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getPeerRuleID() returns unique ID for the rule based on its parameters.

View File

@ -120,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return return
} }
}) })
@ -358,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 { if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return return
} }
} }

View File

@ -48,11 +48,17 @@ type restoreHostManager interface {
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get os dns manager type: %w", err)
} }
log.Infof("System DNS manager discovered: %s", osManager) log.Infof("System DNS manager discovered: %s", osManager)
return newHostManagerFromType(wgInterface, osManager) mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
return nil, fmt.Errorf("create host manager: %w", err)
}
return mgr, nil
} }
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {

View File

@ -12,6 +12,7 @@ import (
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@ -239,7 +240,10 @@ func (s *DefaultServer) Initialize() (err error) {
s.stateManager.RegisterState(&ShutdownState{}) s.stateManager.RegisterState(&ShutdownState{})
if s.disableSys { // use noop host manager if requested or running in netstack mode.
// Netstack mode currently doesn't have a way to receive DNS requests.
// TODO: Use listener on localhost in netstack mode when running as root.
if s.disableSys || netstack.IsEnabled() {
log.Info("system DNS is disabled, not setting up host manager") log.Info("system DNS is disabled, not setting up host manager")
s.hostManager = &noopHostConfigurator{} s.hostManager = &noopHostConfigurator{}
return nil return nil

View File

@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil return nil
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@ -499,7 +499,6 @@ func (e *Engine) initFirewall() error {
manager.ProtocolUDP, manager.ProtocolUDP,
nil, nil,
&port, &port,
manager.RuleDirectionIN,
manager.ActionAccept, manager.ActionAccept,
"", "",
"", "",

View File

@ -4,13 +4,27 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"os/exec" "os/exec"
"runtime" "runtime"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func isRoot() bool {
return os.Geteuid() == 0
}
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
if !isRoot() {
shell := getUserShell(user)
if shell == "" {
shell = "/bin/sh"
}
return shell, []string{"-l"}, nil
}
loginPath, err = exec.LookPath("login") loginPath, err = exec.LookPath("login")
if err != nil { if err != nil {
return "", nil, err return "", nil, err

View File

@ -6,5 +6,9 @@ package ssh
import "os/user" import "os/user"
func userNameLookup(username string) (*user.User, error) { func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
return user.Lookup(username) return user.Lookup(username)
} }

View File

@ -12,6 +12,10 @@ import (
) )
func userNameLookup(username string) (*user.User, error) { func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
var userObject *user.User var userObject *user.User
userObject, err := user.Lookup(username) userObject, err := user.Lookup(username)
if err != nil && err.Error() == user.UnknownUserError(username).Error() { if err != nil && err.Error() == user.UnknownUserError(username).Error() {

View File

@ -39,6 +39,9 @@ func GetInfo(ctx context.Context) *Info {
WiretrusteeVersion: version.NetbirdVersion(), WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUIVersion(ctx), UIVersion: extractUIVersion(ctx),
KernelVersion: kernelVersion, KernelVersion: kernelVersion,
SystemSerialNumber: serial(),
SystemProductName: productModel(),
SystemManufacturer: productManufacturer(),
} }
return gio return gio
@ -49,13 +52,42 @@ func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil return []File{}, nil
} }
func serial() string {
// try to fetch serial ID using different properties
properties := []string{"ril.serialnumber", "ro.serialno", "ro.boot.serialno", "sys.serialnumber"}
var value string
for _, property := range properties {
value = getprop(property)
if len(value) > 0 {
return value
}
}
// unable to get serial ID, fallback to ANDROID_ID
return androidId()
}
func androidId() string {
// this is a uniq id defined on first initialization, id will be a new one if user wipes his device
return run("/system/bin/settings", "get", "secure", "android_id")
}
func productModel() string {
return getprop("ro.product.model")
}
func productManufacturer() string {
return getprop("ro.product.manufacturer")
}
func uname() []string { func uname() []string {
res := run("/system/bin/uname", "-a") res := run("/system/bin/uname", "-a")
return strings.Split(res, " ") return strings.Split(res, " ")
} }
func osVersion() string { func osVersion() string {
return run("/system/bin/getprop", "ro.build.version.release") return getprop("ro.build.version.release")
} }
func extractUIVersion(ctx context.Context) string { func extractUIVersion(ctx context.Context) string {
@ -66,6 +98,10 @@ func extractUIVersion(ctx context.Context) string {
return v return v
} }
func getprop(arg ...string) string {
return run("/system/bin/getprop", arg...)
}
func run(name string, arg ...string) string { func run(name string, arg ...string) string {
cmd := exec.Command(name, arg...) cmd := exec.Command(name, arg...)
cmd.Stdin = strings.NewReader("some") cmd.Stdin = strings.NewReader("some")

4
go.mod
View File

@ -71,6 +71,7 @@ require (
github.com/pion/transport/v3 v3.0.1 github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.48.2
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4 github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@ -156,11 +157,13 @@ require (
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.0 // indirect github.com/go-text/typesetting v0.2.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.7 // indirect github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@ -222,6 +225,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect golang.org/x/mod v0.17.0 // indirect

6
go.sum
View File

@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac" auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer"
"github.com/netbirdio/netbird/relay/client/dialer/quic"
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
msg.Free() msg.Free()
default: default:
msg.Free() msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
} }
} }
@ -179,8 +179,7 @@ func (c *Client) Connect() error {
return nil return nil
} }
err := c.connect() if err := c.connect(); err != nil {
if err != nil {
return err return err
} }
@ -264,14 +263,14 @@ func (c *Client) Close() error {
} }
func (c *Client) connect() error { func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL) rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil { if err != nil {
return err return err
} }
c.relayConn = conn c.relayConn = conn
err = c.handShake() if err = c.handShake(); err != nil {
if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
@ -306,7 +305,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("validate version: %w", err) return fmt.Errorf("validate version: %w", err)
} }
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return err
@ -317,7 +316,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil { if err != nil {
return err return err
} }
@ -345,27 +344,30 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Infof("start to Relay read loop exit") c.log.Infof("start to Relay read loop exit")
c.mu.Lock() c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() { if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Errorf("failed to read message from relay server: %s", errExit)
} }
c.mu.Unlock() c.mu.Unlock()
c.bufPool.Put(bufPtr)
break break
} }
_, err := messages.ValidateVersion(buf[:n]) buf = buf[:n]
_, err := messages.ValidateVersion(buf)
if err != nil { if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err) c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
continue continue
} }
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) msgType, err := messages.DetermineServerMessageType(buf)
if err != nil { if err != nil {
c.log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
continue continue
} }
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) { if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
break break
} }
} }

View File

@ -0,0 +1,7 @@
package net
import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
)

View File

@ -0,0 +1,97 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
)
const (
Network = "quic"
)
type Addr struct {
addr string
}
func (a Addr) Network() string {
return Network
}
func (a Addr) String() string {
return a.addr
}
type Conn struct {
session quic.Connection
ctx context.Context
}
func NewConn(session quic.Connection) net.Conn {
return &Conn{
session: session,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) LocalAddr() net.Addr {
if c.session != nil {
return c.session.LocalAddr()
}
return Addr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
}
func (c *Conn) Close() error {
return c.session.CloseWithError(0, "normal closure")
}
func (c *Conn) remoteCloseErrHandling(err error) error {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return netErr.ErrClosedByServer
}
return err
}

View File

@ -0,0 +1,71 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/relay/tls"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
conn := NewConn(session)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
if strings.HasPrefix(address, "rels://") {
return address[7:], nil
}
return address[6:], nil
}

View File

@ -0,0 +1,96 @@
package dialer
import (
"context"
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
Protocol() string
}
type dialResult struct {
Conn net.Conn
Protocol string
Err error
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
defer abort()
for _, dfn := range r.dialerFns {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
}
return conn, nil
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
var hasWinner bool
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
if errors.Is(dr.Err, context.Canceled) {
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
}
continue
}
if hasWinner {
if cerr := dr.Conn.Close(); cerr != nil {
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
}
continue
}
r.log.Infof("successfully dialed via: %s", dr.Protocol)
abort()
hasWinner = true
winnerConn <- dr.Conn
}
close(winnerConn)
}

View File

@ -0,0 +1,252 @@
package dialer
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockAddr struct {
network string
}
func (m *MockAddr) Network() string {
return m.network
}
func (m *MockAddr) String() string {
return "1.2.3.4"
}
// MockDialer is a mock implementation of DialeFn
type MockDialer struct {
dialFunc func(ctx context.Context, address string) (net.Conn, error)
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}
func (m *MockDialer) Protocol() string {
return m.protocolStr
}
// MockConn implements net.Conn for testing
type MockConn struct {
remoteAddr net.Addr
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return nil
}
func (m *MockConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
}
}
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto := "test-protocol"
mockConn := &MockConn{
remoteAddr: &MockAddr{network: proto},
}
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn, nil
},
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
}
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto2 := "protocol2"
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "proto1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn2, nil
},
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
},
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialAllDialersFail(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "protocol1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("second dialer failed")
},
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto1 := "protocol1"
proto2 := "protocol2"
mockConn1 := &MockConn{
remoteAddr: &MockAddr{network: proto1},
}
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
time.Sleep(1 * time.Second)
return mockConn1, nil
},
protocolStr: proto1,
}
mock2err := make(chan error)
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
mock2err <- ctx.Err()
return mockConn2, ctx.Err()
},
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
if conn != mockConn1 {
t.Errorf("Expected first connection, got %v", conn)
}
select {
case <-time.After(3 * time.Second):
t.Errorf("Timed out waiting for second dialer to finish")
case err := <-mock2err:
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}
}

View File

@ -1,11 +1,15 @@
package ws package ws
const (
Network = "ws"
)
type WebsocketAddr struct { type WebsocketAddr struct {
addr string addr string
} }
func (a WebsocketAddr) Network() string { func (a WebsocketAddr) Network() string {
return "websocket" return Network
} }
func (a WebsocketAddr) String() string { func (a WebsocketAddr) String() string {

View File

@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx) t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil { if err != nil {
// todo use ErrClosedByServer
return 0, err return 0, err
} }

View File

@ -2,6 +2,7 @@ package ws
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -15,7 +16,14 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func Dial(address string) (net.Conn, error) { type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
wsURL, err := prepareURL(address) wsURL, err := prepareURL(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
} }
parsedURL.Path = ws.URLPath parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err return nil, err
} }

View File

@ -23,20 +23,26 @@ const (
MsgTypeAuth = 6 MsgTypeAuth = 6
MsgTypeAuthResponse = 7 MsgTypeAuthResponse = 7
SizeOfVersionByte = 1 // base size of the message
SizeOfMsgType = 1 sizeOfVersionByte = 1
sizeOfMsgType = 1
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType // auth message
sizeOfMagicByte = 4
sizeOfMagicByte = 4 headerSizeAuth = sizeOfMagicByte + IDSize
offsetMagicByte = sizeOfProtoHeader
headerSizeTransport = IDSize offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
headerSizeAuth = sizeOfMagicByte + IDSize // transport
headerSizeAuthResp = 0 headerSizeTransport = IDSize
offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
) )
var ( var (
@ -73,7 +79,7 @@ func (m MsgType) String() string {
// ValidateVersion checks if the given version is supported by the protocol // ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) { func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte { if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength return 0, ErrInvalidMessageLength
} }
version := int(msg[0]) version := int(msg[0])
@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {
// DetermineClientMessageType determines the message type from the first the message // DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) { func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType { if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength return 0, ErrInvalidMessageLength
} }
msgType := MsgType(msg[0]) msgType := MsgType(msg[1])
switch msgType { switch msgType {
case case
MsgTypeHello, MsgTypeHello,
@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
// DetermineServerMessageType determines the message type from the first the message // DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) { func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType { if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength return 0, ErrInvalidMessageLength
} }
msgType := MsgType(msg[0]) msgType := MsgType(msg[1])
switch msgType { switch msgType {
case case
MsgTypeHelloResponse, MsgTypeHelloResponse,
@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello) msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID...)
msg = append(msg, additions...) msg = append(msg, additions...)
@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello { if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead. // Deprecated: Use MarshalAuthResponse instead.
@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay // instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers. // servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) { func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse) msg[1] = byte(MsgTypeHelloResponse)
@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthResponse instead. // Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message. // UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) { func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp { if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return msg, nil return msg, nil
@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth) msg[1] = byte(MsgTypeAuth)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID...)
msg = append(msg, authPayload...) msg = append(msg, authPayload...)
@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
// UnmarshalAuthMsg extracts peerID and the auth payload from the message // UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth { if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
} }
// MarshalAuthResponse creates a response message to the auth. // MarshalAuthResponse creates a response message to the auth.
@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
// servers. // servers.
func MarshalAuthResponse(address string) ([]byte, error) { func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address) ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse) msg[1] = byte(MsgTypeAuthResponse)
@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) {
// UnmarshalAuthResponse it is a confirmation message to auth success // UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) { func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 { if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength return "", ErrInvalidMessageLength
} }
return string(msg), nil return string(msg[sizeOfProtoHeader:]), nil
} }
// MarshalCloseMsg creates a close message. // MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the // The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection. // client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte { func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader) return []byte{
byte(CurrentProtocolVersion),
msg[0] = byte(CurrentProtocolVersion) byte(MsgTypeClose),
msg[1] = byte(MsgTypeClose) }
return msg
} }
// MarshalTransportMsg creates a transport message. // MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID. // destination peer hashed ID.
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
} }
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload)) msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport) msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID)
copy(msg[SizeOfProtoHeader:], peerID)
msg = append(msg, payload...) msg = append(msg, payload...)
return msg, nil return msg, nil
@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. // UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
} }
// UnmarshalTransportID extracts the peerID from the transport message. // UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[:headerSizeTransport], nil return buf[offsetTransportID:headerTotalSizeTransport], nil
} }
// UpdateTransportMsg updates the peerID in the transport message. // UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice. // need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < len(peerID) { if len(msg) < offsetTransportID+len(peerID) {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
copy(msg, peerID) copy(msg[offsetTransportID:], peerID)
return nil return nil
} }

View File

@ -6,12 +6,21 @@ import (
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID, nil) msg, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:]) msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHello {
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
}
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) {
func TestMarshalAuthMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, []byte{}) msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:]) msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuth {
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
}
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) {
} }
} }
func TestMarshalAuthResponse(t *testing.T) {
address := "myaddress"
msg, err := MarshalAuthResponse(address)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuthResponse {
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
}
respAddr, err := UnmarshalAuthResponse(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if respAddr != address {
t.Errorf("expected %s, got %s", address, respAddr)
}
}
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")
@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:]) msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeTransport {
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
}
uPeerID, err := UnmarshalTransportID(msg)
if err != nil {
t.Fatalf("failed to unmarshal transport id: %v", err)
}
if string(uPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, uPeerID)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Errorf("expected %s, got %s", payload, respPayload) t.Errorf("expected %s, got %s", payload, respPayload)
} }
} }
func TestMarshalHealthcheck(t *testing.T) {
msg := MarshalHealthcheck()
_, err := ValidateVersion(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHealthCheck {
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
}
}

View File

@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
} }
_, err = messages.ValidateVersion(buf[:n]) buf = buf[:n]
_, err = messages.ValidateVersion(buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
} }
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) msgType, err := messages.DetermineClientMessageType(buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
} }
@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
switch msgType { switch msgType {
//nolint:staticcheck //nolint:staticcheck
case messages.MsgTypeHello: case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) bytePeerID, peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth: case messages.MsgTypeAuth:
h.handshakeMethodAuth = true h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) bytePeerID, peerID, err = h.handleAuthMsg(buf)
default: default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }

View File

@ -0,0 +1,101 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
)
type Conn struct {
session quic.Connection
closed bool
closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConn(session quic.Connection) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{
session: session,
ctx: ctx,
ctxCancel: cancel,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
// Copy data to b, ensuring we dont exceed the size of b
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
c.closedMu.Unlock()
return nil
}
c.closed = true
c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) remoteCloseErrHandling(err error) error {
if c.isClosed() {
return net.ErrClosed
}
// Check if the connection was closed remotely
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return net.ErrClosed
}
return err
}

View File

@ -0,0 +1,66 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Listener struct {
// Address is the address to listen on
Address string
// TLSConfig is the TLS configuration for the server
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
}
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
if err != nil {
return fmt.Errorf("failed to create QUIC listener: %v", err)
}
l.listener = listener
log.Infof("QUIC server listening on address: %s", l.Address)
for {
session, err := listener.Accept(context.Background())
if err != nil {
if errors.Is(err, quic.ErrServerClosed) {
return nil
}
log.Errorf("Failed to accept QUIC session: %v", err)
continue
}
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
}
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil {
return nil
}
log.Infof("stopping QUIC listener")
if err := l.listener.Close(); err != nil {
return fmt.Errorf("listener shutdown failed: %v", err)
}
log.Infof("QUIC listener stopped")
return nil
}

View File

@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return return
} }
log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, lAddr, rAddr) conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn) l.acceptFn(conn)
} }

View File

@ -84,7 +84,7 @@ func (p *Peer) Work() {
return return
} }
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:]) msgType, err := messages.DetermineClientMessageType(msg)
if err != nil { if err != nil {
p.log.Errorf("failed to determine message type: %s", err) p.log.Errorf("failed to determine message type: %s", err)
return return
@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
} }
func (p *Peer) handleTransportMsg(msg []byte) { func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:]) peerID, err := messages.UnmarshalTransportID(msg)
if err != nil { if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err) p.log.Errorf("failed to unmarshal transport message: %s", err)
return return
@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB) err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil { if err != nil {
p.log.Errorf("failed to update transport message: %s", err) p.log.Errorf("failed to update transport message: %s", err)
return return

View File

@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) {
func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) Shutdown(ctx context.Context) {
log.Infof("close connection with all peers") log.Infof("close connection with all peers")
r.closeMu.Lock() r.closeMu.Lock()
defer r.closeMu.Unlock()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
peers := r.store.Peers() peers := r.store.Peers()
for _, peer := range peers { for _, peer := range peers {
@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) {
} }
wg.Wait() wg.Wait()
r.metricsCancel() r.metricsCancel()
r.closeMu.Unlock() r.closed = true
} }
// InstanceURL returns the instance URL of the relay server // InstanceURL returns the instance URL of the relay server

View File

@ -3,13 +3,17 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"sync"
log "github.com/sirupsen/logrus" "github.com/hashicorp/go-multierror"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls"
) )
// ListenerConfig is the configuration for the listener. // ListenerConfig is the configuration for the listener.
@ -24,8 +28,8 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic. // It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct { type Server struct {
relay *Relay relay *Relay
wSListener listener.Listener listeners []listener.Listener
} }
// NewServer creates a new relay server instance. // NewServer creates a new relay server instance.
@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV
return nil, err return nil, err
} }
return &Server{ return &Server{
relay: relay, relay: relay,
listeners: make([]listener.Listener, 0, 2),
}, nil }, nil
} }
// Listen starts the relay server. // Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error { func (r *Server) Listen(cfg ListenerConfig) error {
r.wSListener = &ws.Listener{ wSListener := &ws.Listener{
Address: cfg.Address, Address: cfg.Address,
TLSConfig: cfg.TLSConfig, TLSConfig: cfg.TLSConfig,
} }
r.listeners = append(r.listeners, wSListener)
wslErr := r.wSListener.Listen(r.relay.Accept) tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
if wslErr != nil { if err != nil {
log.Errorf("failed to bind ws server: %s", wslErr) return err
} }
return wslErr quicListener := &quic.Listener{
Address: cfg.Address,
TLSConfig: tlsConfigQUIC,
}
r.listeners = append(r.listeners, quicListener)
errChan := make(chan error, len(r.listeners))
wg := sync.WaitGroup{}
for _, l := range r.listeners {
wg.Add(1)
go func(listener listener.Listener) {
defer wg.Done()
errChan <- listener.Listen(r.relay.Accept)
}(l)
}
wg.Wait()
close(errChan)
var multiErr *multierror.Error
for err := range errChan {
multiErr = multierror.Append(multiErr, err)
}
return nberrors.FormatErrorOrNil(multiErr)
} }
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context, // Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
// the connections will be forcefully closed. // the connections will be forcefully closed.
func (r *Server) Shutdown(ctx context.Context) (err error) { func (r *Server) Shutdown(ctx context.Context) error {
// stop service new connections
if r.wSListener != nil {
err = r.wSListener.Shutdown(ctx)
}
r.relay.Shutdown(ctx) r.relay.Shutdown(ctx)
return
var multiErr *multierror.Error
for _, l := range r.listeners {
if err := l.Shutdown(ctx); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
return nberrors.FormatErrorOrNil(multiErr)
} }
// InstanceURL returns the instance URL of the relay server. // InstanceURL returns the instance URL of the relay server.

3
relay/tls/alpn.go Normal file
View File

@ -0,0 +1,3 @@
package tls
const nbalpn = "nb-quic"

12
relay/tls/client_dev.go Normal file
View File

@ -0,0 +1,12 @@
//go:build devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
}
}

11
relay/tls/client_prod.go Normal file
View File

@ -0,0 +1,11 @@
//go:build !devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{nbalpn},
}
}

36
relay/tls/doc.go Normal file
View File

@ -0,0 +1,36 @@
// Package tls provides utilities for configuring and managing Transport Layer
// Security (TLS) in server and client environments, with a focus on QUIC
// protocol support and testing configurations.
//
// The package includes functions for cloning and customizing TLS
// configurations as well as generating self-signed certificates for
// development and testing purposes.
//
// Key Features:
//
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
// for QUIC protocol with specified or default settings. QUIC requires a
// specific TLS configuration with proper ALPN (Application-Layer Protocol
// Negotiation) support, making the TLS settings crucial for establishing
// secure connections.
//
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
// for QUIC protocol. The configuration differs between development
// (insecure testing) and production (strict verification).
//
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
// use in local development and testing scenarios.
//
// Usage:
//
// This package provides separate implementations for development and production
// environments. The development implementation (guarded by `//go:build devcert`)
// supports testing configurations with self-signed certificates and insecure
// client connections. The production implementation (guarded by `//go:build
// !devcert`) ensures that valid and secure TLS configurations are supplied
// and used.
//
// The QUIC protocol is highly reliant on properly configured TLS settings,
// and this package ensures that configurations meet the requirements for
// secure and efficient QUIC communication.
package tls

79
relay/tls/server_dev.go Normal file
View File

@ -0,0 +1,79 @@
//go:build devcert
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
log "github.com/sirupsen/logrus"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
log.Warnf("QUIC server will use self signed certificate for testing!")
return generateTestTLSConfig()
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}
// GenerateTestTLSConfig creates a self-signed certificate for testing
func generateTestTLSConfig() (*tls.Config, error) {
log.Infof("generating test TLS config")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn},
}, nil
}

17
relay/tls/server_prod.go Normal file
View File

@ -0,0 +1,17 @@
//go:build !devcert
package tls
import (
"crypto/tls"
"fmt"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}