From 208a2b7169dbdb59f983d8478754b06af5061e51 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 10 Oct 2024 14:14:56 +0200 Subject: [PATCH 01/16] Add billing user role (#2714) --- management/server/user.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/management/server/user.go b/management/server/user.go index 38a8ac0c4..71608ef20 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -19,10 +19,11 @@ import ( ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } From 09bdd271f10fa80f42424ffdb14deb1db60e55a9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:54:34 +0200 Subject: [PATCH 02/16] [client] Improve route acl (#2705) - Update nftables library to v0.2.0 - Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker) - Add nft rules to internal map only if flush was successful - Improve error message if handle is 0 (= not found or hasn't been refreshed) - Add debug logging when route rules are added - Replace nftables userdata (rule ID) with a rule hash --- client/firewall/iptables/acl_linux.go | 57 +++++++- client/firewall/iptables/manager_linux.go | 2 +- client/firewall/iptables/router_linux.go | 15 ++- client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 124 +++++++++++++++--- client/firewall/nftables/router_linux.go | 35 +++-- client/firewall/nftables/router_linux_test.go | 8 +- client/internal/acl/id/id.go | 41 +++++- go.mod | 18 +-- go.sum | 32 ++--- util/net/net.go | 3 +- 11 files changed, 267 insertions(+), 74 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c6a96a876..c271e592d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -21,13 +22,19 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type entry struct { + spec []string + position int +} + type aclManager struct { iptablesClient *iptables.IPTables wgIface iFaceMapper routingFwChainName string - entries map[string][][]string - ipsetStore *ipsetStore + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi wgIface: wgIface, routingFwChainName: routingFwChainName, - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error { } } + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error { } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } +} + func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6fefd58e6..94bd2fccf 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering( } func (m *Manager) AddRouteFiltering( - sources [] netip.Prefix, + sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 737b20785..e60c352d5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error { log.Debug("flushing routing related tables") for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := tableFilter - if chain == chainRTNAT { - table = tableNat - } + table := r.getTableForChain(chain) ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error { func (r *router) createContainers() error { for _, chain := range []string{chainRTFWD, chainRTNAT} { if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %v", chain, err) + return fmt.Errorf("create chain %s: %w", chain, err) } } if err := r.insertEstablishedRule(chainRTFWD); err != nil { - return fmt.Errorf("insert established rule: %v", err) + return fmt.Errorf("insert established rule: %w", err) } - return r.addJumpRules() + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) + } + + return nil } func (r *router) createAndSetupChain(chain string) error { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index a6185d370..556bda0d6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { // GenerateSetName generates a unique name for an ipset based on the given sources. func GenerateSetName(sources []netip.Prefix) string { // sort for consistent naming - sortPrefixes(sources) + SortPrefixes(sources) var sourcesStr strings.Builder for _, src := range sources { @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { return merged } -// sortPrefixes sorts the given slice of netip.Prefix in place. +// SortPrefixes sorts the given slice of netip.Prefix in place. // It sorts first by IP address, then by prefix length (most specific to least specific). -func sortPrefixes(prefixes []netip.Prefix) { +func SortPrefixes(prefixes []netip.Prefix) { sort.Slice(prefixes, func(i, j int) bool { addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) if addrCmp != 0 { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index eaf7fb6a0..61434f035 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,6 +31,7 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -40,15 +43,14 @@ var ( ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,7 +63,7 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) // and is permanent. Using same connection for both type of operations @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to netbird-rt-fwd - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) err = m.rConn.Flush() if err != nil { @@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) { return fmt.Errorf(flushError, err) } + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) + } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) } @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain { return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa61e1858..9b8fdbda5 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -24,7 +25,7 @@ import ( const ( chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" + chainNameRoutingNat = "netbird-rt-postrouting" chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" @@ -149,7 +150,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { } func (r *router) createContainers() error { - r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, @@ -157,25 +157,26 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + prio := *nftables.ChainPriorityNATSource - 1 + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, + Priority: &prio, Type: nftables.ChainTypeNAT, }) r.acceptForwardRules() - err := r.refreshRulesMap() - if err != nil { + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.conn.Flush() - if err != nil { + if err := r.conn.Flush(); err != nil { return fmt.Errorf("nftables: unable to initialize table: %v", err) } + return nil } @@ -188,6 +189,7 @@ func (r *router) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -248,9 +250,18 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - r.rules[string(ruleKey)] = r.conn.AddRule(rule) + rule = r.conn.AddRule(rule) - return ruleKey, r.conn.Flush() + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { @@ -288,6 +299,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + setName := r.findSetNameInRule(nftRule) if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -658,7 +673,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed rules for %s", pair.Destination) + log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index bbf92f3be..25b7587ac 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -314,6 +314,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) { ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + // Check if the rule is in the internal map rule, ok := r.rules[ruleKey.GetRuleID()] assert.True(t, ok, "Rule not found in internal map") @@ -346,10 +350,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) { // Verify actual nftables rule content verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) - - // Clean up - err = r.DeleteRouteRule(ruleKey) - require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index e27fce439..8ce73655d 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -1,8 +1,11 @@ package id import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/netip" + "strconv" "github.com/netbirdio/netbird/client/firewall/manager" ) @@ -21,5 +24,41 @@ func GenerateRouteRuleKey( dPort *manager.Port, action manager.Action, ) RuleID { - return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) } diff --git a/go.mod b/go.mod index e7137ce5b..cb37ca4bb 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,7 +56,7 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 @@ -89,10 +90,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -133,7 +134,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 4563dc933..05df5c66e 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/util/net/net.go b/util/net/net.go index 61b47dbe7..035d7552b 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,7 +11,8 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) From b2379175fe856e24c71263b7f0dcfac77ab8a722 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:23:46 +0200 Subject: [PATCH 03/16] [signal] new signal dispatcher version (#2722) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index cb37ca4bb..e7e3c17a6 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 05df5c66e..e9bc318d6 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index 63cc43bd7..305fd052b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" From 0e95f16cdd8462242a6912ef061347caba67041b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 11 Oct 2024 16:24:30 +0200 Subject: [PATCH 04/16] [relay,client] Relay/fix/wg roaming (#2691) If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature. Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources. Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe. --- client/internal/peer/conn.go | 229 +++++++++++++----------- client/internal/wgproxy/ebpf/proxy.go | 35 +--- client/internal/wgproxy/ebpf/wrapper.go | 102 +++++++++-- client/internal/wgproxy/proxy.go | 5 +- client/internal/wgproxy/proxy_test.go | 2 +- client/internal/wgproxy/usp/proxy.go | 93 +++++++--- 6 files changed, 298 insertions(+), 168 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396..1b740388d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() + } + + if wgProxy != nil { + wgProxy.Work() + } + + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyICE = wgProxy - conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) - - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() - - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) - } - return nil, nil, err + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) + } + } + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() } - return ep, wgProxy, nil } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1..e850f4533 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..b6a8ac452 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..558121cdd 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be55..b88ff3f83 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..f73500717 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) From da3a053e2bed950bf9cf382f0690435548221745 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 12 Oct 2024 08:35:51 +0200 Subject: [PATCH 05/16] [management] Refactor getAccountIDWithAuthorizationClaims (#2715) This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance. - have dedicated methods to handle possible cases - introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods - Remove GetAccount and SaveAccount dependency - added tests --- management/server/account.go | 353 +++++++++++++++++----------- management/server/account_test.go | 280 +++++++++++----------- management/server/sql_store.go | 37 +++ management/server/sql_store_test.go | 60 +++++ management/server/store.go | 2 + 5 files changed, 450 insertions(+), 282 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 6ee0015f8..c468b5ecc 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -50,6 +51,8 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1285,7 +1288,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil @@ -1300,28 +1303,39 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } - if user != nil && user.AppMetadata.WTAccountID == account.Id { + if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", - account.Id, userID) + accountID, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { return err } @@ -1545,48 +1559,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { - - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain - - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { + if claims.Domain == "" { log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil } - err := am.Store.SaveAccount(ctx, account) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err } - return nil + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) + return err + } + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1594,44 +1629,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } - } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { - return nil, err + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true - return account, nil + err = am.Store.SaveAccount(ctx, newAccount) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1775,7 +1824,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1961,16 +2010,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -1980,98 +2030,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - if claims.AccountId != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) - } - return claims.AccountId, nil - } return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } - - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } - - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } - - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } - - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { + return "", err + } + + return userAccountID, nil + } + + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } + + // We checked if the domain has a primary account already + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) + if err != nil { + return "", err + } + + return claims.AccountId, nil +} + +func handleNotFound(err error) error { + if err == nil { + return nil + } + + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err + } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 4dd58e88e..b20071cba 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { + var ( + publicDomain = "public.com" + privateDomain = "private.com" + unknownDomain = "unknown.com" + ) + + defaultInitAccount := initUserParams{ + Domain: publicDomain, + UserId: "defaultUser", + } + + initUnknown := defaultInitAccount + initUnknown.DomainCategory = UnknownCategory + initUnknown.Domain = unknownDomain + + privateInitAccount := defaultInitAccount + privateInitAccount.Domain = privateDomain + privateInitAccount.DomainCategory = PrivateCategory + + testCases := []struct { name string inputClaims jwtclaims.AuthorizationClaims inputInitUserParams initUserParams @@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { expectedPrimaryDomainStatus bool expectedCreatedBy string expectedUsers []string - } - - var ( - publicDomain = "public.com" - privateDomain = "private.com" - unknownDomain = "unknown.com" - ) - - defaultInitAccount := initUserParams{ - Domain: publicDomain, - UserId: "defaultUser", - } - - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - - initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory - initUnknown.Domain = unknownDomain - - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - privateInitAccount := defaultInitAccount - privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory - - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, } - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 615203bee..de3dfa945 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 06e118fd2..20e812ea7 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1191,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { }) assert.NoError(t, err) } + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} diff --git a/management/server/store.go b/management/server/store.go index d914bb8f7..131fd8aaa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,9 +58,11 @@ type Store interface { GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error From 3a88ac78ff80b77eecfdcf9d7d66663b017419aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 12 Oct 2024 10:44:48 +0200 Subject: [PATCH 06/16] [client] Add table filter rules using iptables (#2727) This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate. --- client/firewall/nftables/manager_linux.go | 47 +++-- .../firewall/nftables/manager_linux_test.go | 1 + client/firewall/nftables/router_linux.go | 186 ++++++++++++------ 3 files changed, 148 insertions(+), 86 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d2258ae08..01b08bd71 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain * rule := &nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 1, - }, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, + Exprs: getEstablishedExprs(1), } conn.InsertRule(rule) } + +func getEstablishedExprs(register uint32) []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: register, + }, + &expr.Bitwise{ + SourceRegister: register, + DestRegister: register, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: register, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 904050a51..bbe18ab07 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) { Register: 1, Data: []byte{0, 0, 0, 0}, }, + &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9b8fdbda5..404ba6957 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/coreos/go-iptables/iptables" "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" @@ -81,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa } } - err = r.cleanUpDefaultForwardRules() + err = r.removeAcceptForwardRules() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } @@ -98,40 +99,7 @@ func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() - return r.cleanUpDefaultForwardRules() -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list chains: %v", err) - } - - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { - continue - } - - rules, err := r.conn.GetRules(r.filterTable, chain) - if err != nil { - return fmt.Errorf("get rules: %v", err) - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } - } - } - } - - return r.conn.Flush() + return r.removeAcceptForwardRules() } func (r *router) loadFilterTable() (*nftables.Table, error) { @@ -167,7 +135,9 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) - r.acceptForwardRules() + if err := r.acceptForwardRules(); err != nil { + log.Errorf("failed to add accept rules for the forward chain: %s", err) + } if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) @@ -577,19 +547,60 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. -func (r *router) acceptForwardRules() { +func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") - return + return nil } + fw := "iptables" + + defer func() { + log.Debugf("Used %s to add accept forward rules", fw) + }() + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + // filter table exists but iptables is not + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + + fw = "nftables" + return r.acceptForwardRulesNftables() + } + + return r.acceptForwardRulesIptables(ipt) +} + +func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + } else { + log.Debugf("added iptables rule: %v", rule) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) getAcceptForwardRules() [][]string { + intf := r.wgIface.Name() + return [][]string{ + {"-i", intf, "-j", "ACCEPT"}, + {"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + } +} + +func (r *router) acceptForwardRulesNftables() error { intf := ifname(r.wgIface.Name()) // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -609,6 +620,15 @@ func (r *router) acceptForwardRules() { } r.conn.InsertRule(iifRule) + oifExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, @@ -619,36 +639,72 @@ func (r *router) acceptForwardRules() { Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, }, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 2, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Counter{}, - &expr.Verdict{Kind: expr.VerdictAccept}, - }, + Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) + + return nil +} + +func (r *router) removeAcceptForwardRules() error { + if r.filterTable == nil { + return nil + } + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + return r.removeAcceptForwardRulesNftables() + } + + return r.removeAcceptForwardRulesIptables(ipt) +} + +func (r *router) removeAcceptForwardRulesNftables() error { + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // RemoveNatRule removes a nftables rule pair from nat chains From d93dd4fc7f47c9e1ac597af2570e6faaa36f1219 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 12 Oct 2024 18:21:34 +0200 Subject: [PATCH 07/16] [relay-server] Move the handshake logic to separated struct (#2648) * Move the handshake logic to separated struct - The server will response to the client after it ready to process the peer - Preload the response messages * Fix deprecated lint issue * Fix error handling * [relay-server] Relay measure auth time (#2675) Measure the Relay client's authentication time --- relay/metrics/realy.go | 52 +++++++++++-- relay/server/handshake.go | 153 ++++++++++++++++++++++++++++++++++++++ relay/server/relay.go | 127 ++++++------------------------- 3 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 relay/server/handshake.go diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 13799713a..4dc98a0e0 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -16,8 +16,10 @@ const ( type Metrics struct { metric.Meter - TransferBytesSent metric.Int64Counter - TransferBytesRecv metric.Int64Counter + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + AuthenticationTime metric.Float64Histogram + PeerStoreTime metric.Float64Histogram peers metric.Int64UpDownCounter peerActivityChan chan string @@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + m := &Metrics{ - Meter: meter, - TransferBytesSent: bytesSent, - TransferBytesRecv: bytesRecv, - peers: peers, + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + AuthenticationTime: authTime, + PeerStoreTime: peerStoreTime, + peers: peers, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) { m.peerLastActive[id] = time.Time{} } +// RecordAuthenticationTime measures the time taken for peer authentication +func (m *Metrics) RecordAuthenticationTime(duration time.Duration) { + m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + +// RecordPeerStoreTime measures the time to store the peer in map +func (m *Metrics) RecordPeerStoreTime(duration time.Duration) { + m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections func (m *Metrics) PeerDisconnected(id string) { m.peers.Add(m.ctx, -1) @@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() { } } } + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go new file mode 100644 index 000000000..0257300f8 --- /dev/null +++ b/relay/server/handshake.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" +) + +// preparedMsg contains the marshalled success response messages +type preparedMsg struct { + responseHelloMsg []byte + responseAuthMsg []byte +} + +func newPreparedMsg(instanceURL string) (*preparedMsg, error) { + rhm, err := marshalResponseHelloMsg(instanceURL) + if err != nil { + return nil, err + } + + ram, err := messages.MarshalAuthResponse(instanceURL) + if err != nil { + return nil, fmt.Errorf("failed to marshal auth response msg: %w", err) + } + + return &preparedMsg{ + responseHelloMsg: rhm, + responseAuthMsg: ram, + }, nil +} + +func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { + addr := &address.Address{URL: instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal response address: %w", err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, fmt.Errorf("failed to marshal hello response: %w", err) + } + return responseMsg, nil +} + +type handshake struct { + conn net.Conn + validator auth.Validator + preparedMsg *preparedMsg + + handshakeMethodAuth bool + peerID string +} + +func (h *handshake) handshakeReceive() ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := h.conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) + } + + var ( + bytePeerID []byte + peerID string + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + case messages.MsgTypeAuth: + h.handshakeMethodAuth = true + bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + h.peerID = peerID + return bytePeerID, nil +} + +func (h *handshake) handshakeResponse() error { + var responseMsg []byte + if h.handshakeMethodAuth { + responseMsg = h.preparedMsg.responseAuthMsg + } else { + responseMsg = h.preparedMsg.responseHelloMsg + } + + if _, err := h.conn.Write(responseMsg); err != nil { + return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) + } + + return nil +} + +func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} + +func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := h.validator.Validate(authPayload); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} diff --git a/relay/server/relay.go b/relay/server/relay.go index 76c01a697..6cd8506ae 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -7,16 +7,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/auth" - "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck - "github.com/netbirdio/netbird/relay/messages/address" - //nolint:staticcheck - authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -28,6 +25,7 @@ type Relay struct { store *Store instanceURL string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return nil, fmt.Errorf("get instance URL: %v", err) } + r.preparedMsg, err = newPreparedMsg(r.instanceURL) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("prepare message: %v", err) + } + return r, nil } @@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { + acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() if r.closed { return } - peerID, err := r.handshake(conn) + h := handshake{ + conn: conn, + validator: r.validator, + preparedMsg: r.preparedMsg, + } + peerID, err := h.handshakeReceive() if err != nil { log.Errorf("failed to handshake: %s", err) - cErr := conn.Close() - if cErr != nil { + if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } return @@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, peerID, conn, r.store) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + storeTime := time.Now() r.store.AddPeer(peer) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() @@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) { peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() + + if err := h.handshakeResponse(); err != nil { + log.Errorf("failed to send handshake response, close peer: %s", err) + peer.Close() + } + r.metrics.RecordAuthenticationTime(time.Since(acceptTime)) } // Shutdown closes the relay server @@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } - -func (r *Relay) handshake(conn net.Conn) ([]byte, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) - } - - _, err = messages.ValidateVersion(buf[:n]) - if err != nil { - return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) - } - - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) - if err != nil { - return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) - } - - var ( - responseMsg []byte - peerID []byte - ) - switch msgType { - //nolint:staticcheck - case messages.MsgTypeHello: - peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - case messages.MsgTypeAuth: - peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - default: - return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) - } - if err != nil { - return nil, err - } - - _, err = conn.Write(responseMsg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil -} - -func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { - //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) - - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - //nolint:staticcheck - if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) - } - - //nolint:staticcheck - responseMsg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) - } - return rawPeerID, responseMsg, nil -} - -func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { - rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - - if err := r.validator.Validate(authPayload); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) - } - - responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) - } - - return rawPeerID, responseMsg, nil -} From 49e65109d25a98bfa73245e1252663d37184533a Mon Sep 17 00:00:00 2001 From: ctrl-zzz <78654296+ctrl-zzz@users.noreply.github.com> Date: Sun, 13 Oct 2024 14:52:43 +0200 Subject: [PATCH 08/16] Add session expire functionality based on inactivity (#2326) Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required. --- management/server/account.go | 133 +++++++++ management/server/account_test.go | 315 +++++++++++++++++++++ management/server/activity/codes.go | 14 + management/server/file_store.go | 3 + management/server/http/accounts_handler.go | 3 + management/server/http/api/openapi.yml | 19 ++ management/server/http/api/types.gen.go | 21 +- management/server/http/peers_handler.go | 54 ++-- management/server/peer.go | 97 +++++-- management/server/peer/peer.go | 20 ++ management/server/peer_test.go | 62 ++++ 11 files changed, 682 insertions(+), 59 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c468b5ecc..7c84ad1ca 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -51,6 +51,7 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -181,6 +182,8 @@ type DefaultAccountManager struct { dnsDomain string peerLoginExpiry Scheduler + peerInactivityExpiry Scheduler + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool @@ -198,6 +201,13 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool @@ -228,6 +238,9 @@ func (s *Settings) Copy() *Settings { GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -609,6 +622,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { return peers } +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer @@ -975,6 +1042,7 @@ func BuildManager( dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), + peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, @@ -1103,6 +1171,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, err + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1113,6 +1186,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } + + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + + return nil +} + func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1148,6 +1241,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context } } +// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found +func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextInactivePeerExpiration() + } + + expiredPeers := account.GetInactivePeers() + var peerIDs []string + for _, peer := range expiredPeers { + peerIDs = append(peerIDs, peer.ID) + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextInactivePeerExpiration() + } + + return account.GetNextInactivePeerExpiration() + } +} + +// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { + am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { @@ -2412,6 +2542,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } diff --git a/management/server/account_test.go b/management/server/account_test.go index b20071cba..19514dad1 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1957,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestAccount_GetInactivePeers(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with inactivity expiration disabled, no expired peers", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + }, + "peer-2": { + InactivityExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-30 * time.Minute), + UserID: userID, + }, + "peer-2": { + ID: "peer-2", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-2 * time.Hour), + UserID: userID, + }, + "peer-3": { + ID: "peer-3", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-1 * time.Hour), + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + "peer-2": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + } + + expiredPeers := account.GetInactivePeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } +} + func TestAccount_GetPeersWithExpiration(t *testing.T) { type test struct { name string @@ -2026,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { } } +func TestAccount_GetPeersWithInactivity(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*nbpeer.Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithInactivity() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } +} + func TestAccount_GetNextPeerExpiration(t *testing.T) { type test struct { name string @@ -2187,6 +2340,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } +func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*nbpeer.Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + LoginExpired: false, + LastSeen: time.Now().Add(-1 * time.Second), + }, + InactivityExpirationEnabled: true, + LastLogin: time.Now().UTC(), + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextInactivePeerExpiration() + assert.Equal(t, testCase.expectedNextRun, ok) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + }) + } +} + func TestAccount_SetJWTGroups(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4ee57f181..188494241 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -139,6 +139,13 @@ const ( PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check PostureCheckDeleted Activity = 62 + + PeerInactivityExpirationEnabled Activity = 63 + PeerInactivityExpirationDisabled Activity = 64 + + AccountPeerInactivityExpirationEnabled Activity = 65 + AccountPeerInactivityExpirationDisabled Activity = 66 + AccountPeerInactivityExpirationDurationUpdated Activity = 67 ) var activityMap = map[Activity]Code{ @@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{ PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, + + PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"}, + PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"}, + + AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, + AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, + AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/file_store.go b/management/server/file_store.go index df3e9bb77..561e133ce 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) { account.Settings = &Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: DefaultPeerLoginExpiration, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, } } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 91caa1512..4d4066de4 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), } if req.Settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index fd0343e97..9d5148248 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + peer_inactivity_expiration_enabled: + description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + type: boolean + example: true + peer_inactivity_expiration: + description: Period of time of inactivity after which peer session expires (seconds). + type: integer + example: 43200 regular_users_view_blocked: description: Allows blocking regular users from viewing parts of the system. type: boolean @@ -81,6 +89,8 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - peer_inactivity_expiration_enabled + - peer_inactivity_expiration - regular_users_view_blocked AccountExtraSettings: type: object @@ -243,6 +253,9 @@ components: login_expiration_enabled: type: boolean example: false + inactivity_expiration_enabled: + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -251,6 +264,7 @@ components: - name - ssh_enabled - login_expiration_enabled + - inactivity_expiration_enabled Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -327,6 +341,10 @@ components: type: string format: date-time example: "2023-05-05T09:00:35.477782Z" + inactivity_expiration_enabled: + description: Indicates whether peer inactivity expiration has been enabled or not + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -354,6 +372,7 @@ components: - last_seen - login_expiration_enabled - login_expired + - inactivity_expiration_enabled - os - ssh_enabled - user_id diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 570ec03c5..e2870d5d8 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -220,6 +220,12 @@ type AccountSettings struct { // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). + PeerInactivityExpiration int `json:"peer_inactivity_expiration"` + + // PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -538,6 +544,9 @@ type Peer struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -613,6 +622,9 @@ type PeerBatch struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string // PeerRequest defines model for PeerRequest. type PeerRequest struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - Name string `json:"name"` - SshEnabled bool `json:"ssh_enabled"` + ApprovalRequired *bool `json:"approval_required,omitempty"` + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + LoginExpirationEnabled bool `json:"login_expiration_enabled"` + Name string `json:"name"` + SshEnabled bool `json:"ssh_enabled"` } // PersonalAccessToken defines model for PersonalAccessToken. diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 4fbbc3106..a5856a0e4 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -14,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, SSHEnabled: req.SshEnabled, Name: req.Name, LoginExpirationEnabled: req.LoginExpirationEnabled, + + InactivityExpirationEnabled: req.InactivityExpirationEnabled, } if req.ApprovalRequired != nil { @@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, - LoginExpired: peer.Status.LoginExpired, - ApprovalRequired: !approved, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.LastLogin, + LoginExpired: peer.Status.LoginExpired, + ApprovalRequired: !approved, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } @@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, + + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } diff --git a/management/server/peer.go b/management/server/peer.go index a7d4f3b06..a4c7e1266 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + if err != nil { + return err + } + + if peer.AddedWithSSOLogin() { + if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + + if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + + if expired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + // the expired one. Here we notify them that connection is now allowed again. + am.updateAccountPeers(ctx, account) + } + + return nil +} + +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -138,25 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return err + return false, err } - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } - - if oldStatus.LoginExpired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) - } - - return nil + return oldStatus.LoginExpired, nil } -// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. +// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -219,6 +234,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + + if !peer.AddedWithSSOLogin() { + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + + event := activity.PeerInactivityExpirationEnabled + if !update.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + account.UpdatePeer(peer) err = am.Store.SaveAccount(ctx, account) @@ -442,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s registrationTime := time.Now().UTC() newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - SetupKey: upperKey, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, } opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 3d9ba18e9..9a53459a8 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -38,6 +38,8 @@ type Peer struct { // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin LoginExpirationEnabled bool + + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation LastLogin time.Time // CreatedAt records the time the peer was created @@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, + + InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } @@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) { p.Status = newStatus } +// SessionExpired indicates whether the peer's session has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired. +// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired). +// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected { + return false, 0 + } + expiresAt := p.Status.LastSeen.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). diff --git a/management/server/peer_test.go b/management/server/peer_test.go index f3bf0ddba..c5edb5636 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) { } } +func TestPeer_SessionExpired(t *testing.T) { + tt := []struct { + name string + expirationEnabled bool + lastLogin time.Time + connected bool + expected bool + accountSettings *Settings + }{ + { + name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", + expirationEnabled: false, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Hour, + }, + expected: false, + }, + { + name: "Peer Inactivity Should Expire", + expirationEnabled: true, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: true, + }, + { + name: "Peer Inactivity Should Not Expire", + expirationEnabled: true, + connected: true, + lastLogin: time.Now().UTC(), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peerStatus := &nbpeer.PeerStatus{ + Connected: c.connected, + } + peer := &nbpeer.Peer{ + InactivityExpirationEnabled: c.expirationEnabled, + LastLogin: c.lastLogin, + Status: peerStatus, + UserID: userID, + } + + expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil { From cee95461d15fac87ba01ddb63f0715099e985d4f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 15 Oct 2024 15:03:17 +0200 Subject: [PATCH 09/16] [client] Add universal bin build and update sign workflow version (#2738) * Add universal binaries build for macOS * update sign pipeline version * handle info.plist in sign workflow --- .github/workflows/release.yml | 4 ++-- .goreleaser.yaml | 3 +++ .goreleaser_ui_darwin.yaml | 3 +++ client/ui/Info.plist | 12 ------------ 4 files changed, 8 insertions(+), 14 deletions(-) delete mode 100644 client/ui/Info.plist diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b2e2437e6..1b85ec7ef 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.14" + SIGN_PIPE_VER: "v0.0.15" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -223,4 +223,4 @@ jobs: repo: netbirdio/sign-pipelines ref: ${{ env.SIGN_PIPE_VER }} token: ${{ secrets.SIGN_GITHUB_TOKEN }} - inputs: '{ "tag": "${{ github.ref }}" }' + inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' diff --git a/.goreleaser.yaml b/.goreleaser.yaml index cf2ce4f4f..e718b3fcd 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -96,6 +96,9 @@ builds: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser mod_timestamp: "{{ .CommitTimestamp }}" +universal_binaries: + - id: netbird + archives: - builds: - netbird diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index bccb7f471..0a0082075 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -23,6 +23,9 @@ builds: tags: - load_wgnt_from_rsrc +universal_binaries: + - id: netbird-ui-darwin + archives: - builds: - netbird-ui-darwin diff --git a/client/ui/Info.plist b/client/ui/Info.plist deleted file mode 100644 index 8441110b9..000000000 --- a/client/ui/Info.plist +++ /dev/null @@ -1,12 +0,0 @@ - - - - - CFBundleExecutable - netbird-ui - CFBundleIconFile - Netbird - LSUIElement - 1 - - From 8c8900be57b76e40bedcb4c6c56d4a57d2afd9bf Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:35:59 +0200 Subject: [PATCH 10/16] [client] Exclude loopback from NAT (#2747) --- client/firewall/iptables/router_linux.go | 4 +++- client/firewall/nftables/router_linux.go | 15 +++++++++++++++ client/firewall/nftables/router_linux_test.go | 12 ++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e60c352d5..129323928 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -433,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" + lointdir := "-o" if inverse { intdir = "-o" + lointdir = "-i" } - return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} + return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump} } func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 404ba6957..03526fee7 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -425,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { destExp := generateCIDRMatcherExpressions(false, pair.Destination) dir := expr.MetaKeyIIFNAME + notDir := expr.MetaKeyOIFNAME if pair.Inverse { dir = expr.MetaKeyOIFNAME + notDir = expr.MetaKeyIIFNAME } + lo := ifname("lo") intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ &expr.Meta{ Key: dir, @@ -440,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: intf, }, + + // We need to exclude the loopback interface as this changes the ebpf proxy port + &expr.Meta{ + Key: notDir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: lo, + }, } exprs = append(exprs, sourceExp...) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 25b7587ac..c07111b4e 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) @@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) From f942491b91d8d4627402512f5ec8ff5054a570f7 Mon Sep 17 00:00:00 2001 From: Emre Oksum Date: Wed, 16 Oct 2024 18:51:21 +0300 Subject: [PATCH 11/16] Update Zitadel version on quickstart script (#2744) Update Zitadel version at docker compose in quickstart script from 2.54.3 to 2.54.10 because 2.54.3 isn't stable and has a lot of bugs. --- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2c5c35d53..16b2364fb 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.3' + image: 'ghcr.io/zitadel/zitadel:v2.54.10' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env From 96d22076849027e7b8179feabbdd9892d600eb5a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 16 Oct 2024 18:55:30 +0300 Subject: [PATCH 12/16] Fix JSON function compatibility for SQLite and PostgreSQL (#2746) resolves the issue with json_array_length compatibility between SQLite and PostgreSQL. It adjusts the query to conditionally cast types: PostgreSQL: Casts to json with ::json. SQLite: Uses the text representation directly. --- management/server/sql_store.go | 12 ++++++++++-- management/server/sql_store_test.go | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index de3dfa945..47395f511 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1154,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). - Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + // TODO: This fix is accepted for now, but if we need to handle this more frequently + // we may need to reconsider changing the types. + query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + if s.storeEngine == PostgresStoreEngine { + query = query.Order("json_array_length(peers::json) DESC") + } else { + query = query.Order("json_array_length(peers) DESC") + } + + result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20e812ea7..000eb1b11 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1251,3 +1251,16 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { }) } + +func TestSqlite_GetGroupByName(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + require.NoError(t, err) + require.Equal(t, "All", group.Name) +} From ccd4ae6315853249002de2cb6d31dea6a3d330f4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 17 Oct 2024 19:21:35 +0200 Subject: [PATCH 13/16] Fix domain information is up to date check (#2754) --- management/server/account.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 7c84ad1ca..4c4806bb5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,15 +45,15 @@ import ( ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1440,7 +1440,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, u return err } cachedAccount := &Account{ - Id: accountID, + Id: accountID, Users: make(map[string]*User), } for _, user := range accountUsers { @@ -2276,7 +2276,7 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory + return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { From 507a40bd7f7a9119fd9ddbfe72ef84160b864dae Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 17 Oct 2024 20:39:59 +0200 Subject: [PATCH 14/16] Fix decompress zip path (#2755) Since 0.30.2 the decompressed binary path from the signed package has changed now it doesn't contain the arch suffix this change handles that --- client/ui/netbird-ui.rb.tmpl | 6 +++--- release_files/install.sh | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/ui/netbird-ui.rb.tmpl b/client/ui/netbird-ui.rb.tmpl index 9efddd750..06971909d 100644 --- a/client/ui/netbird-ui.rb.tmpl +++ b/client/ui/netbird-ui.rb.tmpl @@ -8,11 +8,11 @@ cask "{{ $projectName }}" do if Hardware::CPU.intel? url "{{ $amdURL }}" sha256 "{{ crypto.SHA256 $amdFileBytes }}" - app "netbird_ui_darwin_amd64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" else url "{{ $armURL }}" sha256 "{{ crypto.SHA256 $armFileBytes }}" - app "netbird_ui_darwin_arm64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" end depends_on formula: "netbird" @@ -36,4 +36,4 @@ cask "{{ $projectName }}" do name "Netbird UI" desc "Netbird UI Client" homepage "https://www.netbird.io/" -end \ No newline at end of file +end diff --git a/release_files/install.sh b/release_files/install.sh index b7a6c08f9..b0fec2733 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -86,7 +86,7 @@ download_release_binary() { # Unzip the app and move to INSTALL_DIR unzip -q -o "$BINARY_NAME" - mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" + mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" else ${SUDO} mkdir -p "$INSTALL_DIR" tar -xzvf "$BINARY_NAME" From c8d8748dcf051b647f751511a385c180556ee929 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 18 Oct 2024 17:28:58 +0200 Subject: [PATCH 15/16] Update sign workflow version (#2756) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1b85ec7ef..14e383a27 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.15" + SIGN_PIPE_VER: "v0.0.16" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" From 88e4fc2245e5490b40678e51998e08bfd49adad9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 19 Oct 2024 18:32:17 +0200 Subject: [PATCH 16/16] Release global lock on early error (#2760) --- management/server/account.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/account.go b/management/server/account.go index 4c4806bb5..cca3b4e52 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2220,6 +2220,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont // check again if the domain has a primary account because of simultaneous requests domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { + cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", nil, err }