From 9b5c0439e90af22616e5bd45e1b70deeeed6f0e4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 16 Jan 2025 20:28:38 +0100 Subject: [PATCH] Make debug ops a bit safer --- client/server/debug.go | 17 +++++++++++++---- client/server/trace.go | 12 ++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/client/server/debug.go b/client/server/debug.go index 9de80173b..a7ab855e8 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -496,11 +496,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( log.SetLevel(level) - if s.connectClient != nil && - s.connectClient.Engine() != nil && - s.connectClient.Engine().GetFirewallManager() != nil { - s.connectClient.Engine().GetFirewallManager().SetLogLevel(level) + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") } + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, fmt.Errorf("firewall manager not initialized") + } + + fwManager.SetLogLevel(level) log.Infof("Log level set to %s", level.String()) diff --git a/client/server/trace.go b/client/server/trace.go index a8004f446..66b83d8cf 100644 --- a/client/server/trace.go +++ b/client/server/trace.go @@ -18,11 +18,15 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( s.mutex.Lock() defer s.mutex.Unlock() - if s.connectClient == nil || s.connectClient.Engine() == nil { + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") + } + engine := s.connectClient.Engine() + if engine == nil { return nil, fmt.Errorf("engine not initialized") } - fwManager := s.connectClient.Engine().GetFirewallManager() + fwManager := engine.GetFirewallManager() if fwManager == nil { return nil, fmt.Errorf("firewall manager not initialized") } @@ -34,12 +38,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( srcIP := net.ParseIP(req.GetSourceIp()) if req.GetSourceIp() == "self" { - srcIP = s.connectClient.Engine().GetWgAddr() + srcIP = engine.GetWgAddr() } dstIP := net.ParseIP(req.GetDestinationIp()) if req.GetDestinationIp() == "self" { - dstIP = s.connectClient.Engine().GetWgAddr() + dstIP = engine.GetWgAddr() } if srcIP == nil || dstIP == nil {