gotosocial/vendor/github.com/go-pg/pg/v10/orm/tables.go

137 lines
2.3 KiB
Go
Raw Normal View History

package orm
import (
"fmt"
"reflect"
"sync"
"github.com/go-pg/pg/v10/types"
)
var _tables = newTables()
type tableInProgress struct {
table *Table
init1Once sync.Once
init2Once sync.Once
}
func newTableInProgress(table *Table) *tableInProgress {
return &tableInProgress{
table: table,
}
}
func (inp *tableInProgress) init1() bool {
var inited bool
inp.init1Once.Do(func() {
inp.table.init1()
inited = true
})
return inited
}
func (inp *tableInProgress) init2() bool {
var inited bool
inp.init2Once.Do(func() {
inp.table.init2()
inited = true
})
return inited
}
// GetTable returns a Table for a struct type.
func GetTable(typ reflect.Type) *Table {
return _tables.Get(typ)
}
// RegisterTable registers a struct as SQL table.
// It is usually used to register intermediate table
// in many to many relationship.
func RegisterTable(strct interface{}) {
_tables.Register(strct)
}
type tables struct {
tables sync.Map
mu sync.RWMutex
inProgress map[reflect.Type]*tableInProgress
}
func newTables() *tables {
return &tables{
inProgress: make(map[reflect.Type]*tableInProgress),
}
}
func (t *tables) Register(strct interface{}) {
typ := reflect.TypeOf(strct)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
_ = t.Get(typ)
}
func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table {
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}
if v, ok := t.tables.Load(typ); ok {
return v.(*Table)
}
t.mu.Lock()
if v, ok := t.tables.Load(typ); ok {
t.mu.Unlock()
return v.(*Table)
}
var table *Table
inProgress := t.inProgress[typ]
if inProgress == nil {
table = newTable(typ)
inProgress = newTableInProgress(table)
t.inProgress[typ] = inProgress
} else {
table = inProgress.table
}
t.mu.Unlock()
inProgress.init1()
if allowInProgress {
return table
}
if inProgress.init2() {
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
}
return table
}
func (t *tables) Get(typ reflect.Type) *Table {
return t.get(typ, false)
}
func (t *tables) getByName(name types.Safe) *Table {
var found *Table
t.tables.Range(func(key, value interface{}) bool {
t := value.(*Table)
if t.SQLName == name {
found = t
return false
}
return true
})
return found
}