Reduce complexity

This commit is contained in:
Viktor Liu 2025-01-03 11:50:25 +01:00
parent d711172f67
commit 9490e9095b
3 changed files with 49 additions and 48 deletions

View File

@ -67,6 +67,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
@ -74,6 +75,7 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {

View File

@ -41,16 +41,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
return f.handleEchoResponse(conn, id)
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
@ -70,10 +61,18 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
return true
}
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) bool {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return false
return true
}
response := make([]byte, f.endpoint.mtu)
@ -82,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return false
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@ -102,7 +101,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return false
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v", id)

View File

@ -732,16 +732,24 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
// Default deny if no rules match
matched := false
for _, rule := range m.routeRules {
// Check destination
if !rule.destination.Contains(dstAddr) {
continue
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
matched = true
if rule.action == firewall.ActionDrop {
return false
}
}
}
return matched
}
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
return false
}
// Check if source matches any source prefix
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
@ -750,29 +758,21 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
}
}
if !sourceMatched {
continue
}
// Check protocol
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
continue
}
// Check ports if specified
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
continue
}
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
continue
}
matched = true
if rule.action == firewall.ActionDrop {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
return matched
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
return false
}
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
return false
}
return true
}
// SetNetwork of the wireguard interface to which filtering applied