gotosocial/vendor/github.com/uptrace/bun/dialect/pgdialect/inspector.go

298 lines
9.0 KiB
Go

package pgdialect
import (
"context"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type (
Schema = sqlschema.BaseDatabase
Table = sqlschema.BaseTable
Column = sqlschema.BaseColumn
)
func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}
type Inspector struct {
sqlschema.InspectorConfig
db *bun.DB
}
var _ sqlschema.Inspector = (*Inspector)(nil)
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}
var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}
var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
for _, table := range tables {
var columns []*InformationSchemaColumn
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return dbSchema, err
}
colDefs := orderedmap.New[string, sqlschema.Column]()
uniqueGroups := make(map[string][]string)
for _, c := range columns {
def := c.Default
if c.IsSerial || c.IsIdentity {
def = ""
} else if !c.IsDefaultLiteral {
def = strings.ToLower(def)
}
colDefs.Set(c.Name, &Column{
Name: c.Name,
SQLType: c.DataType,
VarcharLen: c.VarcharLen,
DefaultValue: def,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
})
for _, group := range c.UniqueGroups {
uniqueGroups[group] = append(uniqueGroups[group], c.Name)
}
}
var unique []sqlschema.Unique
for name, columns := range uniqueGroups {
unique = append(unique, sqlschema.Unique{
Name: name,
Columns: sqlschema.NewColumns(columns...),
})
}
var pk *sqlschema.PrimaryKey
if len(table.PrimaryKey.Columns) > 0 {
pk = &sqlschema.PrimaryKey{
Name: table.PrimaryKey.ConstraintName,
Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...),
}
}
dbSchema.Tables.Set(table.Name, &Table{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
PrimaryKey: pk,
UniqueConstraints: unique,
})
}
for _, fk := range fks {
dbSchema.ForeignKeys[sqlschema.ForeignKey{
From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...),
To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...),
}] = fk.ConstraintName
}
return dbSchema, nil
}
type InformationSchemaTable struct {
Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`
PrimaryKey PrimaryKey `bun:"embed:primary_key_"`
Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"`
}
type InformationSchemaColumn struct {
Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsDefaultLiteral bool `bun:"default_is_literal_expr"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
UniqueGroups []string `bun:"unique_groups,array"`
}
type ForeignKey struct {
ConstraintName string `bun:"constraint_name"`
SourceSchema string `bun:"schema_name"`
SourceTable string `bun:"table_name"`
SourceColumns []string `bun:"columns,array"`
TargetSchema string `bun:"target_schema"`
TargetTable string `bun:"target_table"`
TargetColumns []string `bun:"target_columns,array"`
}
type PrimaryKey struct {
ConstraintName string `bun:"name"`
Columns []string `bun:"columns,array"`
}
const (
// sqlInspectTables retrieves all user-defined tables in the selected schema.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectTables = `
SELECT
"t".table_schema,
"t".table_name,
pk.name AS primary_key_name,
pk.columns AS primary_key_columns
FROM information_schema.tables "t"
LEFT JOIN (
SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns"
FROM pg_index i
JOIN pg_attribute "a"
ON "a".attrelid = i.indrelid
AND "a".attnum = ANY("i".indkey)
AND i.indisprimary
JOIN pg_class "idx" ON i.indexrelid = "idx".oid
GROUP BY 1, 2
) pk
ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid
WHERE table_type = 'BASE TABLE'
AND "t".table_schema = ?
AND "t".table_schema NOT LIKE 'pg_%'
AND "table_name" NOT IN (?)
ORDER BY "t".table_schema, "t".table_name
`
// sqlInspectColumnsQuery retrieves column definitions for the specified table.
// Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw
// with additional args for table_schema and table_name.
sqlInspectColumnsQuery = `
SELECT
"c".table_schema,
"c".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
"c"."unique_groups" AS unique_groups
FROM (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."unique_groups",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS "table_schema",
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
GROUP BY 1, 2, 3, 4, 5
) att USING ("table_schema", "table_name", "column_name")
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
ORDER BY "table_schema", "table_name", "column_name"
`
// sqlInspectForeignKeys get FK definitions for user-defined tables.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectForeignKeys = `
WITH
"schemas" AS (
SELECT oid, nspname
FROM pg_namespace
),
"tables" AS (
SELECT oid, relnamespace, relname, relkind
FROM pg_class
),
"columns" AS (
SELECT attrelid, attname, attnum
FROM pg_attribute
WHERE attisdropped = false
)
SELECT DISTINCT
co.conname AS "constraint_name",
ss.nspname AS schema_name,
s.relname AS "table_name",
ARRAY_AGG(sc.attname) AS "columns",
ts.nspname AS target_schema,
"t".relname AS target_table,
ARRAY_AGG(tc.attname) AS target_columns
FROM pg_constraint co
LEFT JOIN "tables" s ON s.oid = co.conrelid
LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace
LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey)
LEFT JOIN "tables" t ON t.oid = co.confrelid
LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace
LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey)
WHERE co.contype = 'f'
AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r')
AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum)
AND ss.nspname = ?
AND s.relname NOT IN (?) AND "t".relname NOT IN (?)
GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table
`
)