mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-22 22:08:39 +01:00
Merge branch 'main' into groups-get-account-refactoring
# Conflicts: # management/server/group.go # management/server/group/group.go # management/server/setupkey.go # management/server/sql_store.go # management/server/status/error.go # management/server/store.go
This commit is contained in:
commit
010a8bfdc1
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -13,10 +14,11 @@ import (
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
p.serverInstance = serverInstance
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
p.cancel()
|
||||
|
||||
|
@ -2,6 +2,7 @@ package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
return p.remoteConn.Close()
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
|
@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
|
@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
|
@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return s.routes
|
||||
return maps.Clone(s.routes)
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.AddRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.DeleteRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
ch, found := d.changeNotify[peer]
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
if found {
|
||||
return ch
|
||||
}
|
||||
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
ch, found := d.changeNotify[peerID]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(d.changeNotify, peerID)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerState(peerState)
|
||||
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
|
@ -57,6 +57,9 @@ type WorkerICE struct {
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||
@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
err := w.agent.Close()
|
||||
if err != nil {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
|
||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
agentCancel()
|
||||
_ = agent.Close()
|
||||
w.agent = nil
|
||||
|
||||
w.muxAgent.Unlock()
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
}
|
||||
w.closeAgent(agentCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
cancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
|
@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||
latency := 999 * time.Millisecond
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||
}
|
||||
|
||||
// avoid negative tempScore on the higher latency calculation
|
||||
if latency > 1*time.Second {
|
||||
latency = 999 * time.Millisecond
|
||||
}
|
||||
|
||||
// higher latency is worse score
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if !found {
|
||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
|
||||
closerChan := make(chan struct{})
|
||||
c.routePeersNotifiers[r.Peer] = closerChan
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.addStateRoute()
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
|
@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@ -38,6 +39,7 @@ type GRPCServer struct {
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
|
@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/fnv"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -236,6 +237,10 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
var plainKey string
|
||||
var eventsToStore []func()
|
||||
@ -289,6 +294,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
var newKey *SetupKey
|
||||
var eventsToStore []func()
|
||||
@ -414,10 +423,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range autoGroupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group not found: %s", groupID)
|
||||
}
|
||||
|
||||
if group.IsGroupAll() {
|
||||
@ -432,26 +446,37 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI
|
||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range removedGroups {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
|
||||
continue
|
||||
}
|
||||
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
|
||||
continue
|
||||
}
|
||||
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
|
@ -485,9 +485,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return nil, status.NewSetupKeyNotFoundError(setupKey)
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
|
||||
}
|
||||
|
||||
if key.AccountID == "" {
|
||||
@ -570,7 +571,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@ -756,9 +757,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return "", status.NewSetupKeyNotFoundError(setupKey)
|
||||
}
|
||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
|
||||
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
@ -985,9 +987,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
return nil, status.NewSetupKeyNotFoundError(key)
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
@ -1005,7 +1008,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "setup key not found")
|
||||
return status.NewSetupKeyNotFoundError(setupKeyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1207,6 +1210,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
@ -1278,7 +1298,7 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
var setupKeys []*SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&setupKeys, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
||||
@ -1291,11 +1311,11 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||
var setupKey *SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
||||
@ -1306,8 +1326,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
|
||||
|
||||
// SaveSetupKey saves a setup key to the database.
|
||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save setup key to store")
|
||||
@ -1318,15 +1337,14 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// DeleteSetupKey deletes a setup key from the database.
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "setup key not found")
|
||||
return status.NewSetupKeyNotFoundError(keyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -103,8 +103,8 @@ func NewPeerLoginExpiredError() error {
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError(err error) error {
|
||||
return Errorf(NotFound, "setup key not found: %s", err)
|
||||
func NewSetupKeyNotFoundError(setupKeyID string) error {
|
||||
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
|
||||
}
|
||||
|
||||
func NewGetAccountFromStoreError(err error) error {
|
||||
@ -126,11 +126,6 @@ func NewAdminPermissionError() error {
|
||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
|
@ -71,8 +71,9 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
|
@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
conn, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
|
||||
|
||||
func (c *Conn) ioErrHandling(err error) error {
|
||||
if c.isClosed() {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
var wErr *websocket.CloseError
|
||||
@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
|
||||
return err
|
||||
}
|
||||
if wErr.Code == websocket.StatusNormalClosure {
|
||||
return io.EOF
|
||||
return net.ErrClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@ -57,7 +57,7 @@ func (p *Peer) Work() {
|
||||
for {
|
||||
n, err := p.conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf("failed to read message: %s", err)
|
||||
}
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user