mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 22:10:56 +01:00
Refactor Sync
This commit is contained in:
parent
2f09c3d2c4
commit
241b819156
@ -97,6 +97,7 @@ type AccountManager interface {
|
||||
GetPeer(accountID, peerID, userID string) (*Peer, error)
|
||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
LoginPeer(login PeerLogin) (*Peer, error)
|
||||
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -118,44 +119,18 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||
}
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", peerKey.String())
|
||||
return status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", peerKey.String())
|
||||
}
|
||||
|
||||
peer, err := s.accountManager.GetPeerByKey(peerKey.String())
|
||||
if err != nil {
|
||||
p, _ := gRPCPeer.FromContext(srv.Context())
|
||||
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
|
||||
log.Debug(msg)
|
||||
return msg
|
||||
}
|
||||
|
||||
account, err := s.accountManager.GetAccountByPeerID(peer.ID)
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "internal server error")
|
||||
}
|
||||
expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
|
||||
err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer login expired %s %v", peerKey, err)
|
||||
}
|
||||
return status.Errorf(codes.PermissionDenied, "peer login has expired %v ago. Please log in once more", left)
|
||||
}
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
|
||||
peerKey, err := s.parseRequest(req, syncReq)
|
||||
if err != nil {
|
||||
p, _ := gRPCPeer.FromContext(srv.Context())
|
||||
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
|
||||
log.Debug(msg)
|
||||
return msg
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, srv)
|
||||
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
if err != nil {
|
||||
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
@ -263,21 +238,19 @@ func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) parseLoginRequest(req *proto.EncryptedMessage) (*proto.LoginRequest, wgtypes.Key, error) {
|
||||
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
}
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq)
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
|
||||
if err != nil {
|
||||
return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
}
|
||||
|
||||
return loginReq, peerKey, nil
|
||||
|
||||
return peerKey, nil
|
||||
}
|
||||
|
||||
// Login endpoint first checks whether peer is registered under any account
|
||||
@ -293,7 +266,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||
}
|
||||
|
||||
loginReq, peerKey, err := s.parseLoginRequest(req)
|
||||
loginReq := &proto.LoginRequest{}
|
||||
peerKey, err := s.parseRequest(req, loginReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -310,7 +284,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
|
||||
// or it uses a setup key to register.
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
// todo what about the case when JWT provided expired?
|
||||
// todo what about the case when JWT provided is expired?
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
if err != nil {
|
||||
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
|
||||
@ -473,13 +447,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
||||
networkMap, err := s.accountManager.GetNetworkMap(peer.ID)
|
||||
if err != nil {
|
||||
log.Warnf("error getting a list of peers for a peer %s", peer.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
||||
// make secret time based TURN credentials optional
|
||||
var turnCredentials *TURNCredentials
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
|
@ -73,6 +73,7 @@ type MockAccountManager struct {
|
||||
GetAccountByPeerIDFunc func(peerID string) (*server.Account, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@ -564,3 +565,11 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*server.Peer, e
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(sync)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
}
|
||||
|
@ -36,6 +36,12 @@ type PeerStatus struct {
|
||||
LoginExpired bool
|
||||
}
|
||||
|
||||
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
|
||||
type PeerSync struct {
|
||||
// WireGuardPubKey is a peers WireGuard public key
|
||||
WireGuardPubKey string
|
||||
}
|
||||
|
||||
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
|
||||
type PeerLogin struct {
|
||||
// WireGuardPubKey is a peers WireGuard public key
|
||||
@ -454,6 +460,34 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
|
||||
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getNetworkMap(peer *Peer, account *Account) *NetworkMap {
|
||||
aclPeers := account.getPeersByACL(peer.ID)
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
routesUpdate := account.getRoutesToSync(peer.ID, aclPeers)
|
||||
|
||||
dnsManagementStatus := account.getPeerDNSManagementStatus(peer.ID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
|
||||
if peersCustomZone.Domain != "" {
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(account, peer.ID)
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
Peers: aclPeers,
|
||||
Network: account.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) {
|
||||
|
||||
@ -467,31 +501,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro
|
||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
aclPeers := account.getPeersByACL(peerID)
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
routesUpdate := account.getRoutesToSync(peerID, aclPeers)
|
||||
|
||||
dnsManagementStatus := account.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
|
||||
if peersCustomZone.Domain != "" {
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(account, peerID)
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
Peers: aclPeers,
|
||||
Network: account.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
}, err
|
||||
return am.getNetworkMap(peer, account), nil
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
@ -643,13 +653,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
return newPeer, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer *Peer, account *Account) error {
|
||||
func (am *DefaultAccountManager) checkPeerLoginExpiration(loginUserID string, peer *Peer, account *Account) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||
if expired || peer.Status.LoginExpired {
|
||||
log.Debugf("peer %s login expired", peer.ID)
|
||||
if login.UserID == "" {
|
||||
if loginUserID == "" {
|
||||
// absence of a user ID indicates that JWT wasn't provided.
|
||||
_, err := am.markPeerLoginExpired(peer, account, true)
|
||||
if err != nil {
|
||||
@ -659,8 +669,8 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer
|
||||
"peer login has expired %v ago. Please log in once more", expiresIn)
|
||||
} else {
|
||||
// user ID is there meaning that JWT validation passed successfully in the API layer.
|
||||
if peer.UserID != login.UserID {
|
||||
log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||
if peer.UserID != loginUserID {
|
||||
log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||
return status.Errorf(status.Unauthenticated, "can't login")
|
||||
}
|
||||
_ = am.updatePeerLastLogin(peer, account)
|
||||
@ -671,6 +681,37 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.checkPeerLoginExpiration("", peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return peer, am.getNetworkMap(peer, account), nil
|
||||
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) {
|
||||
@ -705,7 +746,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.checkPeerLoginExpiration(login, peer, account)
|
||||
err = am.checkPeerLoginExpiration(login.UserID, peer, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user