diff --git a/management/server/account.go b/management/server/account.go index 248222ea4..30ebe4331 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2430,16 +2430,17 @@ func newAccountWithId(ctx context.Context, store Store, accountID, userID, domai return fmt.Errorf("failed to save group All: %w", err) } - id := xid.New().String() + policyID := xid.New().String() defaultPolicy := &Policy{ - ID: id, + ID: policyID, AccountID: accountID, Name: DefaultPolicyName, Description: DefaultPolicyDescription, Enabled: true, Rules: []*PolicyRule{ { - ID: id, + ID: xid.New().String(), + PolicyID: policyID, Name: DefaultRuleName, Description: DefaultRuleDescription, Enabled: true, @@ -2451,7 +2452,7 @@ func newAccountWithId(ctx context.Context, store Store, accountID, userID, domai }, }, } - if err := transaction.SavePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil { + if err := transaction.CreatePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil { return fmt.Errorf("failed to save default policy: %w", err) } diff --git a/management/server/policy.go b/management/server/policy.go index ba9a31aa1..12dac2469 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -399,7 +399,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return fmt.Errorf("failed to increment network serial: %w", err) } - if err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy); err != nil { + saveFunc := transaction.SavePolicy + if !isUpdate { + saveFunc = transaction.CreatePolicy + } + + if err := saveFunc(ctx, LockingStrengthUpdate, policy); err != nil { return fmt.Errorf("failed to save policy: %w", err) } return nil diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6eab5814f..c8f138aeb 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -421,7 +421,8 @@ func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, // SaveUser saves the given user to the database. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Select(clause.Associations).Save(user) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save user to store") @@ -502,15 +503,19 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&User{}, accountAndIDQueryCondition, accountID, userID) - if err := result.Error; err != nil { - log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) - return status.Errorf(status.Internal, "failed to user policy from store") - } + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&PersonalAccessToken{}, "user_id = ?", userID) + if result.Error != nil { + return result.Error + } - if result.RowsAffected == 0 { - return status.NewUserNotFoundError(userID) + return tx.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&User{}, accountAndIDQueryCondition, accountID, userID).Error + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete user from store") } return nil @@ -1414,6 +1419,16 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in the store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in the store") + } + + return nil +} + // SavePolicy saves a policy to the database. func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index bb2a8f15d..ddf57b11c 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -88,6 +88,7 @@ func runLargeTest(t *testing.T, store Store) { peer := &nbpeer.Peer{ ID: peerID, + AccountID: accountID, Key: peerID, IP: netIP, Name: peerID, @@ -96,8 +97,8 @@ func runLargeTest(t *testing.T, store Store) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer) - assert.NoError(t, err, "failed to save peer") + err = store.AddPeerToAccount(context.Background(), peer) + assert.NoError(t, err, "failed to add peer") err = store.AddPeerToAllGroup(context.Background(), accountID, peerID) assert.NoError(t, err, "failed to add peer to all group") @@ -237,12 +238,14 @@ func TestSqlite_SaveAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") @@ -255,12 +258,14 @@ func TestSqlite_SaveAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") @@ -312,12 +317,6 @@ func TestSqlite_DeleteAccount(t *testing.T) { accountID := "account_id" testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { - ID: "testtoken", - Name: "test token", - }} - err = newAccountWithId(context.Background(), store, accountID, testUserID, "") require.NoError(t, err) @@ -326,12 +325,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + AccountID: accountID, + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") @@ -638,7 +639,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id) require.NoError(t, err) - require.Equal(t, id, user.PATs[id].ID) + require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) _, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id") require.Error(t, err) @@ -814,12 +815,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") @@ -833,12 +836,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") @@ -907,12 +912,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) require.NoError(t, err, "failed to save setup key") - err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testingpeer", + AccountID: accountID, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, }) require.NoError(t, err, "failed to save peer") diff --git a/management/server/store.go b/management/server/store.go index bcfd680f9..190632f69 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -84,6 +84,7 @@ type Store interface { GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error diff --git a/management/server/user.go b/management/server/user.go index bd72e04a0..fab597ad0 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -69,7 +69,7 @@ type User struct { // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user AutoGroups []string `gorm:"serializer:json"` PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool // LastLogin is the last time the user logged in to IdP