mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-28 08:37:57 +01:00
fix(acl): update each peer's network when rule,group or peer changed (#333)
* fix(acl): update each peer's network when rule,group or peer changed * fix(ACL): update network test * fix(acl): cleanup indexes before update them * fix(acl): clean up rules indexes only for account
This commit is contained in:
parent
fa0399d975
commit
d005cd32b0
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -513,6 +514,189 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.AddAccount("test_account", "account_creator", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
setupKey = key
|
||||
if setupKey.Type == SetupKeyReusable {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if setupKey == nil {
|
||||
t.Errorf("expecting account to have a default setup key")
|
||||
return
|
||||
}
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
|
||||
getPeer := func() *Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, err := manager.AddPeer(setupKey.Key, "", &Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: expectedPeerKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer1 to be added, got failure %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
peer1 := getPeer()
|
||||
peer2 := getPeer()
|
||||
peer3 := getPeer()
|
||||
|
||||
account, err = manager.GetAccountById(account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(peer1.Key)
|
||||
defer manager.peersUpdateManager.CloseChannel(peer1.Key)
|
||||
|
||||
group := Group{
|
||||
ID: "group-id",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.Key, peer2.Key, peer3.Key},
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Source: []string{"group-id"},
|
||||
Destination: []string{"group-id"},
|
||||
Flow: TrafficFlowBidirect,
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
t.Run("save group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveGroup(account.Id, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("delete rule update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
var defaultRule *Rule
|
||||
for _, r := range account.Rules {
|
||||
defaultRule = r
|
||||
}
|
||||
|
||||
if err := manager.DeleteRule(account.Id, defaultRule.ID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("save rule update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveRule(account.Id, &rule); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("delete peer update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := manager.DeletePeer(account.Id, peer3.Key); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("delete group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeleteGroup(account.Id, group.ID); err != nil {
|
||||
t.Errorf("delete group rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@ -664,7 +848,6 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, newMeta, p.Meta)
|
||||
|
||||
}
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
@ -180,6 +180,8 @@ func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||
|
||||
delete(account.Peers, peerKey)
|
||||
delete(s.PeerKeyId2AccountId, peerKey)
|
||||
delete(s.PeerKeyId2DstRulesId, peerKey)
|
||||
delete(s.PeerKeyId2SrcRulesId, peerKey)
|
||||
|
||||
// cleanup groups
|
||||
var peers []string
|
||||
@ -240,9 +242,34 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.PeerKeyId2AccountId[peer.Key] = account.Id
|
||||
}
|
||||
|
||||
// remove all peers related to account from rules indexes
|
||||
cleanIDs := make([]string, 0)
|
||||
for key := range s.PeerKeyId2SrcRulesId {
|
||||
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
|
||||
cleanIDs = append(cleanIDs, key)
|
||||
}
|
||||
}
|
||||
for _, key := range cleanIDs {
|
||||
delete(s.PeerKeyId2SrcRulesId, key)
|
||||
}
|
||||
cleanIDs = cleanIDs[:0]
|
||||
for key := range s.PeerKeyId2DstRulesId {
|
||||
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
|
||||
cleanIDs = append(cleanIDs, key)
|
||||
}
|
||||
}
|
||||
for _, key := range cleanIDs {
|
||||
delete(s.PeerKeyId2DstRulesId, key)
|
||||
}
|
||||
|
||||
// rebuild rule indexes
|
||||
for _, rule := range account.Rules {
|
||||
for _, gid := range rule.Source {
|
||||
for _, pid := range account.Groups[gid].Peers {
|
||||
g, ok := account.Groups[gid]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
for _, pid := range g.Peers {
|
||||
rules := s.PeerKeyId2SrcRulesId[pid]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
@ -252,7 +279,11 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
}
|
||||
}
|
||||
for _, gid := range rule.Destination {
|
||||
for _, pid := range account.Groups[gid].Peers {
|
||||
g, ok := account.Groups[gid]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
for _, pid := range g.Peers {
|
||||
rules := s.PeerKeyId2DstRulesId[pid]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
|
@ -54,7 +54,13 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
|
||||
}
|
||||
|
||||
account.Groups[group.ID] = group
|
||||
return am.Store.SaveAccount(account)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
@ -69,7 +75,12 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
return am.Store.SaveAccount(account)
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
@ -116,7 +127,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
|
||||
group.Peers = append(group.Peers, peerKey)
|
||||
}
|
||||
|
||||
return am.Store.SaveAccount(account)
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
@ -134,14 +150,17 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
|
||||
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerKey {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
return am.Store.SaveAccount(account)
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
return status.Errorf(codes.Internal, "can't save account")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// GroupListPeers returns list of the peers from the group
|
||||
|
@ -134,6 +134,16 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
// delete peer from groups
|
||||
for _, g := range account.Groups {
|
||||
for i, pk := range g.Peers {
|
||||
if pk == peerKey {
|
||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := am.Store.DeletePeer(accountId, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -163,39 +173,10 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// notify other peers of the change
|
||||
peers, err := am.Store.GetAccountPeers(accountId)
|
||||
if err != nil {
|
||||
if err := am.updateAccountPeers(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, p := range peers {
|
||||
peersToSend := []*Peer{}
|
||||
for _, remote := range peers {
|
||||
if p.Key != remote.Key {
|
||||
peersToSend = append(peersToSend, remote)
|
||||
}
|
||||
}
|
||||
update := toRemotePeerConfig(peersToSend)
|
||||
err = am.peersUpdateManager.SendUpdate(p.Key,
|
||||
&UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
// fill those field for backward compatibility
|
||||
RemotePeers: update,
|
||||
RemotePeersIsEmpty: len(update) == 0,
|
||||
// new field
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: account.Network.CurrentSerial(),
|
||||
RemotePeers: update,
|
||||
RemotePeersIsEmpty: len(update) == 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
am.peersUpdateManager.CloseChannel(peerKey)
|
||||
return peer, nil
|
||||
}
|
||||
@ -229,56 +210,8 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
||||
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey)
|
||||
}
|
||||
|
||||
var res []*Peer
|
||||
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
return &NetworkMap{
|
||||
Peers: res,
|
||||
Network: account.Network.Copy(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
return &NetworkMap{
|
||||
Peers: res,
|
||||
Network: account.Network.Copy(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for _, r := range srcRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Destination {
|
||||
groups[gid] = account.Groups[gid]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range dstRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Source {
|
||||
groups[gid] = account.Groups[gid]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
for _, pid := range g.Peers {
|
||||
peer, ok := account.Peers[pid]
|
||||
if !ok {
|
||||
log.Warnf("peer %s found in group %s but doesn't belong to account %s", pid, g.ID, account.Id)
|
||||
continue
|
||||
}
|
||||
// exclude original peer
|
||||
if peer.Key != peerKey {
|
||||
res = append(res, peer.Copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
Peers: res,
|
||||
Peers: am.getPeersByACL(account, peerKey),
|
||||
Network: account.Network.Copy(),
|
||||
}, err
|
||||
}
|
||||
@ -411,3 +344,93 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPeersByACL allowed for given peer by ACL
|
||||
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer {
|
||||
var peers []*Peer
|
||||
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
srcRules = []*Rule{}
|
||||
}
|
||||
|
||||
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
dstRules = []*Rule{}
|
||||
}
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for _, r := range srcRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Destination {
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range dstRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Source {
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersSet := make(map[string]struct{})
|
||||
for _, g := range groups {
|
||||
for _, pid := range g.Peers {
|
||||
peer, ok := account.Peers[pid]
|
||||
if !ok {
|
||||
log.Warnf(
|
||||
"peer %s found in group %s but doesn't belong to account %s",
|
||||
pid,
|
||||
g.ID,
|
||||
account.Id,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// exclude original peer
|
||||
if _, ok := peersSet[peer.Key]; peer.Key != peerKey && !ok {
|
||||
peersSet[peer.Key] = struct{}{}
|
||||
peers = append(peers, peer.Copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peers
|
||||
}
|
||||
|
||||
// updateAccountPeers network map constructed by ACL
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||
// notify other peers of the change
|
||||
peers, err := am.Store.GetAccountPeers(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range peers {
|
||||
update := toRemotePeerConfig(am.getPeersByACL(account, p.Key))
|
||||
err = am.peersUpdateManager.SendUpdate(p.Key,
|
||||
&UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
// fill those field for backward compatibility
|
||||
RemotePeers: update,
|
||||
RemotePeersIsEmpty: len(update) == 0,
|
||||
// new field
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: account.Network.CurrentSerial(),
|
||||
RemotePeers: update,
|
||||
RemotePeersIsEmpty: len(update) == 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -192,6 +192,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
|
||||
len(networkMap1.Peers),
|
||||
networkMap1.Peers,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() {
|
||||
|
@ -70,7 +70,13 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
|
||||
}
|
||||
|
||||
account.Rules[rule.ID] = rule
|
||||
return am.Store.SaveAccount(account)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// DeleteRule of ACL from the store
|
||||
@ -85,7 +91,12 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||
|
||||
delete(account.Rules, ruleID)
|
||||
|
||||
return am.Store.SaveAccount(account)
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
// ListRules of ACL from the store
|
||||
|
Loading…
Reference in New Issue
Block a user