diff --git a/internal/glance/widget-dns-stats.go b/internal/glance/widget-dns-stats.go index 73de4d9..cbb9b11 100644 --- a/internal/glance/widget-dns-stats.go +++ b/internal/glance/widget-dns-stats.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "html/template" "io" "net/http" @@ -348,6 +349,35 @@ func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, return jsonResponse.Session.SID, nil } +// checkPiholeSID checks if the SID is valid by checking HTTP response status code from /api/auth. +func checkPiholeSID(instanceURL string, appPassword, sid string, allowInsecure bool) error { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth?sid=" + sid + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err + } + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + response, err := client.Do(request) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return errors.New("SID is invalid, received status: " + response.Status) + } + + return nil +} + // fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+. func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) { requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true&sid=" + sid @@ -367,6 +397,35 @@ func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) ( return decodeJsonFromRequest[piholeTopDomainsResponse](client, request) } +// fetchPiholeSeries fetches the series data for Pi-hole v6+ (QueriesSeries and BlockedSeries). +func fetchPiholeSeries(instanceURL string, sid string, allowInsecure bool) (piholeQueriesSeries, map[int64]int, error) { + requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/over_time_data?sid=" + sid + + request, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, nil, err + } + + var client requestDoer + if !allowInsecure { + client = defaultHTTPClient + } else { + client = defaultInsecureHTTPClient + } + + var responseJson struct { + QueriesSeries piholeQueriesSeries `json:"queries_over_time"` + BlockedSeries map[int64]int `json:"blocked_over_time"` + } + + err = decodeJsonFromRequest[&responseJson](client, request) + if err != nil { + return nil, nil, err + } + + return responseJson.QueriesSeries, responseJson.BlockedSeries, nil +} + // Helper functions to process the responses func parsePiholeStats(r *piholeStatsResponse, topDomains piholeTopDomainsResponse) *dnsStats { @@ -461,64 +520,75 @@ func parsePiholeStatsLegacy(r *legacyPiholeStatsResponse, noGraph bool) *dnsStat } func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) { + instanceURL = strings.TrimRight(instanceURL, "/") var requestURL string var sid string + isV6 := version == "" || version == "6" - // Handle Pi-hole v6 authentication - if version == "" || version == "6" { + if isV6 { if appPassword == "" { return nil, errors.New("missing app password") } - // If SID env var is not set, get a new SID - if os.Getenv("SID") == "" { - sid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) - if err != nil { - return nil, err - } - os.Setenv("SID", sid) + sid = os.Getenv("SID") + // Only get a new SID if it's not set or is invalid + if sid == "" { + newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to get SID: %w", err) // Use %w for wrapping + } + sid = newSid + os.Setenv("SID", sid) + } else { + // Check existing SID validity. Only get a new one if the check fails. + err := checkPiholeSID(instanceURL, appPassword, sid, allowInsecure) + if err != nil { + newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure) + if err != nil { + return nil, fmt.Errorf("failed to get SID after invalid SID check: %w", err) + } + sid = newSid + os.Setenv("SID", sid) + } } - sid := os.Getenv("SID") - requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid + + requestURL = instanceURL + "/api/stats/summary?sid=" + sid } else { if token == "" { return nil, errors.New("missing API token") } - requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token - + requestURL = instanceURL + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token } request, err := http.NewRequest("GET", requestURL, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create HTTP request: %w", err) } var client requestDoer - if !allowInsecure { - client = defaultHTTPClient - } else { + client = defaultHTTPClient + if allowInsecure { client = defaultInsecureHTTPClient } var responseJson interface{} - - if version == "" || version == "6" { + if isV6 { responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request) - } else { responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request) } + if err != nil { - return nil, err + return nil, fmt.Errorf("failed to decode JSON response: %w", err) } switch r := responseJson.(type) { case *piholeStatsResponse: - // Fetch top domains separately for v6 + // Fetch top domains separately for v6+. topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to fetch top domains: %w", err) } return parsePiholeStats(r, topDomains), nil case *legacyPiholeStatsResponse: