This commit is contained in:
Viktor Liu 2024-12-10 17:56:41 +01:00
parent 16a2867d69
commit 9d820f1eae
2 changed files with 71 additions and 50 deletions

View File

@ -158,7 +158,8 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e
log.Debugf("registering handler %s", handler) log.Debugf("registering handler %s", handler)
for _, domain := range domains { for _, domain := range domains {
pattern := dns.Fqdn(domain) wosuff, _ := strings.CutPrefix(domain, "*.")
pattern := dns.Fqdn(wosuff)
s.service.RegisterMux(pattern, handler) s.service.RegisterMux(pattern, handler)
} }

View File

@ -43,33 +43,33 @@ func New(
} }
} }
func (h *DnsInterceptor) String() string { func (d *DnsInterceptor) String() string {
s, err := h.route.Domains.String() s, err := d.route.Domains.String()
if err != nil { if err != nil {
return h.route.Domains.PunycodeString() return d.route.Domains.PunycodeString()
} }
return s return s
} }
func (h *DnsInterceptor) AddRoute(context.Context) error { func (d *DnsInterceptor) AddRoute(context.Context) error {
return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h) return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d)
} }
func (h *DnsInterceptor) RemoveRoute() error { func (d *DnsInterceptor) RemoveRoute() error {
h.mu.Lock() d.mu.Lock()
defer h.mu.Unlock() defer d.mu.Unlock()
// Remove all intercepted IPs // Remove all intercepted IPs
for key, prefix := range h.interceptedIPs { for key, prefix := range d.interceptedIPs {
if _, err := h.routeRefCounter.Decrement(prefix); err != nil { if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove route for IP %s: %v", prefix, err) log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
} }
if h.currentPeerKey != "" { if d.currentPeerKey != "" {
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
} }
} }
delete(h.interceptedIPs, key) delete(d.interceptedIPs, key)
} }
// TODO: remove from mux // TODO: remove from mux
@ -77,15 +77,15 @@ func (h *DnsInterceptor) RemoveRoute() error {
return nil return nil
} }
func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error { func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
h.mu.Lock() d.mu.Lock()
defer h.mu.Unlock() defer d.mu.Unlock()
h.currentPeerKey = peerKey d.currentPeerKey = peerKey
// Re-add all intercepted IPs for the new peer // Re-add all intercepted IPs for the new peer
for _, prefix := range h.interceptedIPs { for _, prefix := range d.interceptedIPs {
if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", prefix, err) log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
} }
} }
@ -93,71 +93,91 @@ func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
return nil return nil
} }
func (h *DnsInterceptor) RemoveAllowedIPs() error { func (d *DnsInterceptor) RemoveAllowedIPs() error {
h.mu.Lock() d.mu.Lock()
defer h.mu.Unlock() defer d.mu.Unlock()
if h.currentPeerKey != "" { if d.currentPeerKey != "" {
for _, prefix := range h.interceptedIPs { for _, prefix := range d.interceptedIPs {
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
} }
} }
} }
h.currentPeerKey = "" d.currentPeerKey = ""
return nil return nil
} }
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Debugf("Received DNS request: %v", r) log.Debugf("received DNS request: %v", r)
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
// Create response interceptor to capture the response if err := d.writeMsg(w, r); err != nil {
interceptor := &responseInterceptor{ log.Errorf("failed writing DNS response: %v", err)
ResponseWriter: w,
handler: h,
question: r.Question[0],
answered: false,
}
// Let the request pass through with our interceptor
err := interceptor.WriteMsg(r)
if err != nil {
log.Errorf("Failed writing DNS response: %v", err)
} }
} }
func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) { func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
h.mu.Lock() if r == nil || len(r.Answer) == 0 {
defer h.mu.Unlock() return w.WriteMsg(r)
}
for _, ans := range r.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
ip = addr
default:
continue
}
d.processMatch(r.Question[0].Name, ip)
}
return w.WriteMsg(r)
}
func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
d.mu.Lock()
defer d.mu.Unlock()
network := netip.PrefixFrom(ip, ip.BitLen()) network := netip.PrefixFrom(ip, ip.BitLen())
key := fmt.Sprintf("%s:%s", domain, network.String()) key := fmt.Sprintf("%s:%s", domain, network.String())
if _, exists := h.interceptedIPs[key]; exists { if _, exists := d.interceptedIPs[key]; exists {
return return
} }
if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil { if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil {
log.Errorf("Failed to add route for IP %s: %v", network, err) log.Errorf("Failed to add route for IP %s: %v", network, err)
return return
} }
if h.currentPeerKey != "" { if d.currentPeerKey != "" {
if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil { if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err) log.Errorf("Failed to add allowed IP %s: %v", network, err)
// Rollback route addition // Rollback route addition
if _, err := h.routeRefCounter.Decrement(network); err != nil { if _, err := d.routeRefCounter.Decrement(network); err != nil {
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err) log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
} }
return return
} }
} }
h.interceptedIPs[key] = network d.interceptedIPs[key] = network
log.Debugf("Added route for domain %s -> %s", domain, network) log.Debugf("Added route for domain %s -> %s", domain, network)
} }