diff --git a/go.mod b/go.mod index 5f8675626..aa016f2c3 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/h2non/filetype v1.1.3 github.com/jackc/pgconn v1.14.0 - github.com/jackc/pgx/v5 v5.3.1 + github.com/jackc/pgx/v5 v5.4.1 github.com/microcosm-cc/bluemonday v1.0.24 github.com/miekg/dns v1.1.54 github.com/minio/minio-go/v7 v7.0.56 diff --git a/go.sum b/go.sum index 07d2ed9da..b57223ccc 100644 --- a/go.sum +++ b/go.sum @@ -389,8 +389,8 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= -github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jackc/pgx/v5 v5.4.1 h1:oKfB/FhuVtit1bBM3zNRRsZ925ZkMN3HXL+LgLUM9lE= +github.com/jackc/pgx/v5 v5.4.1/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index ec90631cd..cc090da4c 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,31 @@ +# 5.4.1 (June 18, 2023) + +* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov) +* Add TxOptions.BeginQuery to allow overriding the default BEGIN query + +# 5.4.0 (June 14, 2023) + +* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues. +* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov) +* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic +* CancelRequest: don't try to read the reply (Nicola Murino) +* Fix: correctly handle bool type aliases (Wichert Akkerman) +* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr() +* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones) +* Add BeforeClose to pgxpool.Pool (Evan Cordell) +* Fix: various hstore fixes and optimizations (Evan Jones) +* Fix: RowToStructByPos with embedded unexported struct +* Support different bool string representations (Lev Zakharov) +* Fix: error when using BatchResults.Exec on a select that returns an error after some rows. +* Fix: pipelineBatchResults.Exec() not returning error from ResultReader +* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using + a callback. +* Fix: scanning a table type into a struct +* Fix: scan array of record to pointer to slice of struct +* Fix: handle null for json (Cemre Mengu) +* Batch Query callback is called even when there is an error +* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P) + # 5.3.1 (February 27, 2023) * Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka) diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index 29d9521c6..14327f2c6 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -132,13 +132,24 @@ These adapters can be used with the tracelog package. * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) +* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog) ## 3rd Party Libraries with PGX Support +### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock) + +pgxmock is a mock library implementing pgx interfaces. +pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection. + ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) Library for scanning data from a database into Go structs and more. +### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql) + +A carefully designed SQL client for making using SQL easier, +more productive, and less error-prone on Golang. + ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index af62039f8..8f6ea4f0d 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -21,13 +21,10 @@ type QueuedQuery struct { // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Query(fn func(rows Rows) error) { qq.fn = func(br BatchResults) error { - rows, err := br.Query() - if err != nil { - return err - } + rows, _ := br.Query() defer rows.Close() - err = fn(rows) + err := fn(rows) if err != nil { return err } @@ -142,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { } commandTag, err := br.mrr.ResultReader().Close() - br.err = err + if err != nil { + br.err = err + br.mrr.Close() + } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -228,7 +228,7 @@ func (br *batchResults) Close() error { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { @@ -290,7 +290,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { results, err := br.pipeline.GetResults() if err != nil { br.err = err - return pgconn.CommandTag{}, err + return pgconn.CommandTag{}, br.err } var commandTag pgconn.CommandTag switch results := results.(type) { @@ -309,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -384,24 +384,20 @@ func (br *pipelineBatchResults) Close() error { } }() - if br.err != nil { - return br.err - } - - if br.lastRows != nil && br.lastRows.err != nil { + if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err return br.err } if br.closed { - return nil + return br.err } // Read and run fn for all remaining items for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 92b6f3e4a..a609d1002 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -178,7 +178,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) + return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) } } @@ -382,11 +382,9 @@ func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } -// Ping executes an empty sql statement against the *Conn -// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +// Ping delegates to the underlying *pgconn.PgConn.Ping. func (c *Conn) Ping(ctx context.Context) error { - _, err := c.Exec(ctx, ";") - return err + return c.pgConn.Ping(ctx) } // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the @@ -585,8 +583,10 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { QueryExecModeCacheDescribe // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips - // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even - // when the the database schema is modified concurrently. + // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the + // statement description on the first round trip and then uses it to execute the query on the second round trip. This + // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe + // even when the the database schema is modified concurrently. QueryExecModeDescribeExec // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol @@ -648,6 +648,9 @@ type QueryRewriter interface { // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // is allowed to ignore the error returned from Query and handle it in Rows. // +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. +// // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be // collected before processing rather than processed while receiving each row. This avoids the possibility of the // application processing rows from a query that the server rejected. The CollectRows function is useful here. @@ -975,7 +978,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.statementCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1007,7 +1010,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.descriptionCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1074,18 +1077,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } resultSD, ok := results.(*pgconn.StatementDescription) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} } // Fill in the previously empty / pending statement descriptions. @@ -1095,12 +1098,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} } } @@ -1117,7 +1120,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d if err != nil { // we wrap the error so we the user can understand which query failed inside the batch err = fmt.Errorf("error building query %s: %w", bi.query, err) - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { @@ -1129,7 +1132,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } return &pipelineBatchResults{ @@ -1282,7 +1285,9 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com var fieldOID uint32 rows, _ := c.Query(ctx, `select attname, atttypid from pg_attribute -where attrelid=$1 and not attisdropped +where attrelid=$1 + and not attisdropped + and attnum > 0 order by attnum`, typrelid, ) @@ -1324,6 +1329,7 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) + delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go deleted file mode 100644 index 4bf25481c..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( - "sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { - lock sync.Mutex - queue []*[]byte - r, w int -} - -func (bq *bufferQueue) pushBack(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - bq.queue[bq.w] = buf - bq.w++ -} - -func (bq *bufferQueue) pushFront(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) - bq.queue[bq.r] = buf - bq.w++ -} - -func (bq *bufferQueue) popFront() *[]byte { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.r == bq.w { - return nil - } - - buf := bq.queue[bq.r] - bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. - bq.r++ - - if bq.r == bq.w { - bq.r = 0 - bq.w = 0 - if len(bq.queue) > minBufferQueueLen { - bq.queue = make([]*[]byte, minBufferQueueLen) - } - } - - return buf -} - -func (bq *bufferQueue) growQueue() { - desiredLen := (len(bq.queue) + 1) * 3 / 2 - if desiredLen < minBufferQueueLen { - desiredLen = minBufferQueueLen - } - - newQueue := make([]*[]byte, desiredLen) - copy(newQueue, bq.queue) - bq.queue = newQueue -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go deleted file mode 100644 index 7a38383f0..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go +++ /dev/null @@ -1,520 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( - "crypto/tls" - "errors" - "net" - "os" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond -const minNonblockingReadWaitDuration = time.Microsecond -const maxNonblockingReadWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { - return "would block" -} - -func (*wouldBlockError) Timeout() bool { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { - net.Conn - - // Flush flushes any buffered writes. - Flush() error - - // BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. - BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { - // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit - // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and - // https://github.com/jackc/pgx/issues/1307. Only access with atomics - closed int64 // 0 = not closed, 1 = closed - - conn net.Conn - rawConn syscall.RawConn - - readQueue bufferQueue - writeQueue bufferQueue - - readFlushLock sync.Mutex - // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockWriteFunc func(fd uintptr) (done bool) - nonblockWriteBuf []byte - nonblockWriteErr error - nonblockWriteN int - - // non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockReadFunc func(fd uintptr) (done bool) - nonblockReadBuf []byte - nonblockReadErr error - nonblockReadN int - - readDeadlineLock sync.Mutex - readDeadline time.Time - readNonblocking bool - fakeNonBlockingShortReadCount int - fakeNonblockingReadWaitDuration time.Duration - - writeDeadlineLock sync.Mutex - writeDeadline time.Time -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { - nc := &NetConn{ - conn: conn, - fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, - } - - if !fakeNonBlockingIO { - if sc, ok := conn.(syscall.Conn); ok { - if rawConn, err := sc.SyscallConn(); err == nil { - nc.rawConn = rawConn - } - } - } - - return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - - err = c.flush() - if err != nil { - return 0, err - } - - for n < len(b) { - buf := c.readQueue.popFront() - if buf == nil { - break - } - copiedN := copy(b[n:], *buf) - if copiedN < len(*buf) { - *buf = (*buf)[copiedN:] - c.readQueue.pushFront(buf) - } else { - iobufpool.Put(buf) - } - n += copiedN - } - - // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to - // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. - if n > 0 { - return n, nil - } - - var readNonblocking bool - c.readDeadlineLock.Lock() - readNonblocking = c.readNonblocking - c.readDeadlineLock.Unlock() - - var readN int - if readNonblocking { - readN, err = c.nonblockingRead(b[n:]) - } else { - readN, err = c.conn.Read(b[n:]) - } - n += readN - return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - buf := iobufpool.Get(len(b)) - copy(*buf, b) - c.writeQueue.pushBack(buf) - return len(b), nil -} - -func (c *NetConn) Close() (err error) { - swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) - if !swapped { - return errClosed - } - - defer func() { - closeErr := c.conn.Close() - if err == nil { - err = closeErr - } - }() - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - err = c.flush() - if err != nil { - return err - } - - return nil -} - -func (c *NetConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { - err := c.SetReadDeadline(t) - if err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - if c.readDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.readDeadline = t - return nil - } - - if t == NonBlockingDeadline { - c.readNonblocking = true - t = time.Time{} - } else { - c.readNonblocking = false - } - - c.readDeadline = t - - return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - if c.writeDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.writeDeadline = t - return nil - } - - c.writeDeadline = t - - return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { - if c.isClosed() { - return errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { - var stopChan chan struct{} - var errChan chan error - - defer func() { - if stopChan != nil { - select { - case stopChan <- struct{}{}: - case <-errChan: - } - } - }() - - for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { - remainingBuf := *buf - for len(remainingBuf) > 0 { - n, err := c.nonblockingWrite(remainingBuf) - remainingBuf = remainingBuf[n:] - if err != nil { - if !errors.Is(err, ErrWouldBlock) { - *buf = (*buf)[:len(remainingBuf)] - copy(*buf, remainingBuf) - c.writeQueue.pushFront(buf) - return err - } - - // Writing was blocked. Reading might unblock it. - if stopChan == nil { - stopChan, errChan = c.bufferNonblockingRead() - } - - select { - case err := <-errChan: - stopChan = nil - return err - default: - } - - } - } - iobufpool.Put(buf) - } - - return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { - for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(*buf) - if n > 0 { - *buf = (*buf)[:n] - c.readQueue.pushBack(buf) - } else if n == 0 { - iobufpool.Put(buf) - } - - if err != nil { - if errors.Is(err, ErrWouldBlock) { - return nil - } else { - return err - } - } - } -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { - stopChan = make(chan struct{}) - errChan = make(chan error, 1) - - go func() { - for { - err := c.BufferReadUntilBlock() - if err != nil { - errChan <- err - return - } - - select { - case <-stopChan: - return - default: - } - } - }() - - return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { - closed := atomic.LoadInt64(&c.closed) - return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingWrite(b) - } else { - return c.realNonblockingWrite(b) - } -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) - if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.conn.SetWriteDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetWriteDeadline(c.writeDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingRead(b) - } else { - return c.realNonblockingRead(b) - } -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - // The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are - // already in Go or the OS's receive buffer. - if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { - b = b[:1] - } - - startTime := time.Now() - deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) - if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.conn.SetReadDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // If the read was successful and the wait duration is not already the minimum - if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { - endTime := time.Now() - - if n > 0 && c.fakeNonBlockingShortReadCount < 5 { - c.fakeNonBlockingShortReadCount++ - } - - // The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that - // a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive - // buffer. - proposedWait := endTime.Sub(startTime) * 2 - if proposedWait < minNonblockingReadWaitDuration { - proposedWait = minNonblockingReadWaitDuration - } - if proposedWait < c.fakeNonblockingReadWaitDuration { - c.fakeNonblockingReadWaitDuration = proposedWait - } - } - - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetReadDeadline(c.readDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { - tc := tls.Client(conn, config) - err := tc.Handshake() - if err != nil { - return nil, err - } - - // Ensure last written part of Handshake is actually sent. - err = conn.Flush() - if err != nil { - return nil, err - } - - return &TLSConn{ - tlsConn: tc, - nbConn: conn, - }, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { - tlsConn *tls.Conn - nbConn *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { - // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then - // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our - // own 5 second deadline then make all set deadlines no-op. - tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) - tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - - return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index 4915c6219..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix - -package nbconn - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index e93372f25..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,81 +0,0 @@ -//go:build unix - -package nbconn - -import ( - "errors" - "io" - "syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - if c.nonblockWriteFunc == nil { - c.nonblockWriteFunc = func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - } - } - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - - err = c.rawConn.Write(c.nonblockWriteFunc) - n = c.nonblockWriteN - c.nonblockWriteBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - if c.nonblockReadFunc == nil { - c.nonblockReadFunc = func(fd uintptr) (done bool) { - c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) - return true - } - } - c.nonblockReadBuf = b - c.nonblockReadN = 0 - c.nonblockReadErr = nil - - err = c.rawConn.Read(c.nonblockReadFunc) - n = c.nonblockReadN - c.nonblockReadBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockReadErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 6ca9e3379..8c4b2de3c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: []byte(sc.clientFinalMessage()), } c.frontend.Send(saslResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 000000000..aa1a3d39c --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,132 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( + "io" + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( + bgReaderStatusStopped = iota + bgReaderStatusRunning + bgReaderStatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { + r io.Reader + + cond *sync.Cond + bgReaderStatus int32 + readResults []readResult +} + +type readResult struct { + buf *[]byte + err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.bgReaderStatus { + case bgReaderStatusStopped: + r.bgReaderStatus = bgReaderStatusRunning + go r.bgRead() + case bgReaderStatusRunning: + // no-op + case bgReaderStatusStopping: + r.bgReaderStatus = bgReaderStatusRunning + } +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.bgReaderStatus { + case bgReaderStatusStopped: + // no-op + case bgReaderStatusRunning: + r.bgReaderStatus = bgReaderStatusStopping + case bgReaderStatusStopping: + // no-op + } +} + +func (r *BGReader) bgRead() { + keepReading := true + for keepReading { + buf := iobufpool.Get(8192) + n, err := r.r.Read(*buf) + *buf = (*buf)[:n] + + r.cond.L.Lock() + r.readResults = append(r.readResults, readResult{buf: buf, err: err}) + if r.bgReaderStatus == bgReaderStatusStopping || err != nil { + r.bgReaderStatus = bgReaderStatusStopped + keepReading = false + } + r.cond.L.Unlock() + r.cond.Broadcast() + } +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + if len(r.readResults) > 0 { + return r.readFromReadResults(p) + } + + // There are no unread background read results and the background reader is stopped. + if r.bgReaderStatus == bgReaderStatusStopped { + return r.r.Read(p) + } + + // Wait for results from the background reader + for len(r.readResults) == 0 { + r.cond.Wait() + } + return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { + buf := r.readResults[0].buf + var err error + + n := copy(p, *buf) + if n == len(*buf) { + err = r.readResults[0].err + iobufpool.Put(buf) + if len(r.readResults) == 1 { + r.readResults = nil + } else { + r.readResults = r.readResults[1:] + } + } else { + *buf = (*buf)[n:] + r.readResults[0].buf = buf + } + + return n, err +} + +func New(r io.Reader) *BGReader { + return &BGReader{ + r: r, + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 969675fd2..3c1af3477 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error { Data: nextData, } c.frontend.Send(gssResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 8656ea518..9f84605fe 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -13,11 +13,12 @@ "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/v5/internal/iobufpool" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -65,17 +66,24 @@ type Notification struct { // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + conn net.Conn pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer config *Config status byte // One of connStatus* constants + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -266,14 +274,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = nbNetConn - pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -289,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting - pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -307,9 +316,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.frontend.Send(&startupMsg) - if err := pgConn.frontend.Flush(); err != nil { + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -392,7 +401,7 @@ func() { conn.SetDeadline(time.Time{}) }, ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -407,17 +416,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err return nil, errors.New("server refused TLS connection") } - tlsConn, err := nbconn.TLSClient(conn, tlsConfig) - if err != nil { - return nil, err - } - - return tlsConn, nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) - return pgConn.frontend.Flush() + return pgConn.flushWithPotentialWriteReadDeadlock() } func hexMD5(s string) string { @@ -426,6 +430,24 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -465,13 +487,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - msg, err := pgConn.frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() + } + } else { + msg, err = pgConn.frontend.Receive() + } if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - return nil, err - } - // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -582,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() return pgConn.conn.Close() } @@ -609,7 +643,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() }() } @@ -784,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return nil, err @@ -857,9 +891,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + var serverNetwork string + var serverAddress string + if serverAddr.Network() == "unix" { + // for unix sockets, RemoteAddr() calls getpeername() which returns the name the + // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" + // so connecting to it will fail. Fall back to the config's value + serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) + } else { + serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() + } + cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress) if err != nil { - return err + // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the + // first connect failed, try the config. + if serverAddr.Network() != "unix" { + return err + } + serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) + cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) + if err != nil { + return err + } } defer cancelConn.Close() @@ -877,17 +930,11 @@ func() { cancelConn.SetDeadline(time.Time{}) }, binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + // Postgres will process the request and close the connection + // so when don't need to read the reply + // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 _, err = cancelConn.Write(buf) - if err != nil { - return err - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return err - } - - return nil + return err } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not @@ -953,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() @@ -1064,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() result.concludeCommand(CommandTag{}, err) @@ -1097,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.unlock() @@ -1153,85 +1200,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer pgConn.contextWatcher.Unwatch() } - // Send copy to command + // Send copy from query pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking := true - defer func() { - if nonblocking { - pgConn.conn.SetReadDeadline(time.Time{}) + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + (*buf)[0] = 'd' + + for { + n, readErr := r.Read((*buf)[5:cap(*buf)]) + if n > 0 { + *buf = (*buf)[0 : n+5] + pgio.SetInt32((*buf)[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. + pgConn.conn.Close() + + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: + } } }() - buf := iobufpool.Get(65536) - defer iobufpool.Put(buf) - (*buf)[0] = 'd' - - var readErr, pgErr error - for pgErr == nil { - // Read chunk from r. - var n int - n, readErr = r.Read((*buf)[5:cap(*buf)]) - - // Send chunk to PostgreSQL. - if n > 0 { - *buf = (*buf)[0 : n+5] - pgio.SetInt32((*buf)[1:], int32(n+4)) - - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) - if writeErr != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - } - - // Abort loop if there was a read error. - if readErr != nil { - break - } - - // Read messages until error or none available. - for pgErr == nil { - msg, err := pgConn.receiveMessage() - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - break - } - pgConn.asyncClose() + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } + msg, _ := pgConn.receiveMessage() switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - break + default: + signalMessageChan = pgConn.signalMessage() } } } + close(abortCopyChan) + // Make sure io goroutine finishes before writing. + wg.Wait() - err = pgConn.conn.SetReadDeadline(time.Time{}) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking = false - - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - err = pgConn.frontend.Flush() + err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1426,7 +1479,8 @@ func (rr *ResultReader) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.) func (rr *ResultReader) FieldDescriptions() []FieldDescription { return rr.fieldDescriptions } @@ -1592,7 +1646,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + pgConn.enterPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) + pgConn.exitPotentialWriteReadDeadlock() if err != nil { multiResult.closed = true multiResult.err = err @@ -1620,29 +1676,72 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } -// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and -// buffering until the read would block or an error occurs. This can be used to check if the server has closed the -// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection. func (pgConn *PgConn) CheckConn() error { - err := pgConn.conn.BufferReadUntilBlock() - if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { - return err + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, err := pgConn.ReceiveMessage(ctx) + if err != nil { + if !Timeout(err) { + return err + } } + return nil } +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { + return pgConn.Exec(ctx, "-- ping").Close() +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { return CommandTag{s: string(buf)} } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS + // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for + // the normal case, but short enough not to kill performance if a block occurs. + // + // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is + // ineffective. + pgConn.slowWriteTimer.Reset(15 * time.Millisecond) +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) { + pgConn.slowWriteTimer.Stop() + } +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + pgConn.exitPotentialWriteReadDeadlock() + return err +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + Conn net.Conn PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server @@ -1695,6 +1794,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) { } pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) return pgConn, nil } @@ -1817,7 +1918,7 @@ func (p *Pipeline) Flush() error { return errors.New("pipeline closed") } - err := p.conn.frontend.Flush() + err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 0fa4c129b..7dfee389e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -363,12 +363,13 @@ func quoteArrayElement(src string) string { } func isSpace(ch byte) bool { - // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 - return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' + // see array_isspace: + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f' } func quoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index e7be27e2d..71caffa74 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -1,10 +1,12 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/json" "fmt" "strconv" + "strings" ) type BoolScanner interface { @@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } p, ok := (dst).(*bool) @@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return ErrScanTargetTypeChanged } - *p = src[0] == 't' + v, err := planTextToBool(src) + if err != nil { + return err + } + + *p = v return nil } @@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{}) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } - return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) + v, err := planTextToBool(src) + if err != nil { + return err + } + + return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/11/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { + s := string(bytes.ToLower(bytes.TrimSpace(src))) + + switch { + case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": + return true, nil + case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": + return false, nil + default: + return false, fmt.Errorf("unknown boolean string representation %q", src) + } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a2afbe1e..7fddeaa8a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -64,6 +64,9 @@ func underlyingNumberType(val any) (any, bool) { case reflect.String: convVal := refVal.String() return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() } return nil, false @@ -262,7 +265,7 @@ func int64AssignTo(srcVal int64, srcValid bool, dst any) error { *v = uint8(srcVal) case *uint16: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) + return fmt.Errorf("%d is less than zero for uint16", srcVal) } else if srcVal > math.MaxUint16 { return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 4743643e5..9befabd05 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -1,14 +1,11 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "errors" "fmt" "strings" - "unicode" - "unicode/utf8" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -43,7 +40,7 @@ func (h *Hstore) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) + return scanPlanTextAnyToHstoreScanner{}.scanString(src, h) } return fmt.Errorf("cannot scan %T", src) @@ -137,13 +134,20 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e buf = append(buf, ',') } - buf = append(buf, quoteHstoreElementIfNeeded(k)...) + // unconditionally quote hstore keys/values like Postgres does + // this avoids a Mac OS X Postgres hstore parsing bug: + // https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(k)...) + buf = append(buf, '"') buf = append(buf, "=>"...) if v == nil { buf = append(buf, "NULL"...) } else { - buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(*v)...) + buf = append(buf, '"') } } @@ -174,25 +178,28 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } rp := 0 - if len(src[rp:]) < 4 { + const uint32Len = 4 + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len hstore := make(Hstore, pairCount) + // one allocation for all *string, rather than one per string, just like text parsing + valueStrings := make([]string, pairCount) for i := 0; i < pairCount; i++ { - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len if len(src[rp:]) < keyLen { return fmt.Errorf("hstore incomplete %v", src) @@ -200,26 +207,17 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { key := string(src[rp : rp+keyLen]) rp += keyLen - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 - var valueBuf []byte if valueLen >= 0 { - valueBuf = src[rp : rp+valueLen] + valueStrings[i] = string(src[rp : rp+valueLen]) rp += valueLen - } - var value Text - err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) - if err != nil { - return err - } - - if value.Valid { - hstore[key] = &value.String + hstore[key] = &valueStrings[i] } else { hstore[key] = nil } @@ -230,28 +228,22 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { +func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } + return s.scanString(string(src), scanner) +} - keys, values, err := parseHstore(string(src)) +// scanString does not return nil hstore values because string cannot be nil. +func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { + hstore, err := parseHstore(src) if err != nil { return err } - - m := make(Hstore, len(keys)) - for i := range keys { - if values[i].Valid { - m[keys[i]] = &values[i].String - } else { - m[keys[i]] = nil - } - } - - return scanner.ScanHstore(m) + return scanner.ScanHstore(hstore) } func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { @@ -271,191 +263,217 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( return hstore, nil } -var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) - -func quoteHstoreElement(src string) string { - return `"` + quoteArrayReplacer.Replace(src) + `"` -} - -func quoteHstoreElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { - return quoteArrayElement(src) - } - return src -} - -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - type hstoreParser struct { - str string - pos int + str string + pos int + nextBackslash int } func newHSP(in string) *hstoreParser { return &hstoreParser{ - pos: 0, - str: in, + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), } } -func (p *hstoreParser) Consume() (r rune, end bool) { +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) +} + +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { if p.pos >= len(p.str) { - end = true - return + return 0, true } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return + b = p.str[p.pos] + p.pos++ + return b, false } -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return +func unexpectedByteErr(actualB byte, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { - if s == "" { - return +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) + } + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) + } + return nil +} + +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) + } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) + } + p.pos += 2 + return nil +} + +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) + +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. This copies the string from the backing string so it can be garbage collected. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + // clone the string from the source string to ensure it can be garbage collected separately + // TODO: use strings.Clone on Go 1.20; this could get optimized away + s := strings.Clone(p.str[p.pos:nextDoubleQuote]) + p.pos = nextDoubleQuote + 1 + return s, nil } - buf := bytes.Buffer{} - keys := []string{} - values := []Text{} + // slow path: string contains escapes + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos + } + return s, err +} + +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted + } + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) + } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) + } + } + return builder.String(), nil +} + +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { + // peek at the next byte + if p.atEnd() { + return Text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') + if err != nil { + return Text{}, err + } + err = p.consumeExpected2('L', 'L') + if err != nil { + return Text{}, err + } + return Text{String: "", Valid: false}, nil + } else if next != '"' { + return Text{}, unexpectedByteErr(next, '"') + } + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return Text{}, err + } + return Text{String: s, Valid: true}, nil +} + +func parseHstore(s string) (Hstore, error) { p := newHSP(s) - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, Text{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it + // is less likely to occur in keys/values than '=' or ','. + numPairsEstimate := strings.Count(s, ">") + // makes one allocation of strings for the entire Hstore, rather than one allocation per value. + valueStrings := make([]string, 0, numPairsEstimate) + result := make(Hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return nil, err } + } else { + first = false } + err := p.consumeExpectedByte('"') if err != nil { - return + return nil, err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return nil, err + } + + err = p.consumeKVSeparator() + if err != nil { + return nil, err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return nil, err + } + if value.Valid { + valueStrings = append(valueStrings, value.String) + result[key] = &valueStrings[len(valueStrings)-1] + } else { + result[key] = nil } - r, end = p.Consume() } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return - } - k = keys - v = values - return + + return result, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index 69861bf88..753f24103 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -150,7 +150,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if dstValue.Kind() == reflect.Ptr { el := dstValue.Elem() switch el.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface: el.Set(reflect.Zero(el.Type())) return nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 83b349cee..b9cd7b410 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -147,7 +147,7 @@ func (im InfinityModifier) String() string { BinaryFormatCode = 1 ) -// A Codec converts between Go and PostgreSQL values. +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map. type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -178,6 +178,7 @@ func (e *nullAssignmentError) Error() string { return fmt.Sprintf("cannot assign NULL to %T", e.dst) } +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map. type Type struct { Codec Codec Name string @@ -211,7 +212,9 @@ type Map struct { } func NewMap() *Map { - m := &Map{ + defaultMapInitOnce.Do(initDefaultMap) + + return &Map{ oidToType: make(map[uint32]*Type), nameToType: make(map[string]*Type), reflectTypeToName: make(map[reflect.Type]string), @@ -240,184 +243,9 @@ func NewMap() *Map { TryWrapPtrArrayScanPlan, }, } - - // Base types - m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) - m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - m.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) - m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) - m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) - m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - - // Range types - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) - - // Multirange types - m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) - - // Array types - m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) - m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) - m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) - m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) - m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) - m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) - m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) - m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) - m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) - m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) - m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) - m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) - m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) - m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) - m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) - m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) - m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) - m.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONPathOID]}}) - m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) - m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) - m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) - m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) - m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) - m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) - m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) - m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) - m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) - m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) - m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) - m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) - m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) - m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) - m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) - m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) - m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) - m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - - // Integer types that directly map to a PostgreSQL type - registerDefaultPgTypeVariants[int16](m, "int2") - registerDefaultPgTypeVariants[int32](m, "int4") - registerDefaultPgTypeVariants[int64](m, "int8") - - // Integer types that do not have a direct match to a PostgreSQL type - registerDefaultPgTypeVariants[int8](m, "int8") - registerDefaultPgTypeVariants[int](m, "int8") - registerDefaultPgTypeVariants[uint8](m, "int8") - registerDefaultPgTypeVariants[uint16](m, "int8") - registerDefaultPgTypeVariants[uint32](m, "int8") - registerDefaultPgTypeVariants[uint64](m, "numeric") - registerDefaultPgTypeVariants[uint](m, "numeric") - - registerDefaultPgTypeVariants[float32](m, "float4") - registerDefaultPgTypeVariants[float64](m, "float8") - - registerDefaultPgTypeVariants[bool](m, "bool") - registerDefaultPgTypeVariants[time.Time](m, "timestamptz") - registerDefaultPgTypeVariants[time.Duration](m, "interval") - registerDefaultPgTypeVariants[string](m, "text") - registerDefaultPgTypeVariants[[]byte](m, "bytea") - - registerDefaultPgTypeVariants[net.IP](m, "inet") - registerDefaultPgTypeVariants[net.IPNet](m, "cidr") - registerDefaultPgTypeVariants[netip.Addr](m, "inet") - registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") - - // pgtype provided structs - registerDefaultPgTypeVariants[Bits](m, "varbit") - registerDefaultPgTypeVariants[Bool](m, "bool") - registerDefaultPgTypeVariants[Box](m, "box") - registerDefaultPgTypeVariants[Circle](m, "circle") - registerDefaultPgTypeVariants[Date](m, "date") - registerDefaultPgTypeVariants[Range[Date]](m, "daterange") - registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") - registerDefaultPgTypeVariants[Float4](m, "float4") - registerDefaultPgTypeVariants[Float8](m, "float8") - registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. - registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. - registerDefaultPgTypeVariants[Int2](m, "int2") - registerDefaultPgTypeVariants[Int4](m, "int4") - registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") - registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") - registerDefaultPgTypeVariants[Int8](m, "int8") - registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") - registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") - registerDefaultPgTypeVariants[Interval](m, "interval") - registerDefaultPgTypeVariants[Line](m, "line") - registerDefaultPgTypeVariants[Lseg](m, "lseg") - registerDefaultPgTypeVariants[Numeric](m, "numeric") - registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") - registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") - registerDefaultPgTypeVariants[Path](m, "path") - registerDefaultPgTypeVariants[Point](m, "point") - registerDefaultPgTypeVariants[Polygon](m, "polygon") - registerDefaultPgTypeVariants[TID](m, "tid") - registerDefaultPgTypeVariants[Text](m, "text") - registerDefaultPgTypeVariants[Time](m, "time") - registerDefaultPgTypeVariants[Timestamp](m, "timestamp") - registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") - registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") - registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") - registerDefaultPgTypeVariants[UUID](m, "uuid") - - return m } +// RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t @@ -449,13 +277,22 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { } } +// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated. func (m *Map) TypeForOID(oid uint32) (*Type, bool) { - dt, ok := m.oidToType[oid] + if dt, ok := m.oidToType[oid]; ok { + return dt, true + } + + dt, ok := defaultMap.oidToType[oid] return dt, ok } +// TypeForName returns the Type registered for the given name. The returned Type must not be mutated. func (m *Map) TypeForName(name string) (*Type, bool) { - dt, ok := m.nameToType[name] + if dt, ok := m.nameToType[name]; ok { + return dt, true + } + dt, ok := defaultMap.nameToType[name] return dt, ok } @@ -463,30 +300,39 @@ func (m *Map) buildReflectTypeToType() { m.reflectTypeToType = make(map[reflect.Type]*Type) for reflectType, name := range m.reflectTypeToName { - if dt, ok := m.nameToType[name]; ok { + if dt, ok := m.TypeForName(name); ok { m.reflectTypeToType[reflectType] = dt } } } // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode -// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type +// must not be mutated. func (m *Map) TypeForValue(v any) (*Type, bool) { if m.reflectTypeToType == nil { m.buildReflectTypeToType() } - dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] + if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok { + return dt, true + } + + dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } // FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text // format code. func (m *Map) FormatCodeForOID(oid uint32) int16 { - fc, ok := m.oidToFormatCode[oid] - if ok { + if fc, ok := m.oidToFormatCode[oid]; ok { return fc } + + if fc, ok := defaultMap.oidToFormatCode[oid]; ok { + return fc + } + return TextFormatCode } @@ -587,6 +433,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { return plan.Scan(src, dst) } } + for oid := range defaultMap.oidToType { + if _, ok := plan.m.oidToType[oid]; !ok { + plan := plan.m.planScan(oid, plan.formatCode, dst) + if _, ok := plan.(*scanPlanFail); !ok { + return plan.Scan(src, dst) + } + } + } } var format string @@ -600,7 +454,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { } var dataTypeName string - if t, ok := plan.m.oidToType[plan.oid]; ok { + if t, ok := plan.m.TypeForOID(plan.oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" @@ -666,6 +520,7 @@ type SkipUnderlyingTypePlanner interface { reflect.Float32: reflect.TypeOf(new(float32)), reflect.Float64: reflect.TypeOf(new(float64)), reflect.String: reflect.TypeOf(new(string)), + reflect.Bool: reflect.TypeOf(new(bool)), } type underlyingTypeScanPlan struct { @@ -1089,15 +944,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true } - targetValue := reflect.ValueOf(target) - if targetValue.Kind() != reflect.Ptr { + targetType := reflect.TypeOf(target) + if targetType.Kind() != reflect.Ptr { return nil, nil, false } - targetElemValue := targetValue.Elem() + targetElemType := targetType.Elem() - if targetElemValue.Kind() == reflect.Slice { - return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true + if targetElemType.Kind() == reflect.Slice { + slice := reflect.New(targetElemType).Elem() + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true } return nil, nil, false } @@ -1198,6 +1054,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { } func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { + if target == nil { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } @@ -1514,6 +1374,7 @@ func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, reflect.Float32: reflect.TypeOf(float32(0)), reflect.Float64: reflect.TypeOf(float64(0)), reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeOf(false), } type underlyingTypeEncodePlan struct { @@ -2039,13 +1900,13 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) } var dataTypeName string - if t, ok := m.oidToType[oid]; ok { + if t, ok := m.TypeForOID(oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" } - return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err) } // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go new file mode 100644 index 000000000..58f4b92c7 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -0,0 +1,223 @@ +package pgtype + +import ( + "net" + "net/netip" + "reflect" + "sync" + "time" +) + +var ( + // defaultMap contains default mappings between PostgreSQL server types and Go type handling logic. + defaultMap *Map + defaultMapInitOnce = sync.Once{} +) + +func initDefaultMap() { + defaultMap = &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, + }, + } + + // Base types + defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + + // Range types + defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + + // Multirange types + defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + + // Array types + defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}}) + defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}}) + defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}}) + defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}}) + defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}}) + defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}}) + defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}}) + defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}}) + defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}}) + defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}}) + defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}}) + defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}}) + defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}}) + defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}}) + defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}}) + defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}}) + defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}}) + defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}}) + defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}}) + defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) + defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) + defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) + defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants[int16](defaultMap, "int2") + registerDefaultPgTypeVariants[int32](defaultMap, "int4") + registerDefaultPgTypeVariants[int64](defaultMap, "int8") + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants[int8](defaultMap, "int8") + registerDefaultPgTypeVariants[int](defaultMap, "int8") + registerDefaultPgTypeVariants[uint8](defaultMap, "int8") + registerDefaultPgTypeVariants[uint16](defaultMap, "int8") + registerDefaultPgTypeVariants[uint32](defaultMap, "int8") + registerDefaultPgTypeVariants[uint64](defaultMap, "numeric") + registerDefaultPgTypeVariants[uint](defaultMap, "numeric") + + registerDefaultPgTypeVariants[float32](defaultMap, "float4") + registerDefaultPgTypeVariants[float64](defaultMap, "float8") + + registerDefaultPgTypeVariants[bool](defaultMap, "bool") + registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") + registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") + + registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") + registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr") + registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet") + registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr") + + // pgtype provided structs + registerDefaultPgTypeVariants[Bits](defaultMap, "varbit") + registerDefaultPgTypeVariants[Bool](defaultMap, "bool") + registerDefaultPgTypeVariants[Box](defaultMap, "box") + registerDefaultPgTypeVariants[Circle](defaultMap, "circle") + registerDefaultPgTypeVariants[Date](defaultMap, "date") + registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange") + registerDefaultPgTypeVariants[Float4](defaultMap, "float4") + registerDefaultPgTypeVariants[Float8](defaultMap, "float8") + registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. + registerDefaultPgTypeVariants[Int2](defaultMap, "int2") + registerDefaultPgTypeVariants[Int4](defaultMap, "int4") + registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange") + registerDefaultPgTypeVariants[Int8](defaultMap, "int8") + registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange") + registerDefaultPgTypeVariants[Interval](defaultMap, "interval") + registerDefaultPgTypeVariants[Line](defaultMap, "line") + registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg") + registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange") + registerDefaultPgTypeVariants[Path](defaultMap, "path") + registerDefaultPgTypeVariants[Point](defaultMap, "point") + registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon") + registerDefaultPgTypeVariants[TID](defaultMap, "tid") + registerDefaultPgTypeVariants[Text](defaultMap, "text") + registerDefaultPgTypeVariants[Time](defaultMap, "time") + registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") + registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") + + defaultMap.buildReflectTypeToType() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 9f3de2c59..35d739566 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -3,6 +3,7 @@ import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "strings" "time" @@ -66,6 +67,55 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +func (ts Timestamp) MarshalJSON() ([]byte, error) { + if !ts.Valid { + return []byte("null"), nil + } + + var s string + + switch ts.InfinityModifier { + case Finite: + s = ts.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (ts *Timestamp) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *ts = Timestamp{} + return nil + } + + switch *s { + case "infinity": + *ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *ts = Timestamp{Time: tim, Valid: true} + } + + return nil +} + type TimestampCodec struct{} func (TimestampCodec) FormatSupported(format int16) bool { diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index ffe739b02..cdd72a25f 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -28,12 +28,16 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. + // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by + // calling Close or by Next returning false). If it is called early it may return nil even if there was an error + // executing the query. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag + // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur + // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another @@ -533,13 +537,11 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val for i := 0; i < dstElemType.NumField(); i++ { sf := dstElemType.Field(i) - if sf.PkgPath == "" { - // Handle anonymous struct embedding, but do not try to handle embedded pointers. - if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) - } else { - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) - } + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + } else if sf.PkgPath == "" { + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) } } @@ -565,8 +567,28 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { return &value, err } +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return &value, err +} + type namedStructRowScanner struct { ptrToStruct any + lax bool } func (rs *namedStructRowScanner) ScanRow(rows Rows) error { @@ -578,7 +600,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error { dstElemValue := dstValue.Elem() scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) - if err != nil { return err } @@ -638,7 +659,13 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s colName = sf.Name } fpos := fieldPosByName(fldDescs, colName) - if fpos == -1 || fpos >= len(scanTargets) { + if fpos == -1 { + if rs.lax { + continue + } + return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } + if fpos >= len(scanTargets) && !rs.lax { return nil, fmt.Errorf("cannot find field %s in returned row", colName) } scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index e57142a61..575c17a71 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -44,6 +44,10 @@ type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode DeferrableMode TxDeferrableMode + + // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax + // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. + BeginQuery string } var emptyTxOptions TxOptions @@ -53,6 +57,10 @@ func (txOptions TxOptions) beginSQL() string { return "begin" } + if txOptions.BeginQuery != "" { + return txOptions.BeginQuery + } + var buf strings.Builder buf.Grow(64) // 64 - maximum length of string with available options buf.WriteString("begin") diff --git a/vendor/modules.txt b/vendor/modules.txt index 64dfaaf59..411a11143 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -334,16 +334,16 @@ github.com/jackc/pgproto3/v2 # github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.3.1 +# github.com/jackc/pgx/v5 v5.4.1 ## explicit; go 1.19 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/iobufpool -github.com/jackc/pgx/v5/internal/nbconn github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/internal/bgreader github.com/jackc/pgx/v5/pgconn/internal/ctxwatch github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgtype