mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-01 19:09:22 +01:00
Merge branch 'main' into userspace-router
This commit is contained in:
commit
ea6c947f5d
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@ -44,4 +44,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
6
.github/workflows/golang-test-linux.yml
vendored
6
.github/workflows/golang-test-linux.yml
vendored
@ -134,7 +134,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
test_management:
|
||||
needs: [ build-cache ]
|
||||
@ -194,7 +194,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
|
||||
benchmark:
|
||||
needs: [ build-cache ]
|
||||
@ -254,7 +254,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
needs: [ build-cache ]
|
||||
|
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
@ -19,8 +19,7 @@ const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
@ -84,7 +83,6 @@ func (m *aclManager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
@ -97,15 +95,10 @@ func (m *aclManager) AddPeerFiltering(
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
@ -214,28 +207,7 @@ func (m *aclManager) Reset() error {
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
@ -295,12 +267,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
@ -329,8 +295,6 @@ func (m *aclManager) createDefaultChains() error {
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
|
||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
@ -390,30 +354,18 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
|
@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
|
@ -68,27 +68,13 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@ -97,15 +83,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@ -119,7 +96,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
@ -167,33 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@ -201,15 +155,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
@ -270,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
@ -69,7 +69,6 @@ type Manager interface {
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
|
@ -22,8 +22,7 @@ import (
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameOutputRules = "netbird-acl-output-rules"
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
@ -45,9 +44,8 @@ type AclManager struct {
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@ -89,7 +87,6 @@ func (m *AclManager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@ -104,7 +101,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -214,38 +211,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
expOut := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainOutputRules,
|
||||
Position: 0,
|
||||
Exprs: expOut,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
@ -264,15 +229,19 @@ func (m *AclManager) Flush() error {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
r.nftRule,
|
||||
@ -310,9 +279,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrOffset := uint32(12)
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
addrOffset += 4 // is ipv4 address length
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
@ -383,12 +349,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
||||
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
|
||||
var chain *nftables.Chain
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
chain = m.chainInputRules
|
||||
} else {
|
||||
chain = m.chainOutputRules
|
||||
}
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
@ -419,15 +380,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// chainNameOutputRules
|
||||
chain = m.createChain(chainNameOutputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
m.chainOutputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
@ -720,15 +672,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(
|
||||
ip net.IP,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
|
@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@ -210,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
|
@ -16,7 +16,6 @@ type PeerRule struct {
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
drop bool
|
||||
|
@ -45,7 +45,9 @@ type RuleSet map[string]PeerRule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules map[string]RouteRule
|
||||
wgNetwork *net.IPNet
|
||||
@ -297,9 +299,8 @@ func (m *Manager) AddPeerFiltering(
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := PeerRule{
|
||||
@ -307,7 +308,6 @@ func (m *Manager) AddPeerFiltering(
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
@ -343,17 +343,10 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
@ -416,19 +409,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -918,7 +902,6 @@ func (m *Manager) AddUDPPacketHook(
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@ -929,7 +912,6 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
@ -948,19 +930,22 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeletePeerRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
|
||||
fw.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@ -117,7 +117,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{1024 + i}},
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@ -129,7 +129,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
|
||||
fw.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@ -593,7 +593,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -684,7 +684,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -802,7 +802,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@ -889,7 +889,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
|
@ -91,11 +91,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -126,37 +125,15 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP("192.168.1.1")
|
||||
proto = fw.ProtocolTCP
|
||||
port = &fw.Port{Values: []int{80}}
|
||||
direction = fw.RuleDirectionIN
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
@ -246,10 +223,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if tt.expDir != addedRule.direction {
|
||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
@ -272,11 +245,10 @@ func TestManagerReset(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -319,11 +291,10 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
direction := fw.RuleDirectionIN
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@ -523,11 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
@ -288,6 +288,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
@ -308,25 +310,12 @@ func (d *DefaultManager) addInRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@ -337,25 +326,16 @@ func (d *DefaultManager) addOutRules(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.firewall.AddPeerFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err = d.firewall.AddPeerFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return append(rules, rule...), nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
|
@ -120,8 +120,8 @@ func TestDefaultManager(t *testing.T) {
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
@ -358,8 +358,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
if len(acl.peerRulesPairs) != 4 {
|
||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -48,11 +48,17 @@ type restoreHostManager interface {
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@ -239,7 +240,10 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
|
||||
s.stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
if s.disableSys {
|
||||
// use noop host manager if requested or running in netstack mode.
|
||||
// Netstack mode currently doesn't have a way to receive DNS requests.
|
||||
// TODO: Use listener on localhost in netstack mode when running as root.
|
||||
if s.disableSys || netstack.IsEnabled() {
|
||||
log.Info("system DNS is disabled, not setting up host manager")
|
||||
s.hostManager = &noopHostConfigurator{}
|
||||
return nil
|
||||
|
@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
|
@ -499,7 +499,6 @@ func (e *Engine) initFirewall() error {
|
||||
manager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
manager.RuleDirectionIN,
|
||||
manager.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
|
@ -4,13 +4,27 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func isRoot() bool {
|
||||
return os.Geteuid() == 0
|
||||
}
|
||||
|
||||
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
|
||||
if !isRoot() {
|
||||
shell := getUserShell(user)
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
|
||||
return shell, []string{"-l"}, nil
|
||||
}
|
||||
|
||||
loginPath, err = exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
|
@ -6,5 +6,9 @@ package ssh
|
||||
import "os/user"
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
@ -12,6 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
var userObject *user.User
|
||||
userObject, err := user.Lookup(username)
|
||||
if err != nil && err.Error() == user.UnknownUserError(username).Error() {
|
||||
|
@ -39,6 +39,9 @@ func GetInfo(ctx context.Context) *Info {
|
||||
WiretrusteeVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUIVersion(ctx),
|
||||
KernelVersion: kernelVersion,
|
||||
SystemSerialNumber: serial(),
|
||||
SystemProductName: productModel(),
|
||||
SystemManufacturer: productManufacturer(),
|
||||
}
|
||||
|
||||
return gio
|
||||
@ -49,13 +52,42 @@ func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
func serial() string {
|
||||
// try to fetch serial ID using different properties
|
||||
properties := []string{"ril.serialnumber", "ro.serialno", "ro.boot.serialno", "sys.serialnumber"}
|
||||
var value string
|
||||
|
||||
for _, property := range properties {
|
||||
value = getprop(property)
|
||||
if len(value) > 0 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// unable to get serial ID, fallback to ANDROID_ID
|
||||
return androidId()
|
||||
}
|
||||
|
||||
func androidId() string {
|
||||
// this is a uniq id defined on first initialization, id will be a new one if user wipes his device
|
||||
return run("/system/bin/settings", "get", "secure", "android_id")
|
||||
}
|
||||
|
||||
func productModel() string {
|
||||
return getprop("ro.product.model")
|
||||
}
|
||||
|
||||
func productManufacturer() string {
|
||||
return getprop("ro.product.manufacturer")
|
||||
}
|
||||
|
||||
func uname() []string {
|
||||
res := run("/system/bin/uname", "-a")
|
||||
return strings.Split(res, " ")
|
||||
}
|
||||
|
||||
func osVersion() string {
|
||||
return run("/system/bin/getprop", "ro.build.version.release")
|
||||
return getprop("ro.build.version.release")
|
||||
}
|
||||
|
||||
func extractUIVersion(ctx context.Context) string {
|
||||
@ -66,6 +98,10 @@ func extractUIVersion(ctx context.Context) string {
|
||||
return v
|
||||
}
|
||||
|
||||
func getprop(arg ...string) string {
|
||||
return run("/system/bin/getprop", arg...)
|
||||
}
|
||||
|
||||
func run(name string, arg ...string) string {
|
||||
cmd := exec.Command(name, arg...)
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
|
4
go.mod
4
go.mod
@ -71,6 +71,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@ -156,11 +157,13 @@ require (
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
@ -222,6 +225,7 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
|
6
go.sum
6
go.sum
@ -405,6 +405,7 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
@ -610,6 +611,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@ -761,6 +764,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
@ -970,6 +975,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
@ -95,8 +97,6 @@ func (cc *connContainer) writeMsg(msg Msg) {
|
||||
msg.Free()
|
||||
default:
|
||||
msg.Free()
|
||||
cc.log.Infof("message queue is full")
|
||||
// todo consider to close the connection
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,8 +179,7 @@ func (c *Client) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
if err := c.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -264,14 +263,14 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
conn, err := ws.Dial(c.connectionURL)
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
err = c.handShake()
|
||||
if err != nil {
|
||||
if err = c.handShake(); err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
@ -306,7 +305,7 @@ func (c *Client) handShake() error {
|
||||
return fmt.Errorf("validate version: %w", err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
@ -317,7 +316,7 @@ func (c *Client) handShake() error {
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -345,27 +344,30 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
break
|
||||
}
|
||||
|
||||
_, err := messages.ValidateVersion(buf[:n])
|
||||
buf = buf[:n]
|
||||
|
||||
_, err := messages.ValidateVersion(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to validate protocol version: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineServerMessageType(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
||||
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
7
relay/client/dialer/net/err.go
Normal file
7
relay/client/dialer/net/err.go
Normal file
@ -0,0 +1,7 @@
|
||||
package net
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
)
|
97
relay/client/dialer/quic/conn.go
Normal file
97
relay/client/dialer/quic/conn.go
Normal file
@ -0,0 +1,97 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
const (
|
||||
Network = "quic"
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a Addr) Network() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a Addr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
session quic.Connection
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(session quic.Connection) net.Conn {
|
||||
return &Conn{
|
||||
session: session,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
n = copy(b, dgram)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
if c.session != nil {
|
||||
return c.session.LocalAddr()
|
||||
}
|
||||
return Addr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.session.CloseWithError(0, "normal closure")
|
||||
}
|
||||
|
||||
func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
|
||||
return netErr.ErrClosedByServer
|
||||
}
|
||||
return err
|
||||
}
|
71
relay/client/dialer/quic/quic.go
Normal file
71
relay/client/dialer/quic/quic.go
Normal file
@ -0,0 +1,71 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 30 * time.Second,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on UDP: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve UDP address: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(session)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(address, "rels://") {
|
||||
return address[7:], nil
|
||||
}
|
||||
return address[6:], nil
|
||||
}
|
96
relay/client/dialer/race_dialer.go
Normal file
96
relay/client/dialer/race_dialer.go
Normal file
@ -0,0 +1,96 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
Conn net.Conn
|
||||
Protocol string
|
||||
Err error
|
||||
}
|
||||
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
return &RaceDial{
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial() (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(context.Background())
|
||||
defer abort()
|
||||
|
||||
for _, dfn := range r.dialerFns {
|
||||
go r.dial(dfn, abortCtx, connChan)
|
||||
}
|
||||
|
||||
go r.processResults(connChan, winnerConn, abort)
|
||||
|
||||
conn, ok := <-winnerConn
|
||||
if !ok {
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
|
||||
var hasWinner bool
|
||||
for i := 0; i < len(r.dialerFns); i++ {
|
||||
dr := <-connChan
|
||||
if dr.Err != nil {
|
||||
if errors.Is(dr.Err, context.Canceled) {
|
||||
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
|
||||
} else {
|
||||
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if hasWinner {
|
||||
if cerr := dr.Conn.Close(); cerr != nil {
|
||||
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
r.log.Infof("successfully dialed via: %s", dr.Protocol)
|
||||
|
||||
abort()
|
||||
hasWinner = true
|
||||
winnerConn <- dr.Conn
|
||||
}
|
||||
close(winnerConn)
|
||||
}
|
252
relay/client/dialer/race_dialer_test.go
Normal file
252
relay/client/dialer/race_dialer_test.go
Normal file
@ -0,0 +1,252 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type MockAddr struct {
|
||||
network string
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
return m.network
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
return "1.2.3.4"
|
||||
}
|
||||
|
||||
// MockDialer is a mock implementation of DialeFn
|
||||
type MockDialer struct {
|
||||
dialFunc func(ctx context.Context, address string) (net.Conn, error)
|
||||
protocolStr string
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return m.dialFunc(ctx, address)
|
||||
}
|
||||
|
||||
func (m *MockDialer) Protocol() string {
|
||||
return m.protocolStr
|
||||
}
|
||||
|
||||
// MockConn implements net.Conn for testing
|
||||
type MockConn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return m.remoteAddr
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRaceDialEmptyDialers(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
rd := NewRaceDial(logger, serverURL)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error with empty dialers, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto := "test-protocol"
|
||||
|
||||
mockConn := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto},
|
||||
}
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn, nil
|
||||
},
|
||||
protocolStr: proto,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn2, nil
|
||||
},
|
||||
protocolStr: "proto2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn.RemoteAddr().Network() != proto2 {
|
||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialTimeout(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
connectionTimeout = 3 * time.Second
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialAllDialersFail(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "protocol1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("second dialer failed")
|
||||
},
|
||||
protocolStr: "protocol2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto1 := "protocol1"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn1 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto1},
|
||||
}
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
time.Sleep(1 * time.Second)
|
||||
return mockConn1, nil
|
||||
},
|
||||
protocolStr: proto1,
|
||||
}
|
||||
|
||||
mock2err := make(chan error)
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
mock2err <- ctx.Err()
|
||||
return mockConn2, ctx.Err()
|
||||
},
|
||||
protocolStr: proto2,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
if conn != mockConn1 {
|
||||
t.Errorf("Expected first connection, got %v", conn)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Errorf("Timed out waiting for second dialer to finish")
|
||||
case err := <-mock2err:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,11 +1,15 @@
|
||||
package ws
|
||||
|
||||
const (
|
||||
Network = "ws"
|
||||
)
|
||||
|
||||
type WebsocketAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) Network() string {
|
||||
return "websocket"
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) String() string {
|
||||
|
@ -26,6 +26,7 @@ func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
// todo use ErrClosedByServer
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -15,7 +16,14 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -31,8 +39,11 @@ func Dial(address string) (net.Conn, error) {
|
||||
}
|
||||
parsedURL.Path = ws.URLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
|
||||
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -23,20 +23,26 @@ const (
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
SizeOfVersionByte = 1
|
||||
SizeOfMsgType = 1
|
||||
// base size of the message
|
||||
sizeOfVersionByte = 1
|
||||
sizeOfMsgType = 1
|
||||
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
|
||||
|
||||
SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
|
||||
|
||||
sizeOfMagicByte = 4
|
||||
|
||||
headerSizeTransport = IDSize
|
||||
// auth message
|
||||
sizeOfMagicByte = 4
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
offsetMagicByte = sizeOfProtoHeader
|
||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||
|
||||
// hello message
|
||||
headerSizeHello = sizeOfMagicByte + IDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
||||
headerSizeAuthResp = 0
|
||||
// transport
|
||||
headerSizeTransport = IDSize
|
||||
offsetTransportID = sizeOfProtoHeader
|
||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||
)
|
||||
|
||||
var (
|
||||
@ -73,7 +79,7 @@ func (m MsgType) String() string {
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < SizeOfVersionByte {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < SizeOfMsgType {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[0])
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, additions...)
|
||||
@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeHello {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
|
||||
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < headerSizeHelloResp {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return msg, nil
|
||||
@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||
|
||||
msg = append(msg, peerID...)
|
||||
msg = append(msg, authPayload...)
|
||||
@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
||||
|
||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
if len(msg) < headerSizeAuth {
|
||||
if len(msg) < headerTotalSizeAuth {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
|
||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
|
||||
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
|
||||
}
|
||||
|
||||
// MarshalAuthResponse creates a response message to the auth.
|
||||
@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
@ -243,39 +249,34 @@ func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < headerSizeAuthResp+1 {
|
||||
if len(msg) < sizeOfProtoHeader+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg), nil
|
||||
return string(msg[sizeOfProtoHeader:]), nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, SizeOfProtoHeader)
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeClose)
|
||||
|
||||
return msg
|
||||
return []byte{
|
||||
byte(CurrentProtocolVersion),
|
||||
byte(MsgTypeClose),
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalTransportMsg creates a transport message.
|
||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||
// destination peer hashed ID.
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
||||
}
|
||||
|
||||
msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
|
||||
|
||||
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
|
||||
copy(msg[SizeOfProtoHeader:], peerID)
|
||||
|
||||
copy(msg[sizeOfProtoHeader:], peerID)
|
||||
msg = append(msg, payload...)
|
||||
|
||||
return msg, nil
|
||||
@ -283,29 +284,29 @@ func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
|
||||
|
||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
|
||||
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
if len(buf) < headerSizeTransport {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[:headerSizeTransport], nil
|
||||
return buf[offsetTransportID:headerTotalSizeTransport], nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||
// need to allocate a new byte slice.
|
||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||
if len(msg) < len(peerID) {
|
||||
if len(msg) < offsetTransportID+len(peerID) {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg, peerID)
|
||||
copy(msg[offsetTransportID:], peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -6,12 +6,21 @@ import (
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalHelloMsg(peerID, nil)
|
||||
msg, err := MarshalHelloMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHello {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@ -22,12 +31,21 @@ func TestMarshalHelloMsg(t *testing.T) {
|
||||
|
||||
func TestMarshalAuthMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
bHello, err := MarshalAuthMsg(peerID, []byte{})
|
||||
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuth {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@ -36,6 +54,31 @@ func TestMarshalAuthMsg(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthResponse(t *testing.T) {
|
||||
address := "myaddress"
|
||||
msg, err := MarshalAuthResponse(address)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuthResponse {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
|
||||
}
|
||||
|
||||
respAddr, err := UnmarshalAuthResponse(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if respAddr != address {
|
||||
t.Errorf("expected %s, got %s", address, respAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
@ -44,7 +87,25 @@ func TestMarshalTransportMsg(t *testing.T) {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeTransport {
|
||||
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
|
||||
}
|
||||
|
||||
uPeerID, err := UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||
}
|
||||
|
||||
if string(uPeerID) != string(peerID) {
|
||||
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@ -57,3 +118,21 @@ func TestMarshalTransportMsg(t *testing.T) {
|
||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalHealthcheck(t *testing.T) {
|
||||
msg := MarshalHealthcheck()
|
||||
|
||||
_, err := ValidateVersion(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHealthCheck {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
|
||||
}
|
||||
}
|
||||
|
@ -68,12 +68,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
buf = buf[:n]
|
||||
|
||||
_, err = messages.ValidateVersion(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
msgType, err := messages.DetermineClientMessageType(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
@ -85,10 +87,10 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
bytePeerID, peerID, err = h.handleHelloMsg(buf)
|
||||
case messages.MsgTypeAuth:
|
||||
h.handshakeMethodAuth = true
|
||||
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
bytePeerID, peerID, err = h.handleAuthMsg(buf)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||
}
|
||||
|
101
relay/server/listener/quic/conn.go
Normal file
101
relay/server/listener/quic/conn.go
Normal file
@ -0,0 +1,101 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
session quic.Connection
|
||||
closed bool
|
||||
closedMu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConn(session quic.Connection) *Conn {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Conn{
|
||||
session: session,
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
// Copy data to b, ensuring we don’t exceed the size of b
|
||||
n = copy(b, dgram)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.session.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
if c.closed {
|
||||
c.closedMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
|
||||
c.ctxCancel() // Cancel the context
|
||||
|
||||
sessionErr := c.session.CloseWithError(0, "normal closure")
|
||||
return sessionErr
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
c.closedMu.Lock()
|
||||
defer c.closedMu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Check if the connection was closed remotely
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
66
relay/server/listener/quic/listener.go
Normal file
66
relay/server/listener/quic/listener.go
Normal file
@ -0,0 +1,66 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
// Address is the address to listen on
|
||||
Address string
|
||||
// TLSConfig is the TLS configuration for the server
|
||||
TLSConfig *tls.Config
|
||||
|
||||
listener *quic.Listener
|
||||
acceptFn func(conn net.Conn)
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||
l.acceptFn = acceptFn
|
||||
|
||||
quicCfg := &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create QUIC listener: %v", err)
|
||||
}
|
||||
|
||||
l.listener = listener
|
||||
log.Infof("QUIC server listening on address: %s", l.Address)
|
||||
|
||||
for {
|
||||
session, err := listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
if errors.Is(err, quic.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Errorf("Failed to accept QUIC session: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
|
||||
conn := NewConn(session)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Shutdown(ctx context.Context) error {
|
||||
if l.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("stopping QUIC listener")
|
||||
if err := l.listener.Close(); err != nil {
|
||||
return fmt.Errorf("listener shutdown failed: %v", err)
|
||||
}
|
||||
log.Infof("QUIC listener stopped")
|
||||
return nil
|
||||
}
|
@ -88,6 +88,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("WS client connected from: %s", rAddr)
|
||||
|
||||
conn := NewConn(wsConn, lAddr, rAddr)
|
||||
l.acceptFn(conn)
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ func (p *Peer) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
|
||||
msgType, err := messages.DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to determine message type: %s", err)
|
||||
return
|
||||
@ -191,7 +191,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
||||
}
|
||||
|
||||
func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
|
||||
peerID, err := messages.UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to unmarshal transport message: %s", err)
|
||||
return
|
||||
@ -204,7 +204,7 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
|
||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to update transport message: %s", err)
|
||||
return
|
||||
|
@ -150,6 +150,8 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
func (r *Relay) Shutdown(ctx context.Context) {
|
||||
log.Infof("close connection with all peers")
|
||||
r.closeMu.Lock()
|
||||
defer r.closeMu.Unlock()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
peers := r.store.Peers()
|
||||
for _, peer := range peers {
|
||||
@ -161,7 +163,7 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
||||
}
|
||||
wg.Wait()
|
||||
r.metricsCancel()
|
||||
r.closeMu.Unlock()
|
||||
r.closed = true
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server
|
||||
|
@ -3,13 +3,17 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
)
|
||||
|
||||
// ListenerConfig is the configuration for the listener.
|
||||
@ -24,8 +28,8 @@ type ListenerConfig struct {
|
||||
// It is the gate between the WebSocket listener and the Relay server logic.
|
||||
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
|
||||
type Server struct {
|
||||
relay *Relay
|
||||
wSListener listener.Listener
|
||||
relay *Relay
|
||||
listeners []listener.Listener
|
||||
}
|
||||
|
||||
// NewServer creates a new relay server instance.
|
||||
@ -39,35 +43,63 @@ func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authV
|
||||
return nil, err
|
||||
}
|
||||
return &Server{
|
||||
relay: relay,
|
||||
relay: relay,
|
||||
listeners: make([]listener.Listener, 0, 2),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Listen starts the relay server.
|
||||
func (r *Server) Listen(cfg ListenerConfig) error {
|
||||
r.wSListener = &ws.Listener{
|
||||
wSListener := &ws.Listener{
|
||||
Address: cfg.Address,
|
||||
TLSConfig: cfg.TLSConfig,
|
||||
}
|
||||
r.listeners = append(r.listeners, wSListener)
|
||||
|
||||
wslErr := r.wSListener.Listen(r.relay.Accept)
|
||||
if wslErr != nil {
|
||||
log.Errorf("failed to bind ws server: %s", wslErr)
|
||||
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wslErr
|
||||
quicListener := &quic.Listener{
|
||||
Address: cfg.Address,
|
||||
TLSConfig: tlsConfigQUIC,
|
||||
}
|
||||
|
||||
r.listeners = append(r.listeners, quicListener)
|
||||
|
||||
errChan := make(chan error, len(r.listeners))
|
||||
wg := sync.WaitGroup{}
|
||||
for _, l := range r.listeners {
|
||||
wg.Add(1)
|
||||
go func(listener listener.Listener) {
|
||||
defer wg.Done()
|
||||
errChan <- listener.Listen(r.relay.Accept)
|
||||
}(l)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
var multiErr *multierror.Error
|
||||
for err := range errChan {
|
||||
multiErr = multierror.Append(multiErr, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
|
||||
// the connections will be forcefully closed.
|
||||
func (r *Server) Shutdown(ctx context.Context) (err error) {
|
||||
// stop service new connections
|
||||
if r.wSListener != nil {
|
||||
err = r.wSListener.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (r *Server) Shutdown(ctx context.Context) error {
|
||||
r.relay.Shutdown(ctx)
|
||||
return
|
||||
|
||||
var multiErr *multierror.Error
|
||||
for _, l := range r.listeners {
|
||||
if err := l.Shutdown(ctx); err != nil {
|
||||
multiErr = multierror.Append(multiErr, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// InstanceURL returns the instance URL of the relay server.
|
||||
|
3
relay/tls/alpn.go
Normal file
3
relay/tls/alpn.go
Normal file
@ -0,0 +1,3 @@
|
||||
package tls
|
||||
|
||||
const nbalpn = "nb-quic"
|
12
relay/tls/client_dev.go
Normal file
12
relay/tls/client_dev.go
Normal file
@ -0,0 +1,12 @@
|
||||
//go:build devcert
|
||||
|
||||
package tls
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func ClientQUICTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // Debug mode allows insecure connections
|
||||
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
|
||||
}
|
||||
}
|
11
relay/tls/client_prod.go
Normal file
11
relay/tls/client_prod.go
Normal file
@ -0,0 +1,11 @@
|
||||
//go:build !devcert
|
||||
|
||||
package tls
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func ClientQUICTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
NextProtos: []string{nbalpn},
|
||||
}
|
||||
}
|
36
relay/tls/doc.go
Normal file
36
relay/tls/doc.go
Normal file
@ -0,0 +1,36 @@
|
||||
// Package tls provides utilities for configuring and managing Transport Layer
|
||||
// Security (TLS) in server and client environments, with a focus on QUIC
|
||||
// protocol support and testing configurations.
|
||||
//
|
||||
// The package includes functions for cloning and customizing TLS
|
||||
// configurations as well as generating self-signed certificates for
|
||||
// development and testing purposes.
|
||||
//
|
||||
// Key Features:
|
||||
//
|
||||
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
|
||||
// for QUIC protocol with specified or default settings. QUIC requires a
|
||||
// specific TLS configuration with proper ALPN (Application-Layer Protocol
|
||||
// Negotiation) support, making the TLS settings crucial for establishing
|
||||
// secure connections.
|
||||
//
|
||||
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
|
||||
// for QUIC protocol. The configuration differs between development
|
||||
// (insecure testing) and production (strict verification).
|
||||
//
|
||||
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
|
||||
// use in local development and testing scenarios.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// This package provides separate implementations for development and production
|
||||
// environments. The development implementation (guarded by `//go:build devcert`)
|
||||
// supports testing configurations with self-signed certificates and insecure
|
||||
// client connections. The production implementation (guarded by `//go:build
|
||||
// !devcert`) ensures that valid and secure TLS configurations are supplied
|
||||
// and used.
|
||||
//
|
||||
// The QUIC protocol is highly reliant on properly configured TLS settings,
|
||||
// and this package ensures that configurations meet the requirements for
|
||||
// secure and efficient QUIC communication.
|
||||
package tls
|
79
relay/tls/server_dev.go
Normal file
79
relay/tls/server_dev.go
Normal file
@ -0,0 +1,79 @@
|
||||
//go:build devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
||||
if originTLSCfg == nil {
|
||||
log.Warnf("QUIC server will use self signed certificate for testing!")
|
||||
return generateTestTLSConfig()
|
||||
}
|
||||
|
||||
cfg := originTLSCfg.Clone()
|
||||
cfg.NextProtos = []string{nbalpn}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GenerateTestTLSConfig creates a self-signed certificate for testing
|
||||
func generateTestTLSConfig() (*tls.Config, error) {
|
||||
log.Infof("generating test TLS config")
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Organization"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{nbalpn},
|
||||
}, nil
|
||||
}
|
17
relay/tls/server_prod.go
Normal file
17
relay/tls/server_prod.go
Normal file
@ -0,0 +1,17 @@
|
||||
//go:build !devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
||||
if originTLSCfg == nil {
|
||||
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
|
||||
}
|
||||
cfg := originTLSCfg.Clone()
|
||||
cfg.NextProtos = []string{nbalpn}
|
||||
return cfg, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user