mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-24 06:48:40 +01:00
ddc365f7a0
--------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
107 lines
2.3 KiB
Go
107 lines
2.3 KiB
Go
package dnsfwd
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
)
|
|
|
|
const (
|
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
|
ListenPort = 5353
|
|
dnsTTL = 60 //seconds
|
|
)
|
|
|
|
type Manager struct {
|
|
firewall firewall.Manager
|
|
|
|
fwRules []firewall.Rule
|
|
dnsForwarder *DNSForwarder
|
|
}
|
|
|
|
func NewManager(fw firewall.Manager) *Manager {
|
|
return &Manager{
|
|
firewall: fw,
|
|
}
|
|
}
|
|
|
|
func (m *Manager) Start(domains []string) error {
|
|
log.Infof("starting DNS forwarder")
|
|
if m.dnsForwarder != nil {
|
|
return nil
|
|
}
|
|
|
|
if err := m.allowDNSFirewall(); err != nil {
|
|
return err
|
|
}
|
|
|
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
|
go func() {
|
|
if err := m.dnsForwarder.Listen(domains); err != nil {
|
|
// todo handle close error if it is exists
|
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) UpdateDomains(domains []string) {
|
|
if m.dnsForwarder == nil {
|
|
return
|
|
}
|
|
|
|
m.dnsForwarder.UpdateDomains(domains)
|
|
}
|
|
|
|
func (m *Manager) Stop(ctx context.Context) error {
|
|
if m.dnsForwarder == nil {
|
|
return nil
|
|
}
|
|
|
|
var mErr *multierror.Error
|
|
if err := m.dropDNSFirewall(); err != nil {
|
|
mErr = multierror.Append(mErr, err)
|
|
}
|
|
|
|
if err := m.dnsForwarder.Close(ctx); err != nil {
|
|
mErr = multierror.Append(mErr, err)
|
|
}
|
|
|
|
m.dnsForwarder = nil
|
|
return nberrors.FormatErrorOrNil(mErr)
|
|
}
|
|
|
|
func (h *Manager) allowDNSFirewall() error {
|
|
dport := &firewall.Port{
|
|
IsRange: false,
|
|
Values: []int{ListenPort},
|
|
}
|
|
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
|
if err != nil {
|
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
|
return err
|
|
}
|
|
h.fwRules = dnsRules
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *Manager) dropDNSFirewall() error {
|
|
var mErr *multierror.Error
|
|
for _, rule := range h.fwRules {
|
|
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
|
}
|
|
}
|
|
|
|
h.fwRules = nil
|
|
return nberrors.FormatErrorOrNil(mErr)
|
|
}
|