mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
This commit is contained in:
commit
53218f99bc
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.14"
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
@ -223,4 +223,4 @@ jobs:
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
@ -96,6 +96,9 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
- netbird
|
||||
|
@ -23,6 +23,9 @@ builds:
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
- netbird-ui-darwin
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -21,13 +22,19 @@ const (
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
entries map[string][][]string
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
||||
}
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources [] netip.Prefix,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
|
@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := tableFilter
|
||||
if chain == chainRTNAT {
|
||||
table = tableNat
|
||||
}
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
func (r *router) createContainers() error {
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %v", chain, err)
|
||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %v", err)
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
return r.addJumpRules()
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
lointdir = "-i"
|
||||
}
|
||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
|
@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
sortPrefixes(sources)
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
return merged
|
||||
}
|
||||
|
||||
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func sortPrefixes(prefixes []netip.Prefix) {
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
|
@ -11,12 +11,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -29,6 +31,7 @@ const (
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@ -40,15 +43,14 @@ var (
|
||||
)
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainOutputRules *nftables.Chain
|
||||
chainFwFilter *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
@ -61,7 +63,7 @@ type iFaceMapper interface {
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
||||
}
|
||||
|
||||
m := &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward() {
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
|
||||
m.addPreroutingRule(preroutingChain)
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: preroutingChain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.chainInputRules.Name,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routeingFwChainName,
|
||||
Chain: m.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainFwFilter,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
|
@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
Exprs: getEstablishedExprs(1),
|
||||
}
|
||||
|
||||
conn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func getEstablishedExprs(register uint32) []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: register,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: register,
|
||||
DestRegister: register,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: register,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
@ -24,7 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
err = r.removeAcceptForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
@ -97,40 +99,7 @@ func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.cleanUpDefaultForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
return r.removeAcceptForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
@ -157,25 +125,28 @@ func (r *router) createContainers() error {
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.acceptForwardRules()
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering(
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
||||
rule = r.conn.AddRule(rule)
|
||||
|
||||
return ruleKey, r.conn.Flush()
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
notDir := expr.MetaKeyOIFNAME
|
||||
if pair.Inverse {
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
notDir = expr.MetaKeyIIFNAME
|
||||
}
|
||||
|
||||
lo := ifname("lo")
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: notDir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: lo,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (r *router) acceptForwardRules() {
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// filter table exists but iptables is not
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() {
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
oifExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() {
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 2,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
// Verify actual nftables rule content
|
||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,11 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
||||
manager.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
// Write all fields to the hasher, with delimiters
|
||||
h.Write([]byte("sources:"))
|
||||
for _, src := range sources {
|
||||
h.Write([]byte(src.String()))
|
||||
h.Write([]byte(","))
|
||||
}
|
||||
|
||||
h.Write([]byte("destination:"))
|
||||
h.Write([]byte(destination.String()))
|
||||
|
||||
h.Write([]byte("proto:"))
|
||||
h.Write([]byte(proto))
|
||||
|
||||
h.Write([]byte("sPort:"))
|
||||
if sPort != nil {
|
||||
h.Write([]byte(sPort.String()))
|
||||
} else {
|
||||
h.Write([]byte("<nil>"))
|
||||
}
|
||||
|
||||
h.Write([]byte("dPort:"))
|
||||
if dPort != nil {
|
||||
h.Write([]byte(dPort.String()))
|
||||
} else {
|
||||
h.Write([]byte("<nil>"))
|
||||
}
|
||||
|
||||
h.Write([]byte("action:"))
|
||||
h.Write([]byte(strconv.Itoa(int(action))))
|
||||
hash := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// prepend destination prefix to be able to identify the rule
|
||||
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
|
||||
}
|
||||
|
@ -82,8 +82,6 @@ type Conn struct {
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
@ -106,7 +104,8 @@ type Conn struct {
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
|
||||
endpointRelay *net.UDPAddr
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
|
||||
// for reconnection operations
|
||||
iCEDisconnected chan bool
|
||||
@ -257,8 +256,7 @@ func (conn *Conn) Close() {
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
|
||||
defer conn.updateIceState(iceConnInfo)
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Infof("set ICE to active connection")
|
||||
|
||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
||||
if err != nil {
|
||||
return
|
||||
var (
|
||||
ep *net.UDPAddr
|
||||
wgProxy wgproxy.Proxy
|
||||
err error
|
||||
)
|
||||
if iceConnInfo.RelayedOnLocal {
|
||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
ep = wgProxy.EndpointAddr()
|
||||
conn.wgProxyICE = wgProxy
|
||||
} else {
|
||||
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolveUDPaddr")
|
||||
conn.handleConfigurationFailure(err, nil)
|
||||
return
|
||||
}
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
if wgProxy != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
||||
}
|
||||
}
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
}
|
||||
|
||||
if wgProxy != nil {
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyICE = wgProxy
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
|
||||
conn.statusICE.Set(StatusConnected)
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
}
|
||||
|
||||
@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
|
||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||
|
||||
if conn.wgProxyICE != nil {
|
||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// switch back to relay connection
|
||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
||||
if conn.isReadyToUpgrade() {
|
||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
||||
if err != nil {
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||
conn.statusICE.Set(newState)
|
||||
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
if err := rci.relayedConn.Close(); err != nil {
|
||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("Relay connection is ready to use")
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||
return
|
||||
}
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.endpointRelay = endpointUdpAddr
|
||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
|
||||
if conn.currentConnPriority > connPriorityRelay {
|
||||
if conn.statusICE.Get() == StatusConnected {
|
||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
return
|
||||
}
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
conn.connIDRelay = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
||||
if err != nil {
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
}
|
||||
@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("relay connection is disconnected")
|
||||
conn.log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == connPriorityRelay {
|
||||
log.Debugf("clean up WireGuard config")
|
||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
if err != nil {
|
||||
conn.log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.endpointRelay = nil
|
||||
_ = conn.wgProxyRelay.CloseConn()
|
||||
conn.wgProxyRelay = nil
|
||||
}
|
||||
|
||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||
conn.statusRelay.Set(StatusDisconnected)
|
||||
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
||||
Relayed: conn.isRelayed(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
||||
if !iceConnInfo.RelayedOnLocal {
|
||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
||||
}
|
||||
conn.log.Debugf("setup ice turn connection")
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
||||
if err != nil {
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
||||
}
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
return wgProxy, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) iceP2PIsActive() bool {
|
||||
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.relayDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||
select {
|
||||
case conn.iCEDisconnected <- changed:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||
}
|
||||
}
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Work()
|
||||
}
|
||||
return ep, wgProxy, nil
|
||||
}
|
||||
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
|
@ -5,7 +5,6 @@ package ebpf
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
||||
defer p.removeTurnConn(endpointPort)
|
||||
|
||||
var (
|
||||
err error
|
||||
n int
|
||||
)
|
||||
buf := make([]byte, 1500)
|
||||
for ctx.Err() == nil {
|
||||
n, err = remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err != io.EOF {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
@ -4,8 +4,13 @@ package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
@ -13,20 +18,55 @@ type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
ctxConn, cancel := context.WithCancel(ctx)
|
||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
e.remoteConn = remoteConn
|
||||
e.cancel = cancel
|
||||
return addr, err
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
@ -7,6 +7,9 @@ import (
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr
|
||||
Work()
|
||||
Pause()
|
||||
CloseConn() error
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
@ -15,13 +15,17 @@ import (
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn start the proxy with the given remote conn
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
// AddTurnConn
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
go p.proxyToRemote()
|
||||
go p.proxyToLocal()
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return p.localConn.LocalAddr(), err
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUserSpaceProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUserSpaceProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
for ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
|
@ -1,12 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>netbird-ui</string>
|
||||
<key>CFBundleIconFile</key>
|
||||
<string>Netbird</string>
|
||||
<key>LSUIElement</key>
|
||||
<string>1</string>
|
||||
</dict>
|
||||
</plist>
|
@ -8,11 +8,11 @@ cask "{{ $projectName }}" do
|
||||
if Hardware::CPU.intel?
|
||||
url "{{ $amdURL }}"
|
||||
sha256 "{{ crypto.SHA256 $amdFileBytes }}"
|
||||
app "netbird_ui_darwin_amd64", target: "Netbird UI.app"
|
||||
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||
else
|
||||
url "{{ $armURL }}"
|
||||
sha256 "{{ crypto.SHA256 $armFileBytes }}"
|
||||
app "netbird_ui_darwin_arm64", target: "Netbird UI.app"
|
||||
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||
end
|
||||
|
||||
depends_on formula: "netbird"
|
||||
@ -36,4 +36,4 @@ cask "{{ $projectName }}" do
|
||||
name "Netbird UI"
|
||||
desc "Netbird UI Client"
|
||||
homepage "https://www.netbird.io/"
|
||||
end
|
||||
end
|
||||
|
20
go.mod
20
go.mod
@ -19,8 +19,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.24.0
|
||||
golang.org/x/sys v0.21.0
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/sys v0.26.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@ -38,6 +38,7 @@ require (
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
@ -45,7 +46,7 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/google/nftables v0.2.0
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
@ -55,12 +56,12 @@ require (
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/magiconair/properties v1.8.7
|
||||
github.com/mattn/go-sqlite3 v1.14.19
|
||||
github.com/mdlayher/socket v0.4.1
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@ -90,10 +91,10 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/net v0.30.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/term v0.21.0
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/term v0.25.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
@ -134,7 +135,6 @@ require (
|
||||
github.com/containerd/containerd v1.7.16 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
@ -222,7 +222,7 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
|
36
go.sum
36
go.sum
@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
||||
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
|
||||
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||
@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
@ -780,8 +780,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@ -877,8 +877,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@ -907,8 +907,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@ -980,8 +980,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@ -989,8 +989,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
|
||||
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@ -1005,8 +1005,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
@ -873,7 +873,7 @@ services:
|
||||
zitadel:
|
||||
restart: 'always'
|
||||
networks: [netbird]
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
|
||||
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
env_file:
|
||||
- ./zitadel.env
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
@ -44,12 +45,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@ -177,6 +181,8 @@ type DefaultAccountManager struct {
|
||||
dnsDomain string
|
||||
peerLoginExpiry Scheduler
|
||||
|
||||
peerInactivityExpiry Scheduler
|
||||
|
||||
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||
userDeleteFromIDPEnabled bool
|
||||
|
||||
@ -194,6 +200,13 @@ type Settings struct {
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
|
||||
PeerInactivityExpirationEnabled bool
|
||||
|
||||
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
|
||||
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
|
||||
PeerInactivityExpiration time.Duration
|
||||
|
||||
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
||||
RegularUsersViewBlocked bool
|
||||
|
||||
@ -224,6 +237,9 @@ func (s *Settings) Copy() *Settings {
|
||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||
JWTAllowGroups: s.JWTAllowGroups,
|
||||
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
@ -605,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetInactivePeers returns peers that have been expired by inactivity
|
||||
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, inactivePeer := range a.GetPeersWithInactivity() {
|
||||
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||
if inactive {
|
||||
peers = append(peers, inactivePeer)
|
||||
}
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
|
||||
peersWithExpiry := a.GetPeersWithInactivity()
|
||||
if len(peersWithExpiry) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
|
||||
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
for _, peer := range a.Peers {
|
||||
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetPeers returns a list of all Account peers
|
||||
func (a *Account) GetPeers() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
@ -971,6 +1041,7 @@ func BuildManager(
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
peerInactivityExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
@ -1099,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@ -1109,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@ -1144,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
|
||||
}
|
||||
}
|
||||
|
||||
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
|
||||
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
|
||||
expiredPeers := account.GetInactivePeers()
|
||||
var peerIDs []string
|
||||
for _, peer := range expiredPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
}
|
||||
|
||||
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
||||
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
||||
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
||||
}
|
||||
}
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
// If ID is already in use (due to collision) we try one more time before returning error
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
||||
@ -1284,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
@ -1299,28 +1432,39 @@ func isNil(i idp.Manager) bool {
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &Account{
|
||||
Id: accountID,
|
||||
Users: make(map[string]*User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
cachedAccount.Users[user.Id] = user
|
||||
}
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
account.Id, userID)
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1544,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
|
||||
if claims.Domain != "" {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
userObj := account.Users[claims.UserId]
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||
account.Domain = lowerDomain
|
||||
}
|
||||
// prevent updating category for different domain until admin logs in
|
||||
if account.Domain == lowerDomain {
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
}
|
||||
} else {
|
||||
if claims.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(ctx, account)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
newDomain = lowerDomain
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
existingAcc *Account,
|
||||
primaryDomain bool,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) error {
|
||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1593,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
// if domain already has a primary account, add regular user
|
||||
if domainAcc != nil {
|
||||
account = domainAcc
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
return account, nil
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
usersMap := make(map[string]*User)
|
||||
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
@ -1774,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
return "", "", fmt.Errorf("user ID is empty")
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@ -1962,16 +2141,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is not private or domain is invalid, it will return the account ID by user ID.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
// Use cases:
|
||||
//
|
||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
||||
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
|
||||
//
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
|
||||
//
|
||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
||||
// New user + New account + Existing Public Domain -> create account, user role = owner
|
||||
//
|
||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||
//
|
||||
@ -1981,98 +2161,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
|
||||
if claims.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
return "", errors.New(emptyUserID)
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
if claims.AccountId != "" {
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
||||
}
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
||||
defer unlock()
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
if claims.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||
defer unlockAccount()
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != "" {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userAccountID, nil
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return domainAccountID, nil, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
|
||||
// check again if the domain has a primary account because of simultaneous requests
|
||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
cancel()
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
func handleNotFound(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||
return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
@ -2338,6 +2544,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
privateDomain = "private.com"
|
||||
unknownDomain = "unknown.com"
|
||||
)
|
||||
|
||||
defaultInitAccount := initUserParams{
|
||||
Domain: publicDomain,
|
||||
UserId: "defaultUser",
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputClaims jwtclaims.AuthorizationClaims
|
||||
inputInitUserParams initUserParams
|
||||
@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
expectedPrimaryDomainStatus bool
|
||||
expectedCreatedBy string
|
||||
expectedUsers []string
|
||||
}
|
||||
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
privateDomain = "private.com"
|
||||
unknownDomain = "unknown.com"
|
||||
)
|
||||
|
||||
defaultInitAccount := initUserParams{
|
||||
Domain: publicDomain,
|
||||
UserId: "defaultUser",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
}{
|
||||
{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomainCategory: "",
|
||||
expectedDomain: publicDomain,
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pub-domain-user",
|
||||
expectedUsers: []string{"pub-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomainCategory: "",
|
||||
expectedDomain: publicDomain,
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pub-domain-user",
|
||||
expectedUsers: []string{"pub-domain-user"},
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
testCase2 := test{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: unknownDomain,
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "unknown-domain-user",
|
||||
expectedUsers: []string{"unknown-domain-user"},
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: unknownDomain,
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "unknown-domain-user",
|
||||
expectedUsers: []string{"unknown-domain-user"},
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
}
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
|
||||
testCase4 := test{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
}
|
||||
|
||||
testCase6 := test{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputUpdateClaimAccount: true,
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
},
|
||||
inputUpdateClaimAccount: true,
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
}
|
||||
|
||||
testCase7 := test{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: "",
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: "",
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
@ -2025,6 +2019,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
peers map[string]*nbpeer.Peer
|
||||
expectedPeers map[string]struct{}
|
||||
}
|
||||
testCases := []test{
|
||||
{
|
||||
name: "Peers with inactivity expiration disabled, no expired peers",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
InactivityExpirationEnabled: false,
|
||||
},
|
||||
"peer-2": {
|
||||
InactivityExpirationEnabled: false,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{},
|
||||
},
|
||||
{
|
||||
name: "Two peers expired",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
ID: "peer-1",
|
||||
InactivityExpirationEnabled: true,
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||
Connected: false,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
ID: "peer-2",
|
||||
InactivityExpirationEnabled: true,
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||
Connected: false,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-3": {
|
||||
ID: "peer-3",
|
||||
InactivityExpirationEnabled: true,
|
||||
Status: &nbpeer.PeerStatus{
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{
|
||||
"peer-1": {},
|
||||
"peer-2": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
expiredPeers := account.GetInactivePeers()
|
||||
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
|
||||
for _, peer := range expiredPeers {
|
||||
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
|
||||
t.Fatalf("expected to have peer %s expired", peer.ID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@ -2094,6 +2172,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetPeersWithInactivity(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
peers map[string]*nbpeer.Peer
|
||||
expectedPeers map[string]struct{}
|
||||
}
|
||||
|
||||
testCases := []test{
|
||||
{
|
||||
name: "No account peers, no peers with expiration",
|
||||
peers: map[string]*nbpeer.Peer{},
|
||||
expectedPeers: map[string]struct{}{},
|
||||
},
|
||||
{
|
||||
name: "Peers with login expiration disabled, no peers with expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{},
|
||||
},
|
||||
{
|
||||
name: "Peers with login expiration enabled, return peers with expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
ID: "peer-1",
|
||||
InactivityExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{
|
||||
"peer-1": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: testCase.peers,
|
||||
}
|
||||
|
||||
actual := account.GetPeersWithInactivity()
|
||||
assert.Len(t, actual, len(testCase.expectedPeers))
|
||||
if len(testCase.expectedPeers) > 0 {
|
||||
for k := range testCase.expectedPeers {
|
||||
contains := false
|
||||
for _, peer := range actual {
|
||||
if k == peer.ID {
|
||||
contains = true
|
||||
}
|
||||
}
|
||||
assert.True(t, contains)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@ -2255,6 +2402,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
peers map[string]*nbpeer.Peer
|
||||
expiration time.Duration
|
||||
expirationEnabled bool
|
||||
expectedNextRun bool
|
||||
expectedNextExpiration time.Duration
|
||||
}
|
||||
|
||||
expectedNextExpiration := time.Minute
|
||||
testCases := []test{
|
||||
{
|
||||
name: "No peers, no expiration",
|
||||
peers: map[string]*nbpeer.Peer{},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
{
|
||||
name: "No connected peers, no expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: false,
|
||||
},
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: false,
|
||||
},
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
{
|
||||
name: "Connected peers with disabled expiration, no expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
},
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
},
|
||||
InactivityExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
{
|
||||
name: "Expired peers, no expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: true,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: true,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
{
|
||||
name: "To be expired peer, return expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: false,
|
||||
LoginExpired: false,
|
||||
LastSeen: time.Now().Add(-1 * time.Second),
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
LastLogin: time.Now().UTC(),
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: true,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Minute,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: true,
|
||||
expectedNextExpiration: expectedNextExpiration,
|
||||
},
|
||||
{
|
||||
name: "Peers added with setup keys, no expiration",
|
||||
peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
InactivityExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
|
||||
}
|
||||
|
||||
expiration, ok := account.GetNextInactivePeerExpiration()
|
||||
assert.Equal(t, testCase.expectedNextRun, ok)
|
||||
if testCase.expectedNextRun {
|
||||
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
|
||||
} else {
|
||||
assert.Equal(t, expiration, testCase.expectedNextExpiration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
@ -139,6 +139,13 @@ const (
|
||||
PostureCheckUpdated Activity = 61
|
||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||
PostureCheckDeleted Activity = 62
|
||||
|
||||
PeerInactivityExpirationEnabled Activity = 63
|
||||
PeerInactivityExpirationDisabled Activity = 64
|
||||
|
||||
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
|
||||
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
||||
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
||||
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
||||
|
||||
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
|
||||
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
|
||||
|
||||
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
account.Settings = &Settings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
|
@ -54,6 +54,14 @@ components:
|
||||
description: Period of time after which peer login expires (seconds).
|
||||
type: integer
|
||||
example: 43200
|
||||
peer_inactivity_expiration_enabled:
|
||||
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||
type: boolean
|
||||
example: true
|
||||
peer_inactivity_expiration:
|
||||
description: Period of time of inactivity after which peer session expires (seconds).
|
||||
type: integer
|
||||
example: 43200
|
||||
regular_users_view_blocked:
|
||||
description: Allows blocking regular users from viewing parts of the system.
|
||||
type: boolean
|
||||
@ -81,6 +89,8 @@ components:
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
- peer_inactivity_expiration_enabled
|
||||
- peer_inactivity_expiration
|
||||
- regular_users_view_blocked
|
||||
AccountExtraSettings:
|
||||
type: object
|
||||
@ -243,6 +253,9 @@ components:
|
||||
login_expiration_enabled:
|
||||
type: boolean
|
||||
example: false
|
||||
inactivity_expiration_enabled:
|
||||
type: boolean
|
||||
example: false
|
||||
approval_required:
|
||||
description: (Cloud only) Indicates whether peer needs approval
|
||||
type: boolean
|
||||
@ -251,6 +264,7 @@ components:
|
||||
- name
|
||||
- ssh_enabled
|
||||
- login_expiration_enabled
|
||||
- inactivity_expiration_enabled
|
||||
Peer:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PeerMinimum'
|
||||
@ -327,6 +341,10 @@ components:
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2023-05-05T09:00:35.477782Z"
|
||||
inactivity_expiration_enabled:
|
||||
description: Indicates whether peer inactivity expiration has been enabled or not
|
||||
type: boolean
|
||||
example: false
|
||||
approval_required:
|
||||
description: (Cloud only) Indicates whether peer needs approval
|
||||
type: boolean
|
||||
@ -354,6 +372,7 @@ components:
|
||||
- last_seen
|
||||
- login_expiration_enabled
|
||||
- login_expired
|
||||
- inactivity_expiration_enabled
|
||||
- os
|
||||
- ssh_enabled
|
||||
- user_id
|
||||
|
@ -220,6 +220,12 @@ type AccountSettings struct {
|
||||
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||
|
||||
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
||||
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
||||
|
||||
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
|
||||
|
||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||
|
||||
@ -538,6 +544,9 @@ type Peer struct {
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
@ -613,6 +622,9 @@ type PeerBatch struct {
|
||||
// Id Peer ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
|
||||
// Ip Peer's IP address
|
||||
Ip string `json:"ip"`
|
||||
|
||||
@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
|
||||
// PeerRequest defines model for PeerRequest.
|
||||
type PeerRequest struct {
|
||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
Name string `json:"name"`
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||
Name string `json:"name"`
|
||||
SshEnabled bool `json:"ssh_enabled"`
|
||||
}
|
||||
|
||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@ -14,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PeersHandler is a handler that returns peers of the account
|
||||
@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
|
||||
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||
}
|
||||
|
||||
if req.ApprovalRequired != nil {
|
||||
@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
}
|
||||
|
||||
return &api.Peer{
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.LastLogin,
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
ApprovalRequired: !approved,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.LastLogin,
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
ApprovalRequired: !approved,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
@ -139,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
if oldStatus.LoginExpired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
return oldStatus.LoginExpired, nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
@ -222,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||
}
|
||||
|
||||
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||
|
||||
event := activity.PeerInactivityExpirationEnabled
|
||||
if !update.InactivityExpirationEnabled {
|
||||
event = activity.PeerInactivityExpirationDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@ -454,23 +488,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: registrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser,
|
||||
}
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
|
@ -38,6 +38,8 @@ type Peer struct {
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time `diff:"-"`
|
||||
// CreatedAt records the time the peer was created
|
||||
@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer {
|
||||
CreatedAt: p.CreatedAt,
|
||||
Ephemeral: p.Ephemeral,
|
||||
Location: p.Location,
|
||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
|
||||
p.Status = newStatus
|
||||
}
|
||||
|
||||
// SessionExpired indicates whether the peer's session has expired or not.
|
||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
|
||||
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
|
||||
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||
// Only peers added by interactive SSO login can be expired.
|
||||
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
|
||||
return false, 0
|
||||
}
|
||||
expiresAt := p.Status.LastSeen.Add(expiresIn)
|
||||
now := time.Now()
|
||||
timeLeft := expiresAt.Sub(now)
|
||||
return timeLeft <= 0, timeLeft
|
||||
}
|
||||
|
||||
// LoginExpired indicates whether the peer's login has expired or not.
|
||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||
|
@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeer_SessionExpired(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expirationEnabled bool
|
||||
lastLogin time.Time
|
||||
connected bool
|
||||
expected bool
|
||||
accountSettings *Settings
|
||||
}{
|
||||
{
|
||||
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
|
||||
expirationEnabled: false,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Peer Inactivity Should Expire",
|
||||
expirationEnabled: true,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Peer Inactivity Should Not Expire",
|
||||
expirationEnabled: true,
|
||||
connected: true,
|
||||
lastLogin: time.Now().UTC(),
|
||||
accountSettings: &Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tt {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
peerStatus := &nbpeer.PeerStatus{
|
||||
Connected: c.connected,
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
InactivityExpirationEnabled: c.expirationEnabled,
|
||||
LastLogin: c.lastLogin,
|
||||
Status: peerStatus,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
|
||||
assert.Equal(t, expired, c.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
|
@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account %s", accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
// we may need to reconsider changing the types.
|
||||
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||
if s.storeEngine == PostgresStoreEngine {
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
} else {
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
|
@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
t.Run("Should update attributes with public domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "example.com"
|
||||
category := "public"
|
||||
IsDomainPrimaryAccount := false
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should update attributes with private domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should fail when account does not exist", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "All", group.Name)
|
||||
}
|
||||
|
@ -58,9 +58,11 @@ type Store interface {
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
|
@ -20,10 +20,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleBillingAdmin UserRole = "billing_admin"
|
||||
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
@ -42,6 +43,8 @@ func StrRoleToUserRole(strRole string) UserRole {
|
||||
return UserRoleAdmin
|
||||
case "user":
|
||||
return UserRoleUser
|
||||
case "billing_admin":
|
||||
return UserRoleBillingAdmin
|
||||
default:
|
||||
return UserRoleUnknown
|
||||
}
|
||||
|
@ -16,8 +16,10 @@ const (
|
||||
type Metrics struct {
|
||||
metric.Meter
|
||||
|
||||
TransferBytesSent metric.Int64Counter
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
TransferBytesSent metric.Int64Counter
|
||||
TransferBytesRecv metric.Int64Counter
|
||||
AuthenticationTime metric.Float64Histogram
|
||||
PeerStoreTime metric.Float64Histogram
|
||||
|
||||
peers metric.Int64UpDownCounter
|
||||
peerActivityChan chan string
|
||||
@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Metrics{
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
TransferBytesRecv: bytesRecv,
|
||||
peers: peers,
|
||||
Meter: meter,
|
||||
TransferBytesSent: bytesSent,
|
||||
TransferBytesRecv: bytesRecv,
|
||||
AuthenticationTime: authTime,
|
||||
PeerStoreTime: peerStoreTime,
|
||||
peers: peers,
|
||||
|
||||
ctx: ctx,
|
||||
peerActivityChan: make(chan string, 10),
|
||||
@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
|
||||
m.peerLastActive[id] = time.Time{}
|
||||
}
|
||||
|
||||
// RecordAuthenticationTime measures the time taken for peer authentication
|
||||
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
|
||||
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// RecordPeerStoreTime measures the time to store the peer in map
|
||||
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||
func (m *Metrics) PeerDisconnected(id string) {
|
||||
m.peers.Add(m.ctx, -1)
|
||||
@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getStandardBucketBoundaries() []float64 {
|
||||
return []float64{
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
50,
|
||||
100,
|
||||
500,
|
||||
1000,
|
||||
5000,
|
||||
10000,
|
||||
}
|
||||
}
|
||||
|
153
relay/server/handshake.go
Normal file
153
relay/server/handshake.go
Normal file
@ -0,0 +1,153 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
)
|
||||
|
||||
// preparedMsg contains the marshalled success response messages
|
||||
type preparedMsg struct {
|
||||
responseHelloMsg []byte
|
||||
responseAuthMsg []byte
|
||||
}
|
||||
|
||||
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
|
||||
rhm, err := marshalResponseHelloMsg(instanceURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ram, err := messages.MarshalAuthResponse(instanceURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
|
||||
}
|
||||
|
||||
return &preparedMsg{
|
||||
responseHelloMsg: rhm,
|
||||
responseAuthMsg: ram,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||
addr := &address.Address{URL: instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal response address: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
|
||||
}
|
||||
return responseMsg, nil
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn net.Conn
|
||||
validator auth.Validator
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
handshakeMethodAuth bool
|
||||
peerID string
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := h.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
var (
|
||||
bytePeerID []byte
|
||||
peerID string
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
case messages.MsgTypeAuth:
|
||||
h.handshakeMethodAuth = true
|
||||
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.peerID = peerID
|
||||
return bytePeerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handshakeResponse() error {
|
||||
var responseMsg []byte
|
||||
if h.handshakeMethodAuth {
|
||||
responseMsg = h.preparedMsg.responseAuthMsg
|
||||
} else {
|
||||
responseMsg = h.preparedMsg.responseHelloMsg
|
||||
}
|
||||
|
||||
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
}
|
||||
|
||||
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := h.validator.Validate(authPayload); err != nil {
|
||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return rawPeerID, peerID, nil
|
||||
}
|
@ -7,16 +7,13 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/auth"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
@ -28,6 +25,7 @@ type Relay struct {
|
||||
|
||||
store *Store
|
||||
instanceURL string
|
||||
preparedMsg *preparedMsg
|
||||
|
||||
closed bool
|
||||
closeMu sync.RWMutex
|
||||
@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||
}
|
||||
|
||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||
if err != nil {
|
||||
metricsCancel()
|
||||
return nil, fmt.Errorf("prepare message: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||
|
||||
// Accept start to handle a new peer connection
|
||||
func (r *Relay) Accept(conn net.Conn) {
|
||||
acceptTime := time.Now()
|
||||
r.closeMu.RLock()
|
||||
defer r.closeMu.RUnlock()
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
peerID, err := r.handshake(conn)
|
||||
h := handshake{
|
||||
conn: conn,
|
||||
validator: r.validator,
|
||||
preparedMsg: r.preparedMsg,
|
||||
}
|
||||
peerID, err := h.handshakeReceive()
|
||||
if err != nil {
|
||||
log.Errorf("failed to handshake: %s", err)
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
if cErr := conn.Close(); cErr != nil {
|
||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||
}
|
||||
return
|
||||
@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
|
||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||
storeTime := time.Now()
|
||||
r.store.AddPeer(peer)
|
||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
go func() {
|
||||
peer.Work()
|
||||
@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
}()
|
||||
|
||||
if err := h.handshakeResponse(); err != nil {
|
||||
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||
peer.Close()
|
||||
}
|
||||
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
|
||||
}
|
||||
|
||||
// Shutdown closes the relay server
|
||||
@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
||||
func (r *Relay) InstanceURL() string {
|
||||
return r.instanceURL
|
||||
}
|
||||
|
||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, messages.MaxHandshakeSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
var (
|
||||
responseMsg []byte
|
||||
peerID []byte
|
||||
)
|
||||
switch msgType {
|
||||
//nolint:staticcheck
|
||||
case messages.MsgTypeHello:
|
||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
case messages.MsgTypeAuth:
|
||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = conn.Write(responseMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
||||
//nolint:staticcheck
|
||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
||||
|
||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
addr := &address.Address{URL: r.instanceURL}
|
||||
addrData, err := addr.Marshal()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
||||
}
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||
}
|
||||
|
||||
peerID := messages.HashIDToString(rawPeerID)
|
||||
|
||||
if err := r.validator.Validate(authPayload); err != nil {
|
||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
||||
}
|
||||
|
||||
return rawPeerID, responseMsg, nil
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ download_release_binary() {
|
||||
|
||||
# Unzip the app and move to INSTALL_DIR
|
||||
unzip -q -o "$BINARY_NAME"
|
||||
mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
||||
mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
||||
else
|
||||
${SUDO} mkdir -p "$INSTALL_DIR"
|
||||
tar -xzvf "$BINARY_NAME"
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
@ -13,8 +14,6 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
|
@ -11,7 +11,8 @@ import (
|
||||
|
||||
const (
|
||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||
NetbirdFwmark = 0x1BD00
|
||||
NetbirdFwmark = 0x1BD00
|
||||
PreroutingFwmark = 0x1BD01
|
||||
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user