diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b2e2437e6..14e383a27 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.14" + SIGN_PIPE_VER: "v0.0.16" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -223,4 +223,4 @@ jobs: repo: netbirdio/sign-pipelines ref: ${{ env.SIGN_PIPE_VER }} token: ${{ secrets.SIGN_GITHUB_TOKEN }} - inputs: '{ "tag": "${{ github.ref }}" }' + inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' diff --git a/.goreleaser.yaml b/.goreleaser.yaml index cf2ce4f4f..e718b3fcd 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -96,6 +96,9 @@ builds: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser mod_timestamp: "{{ .CommitTimestamp }}" +universal_binaries: + - id: netbird + archives: - builds: - netbird diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index bccb7f471..0a0082075 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -23,6 +23,9 @@ builds: tags: - load_wgnt_from_rsrc +universal_binaries: + - id: netbird-ui-darwin + archives: - builds: - netbird-ui-darwin diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c6a96a876..c271e592d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -21,13 +22,19 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type entry struct { + spec []string + position int +} + type aclManager struct { iptablesClient *iptables.IPTables wgIface iFaceMapper routingFwChainName string - entries map[string][][]string - ipsetStore *ipsetStore + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi wgIface: wgIface, routingFwChainName: routingFwChainName, - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error { } } + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error { } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } +} + func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6fefd58e6..94bd2fccf 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering( } func (m *Manager) AddRouteFiltering( - sources [] netip.Prefix, + sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 737b20785..129323928 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error { log.Debug("flushing routing related tables") for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := tableFilter - if chain == chainRTNAT { - table = tableNat - } + table := r.getTableForChain(chain) ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error { func (r *router) createContainers() error { for _, chain := range []string{chainRTFWD, chainRTNAT} { if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %v", chain, err) + return fmt.Errorf("create chain %s: %w", chain, err) } } if err := r.insertEstablishedRule(chainRTFWD); err != nil { - return fmt.Errorf("insert established rule: %v", err) + return fmt.Errorf("insert established rule: %w", err) } - return r.addJumpRules() + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) + } + + return nil } func (r *router) createAndSetupChain(chain string) error { @@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" + lointdir := "-o" if inverse { intdir = "-o" + lointdir = "-i" } - return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} + return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump} } func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index a6185d370..556bda0d6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { // GenerateSetName generates a unique name for an ipset based on the given sources. func GenerateSetName(sources []netip.Prefix) string { // sort for consistent naming - sortPrefixes(sources) + SortPrefixes(sources) var sourcesStr strings.Builder for _, src := range sources { @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { return merged } -// sortPrefixes sorts the given slice of netip.Prefix in place. +// SortPrefixes sorts the given slice of netip.Prefix in place. // It sorts first by IP address, then by prefix length (most specific to least specific). -func sortPrefixes(prefixes []netip.Prefix) { +func SortPrefixes(prefixes []netip.Prefix) { sort.Slice(prefixes, func(i, j int) bool { addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) if addrCmp != 0 { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index eaf7fb6a0..61434f035 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,6 +31,7 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -40,15 +43,14 @@ var ( ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,7 +63,7 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) // and is permanent. Using same connection for both type of operations @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to netbird-rt-fwd - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) err = m.rConn.Flush() if err != nil { @@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) { return fmt.Errorf(flushError, err) } + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) + } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) } @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain { return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d2258ae08..01b08bd71 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain * rule := &nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 1, - }, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, + Exprs: getEstablishedExprs(1), } conn.InsertRule(rule) } + +func getEstablishedExprs(register uint32) []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: register, + }, + &expr.Bitwise{ + SourceRegister: register, + DestRegister: register, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: register, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 904050a51..bbe18ab07 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) { Register: 1, Data: []byte{0, 0, 0, 0}, }, + &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa61e1858..03526fee7 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,8 @@ import ( "net/netip" "strings" + "github.com/coreos/go-iptables/iptables" + "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -24,7 +26,7 @@ import ( const ( chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" + chainNameRoutingNat = "netbird-rt-postrouting" chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" @@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa } } - err = r.cleanUpDefaultForwardRules() + err = r.removeAcceptForwardRules() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } @@ -97,40 +99,7 @@ func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() - return r.cleanUpDefaultForwardRules() -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list chains: %v", err) - } - - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { - continue - } - - rules, err := r.conn.GetRules(r.filterTable, chain) - if err != nil { - return fmt.Errorf("get rules: %v", err) - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } - } - } - } - - return r.conn.Flush() + return r.removeAcceptForwardRules() } func (r *router) loadFilterTable() (*nftables.Table, error) { @@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { } func (r *router) createContainers() error { - r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, @@ -157,25 +125,28 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + prio := *nftables.ChainPriorityNATSource - 1 + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, + Priority: &prio, Type: nftables.ChainTypeNAT, }) - r.acceptForwardRules() + if err := r.acceptForwardRules(); err != nil { + log.Errorf("failed to add accept rules for the forward chain: %s", err) + } - err := r.refreshRulesMap() - if err != nil { + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.conn.Flush() - if err != nil { + if err := r.conn.Flush(); err != nil { return fmt.Errorf("nftables: unable to initialize table: %v", err) } + return nil } @@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - r.rules[string(ruleKey)] = r.conn.AddRule(rule) + rule = r.conn.AddRule(rule) - return ruleKey, r.conn.Flush() + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { @@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + setName := r.findSetNameInRule(nftRule) if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { destExp := generateCIDRMatcherExpressions(false, pair.Destination) dir := expr.MetaKeyIIFNAME + notDir := expr.MetaKeyOIFNAME if pair.Inverse { dir = expr.MetaKeyOIFNAME + notDir = expr.MetaKeyIIFNAME } + lo := ifname("lo") intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ &expr.Meta{ Key: dir, @@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: intf, }, + + // We need to exclude the loopback interface as this changes the ebpf proxy port + &expr.Meta{ + Key: notDir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: lo, + }, } exprs = append(exprs, sourceExp...) @@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. -func (r *router) acceptForwardRules() { +func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") - return + return nil } + fw := "iptables" + + defer func() { + log.Debugf("Used %s to add accept forward rules", fw) + }() + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + // filter table exists but iptables is not + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + + fw = "nftables" + return r.acceptForwardRulesNftables() + } + + return r.acceptForwardRulesIptables(ipt) +} + +func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + } else { + log.Debugf("added iptables rule: %v", rule) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) getAcceptForwardRules() [][]string { + intf := r.wgIface.Name() + return [][]string{ + {"-i", intf, "-j", "ACCEPT"}, + {"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + } +} + +func (r *router) acceptForwardRulesNftables() error { intf := ifname(r.wgIface.Name()) // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() { } r.conn.InsertRule(iifRule) + oifExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, @@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() { Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, }, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 2, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Counter{}, - &expr.Verdict{Kind: expr.VerdictAccept}, - }, + Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) + + return nil +} + +func (r *router) removeAcceptForwardRules() error { + if r.filterTable == nil { + return nil + } + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + return r.removeAcceptForwardRulesNftables() + } + + return r.removeAcceptForwardRulesIptables(ipt) +} + +func (r *router) removeAcceptForwardRulesNftables() error { + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // RemoveNatRule removes a nftables rule pair from nat chains @@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed rules for %s", pair.Destination) + log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index bbf92f3be..c07111b4e 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) @@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) @@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) { ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + // Check if the rule is in the internal map rule, ok := r.rules[ruleKey.GetRuleID()] assert.True(t, ok, "Rule not found in internal map") @@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) { // Verify actual nftables rule content verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) - - // Clean up - err = r.DeleteRouteRule(ruleKey) - require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index e27fce439..8ce73655d 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -1,8 +1,11 @@ package id import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/netip" + "strconv" "github.com/netbirdio/netbird/client/firewall/manager" ) @@ -21,5 +24,41 @@ func GenerateRouteRuleKey( dPort *manager.Port, action manager.Action, ) RuleID { - return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396..1b740388d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() + } + + if wgProxy != nil { + wgProxy.Work() + } + + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyICE = wgProxy - conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) - - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() - - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) - } - return nil, nil, err + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) + } + } + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() } - return ep, wgProxy, nil } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1..e850f4533 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..b6a8ac452 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..558121cdd 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be55..b88ff3f83 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..f73500717 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) diff --git a/client/ui/Info.plist b/client/ui/Info.plist deleted file mode 100644 index 8441110b9..000000000 --- a/client/ui/Info.plist +++ /dev/null @@ -1,12 +0,0 @@ - - - - - CFBundleExecutable - netbird-ui - CFBundleIconFile - Netbird - LSUIElement - 1 - - diff --git a/client/ui/netbird-ui.rb.tmpl b/client/ui/netbird-ui.rb.tmpl index 9efddd750..06971909d 100644 --- a/client/ui/netbird-ui.rb.tmpl +++ b/client/ui/netbird-ui.rb.tmpl @@ -8,11 +8,11 @@ cask "{{ $projectName }}" do if Hardware::CPU.intel? url "{{ $amdURL }}" sha256 "{{ crypto.SHA256 $amdFileBytes }}" - app "netbird_ui_darwin_amd64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" else url "{{ $armURL }}" sha256 "{{ crypto.SHA256 $armFileBytes }}" - app "netbird_ui_darwin_arm64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" end depends_on formula: "netbird" @@ -36,4 +36,4 @@ cask "{{ $projectName }}" do name "Netbird UI" desc "Netbird UI Client" homepage "https://www.netbird.io/" -end \ No newline at end of file +end diff --git a/go.mod b/go.mod index 552c9339d..a6b83794d 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,12 +56,12 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible @@ -90,10 +91,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -134,7 +135,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -222,7 +222,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index ff721504a..412542d5e 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= @@ -780,8 +780,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -877,8 +877,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -907,8 +907,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -980,8 +980,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -989,8 +989,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1005,8 +1005,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2c5c35d53..16b2364fb 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.3' + image: 'ghcr.io/zitadel/zitadel:v2.54.10' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env diff --git a/management/server/account.go b/management/server/account.go index d99664750..8e933404e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -44,12 +45,15 @@ import ( ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -177,6 +181,8 @@ type DefaultAccountManager struct { dnsDomain string peerLoginExpiry Scheduler + peerInactivityExpiry Scheduler + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool @@ -194,6 +200,13 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool @@ -224,6 +237,9 @@ func (s *Settings) Copy() *Settings { GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -605,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { return peers } +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer @@ -971,6 +1041,7 @@ func BuildManager( dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), + peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, @@ -1099,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, err + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1109,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } + + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + + return nil +} + func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1144,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context } } +// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found +func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextInactivePeerExpiration() + } + + expiredPeers := account.GetInactivePeers() + var peerIDs []string + for _, peer := range expiredPeers { + peerIDs = append(peerIDs, peer.ID) + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextInactivePeerExpiration() + } + + return account.GetNextInactivePeerExpiration() + } +} + +// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { + am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { @@ -1284,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil @@ -1299,28 +1432,39 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } - if user != nil && user.AppMetadata.WTAccountID == account.Id { + if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", - account.Id, userID) + accountID, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { return err } @@ -1544,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { - - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain - - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { + if claims.Domain == "" { log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil } - err := am.Store.SaveAccount(ctx, account) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err } - return nil + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) + return err + } + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1593,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } - } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { - return nil, err + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true - return account, nil + err = am.Store.SaveAccount(ctx, newAccount) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1774,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1962,16 +2141,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -1981,98 +2161,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - if claims.AccountId != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) - } - return claims.AccountId, nil - } return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } - - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } - - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } - - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } - - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + return "", err + } + + return userAccountID, nil + } + + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + cancel() + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } + + // We checked if the domain has a primary account already + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + if err != nil { + return "", err + } + + return claims.AccountId, nil +} + +func handleNotFound(err error) error { + if err == nil { + return nil + } + + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err + } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { @@ -2338,6 +2544,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } diff --git a/management/server/account_test.go b/management/server/account_test.go index d94110c65..3c3fcebc6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { + var ( + publicDomain = "public.com" + privateDomain = "private.com" + unknownDomain = "unknown.com" + ) + + defaultInitAccount := initUserParams{ + Domain: publicDomain, + UserId: "defaultUser", + } + + initUnknown := defaultInitAccount + initUnknown.DomainCategory = UnknownCategory + initUnknown.Domain = unknownDomain + + privateInitAccount := defaultInitAccount + privateInitAccount.Domain = privateDomain + privateInitAccount.DomainCategory = PrivateCategory + + testCases := []struct { name string inputClaims jwtclaims.AuthorizationClaims inputInitUserParams initUserParams @@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { expectedPrimaryDomainStatus bool expectedCreatedBy string expectedUsers []string - } - - var ( - publicDomain = "public.com" - privateDomain = "private.com" - unknownDomain = "unknown.com" - ) - - defaultInitAccount := initUserParams{ - Domain: publicDomain, - UserId: "defaultUser", - } - - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - - initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory - initUnknown.Domain = unknownDomain - - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - privateInitAccount := defaultInitAccount - privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory - - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -2025,6 +2019,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestAccount_GetInactivePeers(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with inactivity expiration disabled, no expired peers", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + }, + "peer-2": { + InactivityExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-30 * time.Minute), + UserID: userID, + }, + "peer-2": { + ID: "peer-2", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-2 * time.Hour), + UserID: userID, + }, + "peer-3": { + ID: "peer-3", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-1 * time.Hour), + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + "peer-2": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + } + + expiredPeers := account.GetInactivePeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } +} + func TestAccount_GetPeersWithExpiration(t *testing.T) { type test struct { name string @@ -2094,6 +2172,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { } } +func TestAccount_GetPeersWithInactivity(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*nbpeer.Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithInactivity() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } +} + func TestAccount_GetNextPeerExpiration(t *testing.T) { type test struct { name string @@ -2255,6 +2402,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } +func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*nbpeer.Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + LoginExpired: false, + LastSeen: time.Now().Add(-1 * time.Second), + }, + InactivityExpirationEnabled: true, + LastLogin: time.Now().UTC(), + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextInactivePeerExpiration() + assert.Equal(t, testCase.expectedNextRun, ok) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + }) + } +} + func TestAccount_SetJWTGroups(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4ee57f181..188494241 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -139,6 +139,13 @@ const ( PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check PostureCheckDeleted Activity = 62 + + PeerInactivityExpirationEnabled Activity = 63 + PeerInactivityExpirationDisabled Activity = 64 + + AccountPeerInactivityExpirationEnabled Activity = 65 + AccountPeerInactivityExpirationDisabled Activity = 66 + AccountPeerInactivityExpirationDurationUpdated Activity = 67 ) var activityMap = map[Activity]Code{ @@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{ PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, + + PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"}, + PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"}, + + AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, + AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, + AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/file_store.go b/management/server/file_store.go index df3e9bb77..561e133ce 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) { account.Settings = &Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: DefaultPeerLoginExpiration, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, } } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 91caa1512..4d4066de4 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), } if req.Settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index fd0343e97..9d5148248 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + peer_inactivity_expiration_enabled: + description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + type: boolean + example: true + peer_inactivity_expiration: + description: Period of time of inactivity after which peer session expires (seconds). + type: integer + example: 43200 regular_users_view_blocked: description: Allows blocking regular users from viewing parts of the system. type: boolean @@ -81,6 +89,8 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - peer_inactivity_expiration_enabled + - peer_inactivity_expiration - regular_users_view_blocked AccountExtraSettings: type: object @@ -243,6 +253,9 @@ components: login_expiration_enabled: type: boolean example: false + inactivity_expiration_enabled: + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -251,6 +264,7 @@ components: - name - ssh_enabled - login_expiration_enabled + - inactivity_expiration_enabled Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -327,6 +341,10 @@ components: type: string format: date-time example: "2023-05-05T09:00:35.477782Z" + inactivity_expiration_enabled: + description: Indicates whether peer inactivity expiration has been enabled or not + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -354,6 +372,7 @@ components: - last_seen - login_expiration_enabled - login_expired + - inactivity_expiration_enabled - os - ssh_enabled - user_id diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 570ec03c5..e2870d5d8 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -220,6 +220,12 @@ type AccountSettings struct { // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). + PeerInactivityExpiration int `json:"peer_inactivity_expiration"` + + // PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -538,6 +544,9 @@ type Peer struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -613,6 +622,9 @@ type PeerBatch struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string // PeerRequest defines model for PeerRequest. type PeerRequest struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - Name string `json:"name"` - SshEnabled bool `json:"ssh_enabled"` + ApprovalRequired *bool `json:"approval_required,omitempty"` + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + LoginExpirationEnabled bool `json:"login_expiration_enabled"` + Name string `json:"name"` + SshEnabled bool `json:"ssh_enabled"` } // PersonalAccessToken defines model for PersonalAccessToken. diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 4fbbc3106..a5856a0e4 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -14,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, SSHEnabled: req.SshEnabled, Name: req.Name, LoginExpirationEnabled: req.LoginExpirationEnabled, + + InactivityExpirationEnabled: req.InactivityExpirationEnabled, } if req.ApprovalRequired != nil { @@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, - LoginExpired: peer.Status.LoginExpired, - ApprovalRequired: !approved, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.LastLogin, + LoginExpired: peer.Status.LoginExpired, + ApprovalRequired: !approved, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } @@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, + + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } diff --git a/management/server/peer.go b/management/server/peer.go index a85e8c6b2..5246e1fdc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -111,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + if err != nil { + return err + } + + if peer.AddedWithSSOLogin() { + if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + + if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + + if expired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + // the expired one. Here we notify them that connection is now allowed again. + am.updateAccountPeers(ctx, account) + } + + return nil +} + +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -139,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return err + return false, err } - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } - - if oldStatus.LoginExpired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) - } - - return nil + return oldStatus.LoginExpired, nil } -// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. +// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -222,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + + if !peer.AddedWithSSOLogin() { + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + + event := activity.PeerInactivityExpirationEnabled + if !update.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + account.UpdatePeer(peer) err = am.Store.SaveAccount(ctx, account) @@ -454,23 +488,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s registrationTime := time.Now().UTC() newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - SetupKey: upperKey, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, } opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 40b3d71d9..ef96bce7d 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -38,6 +38,8 @@ type Peer struct { // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin LoginExpirationEnabled bool `diff:"-"` + + InactivityExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created @@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, + InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } @@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) { p.Status = newStatus } +// SessionExpired indicates whether the peer's session has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired. +// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired). +// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected { + return false, 0 + } + expiresAt := p.Status.LastSeen.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 170e3ba56..b8506bc50 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) { } } +func TestPeer_SessionExpired(t *testing.T) { + tt := []struct { + name string + expirationEnabled bool + lastLogin time.Time + connected bool + expected bool + accountSettings *Settings + }{ + { + name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", + expirationEnabled: false, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Hour, + }, + expected: false, + }, + { + name: "Peer Inactivity Should Expire", + expirationEnabled: true, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: true, + }, + { + name: "Peer Inactivity Should Not Expire", + expirationEnabled: true, + connected: true, + lastLogin: time.Now().UTC(), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peerStatus := &nbpeer.PeerStatus{ + Connected: c.connected, + } + peer := &nbpeer.Peer{ + InactivityExpirationEnabled: c.expirationEnabled, + LastLogin: c.lastLogin, + Status: peerStatus, + UserID: userID, + } + + expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 615203bee..47395f511 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) @@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). - Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + // TODO: This fix is accepted for now, but if we need to handle this more frequently + // we may need to reconsider changing the types. + query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + if s.storeEngine == PostgresStoreEngine { + query = query.Order("json_array_length(peers::json) DESC") + } else { + query = query.Order("json_array_length(peers) DESC") + } + + result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 06e118fd2..000eb1b11 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { }) assert.NoError(t, err) } + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} + +func TestSqlite_GetGroupByName(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + require.NoError(t, err) + require.Equal(t, "All", group.Name) +} diff --git a/management/server/store.go b/management/server/store.go index d914bb8f7..131fd8aaa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,9 +58,11 @@ type Store interface { GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error diff --git a/management/server/user.go b/management/server/user.go index e40fc67eb..a14dcde09 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -20,10 +20,11 @@ import ( ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -42,6 +43,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 13799713a..4dc98a0e0 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -16,8 +16,10 @@ const ( type Metrics struct { metric.Meter - TransferBytesSent metric.Int64Counter - TransferBytesRecv metric.Int64Counter + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + AuthenticationTime metric.Float64Histogram + PeerStoreTime metric.Float64Histogram peers metric.Int64UpDownCounter peerActivityChan chan string @@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + m := &Metrics{ - Meter: meter, - TransferBytesSent: bytesSent, - TransferBytesRecv: bytesRecv, - peers: peers, + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + AuthenticationTime: authTime, + PeerStoreTime: peerStoreTime, + peers: peers, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) { m.peerLastActive[id] = time.Time{} } +// RecordAuthenticationTime measures the time taken for peer authentication +func (m *Metrics) RecordAuthenticationTime(duration time.Duration) { + m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + +// RecordPeerStoreTime measures the time to store the peer in map +func (m *Metrics) RecordPeerStoreTime(duration time.Duration) { + m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections func (m *Metrics) PeerDisconnected(id string) { m.peers.Add(m.ctx, -1) @@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() { } } } + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go new file mode 100644 index 000000000..0257300f8 --- /dev/null +++ b/relay/server/handshake.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" +) + +// preparedMsg contains the marshalled success response messages +type preparedMsg struct { + responseHelloMsg []byte + responseAuthMsg []byte +} + +func newPreparedMsg(instanceURL string) (*preparedMsg, error) { + rhm, err := marshalResponseHelloMsg(instanceURL) + if err != nil { + return nil, err + } + + ram, err := messages.MarshalAuthResponse(instanceURL) + if err != nil { + return nil, fmt.Errorf("failed to marshal auth response msg: %w", err) + } + + return &preparedMsg{ + responseHelloMsg: rhm, + responseAuthMsg: ram, + }, nil +} + +func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { + addr := &address.Address{URL: instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal response address: %w", err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, fmt.Errorf("failed to marshal hello response: %w", err) + } + return responseMsg, nil +} + +type handshake struct { + conn net.Conn + validator auth.Validator + preparedMsg *preparedMsg + + handshakeMethodAuth bool + peerID string +} + +func (h *handshake) handshakeReceive() ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := h.conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) + } + + var ( + bytePeerID []byte + peerID string + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + case messages.MsgTypeAuth: + h.handshakeMethodAuth = true + bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + h.peerID = peerID + return bytePeerID, nil +} + +func (h *handshake) handshakeResponse() error { + var responseMsg []byte + if h.handshakeMethodAuth { + responseMsg = h.preparedMsg.responseAuthMsg + } else { + responseMsg = h.preparedMsg.responseHelloMsg + } + + if _, err := h.conn.Write(responseMsg); err != nil { + return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) + } + + return nil +} + +func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} + +func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := h.validator.Validate(authPayload); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} diff --git a/relay/server/relay.go b/relay/server/relay.go index 76c01a697..6cd8506ae 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -7,16 +7,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/auth" - "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck - "github.com/netbirdio/netbird/relay/messages/address" - //nolint:staticcheck - authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -28,6 +25,7 @@ type Relay struct { store *Store instanceURL string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return nil, fmt.Errorf("get instance URL: %v", err) } + r.preparedMsg, err = newPreparedMsg(r.instanceURL) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("prepare message: %v", err) + } + return r, nil } @@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { + acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() if r.closed { return } - peerID, err := r.handshake(conn) + h := handshake{ + conn: conn, + validator: r.validator, + preparedMsg: r.preparedMsg, + } + peerID, err := h.handshakeReceive() if err != nil { log.Errorf("failed to handshake: %s", err) - cErr := conn.Close() - if cErr != nil { + if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } return @@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, peerID, conn, r.store) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + storeTime := time.Now() r.store.AddPeer(peer) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() @@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) { peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() + + if err := h.handshakeResponse(); err != nil { + log.Errorf("failed to send handshake response, close peer: %s", err) + peer.Close() + } + r.metrics.RecordAuthenticationTime(time.Since(acceptTime)) } // Shutdown closes the relay server @@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } - -func (r *Relay) handshake(conn net.Conn) ([]byte, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) - } - - _, err = messages.ValidateVersion(buf[:n]) - if err != nil { - return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) - } - - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) - if err != nil { - return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) - } - - var ( - responseMsg []byte - peerID []byte - ) - switch msgType { - //nolint:staticcheck - case messages.MsgTypeHello: - peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - case messages.MsgTypeAuth: - peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - default: - return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) - } - if err != nil { - return nil, err - } - - _, err = conn.Write(responseMsg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil -} - -func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { - //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) - - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - //nolint:staticcheck - if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) - } - - //nolint:staticcheck - responseMsg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) - } - return rawPeerID, responseMsg, nil -} - -func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { - rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - - if err := r.validator.Validate(authPayload); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) - } - - responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) - } - - return rawPeerID, responseMsg, nil -} diff --git a/release_files/install.sh b/release_files/install.sh index b7a6c08f9..b0fec2733 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -86,7 +86,7 @@ download_release_binary() { # Unzip the app and move to INSTALL_DIR unzip -q -o "$BINARY_NAME" - mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" + mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" else ${SUDO} mkdir -p "$INSTALL_DIR" tar -xzvf "$BINARY_NAME" diff --git a/signal/server/signal.go b/signal/server/signal.go index 63cc43bd7..305fd052b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" diff --git a/util/net/net.go b/util/net/net.go index 61b47dbe7..035d7552b 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,7 +11,8 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" )