diff --git a/internal/glance/widget-dns-stats.go b/internal/glance/widget-dns-stats.go index 8944603..73de4d9 100644 --- a/internal/glance/widget-dns-stats.go +++ b/internal/glance/widget-dns-stats.go @@ -7,7 +7,6 @@ import ( "errors" "html/template" "io" - "log/slog" "net/http" "os" "sort" @@ -15,9 +14,6 @@ import ( "time" ) -// Global HTTP client for reuse -var httpClient = &http.Client{} - var dnsStatsWidgetTemplate = mustParseTemplate("dns-stats.html", "widget-base.html") type dnsStatsWidget struct { @@ -235,7 +231,8 @@ func fetchAdguardStats(instanceURL string, allowInsecure bool, username, passwor return stats, nil } -type piholeStatsResponse struct { +// Legacy Pi-hole stats response (before v6) +type legacyPiholeStatsResponse struct { TotalQueries int `json:"dns_queries_today"` QueriesSeries piholeQueriesSeries `json:"domains_over_time"` BlockedQueries int `json:"ads_blocked_today"` @@ -245,6 +242,24 @@ type piholeStatsResponse struct { DomainsBlocked int `json:"domains_being_blocked"` } +// Pi-hole v6+ response format +type piholeStatsResponse struct { + Queries struct { + Total int `json:"total"` + Blocked int `json:"blocked"` + PercentBlocked float64 `json:"percent_blocked"` + } `json:"queries"` + Gravity struct { + DomainsBlocked int `json:"domains_being_blocked"` + } `json:"gravity"` + //Note we do not need the full structure. We extract the values needed + //Adding dummy fields to allow easier json parsing. + QueriesSeries piholeQueriesSeries `json:"domains_over_time"` // Will always be empty + BlockedSeries map[int64]int `json:"ads_over_time"` // Will always be empty. +} + +type piholeTopDomainsResponse map[string]int + // If the user has query logging disabled it's possible for domains_over_time to be returned as an // empty array rather than a map which will prevent unmashalling the rest of the data so we use // custom unmarshal behavior to fallback to an empty map. @@ -284,7 +299,14 @@ func (p *piholeTopBlockedDomains) UnmarshalJSON(data []byte) error { } // piholeGetSID retrieves a new SID from Pi-hole using the app password. -func piholeGetSID(instanceURL, appPassword string) (string, error) { +func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, error) { + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth" requestBody := []byte(`{"password":"` + appPassword + `"}`) @@ -294,7 +316,7 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) { } request.Header.Set("Content-Type", "application/json") - response, err := httpClient.Do(request) + response, err := client.Do(request) if err != nil { return "", errors.New("failed to send authentication request: " + err.Error()) } @@ -326,8 +348,121 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) { return jsonResponse.Session.SID, nil } +// fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+. +func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true&sid=" + sid + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, err + } + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + return decodeJsonFromRequest[piholeTopDomainsResponse](client, request) +} + +// Helper functions to process the responses +func parsePiholeStats(r *piholeStatsResponse, topDomains piholeTopDomainsResponse) *dnsStats { + + stats := &dnsStats{ + TotalQueries: r.Queries.Total, + BlockedQueries: r.Queries.Blocked, + BlockedPercent: int(r.Queries.PercentBlocked), + DomainsBlocked: r.Gravity.DomainsBlocked, + } + + if len(topDomains) > 0 { + domains := make([]dnsStatsBlockedDomain, 0, len(topDomains)) + for domain, count := range topDomains { + domains = append(domains, dnsStatsBlockedDomain{ + Domain: domain, + PercentBlocked: int(float64(count) / float64(r.Queries.Blocked) * 100), // Calculate percentage here + }) + } + + sort.Slice(domains, func(a, b int) bool { + return domains[a].PercentBlocked > domains[b].PercentBlocked + }) + stats.TopBlockedDomains = domains[:min(len(domains), 5)] + } + + return stats +} +func parsePiholeStatsLegacy(r *legacyPiholeStatsResponse, noGraph bool) *dnsStats { + + stats := &dnsStats{ + TotalQueries: r.TotalQueries, + BlockedQueries: r.BlockedQueries, + BlockedPercent: int(r.BlockedPercentage), + DomainsBlocked: r.DomainsBlocked, + } + if len(r.TopBlockedDomains) > 0 { + domains := make([]dnsStatsBlockedDomain, 0, len(r.TopBlockedDomains)) + + for domain, count := range r.TopBlockedDomains { + domains = append(domains, dnsStatsBlockedDomain{ + Domain: domain, + PercentBlocked: int(float64(count) / float64(r.BlockedQueries) * 100), + }) + } + + sort.Slice(domains, func(a, b int) bool { + return domains[a].PercentBlocked > domains[b].PercentBlocked + }) + + stats.TopBlockedDomains = domains[:min(len(domains), 5)] + } + if noGraph { + return stats + } + + // Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144 + if len(r.QueriesSeries) != 144 || len(r.BlockedSeries) != 144 { + return stats + } + + var lowestTimestamp int64 = 0 + for timestamp := range r.QueriesSeries { + if lowestTimestamp == 0 || timestamp < lowestTimestamp { + lowestTimestamp = timestamp + } + } + maxQueriesInSeries := 0 + + for i := 0; i < 8; i++ { + queries := 0 + blocked := 0 + for j := 0; j < 18; j++ { + index := lowestTimestamp + int64(i*10800+j*600) + queries += r.QueriesSeries[index] + blocked += r.BlockedSeries[index] + } + if queries > maxQueriesInSeries { + maxQueriesInSeries = queries + } + stats.Series[i] = dnsStatsSeries{ + Queries: queries, + Blocked: blocked, + } + if queries > 0 { + stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100) + } + } + for i := 0; i < 8; i++ { + stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100) + } + return stats +} + func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) { var requestURL string + var sid string // Handle Pi-hole v6 authentication if version == "" || version == "6" { @@ -336,20 +471,22 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr } // If SID env var is not set, get a new SID if os.Getenv("SID") == "" { - sid, err := piholeGetSID(instanceURL, appPassword) - os.Setenv("SID", sid) + sid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) if err != nil { return nil, err } + os.Setenv("SID", sid) + } sid := os.Getenv("SID") - requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid + } else { if token == "" { return nil, errors.New("missing API token") } requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token + } request, err := http.NewRequest("GET", requestURL, nil) @@ -364,87 +501,29 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr client = defaultInsecureHTTPClient } - responseJson, err := decodeJsonFromRequest[piholeStatsResponse](client, request) + var responseJson interface{} + + if version == "" || version == "6" { + responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request) + + } else { + responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request) + } if err != nil { return nil, err } - stats := &dnsStats{ - TotalQueries: responseJson.TotalQueries, - BlockedQueries: responseJson.BlockedQueries, - BlockedPercent: int(responseJson.BlockedPercentage), - DomainsBlocked: responseJson.DomainsBlocked, - } - - if len(responseJson.TopBlockedDomains) > 0 { - domains := make([]dnsStatsBlockedDomain, 0, len(responseJson.TopBlockedDomains)) - - for domain, count := range responseJson.TopBlockedDomains { - domains = append(domains, dnsStatsBlockedDomain{ - Domain: domain, - PercentBlocked: int(float64(count) / float64(responseJson.BlockedQueries) * 100), - }) + switch r := responseJson.(type) { + case *piholeStatsResponse: + // Fetch top domains separately for v6 + topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure) + if err != nil { + return nil, err } - - sort.Slice(domains, func(a, b int) bool { - return domains[a].PercentBlocked > domains[b].PercentBlocked - }) - - stats.TopBlockedDomains = domains[:min(len(domains), 5)] + return parsePiholeStats(r, topDomains), nil + case *legacyPiholeStatsResponse: + return parsePiholeStatsLegacy(r, noGraph), nil + default: + return nil, errors.New("unexpected response type") } - - if noGraph { - return stats, nil - } - - // Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144 - if len(responseJson.QueriesSeries) != 144 || len(responseJson.BlockedSeries) != 144 { - slog.Warn( - "DNS stats for pihole: did not get expected 144 data points", - "len(queries)", len(responseJson.QueriesSeries), - "len(blocked)", len(responseJson.BlockedSeries), - ) - return stats, nil - } - - var lowestTimestamp int64 = 0 - - for timestamp := range responseJson.QueriesSeries { - if lowestTimestamp == 0 || timestamp < lowestTimestamp { - lowestTimestamp = timestamp - } - } - - maxQueriesInSeries := 0 - - for i := 0; i < 8; i++ { - queries := 0 - blocked := 0 - - for j := 0; j < 18; j++ { - index := lowestTimestamp + int64(i*10800+j*600) - - queries += responseJson.QueriesSeries[index] - blocked += responseJson.BlockedSeries[index] - } - - if queries > maxQueriesInSeries { - maxQueriesInSeries = queries - } - - stats.Series[i] = dnsStatsSeries{ - Queries: queries, - Blocked: blocked, - } - - if queries > 0 { - stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100) - } - } - - for i := 0; i < 8; i++ { - stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100) - } - - return stats, nil }