Merge branch 'main' into groups-get-account-refactoring

# Conflicts:
#	management/server/group.go
#	management/server/group/group.go
#	management/server/setupkey.go
#	management/server/sql_store.go
#	management/server/status/error.go
#	management/server/store.go
This commit is contained in:
bcmmbaga 2024-11-11 21:10:02 +03:00
commit 010a8bfdc1
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
22 changed files with 262 additions and 144 deletions

View File

@ -2,6 +2,7 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
@ -13,10 +14,11 @@ import (
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()

View File

@ -2,6 +2,7 @@ package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close()
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {

View File

@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
if err := e.remoteConn.Close(); err != nil {
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil

View File

@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}

View File

@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil)
defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err)
_, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal

View File

@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
}
}
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
return maps.Clone(s.routes)
}
// LocalPeerState contains the latest state of the local peer
@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
d.notifyPeerListChanged()
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.AddRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.DeleteRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
if found {
return ch
}
ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}
@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}

View File

@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
peerState.IP = ip
err := status.UpdatePeerState(peerState)
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")
select {

View File

@ -57,6 +57,9 @@ type WorkerICE struct {
localUfrag string
localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
return
}
err := w.agent.Close()
if err != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)

View File

@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch {
case chosen == "":
var peers []string
@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
if found {
continue
}
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
c.addStateRoute()
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update

View File

@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`

View File

@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@ -38,6 +39,7 @@ type GRPCServer struct {
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
}
// NewServer creates a new Management server
@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
unlock()
unlock = nil
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return claims.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
start = time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {

View File

@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
}
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
}
if req.Peer != nil && req.PeerGroups != nil {

View File

@ -5,6 +5,7 @@ import (
"crypto/sha256"
b64 "encoding/base64"
"hash/fnv"
"slices"
"strconv"
"strings"
"time"
@ -236,6 +237,10 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var setupKey *SetupKey
var plainKey string
var eventsToStore []func()
@ -289,6 +294,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var oldKey *SetupKey
var newKey *SetupKey
var eventsToStore []func()
@ -414,10 +423,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
}
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
if err != nil {
return err
}
for _, groupID := range autoGroupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
group, ok := groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group not found: %s", groupID)
}
if group.IsGroupAll() {
@ -432,26 +446,37 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
continue
}
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
})
}
for _, g := range addedGroups {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
continue
}
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
})
}
return eventsToStore

View File

@ -485,9 +485,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
return nil, status.NewSetupKeyNotFoundError(setupKey)
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
}
if key.AccountID == "" {
@ -570,7 +571,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -756,9 +757,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
return "", status.NewSetupKeyNotFoundError(setupKey)
}
return "", status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
}
if accountID == "" {
@ -985,9 +987,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
First(&setupKey, keyQueryCondition, key)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
return nil, status.NewSetupKeyNotFoundError(key)
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
}
return &setupKey, nil
}
@ -1005,7 +1008,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
return status.NewSetupKeyNotFoundError(setupKeyID)
}
return nil
@ -1207,6 +1210,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
}
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
return groupsMap, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
@ -1278,7 +1298,7 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
var setupKeys []*SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
@ -1291,11 +1311,11 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
var setupKey *SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
}
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
@ -1306,8 +1326,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save setup key to store")
@ -1318,15 +1337,14 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
// DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete setup key from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
return status.NewSetupKeyNotFoundError(keyID)
}
return nil

View File

@ -103,8 +103,8 @@ func NewPeerLoginExpiredError() error {
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
func NewSetupKeyNotFoundError(setupKeyID string) error {
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
}
func NewGetAccountFromStoreError(err error) error {
@ -126,11 +126,6 @@ func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")

View File

@ -71,8 +71,9 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error

View File

@ -3,7 +3,6 @@ package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
if conn.conn != connReference {
return 0, io.EOF
return 0, net.ErrClosed
}
// todo: use buffer pool instead of create new transport msg.
@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
return net.ErrClosed
}
if container.conn != connReference {

View File

@ -1,7 +1,6 @@
package client
import (
"io"
"net"
"time"
)
@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
n = copy(b, msg.Payload)

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
return net.ErrClosed
}
var wErr *websocket.CloseError
@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
return net.ErrClosed
}
return err
}

View File

@ -2,7 +2,7 @@ package server
import (
"context"
"io"
"errors"
"net"
"sync"
"time"
@ -57,7 +57,7 @@ func (p *Peer) Work() {
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err)
}
return