mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 18:31:18 +01:00
Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
This commit is contained in:
commit
53218f99bc
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.14"
|
SIGN_PIPE_VER: "v0.0.16"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
@ -223,4 +223,4 @@ jobs:
|
|||||||
repo: netbirdio/sign-pipelines
|
repo: netbirdio/sign-pipelines
|
||||||
ref: ${{ env.SIGN_PIPE_VER }}
|
ref: ${{ env.SIGN_PIPE_VER }}
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||||
|
@ -96,6 +96,9 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
universal_binaries:
|
||||||
|
- id: netbird
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
|
@ -23,6 +23,9 @@ builds:
|
|||||||
tags:
|
tags:
|
||||||
- load_wgnt_from_rsrc
|
- load_wgnt_from_rsrc
|
||||||
|
|
||||||
|
universal_binaries:
|
||||||
|
- id: netbird-ui-darwin
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
- builds:
|
- builds:
|
||||||
- netbird-ui-darwin
|
- netbird-ui-darwin
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -21,13 +22,19 @@ const (
|
|||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
spec []string
|
||||||
|
position int
|
||||||
|
}
|
||||||
|
|
||||||
type aclManager struct {
|
type aclManager struct {
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
entries map[string][][]string
|
entries map[string][][]string
|
||||||
ipsetStore *ipsetStore
|
optionalEntries map[string][]entry
|
||||||
|
ipsetStore *ipsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||||
@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
|||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
routingFwChainName: routingFwChainName,
|
routingFwChainName: routingFwChainName,
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
ipsetStore: newIpsetStore(),
|
optionalEntries: make(map[string][]entry),
|
||||||
|
ipsetStore: newIpsetStore(),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ipset.Init()
|
err := ipset.Init()
|
||||||
@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.seedInitialEntries()
|
m.seedInitialEntries()
|
||||||
|
m.seedInitialOptionalEntries()
|
||||||
|
|
||||||
err = m.cleanChains()
|
err = m.cleanChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
for _, rule := range m.entries["PREROUTING"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for chainName, entries := range m.optionalEntries {
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||||
|
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clear(m.optionalEntries)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
|
{
|
||||||
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
||||||
|
position: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.optionalEntries["PREROUTING"] = []entry{
|
||||||
|
{
|
||||||
|
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
||||||
|
position: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
sources [] netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||||
table := tableFilter
|
table := r.getTableForChain(chain)
|
||||||
if chain == chainRTNAT {
|
|
||||||
table = tableNat
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||||
if err := r.createAndSetupChain(chain); err != nil {
|
if err := r.createAndSetupChain(chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s: %v", chain, err)
|
return fmt.Errorf("create chain %s: %w", chain, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||||
return fmt.Errorf("insert established rule: %v", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.addJumpRules()
|
if err := r.addJumpRules(); err != nil {
|
||||||
|
return fmt.Errorf("add jump rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
func (r *router) createAndSetupChain(chain string) error {
|
||||||
@ -432,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||||
intdir := "-i"
|
intdir := "-i"
|
||||||
|
lointdir := "-o"
|
||||||
if inverse {
|
if inverse {
|
||||||
intdir = "-o"
|
intdir = "-o"
|
||||||
|
lointdir = "-i"
|
||||||
}
|
}
|
||||||
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
|
@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
|||||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||||
func GenerateSetName(sources []netip.Prefix) string {
|
func GenerateSetName(sources []netip.Prefix) string {
|
||||||
// sort for consistent naming
|
// sort for consistent naming
|
||||||
sortPrefixes(sources)
|
SortPrefixes(sources)
|
||||||
|
|
||||||
var sourcesStr strings.Builder
|
var sourcesStr strings.Builder
|
||||||
for _, src := range sources {
|
for _, src := range sources {
|
||||||
@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
|||||||
return merged
|
return merged
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||||
func sortPrefixes(prefixes []netip.Prefix) {
|
func SortPrefixes(prefixes []netip.Prefix) {
|
||||||
sort.Slice(prefixes, func(i, j int) bool {
|
sort.Slice(prefixes, func(i, j int) bool {
|
||||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||||
if addrCmp != 0 {
|
if addrCmp != 0 {
|
||||||
|
@ -11,12 +11,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -29,6 +31,7 @@ const (
|
|||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
chainNameOutputFilter = "netbird-acl-output-filter"
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
|
chainNamePrerouting = "netbird-rt-prerouting"
|
||||||
|
|
||||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
@ -40,15 +43,14 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type AclManager struct {
|
type AclManager struct {
|
||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
sConn *nftables.Conn
|
sConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routeingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
chainOutputRules *nftables.Chain
|
chainOutputRules *nftables.Chain
|
||||||
chainFwFilter *nftables.Chain
|
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
@ -61,7 +63,7 @@ type iFaceMapper interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
// and is permanent. Using same connection for both type of operations
|
// and is permanent. Using same connection for both type of operations
|
||||||
@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
m := &AclManager{
|
m := &AclManager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
sConn: sConn,
|
sConn: sConn,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
workTable: table,
|
workTable: table,
|
||||||
routeingFwChainName: routeingFwChainName,
|
routingFwChainName: routingFwChainName,
|
||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
rules: make(map[string]*Rule),
|
rules: make(map[string]*Rule),
|
||||||
@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
// netbird-acl-forward-filter
|
||||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||||
|
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||||
|
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addJumpRulesToRtForward() {
|
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||||
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
|
// netbird peer IP.
|
||||||
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
|
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNamePrerouting,
|
||||||
|
Table: m.workTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
})
|
||||||
|
|
||||||
|
m.addPreroutingRule(preroutingChain)
|
||||||
|
|
||||||
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||||
|
m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: preroutingChain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
|
m.rConn.InsertRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: chainFwFilter,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: m.chainInputRules.Name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
|||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictJump,
|
||||||
Chain: m.routeingFwChainName,
|
Chain: m.routingFwChainName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Chain: m.chainFwFilter,
|
Chain: chainFwFilter,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
|||||||
return chain
|
return chain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
|
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||||
polAccept := nftables.ChainPolicyAccept
|
polAccept := nftables.ChainPolicyAccept
|
||||||
chain := &nftables.Chain{
|
chain := &nftables.Chain{
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
|
|||||||
rule := &nftables.Rule{
|
rule := &nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: []expr.Any{
|
Exprs: getEstablishedExprs(1),
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.InsertRule(rule)
|
conn.InsertRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getEstablishedExprs(register uint32) []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: register,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: register,
|
||||||
|
DestRegister: register,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: register,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: []byte{0, 0, 0, 0},
|
Data: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
|
&expr.Counter{},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
@ -24,7 +26,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-nat"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
@ -80,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.cleanUpDefaultForwardRules()
|
err = r.removeAcceptForwardRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
@ -97,40 +99,7 @@ func (r *router) Reset() error {
|
|||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
r.ipsetCounter.Clear()
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
return r.cleanUpDefaultForwardRules()
|
return r.removeAcceptForwardRules()
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list chains: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
@ -149,7 +118,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
|
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingFw,
|
Name: chainNameRoutingFw,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
@ -157,25 +125,28 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
|
||||||
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
Priority: nftables.ChainPriorityNATSource - 1,
|
Priority: &prio,
|
||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
r.acceptForwardRules()
|
if err := r.acceptForwardRules(); err != nil {
|
||||||
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
err := r.refreshRulesMap()
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.conn.Flush()
|
if err := r.conn.Flush(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,6 +159,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
@ -248,9 +220,18 @@ func (r *router) AddRouteFiltering(
|
|||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
|
|
||||||
return ruleKey, r.conn.Flush()
|
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
|
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||||
@ -288,6 +269,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nftRule.Handle == 0 {
|
||||||
|
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
setName := r.findSetNameInRule(nftRule)
|
setName := r.findSetNameInRule(nftRule)
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@ -440,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
dir := expr.MetaKeyIIFNAME
|
dir := expr.MetaKeyIIFNAME
|
||||||
|
notDir := expr.MetaKeyOIFNAME
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
dir = expr.MetaKeyOIFNAME
|
dir = expr.MetaKeyOIFNAME
|
||||||
|
notDir = expr.MetaKeyIIFNAME
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lo := ifname("lo")
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: dir,
|
Key: dir,
|
||||||
@ -455,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: intf,
|
Data: intf,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||||
|
&expr.Meta{
|
||||||
|
Key: notDir,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: lo,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
@ -562,19 +562,60 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
// that our traffic is not dropped by existing rules there.
|
// that our traffic is not dropped by existing rules there.
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
func (r *router) acceptForwardRules() {
|
func (r *router) acceptForwardRules() error {
|
||||||
if r.filterTable == nil {
|
if r.filterTable == nil {
|
||||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fw := "iptables"
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Debugf("Used %s to add accept forward rules", fw)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
|
ipt, err := iptables.New()
|
||||||
|
if err != nil {
|
||||||
|
// filter table exists but iptables is not
|
||||||
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
|
||||||
|
fw = "nftables"
|
||||||
|
return r.acceptForwardRulesNftables()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.acceptForwardRulesIptables(ipt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||||
|
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||||
|
} else {
|
||||||
|
log.Debugf("added iptables rule: %v", rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) getAcceptForwardRules() [][]string {
|
||||||
|
intf := r.wgIface.Name()
|
||||||
|
return [][]string{
|
||||||
|
{"-i", intf, "-j", "ACCEPT"},
|
||||||
|
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptForwardRulesNftables() error {
|
||||||
intf := ifname(r.wgIface.Name())
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
// Rule for incoming interface (iif) with counter
|
// Rule for incoming interface (iif) with counter
|
||||||
iifRule := &nftables.Rule{
|
iifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
Chain: &nftables.Chain{
|
Chain: &nftables.Chain{
|
||||||
Name: "FORWARD",
|
Name: chainNameForward,
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
@ -594,6 +635,15 @@ func (r *router) acceptForwardRules() {
|
|||||||
}
|
}
|
||||||
r.conn.InsertRule(iifRule)
|
r.conn.InsertRule(iifRule)
|
||||||
|
|
||||||
|
oifExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Rule for outgoing interface (oif) with counter
|
// Rule for outgoing interface (oif) with counter
|
||||||
oifRule := &nftables.Rule{
|
oifRule := &nftables.Rule{
|
||||||
Table: r.filterTable,
|
Table: r.filterTable,
|
||||||
@ -604,36 +654,72 @@ func (r *router) acceptForwardRules() {
|
|||||||
Hooknum: nftables.ChainHookForward,
|
Hooknum: nftables.ChainHookForward,
|
||||||
Priority: nftables.ChainPriorityFilter,
|
Priority: nftables.ChainPriorityFilter,
|
||||||
},
|
},
|
||||||
Exprs: []expr.Any{
|
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: intf,
|
|
||||||
},
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 2,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
},
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.conn.InsertRule(oifRule)
|
r.conn.InsertRule(oifRule)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRules() error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try iptables first and fallback to nftables if iptables is not available
|
||||||
|
ipt, err := iptables.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||||
|
return r.removeAcceptForwardRulesNftables()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.removeAcceptForwardRulesIptables(ipt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||||
|
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, rule := range r.getAcceptForwardRules() {
|
||||||
|
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||||
|
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||||
@ -658,7 +744,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
@ -314,6 +326,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||||
|
})
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
@ -346,10 +362,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
// Verify actual nftables rule content
|
// Verify actual nftables rule content
|
||||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||||
|
|
||||||
// Clean up
|
|
||||||
err = r.DeleteRouteRule(ruleKey)
|
|
||||||
require.NoError(t, err, "Failed to delete rule")
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
package id
|
package id
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
|
|||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
action manager.Action,
|
action manager.Action,
|
||||||
) RuleID {
|
) RuleID {
|
||||||
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
manager.SortPrefixes(sources)
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
|
||||||
|
// Write all fields to the hasher, with delimiters
|
||||||
|
h.Write([]byte("sources:"))
|
||||||
|
for _, src := range sources {
|
||||||
|
h.Write([]byte(src.String()))
|
||||||
|
h.Write([]byte(","))
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Write([]byte("destination:"))
|
||||||
|
h.Write([]byte(destination.String()))
|
||||||
|
|
||||||
|
h.Write([]byte("proto:"))
|
||||||
|
h.Write([]byte(proto))
|
||||||
|
|
||||||
|
h.Write([]byte("sPort:"))
|
||||||
|
if sPort != nil {
|
||||||
|
h.Write([]byte(sPort.String()))
|
||||||
|
} else {
|
||||||
|
h.Write([]byte("<nil>"))
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Write([]byte("dPort:"))
|
||||||
|
if dPort != nil {
|
||||||
|
h.Write([]byte(dPort.String()))
|
||||||
|
} else {
|
||||||
|
h.Write([]byte("<nil>"))
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Write([]byte("action:"))
|
||||||
|
h.Write([]byte(strconv.Itoa(int(action))))
|
||||||
|
hash := hex.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
// prepend destination prefix to be able to identify the rule
|
||||||
|
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
|
||||||
}
|
}
|
||||||
|
@ -82,8 +82,6 @@ type Conn struct {
|
|||||||
config ConnConfig
|
config ConnConfig
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
wgProxyFactory *wgproxy.Factory
|
wgProxyFactory *wgproxy.Factory
|
||||||
wgProxyICE wgproxy.Proxy
|
|
||||||
wgProxyRelay wgproxy.Proxy
|
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
@ -106,7 +104,8 @@ type Conn struct {
|
|||||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||||
|
|
||||||
endpointRelay *net.UDPAddr
|
wgProxyICE wgproxy.Proxy
|
||||||
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
|
||||||
// for reconnection operations
|
// for reconnection operations
|
||||||
iCEDisconnected chan bool
|
iCEDisconnected chan bool
|
||||||
@ -257,8 +256,7 @@ func (conn *Conn) Close() {
|
|||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
conn.statusICE.Set(StatusConnected)
|
|
||||||
|
|
||||||
defer conn.updateIceState(iceConnInfo)
|
|
||||||
|
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("set ICE to active connection")
|
conn.log.Infof("set ICE to active connection")
|
||||||
|
|
||||||
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
|
var (
|
||||||
if err != nil {
|
ep *net.UDPAddr
|
||||||
return
|
wgProxy wgproxy.Proxy
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if iceConnInfo.RelayedOnLocal {
|
||||||
|
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||||
|
if err != nil {
|
||||||
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = wgProxy.EndpointAddr()
|
||||||
|
conn.wgProxyICE = wgProxy
|
||||||
|
} else {
|
||||||
|
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to resolveUDPaddr")
|
||||||
|
conn.handleConfigurationFailure(err, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||||
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
|
|
||||||
conn.connIDICE = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
if conn.wgProxyRelay != nil {
|
||||||
if err != nil {
|
conn.wgProxyRelay.Pause()
|
||||||
if wgProxy != nil {
|
}
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("Failed to close turn connection: %v", err)
|
if wgProxy != nil {
|
||||||
}
|
wgProxy.Work()
|
||||||
}
|
}
|
||||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
|
||||||
|
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||||
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyICE = wgProxy
|
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
conn.updateIceState(iceConnInfo)
|
||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to %s", newState)
|
conn.log.Tracef("ICE connection state changed to %s", newState)
|
||||||
|
|
||||||
|
if conn.wgProxyICE != nil {
|
||||||
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
|
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
conn.log.Debugf("ICE disconnected, set Relay to active connection")
|
||||||
err := conn.configureWGEndpoint(conn.endpointRelay)
|
conn.wgProxyRelay.Work()
|
||||||
if err != nil {
|
|
||||||
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||||
conn.statusICE.Set(newState)
|
conn.statusICE.Set(newState)
|
||||||
|
|
||||||
select {
|
conn.notifyReconnectLoopICEDisconnected(changed)
|
||||||
case conn.iCEDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
if err := rci.relayedConn.Close(); err != nil {
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
conn.statusRelay.Set(StatusConnected)
|
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||||
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
|
|
||||||
|
|
||||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
conn.endpointRelay = endpointUdpAddr
|
|
||||||
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
|
||||||
|
|
||||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
if conn.iceP2PIsActive() {
|
||||||
|
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
if conn.currentConnPriority > connPriorityRelay {
|
conn.wgProxyRelay = wgProxy
|
||||||
if conn.statusICE.Get() == StatusConnected {
|
conn.statusRelay.Set(StatusConnected)
|
||||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.connIDRelay = nbnet.GenerateConnID()
|
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
conn.log.Errorf("Before add peer hook failed: %v", err)
|
||||||
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
|
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.configureWGEndpoint(endpointUdpAddr)
|
wgProxy.Work()
|
||||||
if err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
|
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
|
||||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.wgProxyRelay = wgProxy
|
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
conn.wgProxyRelay = wgProxy
|
||||||
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
}
|
}
|
||||||
@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("relay connection is disconnected")
|
conn.log.Debugf("relay connection is disconnected")
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityRelay {
|
if conn.currentConnPriority == connPriorityRelay {
|
||||||
log.Debugf("clean up WireGuard config")
|
conn.log.Debugf("clean up WireGuard config")
|
||||||
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
conn.endpointRelay = nil
|
|
||||||
_ = conn.wgProxyRelay.CloseConn()
|
_ = conn.wgProxyRelay.CloseConn()
|
||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
|
conn.notifyReconnectLoopRelayDisconnected(changed)
|
||||||
select {
|
|
||||||
case conn.relayDisconnected <- changed:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
Relayed: conn.isRelayed(),
|
Relayed: conn.isRelayed(),
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
}
|
}
|
||||||
|
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||||
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
|
|
||||||
if err != nil {
|
|
||||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||||
|
conn.connIDICE = nbnet.GenerateConnID()
|
||||||
|
for _, hook := range conn.beforeAddPeerHooks {
|
||||||
|
if err := hook(conn.connIDICE, ip); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) freeUpConnID() {
|
func (conn *Conn) freeUpConnID() {
|
||||||
if conn.connIDRelay != "" {
|
if conn.connIDRelay != "" {
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
if !iceConnInfo.RelayedOnLocal {
|
conn.log.Debugf("setup proxied WireGuard connection")
|
||||||
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
|
|
||||||
}
|
|
||||||
conn.log.Debugf("setup ice turn connection")
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||||
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
|
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||||
if err != nil {
|
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
if errClose := wgProxy.CloseConn(); errClose != nil {
|
return nil, err
|
||||||
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
|
}
|
||||||
}
|
return wgProxy, nil
|
||||||
return nil, nil, err
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) iceP2PIsActive() bool {
|
||||||
|
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) removeWgPeer() error {
|
||||||
|
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.relayDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
|
||||||
|
select {
|
||||||
|
case conn.iCEDisconnected <- changed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
|
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
|
if wgProxy != nil {
|
||||||
|
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||||
|
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.wgProxyRelay.Work()
|
||||||
}
|
}
|
||||||
return ep, wgProxy, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
|
@ -5,7 +5,6 @@ package ebpf
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn add new turn connection for the proxy
|
// AddTurnConn add new turn connection for the proxy
|
||||||
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
|
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
|
|
||||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||||
|
|
||||||
wgEndpoint := &net.UDPAddr{
|
wgEndpoint := &net.UDPAddr{
|
||||||
@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
|
|
||||||
defer p.removeTurnConn(endpointPort)
|
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
n int
|
|
||||||
)
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
n, err = remoteConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
|
|
||||||
if ctx.Err() != nil || p.ctx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||||
// From this go routine has only one instance.
|
// From this go routine has only one instance.
|
||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
return packetConn, nil
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||||
localhost := net.ParseIP("127.0.0.1")
|
localhost := net.ParseIP("127.0.0.1")
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
payload := gopacket.Payload(data)
|
||||||
|
@ -4,8 +4,13 @@ package ebpf
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@ -13,20 +18,55 @@ type ProxyWrapper struct {
|
|||||||
WgeBPFProxy *WGEBPFProxy
|
WgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
wgEndpointAddr *net.UDPAddr
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
ctxConn, cancel := context.WithCancel(ctx)
|
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cancel()
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
return nil, fmt.Errorf("add turn conn: %w", err)
|
|
||||||
}
|
}
|
||||||
e.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
e.cancel = cancel
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return addr, err
|
p.wgEndpointAddr = addr
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
|
return p.wgEndpointAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||||
|
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||||
|
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
n, err := p.readFromRemote(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||||
|
n, err := p.remoteConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return 0, ctx.Err()
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
@ -7,6 +7,9 @@ import (
|
|||||||
|
|
||||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
|
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||||
|
EndpointAddr() *net.UDPAddr
|
||||||
|
Work()
|
||||||
|
Pause()
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,17 @@ import (
|
|||||||
// WGUserSpaceProxy proxies
|
// WGUserSpaceProxy proxies
|
||||||
type WGUserSpaceProxy struct {
|
type WGUserSpaceProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
pausedMu sync.Mutex
|
||||||
|
paused bool
|
||||||
|
isStarted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||||
@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
|
// The provided Context must be non-nil. If the context expires before
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
// the connection is complete, an error is returned. Once successfully
|
||||||
|
// connected, any expiration of the context will not affect the
|
||||||
p.remoteConn = remoteConn
|
// connection.
|
||||||
|
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||||
var err error
|
|
||||||
dialer := net.Dialer{}
|
dialer := net.Dialer{}
|
||||||
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go p.proxyToRemote()
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
go p.proxyToLocal()
|
p.localConn = localConn
|
||||||
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
return p.localConn.LocalAddr(), err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||||
|
if p.localConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||||
|
return endpointUdpAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work starts the proxy or resumes it if it was paused
|
||||||
|
func (p *WGUserSpaceProxy) Work() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
|
if !p.isStarted {
|
||||||
|
p.isStarted = true
|
||||||
|
go p.proxyToRemote(p.ctx)
|
||||||
|
go p.proxyToLocal(p.ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
|
func (p *WGUserSpaceProxy) Pause() {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
p.paused = true
|
||||||
|
p.pausedMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||||
func (p *WGUserSpaceProxy) proxyToRemote() {
|
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to remote loop: %s", err)
|
log.Warnf("error in proxy to remote loop: %s", err)
|
||||||
@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
n, err := p.localConn.Read(buf)
|
n, err := p.localConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
|
|
||||||
_, err = p.remoteConn.Write(buf[:n])
|
_, err = p.remoteConn.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||||
func (p *WGUserSpaceProxy) proxyToLocal() {
|
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||||
|
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := p.close(); err != nil {
|
if err := p.close(); err != nil {
|
||||||
log.Warnf("error in proxy to local loop: %s", err)
|
log.Warnf("error in proxy to local loop: %s", err)
|
||||||
@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for p.ctx.Err() == nil {
|
for {
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedMu.Lock()
|
||||||
|
if p.paused {
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
_, err = p.localConn.Write(buf[:n])
|
_, err = p.localConn.Write(buf[:n])
|
||||||
|
p.pausedMu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
|
||||||
<plist version="1.0">
|
|
||||||
<dict>
|
|
||||||
<key>CFBundleExecutable</key>
|
|
||||||
<string>netbird-ui</string>
|
|
||||||
<key>CFBundleIconFile</key>
|
|
||||||
<string>Netbird</string>
|
|
||||||
<key>LSUIElement</key>
|
|
||||||
<string>1</string>
|
|
||||||
</dict>
|
|
||||||
</plist>
|
|
@ -8,11 +8,11 @@ cask "{{ $projectName }}" do
|
|||||||
if Hardware::CPU.intel?
|
if Hardware::CPU.intel?
|
||||||
url "{{ $amdURL }}"
|
url "{{ $amdURL }}"
|
||||||
sha256 "{{ crypto.SHA256 $amdFileBytes }}"
|
sha256 "{{ crypto.SHA256 $amdFileBytes }}"
|
||||||
app "netbird_ui_darwin_amd64", target: "Netbird UI.app"
|
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||||
else
|
else
|
||||||
url "{{ $armURL }}"
|
url "{{ $armURL }}"
|
||||||
sha256 "{{ crypto.SHA256 $armFileBytes }}"
|
sha256 "{{ crypto.SHA256 $armFileBytes }}"
|
||||||
app "netbird_ui_darwin_arm64", target: "Netbird UI.app"
|
app "netbird_ui_darwin", target: "Netbird UI.app"
|
||||||
end
|
end
|
||||||
|
|
||||||
depends_on formula: "netbird"
|
depends_on formula: "netbird"
|
||||||
|
20
go.mod
20
go.mod
@ -19,8 +19,8 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||||
golang.org/x/crypto v0.24.0
|
golang.org/x/crypto v0.28.0
|
||||||
golang.org/x/sys v0.21.0
|
golang.org/x/sys v0.26.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
@ -38,6 +38,7 @@ require (
|
|||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/creack/pty v1.1.18
|
github.com/creack/pty v1.1.18
|
||||||
|
github.com/davecgh/go-spew v1.1.1
|
||||||
github.com/eko/gocache/v3 v3.1.1
|
github.com/eko/gocache/v3 v3.1.1
|
||||||
github.com/fsnotify/fsnotify v1.7.0
|
github.com/fsnotify/fsnotify v1.7.0
|
||||||
github.com/gliderlabs/ssh v0.3.4
|
github.com/gliderlabs/ssh v0.3.4
|
||||||
@ -45,7 +46,7 @@ require (
|
|||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
github.com/google/nftables v0.2.0
|
||||||
github.com/gopacket/gopacket v1.1.1
|
github.com/gopacket/gopacket v1.1.1
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
@ -55,12 +56,12 @@ require (
|
|||||||
github.com/libp2p/go-netroute v0.2.1
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/magiconair/properties v1.8.7
|
github.com/magiconair/properties v1.8.7
|
||||||
github.com/mattn/go-sqlite3 v1.14.19
|
github.com/mattn/go-sqlite3 v1.14.19
|
||||||
github.com/mdlayher/socket v0.4.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
@ -90,10 +91,10 @@ require (
|
|||||||
goauthentik.io/api/v3 v3.2023051.3
|
goauthentik.io/api/v3 v3.2023051.3
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||||
golang.org/x/net v0.26.0
|
golang.org/x/net v0.30.0
|
||||||
golang.org/x/oauth2 v0.19.0
|
golang.org/x/oauth2 v0.19.0
|
||||||
golang.org/x/sync v0.7.0
|
golang.org/x/sync v0.8.0
|
||||||
golang.org/x/term v0.21.0
|
golang.org/x/term v0.25.0
|
||||||
google.golang.org/api v0.177.0
|
google.golang.org/api v0.177.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
@ -134,7 +135,6 @@ require (
|
|||||||
github.com/containerd/containerd v1.7.16 // indirect
|
github.com/containerd/containerd v1.7.16 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.1 // indirect
|
github.com/cpuguy83/dockercfg v0.3.1 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
|
||||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
@ -222,7 +222,7 @@ require (
|
|||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/image v0.18.0 // indirect
|
golang.org/x/image v0.18.0 // indirect
|
||||||
golang.org/x/mod v0.17.0 // indirect
|
golang.org/x/mod v0.17.0 // indirect
|
||||||
golang.org/x/text v0.16.0 // indirect
|
golang.org/x/text v0.19.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
|
36
go.sum
36
go.sum
@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
|
|||||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||||
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
|
||||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||||
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
|
||||||
@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
|
|||||||
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
|
||||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||||
@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
|
|||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
@ -780,8 +780,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
|||||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@ -877,8 +877,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
|||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
@ -907,8 +907,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
@ -980,8 +980,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
@ -989,8 +989,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
|||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
|
||||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
@ -1005,8 +1005,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
|||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
@ -873,7 +873,7 @@ services:
|
|||||||
zitadel:
|
zitadel:
|
||||||
restart: 'always'
|
restart: 'always'
|
||||||
networks: [netbird]
|
networks: [netbird]
|
||||||
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
|
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
||||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||||
env_file:
|
env_file:
|
||||||
- ./zitadel.env
|
- ./zitadel.env
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -44,12 +45,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PublicCategory = "public"
|
PublicCategory = "public"
|
||||||
PrivateCategory = "private"
|
PrivateCategory = "private"
|
||||||
UnknownCategory = "unknown"
|
UnknownCategory = "unknown"
|
||||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
|
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||||
|
emptyUserID = "empty user ID in claims"
|
||||||
|
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userLoggedInOnce bool
|
type userLoggedInOnce bool
|
||||||
@ -177,6 +181,8 @@ type DefaultAccountManager struct {
|
|||||||
dnsDomain string
|
dnsDomain string
|
||||||
peerLoginExpiry Scheduler
|
peerLoginExpiry Scheduler
|
||||||
|
|
||||||
|
peerInactivityExpiry Scheduler
|
||||||
|
|
||||||
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||||
userDeleteFromIDPEnabled bool
|
userDeleteFromIDPEnabled bool
|
||||||
|
|
||||||
@ -194,6 +200,13 @@ type Settings struct {
|
|||||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||||
PeerLoginExpiration time.Duration
|
PeerLoginExpiration time.Duration
|
||||||
|
|
||||||
|
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
|
||||||
|
PeerInactivityExpirationEnabled bool
|
||||||
|
|
||||||
|
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
|
||||||
|
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
|
||||||
|
PeerInactivityExpiration time.Duration
|
||||||
|
|
||||||
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
||||||
RegularUsersViewBlocked bool
|
RegularUsersViewBlocked bool
|
||||||
|
|
||||||
@ -224,6 +237,9 @@ func (s *Settings) Copy() *Settings {
|
|||||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||||
JWTAllowGroups: s.JWTAllowGroups,
|
JWTAllowGroups: s.JWTAllowGroups,
|
||||||
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||||
}
|
}
|
||||||
if s.Extra != nil {
|
if s.Extra != nil {
|
||||||
settings.Extra = s.Extra.Copy()
|
settings.Extra = s.Extra.Copy()
|
||||||
@ -605,6 +621,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
|
|||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInactivePeers returns peers that have been expired by inactivity
|
||||||
|
func (a *Account) GetInactivePeers() []*nbpeer.Peer {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
for _, inactivePeer := range a.GetPeersWithInactivity() {
|
||||||
|
inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||||
|
if inactive {
|
||||||
|
peers = append(peers, inactivePeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
|
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||||
|
func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) {
|
||||||
|
peersWithExpiry := a.GetPeersWithInactivity()
|
||||||
|
if len(peersWithExpiry) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
var nextExpiry *time.Duration
|
||||||
|
for _, peer := range peersWithExpiry {
|
||||||
|
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration)
|
||||||
|
if nextExpiry == nil || duration < *nextExpiry {
|
||||||
|
// if expiration is below 1s return 1s duration
|
||||||
|
// this avoids issues with ticker that can't be set to < 0
|
||||||
|
if duration < time.Second {
|
||||||
|
return time.Second, true
|
||||||
|
}
|
||||||
|
nextExpiry = &duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextExpiry == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *nextExpiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user
|
||||||
|
func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer {
|
||||||
|
peers := make([]*nbpeer.Peer, 0)
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeers returns a list of all Account peers
|
// GetPeers returns a list of all Account peers
|
||||||
func (a *Account) GetPeers() []*nbpeer.Peer {
|
func (a *Account) GetPeers() []*nbpeer.Peer {
|
||||||
var peers []*nbpeer.Peer
|
var peers []*nbpeer.Peer
|
||||||
@ -971,6 +1041,7 @@ func BuildManager(
|
|||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
eventStore: eventStore,
|
eventStore: eventStore,
|
||||||
peerLoginExpiry: NewDefaultScheduler(),
|
peerLoginExpiry: NewDefaultScheduler(),
|
||||||
|
peerInactivityExpiry: NewDefaultScheduler(),
|
||||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
@ -1099,6 +1170,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
updatedAccount := account.UpdateSettings(newSettings)
|
updatedAccount := account.UpdateSettings(newSettings)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SaveAccount(ctx, account)
|
||||||
@ -1109,6 +1185,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return updatedAccount, nil
|
return updatedAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||||
|
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event := activity.AccountPeerInactivityExpirationEnabled
|
||||||
|
if !newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event = activity.AccountPeerInactivityExpirationDisabled
|
||||||
|
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||||
|
} else {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
return func() (time.Duration, bool) {
|
return func() (time.Duration, bool) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
@ -1144,6 +1240,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
|
||||||
|
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
|
return func() (time.Duration, bool) {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetInactivePeers()
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||||
|
|
||||||
|
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||||
|
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.GetNextInactivePeerExpiration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||||
|
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
||||||
|
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
||||||
|
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
||||||
|
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||||
// If ID is already in use (due to collision) we try one more time before returning error
|
// If ID is already in use (due to collision) we try one more time before returning error
|
||||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
||||||
@ -1284,7 +1417,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
|||||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return account.Id, nil
|
return account.Id, nil
|
||||||
@ -1299,28 +1432,39 @@ func isNil(i idp.Manager) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
|
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cachedAccount := &Account{
|
||||||
|
Id: accountID,
|
||||||
|
Users: make(map[string]*User),
|
||||||
|
}
|
||||||
|
for _, user := range accountUsers {
|
||||||
|
cachedAccount.Users[user.Id] = user
|
||||||
|
}
|
||||||
|
|
||||||
// user can be nil if it wasn't found (e.g., just created)
|
// user can be nil if it wasn't found (e.g., just created)
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||||
// it was already set, so we skip the unnecessary update
|
// it was already set, so we skip the unnecessary update
|
||||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||||
account.Id, userID)
|
accountID, userID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||||
}
|
}
|
||||||
// refresh cache to reflect the update
|
// refresh cache to reflect the update
|
||||||
_, err = am.refreshCache(ctx, account.Id)
|
_, err = am.refreshCache(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1544,48 +1688,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
|||||||
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
|
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||||
primaryDomain bool,
|
primaryDomain bool,
|
||||||
) error {
|
) error {
|
||||||
|
if claims.Domain == "" {
|
||||||
if claims.Domain != "" {
|
|
||||||
account.IsDomainPrimaryAccount = primaryDomain
|
|
||||||
|
|
||||||
lowerDomain := strings.ToLower(claims.Domain)
|
|
||||||
userObj := account.Users[claims.UserId]
|
|
||||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
|
||||||
account.Domain = lowerDomain
|
|
||||||
}
|
|
||||||
// prevent updating category for different domain until admin logs in
|
|
||||||
if account.Domain == lowerDomain {
|
|
||||||
account.DomainCategory = claims.DomainCategory
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := am.Store.SaveAccount(ctx, account)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlockAccount()
|
||||||
|
|
||||||
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newDomain := accountDomain
|
||||||
|
newCategoty := domainCategory
|
||||||
|
|
||||||
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
|
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||||
|
newDomain = lowerDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountDomain == lowerDomain {
|
||||||
|
newCategoty = claims.DomainCategory
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||||
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
|
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||||
|
// and peers that shouldn't be lost.
|
||||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
existingAcc *Account,
|
userAccountID string,
|
||||||
primaryDomain bool,
|
domainAccountID string,
|
||||||
claims jwtclaims.AuthorizationClaims,
|
claims jwtclaims.AuthorizationClaims,
|
||||||
) error {
|
) error {
|
||||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||||
|
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// we should register the account ID to this user's metadata in our IDP manager
|
// we should register the account ID to this user's metadata in our IDP manager
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1593,44 +1758,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||||
// otherwise it will create a new account and make it primary account for the domain.
|
// otherwise it will create a new account and make it primary account for the domain.
|
||||||
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
var (
|
|
||||||
account *Account
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
lowerDomain := strings.ToLower(claims.Domain)
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
// if domain already has a primary account, add regular user
|
|
||||||
if domainAcc != nil {
|
|
||||||
account = domainAcc
|
|
||||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
|
newAccount.Domain = lowerDomain
|
||||||
|
newAccount.DomainCategory = claims.DomainCategory
|
||||||
|
newAccount.IsDomainPrimaryAccount = true
|
||||||
|
|
||||||
return account, nil
|
err = am.Store.SaveAccount(ctx, newAccount)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||||
|
|
||||||
|
return newAccount.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
|
defer unlockAccount()
|
||||||
|
|
||||||
|
usersMap := make(map[string]*User)
|
||||||
|
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
|
||||||
|
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||||
|
|
||||||
|
return domainAccountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
@ -1774,7 +1953,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
|||||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return "", "", fmt.Errorf("user ID is empty")
|
return "", "", errors.New(emptyUserID)
|
||||||
}
|
}
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@ -1962,16 +2141,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||||
|
// if domain is not private or domain is invalid, it will return the account ID by user ID.
|
||||||
// if domain is of the PrivateCategory category, it will evaluate
|
// if domain is of the PrivateCategory category, it will evaluate
|
||||||
// if account is new, existing or if there is another account with the same domain
|
// if account is new, existing or if there is another account with the same domain
|
||||||
//
|
//
|
||||||
// Use cases:
|
// Use cases:
|
||||||
//
|
//
|
||||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
|
||||||
//
|
//
|
||||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
|
||||||
//
|
//
|
||||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
// New user + New account + Existing Public Domain -> create account, user role = owner
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||||
//
|
//
|
||||||
@ -1981,98 +2161,124 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
|
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return "", fmt.Errorf("user ID is empty")
|
return "", errors.New(emptyUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
|
||||||
// it means that we've already classified the domain and user has an account
|
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
if claims.AccountId != "" {
|
|
||||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
|
||||||
}
|
|
||||||
return claims.AccountId, nil
|
|
||||||
}
|
|
||||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if userAccountID != claims.AccountId {
|
|
||||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
|
||||||
}
|
|
||||||
|
|
||||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
|
||||||
return userAccountID, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
if claims.AccountId != "" {
|
||||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||||
defer unlock()
|
}
|
||||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||||
|
if cancel != nil {
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if NotFound we are good to continue, otherwise return error
|
return "", err
|
||||||
e, ok := status.FromError(err)
|
|
||||||
if !ok || e.Type() != status.NotFound {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err == nil {
|
if handleNotFound(err) != nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
defer unlockAccount()
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
|
||||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
|
||||||
// and peers that shouldn't be lost.
|
|
||||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
|
||||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Id, nil
|
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
|
||||||
var domainAccount *Account
|
|
||||||
if domainAccountID != "" {
|
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
|
||||||
defer unlockAccount()
|
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return account.Id, nil
|
|
||||||
} else {
|
|
||||||
// other error
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userAccountID != "" {
|
||||||
|
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return userAccountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainAccountID != "" {
|
||||||
|
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||||
|
}
|
||||||
|
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||||
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainAccountID != "" {
|
||||||
|
return domainAccountID, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
|
||||||
|
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||||
|
|
||||||
|
// check again if the domain has a primary account because of simultaneous requests
|
||||||
|
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
cancel()
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return domainAccountID, cancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAccountID != claims.AccountId {
|
||||||
|
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We checked if the domain has a primary account already
|
||||||
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNotFound(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e, ok := status.FromError(err)
|
||||||
|
if !ok || e.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||||
|
return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||||
@ -2338,6 +2544,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||||
GroupsPropagationEnabled: true,
|
GroupsPropagationEnabled: true,
|
||||||
RegularUsersViewBlocked: true,
|
RegularUsersViewBlocked: true,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: false,
|
||||||
|
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||||
type initUserParams jwtclaims.AuthorizationClaims
|
type initUserParams jwtclaims.AuthorizationClaims
|
||||||
|
|
||||||
type test struct {
|
var (
|
||||||
|
publicDomain = "public.com"
|
||||||
|
privateDomain = "private.com"
|
||||||
|
unknownDomain = "unknown.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
defaultInitAccount := initUserParams{
|
||||||
|
Domain: publicDomain,
|
||||||
|
UserId: "defaultUser",
|
||||||
|
}
|
||||||
|
|
||||||
|
initUnknown := defaultInitAccount
|
||||||
|
initUnknown.DomainCategory = UnknownCategory
|
||||||
|
initUnknown.Domain = unknownDomain
|
||||||
|
|
||||||
|
privateInitAccount := defaultInitAccount
|
||||||
|
privateInitAccount.Domain = privateDomain
|
||||||
|
privateInitAccount.DomainCategory = PrivateCategory
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
inputClaims jwtclaims.AuthorizationClaims
|
inputClaims jwtclaims.AuthorizationClaims
|
||||||
inputInitUserParams initUserParams
|
inputInitUserParams initUserParams
|
||||||
@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
expectedPrimaryDomainStatus bool
|
expectedPrimaryDomainStatus bool
|
||||||
expectedCreatedBy string
|
expectedCreatedBy string
|
||||||
expectedUsers []string
|
expectedUsers []string
|
||||||
}
|
}{
|
||||||
|
{
|
||||||
var (
|
name: "New User With Public Domain",
|
||||||
publicDomain = "public.com"
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
privateDomain = "private.com"
|
Domain: publicDomain,
|
||||||
unknownDomain = "unknown.com"
|
UserId: "pub-domain-user",
|
||||||
)
|
DomainCategory: PublicCategory,
|
||||||
|
},
|
||||||
defaultInitAccount := initUserParams{
|
inputInitUserParams: defaultInitAccount,
|
||||||
Domain: publicDomain,
|
testingFunc: require.NotEqual,
|
||||||
UserId: "defaultUser",
|
expectedMSG: "account IDs shouldn't match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomainCategory: "",
|
||||||
testCase1 := test{
|
expectedDomain: publicDomain,
|
||||||
name: "New User With Public Domain",
|
expectedPrimaryDomainStatus: false,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: "pub-domain-user",
|
||||||
Domain: publicDomain,
|
expectedUsers: []string{"pub-domain-user"},
|
||||||
UserId: "pub-domain-user",
|
|
||||||
DomainCategory: PublicCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New User With Unknown Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: unknownDomain,
|
||||||
expectedDomainCategory: "",
|
UserId: "unknown-domain-user",
|
||||||
expectedDomain: publicDomain,
|
DomainCategory: UnknownCategory,
|
||||||
expectedPrimaryDomainStatus: false,
|
},
|
||||||
expectedCreatedBy: "pub-domain-user",
|
inputInitUserParams: initUnknown,
|
||||||
expectedUsers: []string{"pub-domain-user"},
|
testingFunc: require.NotEqual,
|
||||||
}
|
expectedMSG: "account IDs shouldn't match",
|
||||||
|
expectedUserRole: UserRoleOwner,
|
||||||
initUnknown := defaultInitAccount
|
expectedDomain: unknownDomain,
|
||||||
initUnknown.DomainCategory = UnknownCategory
|
expectedDomainCategory: "",
|
||||||
initUnknown.Domain = unknownDomain
|
expectedPrimaryDomainStatus: false,
|
||||||
|
expectedCreatedBy: "unknown-domain-user",
|
||||||
testCase2 := test{
|
expectedUsers: []string{"unknown-domain-user"},
|
||||||
name: "New User With Unknown Domain",
|
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
|
||||||
Domain: unknownDomain,
|
|
||||||
UserId: "unknown-domain-user",
|
|
||||||
DomainCategory: UnknownCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: initUnknown,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New User With Private Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: privateDomain,
|
||||||
expectedDomain: unknownDomain,
|
UserId: "pvt-domain-user",
|
||||||
expectedDomainCategory: "",
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: false,
|
},
|
||||||
expectedCreatedBy: "unknown-domain-user",
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedUsers: []string{"unknown-domain-user"},
|
testingFunc: require.NotEqual,
|
||||||
}
|
expectedMSG: "account IDs shouldn't match",
|
||||||
|
expectedUserRole: UserRoleOwner,
|
||||||
testCase3 := test{
|
expectedDomain: privateDomain,
|
||||||
name: "New User With Private Domain",
|
expectedDomainCategory: PrivateCategory,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedPrimaryDomainStatus: true,
|
||||||
Domain: privateDomain,
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
UserId: "pvt-domain-user",
|
expectedUsers: []string{"pvt-domain-user"},
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.NotEqual,
|
name: "New Regular User With Existing Private Domain",
|
||||||
expectedMSG: "account IDs shouldn't match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: privateDomain,
|
||||||
expectedDomain: privateDomain,
|
UserId: "new-pvt-domain-user",
|
||||||
expectedDomainCategory: PrivateCategory,
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
},
|
||||||
expectedCreatedBy: "pvt-domain-user",
|
inputUpdateAttrs: true,
|
||||||
expectedUsers: []string{"pvt-domain-user"},
|
inputInitUserParams: privateInitAccount,
|
||||||
}
|
testingFunc: require.Equal,
|
||||||
|
expectedMSG: "account IDs should match",
|
||||||
privateInitAccount := defaultInitAccount
|
expectedUserRole: UserRoleUser,
|
||||||
privateInitAccount.Domain = privateDomain
|
expectedDomain: privateDomain,
|
||||||
privateInitAccount.DomainCategory = PrivateCategory
|
expectedDomainCategory: PrivateCategory,
|
||||||
|
expectedPrimaryDomainStatus: true,
|
||||||
testCase4 := test{
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
name: "New Regular User With Existing Private Domain",
|
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
|
||||||
Domain: privateDomain,
|
|
||||||
UserId: "new-pvt-domain-user",
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputUpdateAttrs: true,
|
{
|
||||||
inputInitUserParams: privateInitAccount,
|
name: "Existing User With Existing Reclassified Private Domain",
|
||||||
testingFunc: require.Equal,
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedMSG: "account IDs should match",
|
Domain: defaultInitAccount.Domain,
|
||||||
expectedUserRole: UserRoleUser,
|
UserId: defaultInitAccount.UserId,
|
||||||
expectedDomain: privateDomain,
|
DomainCategory: PrivateCategory,
|
||||||
expectedDomainCategory: PrivateCategory,
|
},
|
||||||
expectedPrimaryDomainStatus: true,
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
testingFunc: require.Equal,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
expectedMSG: "account IDs should match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
testCase5 := test{
|
expectedDomainCategory: PrivateCategory,
|
||||||
name: "Existing User With Existing Reclassified Private Domain",
|
expectedPrimaryDomainStatus: true,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
Domain: defaultInitAccount.Domain,
|
expectedUsers: []string{defaultInitAccount.UserId},
|
||||||
UserId: defaultInitAccount.UserId,
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
{
|
||||||
testingFunc: require.Equal,
|
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||||
expectedMSG: "account IDs should match",
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedUserRole: UserRoleOwner,
|
Domain: defaultInitAccount.Domain,
|
||||||
expectedDomain: defaultInitAccount.Domain,
|
UserId: defaultInitAccount.UserId,
|
||||||
expectedDomainCategory: PrivateCategory,
|
DomainCategory: PrivateCategory,
|
||||||
expectedPrimaryDomainStatus: true,
|
},
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
inputUpdateClaimAccount: true,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId},
|
inputInitUserParams: defaultInitAccount,
|
||||||
}
|
testingFunc: require.Equal,
|
||||||
|
expectedMSG: "account IDs should match",
|
||||||
testCase6 := test{
|
expectedUserRole: UserRoleOwner,
|
||||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
expectedDomain: defaultInitAccount.Domain,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedDomainCategory: PrivateCategory,
|
||||||
Domain: defaultInitAccount.Domain,
|
expectedPrimaryDomainStatus: true,
|
||||||
UserId: defaultInitAccount.UserId,
|
expectedCreatedBy: defaultInitAccount.UserId,
|
||||||
DomainCategory: PrivateCategory,
|
expectedUsers: []string{defaultInitAccount.UserId},
|
||||||
},
|
},
|
||||||
inputUpdateClaimAccount: true,
|
{
|
||||||
inputInitUserParams: defaultInitAccount,
|
name: "User With Private Category And Empty Domain",
|
||||||
testingFunc: require.Equal,
|
inputClaims: jwtclaims.AuthorizationClaims{
|
||||||
expectedMSG: "account IDs should match",
|
Domain: "",
|
||||||
expectedUserRole: UserRoleOwner,
|
UserId: "pvt-domain-user",
|
||||||
expectedDomain: defaultInitAccount.Domain,
|
DomainCategory: PrivateCategory,
|
||||||
expectedDomainCategory: PrivateCategory,
|
},
|
||||||
expectedPrimaryDomainStatus: true,
|
inputInitUserParams: defaultInitAccount,
|
||||||
expectedCreatedBy: defaultInitAccount.UserId,
|
testingFunc: require.NotEqual,
|
||||||
expectedUsers: []string{defaultInitAccount.UserId},
|
expectedMSG: "account IDs shouldn't match",
|
||||||
}
|
expectedUserRole: UserRoleOwner,
|
||||||
|
expectedDomain: "",
|
||||||
testCase7 := test{
|
expectedDomainCategory: "",
|
||||||
name: "User With Private Category And Empty Domain",
|
expectedPrimaryDomainStatus: false,
|
||||||
inputClaims: jwtclaims.AuthorizationClaims{
|
expectedCreatedBy: "pvt-domain-user",
|
||||||
Domain: "",
|
expectedUsers: []string{"pvt-domain-user"},
|
||||||
UserId: "pvt-domain-user",
|
|
||||||
DomainCategory: PrivateCategory,
|
|
||||||
},
|
},
|
||||||
inputInitUserParams: defaultInitAccount,
|
|
||||||
testingFunc: require.NotEqual,
|
|
||||||
expectedMSG: "account IDs shouldn't match",
|
|
||||||
expectedUserRole: UserRoleOwner,
|
|
||||||
expectedDomain: "",
|
|
||||||
expectedDomainCategory: "",
|
|
||||||
expectedPrimaryDomainStatus: false,
|
|
||||||
expectedCreatedBy: "pvt-domain-user",
|
|
||||||
expectedUsers: []string{"pvt-domain-user"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
if testCase.inputUpdateAttrs {
|
if testCase.inputUpdateAttrs {
|
||||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||||
require.NoError(t, err, "update init user failed")
|
require.NoError(t, err, "update init user failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2025,6 +2019,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "Peers with inactivity expiration disabled, no expired peers",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Two peers expired",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
ID: "peer-2",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC().Add(-45 * time.Second),
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-3": {
|
||||||
|
ID: "peer-3",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-1": {},
|
||||||
|
"peer-2": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetInactivePeers()
|
||||||
|
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
|
||||||
|
t.Fatalf("expected to have peer %s expired", peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@ -2094,6 +2172,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetPeersWithInactivity(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No account peers, no peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration disabled, no peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration enabled, return peers with expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-1": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := account.GetPeersWithInactivity()
|
||||||
|
assert.Len(t, actual, len(testCase.expectedPeers))
|
||||||
|
if len(testCase.expectedPeers) > 0 {
|
||||||
|
for k := range testCase.expectedPeers {
|
||||||
|
contains := false
|
||||||
|
for _, peer := range actual {
|
||||||
|
if k == peer.ID {
|
||||||
|
contains = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, contains)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
@ -2255,6 +2402,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*nbpeer.Peer
|
||||||
|
expiration time.Duration
|
||||||
|
expirationEnabled bool
|
||||||
|
expectedNextRun bool
|
||||||
|
expectedNextExpiration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNextExpiration := time.Minute
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No connected peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Connected peers with disabled expiration, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: false,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired peers, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To be expired peer, return expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
LoginExpired: false,
|
||||||
|
LastSeen: time.Now().Add(-1 * time.Second),
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
LastLogin: time.Now().UTC(),
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Minute,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: true,
|
||||||
|
expectedNextExpiration: expectedNextExpiration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers added with setup keys, no expiration",
|
||||||
|
peers: map[string]*nbpeer.Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
InactivityExpirationEnabled: true,
|
||||||
|
SetupKey: "key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration, ok := account.GetNextInactivePeerExpiration()
|
||||||
|
assert.Equal(t, testCase.expectedNextRun, ok)
|
||||||
|
if testCase.expectedNextRun {
|
||||||
|
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, expiration, testCase.expectedNextExpiration)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
@ -139,6 +139,13 @@ const (
|
|||||||
PostureCheckUpdated Activity = 61
|
PostureCheckUpdated Activity = 61
|
||||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||||
PostureCheckDeleted Activity = 62
|
PostureCheckDeleted Activity = 62
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled Activity = 63
|
||||||
|
PeerInactivityExpirationDisabled Activity = 64
|
||||||
|
|
||||||
|
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||||
|
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||||
|
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||||
)
|
)
|
||||||
|
|
||||||
var activityMap = map[Activity]Code{
|
var activityMap = map[Activity]Code{
|
||||||
@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{
|
|||||||
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
PostureCheckCreated: {"Posture check created", "posture.check.created"},
|
||||||
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
PostureCheckUpdated: {"Posture check updated", "posture.check.updated"},
|
||||||
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"},
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"},
|
||||||
|
PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"},
|
||||||
|
|
||||||
|
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||||
|
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||||
|
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
account.Settings = &Settings{
|
account.Settings = &Settings{
|
||||||
PeerLoginExpirationEnabled: false,
|
PeerLoginExpirationEnabled: false,
|
||||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: false,
|
||||||
|
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Settings.Extra != nil {
|
if req.Settings.Extra != nil {
|
||||||
|
@ -54,6 +54,14 @@ components:
|
|||||||
description: Period of time after which peer login expires (seconds).
|
description: Period of time after which peer login expires (seconds).
|
||||||
type: integer
|
type: integer
|
||||||
example: 43200
|
example: 43200
|
||||||
|
peer_inactivity_expiration_enabled:
|
||||||
|
description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
peer_inactivity_expiration:
|
||||||
|
description: Period of time of inactivity after which peer session expires (seconds).
|
||||||
|
type: integer
|
||||||
|
example: 43200
|
||||||
regular_users_view_blocked:
|
regular_users_view_blocked:
|
||||||
description: Allows blocking regular users from viewing parts of the system.
|
description: Allows blocking regular users from viewing parts of the system.
|
||||||
type: boolean
|
type: boolean
|
||||||
@ -81,6 +89,8 @@ components:
|
|||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
|
- peer_inactivity_expiration_enabled
|
||||||
|
- peer_inactivity_expiration
|
||||||
- regular_users_view_blocked
|
- regular_users_view_blocked
|
||||||
AccountExtraSettings:
|
AccountExtraSettings:
|
||||||
type: object
|
type: object
|
||||||
@ -243,6 +253,9 @@ components:
|
|||||||
login_expiration_enabled:
|
login_expiration_enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
example: false
|
example: false
|
||||||
|
inactivity_expiration_enabled:
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
approval_required:
|
approval_required:
|
||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
@ -251,6 +264,7 @@ components:
|
|||||||
- name
|
- name
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
- login_expiration_enabled
|
- login_expiration_enabled
|
||||||
|
- inactivity_expiration_enabled
|
||||||
Peer:
|
Peer:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PeerMinimum'
|
- $ref: '#/components/schemas/PeerMinimum'
|
||||||
@ -327,6 +341,10 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
example: "2023-05-05T09:00:35.477782Z"
|
example: "2023-05-05T09:00:35.477782Z"
|
||||||
|
inactivity_expiration_enabled:
|
||||||
|
description: Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
approval_required:
|
approval_required:
|
||||||
description: (Cloud only) Indicates whether peer needs approval
|
description: (Cloud only) Indicates whether peer needs approval
|
||||||
type: boolean
|
type: boolean
|
||||||
@ -354,6 +372,7 @@ components:
|
|||||||
- last_seen
|
- last_seen
|
||||||
- login_expiration_enabled
|
- login_expiration_enabled
|
||||||
- login_expired
|
- login_expired
|
||||||
|
- inactivity_expiration_enabled
|
||||||
- os
|
- os
|
||||||
- ssh_enabled
|
- ssh_enabled
|
||||||
- user_id
|
- user_id
|
||||||
|
@ -220,6 +220,12 @@ type AccountSettings struct {
|
|||||||
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||||
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||||
|
|
||||||
|
// PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds).
|
||||||
|
PeerInactivityExpiration int `json:"peer_inactivity_expiration"`
|
||||||
|
|
||||||
|
// PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login).
|
||||||
|
PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||||
|
|
||||||
@ -538,6 +544,9 @@ type Peer struct {
|
|||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// Ip Peer's IP address
|
// Ip Peer's IP address
|
||||||
Ip string `json:"ip"`
|
Ip string `json:"ip"`
|
||||||
|
|
||||||
@ -613,6 +622,9 @@ type PeerBatch struct {
|
|||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not
|
||||||
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
|
|
||||||
// Ip Peer's IP address
|
// Ip Peer's IP address
|
||||||
Ip string `json:"ip"`
|
Ip string `json:"ip"`
|
||||||
|
|
||||||
@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string
|
|||||||
// PeerRequest defines model for PeerRequest.
|
// PeerRequest defines model for PeerRequest.
|
||||||
type PeerRequest struct {
|
type PeerRequest struct {
|
||||||
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
|
||||||
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
ApprovalRequired *bool `json:"approval_required,omitempty"`
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"`
|
||||||
Name string `json:"name"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
SshEnabled bool `json:"ssh_enabled"`
|
Name string `json:"name"`
|
||||||
|
SshEnabled bool `json:"ssh_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@ -14,7 +16,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeersHandler is a handler that returns peers of the account
|
// PeersHandler is a handler that returns peers of the account
|
||||||
@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
SSHEnabled: req.SshEnabled,
|
SSHEnabled: req.SshEnabled,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||||
|
|
||||||
|
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.ApprovalRequired != nil {
|
if req.ApprovalRequired != nil {
|
||||||
@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &api.Peer{
|
return &api.Peer{
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
Name: peer.Name,
|
Name: peer.Name,
|
||||||
Ip: peer.IP.String(),
|
Ip: peer.IP.String(),
|
||||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||||
Connected: peer.Status.Connected,
|
Connected: peer.Status.Connected,
|
||||||
LastSeen: peer.Status.LastSeen,
|
LastSeen: peer.Status.LastSeen,
|
||||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||||
KernelVersion: peer.Meta.KernelVersion,
|
KernelVersion: peer.Meta.KernelVersion,
|
||||||
GeonameId: int(peer.Location.GeoNameID),
|
GeonameId: int(peer.Location.GeoNameID),
|
||||||
Version: peer.Meta.WtVersion,
|
Version: peer.Meta.WtVersion,
|
||||||
Groups: groupsInfo,
|
Groups: groupsInfo,
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.SSHEnabled,
|
||||||
Hostname: peer.Meta.Hostname,
|
Hostname: peer.Meta.Hostname,
|
||||||
UserId: peer.UserID,
|
UserId: peer.UserID,
|
||||||
UiVersion: peer.Meta.UIVersion,
|
UiVersion: peer.Meta.UIVersion,
|
||||||
DnsLabel: fqdn(peer, dnsDomain),
|
DnsLabel: fqdn(peer, dnsDomain),
|
||||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||||
LastLogin: peer.LastLogin,
|
LastLogin: peer.LastLogin,
|
||||||
LoginExpired: peer.Status.LoginExpired,
|
LoginExpired: peer.Status.LoginExpired,
|
||||||
ApprovalRequired: !approved,
|
ApprovalRequired: !approved,
|
||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
CountryCode: peer.Location.CountryCode,
|
CountryCode: peer.Location.CountryCode,
|
||||||
CityName: peer.Location.CityName,
|
CityName: peer.Location.CityName,
|
||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
|
|
||||||
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,6 +111,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if expired {
|
||||||
|
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||||
|
// the expired one. Here we notify them that connection is now allowed again.
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||||
oldStatus := peer.Status.Copy()
|
oldStatus := peer.Status.Copy()
|
||||||
newStatus := oldStatus
|
newStatus := oldStatus
|
||||||
newStatus.LastSeen = time.Now().UTC()
|
newStatus.LastSeen = time.Now().UTC()
|
||||||
@ -139,25 +164,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
return oldStatus.LoginExpired, nil
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldStatus.LoginExpired {
|
|
||||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
|
||||||
// the expired one. Here we notify them that connection is now allowed again.
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
@ -222,6 +237,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||||
|
|
||||||
|
if !peer.AddedWithSSOLogin() {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||||
|
|
||||||
|
event := activity.PeerInactivityExpirationEnabled
|
||||||
|
if !update.InactivityExpirationEnabled {
|
||||||
|
event = activity.PeerInactivityExpirationDisabled
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SaveAccount(ctx, account)
|
||||||
@ -454,23 +488,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
|
|
||||||
registrationTime := time.Now().UTC()
|
registrationTime := time.Now().UTC()
|
||||||
newPeer = &nbpeer.Peer{
|
newPeer = &nbpeer.Peer{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Key: peer.Key,
|
Key: peer.Key,
|
||||||
SetupKey: upperKey,
|
SetupKey: upperKey,
|
||||||
IP: freeIP,
|
IP: freeIP,
|
||||||
Meta: peer.Meta,
|
Meta: peer.Meta,
|
||||||
Name: peer.Meta.Hostname,
|
Name: peer.Meta.Hostname,
|
||||||
DNSLabel: freeLabel,
|
DNSLabel: freeLabel,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
SSHKey: peer.SSHKey,
|
SSHKey: peer.SSHKey,
|
||||||
LastLogin: registrationTime,
|
LastLogin: registrationTime,
|
||||||
CreatedAt: registrationTime,
|
CreatedAt: registrationTime,
|
||||||
LoginExpirationEnabled: addedByUser,
|
LoginExpirationEnabled: addedByUser,
|
||||||
Ephemeral: ephemeral,
|
Ephemeral: ephemeral,
|
||||||
Location: peer.Location,
|
Location: peer.Location,
|
||||||
|
InactivityExpirationEnabled: addedByUser,
|
||||||
}
|
}
|
||||||
opEvent.TargetID = newPeer.ID
|
opEvent.TargetID = newPeer.ID
|
||||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||||
|
@ -38,6 +38,8 @@ type Peer struct {
|
|||||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||||
// Works with LastLogin
|
// Works with LastLogin
|
||||||
LoginExpirationEnabled bool `diff:"-"`
|
LoginExpirationEnabled bool `diff:"-"`
|
||||||
|
|
||||||
|
InactivityExpirationEnabled bool `diff:"-"`
|
||||||
// LastLogin the time when peer performed last login operation
|
// LastLogin the time when peer performed last login operation
|
||||||
LastLogin time.Time `diff:"-"`
|
LastLogin time.Time `diff:"-"`
|
||||||
// CreatedAt records the time the peer was created
|
// CreatedAt records the time the peer was created
|
||||||
@ -187,6 +189,7 @@ func (p *Peer) Copy() *Peer {
|
|||||||
CreatedAt: p.CreatedAt,
|
CreatedAt: p.CreatedAt,
|
||||||
Ephemeral: p.Ephemeral,
|
Ephemeral: p.Ephemeral,
|
||||||
Location: p.Location,
|
Location: p.Location,
|
||||||
|
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,6 +222,22 @@ func (p *Peer) MarkLoginExpired(expired bool) {
|
|||||||
p.Status = newStatus
|
p.Status = newStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionExpired indicates whether the peer's session has expired or not.
|
||||||
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired.
|
||||||
|
// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
|
// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||||
|
// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||||
|
// Only peers added by interactive SSO login can be expired.
|
||||||
|
func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||||
|
if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
expiresAt := p.Status.LastSeen.Add(expiresIn)
|
||||||
|
now := time.Now()
|
||||||
|
timeLeft := expiresAt.Sub(now)
|
||||||
|
return timeLeft <= 0, timeLeft
|
||||||
|
}
|
||||||
|
|
||||||
// LoginExpired indicates whether the peer's login has expired or not.
|
// LoginExpired indicates whether the peer's login has expired or not.
|
||||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||||
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
|
@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeer_SessionExpired(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expirationEnabled bool
|
||||||
|
lastLogin time.Time
|
||||||
|
connected bool
|
||||||
|
expected bool
|
||||||
|
accountSettings *Settings
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
|
||||||
|
expirationEnabled: false,
|
||||||
|
connected: false,
|
||||||
|
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Hour,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Should Expire",
|
||||||
|
expirationEnabled: true,
|
||||||
|
connected: false,
|
||||||
|
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peer Inactivity Should Not Expire",
|
||||||
|
expirationEnabled: true,
|
||||||
|
connected: true,
|
||||||
|
lastLogin: time.Now().UTC(),
|
||||||
|
accountSettings: &Settings{
|
||||||
|
PeerInactivityExpirationEnabled: true,
|
||||||
|
PeerInactivityExpiration: time.Second,
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range tt {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
peerStatus := &nbpeer.PeerStatus{
|
||||||
|
Connected: c.connected,
|
||||||
|
}
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
InactivityExpirationEnabled: c.expirationEnabled,
|
||||||
|
LastLogin: c.lastLogin,
|
||||||
|
Status: peerStatus,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration)
|
||||||
|
assert.Equal(t, expired, c.expected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||||
|
accountCopy := Account{
|
||||||
|
Domain: domain,
|
||||||
|
DomainCategory: category,
|
||||||
|
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||||
|
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||||
|
Select(fieldsToUpdate).
|
||||||
|
Where(idQueryCondition, accountID).
|
||||||
|
Updates(&accountCopy)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "account %s", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||||
var peerCopy nbpeer.Peer
|
var peerCopy nbpeer.Peer
|
||||||
peerCopy.Status = &peerStatus
|
peerCopy.Status = &peerStatus
|
||||||
@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||||
|
var users []*User
|
||||||
|
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||||
@ -1117,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
|||||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
// we may need to reconsider changing the types.
|
||||||
|
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
|
||||||
|
if s.storeEngine == PostgresStoreEngine {
|
||||||
|
query = query.Order("json_array_length(peers::json) DESC")
|
||||||
|
} else {
|
||||||
|
query = query.Order("json_array_length(peers) DESC")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
|
@ -1191,3 +1191,76 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, users, len(account.Users))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
t.Run("Should update attributes with public domain", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "example.com"
|
||||||
|
category := "public"
|
||||||
|
IsDomainPrimaryAccount := false
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.NoError(t, err)
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, domain, account.Domain)
|
||||||
|
require.Equal(t, category, account.DomainCategory)
|
||||||
|
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Should update attributes with private domain", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "test.com"
|
||||||
|
category := "private"
|
||||||
|
IsDomainPrimaryAccount := true
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.NoError(t, err)
|
||||||
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, domain, account.Domain)
|
||||||
|
require.Equal(t, category, account.DomainCategory)
|
||||||
|
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Should fail when account does not exist", func(t *testing.T) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
domain := "test.com"
|
||||||
|
category := "private"
|
||||||
|
IsDomainPrimaryAccount := true
|
||||||
|
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlite_GetGroupByName(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "All", group.Name)
|
||||||
|
}
|
||||||
|
@ -58,9 +58,11 @@ type Store interface {
|
|||||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
|
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
|
@ -20,10 +20,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UserRoleOwner UserRole = "owner"
|
UserRoleOwner UserRole = "owner"
|
||||||
UserRoleAdmin UserRole = "admin"
|
UserRoleAdmin UserRole = "admin"
|
||||||
UserRoleUser UserRole = "user"
|
UserRoleUser UserRole = "user"
|
||||||
UserRoleUnknown UserRole = "unknown"
|
UserRoleUnknown UserRole = "unknown"
|
||||||
|
UserRoleBillingAdmin UserRole = "billing_admin"
|
||||||
|
|
||||||
UserStatusActive UserStatus = "active"
|
UserStatusActive UserStatus = "active"
|
||||||
UserStatusDisabled UserStatus = "disabled"
|
UserStatusDisabled UserStatus = "disabled"
|
||||||
@ -42,6 +43,8 @@ func StrRoleToUserRole(strRole string) UserRole {
|
|||||||
return UserRoleAdmin
|
return UserRoleAdmin
|
||||||
case "user":
|
case "user":
|
||||||
return UserRoleUser
|
return UserRoleUser
|
||||||
|
case "billing_admin":
|
||||||
|
return UserRoleBillingAdmin
|
||||||
default:
|
default:
|
||||||
return UserRoleUnknown
|
return UserRoleUnknown
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,10 @@ const (
|
|||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
metric.Meter
|
metric.Meter
|
||||||
|
|
||||||
TransferBytesSent metric.Int64Counter
|
TransferBytesSent metric.Int64Counter
|
||||||
TransferBytesRecv metric.Int64Counter
|
TransferBytesRecv metric.Int64Counter
|
||||||
|
AuthenticationTime metric.Float64Histogram
|
||||||
|
PeerStoreTime metric.Float64Histogram
|
||||||
|
|
||||||
peers metric.Int64UpDownCounter
|
peers metric.Int64UpDownCounter
|
||||||
peerActivityChan chan string
|
peerActivityChan chan string
|
||||||
@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Metrics{
|
m := &Metrics{
|
||||||
Meter: meter,
|
Meter: meter,
|
||||||
TransferBytesSent: bytesSent,
|
TransferBytesSent: bytesSent,
|
||||||
TransferBytesRecv: bytesRecv,
|
TransferBytesRecv: bytesRecv,
|
||||||
peers: peers,
|
AuthenticationTime: authTime,
|
||||||
|
PeerStoreTime: peerStoreTime,
|
||||||
|
peers: peers,
|
||||||
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
peerActivityChan: make(chan string, 10),
|
peerActivityChan: make(chan string, 10),
|
||||||
@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) {
|
|||||||
m.peerLastActive[id] = time.Time{}
|
m.peerLastActive[id] = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordAuthenticationTime measures the time taken for peer authentication
|
||||||
|
func (m *Metrics) RecordAuthenticationTime(duration time.Duration) {
|
||||||
|
m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordPeerStoreTime measures the time to store the peer in map
|
||||||
|
func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||||
|
m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||||
func (m *Metrics) PeerDisconnected(id string) {
|
func (m *Metrics) PeerDisconnected(id string) {
|
||||||
m.peers.Add(m.ctx, -1)
|
m.peers.Add(m.ctx, -1)
|
||||||
@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStandardBucketBoundaries() []float64 {
|
||||||
|
return []float64{
|
||||||
|
0.1,
|
||||||
|
0.5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
500,
|
||||||
|
1000,
|
||||||
|
5000,
|
||||||
|
10000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
153
relay/server/handshake.go
Normal file
153
relay/server/handshake.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
//nolint:staticcheck
|
||||||
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
|
//nolint:staticcheck
|
||||||
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// preparedMsg contains the marshalled success response messages
|
||||||
|
type preparedMsg struct {
|
||||||
|
responseHelloMsg []byte
|
||||||
|
responseAuthMsg []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPreparedMsg(instanceURL string) (*preparedMsg, error) {
|
||||||
|
rhm, err := marshalResponseHelloMsg(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ram, err := messages.MarshalAuthResponse(instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal auth response msg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &preparedMsg{
|
||||||
|
responseHelloMsg: rhm,
|
||||||
|
responseAuthMsg: ram,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
||||||
|
addr := &address.Address{URL: instanceURL}
|
||||||
|
addrData, err := addr.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal response address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal hello response: %w", err)
|
||||||
|
}
|
||||||
|
return responseMsg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshake struct {
|
||||||
|
conn net.Conn
|
||||||
|
validator auth.Validator
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
|
handshakeMethodAuth bool
|
||||||
|
peerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeReceive() ([]byte, error) {
|
||||||
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
|
n, err := h.conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = messages.ValidateVersion(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
bytePeerID []byte
|
||||||
|
peerID string
|
||||||
|
)
|
||||||
|
switch msgType {
|
||||||
|
//nolint:staticcheck
|
||||||
|
case messages.MsgTypeHello:
|
||||||
|
bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
case messages.MsgTypeAuth:
|
||||||
|
h.handshakeMethodAuth = true
|
||||||
|
bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.peerID = peerID
|
||||||
|
return bytePeerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handshakeResponse() error {
|
||||||
|
var responseMsg []byte
|
||||||
|
if h.handshakeMethodAuth {
|
||||||
|
responseMsg = h.preparedMsg.responseAuthMsg
|
||||||
|
} else {
|
||||||
|
responseMsg = h.preparedMsg.responseHelloMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.conn.Write(responseMsg); err != nil {
|
||||||
|
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
//nolint:staticcheck
|
||||||
|
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||||
|
|
||||||
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck
|
||||||
|
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
||||||
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerID := messages.HashIDToString(rawPeerID)
|
||||||
|
|
||||||
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawPeerID, peerID, nil
|
||||||
|
}
|
@ -7,16 +7,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
"github.com/netbirdio/netbird/relay/auth"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
|
||||||
//nolint:staticcheck
|
|
||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +25,7 @@ type Relay struct {
|
|||||||
|
|
||||||
store *Store
|
store *Store
|
||||||
instanceURL string
|
instanceURL string
|
||||||
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
closeMu sync.RWMutex
|
closeMu sync.RWMutex
|
||||||
@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
|||||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
return nil, fmt.Errorf("get instance URL: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||||
|
if err != nil {
|
||||||
|
metricsCancel()
|
||||||
|
return nil, fmt.Errorf("prepare message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
|||||||
|
|
||||||
// Accept start to handle a new peer connection
|
// Accept start to handle a new peer connection
|
||||||
func (r *Relay) Accept(conn net.Conn) {
|
func (r *Relay) Accept(conn net.Conn) {
|
||||||
|
acceptTime := time.Now()
|
||||||
r.closeMu.RLock()
|
r.closeMu.RLock()
|
||||||
defer r.closeMu.RUnlock()
|
defer r.closeMu.RUnlock()
|
||||||
if r.closed {
|
if r.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID, err := r.handshake(conn)
|
h := handshake{
|
||||||
|
conn: conn,
|
||||||
|
validator: r.validator,
|
||||||
|
preparedMsg: r.preparedMsg,
|
||||||
|
}
|
||||||
|
peerID, err := h.handshakeReceive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to handshake: %s", err)
|
log.Errorf("failed to handshake: %s", err)
|
||||||
cErr := conn.Close()
|
if cErr := conn.Close(); cErr != nil {
|
||||||
if cErr != nil {
|
|
||||||
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
|
|
||||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
||||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
|
storeTime := time.Now()
|
||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
|
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||||
r.metrics.PeerConnected(peer.String())
|
r.metrics.PeerConnected(peer.String())
|
||||||
go func() {
|
go func() {
|
||||||
peer.Work()
|
peer.Work()
|
||||||
@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
peer.log.Debugf("relay connection closed")
|
peer.log.Debugf("relay connection closed")
|
||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if err := h.handshakeResponse(); err != nil {
|
||||||
|
log.Errorf("failed to send handshake response, close peer: %s", err)
|
||||||
|
peer.Close()
|
||||||
|
}
|
||||||
|
r.metrics.RecordAuthenticationTime(time.Since(acceptTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown closes the relay server
|
// Shutdown closes the relay server
|
||||||
@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
|||||||
func (r *Relay) InstanceURL() string {
|
func (r *Relay) InstanceURL() string {
|
||||||
return r.instanceURL
|
return r.instanceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
|
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = messages.ValidateVersion(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
responseMsg []byte
|
|
||||||
peerID []byte
|
|
||||||
)
|
|
||||||
switch msgType {
|
|
||||||
//nolint:staticcheck
|
|
||||||
case messages.MsgTypeHello:
|
|
||||||
peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
case messages.MsgTypeAuth:
|
|
||||||
peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.Write(responseMsg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
|
|
||||||
//nolint:staticcheck
|
|
||||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
|
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &address.Address{URL: r.instanceURL}
|
|
||||||
addrData, err := addr.Marshal()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
responseMsg, err := messages.MarshalHelloResponse(addrData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
|
|
||||||
}
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
|
|
||||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
|
|
||||||
if err := r.validator.Validate(authPayload); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawPeerID, responseMsg, nil
|
|
||||||
}
|
|
||||||
|
@ -86,7 +86,7 @@ download_release_binary() {
|
|||||||
|
|
||||||
# Unzip the app and move to INSTALL_DIR
|
# Unzip the app and move to INSTALL_DIR
|
||||||
unzip -q -o "$BINARY_NAME"
|
unzip -q -o "$BINARY_NAME"
|
||||||
mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/"
|
||||||
else
|
else
|
||||||
${SUDO} mkdir -p "$INSTALL_DIR"
|
${SUDO} mkdir -p "$INSTALL_DIR"
|
||||||
tar -xzvf "$BINARY_NAME"
|
tar -xzvf "$BINARY_NAME"
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
@ -13,8 +14,6 @@ import (
|
|||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/signal/metrics"
|
"github.com/netbirdio/netbird/signal/metrics"
|
||||||
"github.com/netbirdio/netbird/signal/peer"
|
"github.com/netbirdio/netbird/signal/peer"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
|
@ -11,7 +11,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||||
NetbirdFwmark = 0x1BD00
|
NetbirdFwmark = 0x1BD00
|
||||||
|
PreroutingFwmark = 0x1BD01
|
||||||
|
|
||||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user