Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association

This commit is contained in:
bcmmbaga 2024-10-21 12:37:24 +03:00
commit 53218f99bc
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
45 changed files with 2151 additions and 815 deletions

View File

@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.14" SIGN_PIPE_VER: "v0.0.16"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@ -223,4 +223,4 @@ jobs:
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'

View File

@ -96,6 +96,9 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives: archives:
- builds: - builds:
- netbird - netbird

View File

@ -23,6 +23,9 @@ builds:
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives: archives:
- builds: - builds:
- netbird-ui-darwin - netbird-ui-darwin

View File

@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@ -21,13 +22,19 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type entry struct {
spec []string
position int
}
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
entries map[string][][]string entries map[string][][]string
ipsetStore *ipsetStore optionalEntries map[string][]entry
ipsetStore *ipsetStore
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName, routingFwChainName: routingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
ipsetStore: newIpsetStore(), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
} }
err := ipset.Init() err := ipset.Init()
@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
} }
m.seedInitialEntries() m.seedInitialEntries()
m.seedInitialOptionalEntries()
err = m.cleanChains() err = m.cleanChains()
if err != nil { if err != nil {
@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
} }
} }
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil { if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err) log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
} }
} }
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil return nil
} }
@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
} }
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }

View File

@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
sources [] netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,

View File

@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := tableFilter table := r.getTableForChain(chain)
if chain == chainRTNAT {
table = tableNat
}
ok, err := r.iptablesClient.ChainExists(table, chain) ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil { if err != nil {
@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
func (r *router) createContainers() error { func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil { if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %v", chain, err) return fmt.Errorf("create chain %s: %w", chain, err)
} }
} }
if err := r.insertEstablishedRule(chainRTFWD); err != nil { if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %v", err) return fmt.Errorf("insert established rule: %w", err)
} }
return r.addJumpRules() if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
} }
func (r *router) createAndSetupChain(chain string) error { func (r *router) createAndSetupChain(chain string) error {
@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i" intdir := "-i"
lointdir := "-o"
if inverse { if inverse {
intdir = "-o" intdir = "-o"
lointdir = "-i"
} }
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
} }
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {

View File

@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
// GenerateSetName generates a unique name for an ipset based on the given sources. // GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string { func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming // sort for consistent naming
sortPrefixes(sources) SortPrefixes(sources)
var sourcesStr strings.Builder var sourcesStr strings.Builder
for _, src := range sources { for _, src := range sources {
@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged return merged
} }
// sortPrefixes sorts the given slice of netip.Prefix in place. // SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific). // It sorts first by IP address, then by prefix length (most specific to least specific).
func sortPrefixes(prefixes []netip.Prefix) { func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool { sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 { if addrCmp != 0 {

View File

@ -11,12 +11,14 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@ -29,6 +31,7 @@ const (
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter" chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
@ -40,15 +43,14 @@ var (
) )
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routeingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@ -61,7 +63,7 @@ type iFaceMapper interface {
IsUserspaceBind() bool IsUserspaceBind() bool
} }
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations // and is permanent. Using same connection for both type of operations
@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
} }
m := &AclManager{ m := &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routeingFwChainName: routeingFwChainName, routingFwChainName: routingFwChainName,
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
// netbird-acl-forward-filter // netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to netbird-rt-fwd m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil return nil
} }
func (m *AclManager) addJumpRulesToRtForward() { // Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: m.routeingFwChainName, Chain: m.routingFwChainName,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: m.chainFwFilter, Chain: chainFwFilter,
Exprs: expressions, Exprs: expressions,
}) })
} }
@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
return chain return chain
} }
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: name, Name: name,

View File

@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
rule := &nftables.Rule{ rule := &nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: []expr.Any{ Exprs: getEstablishedExprs(1),
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
} }
conn.InsertRule(rule) conn.InsertRule(rule)
} }
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0, 0, 0, 0},
}, },
&expr.Counter{},
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },

View File

@ -10,6 +10,8 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@ -24,7 +26,7 @@ import (
const ( const (
chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat" chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
} }
} }
err = r.cleanUpDefaultForwardRules() err = r.removeAcceptForwardRules()
if err != nil { if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
@ -97,40 +99,7 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller // clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear() r.ipsetCounter.Clear()
return r.cleanUpDefaultForwardRules() return r.removeAcceptForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
} }
func (r *router) createContainers() error { func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw, Name: chainNameRoutingFw,
Table: r.workTable, Table: r.workTable,
@ -157,25 +125,28 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat, Name: chainNameRoutingNat,
Table: r.workTable, Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting, Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1, Priority: &prio,
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
r.acceptForwardRules() if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
err := r.refreshRulesMap() if err := r.refreshRulesMap(); err != nil {
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = r.conn.Flush() if err := r.conn.Flush(); err != nil {
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err) return fmt.Errorf("nftables: unable to initialize table: %v", err)
} }
return nil return nil
} }
@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
} }
r.rules[string(ruleKey)] = r.conn.AddRule(rule) rule = r.conn.AddRule(rule)
return ruleKey, r.conn.Flush() log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
} }
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil return nil
} }
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule) setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
if pair.Inverse { if pair.Inverse {
dir = expr.MetaKeyOIFNAME dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
} }
lo := ifname("lo")
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Meta{ &expr.Meta{
Key: dir, Key: dir,
@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Register: 1, Register: 1,
Data: intf, Data: intf,
}, },
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: lo,
},
} }
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there. // that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() { func (r *router) acceptForwardRules() error {
if r.filterTable == nil { if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules") log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return return nil
} }
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter // Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{ iifRule := &nftables.Rule{
Table: r.filterTable, Table: r.filterTable,
Chain: &nftables.Chain{ Chain: &nftables.Chain{
Name: "FORWARD", Name: chainNameForward,
Table: r.filterTable, Table: r.filterTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward, Hooknum: nftables.ChainHookForward,
@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() {
} }
r.conn.InsertRule(iifRule) r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter // Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{ oifRule := &nftables.Rule{
Table: r.filterTable, Table: r.filterTable,
@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() {
Hooknum: nftables.ChainHookForward, Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter, Priority: nftables.ChainPriorityFilter,
}, },
Exprs: []expr.Any{ Exprs: append(oifExprs, getEstablishedExprs(2)...),
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} }
r.conn.InsertRule(oifRule) r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// RemoveNatRule removes a nftables rule pair from nat chains // RemoveNatRule removes a nftables rule pair from nat chains
@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
} }
log.Debugf("nftables: removed rules for %s", pair.Destination) log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil return nil
} }

View File

@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1, Register: 1,
Data: ifname(ifaceMock.Name()), Data: ifname(ifaceMock.Name()),
}, },
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
) )
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
Register: 1, Register: 1,
Data: ifname(ifaceMock.Name()), Data: ifname(ifaceMock.Name()),
}, },
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
) )
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()] rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
// Verify actual nftables rule content // Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
}) })
} }
} }

View File

@ -1,8 +1,11 @@
package id package id
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net/netip" "net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
) )
@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
dPort *manager.Port, dPort *manager.Port,
action manager.Action, action manager.Action,
) RuleID { ) RuleID {
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) manager.SortPrefixes(sources)
h := sha256.New()
// Write all fields to the hasher, with delimiters
h.Write([]byte("sources:"))
for _, src := range sources {
h.Write([]byte(src.String()))
h.Write([]byte(","))
}
h.Write([]byte("destination:"))
h.Write([]byte(destination.String()))
h.Write([]byte("proto:"))
h.Write([]byte(proto))
h.Write([]byte("sPort:"))
if sPort != nil {
h.Write([]byte(sPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("dPort:"))
if dPort != nil {
h.Write([]byte(dPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("action:"))
h.Write([]byte(strconv.Itoa(int(action))))
hash := hex.EncodeToString(h.Sum(nil))
// prepend destination prefix to be able to identify the rule
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
} }

View File

@ -82,8 +82,6 @@ type Conn struct {
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
@ -106,7 +104,8 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations // for reconnection operations
iCEDisconnected chan bool iCEDisconnected chan bool
@ -257,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.log.Infof("set ICE to active connection")
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) var (
if err != nil { ep *net.UDPAddr
return wgProxy wgproxy.Proxy
err error
)
if iceConnInfo.RelayedOnLocal {
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
ep = directEp
} }
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) conn.log.Errorf("Before add peer hook failed: %v", err)
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
err = conn.configureWGEndpoint(endpointUdpAddr) if conn.wgProxyRelay != nil {
if err != nil { conn.wgProxyRelay.Pause()
if wgProxy != nil { }
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close turn connection: %v", err) if wgProxy != nil {
} wgProxy.Work()
} }
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyICE = wgProxy
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to %s", newState)
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection // switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay) conn.wgProxyRelay.Work()
if err != nil {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState) conn.statusICE.Set(newState)
select { conn.notifyReconnectLoopICEDisconnected(changed)
case conn.iCEDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection has been established, setup the WireGuard")
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy, err := conn.newProxy(rci.relayedConn)
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
if conn.currentConnPriority > connPriorityRelay { conn.wgProxyRelay = wgProxy
if conn.statusICE.Get() == StatusConnected { conn.statusRelay.Set(StatusConnected)
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
}
} }
conn.connIDRelay = nbnet.GenerateConnID() if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
for _, hook := range conn.beforeAddPeerHooks { conn.log.Errorf("Before add peer hook failed: %v", err)
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
err = conn.configureWGEndpoint(endpointUdpAddr) wgProxy.Work()
if err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update wg peer configuration: %v", err) conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
} }
@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return return
} }
log.Debugf("relay connection is disconnected") conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == connPriorityRelay {
log.Debugf("clean up WireGuard config") conn.log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn() _ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed)
select {
case conn.relayDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(), Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
if err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() { func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
} }
} }
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
if !iceConnInfo.RelayedOnLocal { conn.log.Debugf("setup proxied WireGuard connection")
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil { return nil, err
conn.log.Warnf("failed to close turn proxy connection: %v", errClose) }
} return wgProxy, nil
return nil, nil, err }
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Work()
} }
return ep, wgProxy, nil
} }
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {

View File

@ -5,7 +5,6 @@ package ebpf
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
} }
// AddTurnConn add new turn connection for the proxy // AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn) wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{ wgEndpoint := &net.UDPAddr{
@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil return packetConn, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1") localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)

View File

@ -4,8 +4,13 @@ package ebpf
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync"
log "github.com/sirupsen/logrus"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
} }
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
ctxConn, cancel := context.WithCancel(ctx) addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil { if err != nil {
cancel() return fmt.Errorf("add turn conn: %w", err)
return nil, fmt.Errorf("add turn conn: %w", err)
} }
e.remoteConn = remoteConn p.remoteConn = remoteConn
e.cancel = cancel p.ctx, p.cancel = context.WithCancel(ctx)
return addr, err p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard // Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface { type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) AddTurnConn(ctx context.Context, turnConn net.Conn) error
EndpointAddr() *net.UDPAddr
Work()
Pause()
CloseConn() error CloseConn() error
} }

View File

@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn() relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn) err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }

View File

@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies // WGUserSpaceProxy proxies
type WGUserSpaceProxy struct { type WGUserSpaceProxy struct {
localWGListenPort int localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p return p
} }
// AddTurnConn start the proxy with the given remote conn // AddTurnConn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { // The provided Context must be non-nil. If the context expires before
p.ctx, p.cancel = context.WithCancel(ctx) // the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
p.remoteConn = remoteConn // connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
var err error
dialer := net.Dialer{} dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return err
} }
go p.proxyToRemote() p.ctx, p.cancel = context.WithCancel(ctx)
go p.proxyToLocal() p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the localConn // CloseConn close the localConn
@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() { func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err) log.Warnf("error in proxy to remote loop: %s", err)
@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for ctx.Err() == nil {
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n]) _, err = p.remoteConn.Write(buf[:n])
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
} }
// proxyToLocal proxies from the Remote peer to local WireGuard // proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() { // if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err) log.Warnf("error in proxy to local loop: %s", err)
@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n]) _, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to write to wg interface conn: %s", err) log.Debugf("failed to write to wg interface conn: %s", err)

View File

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleExecutable</key>
<string>netbird-ui</string>
<key>CFBundleIconFile</key>
<string>Netbird</string>
<key>LSUIElement</key>
<string>1</string>
</dict>
</plist>

View File

@ -8,11 +8,11 @@ cask "{{ $projectName }}" do
if Hardware::CPU.intel? if Hardware::CPU.intel?
url "{{ $amdURL }}" url "{{ $amdURL }}"
sha256 "{{ crypto.SHA256 $amdFileBytes }}" sha256 "{{ crypto.SHA256 $amdFileBytes }}"
app "netbird_ui_darwin_amd64", target: "Netbird UI.app" app "netbird_ui_darwin", target: "Netbird UI.app"
else else
url "{{ $armURL }}" url "{{ $armURL }}"
sha256 "{{ crypto.SHA256 $armFileBytes }}" sha256 "{{ crypto.SHA256 $armFileBytes }}"
app "netbird_ui_darwin_arm64", target: "Netbird UI.app" app "netbird_ui_darwin", target: "Netbird UI.app"
end end
depends_on formula: "netbird" depends_on formula: "netbird"

20
go.mod
View File

@ -19,8 +19,8 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.24.0 golang.org/x/crypto v0.28.0
golang.org/x/sys v0.21.0 golang.org/x/sys v0.26.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
@ -38,6 +38,7 @@ require (
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/davecgh/go-spew v1.1.1
github.com/eko/gocache/v3 v3.1.1 github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/gliderlabs/ssh v0.3.4 github.com/gliderlabs/ssh v0.3.4
@ -45,7 +46,7 @@ require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.2.0
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
@ -55,12 +56,12 @@ require (
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.7 github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19 github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@ -90,10 +91,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.26.0 golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.19.0 golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0 golang.org/x/sync v0.8.0
golang.org/x/term v0.21.0 golang.org/x/term v0.25.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
@ -134,7 +135,6 @@ require (
github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
@ -222,7 +222,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect

36
go.sum
View File

@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
@ -780,8 +780,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -877,8 +877,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -907,8 +907,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -980,8 +980,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@ -989,8 +989,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1005,8 +1005,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -873,7 +873,7 @@ services:
zitadel: zitadel:
restart: 'always' restart: 'always'
networks: [netbird] networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.54.3' image: 'ghcr.io/zitadel/zitadel:v2.54.10'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file: env_file:
- ./zitadel.env - ./zitadel.env

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"errors"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
"math/rand" "math/rand"
@ -44,12 +45,15 @@ import (
) )
const ( const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
) )
type userLoggedInOnce bool type userLoggedInOnce bool
@ -177,6 +181,8 @@ type DefaultAccountManager struct {
dnsDomain string dnsDomain string
peerLoginExpiry Scheduler peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
@ -194,6 +200,13 @@ type Settings struct {
// Applies to all peers that have Peer.LoginExpirationEnabled set to true. // Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration PeerLoginExpiration time.Duration
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
PeerInactivityExpirationEnabled bool
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
PeerInactivityExpiration time.Duration
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
RegularUsersViewBlocked bool RegularUsersViewBlocked bool
@ -224,6 +237,9 @@ func (s *Settings) Copy() *Settings {
GroupsPropagationEnabled: s.GroupsPropagationEnabled, GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups, JWTAllowGroups: s.JWTAllowGroups,
RegularUsersViewBlocked: s.RegularUsersViewBlocked, RegularUsersViewBlocked: s.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: s.PeerInactivityExpiration,
} }
if s.Extra != nil { if s.Extra != nil {
settings.Extra = s.Extra.Copy() settings.Extra = s.Extra.Copy()
@ -605,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
return peers return peers
} }
// GetInactivePeers returns peers that have been expired by inactivity
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
var peers []*nbpeer.Peer
for _, inactivePeer := range a.GetPeersWithInactivity() {
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
peersWithExpiry := a.GetPeersWithInactivity()
if len(peersWithExpiry) == 0 {
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
peers := make([]*nbpeer.Peer, 0)
for _, peer := range a.Peers {
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
peers = append(peers, peer)
}
}
return peers
}
// GetPeers returns a list of all Account peers // GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*nbpeer.Peer { func (a *Account) GetPeers() []*nbpeer.Peer {
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
@ -971,6 +1041,7 @@ func BuildManager(
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator, integratedPeerValidator: integratedPeerValidator,
metrics: metrics, metrics: metrics,
@ -1099,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, account)
} }
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
updatedAccount := account.UpdateSettings(newSettings) updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
@ -1109,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
return nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -1144,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
} }
} }
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextInactivePeerExpiration()
}
expiredPeers := account.GetInactivePeers()
var peerIDs []string
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextInactivePeerExpiration()
}
return account.GetNextInactivePeerExpiration()
}
}
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
}
}
// newAccount creates a new Account with a generated ID and generated default setup keys. // newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error // If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
@ -1284,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
} }
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
return "", err return "", err
} }
return account.Id, nil return account.Id, nil
@ -1299,28 +1432,39 @@ func isNil(i idp.Manager) bool {
} }
// addAccountIDToIDPAppMeta update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
if err != nil {
return err
}
cachedAccount := &Account{
Id: accountID,
Users: make(map[string]*User),
}
for _, user := range accountUsers {
cachedAccount.Users[user.Id] = user
}
// user can be nil if it wasn't found (e.g., just created) // user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, account) user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
if err != nil { if err != nil {
return err return err
} }
if user != nil && user.AppMetadata.WTAccountID == account.Id { if user != nil && user.AppMetadata.WTAccountID == accountID {
// it was already set, so we skip the unnecessary update // it was already set, so we skip the unnecessary update
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID) accountID, userID)
return nil return nil
} }
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err != nil { if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
} }
// refresh cache to reflect the update // refresh cache to reflect the update
_, err = am.refreshCache(ctx, account.Id) _, err = am.refreshCache(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@ -1544,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
} }
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
primaryDomain bool, primaryDomain bool,
) error { ) error {
if claims.Domain == "" {
if claims.Domain != "" {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
} else {
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
return nil
} }
err := am.Store.SaveAccount(ctx, account) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
} }
return nil
if domainIsUpToDate(accountDomain, domainCategory, claims) {
return nil
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
}
newDomain := accountDomain
newCategoty := domainCategory
lowerDomain := strings.ToLower(claims.Domain)
if accountDomain != lowerDomain && user.HasAdminPower() {
newDomain = lowerDomain
}
if accountDomain == lowerDomain {
newCategoty = claims.DomainCategory
}
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
} }
// handleExistingUserAccount handles existing User accounts and update its domain attributes. // handleExistingUserAccount handles existing User accounts and update its domain attributes.
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount( func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context, ctx context.Context,
existingAcc *Account, userAccountID string,
primaryDomain bool, domainAccountID string,
claims jwtclaims.AuthorizationClaims, claims jwtclaims.AuthorizationClaims,
) error { ) error {
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
if err != nil { if err != nil {
return err return err
} }
// we should register the account ID to this user's metadata in our IDP manager // we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -1593,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
return nil return nil
} }
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain. // otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
if claims.UserId == "" { if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty") return "", fmt.Errorf("user ID is empty")
} }
var (
account *Account
err error
)
lowerDomain := strings.ToLower(claims.Domain) lowerDomain := strings.ToLower(claims.Domain)
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
} else {
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
if err != nil {
return nil, err
}
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
if err != nil {
return nil, err
}
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
if err != nil { if err != nil {
return nil, err return "", err
} }
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) newAccount.Domain = lowerDomain
newAccount.DomainCategory = claims.DomainCategory
newAccount.IsDomainPrimaryAccount = true
return account, nil err = am.Store.SaveAccount(ctx, newAccount)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
return newAccount.Id, nil
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
usersMap := make(map[string]*User)
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
err := am.Store.SaveUsers(domainAccountID, usersMap)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
} }
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite
@ -1774,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
// GetAccountIDFromToken returns an account ID associated with this token. // GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" { if claims.UserId == "" {
return "", "", fmt.Errorf("user ID is empty") return "", "", errors.New(emptyUserID)
} }
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
@ -1962,16 +2141,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is not private or domain is invalid, it will return the account ID by user ID.
// if domain is of the PrivateCategory category, it will evaluate // if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain // if account is new, existing or if there is another account with the same domain
// //
// Use cases: // Use cases:
// //
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) // New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
// //
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) // New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
// //
// New user + New account + Existing Public Domain -> create account, user role = admin // New user + New account + Existing Public Domain -> create account, user role = owner
// //
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
// //
@ -1981,98 +2161,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" { if claims.UserId == "" {
return "", fmt.Errorf("user ID is empty") return "", errors.New(emptyUserID)
} }
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
return "", err
}
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
return userAccountID, nil
}
} }
start := time.Now() if claims.AccountId != "" {
unlock := am.Store.AcquireGlobalLock(ctx) return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
defer unlock() }
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
if cancel != nil {
defer cancel()
}
if err != nil { if err != nil {
// if NotFound we are good to continue, otherwise return error return "", err
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return "", err
}
} }
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil { if handleNotFound(err) != nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
defer unlockAccount()
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
return "", err
}
return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
var domainAccount *Account
if domainAccountID != "" {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return "", err
}
}
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else {
// other error
return "", err return "", err
} }
if userAccountID != "" {
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
return "", err
}
return userAccountID, nil
}
if domainAccountID != "" {
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
}
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", nil, err
}
if domainAccountID != "" {
return domainAccountID, nil, nil
}
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", nil, err
}
return domainAccountID, cancel, nil
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
}
if domainIsUpToDate(accountDomain, domainCategory, claims) {
return claims.AccountId, nil
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
}
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
if err != nil {
return "", err
}
return claims.AccountId, nil
}
func handleNotFound(err error) error {
if err == nil {
return nil
}
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return err
}
return nil
}
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain
} }
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
@ -2338,6 +2544,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
PeerLoginExpiration: DefaultPeerLoginExpiration, PeerLoginExpiration: DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true, GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true, RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
}, },
} }

View File

@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims type initUserParams jwtclaims.AuthorizationClaims
type test struct { var (
publicDomain = "public.com"
privateDomain = "private.com"
unknownDomain = "unknown.com"
)
defaultInitAccount := initUserParams{
Domain: publicDomain,
UserId: "defaultUser",
}
initUnknown := defaultInitAccount
initUnknown.DomainCategory = UnknownCategory
initUnknown.Domain = unknownDomain
privateInitAccount := defaultInitAccount
privateInitAccount.Domain = privateDomain
privateInitAccount.DomainCategory = PrivateCategory
testCases := []struct {
name string name string
inputClaims jwtclaims.AuthorizationClaims inputClaims jwtclaims.AuthorizationClaims
inputInitUserParams initUserParams inputInitUserParams initUserParams
@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
expectedPrimaryDomainStatus bool expectedPrimaryDomainStatus bool
expectedCreatedBy string expectedCreatedBy string
expectedUsers []string expectedUsers []string
} }{
{
var ( name: "New User With Public Domain",
publicDomain = "public.com" inputClaims: jwtclaims.AuthorizationClaims{
privateDomain = "private.com" Domain: publicDomain,
unknownDomain = "unknown.com" UserId: "pub-domain-user",
) DomainCategory: PublicCategory,
},
defaultInitAccount := initUserParams{ inputInitUserParams: defaultInitAccount,
Domain: publicDomain, testingFunc: require.NotEqual,
UserId: "defaultUser", expectedMSG: "account IDs shouldn't match",
} expectedUserRole: UserRoleOwner,
expectedDomainCategory: "",
testCase1 := test{ expectedDomain: publicDomain,
name: "New User With Public Domain", expectedPrimaryDomainStatus: false,
inputClaims: jwtclaims.AuthorizationClaims{ expectedCreatedBy: "pub-domain-user",
Domain: publicDomain, expectedUsers: []string{"pub-domain-user"},
UserId: "pub-domain-user",
DomainCategory: PublicCategory,
}, },
inputInitUserParams: defaultInitAccount, {
testingFunc: require.NotEqual, name: "New User With Unknown Domain",
expectedMSG: "account IDs shouldn't match", inputClaims: jwtclaims.AuthorizationClaims{
expectedUserRole: UserRoleOwner, Domain: unknownDomain,
expectedDomainCategory: "", UserId: "unknown-domain-user",
expectedDomain: publicDomain, DomainCategory: UnknownCategory,
expectedPrimaryDomainStatus: false, },
expectedCreatedBy: "pub-domain-user", inputInitUserParams: initUnknown,
expectedUsers: []string{"pub-domain-user"}, testingFunc: require.NotEqual,
} expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
initUnknown := defaultInitAccount expectedDomain: unknownDomain,
initUnknown.DomainCategory = UnknownCategory expectedDomainCategory: "",
initUnknown.Domain = unknownDomain expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user",
testCase2 := test{ expectedUsers: []string{"unknown-domain-user"},
name: "New User With Unknown Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: UnknownCategory,
}, },
inputInitUserParams: initUnknown, {
testingFunc: require.NotEqual, name: "New User With Private Domain",
expectedMSG: "account IDs shouldn't match", inputClaims: jwtclaims.AuthorizationClaims{
expectedUserRole: UserRoleOwner, Domain: privateDomain,
expectedDomain: unknownDomain, UserId: "pvt-domain-user",
expectedDomainCategory: "", DomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: false, },
expectedCreatedBy: "unknown-domain-user", inputInitUserParams: defaultInitAccount,
expectedUsers: []string{"unknown-domain-user"}, testingFunc: require.NotEqual,
} expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
testCase3 := test{ expectedDomain: privateDomain,
name: "New User With Private Domain", expectedDomainCategory: PrivateCategory,
inputClaims: jwtclaims.AuthorizationClaims{ expectedPrimaryDomainStatus: true,
Domain: privateDomain, expectedCreatedBy: "pvt-domain-user",
UserId: "pvt-domain-user", expectedUsers: []string{"pvt-domain-user"},
DomainCategory: PrivateCategory,
}, },
inputInitUserParams: defaultInitAccount, {
testingFunc: require.NotEqual, name: "New Regular User With Existing Private Domain",
expectedMSG: "account IDs shouldn't match", inputClaims: jwtclaims.AuthorizationClaims{
expectedUserRole: UserRoleOwner, Domain: privateDomain,
expectedDomain: privateDomain, UserId: "new-pvt-domain-user",
expectedDomainCategory: PrivateCategory, DomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, },
expectedCreatedBy: "pvt-domain-user", inputUpdateAttrs: true,
expectedUsers: []string{"pvt-domain-user"}, inputInitUserParams: privateInitAccount,
} testingFunc: require.Equal,
expectedMSG: "account IDs should match",
privateInitAccount := defaultInitAccount expectedUserRole: UserRoleUser,
privateInitAccount.Domain = privateDomain expectedDomain: privateDomain,
privateInitAccount.DomainCategory = PrivateCategory expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true,
testCase4 := test{ expectedCreatedBy: defaultInitAccount.UserId,
name: "New Regular User With Existing Private Domain", expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: PrivateCategory,
}, },
inputUpdateAttrs: true, {
inputInitUserParams: privateInitAccount, name: "Existing User With Existing Reclassified Private Domain",
testingFunc: require.Equal, inputClaims: jwtclaims.AuthorizationClaims{
expectedMSG: "account IDs should match", Domain: defaultInitAccount.Domain,
expectedUserRole: UserRoleUser, UserId: defaultInitAccount.UserId,
expectedDomain: privateDomain, DomainCategory: PrivateCategory,
expectedDomainCategory: PrivateCategory, },
expectedPrimaryDomainStatus: true, inputInitUserParams: defaultInitAccount,
expectedCreatedBy: defaultInitAccount.UserId, testingFunc: require.Equal,
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, expectedMSG: "account IDs should match",
} expectedUserRole: UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
testCase5 := test{ expectedDomainCategory: PrivateCategory,
name: "Existing User With Existing Reclassified Private Domain", expectedPrimaryDomainStatus: true,
inputClaims: jwtclaims.AuthorizationClaims{ expectedCreatedBy: defaultInitAccount.UserId,
Domain: defaultInitAccount.Domain, expectedUsers: []string{defaultInitAccount.UserId},
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
}, },
inputInitUserParams: defaultInitAccount, {
testingFunc: require.Equal, name: "Existing Account Id With Existing Reclassified Private Domain",
expectedMSG: "account IDs should match", inputClaims: jwtclaims.AuthorizationClaims{
expectedUserRole: UserRoleOwner, Domain: defaultInitAccount.Domain,
expectedDomain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId,
expectedDomainCategory: PrivateCategory, DomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, },
expectedCreatedBy: defaultInitAccount.UserId, inputUpdateClaimAccount: true,
expectedUsers: []string{defaultInitAccount.UserId}, inputInitUserParams: defaultInitAccount,
} testingFunc: require.Equal,
expectedMSG: "account IDs should match",
testCase6 := test{ expectedUserRole: UserRoleOwner,
name: "Existing Account Id With Existing Reclassified Private Domain", expectedDomain: defaultInitAccount.Domain,
inputClaims: jwtclaims.AuthorizationClaims{ expectedDomainCategory: PrivateCategory,
Domain: defaultInitAccount.Domain, expectedPrimaryDomainStatus: true,
UserId: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
DomainCategory: PrivateCategory, expectedUsers: []string{defaultInitAccount.UserId},
}, },
inputUpdateClaimAccount: true, {
inputInitUserParams: defaultInitAccount, name: "User With Private Category And Empty Domain",
testingFunc: require.Equal, inputClaims: jwtclaims.AuthorizationClaims{
expectedMSG: "account IDs should match", Domain: "",
expectedUserRole: UserRoleOwner, UserId: "pvt-domain-user",
expectedDomain: defaultInitAccount.Domain, DomainCategory: PrivateCategory,
expectedDomainCategory: PrivateCategory, },
expectedPrimaryDomainStatus: true, inputInitUserParams: defaultInitAccount,
expectedCreatedBy: defaultInitAccount.UserId, testingFunc: require.NotEqual,
expectedUsers: []string{defaultInitAccount.UserId}, expectedMSG: "account IDs shouldn't match",
} expectedUserRole: UserRoleOwner,
expectedDomain: "",
testCase7 := test{ expectedDomainCategory: "",
name: "User With Private Category And Empty Domain", expectedPrimaryDomainStatus: false,
inputClaims: jwtclaims.AuthorizationClaims{ expectedCreatedBy: "pvt-domain-user",
Domain: "", expectedUsers: []string{"pvt-domain-user"},
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
}, },
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
} }
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed") require.NoError(t, err, "update init user failed")
} }
@ -2025,6 +2019,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
} }
} }
func TestAccount_GetInactivePeers(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "Peers with inactivity expiration disabled, no expired peers",
peers: map[string]*nbpeer.Peer{
"peer-1": {
InactivityExpirationEnabled: false,
},
"peer-2": {
InactivityExpirationEnabled: false,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Two peers expired",
peers: map[string]*nbpeer.Peer{
"peer-1": {
ID: "peer-1",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC().Add(-45 * time.Second),
Connected: false,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
UserID: userID,
},
"peer-2": {
ID: "peer-2",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC().Add(-45 * time.Second),
Connected: false,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
UserID: userID,
},
"peer-3": {
ID: "peer-3",
InactivityExpirationEnabled: true,
Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
UserID: userID,
},
},
expectedPeers: map[string]struct{}{
"peer-1": {},
"peer-2": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
}
expiredPeers := account.GetInactivePeers()
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
for _, peer := range expiredPeers {
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
t.Fatalf("expected to have peer %s expired", peer.ID)
}
}
})
}
}
func TestAccount_GetPeersWithExpiration(t *testing.T) { func TestAccount_GetPeersWithExpiration(t *testing.T) {
type test struct { type test struct {
name string name string
@ -2094,6 +2172,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
} }
} }
func TestAccount_GetPeersWithInactivity(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "No account peers, no peers with expiration",
peers: map[string]*nbpeer.Peer{},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration disabled, no peers with expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration enabled, return peers with expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
ID: "peer-1",
InactivityExpirationEnabled: true,
UserID: userID,
},
"peer-2": {
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expectedPeers: map[string]struct{}{
"peer-1": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
}
actual := account.GetPeersWithInactivity()
assert.Len(t, actual, len(testCase.expectedPeers))
if len(testCase.expectedPeers) > 0 {
for k := range testCase.expectedPeers {
contains := false
for _, peer := range actual {
if k == peer.ID {
contains = true
}
}
assert.True(t, contains)
}
}
})
}
}
func TestAccount_GetNextPeerExpiration(t *testing.T) { func TestAccount_GetNextPeerExpiration(t *testing.T) {
type test struct { type test struct {
name string name string
@ -2255,6 +2402,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
} }
} }
func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
type test struct {
name string
peers map[string]*nbpeer.Peer
expiration time.Duration
expirationEnabled bool
expectedNextRun bool
expectedNextExpiration time.Duration
}
expectedNextExpiration := time.Minute
testCases := []test{
{
name: "No peers, no expiration",
peers: map[string]*nbpeer.Peer{},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "No connected peers, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: false,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: false,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Connected peers with disabled expiration, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
},
InactivityExpirationEnabled: false,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Expired peers, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "To be expired peer, return expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: false,
LoginExpired: false,
LastSeen: time.Now().Add(-1 * time.Second),
},
InactivityExpirationEnabled: true,
LastLogin: time.Now().UTC(),
UserID: userID,
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: true,
},
InactivityExpirationEnabled: true,
UserID: userID,
},
},
expiration: time.Minute,
expirationEnabled: false,
expectedNextRun: true,
expectedNextExpiration: expectedNextExpiration,
},
{
name: "Peers added with setup keys, no expiration",
peers: map[string]*nbpeer.Peer{
"peer-1": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: false,
},
InactivityExpirationEnabled: true,
SetupKey: "key",
},
"peer-2": {
Status: &nbpeer.PeerStatus{
Connected: true,
LoginExpired: false,
},
InactivityExpirationEnabled: true,
SetupKey: "key",
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
}
expiration, ok := account.GetNextInactivePeerExpiration()
assert.Equal(t, testCase.expectedNextRun, ok)
if testCase.expectedNextRun {
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
} else {
assert.Equal(t, expiration, testCase.expectedNextExpiration)
}
})
}
}
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")

View File

@ -139,6 +139,13 @@ const (
PostureCheckUpdated Activity = 61 PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check // PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted Activity = 62 PostureCheckDeleted Activity = 62
PeerInactivityExpirationEnabled Activity = 63
PeerInactivityExpirationDisabled Activity = 64
AccountPeerInactivityExpirationEnabled Activity = 65
AccountPeerInactivityExpirationDisabled Activity = 66
AccountPeerInactivityExpirationDurationUpdated Activity = 67
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckCreated: {"Posture check created", "posture.check.created"},
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
account.Settings = &Settings{ account.Settings = &Settings{
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
PeerLoginExpiration: DefaultPeerLoginExpiration, PeerLoginExpiration: DefaultPeerLoginExpiration,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
} }
} }

View File

@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
} }
if req.Settings.Extra != nil { if req.Settings.Extra != nil {

View File

@ -54,6 +54,14 @@ components:
description: Period of time after which peer login expires (seconds). description: Period of time after which peer login expires (seconds).
type: integer type: integer
example: 43200 example: 43200
peer_inactivity_expiration_enabled:
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
type: boolean
example: true
peer_inactivity_expiration:
description: Period of time of inactivity after which peer session expires (seconds).
type: integer
example: 43200
regular_users_view_blocked: regular_users_view_blocked:
description: Allows blocking regular users from viewing parts of the system. description: Allows blocking regular users from viewing parts of the system.
type: boolean type: boolean
@ -81,6 +89,8 @@ components:
required: required:
- peer_login_expiration_enabled - peer_login_expiration_enabled
- peer_login_expiration - peer_login_expiration
- peer_inactivity_expiration_enabled
- peer_inactivity_expiration
- regular_users_view_blocked - regular_users_view_blocked
AccountExtraSettings: AccountExtraSettings:
type: object type: object
@ -243,6 +253,9 @@ components:
login_expiration_enabled: login_expiration_enabled:
type: boolean type: boolean
example: false example: false
inactivity_expiration_enabled:
type: boolean
example: false
approval_required: approval_required:
description: (Cloud only) Indicates whether peer needs approval description: (Cloud only) Indicates whether peer needs approval
type: boolean type: boolean
@ -251,6 +264,7 @@ components:
- name - name
- ssh_enabled - ssh_enabled
- login_expiration_enabled - login_expiration_enabled
- inactivity_expiration_enabled
Peer: Peer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'
@ -327,6 +341,10 @@ components:
type: string type: string
format: date-time format: date-time
example: "2023-05-05T09:00:35.477782Z" example: "2023-05-05T09:00:35.477782Z"
inactivity_expiration_enabled:
description: Indicates whether peer inactivity expiration has been enabled or not
type: boolean
example: false
approval_required: approval_required:
description: (Cloud only) Indicates whether peer needs approval description: (Cloud only) Indicates whether peer needs approval
type: boolean type: boolean
@ -354,6 +372,7 @@ components:
- last_seen - last_seen
- login_expiration_enabled - login_expiration_enabled
- login_expired - login_expired
- inactivity_expiration_enabled
- os - os
- ssh_enabled - ssh_enabled
- user_id - user_id

View File

@ -220,6 +220,12 @@ type AccountSettings struct {
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
// PeerLoginExpiration Period of time after which peer login expires (seconds). // PeerLoginExpiration Period of time after which peer login expires (seconds).
PeerLoginExpiration int `json:"peer_login_expiration"` PeerLoginExpiration int `json:"peer_login_expiration"`
@ -538,6 +544,9 @@ type Peer struct {
// Id Peer ID // Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address // Ip Peer's IP address
Ip string `json:"ip"` Ip string `json:"ip"`
@ -613,6 +622,9 @@ type PeerBatch struct {
// Id Peer ID // Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
// Ip Peer's IP address // Ip Peer's IP address
Ip string `json:"ip"` Ip string `json:"ip"`
@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest. // PeerRequest defines model for PeerRequest.
type PeerRequest struct { type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"` ApprovalRequired *bool `json:"approval_required,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"` InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
Name string `json:"name"` LoginExpirationEnabled bool `json:"login_expiration_enabled"`
SshEnabled bool `json:"ssh_enabled"` Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
} }
// PersonalAccessToken defines model for PersonalAccessToken. // PersonalAccessToken defines model for PersonalAccessToken.

View File

@ -7,6 +7,8 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@ -14,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
// PeersHandler is a handler that returns peers of the account // PeersHandler is a handler that returns peers of the account
@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
SSHEnabled: req.SshEnabled, SSHEnabled: req.SshEnabled,
Name: req.Name, Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled, LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
} }
if req.ApprovalRequired != nil { if req.ApprovalRequired != nil {
@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
} }
return &api.Peer{ return &api.Peer{
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
Ip: peer.IP.String(), Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(), ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected, Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen, LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion, KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID), GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion, Version: peer.Meta.WtVersion,
Groups: groupsInfo, Groups: groupsInfo,
SshEnabled: peer.SSHEnabled, SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname, Hostname: peer.Meta.Hostname,
UserId: peer.UserID, UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion, UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain), DnsLabel: fqdn(peer, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled, LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
ApprovalRequired: !approved, ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
} }
} }
@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
} }
} }

View File

@ -111,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err return err
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
oldStatus := peer.Status.Copy() oldStatus := peer.Status.Copy()
newStatus := oldStatus newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC() newStatus.LastSeen = time.Now().UTC()
@ -139,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
account.UpdatePeer(peer) account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil { if err != nil {
return err return false, err
} }
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { return oldStatus.LoginExpired, nil
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
}
return nil
} }
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -222,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
} }
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
}
account.UpdatePeer(peer) account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
@ -454,23 +488,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
registrationTime := time.Now().UTC() registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{ newPeer = &nbpeer.Peer{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID, AccountID: accountID,
Key: peer.Key, Key: peer.Key,
SetupKey: upperKey, SetupKey: upperKey,
IP: freeIP, IP: freeIP,
Meta: peer.Meta, Meta: peer.Meta,
Name: peer.Meta.Hostname, Name: peer.Meta.Hostname,
DNSLabel: freeLabel, DNSLabel: freeLabel,
UserID: userID, UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false, SSHEnabled: false,
SSHKey: peer.SSHKey, SSHKey: peer.SSHKey,
LastLogin: registrationTime, LastLogin: registrationTime,
CreatedAt: registrationTime, CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser, LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral, Ephemeral: ephemeral,
Location: peer.Location, Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
} }
opEvent.TargetID = newPeer.ID opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())

View File

@ -38,6 +38,8 @@ type Peer struct {
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin // Works with LastLogin
LoginExpirationEnabled bool `diff:"-"` LoginExpirationEnabled bool `diff:"-"`
InactivityExpirationEnabled bool `diff:"-"`
// LastLogin the time when peer performed last login operation // LastLogin the time when peer performed last login operation
LastLogin time.Time `diff:"-"` LastLogin time.Time `diff:"-"`
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer {
CreatedAt: p.CreatedAt, CreatedAt: p.CreatedAt,
Ephemeral: p.Ephemeral, Ephemeral: p.Ephemeral,
Location: p.Location, Location: p.Location,
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
} }
} }
@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
p.Status = newStatus p.Status = newStatus
} }
// SessionExpired indicates whether the peer's session has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
return false, 0
}
expiresAt := p.Status.LastSeen.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// LoginExpired indicates whether the peer's login has expired or not. // LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).

View File

@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
} }
} }
func TestPeer_SessionExpired(t *testing.T) {
tt := []struct {
name string
expirationEnabled bool
lastLogin time.Time
connected bool
expected bool
accountSettings *Settings
}{
{
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
expirationEnabled: false,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
},
expected: false,
},
{
name: "Peer Inactivity Should Expire",
expirationEnabled: true,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: true,
},
{
name: "Peer Inactivity Should Not Expire",
expirationEnabled: true,
connected: true,
lastLogin: time.Now().UTC(),
accountSettings: &Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
expected: false,
},
}
for _, c := range tt {
t.Run(c.name, func(t *testing.T) {
peerStatus := &nbpeer.PeerStatus{
Connected: c.connected,
}
peer := &nbpeer.Peer{
InactivityExpirationEnabled: c.expirationEnabled,
LastLogin: c.lastLogin,
Status: peerStatus,
UserID: userID,
}
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
assert.Equal(t, expired, c.expected)
})
}
}
func TestAccountManager_GetNetworkMap(t *testing.T) { func TestAccountManager_GetNetworkMap(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {

View File

@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
return nil return nil
} }
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
accountCopy := Account{
Domain: domain,
DomainCategory: category,
IsDomainPrimaryAccount: isPrimaryDomain,
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account %s", accountID)
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
return users, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID) result := s.db.Find(&groups, accountIDCondition, accountID)
@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). // TODO: This fix is accepted for now, but if we need to handle this more frequently
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) // we may need to reconsider changing the types.
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
if s.storeEngine == PostgresStoreEngine {
query = query.Order("json_array_length(peers::json) DESC")
} else {
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found") return nil, status.Errorf(status.NotFound, "group not found")

View File

@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestSqlite_GetAccoundUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
t.Run("Should update attributes with public domain", func(t *testing.T) {
require.NoError(t, err)
domain := "example.com"
category := "public"
IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should update attributes with private domain", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, domain, account.Domain)
require.Equal(t, category, account.DomainCategory)
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
})
t.Run("Should fail when account does not exist", func(t *testing.T) {
require.NoError(t, err)
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
require.Error(t, err)
})
}
func TestSqlite_GetGroupByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
require.NoError(t, err)
require.Equal(t, "All", group.Name)
}

View File

@ -58,9 +58,11 @@ type Store interface {
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error

View File

@ -20,10 +20,11 @@ import (
) )
const ( const (
UserRoleOwner UserRole = "owner" UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin" UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user" UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown" UserRoleUnknown UserRole = "unknown"
UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active" UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled" UserStatusDisabled UserStatus = "disabled"
@ -42,6 +43,8 @@ func StrRoleToUserRole(strRole string) UserRole {
return UserRoleAdmin return UserRoleAdmin
case "user": case "user":
return UserRoleUser return UserRoleUser
case "billing_admin":
return UserRoleBillingAdmin
default: default:
return UserRoleUnknown return UserRoleUnknown
} }

View File

@ -16,8 +16,10 @@ const (
type Metrics struct { type Metrics struct {
metric.Meter metric.Meter
TransferBytesSent metric.Int64Counter TransferBytesSent metric.Int64Counter
TransferBytesRecv metric.Int64Counter TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter peers metric.Int64UpDownCounter
peerActivityChan chan string peerActivityChan chan string
@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err return nil, err
} }
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
m := &Metrics{ m := &Metrics{
Meter: meter, Meter: meter,
TransferBytesSent: bytesSent, TransferBytesSent: bytesSent,
TransferBytesRecv: bytesRecv, TransferBytesRecv: bytesRecv,
peers: peers, AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
ctx: ctx, ctx: ctx,
peerActivityChan: make(chan string, 10), peerActivityChan: make(chan string, 10),
@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
m.peerLastActive[id] = time.Time{} m.peerLastActive[id] = time.Time{}
} }
// RecordAuthenticationTime measures the time taken for peer authentication
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// RecordPeerStoreTime measures the time to store the peer in map
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) { func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1) m.peers.Add(m.ctx, -1)
@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
} }
} }
} }
func getStandardBucketBoundaries() []float64 {
return []float64{
0.1,
0.5,
1,
5,
10,
50,
100,
500,
1000,
5000,
10000,
}
}

153
relay/server/handshake.go Normal file
View File

@ -0,0 +1,153 @@
package server
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
)
// preparedMsg contains the marshalled success response messages
type preparedMsg struct {
responseHelloMsg []byte
responseAuthMsg []byte
}
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
rhm, err := marshalResponseHelloMsg(instanceURL)
if err != nil {
return nil, err
}
ram, err := messages.MarshalAuthResponse(instanceURL)
if err != nil {
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
}
return &preparedMsg{
responseHelloMsg: rhm,
responseAuthMsg: ram,
}, nil
}
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
addr := &address.Address{URL: instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal response address: %w", err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
}
return responseMsg, nil
}
type handshake struct {
conn net.Conn
validator auth.Validator
preparedMsg *preparedMsg
handshakeMethodAuth bool
peerID string
}
func (h *handshake) handshakeReceive() ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
}
var (
bytePeerID []byte
peerID string
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
case messages.MsgTypeAuth:
h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
}
h.peerID = peerID
return bytePeerID, nil
}
func (h *handshake) handshakeResponse() error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
} else {
responseMsg = h.preparedMsg.responseHelloMsg
}
if _, err := h.conn.Write(responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}
return nil
}
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
}
return rawPeerID, peerID, nil
}

View File

@ -7,16 +7,13 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
@ -28,6 +25,7 @@ type Relay struct {
store *Store store *Store
instanceURL string instanceURL string
preparedMsg *preparedMsg
closed bool closed bool
closeMu sync.RWMutex closeMu sync.RWMutex
@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return nil, fmt.Errorf("get instance URL: %v", err) return nil, fmt.Errorf("get instance URL: %v", err)
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("prepare message: %v", err)
}
return r, nil return r, nil
} }
@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now()
r.closeMu.RLock() r.closeMu.RLock()
defer r.closeMu.RUnlock() defer r.closeMu.RUnlock()
if r.closed { if r.closed {
return return
} }
peerID, err := r.handshake(conn) h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
peerID, err := h.handshakeReceive()
if err != nil { if err != nil {
log.Errorf("failed to handshake: %s", err) log.Errorf("failed to handshake: %s", err)
cErr := conn.Close() if cErr := conn.Close(); cErr != nil {
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
} }
return return
@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, peerID, conn, r.store) peer := NewPeer(r.metrics, peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer) r.store.AddPeer(peer)
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()
if err := h.handshakeResponse(); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
} }
// Shutdown closes the relay server // Shutdown closes the relay server
@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
func (r *Relay) InstanceURL() string { func (r *Relay) InstanceURL() string {
return r.instanceURL return r.instanceURL
} }
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
}
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
}
var (
responseMsg []byte
peerID []byte
)
switch msgType {
//nolint:staticcheck
case messages.MsgTypeHello:
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
case messages.MsgTypeAuth:
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
}
if err != nil {
return nil, err
}
_, err = conn.Write(responseMsg)
if err != nil {
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
}
return peerID, nil
}
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
//nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
}
//nolint:staticcheck
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
}
addr := &address.Address{URL: r.instanceURL}
addrData, err := addr.Marshal()
if err != nil {
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
}
//nolint:staticcheck
responseMsg, err := messages.MarshalHelloResponse(addrData)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
}
return rawPeerID, responseMsg, nil
}
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
}
peerID := messages.HashIDToString(rawPeerID)
if err := r.validator.Validate(authPayload); err != nil {
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
}
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
if err != nil {
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
}
return rawPeerID, responseMsg, nil
}

View File

@ -86,7 +86,7 @@ download_release_binary() {
# Unzip the app and move to INSTALL_DIR # Unzip the app and move to INSTALL_DIR
unzip -q -o "$BINARY_NAME" unzip -q -o "$BINARY_NAME"
mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
else else
${SUDO} mkdir -p "$INSTALL_DIR" ${SUDO} mkdir -p "$INSTALL_DIR"
tar -xzvf "$BINARY_NAME" tar -xzvf "$BINARY_NAME"

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"time" "time"
"github.com/netbirdio/signal-dispatcher/dispatcher"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
@ -13,8 +14,6 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"

View File

@ -11,7 +11,8 @@ import (
const ( const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard // NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00 NetbirdFwmark = 0x1BD00
PreroutingFwmark = 0x1BD01
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
) )