package bun import ( "context" "reflect" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) type relationJoin struct { Parent *relationJoin BaseModel tableModel JoinModel tableModel Relation *schema.Relation apply func(*SelectQuery) *SelectQuery columns []schema.QueryWithArgs } func (j *relationJoin) applyTo(q *SelectQuery) { if j.apply == nil { return } var table *schema.Table var columns []schema.QueryWithArgs // Save state. table, q.table = q.table, j.JoinModel.Table() columns, q.columns = q.columns, nil q = j.apply(q) // Restore state. q.table = table j.columns, q.columns = q.columns, columns } func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error { switch j.Relation.Type { } panic("not reached") } func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error { q = j.manyQuery(q) if q == nil { return nil } return q.Scan(ctx) } func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { hasManyModel := newHasManyModel(j) if hasManyModel == nil { return nil } q = q.Model(hasManyModel) var where []byte if len(j.Relation.JoinFields) > 1 { where = append(where, '(') } where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) if len(j.Relation.JoinFields) > 1 { where = append(where, ')') } where = append(where, " IN ("...) where = appendChildValues( q.db.Formatter(), where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Relation.BaseFields, ) where = append(where, ")"...) q = q.Where(internal.String(where)) if j.Relation.PolymorphicField != nil { q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) } j.applyTo(q) q = q.Apply(j.hasManyColumns) return q } func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { if j.Relation.M2MTable != nil { q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") } b := make([]byte, 0, 32) if len(j.columns) > 0 { for i, col := range j.columns { if i > 0 { b = append(b, ", "...) } var err error b, err = col.AppendQuery(q.db.fmter, b) if err != nil { q.err = err return q } } } else { joinTable := j.JoinModel.Table() b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) } q = q.ColumnExpr(internal.String(b)) return q } func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error { q = j.m2mQuery(q) if q == nil { return nil } return q.Scan(ctx) } func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { fmter := q.db.fmter m2mModel := newM2MModel(j) if m2mModel == nil { return nil } q = q.Model(m2mModel) index := j.JoinModel.ParentIndex() baseTable := j.BaseModel.Table() //nolint var join []byte join = append(join, "JOIN "...) join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) join = append(join, " AS "...) join = append(join, j.Relation.M2MTable.SQLAlias...) join = append(join, " ON ("...) for i, col := range j.Relation.M2MBaseFields { if i > 0 { join = append(join, ", "...) } join = append(join, j.Relation.M2MTable.SQLAlias...) join = append(join, '.') join = append(join, col.SQLName...) } join = append(join, ") IN ("...) join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs) join = append(join, ")"...) q = q.Join(internal.String(join)) joinTable := j.JoinModel.Table() for i, m2mJoinField := range j.Relation.M2MJoinFields { joinField := j.Relation.JoinFields[i] q = q.Where("?.? = ?.?", joinTable.SQLAlias, joinField.SQLName, j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) } j.applyTo(q) q = q.Apply(j.hasManyColumns) return q } func (j *relationJoin) hasParent() bool { if j.Parent != nil { switch j.Parent.Relation.Type { case schema.HasOneRelation, schema.BelongsToRelation: return true } } return false } func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte { quote := fmter.IdentQuote() b = append(b, quote) b = appendAlias(b, j) b = append(b, quote) return b } func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { quote := fmter.IdentQuote() b = append(b, quote) b = appendAlias(b, j) b = append(b, "__"...) b = append(b, column...) b = append(b, quote) return b } func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { quote := fmter.IdentQuote() if j.hasParent() { b = append(b, quote) b = appendAlias(b, j.Parent) b = append(b, quote) return b } return append(b, j.BaseModel.Table().SQLAlias...) } func (j *relationJoin) appendSoftDelete(b []byte, flags internal.Flag) []byte { b = append(b, '.') b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) if flags.Has(deletedFlag) { b = append(b, " IS NOT NULL"...) } else { b = append(b, " IS NULL"...) } return b } func appendAlias(b []byte, j *relationJoin) []byte { if j.hasParent() { b = appendAlias(b, j.Parent) b = append(b, "__"...) } b = append(b, j.Relation.Field.Name...) return b } func (j *relationJoin) appendHasOneJoin( fmter schema.Formatter, b []byte, q *SelectQuery, ) (_ []byte, err error) { isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) b = append(b, "LEFT JOIN "...) b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) b = append(b, " AS "...) b = j.appendAlias(fmter, b) b = append(b, " ON "...) b = append(b, '(') for i, baseField := range j.Relation.BaseFields { if i > 0 { b = append(b, " AND "...) } b = j.appendAlias(fmter, b) b = append(b, '.') b = append(b, j.Relation.JoinFields[i].SQLName...) b = append(b, " = "...) b = j.appendBaseAlias(fmter, b) b = append(b, '.') b = append(b, baseField.SQLName...) } b = append(b, ')') if isSoftDelete { b = append(b, " AND "...) b = j.appendAlias(fmter, b) b = j.appendSoftDelete(b, q.flags) } return b, nil } func appendChildValues( fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, ) []byte { seen := make(map[string]struct{}) walk(v, index, func(v reflect.Value) { start := len(b) if len(fields) > 1 { b = append(b, '(') } for i, f := range fields { if i > 0 { b = append(b, ", "...) } b = f.AppendValue(fmter, b, v) } if len(fields) > 1 { b = append(b, ')') } b = append(b, ", "...) if _, ok := seen[string(b[start:])]; ok { b = b[:start] } else { seen[string(b[start:])] = struct{}{} } }) if len(seen) > 0 { b = b[:len(b)-2] // trim ", " } return b }