mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 09:50:47 +01:00
104 lines
2.6 KiB
Go
104 lines
2.6 KiB
Go
|
package dns
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net"
|
||
|
|
||
|
"github.com/google/gopacket"
|
||
|
"github.com/google/gopacket/layers"
|
||
|
"github.com/miekg/dns"
|
||
|
"golang.zx2c4.com/wireguard/tun"
|
||
|
)
|
||
|
|
||
|
type responseWriter struct {
|
||
|
local net.Addr
|
||
|
remote net.Addr
|
||
|
packet gopacket.Packet
|
||
|
device tun.Device
|
||
|
}
|
||
|
|
||
|
// LocalAddr returns the net.Addr of the server
|
||
|
func (r *responseWriter) LocalAddr() net.Addr {
|
||
|
return r.local
|
||
|
}
|
||
|
|
||
|
// RemoteAddr returns the net.Addr of the client that sent the current request.
|
||
|
func (r *responseWriter) RemoteAddr() net.Addr {
|
||
|
return r.remote
|
||
|
}
|
||
|
|
||
|
// WriteMsg writes a reply back to the client.
|
||
|
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||
|
buff, err := msg.Pack()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = r.Write(buff)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Write writes a raw buffer back to the client.
|
||
|
func (r *responseWriter) Write(data []byte) (int, error) {
|
||
|
var ip gopacket.SerializableLayer
|
||
|
|
||
|
// Get the UDP layer
|
||
|
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
|
||
|
udp := udpLayer.(*layers.UDP)
|
||
|
|
||
|
// Swap the source and destination addresses for the response
|
||
|
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
|
||
|
|
||
|
// Check if it's an IPv4 packet
|
||
|
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
|
||
|
ipv4 := ipv4Layer.(*layers.IPv4)
|
||
|
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
|
||
|
ip = ipv4
|
||
|
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
|
||
|
ipv6 := ipv6Layer.(*layers.IPv6)
|
||
|
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
|
||
|
ip = ipv6
|
||
|
}
|
||
|
|
||
|
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
|
||
|
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
|
||
|
}
|
||
|
|
||
|
// Serialize the packet
|
||
|
buffer := gopacket.NewSerializeBuffer()
|
||
|
options := gopacket.SerializeOptions{
|
||
|
ComputeChecksums: true,
|
||
|
FixLengths: true,
|
||
|
}
|
||
|
|
||
|
payload := gopacket.Payload(data)
|
||
|
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
|
||
|
if err != nil {
|
||
|
return 0, fmt.Errorf("failed to serialize packet: %v", err)
|
||
|
}
|
||
|
|
||
|
send := buffer.Bytes()
|
||
|
sendBuffer := make([]byte, 40, len(send)+40)
|
||
|
sendBuffer = append(sendBuffer, send...)
|
||
|
|
||
|
return r.device.Write([][]byte{sendBuffer}, 40)
|
||
|
}
|
||
|
|
||
|
// Close closes the connection.
|
||
|
func (r *responseWriter) Close() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// TsigStatus returns the status of the Tsig.
|
||
|
func (r *responseWriter) TsigStatus() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// TsigTimersOnly sets the tsig timers only boolean.
|
||
|
func (r *responseWriter) TsigTimersOnly(bool) {
|
||
|
}
|
||
|
|
||
|
// Hijack lets the caller take over the connection.
|
||
|
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||
|
func (r *responseWriter) Hijack() {
|
||
|
}
|