mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-23 11:12:01 +02:00
Reduce complexity
This commit is contained in:
parent
d711172f67
commit
9490e9095b
@ -67,6 +67,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
||||
}
|
||||
|
||||
func (e *endpoint) Wait() {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
@ -74,6 +75,7 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
}
|
||||
|
||||
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
// not required
|
||||
}
|
||||
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
|
@ -41,16 +41,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return f.handleEchoResponse(conn, id)
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
@ -70,10 +61,18 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
@ -82,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@ -102,7 +101,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return false
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
|
@ -732,16 +732,24 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
// Default deny if no rules match
|
||||
matched := false
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
// Check destination
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
continue
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
matched = true
|
||||
if rule.action == firewall.ActionDrop {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if source matches any source prefix
|
||||
sourceMatched := false
|
||||
for _, src := range rule.sources {
|
||||
if src.Contains(srcAddr) {
|
||||
@ -750,29 +758,21 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check protocol
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check ports if specified
|
||||
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
|
||||
continue
|
||||
}
|
||||
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if rule.action == firewall.ActionDrop {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
return matched
|
||||
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
|
||||
return false
|
||||
}
|
||||
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
|
Loading…
x
Reference in New Issue
Block a user