netbird/client/internal/dns/service_memory.go
Zoltan Papp 7ebe58f20a
Feature/permanent dns (#967)
* Add DNS list argument for mobile client

* Write testable code

Many places are checked the wgInterface != nil condition.
It is doing it just because to avoid the real wgInterface creation for tests.
Instead of this involve a wgInterface interface what is moc-able.

* Refactor the DNS server internal code structure

With the fake resolver has been involved several
if-else statement and generated some unused
variables to distinguish the listener and fake
resolver solutions at running time. With this
commit the fake resolver and listener based
solution has been moved into two separated
structure. Name of this layer is the 'service'.
With this modification the unit test looks
simpler and open the option to add new logic for
the permanent DNS service usage for mobile
systems.



* Remove is running check in test

We can not ensure the state well so remove this
check. The test will fail if the server is not
running well.
2023-07-14 21:56:22 +02:00

140 lines
3.1 KiB
Go

package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}