netbird/client/internal/dnsfwd/manager.go

107 lines
2.3 KiB
Go
Raw Normal View History

2024-12-10 19:14:09 +01:00
package dnsfwd
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
2024-12-10 19:14:09 +01:00
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
2024-12-10 19:14:09 +01:00
ListenPort = 5353
dnsTTL = 60 //seconds
2024-12-10 19:14:09 +01:00
)
type Manager struct {
firewall firewall.Manager
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
2024-12-10 19:14:09 +01:00
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
2024-12-10 19:14:09 +01:00
}
func (m *Manager) Start(domains []string) error {
2024-12-10 19:14:09 +01:00
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
2024-12-10 19:14:09 +01:00
return nil
}
if err := m.allowDNSFirewall(); err != nil {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
2024-12-10 19:14:09 +01:00
go func() {
if err := m.dnsForwarder.Listen(); err != nil {
2024-12-10 19:14:09 +01:00
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
}()
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
}
2024-12-10 19:14:09 +01:00
func (m *Manager) Stop(ctx context.Context) error {
if m.dnsForwarder == nil {
2024-12-10 19:14:09 +01:00
return nil
}
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}
m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
2024-12-10 19:14:09 +01:00
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
2024-12-10 19:14:09 +01:00
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
2024-12-10 19:14:09 +01:00
return nil
}
func (h *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
2024-12-10 19:14:09 +01:00
}
}
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
2024-12-10 19:14:09 +01:00
}