Merge branch 'main' into userspace-router

This commit is contained in:
Viktor Liu 2025-01-15 17:00:37 +01:00
commit ea6c947f5d
50 changed files with 1255 additions and 443 deletions

View File

@ -44,4 +44,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@ -134,7 +134,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
@ -194,7 +194,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
@ -254,7 +254,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
api_benchmark:
needs: [ build-cache ]

View File

@ -65,7 +65,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@ -19,8 +19,7 @@ const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
chainNameInputRules = "NETBIRD-ACL-INPUT"
)
type aclEntries map[string][][]string
@ -84,7 +83,6 @@ func (m *aclManager) AddPeerFiltering(
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
@ -97,15 +95,10 @@ func (m *aclManager) AddPeerFiltering(
sPortVal = strconv.Itoa(sPort.Values[0])
}
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
@ -214,28 +207,7 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
@ -295,12 +267,6 @@ func (m *aclManager) createDefaultChains() error {
return err
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@ -329,8 +295,6 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
@ -390,30 +354,18 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {

View File

@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
_ string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",

View File

@ -68,27 +68,13 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{8043: 8046},
}
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@ -97,15 +83,6 @@ func TestIptablesManager(t *testing.T) {
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
@ -119,7 +96,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) {
}
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
@ -167,33 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -201,15 +155,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
@ -270,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@ -69,7 +69,6 @@ type Manager interface {
proto Protocol,
sPort *Port,
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,

View File

@ -22,8 +22,7 @@ import (
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
@ -45,9 +44,8 @@ type AclManager struct {
wgIface iFaceMapper
routingFwChainName string
workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
workTable *nftables.Table
chainInputRules *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@ -89,7 +87,6 @@ func (m *AclManager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
@ -104,7 +101,7 @@ func (m *AclManager) AddPeerFiltering(
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil {
return nil, err
}
@ -214,38 +211,6 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn,
})
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
@ -264,15 +229,19 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
@ -310,9 +279,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions,
&expr.Payload{
@ -383,12 +349,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
var chain *nftables.Chain
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
@ -419,15 +380,6 @@ func (m *AclManager) createDefaultChains() (err error) {
}
m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@ -720,15 +672,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
func generatePeerRuleId(
ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
if sPort != nil {
rulesetID += sPort.String()
}

View File

@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
}
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@ -210,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(

View File

@ -16,7 +16,6 @@ type PeerRule struct {
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
direction firewall.RuleDirection
sPort uint16
dPort uint16
drop bool

View File

@ -45,7 +45,9 @@ type RuleSet map[string]PeerRule
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[string]RuleSet
// outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet
routeRules map[string]RouteRule
wgNetwork *net.IPNet
@ -297,9 +299,8 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
_ string,
comment string,
) ([]firewall.Rule, error) {
r := PeerRule{
@ -307,7 +308,6 @@ func (m *Manager) AddPeerFiltering(
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop,
comment: comment,
}
@ -343,17 +343,10 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if direction == firewall.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@ -416,19 +409,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if r.direction == firewall.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
return nil
}
@ -918,7 +902,6 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP,
dPort: dPort,
ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook,
}
@ -929,7 +912,6 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock()
if in {
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
}
@ -948,19 +930,22 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeletePeerRule(&rule)
delete(arr, r.id)
return nil
}
}
}
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeletePeerRule(&rule)
delete(arr, r.id)
return nil
}
}
}

View File

@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
fw.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@ -117,7 +117,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
fw.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
@ -129,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
fw.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@ -593,7 +593,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@ -684,7 +684,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@ -802,7 +802,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@ -889,7 +889,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}

View File

@ -91,11 +91,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -126,37 +125,15 @@ func TestManagerDeleteRule(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop
comment := "Test rule"
comment := "Test rule 2"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
@ -246,10 +223,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
@ -272,11 +245,10 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -319,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
direction := fw.RuleDirectionIN
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@ -523,11 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.rollBack(newRulePairs)
break
}
if len(rules) > 0 {
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
@ -288,6 +288,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@ -308,25 +310,12 @@ func (d *DefaultManager) addInRules(
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
return nil, fmt.Errorf("add firewall rule: %w", err)
}
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
return rule, nil
}
func (d *DefaultManager) addOutRules(
@ -337,25 +326,16 @@ func (d *DefaultManager) addOutRules(
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
return nil, nil
}
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return append(rules, rule...), nil
return rule, nil
}
// getPeerRuleID() returns unique ID for the rule based on its parameters.

View File

@ -120,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
})
@ -358,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
}

View File

@ -48,11 +48,17 @@ type restoreHostManager interface {
func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s", osManager)
return newHostManagerFromType(wgInterface, osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
return nil, fmt.Errorf("create host manager: %w", err)
}
return mgr, nil
}
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {

View File

@ -12,6 +12,7 @@ import (
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@ -239,7 +240,10 @@ func (s *DefaultServer) Initialize() (err error) {
s.stateManager.RegisterState(&ShutdownState{})
if s.disableSys {
// use noop host manager if requested or running in netstack mode.
// Netstack mode currently doesn't have a way to receive DNS requests.
// TODO: Use listener on localhost in netstack mode when running as root.
if s.disableSys || netstack.IsEnabled() {
log.Info("system DNS is disabled, not setting up host manager")
s.hostManager = &noopHostConfigurator{}
return nil

View File

@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil
}
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err

View File

@ -499,7 +499,6 @@ func (e *Engine) initFirewall() error {
manager.ProtocolUDP,
nil,
&port,
manager.RuleDirectionIN,
manager.ActionAccept,
"",
"",

View File

@ -4,13 +4,27 @@ import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"runtime"
"github.com/netbirdio/netbird/util"
)
func isRoot() bool {
return os.Geteuid() == 0
}
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
if !isRoot() {
shell := getUserShell(user)
if shell == "" {
shell = "/bin/sh"
}
return shell, []string{"-l"}, nil
}
loginPath, err = exec.LookPath("login")
if err != nil {
return "", nil, err

View File

@ -6,5 +6,9 @@ package ssh
import "os/user"
func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
return user.Lookup(username)
}

View File

@ -12,6 +12,10 @@ import (
)
func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
var userObject *user.User
userObject, err := user.Lookup(username)
if err != nil && err.Error() == user.UnknownUserError(username).Error() {

View File

@ -39,6 +39,9 @@ func GetInfo(ctx context.Context) *Info {
WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUIVersion(ctx),
KernelVersion: kernelVersion,
SystemSerialNumber: serial(),
SystemProductName: productModel(),
SystemManufacturer: productManufacturer(),
}
return gio
@ -49,13 +52,42 @@ func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
func serial() string {
// try to fetch serial ID using different properties
properties := []string{"ril.serialnumber", "ro.serialno", "ro.boot.serialno", "sys.serialnumber"}
var value string
for _, property := range properties {
value = getprop(property)
if len(value) > 0 {
return value
}
}
// unable to get serial ID, fallback to ANDROID_ID
return androidId()
}
func androidId() string {
// this is a uniq id defined on first initialization, id will be a new one if user wipes his device
return run("/system/bin/settings", "get", "secure", "android_id")
}
func productModel() string {
return getprop("ro.product.model")
}
func productManufacturer() string {
return getprop("ro.product.manufacturer")
}
func uname() []string {
res := run("/system/bin/uname", "-a")
return strings.Split(res, " ")
}
func osVersion() string {
return run("/system/bin/getprop", "ro.build.version.release")
return getprop("ro.build.version.release")
}
func extractUIVersion(ctx context.Context) string {
@ -66,6 +98,10 @@ func extractUIVersion(ctx context.Context) string {
return v
}
func getprop(arg ...string) string {
return run("/system/bin/getprop", arg...)
}
func run(name string, arg ...string) string {
cmd := exec.Command(name, arg...)
cmd.Stdin = strings.NewReader("some")

4
go.mod
View File

@ -71,6 +71,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.48.2
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@ -156,11 +157,13 @@ require (
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@ -222,6 +225,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect

6
go.sum
View File

@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer"
"github.com/netbirdio/netbird/relay/client/dialer/quic"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
}
@ -179,8 +179,7 @@ func (c *Client) Connect() error {
return nil
}
err := c.connect()
if err != nil {
if err := c.connect(); err != nil {
return err
}
@ -264,14 +263,14 @@ func (c *Client) Close() error {
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
if err = c.handShake(); err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
@ -306,7 +305,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("validate version: %w", err)
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
return err
@ -317,7 +316,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type")
}
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return err
}
@ -345,27 +344,30 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
c.log.Errorf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)
break
}
_, err := messages.ValidateVersion(buf[:n])
buf = buf[:n]
_, err := messages.ValidateVersion(buf)
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf)
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
break
}
}

View File

@ -0,0 +1,7 @@
package net
import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
)

View File

@ -0,0 +1,97 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
)
const (
Network = "quic"
)
type Addr struct {
addr string
}
func (a Addr) Network() string {
return Network
}
func (a Addr) String() string {
return a.addr
}
type Conn struct {
session quic.Connection
ctx context.Context
}
func NewConn(session quic.Connection) net.Conn {
return &Conn{
session: session,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) LocalAddr() net.Addr {
if c.session != nil {
return c.session.LocalAddr()
}
return Addr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
}
func (c *Conn) Close() error {
return c.session.CloseWithError(0, "normal closure")
}
func (c *Conn) remoteCloseErrHandling(err error) error {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return netErr.ErrClosedByServer
}
return err
}

View File

@ -0,0 +1,71 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/relay/tls"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
conn := NewConn(session)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
if strings.HasPrefix(address, "rels://") {
return address[7:], nil
}
return address[6:], nil
}

View File

@ -0,0 +1,96 @@
package dialer
import (
"context"
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
Protocol() string
}
type dialResult struct {
Conn net.Conn
Protocol string
Err error
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
defer abort()
for _, dfn := range r.dialerFns {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
}
return conn, nil
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
var hasWinner bool
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
if errors.Is(dr.Err, context.Canceled) {
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
}
continue
}
if hasWinner {
if cerr := dr.Conn.Close(); cerr != nil {
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
}
continue
}
r.log.Infof("successfully dialed via: %s", dr.Protocol)
abort()
hasWinner = true
winnerConn <- dr.Conn
}
close(winnerConn)
}

View File

@ -0,0 +1,252 @@
package dialer
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockAddr struct {
network string
}
func (m *MockAddr) Network() string {
return m.network
}
func (m *MockAddr) String() string {
return "1.2.3.4"
}
// MockDialer is a mock implementation of DialeFn
type MockDialer struct {
dialFunc func(ctx context.Context, address string) (net.Conn, error)
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}
func (m *MockDialer) Protocol() string {
return m.protocolStr
}
// MockConn implements net.Conn for testing
type MockConn struct {
remoteAddr net.Addr
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return nil
}
func (m *MockConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
}
}
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto := "test-protocol"
mockConn := &MockConn{
remoteAddr: &MockAddr{network: proto},
}
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn, nil
},
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
}
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto2 := "protocol2"
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "proto1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn2, nil
},
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
},
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialAllDialersFail(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "protocol1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("second dialer failed")
},
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto1 := "protocol1"
proto2 := "protocol2"
mockConn1 := &MockConn{
remoteAddr: &MockAddr{network: proto1},
}
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
time.Sleep(1 * time.Second)
return mockConn1, nil
},
protocolStr: proto1,
}
mock2err := make(chan error)
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
mock2err <- ctx.Err()
return mockConn2, ctx.Err()
},
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
if conn != mockConn1 {
t.Errorf("Expected first connection, got %v", conn)
}
select {
case <-time.After(3 * time.Second):
t.Errorf("Timed out waiting for second dialer to finish")
case err := <-mock2err:
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}
}

View File

@ -1,11 +1,15 @@
package ws
const (
Network = "ws"
)
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return "websocket"
return Network
}
func (a WebsocketAddr) String() string {

View File

@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
// todo use ErrClosedByServer
return 0, err
}

View File

@ -2,6 +2,7 @@ package ws
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -15,7 +16,14 @@ import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}

View File

@ -23,20 +23,26 @@ const (
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
SizeOfVersionByte = 1
SizeOfMsgType = 1
// base size of the message
sizeOfVersionByte = 1
sizeOfMsgType = 1
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
sizeOfMagicByte = 4
headerSizeTransport = IDSize
// auth message
sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize
offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0
headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
// transport
headerSizeTransport = IDSize
offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
)
var (
@ -73,7 +79,7 @@ func (m MsgType) String() string {
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {
// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHello,
@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHelloResponse,
@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, additions...)
@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
}
// Deprecated: Use MarshalAuthResponse instead.
@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
return nil, ErrInvalidMessageLength
}
return msg, nil
@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, authPayload...)
@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) {
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
return string(msg[sizeOfProtoHeader:]), nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader)
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeClose)
return msg
return []byte{
byte(CurrentProtocolVersion),
byte(MsgTypeClose),
}
}
// MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID.
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)
copy(msg[SizeOfProtoHeader:], peerID)
copy(msg[sizeOfProtoHeader:], peerID)
msg = append(msg, payload...)
return msg, nil
@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
}
// UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], nil
return buf[offsetTransportID:headerTotalSizeTransport], nil
}
// UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < len(peerID) {
if len(msg) < offsetTransportID+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg, peerID)
copy(msg[offsetTransportID:], peerID)
return nil
}

View File

@ -6,12 +6,21 @@ import (
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID, nil)
msg, err := MarshalHelloMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHello {
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
}
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) {
func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalAuthMsg(peerID, []byte{})
msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuth {
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
}
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) {
}
}
func TestMarshalAuthResponse(t *testing.T) {
address := "myaddress"
msg, err := MarshalAuthResponse(address)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuthResponse {
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
}
respAddr, err := UnmarshalAuthResponse(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if respAddr != address {
t.Errorf("expected %s, got %s", address, respAddr)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeTransport {
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
}
uPeerID, err := UnmarshalTransportID(msg)
if err != nil {
t.Fatalf("failed to unmarshal transport id: %v", err)
}
if string(uPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, uPeerID)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}
func TestMarshalHealthcheck(t *testing.T) {
msg := MarshalHealthcheck()
_, err := ValidateVersion(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHealthCheck {
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
}
}

View File

@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
buf = buf[:n]
_, err = messages.ValidateVersion(buf)
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineClientMessageType(buf)
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
}
@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
bytePeerID, peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth:
h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
bytePeerID, peerID, err = h.handleAuthMsg(buf)
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}

View File

@ -0,0 +1,101 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
)
type Conn struct {
session quic.Connection
closed bool
closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConn(session quic.Connection) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{
session: session,
ctx: ctx,
ctxCancel: cancel,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
// Copy data to b, ensuring we dont exceed the size of b
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
}
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
c.closedMu.Unlock()
return nil
}
c.closed = true
c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) remoteCloseErrHandling(err error) error {
if c.isClosed() {
return net.ErrClosed
}
// Check if the connection was closed remotely
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return net.ErrClosed
}
return err
}

View File

@ -0,0 +1,66 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Listener struct {
// Address is the address to listen on
Address string
// TLSConfig is the TLS configuration for the server
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
}
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
if err != nil {
return fmt.Errorf("failed to create QUIC listener: %v", err)
}
l.listener = listener
log.Infof("QUIC server listening on address: %s", l.Address)
for {
session, err := listener.Accept(context.Background())
if err != nil {
if errors.Is(err, quic.ErrServerClosed) {
return nil
}
log.Errorf("Failed to accept QUIC session: %v", err)
continue
}
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
}
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil {
return nil
}
log.Infof("stopping QUIC listener")
if err := l.listener.Close(); err != nil {
return fmt.Errorf("listener shutdown failed: %v", err)
}
log.Infof("QUIC listener stopped")
return nil
}

View File

@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}
log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}

View File

@ -84,7 +84,7 @@ func (p *Peer) Work() {
return
}
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
msgType, err := messages.DetermineClientMessageType(msg)
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
}
func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
return
@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return
}
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
return

View File

@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) {
func (r *Relay) Shutdown(ctx context.Context) {
log.Infof("close connection with all peers")
r.closeMu.Lock()
defer r.closeMu.Unlock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) {
}
wg.Wait()
r.metricsCancel()
r.closeMu.Unlock()
r.closed = true
}
// InstanceURL returns the instance URL of the relay server

View File

@ -3,13 +3,17 @@ package server
import (
"context"
"crypto/tls"
"sync"
log "github.com/sirupsen/logrus"
"github.com/hashicorp/go-multierror"
"go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls"
)
// ListenerConfig is the configuration for the listener.
@ -24,8 +28,8 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
wSListener listener.Listener
relay *Relay
listeners []listener.Listener
}
// NewServer creates a new relay server instance.
@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV
return nil, err
}
return &Server{
relay: relay,
relay: relay,
listeners: make([]listener.Listener, 0, 2),
}, nil
}
// Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error {
r.wSListener = &ws.Listener{
wSListener := &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
r.listeners = append(r.listeners, wSListener)
wslErr := r.wSListener.Listen(r.relay.Accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
if err != nil {
return err
}
return wslErr
quicListener := &quic.Listener{
Address: cfg.Address,
TLSConfig: tlsConfigQUIC,
}
r.listeners = append(r.listeners, quicListener)
errChan := make(chan error, len(r.listeners))
wg := sync.WaitGroup{}
for _, l := range r.listeners {
wg.Add(1)
go func(listener listener.Listener) {
defer wg.Done()
errChan <- listener.Listen(r.relay.Accept)
}(l)
}
wg.Wait()
close(errChan)
var multiErr *multierror.Error
for err := range errChan {
multiErr = multierror.Append(multiErr, err)
}
return nberrors.FormatErrorOrNil(multiErr)
}
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
// the connections will be forcefully closed.
func (r *Server) Shutdown(ctx context.Context) (err error) {
// stop service new connections
if r.wSListener != nil {
err = r.wSListener.Shutdown(ctx)
}
func (r *Server) Shutdown(ctx context.Context) error {
r.relay.Shutdown(ctx)
return
var multiErr *multierror.Error
for _, l := range r.listeners {
if err := l.Shutdown(ctx); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
return nberrors.FormatErrorOrNil(multiErr)
}
// InstanceURL returns the instance URL of the relay server.

3
relay/tls/alpn.go Normal file
View File

@ -0,0 +1,3 @@
package tls
const nbalpn = "nb-quic"

12
relay/tls/client_dev.go Normal file
View File

@ -0,0 +1,12 @@
//go:build devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
}
}

11
relay/tls/client_prod.go Normal file
View File

@ -0,0 +1,11 @@
//go:build !devcert
package tls
import "crypto/tls"
func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{nbalpn},
}
}

36
relay/tls/doc.go Normal file
View File

@ -0,0 +1,36 @@
// Package tls provides utilities for configuring and managing Transport Layer
// Security (TLS) in server and client environments, with a focus on QUIC
// protocol support and testing configurations.
//
// The package includes functions for cloning and customizing TLS
// configurations as well as generating self-signed certificates for
// development and testing purposes.
//
// Key Features:
//
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
// for QUIC protocol with specified or default settings. QUIC requires a
// specific TLS configuration with proper ALPN (Application-Layer Protocol
// Negotiation) support, making the TLS settings crucial for establishing
// secure connections.
//
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
// for QUIC protocol. The configuration differs between development
// (insecure testing) and production (strict verification).
//
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
// use in local development and testing scenarios.
//
// Usage:
//
// This package provides separate implementations for development and production
// environments. The development implementation (guarded by `//go:build devcert`)
// supports testing configurations with self-signed certificates and insecure
// client connections. The production implementation (guarded by `//go:build
// !devcert`) ensures that valid and secure TLS configurations are supplied
// and used.
//
// The QUIC protocol is highly reliant on properly configured TLS settings,
// and this package ensures that configurations meet the requirements for
// secure and efficient QUIC communication.
package tls

79
relay/tls/server_dev.go Normal file
View File

@ -0,0 +1,79 @@
//go:build devcert
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
log "github.com/sirupsen/logrus"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
log.Warnf("QUIC server will use self signed certificate for testing!")
return generateTestTLSConfig()
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}
// GenerateTestTLSConfig creates a self-signed certificate for testing
func generateTestTLSConfig() (*tls.Config, error) {
log.Infof("generating test TLS config")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn},
}, nil
}

17
relay/tls/server_prod.go Normal file
View File

@ -0,0 +1,17 @@
//go:build !devcert
package tls
import (
"crypto/tls"
"fmt"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}