Move all transactions to the exported methods

This commit is contained in:
TwinProduction 2021-07-13 22:59:43 -04:00 committed by Chris
parent 796228466d
commit 5cc1c11b1a

View File

@ -12,6 +12,10 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
//////////////////////////////////////////////////////////////////////////////////////////////////
// Note that only exported functions in this file may create, commit, or rollback a transaction //
//////////////////////////////////////////////////////////////////////////////////////////////////
const ( const (
arraySeparator = "|~|" arraySeparator = "|~|"
) )
@ -121,12 +125,27 @@ func (s *Store) createSchema() error {
// GetAllServiceStatusesWithResultPagination returns all monitored core.ServiceStatus // GetAllServiceStatusesWithResultPagination returns all monitored core.ServiceStatus
// with a subset of core.Result defined by the page and pageSize parameters // with a subset of core.Result defined by the page and pageSize parameters
func (s *Store) GetAllServiceStatusesWithResultPagination(page, pageSize int) map[string]*core.ServiceStatus { func (s *Store) GetAllServiceStatusesWithResultPagination(page, pageSize int) map[string]*core.ServiceStatus {
serviceStatuses := s.getAllServiceStatuses(0, 0, page, pageSize) tx, err := s.db.Begin()
m := make(map[string]*core.ServiceStatus, len(serviceStatuses)) if err != nil {
for _, serviceStatus := range serviceStatuses { return nil
m[serviceStatus.Key] = serviceStatus
} }
return m keys, err := s.getAllServiceKeys(tx)
if err != nil {
_ = tx.Rollback()
return nil
}
serviceStatuses := make(map[string]*core.ServiceStatus, len(keys))
for _, key := range keys {
serviceStatus, err := s.getServiceStatusByKey(tx, key, 0, 0, page, pageSize)
if err != nil {
continue
}
serviceStatuses[key] = serviceStatus
}
if err = tx.Commit(); err != nil {
_ = tx.Rollback()
}
return serviceStatuses
} }
// GetServiceStatus returns the service status for a given service name in the given group // GetServiceStatus returns the service status for a given service name in the given group
@ -136,7 +155,18 @@ func (s *Store) GetServiceStatus(groupName, serviceName string) *core.ServiceSta
// GetServiceStatusByKey returns the service status for a given key // GetServiceStatusByKey returns the service status for a given key
func (s *Store) GetServiceStatusByKey(key string) *core.ServiceStatus { func (s *Store) GetServiceStatusByKey(key string) *core.ServiceStatus {
serviceStatus, _ := s.getServiceStatusByKey(key, 1, core.MaximumNumberOfEvents, 1, core.MaximumNumberOfResults) tx, err := s.db.Begin()
if err != nil {
return nil
}
serviceStatus, err := s.getServiceStatusByKey(tx, key, 1, core.MaximumNumberOfEvents, 1, core.MaximumNumberOfResults)
if err != nil {
_ = tx.Rollback()
return nil
}
if err = tx.Commit(); err != nil {
_ = tx.Rollback()
}
return serviceStatus return serviceStatus
} }
@ -265,24 +295,8 @@ func (s *Store) Close() {
_ = s.db.Close() _ = s.db.Close()
} }
func (s *Store) getAllServiceStatuses(eventsPage, eventsPageSize, resultsPage, resultsPageSize int) []*core.ServiceStatus { func (s *Store) getAllServiceKeys(tx *sql.Tx) (keys []string, err error) {
var serviceStatuses []*core.ServiceStatus rows, err := tx.Query("SELECT service_key FROM service")
keys, err := s.getAllServiceKeys()
if err != nil {
return nil
}
for _, key := range keys {
serviceStatus, err := s.getServiceStatusByKey(key, eventsPage, eventsPageSize, resultsPage, resultsPageSize)
if err != nil {
continue
}
serviceStatuses = append(serviceStatuses, serviceStatus)
}
return serviceStatuses
}
func (s *Store) getAllServiceKeys() (keys []string, err error) {
rows, err := s.db.Query("SELECT service_key FROM service")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -295,8 +309,8 @@ func (s *Store) getAllServiceKeys() (keys []string, err error) {
return return
} }
func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, resultsPage, resultsPageSize int) (*core.ServiceStatus, error) { // TODO: add uptimePage? func (s *Store) getServiceStatusByKey(tx *sql.Tx, key string, eventsPage, eventsPageSize, resultsPage, resultsPageSize int) (*core.ServiceStatus, error) { // TODO: add uptimePage?
serviceID, serviceName, serviceGroup, err := s.getServiceIDGroupAndNameByKey(key) serviceID, serviceName, serviceGroup, err := s.getServiceIDGroupAndNameByKey(tx, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -307,12 +321,12 @@ func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, re
Uptime: nil, Uptime: nil,
} }
if eventsPageSize > 0 { if eventsPageSize > 0 {
if serviceStatus.Events, err = s.getEventsByServiceID(serviceID, eventsPage, eventsPageSize); err != nil { if serviceStatus.Events, err = s.getEventsByServiceID(tx, serviceID, eventsPage, eventsPageSize); err != nil {
log.Printf("[database][getServiceStatusByKey] Failed to retrieve events for key=%s: %s", key, err.Error()) log.Printf("[database][getServiceStatusByKey] Failed to retrieve events for key=%s: %s", key, err.Error())
} }
} }
if resultsPageSize > 0 { if resultsPageSize > 0 {
if serviceStatus.Results, err = s.getResultsByServiceID(serviceID, resultsPage, resultsPageSize); err != nil { if serviceStatus.Results, err = s.getResultsByServiceID(tx, serviceID, resultsPage, resultsPageSize); err != nil {
log.Printf("[database][getServiceStatusByKey] Failed to retrieve results for key=%s: %s", key, err.Error()) log.Printf("[database][getServiceStatusByKey] Failed to retrieve results for key=%s: %s", key, err.Error())
} }
} }
@ -321,8 +335,8 @@ func (s *Store) getServiceStatusByKey(key string, eventsPage, eventsPageSize, re
return serviceStatus, nil return serviceStatus, nil
} }
func (s *Store) getServiceIDGroupAndNameByKey(key string) (id int64, group, name string, err error) { func (s *Store) getServiceIDGroupAndNameByKey(tx *sql.Tx, key string) (id int64, group, name string, err error) {
rows, err := s.db.Query("SELECT service_id, service_group, service_name FROM service WHERE service_key = $1 LIMIT 1", key) rows, err := tx.Query("SELECT service_id, service_group, service_name FROM service WHERE service_key = $1 LIMIT 1", key)
if err != nil { if err != nil {
return 0, "", "", err return 0, "", "", err
} }
@ -336,8 +350,8 @@ func (s *Store) getServiceIDGroupAndNameByKey(key string) (id int64, group, name
return return
} }
func (s *Store) getEventsByServiceID(serviceID int64, page, pageSize int) (events []*core.Event, err error) { func (s *Store) getEventsByServiceID(tx *sql.Tx, serviceID int64, page, pageSize int) (events []*core.Event, err error) {
rows, err := s.db.Query( rows, err := tx.Query(
` `
SELECT event_type, event_timestamp SELECT event_type, event_timestamp
FROM service_event FROM service_event
@ -361,11 +375,7 @@ func (s *Store) getEventsByServiceID(serviceID int64, page, pageSize int) (event
return return
} }
func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (results []*core.Result, err error) { func (s *Store) getResultsByServiceID(tx *sql.Tx, serviceID int64, page, pageSize int) (results []*core.Result, err error) {
tx, err := s.db.Begin()
if err != nil {
return
}
rows, err := tx.Query( rows, err := tx.Query(
` `
SELECT service_result_id, success, errors, connected, status, dns_rcode, certificate_expiration, hostname, ip, duration, timestamp SELECT service_result_id, success, errors, connected, status, dns_rcode, certificate_expiration, hostname, ip, duration, timestamp
@ -379,7 +389,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu
(page-1)*pageSize, (page-1)*pageSize,
) )
if err != nil { if err != nil {
_ = tx.Rollback()
return nil, err return nil, err
} }
idResultMap := make(map[int64]*core.Result) idResultMap := make(map[int64]*core.Result)
@ -404,7 +413,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu
serviceResultID, serviceResultID,
) )
if err != nil { if err != nil {
_ = tx.Rollback()
return return
} }
for rows.Next() { for rows.Next() {
@ -414,10 +422,6 @@ func (s *Store) getResultsByServiceID(serviceID int64, page, pageSize int) (resu
} }
_ = rows.Close() _ = rows.Close()
} }
if err = tx.Commit(); err != nil {
_ = tx.Rollback()
return
}
return return
} }
@ -564,7 +568,6 @@ func (s *Store) insertConditionResults(tx *sql.Tx, serviceResultID int64, condit
cr.Success, cr.Success,
) )
if err != nil { if err != nil {
_ = tx.Rollback()
return err return err
} }
} }