Refactor Sync

This commit is contained in:
braginini 2023-03-01 18:54:27 +01:00
parent 2f09c3d2c4
commit 241b819156
4 changed files with 99 additions and 80 deletions

View File

@ -97,6 +97,7 @@ type AccountManager interface {
GetPeer(accountID, peerID, userID string) (*Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, error)
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error)
}
type DefaultAccountManager struct {

View File

@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
pb "github.com/golang/protobuf/proto" //nolint
"strings"
"time"
@ -118,44 +119,18 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", peerKey.String())
return status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", peerKey.String())
}
peer, err := s.accountManager.GetPeerByKey(peerKey.String())
if err != nil {
p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
}
account, err := s.accountManager.GetAccountByPeerID(peer.ID)
if err != nil {
return status.Error(codes.Internal, "internal server error")
}
expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true)
if err != nil {
log.Warnf("failed marking peer login expired %s %v", peerKey, err)
}
return status.Errorf(codes.PermissionDenied, "peer login has expired %v ago. Please log in once more", left)
}
syncReq := &proto.SyncRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
peerKey, err := s.parseRequest(req, syncReq)
if err != nil {
p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
return err
}
err = s.sendInitialSync(peerKey, peer, srv)
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if err != nil {
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
@ -263,21 +238,19 @@ func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta {
}
}
func (s *GRPCServer) parseLoginRequest(req *proto.EncryptedMessage) (*proto.LoginRequest, wgtypes.Key, error) {
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
loginReq := &proto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq)
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
if err != nil {
return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
return loginReq, peerKey, nil
return peerKey, nil
}
// Login endpoint first checks whether peer is registered under any account
@ -293,7 +266,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
loginReq, peerKey, err := s.parseLoginRequest(req)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq)
if err != nil {
return nil, err
}
@ -310,7 +284,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
// or it uses a setup key to register.
if loginReq.GetJwtToken() != "" {
// todo what about the case when JWT provided expired?
// todo what about the case when JWT provided is expired?
userID, err = s.validateToken(loginReq.GetJwtToken())
if err != nil {
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
@ -473,13 +447,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
networkMap, err := s.accountManager.GetNetworkMap(peer.ID)
if err != nil {
log.Warnf("error getting a list of peers for a peer %s", peer.ID)
return err
}
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {

View File

@ -73,6 +73,7 @@ type MockAccountManager struct {
GetAccountByPeerIDFunc func(peerID string) (*server.Account, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, error)
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@ -564,3 +565,11 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*server.Peer, e
}
return nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}

View File

@ -36,6 +36,12 @@ type PeerStatus struct {
LoginExpired bool
}
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
}
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
type PeerLogin struct {
// WireGuardPubKey is a peers WireGuard public key
@ -454,6 +460,34 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
}
func (am *DefaultAccountManager) getNetworkMap(peer *Peer, account *Account) *NetworkMap {
aclPeers := account.getPeersByACL(peer.ID)
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
routesUpdate := account.getRoutesToSync(peer.ID, aclPeers)
dnsManagementStatus := account.getPeerDNSManagementStatus(peer.ID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(account, peer.ID)
}
return &NetworkMap{
Peers: aclPeers,
Network: account.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
}
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) {
@ -467,31 +501,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
aclPeers := account.getPeersByACL(peerID)
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
routesUpdate := account.getRoutesToSync(peerID, aclPeers)
dnsManagementStatus := account.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(account, am.dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(account, peerID)
}
return &NetworkMap{
Peers: aclPeers,
Network: account.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
}, err
return am.getNetworkMap(peer, account), nil
}
// GetPeerNetwork returns the Network for a given peer
@ -643,13 +653,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
return newPeer, nil
}
func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer *Peer, account *Account) error {
func (am *DefaultAccountManager) checkPeerLoginExpiration(loginUserID string, peer *Peer, account *Account) error {
if peer.AddedWithSSOLogin() {
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
log.Debugf("peer %s login expired", peer.ID)
if login.UserID == "" {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.
_, err := am.markPeerLoginExpired(peer, account, true)
if err != nil {
@ -659,8 +669,8 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer
"peer login has expired %v ago. Please log in once more", expiresIn)
} else {
// user ID is there meaning that JWT validation passed successfully in the API layer.
if peer.UserID != login.UserID {
log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
if peer.UserID != loginUserID {
log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
return status.Errorf(status.Unauthenticated, "can't login")
}
_ = am.updatePeerLastLogin(peer, account)
@ -671,6 +681,37 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer
return nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, err
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, nil, err
}
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, err
}
err = am.checkPeerLoginExpiration("", peer, account)
if err != nil {
return nil, nil, err
}
return peer, am.getNetworkMap(peer, account), nil
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) {
@ -705,7 +746,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) {
return nil, err
}
err = am.checkPeerLoginExpiration(login, peer, account)
err = am.checkPeerLoginExpiration(login.UserID, peer, account)
if err != nil {
return nil, err
}