mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-01 20:43:43 +01:00
7ebe58f20a
* Add DNS list argument for mobile client * Write testable code Many places are checked the wgInterface != nil condition. It is doing it just because to avoid the real wgInterface creation for tests. Instead of this involve a wgInterface interface what is moc-able. * Refactor the DNS server internal code structure With the fake resolver has been involved several if-else statement and generated some unused variables to distinguish the listener and fake resolver solutions at running time. With this commit the fake resolver and listener based solution has been moved into two separated structure. Name of this layer is the 'service'. With this modification the unit test looks simpler and open the option to add new logic for the permanent DNS service usage for mobile systems. * Remove is running check in test We can not ensure the state well so remove this check. The test will fail if the server is not running well.
140 lines
3.1 KiB
Go
140 lines
3.1 KiB
Go
package dns
|
|
|
|
import (
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type serviceViaMemory struct {
|
|
wgInterface WGIface
|
|
dnsMux *dns.ServeMux
|
|
runtimeIP string
|
|
runtimePort int
|
|
udpFilterHookID string
|
|
listenerIsRunning bool
|
|
listenerFlagLock sync.Mutex
|
|
}
|
|
|
|
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
|
s := &serviceViaMemory{
|
|
wgInterface: wgIface,
|
|
dnsMux: dns.NewServeMux(),
|
|
|
|
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
|
runtimePort: defaultPort,
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *serviceViaMemory) Listen() error {
|
|
s.listenerFlagLock.Lock()
|
|
defer s.listenerFlagLock.Unlock()
|
|
|
|
if s.listenerIsRunning {
|
|
return nil
|
|
}
|
|
|
|
var err error
|
|
s.udpFilterHookID, err = s.filterDNSTraffic()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.listenerIsRunning = true
|
|
|
|
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
|
return nil
|
|
}
|
|
|
|
func (s *serviceViaMemory) Stop() {
|
|
s.listenerFlagLock.Lock()
|
|
defer s.listenerFlagLock.Unlock()
|
|
|
|
if !s.listenerIsRunning {
|
|
return
|
|
}
|
|
|
|
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
}
|
|
|
|
s.listenerIsRunning = false
|
|
}
|
|
|
|
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
|
s.dnsMux.Handle(pattern, handler)
|
|
}
|
|
|
|
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
|
s.dnsMux.HandleRemove(pattern)
|
|
}
|
|
|
|
func (s *serviceViaMemory) RuntimePort() int {
|
|
return s.runtimePort
|
|
}
|
|
|
|
func (s *serviceViaMemory) RuntimeIP() string {
|
|
return s.runtimeIP
|
|
}
|
|
|
|
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
|
filter := s.wgInterface.GetFilter()
|
|
if filter == nil {
|
|
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
|
}
|
|
|
|
firstLayerDecoder := layers.LayerTypeIPv4
|
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
|
firstLayerDecoder = layers.LayerTypeIPv6
|
|
}
|
|
|
|
hook := func(packetData []byte) bool {
|
|
// Decode the packet
|
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
|
|
// Get the UDP layer
|
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
udp := udpLayer.(*layers.UDP)
|
|
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(udp.Payload); err != nil {
|
|
log.Tracef("parse DNS request: %v", err)
|
|
return true
|
|
}
|
|
|
|
writer := responseWriter{
|
|
packet: packet,
|
|
device: s.wgInterface.GetDevice().Device,
|
|
}
|
|
go s.dnsMux.ServeDNS(&writer, msg)
|
|
return true
|
|
}
|
|
|
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
|
}
|
|
|
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
// Calculate the last IP in the CIDR range
|
|
var endIP net.IP
|
|
for i := 0; i < len(network.IP); i++ {
|
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
}
|
|
|
|
// convert to big.Int
|
|
endInt := big.NewInt(0)
|
|
endInt.SetBytes(endIP)
|
|
|
|
// subtract fromEnd from the last ip
|
|
fromEndBig := big.NewInt(int64(fromEnd))
|
|
resultInt := big.NewInt(0)
|
|
resultInt.Sub(endInt, fromEndBig)
|
|
|
|
return net.IP(resultInt.Bytes()).String()
|
|
}
|