mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 18:01:13 +01:00
124 lines
2.9 KiB
Go
124 lines
2.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
|
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
"github.com/netbirdio/netbird/client/proto"
|
|
)
|
|
|
|
type packetTracer interface {
|
|
TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error)
|
|
}
|
|
|
|
func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.connectClient == nil {
|
|
return nil, fmt.Errorf("connect client not initialized")
|
|
}
|
|
engine := s.connectClient.Engine()
|
|
if engine == nil {
|
|
return nil, fmt.Errorf("engine not initialized")
|
|
}
|
|
|
|
fwManager := engine.GetFirewallManager()
|
|
if fwManager == nil {
|
|
return nil, fmt.Errorf("firewall manager not initialized")
|
|
}
|
|
|
|
tracer, ok := fwManager.(packetTracer)
|
|
if !ok {
|
|
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
|
}
|
|
|
|
srcIP := net.ParseIP(req.GetSourceIp())
|
|
if req.GetSourceIp() == "self" {
|
|
srcIP = engine.GetWgAddr()
|
|
}
|
|
|
|
dstIP := net.ParseIP(req.GetDestinationIp())
|
|
if req.GetDestinationIp() == "self" {
|
|
dstIP = engine.GetWgAddr()
|
|
}
|
|
|
|
if srcIP == nil || dstIP == nil {
|
|
return nil, fmt.Errorf("invalid IP address")
|
|
}
|
|
|
|
var tcpState *uspfilter.TCPState
|
|
if flags := req.GetTcpFlags(); flags != nil {
|
|
tcpState = &uspfilter.TCPState{
|
|
SYN: flags.GetSyn(),
|
|
ACK: flags.GetAck(),
|
|
FIN: flags.GetFin(),
|
|
RST: flags.GetRst(),
|
|
PSH: flags.GetPsh(),
|
|
URG: flags.GetUrg(),
|
|
}
|
|
}
|
|
|
|
var dir fw.RuleDirection
|
|
switch req.GetDirection() {
|
|
case "in":
|
|
dir = fw.RuleDirectionIN
|
|
case "out":
|
|
dir = fw.RuleDirectionOUT
|
|
default:
|
|
return nil, fmt.Errorf("invalid direction")
|
|
}
|
|
|
|
var protocol fw.Protocol
|
|
switch req.GetProtocol() {
|
|
case "tcp":
|
|
protocol = fw.ProtocolTCP
|
|
case "udp":
|
|
protocol = fw.ProtocolUDP
|
|
case "icmp":
|
|
protocol = fw.ProtocolICMP
|
|
default:
|
|
return nil, fmt.Errorf("invalid protocolcol")
|
|
}
|
|
|
|
builder := &uspfilter.PacketBuilder{
|
|
SrcIP: srcIP,
|
|
DstIP: dstIP,
|
|
Protocol: protocol,
|
|
SrcPort: uint16(req.GetSourcePort()),
|
|
DstPort: uint16(req.GetDestinationPort()),
|
|
Direction: dir,
|
|
TCPState: tcpState,
|
|
ICMPType: uint8(req.GetIcmpType()),
|
|
ICMPCode: uint8(req.GetIcmpCode()),
|
|
}
|
|
trace, err := tracer.TracePacketFromBuilder(builder)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("trace packet: %w", err)
|
|
}
|
|
|
|
resp := &proto.TracePacketResponse{}
|
|
|
|
for _, result := range trace.Results {
|
|
stage := &proto.TraceStage{
|
|
Name: result.Stage.String(),
|
|
Message: result.Message,
|
|
Allowed: result.Allowed,
|
|
}
|
|
if result.ForwarderAction != nil {
|
|
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
|
stage.ForwardingDetails = &details
|
|
}
|
|
resp.Stages = append(resp.Stages, stage)
|
|
}
|
|
|
|
if len(trace.Results) > 0 {
|
|
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
|
}
|
|
|
|
return resp, nil
|
|
}
|