Pushing latest fixes to handle the two different JSON responses.

This commit is contained in:
Keith Carichner Jr 2025-02-20 14:54:11 -05:00
parent 72c1ebf66d
commit e643a44b59

View File

@ -7,7 +7,6 @@ import (
"errors" "errors"
"html/template" "html/template"
"io" "io"
"log/slog"
"net/http" "net/http"
"os" "os"
"sort" "sort"
@ -15,9 +14,6 @@ import (
"time" "time"
) )
// Global HTTP client for reuse
var httpClient = &http.Client{}
var dnsStatsWidgetTemplate = mustParseTemplate("dns-stats.html", "widget-base.html") var dnsStatsWidgetTemplate = mustParseTemplate("dns-stats.html", "widget-base.html")
type dnsStatsWidget struct { type dnsStatsWidget struct {
@ -235,7 +231,8 @@ func fetchAdguardStats(instanceURL string, allowInsecure bool, username, passwor
return stats, nil return stats, nil
} }
type piholeStatsResponse struct { // Legacy Pi-hole stats response (before v6)
type legacyPiholeStatsResponse struct {
TotalQueries int `json:"dns_queries_today"` TotalQueries int `json:"dns_queries_today"`
QueriesSeries piholeQueriesSeries `json:"domains_over_time"` QueriesSeries piholeQueriesSeries `json:"domains_over_time"`
BlockedQueries int `json:"ads_blocked_today"` BlockedQueries int `json:"ads_blocked_today"`
@ -245,6 +242,24 @@ type piholeStatsResponse struct {
DomainsBlocked int `json:"domains_being_blocked"` DomainsBlocked int `json:"domains_being_blocked"`
} }
// Pi-hole v6+ response format
type piholeStatsResponse struct {
Queries struct {
Total int `json:"total"`
Blocked int `json:"blocked"`
PercentBlocked float64 `json:"percent_blocked"`
} `json:"queries"`
Gravity struct {
DomainsBlocked int `json:"domains_being_blocked"`
} `json:"gravity"`
//Note we do not need the full structure. We extract the values needed
//Adding dummy fields to allow easier json parsing.
QueriesSeries piholeQueriesSeries `json:"domains_over_time"` // Will always be empty
BlockedSeries map[int64]int `json:"ads_over_time"` // Will always be empty.
}
type piholeTopDomainsResponse map[string]int
// If the user has query logging disabled it's possible for domains_over_time to be returned as an // If the user has query logging disabled it's possible for domains_over_time to be returned as an
// empty array rather than a map which will prevent unmashalling the rest of the data so we use // empty array rather than a map which will prevent unmashalling the rest of the data so we use
// custom unmarshal behavior to fallback to an empty map. // custom unmarshal behavior to fallback to an empty map.
@ -284,7 +299,14 @@ func (p *piholeTopBlockedDomains) UnmarshalJSON(data []byte) error {
} }
// piholeGetSID retrieves a new SID from Pi-hole using the app password. // piholeGetSID retrieves a new SID from Pi-hole using the app password.
func piholeGetSID(instanceURL, appPassword string) (string, error) { func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, error) {
var client requestDoer
if !allowInsecure {
client = defaultHTTPClient
} else {
client = defaultInsecureHTTPClient
}
requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth" requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth"
requestBody := []byte(`{"password":"` + appPassword + `"}`) requestBody := []byte(`{"password":"` + appPassword + `"}`)
@ -294,7 +316,7 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) {
} }
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
response, err := httpClient.Do(request) response, err := client.Do(request)
if err != nil { if err != nil {
return "", errors.New("failed to send authentication request: " + err.Error()) return "", errors.New("failed to send authentication request: " + err.Error())
} }
@ -326,8 +348,121 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) {
return jsonResponse.Session.SID, nil return jsonResponse.Session.SID, nil
} }
// fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+.
func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) {
requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true&sid=" + sid
request, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, err
}
var client requestDoer
if !allowInsecure {
client = defaultHTTPClient
} else {
client = defaultInsecureHTTPClient
}
return decodeJsonFromRequest[piholeTopDomainsResponse](client, request)
}
// Helper functions to process the responses
func parsePiholeStats(r *piholeStatsResponse, topDomains piholeTopDomainsResponse) *dnsStats {
stats := &dnsStats{
TotalQueries: r.Queries.Total,
BlockedQueries: r.Queries.Blocked,
BlockedPercent: int(r.Queries.PercentBlocked),
DomainsBlocked: r.Gravity.DomainsBlocked,
}
if len(topDomains) > 0 {
domains := make([]dnsStatsBlockedDomain, 0, len(topDomains))
for domain, count := range topDomains {
domains = append(domains, dnsStatsBlockedDomain{
Domain: domain,
PercentBlocked: int(float64(count) / float64(r.Queries.Blocked) * 100), // Calculate percentage here
})
}
sort.Slice(domains, func(a, b int) bool {
return domains[a].PercentBlocked > domains[b].PercentBlocked
})
stats.TopBlockedDomains = domains[:min(len(domains), 5)]
}
return stats
}
func parsePiholeStatsLegacy(r *legacyPiholeStatsResponse, noGraph bool) *dnsStats {
stats := &dnsStats{
TotalQueries: r.TotalQueries,
BlockedQueries: r.BlockedQueries,
BlockedPercent: int(r.BlockedPercentage),
DomainsBlocked: r.DomainsBlocked,
}
if len(r.TopBlockedDomains) > 0 {
domains := make([]dnsStatsBlockedDomain, 0, len(r.TopBlockedDomains))
for domain, count := range r.TopBlockedDomains {
domains = append(domains, dnsStatsBlockedDomain{
Domain: domain,
PercentBlocked: int(float64(count) / float64(r.BlockedQueries) * 100),
})
}
sort.Slice(domains, func(a, b int) bool {
return domains[a].PercentBlocked > domains[b].PercentBlocked
})
stats.TopBlockedDomains = domains[:min(len(domains), 5)]
}
if noGraph {
return stats
}
// Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144
if len(r.QueriesSeries) != 144 || len(r.BlockedSeries) != 144 {
return stats
}
var lowestTimestamp int64 = 0
for timestamp := range r.QueriesSeries {
if lowestTimestamp == 0 || timestamp < lowestTimestamp {
lowestTimestamp = timestamp
}
}
maxQueriesInSeries := 0
for i := 0; i < 8; i++ {
queries := 0
blocked := 0
for j := 0; j < 18; j++ {
index := lowestTimestamp + int64(i*10800+j*600)
queries += r.QueriesSeries[index]
blocked += r.BlockedSeries[index]
}
if queries > maxQueriesInSeries {
maxQueriesInSeries = queries
}
stats.Series[i] = dnsStatsSeries{
Queries: queries,
Blocked: blocked,
}
if queries > 0 {
stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100)
}
}
for i := 0; i < 8; i++ {
stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100)
}
return stats
}
func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) { func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) {
var requestURL string var requestURL string
var sid string
// Handle Pi-hole v6 authentication // Handle Pi-hole v6 authentication
if version == "" || version == "6" { if version == "" || version == "6" {
@ -336,20 +471,22 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
} }
// If SID env var is not set, get a new SID // If SID env var is not set, get a new SID
if os.Getenv("SID") == "" { if os.Getenv("SID") == "" {
sid, err := piholeGetSID(instanceURL, appPassword) sid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
os.Setenv("SID", sid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
os.Setenv("SID", sid)
} }
sid := os.Getenv("SID") sid := os.Getenv("SID")
requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid
} else { } else {
if token == "" { if token == "" {
return nil, errors.New("missing API token") return nil, errors.New("missing API token")
} }
requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
} }
request, err := http.NewRequest("GET", requestURL, nil) request, err := http.NewRequest("GET", requestURL, nil)
@ -364,87 +501,29 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
client = defaultInsecureHTTPClient client = defaultInsecureHTTPClient
} }
responseJson, err := decodeJsonFromRequest[piholeStatsResponse](client, request) var responseJson interface{}
if version == "" || version == "6" {
responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request)
} else {
responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
stats := &dnsStats{ switch r := responseJson.(type) {
TotalQueries: responseJson.TotalQueries, case *piholeStatsResponse:
BlockedQueries: responseJson.BlockedQueries, // Fetch top domains separately for v6
BlockedPercent: int(responseJson.BlockedPercentage), topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure)
DomainsBlocked: responseJson.DomainsBlocked, if err != nil {
return nil, err
} }
return parsePiholeStats(r, topDomains), nil
if len(responseJson.TopBlockedDomains) > 0 { case *legacyPiholeStatsResponse:
domains := make([]dnsStatsBlockedDomain, 0, len(responseJson.TopBlockedDomains)) return parsePiholeStatsLegacy(r, noGraph), nil
default:
for domain, count := range responseJson.TopBlockedDomains { return nil, errors.New("unexpected response type")
domains = append(domains, dnsStatsBlockedDomain{
Domain: domain,
PercentBlocked: int(float64(count) / float64(responseJson.BlockedQueries) * 100),
})
}
sort.Slice(domains, func(a, b int) bool {
return domains[a].PercentBlocked > domains[b].PercentBlocked
})
stats.TopBlockedDomains = domains[:min(len(domains), 5)]
}
if noGraph {
return stats, nil
}
// Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144
if len(responseJson.QueriesSeries) != 144 || len(responseJson.BlockedSeries) != 144 {
slog.Warn(
"DNS stats for pihole: did not get expected 144 data points",
"len(queries)", len(responseJson.QueriesSeries),
"len(blocked)", len(responseJson.BlockedSeries),
)
return stats, nil
}
var lowestTimestamp int64 = 0
for timestamp := range responseJson.QueriesSeries {
if lowestTimestamp == 0 || timestamp < lowestTimestamp {
lowestTimestamp = timestamp
} }
} }
maxQueriesInSeries := 0
for i := 0; i < 8; i++ {
queries := 0
blocked := 0
for j := 0; j < 18; j++ {
index := lowestTimestamp + int64(i*10800+j*600)
queries += responseJson.QueriesSeries[index]
blocked += responseJson.BlockedSeries[index]
}
if queries > maxQueriesInSeries {
maxQueriesInSeries = queries
}
stats.Series[i] = dnsStatsSeries{
Queries: queries,
Blocked: blocked,
}
if queries > 0 {
stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100)
}
}
for i := 0; i < 8; i++ {
stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100)
}
return stats, nil
}