mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-10 16:28:13 +01:00
242 lines
7.2 KiB
Go
242 lines
7.2 KiB
Go
package sqlschema
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/uptrace/bun"
|
|
"github.com/uptrace/bun/schema"
|
|
orderedmap "github.com/wk8/go-ordered-map/v2"
|
|
)
|
|
|
|
type InspectorDialect interface {
|
|
schema.Dialect
|
|
|
|
// Inspector returns a new instance of Inspector for the dialect.
|
|
// Dialects MAY set their default InspectorConfig values in constructor
|
|
// but MUST apply InspectorOptions to ensure they can be overriden.
|
|
//
|
|
// Use ApplyInspectorOptions to reduce boilerplate.
|
|
NewInspector(db *bun.DB, options ...InspectorOption) Inspector
|
|
|
|
// CompareType returns true if col1 and co2 SQL types are equivalent,
|
|
// i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT)
|
|
// or specify the same VARCHAR length differently (VARCHAR(255) ~ VARCHAR).
|
|
CompareType(Column, Column) bool
|
|
}
|
|
|
|
// InspectorConfig controls the scope of migration by limiting the objects Inspector should return.
|
|
// Inspectors SHOULD use the configuration directly instead of copying it, or MAY choose to embed it,
|
|
// to make sure options are always applied correctly.
|
|
type InspectorConfig struct {
|
|
// SchemaName limits inspection to tables in a particular schema.
|
|
SchemaName string
|
|
|
|
// ExcludeTables from inspection.
|
|
ExcludeTables []string
|
|
}
|
|
|
|
// Inspector reads schema state.
|
|
type Inspector interface {
|
|
Inspect(ctx context.Context) (Database, error)
|
|
}
|
|
|
|
func WithSchemaName(schemaName string) InspectorOption {
|
|
return func(cfg *InspectorConfig) {
|
|
cfg.SchemaName = schemaName
|
|
}
|
|
}
|
|
|
|
// WithExcludeTables works in append-only mode, i.e. tables cannot be re-included.
|
|
func WithExcludeTables(tables ...string) InspectorOption {
|
|
return func(cfg *InspectorConfig) {
|
|
cfg.ExcludeTables = append(cfg.ExcludeTables, tables...)
|
|
}
|
|
}
|
|
|
|
// NewInspector creates a new database inspector, if the dialect supports it.
|
|
func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) {
|
|
dialect, ok := (db.Dialect()).(InspectorDialect)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name())
|
|
}
|
|
return &inspector{
|
|
Inspector: dialect.NewInspector(db, options...),
|
|
}, nil
|
|
}
|
|
|
|
func NewBunModelInspector(tables *schema.Tables, options ...InspectorOption) *BunModelInspector {
|
|
bmi := &BunModelInspector{
|
|
tables: tables,
|
|
}
|
|
ApplyInspectorOptions(&bmi.InspectorConfig, options...)
|
|
return bmi
|
|
}
|
|
|
|
type InspectorOption func(*InspectorConfig)
|
|
|
|
func ApplyInspectorOptions(cfg *InspectorConfig, options ...InspectorOption) {
|
|
for _, opt := range options {
|
|
opt(cfg)
|
|
}
|
|
}
|
|
|
|
// inspector is opaque pointer to a database inspector.
|
|
type inspector struct {
|
|
Inspector
|
|
}
|
|
|
|
// BunModelInspector creates the current project state from the passed bun.Models.
|
|
// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run.
|
|
type BunModelInspector struct {
|
|
InspectorConfig
|
|
tables *schema.Tables
|
|
}
|
|
|
|
var _ Inspector = (*BunModelInspector)(nil)
|
|
|
|
func (bmi *BunModelInspector) Inspect(ctx context.Context) (Database, error) {
|
|
state := BunModelSchema{
|
|
BaseDatabase: BaseDatabase{
|
|
ForeignKeys: make(map[ForeignKey]string),
|
|
},
|
|
Tables: orderedmap.New[string, Table](),
|
|
}
|
|
for _, t := range bmi.tables.All() {
|
|
if t.Schema != bmi.SchemaName {
|
|
continue
|
|
}
|
|
|
|
columns := orderedmap.New[string, Column]()
|
|
for _, f := range t.Fields {
|
|
|
|
sqlType, length, err := parseLen(f.CreateTableSQLType)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse length in %q: %w", f.CreateTableSQLType, err)
|
|
}
|
|
columns.Set(f.Name, &BaseColumn{
|
|
Name: f.Name,
|
|
SQLType: strings.ToLower(sqlType), // TODO(dyma): maybe this is not necessary after Column.Eq()
|
|
VarcharLen: length,
|
|
DefaultValue: exprToLower(f.SQLDefault),
|
|
IsNullable: !f.NotNull,
|
|
IsAutoIncrement: f.AutoIncrement,
|
|
IsIdentity: f.Identity,
|
|
})
|
|
}
|
|
|
|
var unique []Unique
|
|
for name, group := range t.Unique {
|
|
// Create a separate unique index for single-column unique constraints
|
|
// let each dialect apply the default naming convention.
|
|
if name == "" {
|
|
for _, f := range group {
|
|
unique = append(unique, Unique{Columns: NewColumns(f.Name)})
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Set the name if it is a "unique group", in which case the user has provided the name.
|
|
var columns []string
|
|
for _, f := range group {
|
|
columns = append(columns, f.Name)
|
|
}
|
|
unique = append(unique, Unique{Name: name, Columns: NewColumns(columns...)})
|
|
}
|
|
|
|
var pk *PrimaryKey
|
|
if len(t.PKs) > 0 {
|
|
var columns []string
|
|
for _, f := range t.PKs {
|
|
columns = append(columns, f.Name)
|
|
}
|
|
pk = &PrimaryKey{Columns: NewColumns(columns...)}
|
|
}
|
|
|
|
// In cases where a table is defined in a non-default schema in the `bun:table` tag,
|
|
// schema.Table only extracts the name of the schema, but passes the entire tag value to t.Name
|
|
// for backwads-compatibility. For example, a bun model like this:
|
|
// type Model struct { bun.BaseModel `bun:"table:favourite.books` }
|
|
// produces
|
|
// schema.Table{ Schema: "favourite", Name: "favourite.books" }
|
|
tableName := strings.TrimPrefix(t.Name, t.Schema+".")
|
|
state.Tables.Set(tableName, &BunTable{
|
|
BaseTable: BaseTable{
|
|
Schema: t.Schema,
|
|
Name: tableName,
|
|
Columns: columns,
|
|
UniqueConstraints: unique,
|
|
PrimaryKey: pk,
|
|
},
|
|
Model: t.ZeroIface,
|
|
})
|
|
|
|
for _, rel := range t.Relations {
|
|
// These relations are nominal and do not need a foreign key to be declared in the current table.
|
|
// They will be either expressed as N:1 relations in an m2m mapping table, or will be referenced by the other table if it's a 1:N.
|
|
if rel.Type == schema.ManyToManyRelation ||
|
|
rel.Type == schema.HasManyRelation {
|
|
continue
|
|
}
|
|
|
|
var fromCols, toCols []string
|
|
for _, f := range rel.BasePKs {
|
|
fromCols = append(fromCols, f.Name)
|
|
}
|
|
for _, f := range rel.JoinPKs {
|
|
toCols = append(toCols, f.Name)
|
|
}
|
|
|
|
target := rel.JoinTable
|
|
state.ForeignKeys[ForeignKey{
|
|
From: NewColumnReference(t.Name, fromCols...),
|
|
To: NewColumnReference(target.Name, toCols...),
|
|
}] = ""
|
|
}
|
|
}
|
|
return state, nil
|
|
}
|
|
|
|
func parseLen(typ string) (string, int, error) {
|
|
paren := strings.Index(typ, "(")
|
|
if paren == -1 {
|
|
return typ, 0, nil
|
|
}
|
|
length, err := strconv.Atoi(typ[paren+1 : len(typ)-1])
|
|
if err != nil {
|
|
return typ, 0, err
|
|
}
|
|
return typ[:paren], length, nil
|
|
}
|
|
|
|
// exprToLower converts string to lowercase, if it does not contain a string literal 'lit'.
|
|
// Use it to ensure that user-defined default values in the models are always comparable
|
|
// to those returned by the database inspector, regardless of the case convention in individual drivers.
|
|
func exprToLower(s string) string {
|
|
if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") {
|
|
return s
|
|
}
|
|
return strings.ToLower(s)
|
|
}
|
|
|
|
// BunModelSchema is the schema state derived from bun table models.
|
|
type BunModelSchema struct {
|
|
BaseDatabase
|
|
|
|
Tables *orderedmap.OrderedMap[string, Table]
|
|
}
|
|
|
|
func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] {
|
|
return ms.Tables
|
|
}
|
|
|
|
// BunTable provides additional table metadata that is only accessible from scanning bun models.
|
|
type BunTable struct {
|
|
BaseTable
|
|
|
|
// Model stores the zero interface to the underlying Go struct.
|
|
Model interface{}
|
|
}
|