mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
Merge branch 'main' into groups-get-account-refactoring
# Conflicts: # management/server/group.go # management/server/group/group.go # management/server/setupkey.go # management/server/sql_store.go # management/server/status/error.go # management/server/store.go
This commit is contained in:
commit
010a8bfdc1
@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -13,10 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
serv *grpc.Server
|
serv *grpc.Server
|
||||||
serverInstance *server.Server
|
serverInstance *server.Server
|
||||||
|
serverInstanceMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||||
|
@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||||
|
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
p.serverInstance = serverInstance
|
p.serverInstance = serverInstance
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
log.Printf("started daemon server: %v", split[1])
|
log.Printf("started daemon server: %v", split[1])
|
||||||
if err := p.serv.Serve(listen); err != nil {
|
if err := p.serv.Serve(listen); err != nil {
|
||||||
@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Stop(srv service.Service) error {
|
func (p *program) Stop(srv service.Service) error {
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
if p.serverInstance != nil {
|
if p.serverInstance != nil {
|
||||||
in := new(proto.DownRequest)
|
in := new(proto.DownRequest)
|
||||||
_, err := p.serverInstance.Down(p.ctx, in)
|
_, err := p.serverInstance.Down(p.ctx, in)
|
||||||
@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
log.Errorf("failed to stop daemon: %v", err)
|
log.Errorf("failed to stop daemon: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package bind
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
|
|||||||
|
|
||||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||||
|
|
||||||
return p.remoteConn.Close()
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
|
return rErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||||
|
@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
|||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
if err := p.remoteConn.Close(); err != nil {
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,7 +207,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkSignalDisconnected(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
|
@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.iceP2PIsActive() {
|
if conn.iceP2PIsActive() {
|
||||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
conn.wgProxyRelay = wgProxy
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.wgProxyRelay = wgProxy
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||||
|
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.wgProxyRelay = proxy
|
||||||
|
}
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
|||||||
func (s *State) GetRoutes() map[string]struct{} {
|
func (s *State) GetRoutes() map[string]struct{} {
|
||||||
s.Mux.RLock()
|
s.Mux.RLock()
|
||||||
defer s.Mux.RUnlock()
|
defer s.Mux.RUnlock()
|
||||||
return s.routes
|
return maps.Clone(s.routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
if receivedState.GetRoutes() != nil {
|
|
||||||
peerState.SetRoutes(receivedState.GetRoutes())
|
|
||||||
}
|
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerListChanged()
|
||||||
if found && ch != nil {
|
return nil
|
||||||
close(ch)
|
}
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
|
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peer]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerState.AddRoute(route)
|
||||||
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
// todo: consider to make sense of this notification or not
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peer]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.DeleteRoute(route)
|
||||||
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
|
|||||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
ch, found := d.changeNotify[peer]
|
ch, found := d.changeNotify[peer]
|
||||||
if !found || ch == nil {
|
if found {
|
||||||
ch = make(chan struct{})
|
return ch
|
||||||
d.changeNotify[peer] = ch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ch = make(chan struct{})
|
||||||
|
d.changeNotify[peer] = ch
|
||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
|
|||||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||||
|
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||||
|
ch, found := d.changeNotify[peerID]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(ch)
|
||||||
|
delete(d.changeNotify, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) notifyPeerListChanged() {
|
func (d *Status) notifyPeerListChanged() {
|
||||||
d.notifier.peerListChanged(d.numOfPeers())
|
d.notifier.peerListChanged(d.numOfPeers())
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|||||||
|
|
||||||
peerState.IP = ip
|
peerState.IP = ip
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -57,6 +57,9 @@ type WorkerICE struct {
|
|||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
|
|
||||||
|
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||||
|
lastKnownState ice.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||||
@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := w.agent.Close()
|
if err := w.agent.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
|
|
||||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
switch state {
|
||||||
w.conn.OnStatusChanged(StatusDisconnected)
|
case ice.ConnectionStateConnected:
|
||||||
|
w.lastKnownState = ice.ConnectionStateConnected
|
||||||
w.muxAgent.Lock()
|
return
|
||||||
agentCancel()
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||||
_ = agent.Close()
|
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||||
w.agent = nil
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
|
w.conn.OnStatusChanged(StatusDisconnected)
|
||||||
w.muxAgent.Unlock()
|
}
|
||||||
|
w.closeAgent(agentCancel)
|
||||||
|
default:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
return agent, nil
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||||
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
|
w.agent = nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
// wait local endpoint configuration
|
// wait local endpoint configuration
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
tempScore = float64(metricDiff) * 10
|
tempScore = float64(metricDiff) * 10
|
||||||
}
|
}
|
||||||
|
|
||||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||||
latency := time.Second
|
latency := 999 * time.Millisecond
|
||||||
if peerStatus.latency != 0 {
|
if peerStatus.latency != 0 {
|
||||||
latency = peerStatus.latency
|
latency = peerStatus.latency
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// avoid negative tempScore on the higher latency calculation
|
||||||
|
if latency > 1*time.Second {
|
||||||
|
latency = 999 * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// higher latency is worse score
|
||||||
tempScore += 1 - latency.Seconds()
|
tempScore += 1 - latency.Seconds()
|
||||||
|
|
||||||
if !peerStatus.relayed {
|
if !peerStatus.relayed {
|
||||||
@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case chosen == "":
|
case chosen == "":
|
||||||
var peers []string
|
var peers []string
|
||||||
@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
|||||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
_, found := c.routePeersNotifiers[r.Peer]
|
_, found := c.routePeersNotifiers[r.Peer]
|
||||||
if !found {
|
if found {
|
||||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
continue
|
||||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
closerChan := make(chan struct{})
|
||||||
|
c.routePeersNotifiers[r.Peer] = closerChan
|
||||||
|
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||||
c.removeStateRoute()
|
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||||
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||||
@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||||
}
|
}
|
||||||
if err := c.handler.RemoveRoute(); err != nil {
|
if err := c.handler.RemoveRoute(); err != nil {
|
||||||
@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, remove the allowed IPs from the previous peer first
|
// Otherwise, remove the allowed IPs from the previous peer first
|
||||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.addStateRoute()
|
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("add peer state route: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) addStateRoute() {
|
|
||||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state.AddRoute(c.handler.String())
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientNetwork) removeStateRoute() {
|
|
||||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state.DeleteRoute(c.handler.String())
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||||
go func() {
|
go func() {
|
||||||
c.routeUpdate <- update
|
c.routeUpdate <- update
|
||||||
|
@ -217,6 +217,11 @@ func (rm *Counter[Key, I, O]) Clear() {
|
|||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||||
|
rm.refCountMu.Lock()
|
||||||
|
defer rm.refCountMu.Unlock()
|
||||||
|
rm.idMu.Lock()
|
||||||
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
return json.Marshal(struct {
|
return json.Marshal(struct {
|
||||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
IDMap map[string][]Key `json:"idMap"`
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pb "github.com/golang/protobuf/proto" // nolint
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
@ -38,6 +39,7 @@ type GRPCServer struct {
|
|||||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
|
peerLocks sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||||
|
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||||
|
defer func() {
|
||||||
|
if unlock != nil {
|
||||||
|
unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
return claims.UserId, nil
|
return claims.UserId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||||
|
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||||
|
mtx := value.(*sync.RWMutex)
|
||||||
|
mtx.Lock()
|
||||||
|
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
start = time.Now()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
mtx.Unlock()
|
||||||
|
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
// maps internal internalStatus.Error to gRPC status.Error
|
// maps internal internalStatus.Error to gRPC status.Error
|
||||||
func mapError(ctx context.Context, err error) error {
|
func mapError(ctx context.Context, err error) error {
|
||||||
if e, ok := internalStatus.FromError(err); ok {
|
if e, ok := internalStatus.FromError(err); ok {
|
||||||
|
@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Peer == nil && req.PeerGroups == nil {
|
if req.Peer == nil && req.PeerGroups == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Peer != nil && req.PeerGroups != nil {
|
if req.Peer != nil && req.PeerGroups != nil {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -236,6 +237,10 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
|||||||
return nil, status.NewUserNotPartOfAccountError()
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var setupKey *SetupKey
|
var setupKey *SetupKey
|
||||||
var plainKey string
|
var plainKey string
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
@ -289,6 +294,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
return nil, status.NewUserNotPartOfAccountError()
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var oldKey *SetupKey
|
var oldKey *SetupKey
|
||||||
var newKey *SetupKey
|
var newKey *SetupKey
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
@ -414,10 +423,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, groupID := range autoGroupIDs {
|
for _, groupID := range autoGroupIDs {
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
group, ok := groups[groupID]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return status.Errorf(status.NotFound, "group not found: %s", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if group.IsGroupAll() {
|
if group.IsGroupAll() {
|
||||||
@ -432,26 +446,37 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI
|
|||||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, g := range removedGroups {
|
for _, g := range removedGroups {
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
group, ok := groups[g]
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||||
|
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addedGroups {
|
for _, g := range addedGroups {
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
group, ok := groups[g]
|
||||||
if err != nil {
|
if !ok {
|
||||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||||
|
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return eventsToStore
|
return eventsToStore
|
||||||
|
@ -485,9 +485,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
|||||||
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.NewSetupKeyNotFoundError(setupKey)
|
||||||
}
|
}
|
||||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.AccountID == "" {
|
if key.AccountID == "" {
|
||||||
@ -570,7 +571,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@ -756,9 +757,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
|||||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.NewSetupKeyNotFoundError(setupKey)
|
||||||
}
|
}
|
||||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
|
||||||
|
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
if accountID == "" {
|
if accountID == "" {
|
||||||
@ -985,9 +987,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
First(&setupKey, keyQueryCondition, key)
|
First(&setupKey, keyQueryCondition, key)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
return nil, status.NewSetupKeyNotFoundError(key)
|
||||||
}
|
}
|
||||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||||
}
|
}
|
||||||
return &setupKey, nil
|
return &setupKey, nil
|
||||||
}
|
}
|
||||||
@ -1005,7 +1008,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
return status.Errorf(status.NotFound, "setup key not found")
|
return status.NewSetupKeyNotFoundError(setupKeyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1207,6 +1210,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
return &group, nil
|
return &group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||||
|
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
||||||
|
var groups []*nbgroup.Group
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupsMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveGroup saves a group to the store.
|
// SaveGroup saves a group to the store.
|
||||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||||
@ -1278,7 +1298,7 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
|||||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||||
var setupKeys []*SetupKey
|
var setupKeys []*SetupKey
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
Find(&setupKeys, accountIDCondition, accountID)
|
Find(&setupKeys, accountIDCondition, accountID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
||||||
@ -1291,11 +1311,11 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
|
|||||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||||
var setupKey *SetupKey
|
var setupKey *SetupKey
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
||||||
@ -1306,8 +1326,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
|
|||||||
|
|
||||||
// SaveSetupKey saves a setup key to the database.
|
// SaveSetupKey saves a setup key to the database.
|
||||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to save setup key to store")
|
return status.Errorf(status.Internal, "failed to save setup key to store")
|
||||||
@ -1318,15 +1337,14 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
|
|||||||
|
|
||||||
// DeleteSetupKey deletes a setup key from the database.
|
// DeleteSetupKey deletes a setup key from the database.
|
||||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
|
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||||
Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
return status.Errorf(status.NotFound, "setup key not found")
|
return status.NewSetupKeyNotFoundError(keyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -103,8 +103,8 @@ func NewPeerLoginExpiredError() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||||
func NewSetupKeyNotFoundError(err error) error {
|
func NewSetupKeyNotFoundError(setupKeyID string) error {
|
||||||
return Errorf(NotFound, "setup key not found: %s", err)
|
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGetAccountFromStoreError(err error) error {
|
func NewGetAccountFromStoreError(err error) error {
|
||||||
@ -126,11 +126,6 @@ func NewAdminPermissionError() error {
|
|||||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
|
||||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
|
||||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||||
func NewInvalidKeyIDError() error {
|
func NewInvalidKeyIDError() error {
|
||||||
return Errorf(InvalidArgument, "invalid key ID")
|
return Errorf(InvalidArgument, "invalid key ID")
|
||||||
|
@ -71,8 +71,9 @@ type Store interface {
|
|||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
|
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||||
|
@ -3,7 +3,6 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
|||||||
conn, ok := c.conns[id]
|
conn, ok := c.conns[id]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, io.EOF
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.conn != connReference {
|
if conn.conn != connReference {
|
||||||
return 0, io.EOF
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: use buffer pool instead of create new transport msg.
|
// todo: use buffer pool instead of create new transport msg.
|
||||||
@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
|
|
||||||
container, ok := c.conns[id]
|
container, ok := c.conns[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("connection already closed")
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
|||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
msg, ok := <-c.messageChan
|
msg, ok := <-c.messageChan
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, io.EOF
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
n = copy(b, msg.Payload)
|
n = copy(b, msg.Payload)
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
|
|||||||
|
|
||||||
func (c *Conn) ioErrHandling(err error) error {
|
func (c *Conn) ioErrHandling(err error) error {
|
||||||
if c.isClosed() {
|
if c.isClosed() {
|
||||||
return io.EOF
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
var wErr *websocket.CloseError
|
var wErr *websocket.CloseError
|
||||||
@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if wErr.Code == websocket.StatusNormalClosure {
|
if wErr.Code == websocket.StatusNormalClosure {
|
||||||
return io.EOF
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -57,7 +57,7 @@ func (p *Peer) Work() {
|
|||||||
for {
|
for {
|
||||||
n, err := p.conn.Read(buf)
|
n, err := p.conn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
p.log.Errorf("failed to read message: %s", err)
|
p.log.Errorf("failed to read message: %s", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user