mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-05 21:49:03 +01:00
f
This commit is contained in:
parent
16a2867d69
commit
9d820f1eae
@ -158,7 +158,8 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e
|
|||||||
|
|
||||||
log.Debugf("registering handler %s", handler)
|
log.Debugf("registering handler %s", handler)
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
pattern := dns.Fqdn(domain)
|
wosuff, _ := strings.CutPrefix(domain, "*.")
|
||||||
|
pattern := dns.Fqdn(wosuff)
|
||||||
s.service.RegisterMux(pattern, handler)
|
s.service.RegisterMux(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,33 +43,33 @@ func New(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) String() string {
|
func (d *DnsInterceptor) String() string {
|
||||||
s, err := h.route.Domains.String()
|
s, err := d.route.Domains.String()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h.route.Domains.PunycodeString()
|
return d.route.Domains.PunycodeString()
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) AddRoute(context.Context) error {
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h)
|
return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) RemoveRoute() error {
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
h.mu.Lock()
|
d.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
// Remove all intercepted IPs
|
// Remove all intercepted IPs
|
||||||
for key, prefix := range h.interceptedIPs {
|
for key, prefix := range d.interceptedIPs {
|
||||||
if _, err := h.routeRefCounter.Decrement(prefix); err != nil {
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
|
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
if h.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(h.interceptedIPs, key)
|
delete(d.interceptedIPs, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove from mux
|
// TODO: remove from mux
|
||||||
@ -77,15 +77,15 @@ func (h *DnsInterceptor) RemoveRoute() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
h.mu.Lock()
|
d.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
h.currentPeerKey = peerKey
|
d.currentPeerKey = peerKey
|
||||||
|
|
||||||
// Re-add all intercepted IPs for the new peer
|
// Re-add all intercepted IPs for the new peer
|
||||||
for _, prefix := range h.interceptedIPs {
|
for _, prefix := range d.interceptedIPs {
|
||||||
if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||||
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
|
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,71 +93,91 @@ func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) RemoveAllowedIPs() error {
|
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||||
h.mu.Lock()
|
d.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
if h.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
for _, prefix := range h.interceptedIPs {
|
for _, prefix := range d.interceptedIPs {
|
||||||
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.currentPeerKey = ""
|
d.currentPeerKey = ""
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS implements the dns.Handler interface
|
// ServeDNS implements the dns.Handler interface
|
||||||
func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Debugf("Received DNS request: %v", r)
|
log.Debugf("received DNS request: %v", r)
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create response interceptor to capture the response
|
if err := d.writeMsg(w, r); err != nil {
|
||||||
interceptor := &responseInterceptor{
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
ResponseWriter: w,
|
|
||||||
handler: h,
|
|
||||||
question: r.Question[0],
|
|
||||||
answered: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let the request pass through with our interceptor
|
|
||||||
err := interceptor.WriteMsg(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed writing DNS response: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
|
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||||
h.mu.Lock()
|
if r == nil || len(r.Answer) == 0 {
|
||||||
defer h.mu.Unlock()
|
return w.WriteMsg(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ans := range r.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := ans.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d.processMatch(r.Question[0].Name, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.WriteMsg(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
network := netip.PrefixFrom(ip, ip.BitLen())
|
network := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
key := fmt.Sprintf("%s:%s", domain, network.String())
|
key := fmt.Sprintf("%s:%s", domain, network.String())
|
||||||
|
|
||||||
if _, exists := h.interceptedIPs[key]; exists {
|
if _, exists := d.interceptedIPs[key]; exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil {
|
if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil {
|
||||||
log.Errorf("Failed to add route for IP %s: %v", network, err)
|
log.Errorf("Failed to add route for IP %s: %v", network, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil {
|
||||||
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
||||||
// Rollback route addition
|
// Rollback route addition
|
||||||
if _, err := h.routeRefCounter.Decrement(network); err != nil {
|
if _, err := d.routeRefCounter.Decrement(network); err != nil {
|
||||||
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
|
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.interceptedIPs[key] = network
|
d.interceptedIPs[key] = network
|
||||||
log.Debugf("Added route for domain %s -> %s", domain, network)
|
log.Debugf("Added route for domain %s -> %s", domain, network)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user