3
0
mirror of https://github.com/ergochat/ergo.git synced 2026-04-25 15:28:13 +02:00

Merge pull request #2392 from ergochat/dependabot/go_modules/github.com/jackc/pgx/v5-5.9.2

Bump github.com/jackc/pgx/v5 from 5.8.0 to 5.9.2
This commit is contained in:
Shivaram Lingamneni 2026-04-22 21:50:41 -07:00 committed by GitHub
commit 20dde44b62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 2616 additions and 571 deletions

2
go.mod
View File

@ -27,7 +27,7 @@ require (
github.com/emersion/go-msgauth v0.7.0
github.com/ergochat/webpush-go/v2 v2.0.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/jackc/pgx/v5 v5.8.0
github.com/jackc/pgx/v5 v5.9.2
modernc.org/sqlite v1.42.2
)

4
go.sum
View File

@ -44,8 +44,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=

View File

@ -1,9 +1,14 @@
# See for configurations: https://golangci-lint.run/usage/configuration/
version: 2
version: "2"
linters:
default: none
enable:
- govet
- ineffassign
# See: https://golangci-lint.run/usage/formatters/
formatters:
default: none
enable:
- gofmt # https://pkg.go.dev/cmd/gofmt
- gofumpt # https://github.com/mvdan/gofumpt

View File

@ -1,3 +1,71 @@
# 5.9.2 (April 18, 2026)
Fix SQL Injection via placeholder confusion with dollar quoted string literals (GHSA-j88v-2chj-qfwx)
SQL injection can occur when:
1. The non-default simple protocol is used.
2. A dollar quoted string literal is used in the SQL query.
3. That query contains text that would be would be interpreted outside as a placeholder outside of a string literal.
4. The value of that placeholder is controllable by the attacker.
e.g.
```go
attackValue := `$tag$; drop table canary; --`
_, err = tx.Exec(ctx, `select $tag$ $1 $tag$, $1`, pgx.QueryExecModeSimpleProtocol, attackValue)
```
This is unlikely to occur outside of a contrived scenario.
# 5.9.1 (March 22, 2026)
* Fix: batch result format corruption when using cached prepared statements (reported by Dirkjan Bussink)
# 5.9.0 (March 21, 2026)
This release includes a number of new features such as SCRAM-SHA-256-PLUS support, OAuth authentication support, and
PostgreSQL protocol 3.2 support.
It significantly reduces the amount of network traffic when using prepared statements (which are used automatically by
default) by avoiding unnecessary Describe Portal messages. This also reduces local memory usage.
It also includes multiple fixes for potential DoS due to panic or OOM if connected to a malicious server that sends
deliberately malformed messages.
* Require Go 1.25+
* Add SCRAM-SHA-256-PLUS support (Adam Brightwell)
* Add OAuth authentication support for PostgreSQL 18 (David Schneider)
* Add PostgreSQL protocol 3.2 support (Dirkjan Bussink)
* Add tsvector type support (Adam Brightwell)
* Skip Describe Portal for cached prepared statements reducing network round trips
* Make LoadTypes query easier to support on "postgres-like" servers (Jelte Fennema-Nio)
* Default empty user to current OS user matching libpq behavior (ShivangSrivastava)
* Optimize LRU statement cache with custom linked list and node pooling (Mathias Bogaert)
* Optimize date scanning by replacing regex with manual parsing (Mathias Bogaert)
* Optimize pgio append/set functions with direct byte shifts (Mathias Bogaert)
* Make RowsAffected faster (Abhishek Chanda)
* Fix: Pipeline.Close panic when server sends multiple FATAL errors (Varun Chawla)
* Fix: ContextWatcher goroutine leak (Hank Donnay)
* Fix: stdlib discard connections with open transactions in ResetSession (Jeremy Schneider)
* Fix: pipelineBatchResults.Exec silently swallowing lastRows error
* Fix: ColumnTypeLength using BPCharArrayOID instead of BPCharOID
* Fix: TSVector text encoding returning nil for valid empty tsvector
* Fix: wrong error messages for Int2 and Int4 underflow
* Fix: Numeric nil Int pointer dereference with Valid: true
* Fix: reversed strings.ContainsAny arguments in Numeric.ScanScientific
* Fix: message length parsing on 32-bit platforms
* Fix: FunctionCallResponse.Decode mishandling of signed result size
* Fix: returning wrong error in configTLS when DecryptPEMBlock fails (Maxim Motyshen)
* Fix: misleading ParseConfig error when default_query_exec_mode is invalid (Skarm)
* Fix: missed Unwatch in Pipeline error paths
* Clarify too many failed acquire attempts error message
* Better error wrapping with context and SQL statement (Aneesh Makala)
* Enable govet and ineffassign linters (Federico Guerinoni)
* Guard against various malformed binary messages (arrays, hstore, multirange, protocol messages)
* Fix various godoc comments (ferhat elmas)
* Fix typos in comments (Oleksandr Redko)
# 5.8.0 (December 26, 2025)
* Require Go 1.24+

73
vendor/github.com/jackc/pgx/v5/CLAUDE.md generated vendored Normal file
View File

@ -0,0 +1,73 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
pgx is a PostgreSQL driver and toolkit for Go (`github.com/jackc/pgx/v5`). It provides both a native PostgreSQL interface and a `database/sql` compatible driver. Requires Go 1.25+ and supports PostgreSQL 14+ and CockroachDB.
## Build & Test Commands
```bash
# Run all tests (requires PGX_TEST_DATABASE to be set)
go test ./...
# Run a specific test
go test -run TestFunctionName ./...
# Run tests for a specific package
go test ./pgconn/...
# Run tests with race detector
go test -race ./...
# DevContainer: run tests against specific PostgreSQL versions
./test.sh pg18 # Default: PostgreSQL 18
./test.sh pg16 -run TestConnect # Specific test against PG16
./test.sh crdb # CockroachDB
./test.sh all # All targets (pg14-18 + crdb)
# Format (always run after making changes)
goimports -w .
# Lint
golangci-lint run ./...
```
## Test Database Setup
Tests require `PGX_TEST_DATABASE` environment variable. In the devcontainer, `test.sh` handles this. For local development:
```bash
export PGX_TEST_DATABASE="host=localhost user=postgres password=postgres dbname=pgx_test"
```
The test database needs extensions: `hstore`, `ltree`, and a `uint64` domain. See `testsetup/postgresql_setup.sql` for full setup. Many tests are skipped unless additional `PGX_TEST_*` env vars are set (for TLS, SCRAM, MD5, unix socket, PgBouncer testing).
## Architecture
The codebase is a layered architecture, bottom-up:
- **pgproto3/** — PostgreSQL wire protocol v3 encoder/decoder. Defines `FrontendMessage` and `BackendMessage` types for every protocol message.
- **pgconn/** — Low-level connection layer (roughly libpq-equivalent). Handles authentication, TLS, query execution, COPY protocol, and notifications. `PgConn` is the core type.
- **pgx** (root package) — High-level query interface built on `pgconn`. Provides `Conn`, `Rows`, `Tx`, `Batch`, `CopyFrom`, and generic helpers like `CollectRows`/`ForEachRow`. Includes automatic statement caching (LRU).
- **pgtype/** — Type system mapping between Go and PostgreSQL types (70+ types). Key interfaces: `Codec`, `Type`, `TypeMap`. Custom types (enums, composites, domains) are registered through `TypeMap`.
- **pgxpool/** — Concurrency-safe connection pool built on `puddle/v2`. `Pool` is the main type; wraps `pgx.Conn`.
- **stdlib/**`database/sql` compatibility adapter.
Supporting packages:
- **internal/stmtcache/** — Prepared statement cache with LRU eviction
- **internal/sanitize/** — SQL query sanitization
- **tracelog/** — Logging adapter that implements tracer interfaces
- **multitracer/** — Composes multiple tracers into one
- **pgxtest/** — Test helpers for running tests across connection types
## Key Design Conventions
- **Semantic versioning** — strictly followed. Do not break the public API (no removing or renaming exported types, functions, methods, or fields; no changing function signatures).
- **Minimal dependencies** — adding new dependencies is strongly discouraged (see CONTRIBUTING.md).
- **Context-based** — all blocking operations take `context.Context`.
- **Tracer interfaces** — observability via `QueryTracer`, `BatchTracer`, `CopyFromTracer`, `PrepareTracer` on `ConnConfig.Tracer`.
- **Formatting** — always run `goimports -w .` after making changes to ensure code is properly formatted. CI checks formatting via `gofmt -l -s -w . && git diff --exit-code`. `gofumpt` with extra rules is also enforced via `golangci-lint`.
- **Linters**`govet` and `ineffassign` only (configured in `.golangci.yml`).
- **CI matrix** — tests run against Go 1.25/1.26 × PostgreSQL 14-18 + CockroachDB, on Linux and Windows. Race detector enabled on Linux only.

View File

@ -10,6 +10,18 @@ proposal. This will help to ensure your proposed change has a reasonable chance
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
that adds a dependency is no.
## AI
Using AI is acceptable (not that it can really be stopped) under one the following conditions.
* AI was used, but you deeply understand the code and you can answer questions regarding your change. You are not going
to answer questions with "I don't know", AI did it. You are not going to "answer" questions by relaying them to your
agent. This is wasteful of the code reviewer's time.
* AI was used to solve a problem without your deep understanding. This can still be a good starting point for a fix or
feature. But you need to clearly state that this is an AI proposal. You should include additional information such as
the AI used and what prompts were used. You should also be aware that large, complicated, or subtle changes may be
rejected simply because the reviewer is not confident in a change that no human understands.
## Development Environment Setup
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
@ -17,7 +29,12 @@ environment variable. The `PGX_TEST_DATABASE` environment variable can either be
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
simplify environment variable handling.
### Using an Existing PostgreSQL Cluster
### Devcontainer
The easiest way to start development is with the included devcontainer. It includes containers for each supported
PostgreSQL version as well as CockroachDB. `./test.sh all` will run the tests against all database types.
### Using an Existing PostgreSQL Cluster Outside of a Devcontainer
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
@ -49,7 +66,7 @@ go test ./...
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
### Creating a New PostgreSQL Cluster Exclusively for Testing
### Creating a New PostgreSQL Cluster Exclusively for Testing Outside of a Devcontainer
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
@ -63,10 +80,11 @@ export POSTGRESQL_DATA_DIR=postgresql
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test channel_binding=disable"
export PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test channel_binding=require"
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem channel_binding=disable"
export PGX_SSL_PASSWORD=certpw
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
```

View File

@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.24 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.25 and higher and PostgreSQL 14 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy
@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/ColeBurch/pgx-govalues-decimal](https://github.com/ColeBurch/pgx-govalues-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)

View File

@ -8,7 +8,7 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)
// QueuedQuery is a query that has been queued for execution via a Batch.
// QueuedQuery is a query that has been queued for execution via a [Batch].
type QueuedQuery struct {
SQL string
Arguments []any
@ -46,7 +46,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
//
// Note: for simple batch insert uses where it is not required to handle
// each potential error individually, it's sufficient to not set any callbacks,
// and just handle the return value of BatchResults.Close.
// and just handle the return value of [BatchResults.Close].
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.Fn = func(br BatchResults) error {
ct, err := br.Exec()
@ -65,12 +65,13 @@ type Batch struct {
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
// argument that is supported is [QueryRewriter]. Queries are executed using the connection's DefaultQueryExecMode
// (see [ConnConfig.DefaultQueryExecMode]).
//
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
// include the current query may reference the wrong query.
// While query can contain multiple statements if the connection's DefaultQueryExecMode is [QueryExecModeSimpleProtocol],
// this should be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is,
// [QueuedQuery.Query], [QueuedQuery.QueryRow], and [QueuedQuery.Exec] must not be called. In addition, any error
// messages or tracing that include the current query may reference the wrong query.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
SQL: query,
@ -86,20 +87,20 @@ func (b *Batch) Len() int {
}
type BatchResults interface {
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
// Exec reads the results from the next query in the batch as if the query has been sent with [Conn.Exec]. Prefer
// calling Exec on the QueuedQuery, or just calling Close.
Exec() (pgconn.CommandTag, error)
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
// calling Query on the QueuedQuery.
// Query reads the results from the next query in the batch as if the query has been sent with [Conn.Query]. Prefer
// calling [QueuedQuery.Query].
Query() (Rows, error)
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
// Prefer calling QueryRow on the QueuedQuery.
// QueryRow reads the results from the next query in the batch as if the query has been sent with [Conn.QueryRow].
// Prefer calling [QueuedQuery.QueryRow].
QueryRow() Row
// Close closes the batch operation. All unread results are read and any callback functions registered with
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
// [QueuedQuery.Query], [QueuedQuery.QueryRow], or [QueuedQuery.Exec] will be called. If a callback function returns an
// error or the batch encounters an error subsequent callback functions will not be called.
//
// For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks,
@ -272,7 +273,7 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
ok = true
br.qqIdx++
}
return
return query, args, ok
}
type pipelineBatchResults struct {
@ -296,6 +297,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return pgconn.CommandTag{}, br.err
}
@ -505,3 +507,31 @@ func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
}
}
}
// ErrPreprocessingBatch occurs when an error is encountered while preprocessing a batch.
// The two preprocessing steps are "prepare" (server-side SQL parse/plan) and
// "build" (client-side argument encoding).
type ErrPreprocessingBatch struct {
step string // "prepare" or "build"
sql string
err error
}
func newErrPreprocessingBatch(step, sql string, err error) ErrPreprocessingBatch {
return ErrPreprocessingBatch{step: step, sql: sql, err: err}
}
func (e ErrPreprocessingBatch) Error() string {
// intentionally not including the SQL query in the error message
// to avoid leaking potentially sensitive information into logs.
// If the user wants the SQL, they can call SQL().
return fmt.Sprintf("error preprocessing batch (%s): %v", e.step, e.err)
}
func (e ErrPreprocessingBatch) Unwrap() error {
return e.err
}
func (e ErrPreprocessingBatch) SQL() string {
return e.sql
}

View File

@ -17,8 +17,8 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and
// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
// ConnConfig contains all the options used to establish a connection. It must be created by [ParseConfig] and
// then it can be modified. A manually initialized ConnConfig will cause [ConnectConfig] to panic.
type ConnConfig struct {
pgconn.Config
@ -37,8 +37,8 @@ type ConnConfig struct {
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
// PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
// PGBouncer. In this case it may be preferable to use [QueryExecModeExec] or [QueryExecModeSimpleProtocol]. The same
// functionality can be controlled on a per query basis by passing a [QueryExecMode] as the first query argument.
DefaultQueryExecMode QueryExecMode
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
@ -131,7 +131,7 @@ var (
)
// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
// [pgconn.Connect] for details.
func Connect(ctx context.Context, connString string) (*Conn, error) {
connConfig, err := ParseConfig(connString)
if err != nil {
@ -141,7 +141,7 @@ func Connect(ctx context.Context, connString string) (*Conn, error) {
}
// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
// provide a GetSSLPassword function.
// provide a [pgconn.GetSSLPasswordFunc] function.
func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
connConfig, err := ParseConfigWithOptions(connString, options)
if err != nil {
@ -151,7 +151,7 @@ func ConnectWithOptions(ctx context.Context, connString string, options ParseCon
}
// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
// connConfig must have been created by ParseConfig.
// connConfig must have been created by [ParseConfig].
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
// In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other
// connections with the same config. See https://github.com/jackc/pgx/issues/618.
@ -160,8 +160,8 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig)
}
// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
// only used to provide a GetSSLPassword function.
// ParseConfigWithOptions behaves exactly as [ParseConfig] does with the addition of options. At the present options is
// only used to provide a [pgconn.GetSSLPasswordFunc] function.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
if err != nil {
@ -203,7 +203,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "simple_protocol":
defaultQueryExecMode = QueryExecModeSimpleProtocol
default:
return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err)
return nil, pgconn.NewParseConfigError(
connString, "invalid default_query_exec_mode", fmt.Errorf("unknown value %q", s),
)
}
}
@ -306,8 +308,8 @@ func (c *Conn) Close(ctx context.Context) error {
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
// Exec to execute the statement. It can also be used with Batch.Queue.
// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with [Conn.Query],
// [Conn.QueryRow], and [Conn.Exec] to execute the statement. It can also be used with [Batch.Queue].
//
// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
// name == sql.
@ -608,7 +610,7 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return pgconn.CommandTag{}, err
}
result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
result := c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
return result.CommandTag, result.Err
}
@ -842,7 +844,7 @@ optionLoop:
if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
rows.resultReader = c.pgConn.ExecStatement(ctx, sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
}
} else if mode == QueryExecModeExec {
err := c.eqb.Build(c.typeMap, nil, args)
@ -931,7 +933,7 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
}
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// explicit transaction control statements are executed. The returned [BatchResults] must be closed before the connection
// is used again.
//
// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table
@ -1192,7 +1194,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
return newErrPreprocessingBatch("prepare", sd.SQL, err)
}
resultSD, ok := results.(*pgconn.StatementDescription)
@ -1226,15 +1228,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
err = newErrPreprocessingBatch("build", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
if bi.sd.Name == "" {
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
// Copy ResultFormats because SendQueryStatement stores the slice for later use, and eqb.Build reuses the
// backing array on the next iteration.
resultFormats := make([]int16, len(c.eqb.ResultFormats))
copy(resultFormats, c.eqb.ResultFormats)
pipeline.SendQueryStatement(bi.sd, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
}
}
@ -1272,7 +1277,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
// LoadType inspects the database for typeName and produces a [pgtype.Type] suitable for registration. typeName must be
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
// typeName must be one of the following:
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.

View File

@ -10,8 +10,8 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
// making it usable by *Conn.CopyFrom.
// CopyFromRows returns a [CopyFromSource] interface over the provided rows slice
// making it usable by [Conn.CopyFrom].
func CopyFromRows(rows [][]any) CopyFromSource {
return &copyFromRows{rows: rows, idx: -1}
}
@ -34,8 +34,8 @@ func (ctr *copyFromRows) Err() error {
return nil
}
// CopyFromSlice returns a CopyFromSource interface over a dynamic func
// making it usable by *Conn.CopyFrom.
// CopyFromSlice returns a [CopyFromSource] interface over a dynamic func
// making it usable by [Conn.CopyFrom].
func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
return &copyFromSlice{next: next, idx: -1, len: length}
}
@ -64,7 +64,7 @@ func (cts *copyFromSlice) Err() error {
return cts.err
}
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// CopyFromFunc returns a [CopyFromSource] interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
@ -91,7 +91,7 @@ func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
// CopyFromSource is the interface used by [Conn.CopyFrom] as the source for copy data.
type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data
// available to Values(). When there are no more rows available or an error
@ -260,8 +260,8 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
// CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// for the type of each column. Almost all types implemented by pgx support the binary format.
//
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
// Even though enum types appear to be strings they still must be registered to use with [Conn.CopyFrom]. This can be done with
// [Conn.LoadType] and [pgtype.Map.RegisterType].
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
ct := &copyFrom{
conn: c,

View File

@ -24,7 +24,7 @@ func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
// This should not occur; this will not return any types
typeNamesClause = "= ''"
} else {
typeNamesClause = "= ANY($1)"
typeNamesClause = "= ANY($1::text[])"
}
parts := make([]string, 0, 10)
@ -169,7 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
rows, err := c.Query(ctx, sql, QueryResultFormats{TextFormatCode}, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
@ -227,7 +227,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
// the type_ is imposible to be null
// the type_ is impossible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}

View File

@ -1,8 +1,8 @@
// Package pgx is a PostgreSQL database driver.
/*
pgx provides a native PostgreSQL driver and can act as a database/sql driver. The native PostgreSQL interface is similar
to the database/sql interface while providing better speed and access to PostgreSQL specific features. Use
github.com/jackc/pgx/v5/stdlib to use pgx as a database/sql compatible driver. See that package's documentation for
pgx provides a native PostgreSQL driver and can act as a [database/sql/driver]. The native PostgreSQL interface is similar
to the [database/sql] interface while providing better speed and access to PostgreSQL specific features. Use
[github.com/jackc/pgx/v5/stdlib] to use pgx as a database/sql compatible driver. See that package's documentation for
details.
Establishing a Connection
@ -19,15 +19,15 @@ string.
Connection Pool
[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
[github.com/jackc/pgx/v5/pgxpool] for a concurrency safe connection pool.
Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Scan, and rows.Err().
pgx implements [Conn.Query] in the familiar database/sql style. However, pgx provides generic functions such as [CollectRows] and
[ForEachRow] that are a simpler and safer way of processing rows than manually calling defer [Rows.Close], [Rows.Next],
[Rows.Scan], and [Rows.Err].
CollectRows can be used collect all returned rows into a slice.
[CollectRows] can be used collect all returned rows into a slice.
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 5)
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
@ -36,7 +36,7 @@ CollectRows can be used collect all returned rows into a slice.
}
// numbers => [1 2 3 4 5]
ForEachRow can be used to execute a callback function for every row. This is often easier than iterating over rows
[ForEachRow] can be used to execute a callback function for every row. This is often easier than iterating over rows
directly.
var sum, n int32
@ -49,7 +49,7 @@ directly.
return err
}
pgx also implements QueryRow in the same style as database/sql.
pgx also implements [Conn.QueryRow] in the same style as database/sql.
var name string
var weight int64
@ -58,7 +58,7 @@ pgx also implements QueryRow in the same style as database/sql.
return err
}
Use Exec to execute a query that does not return a result set.
Use [Conn.Exec] to execute a query that does not return a result set.
commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42)
if err != nil {
@ -70,13 +70,13 @@ Use Exec to execute a query that does not return a result set.
PostgreSQL Data Types
pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
pgx uses the [pgtype] package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
Transactions
Transactions are started by calling Begin.
Transactions are started by calling [Conn.Begin].
tx, err := conn.Begin(context.Background())
if err != nil {
@ -96,13 +96,13 @@ Transactions are started by calling Begin.
return err
}
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
The [Tx] returned from [Conn.Begin] also implements the [Tx.Begin] method. This can be used to implement pseudo nested transactions.
These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
Use [Conn.BeginTx] to control the transaction mode. [Conn.BeginTx] also can be used to ensure a new transaction is created instead of
a pseudo nested transaction.
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
[BeginFunc] and [BeginTxFunc] are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use.
err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
@ -115,16 +115,16 @@ transaction depending on the return value of the function. These can be simpler
Prepared Statements
Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx
includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are
automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig
for information on how to customize or disable the statement cache.
Prepared statements can be manually created with the [Conn.Prepare] method. However, this is rarely necessary because pgx
includes an automatic statement cache by default. Queries run through the normal [Conn.Query], [Conn.QueryRow], and [Conn.Exec]
functions are automatically prepared on first execution and the prepared statement is reused on subsequent executions.
See [ParseConfig] for information on how to customize or disable the statement cache.
Copy Protocol
Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a
CopyFromSource interface. If the data is already in a [][]any use CopyFromRows to wrap it in a CopyFromSource interface.
Or implement CopyFromSource to avoid buffering the entire data set in memory.
Use [Conn.CopyFrom] to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. [Conn.CopyFrom] accepts a
[CopyFromSource] interface. If the data is already in a [][]any use [CopyFromRows] to wrap it in a [CopyFromSource] interface.
Or implement [CopyFromSource] to avoid buffering the entire data set in memory.
rows := [][]any{
{"John", "Smith", int32(36)},
@ -138,7 +138,7 @@ Or implement CopyFromSource to avoid buffering the entire data set in memory.
pgx.CopyFromRows(rows),
)
When you already have a typed array using CopyFromSlice can be more convenient.
When you already have a typed array using [CopyFromSlice] can be more convenient.
rows := []User{
{"John", "Smith", 36},
@ -158,7 +158,7 @@ CopyFrom can be faster than an insert with as few as 5 rows.
Listen and Notify
pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a
pgx can listen to the PostgreSQL notification system with the [Conn.WaitForNotification] method. It blocks until a
notification is received or the context is canceled.
_, err := conn.Exec(context.Background(), "listen channelname")
@ -175,20 +175,25 @@ notification is received or the context is canceled.
Tracing and Logging
pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
pgx supports tracing by setting [ConnConfig.Tracer]. To combine several tracers you can use the [github.com/jackc/pgx/v5/multitracer.Tracer].
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
In addition, the [github.com/jackc/pgx/v5/tracelog] package provides the [github.com/jackc/pgx/v5/tracelog.TraceLog] type which lets a
traditional logger act as a [QueryTracer].
For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3.
For debug tracing of the actual PostgreSQL wire protocol messages see [github.com/jackc/pgx/v5/pgproto3].
Lower Level PostgreSQL Functionality
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
[github.com/jackc/pgx/v5/pgconn] contains a lower level PostgreSQL driver roughly at the level of libpq. [Conn] is
implemented on top of [pgconn.PgConn]. The [Conn.PgConn] method can be used to access this lower layer.
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
disabled by setting a different [QueryExecMode] in [ConnConfig.DefaultQueryExecMode].
*/
package pgx
import (
_ "github.com/jackc/pgx/v5/pgconn" // Just for allowing godoc to resolve "pgconn"
)

View File

@ -1,26 +1,18 @@
package pgio
import "encoding/binary"
func AppendUint16(buf []byte, n uint16) []byte {
wp := len(buf)
buf = append(buf, 0, 0)
binary.BigEndian.PutUint16(buf[wp:], n)
return buf
return append(buf, byte(n>>8), byte(n))
}
func AppendUint32(buf []byte, n uint32) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0)
binary.BigEndian.PutUint32(buf[wp:], n)
return buf
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func AppendUint64(buf []byte, n uint64) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[wp:], n)
return buf
return append(buf,
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n),
)
}
func AppendInt16(buf []byte, n int16) []byte {
@ -36,5 +28,5 @@ func AppendInt64(buf []byte, n int64) []byte {
}
func SetInt32(buf []byte, n int32) {
binary.BigEndian.PutUint32(buf, uint32(n))
*(*[4]byte)(buf) = [4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}
}

View File

@ -42,7 +42,7 @@ for i in "${!commits[@]}"; do
exit 1
}
# Sanitized commmit message
# Sanitized commit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/hex"
"fmt"
"math"
"slices"
"strconv"
"strings"
@ -202,12 +203,13 @@ func QuoteBytes(dst, buf []byte) []byte {
}
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []Part
src string
start int
pos int
nested int // multiline comment nesting level.
dollarTag string // active tag while inside a dollar-quoted string (may be empty for $$).
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
@ -237,6 +239,15 @@ func rawState(l *sqlLexer) stateFn {
l.start = l.pos
return placeholderState
}
// PostgreSQL dollar-quoted string: $[tag]$...$[tag]$. The $ was
// just consumed; try to match the rest of the opening tag.
// Without this, placeholders embedded inside dollar-quoted
// literals would be incorrectly substituted.
if tagLen, ok := scanDollarQuoteTag(l.src[l.pos:]); ok {
l.dollarTag = l.src[l.pos : l.pos+tagLen]
l.pos += tagLen + 1 // advance past tag and closing '$'
return dollarQuoteState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
@ -319,8 +330,16 @@ func placeholderState(l *sqlLexer) stateFn {
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
// Clamp rather than silently wrap on pathological input like
// "$92233720368547758070" which would otherwise overflow int and
// could land on a valid args index. Any value above MaxInt32 far
// exceeds any plausible args length, so Sanitize will correctly
// return "insufficient arguments".
if num > (math.MaxInt32-9)/10 {
num = math.MaxInt32
} else {
num = num*10 + int(r-'0')
}
} else {
l.parts = append(l.parts, num)
l.pos -= width
@ -330,6 +349,68 @@ func placeholderState(l *sqlLexer) stateFn {
}
}
// dollarQuoteState consumes the body of a PostgreSQL dollar-quoted string
// ($[tag]$...$[tag]$). The opening tag (including its terminating '$') has
// already been consumed.
func dollarQuoteState(l *sqlLexer) stateFn {
closer := "$" + l.dollarTag + "$"
idx := strings.Index(l.src[l.pos:], closer)
if idx < 0 {
// Unterminated — mirror the behavior of other quoted-string states by
// consuming the remaining input into the current part and stopping.
if len(l.src)-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:])
l.start = len(l.src)
}
l.pos = len(l.src)
return nil
}
l.pos += idx + len(closer)
l.dollarTag = ""
return rawState
}
// scanDollarQuoteTag checks whether src begins with an optional dollar-quoted
// string tag followed by a closing '$'. src must point just past the opening
// '$'. Returns the byte length of the tag (zero for an anonymous $$) and
// whether a valid tag was found.
//
// Tag grammar matches the PostgreSQL lexer (scan.l):
//
// dolq_start: [A-Za-z_\x80-\xff]
// dolq_cont: [A-Za-z0-9_\x80-\xff]
func scanDollarQuoteTag(src string) (int, bool) {
first := true
for i := 0; i < len(src); {
r, w := utf8.DecodeRuneInString(src[i:])
if r == '$' {
return i, true
}
if !isDollarTagRune(r, first) {
return 0, false
}
first = false
i += w
}
return 0, false
}
func isDollarTagRune(r rune, first bool) bool {
switch {
case r == '_':
return true
case 'a' <= r && r <= 'z':
return true
case 'A' <= r && r <= 'Z':
return true
case !first && '0' <= r && r <= '9':
return true
case r >= 0x80 && r != utf8.RuneError:
return true
}
return false
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])

View File

@ -1,38 +1,54 @@
package stmtcache
import (
"container/list"
"github.com/jackc/pgx/v5/pgconn"
)
// lruNode is a typed doubly-linked list node with freelist support.
type lruNode struct {
sd *pgconn.StatementDescription
prev *lruNode
next *lruNode
}
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
cap int
m map[string]*list.Element
l *list.List
m map[string]*lruNode
head *lruNode
tail *lruNode
len int
cap int
freelist *lruNode
invalidStmts []*pgconn.StatementDescription
invalidSet map[string]struct{}
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
head := &lruNode{}
tail := &lruNode{}
head.next = tail
tail.prev = head
return &LRUCache{
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
m: make(map[string]*lruNode, cap),
head: head,
tail: tail,
invalidSet: make(map[string]struct{}),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
if el, ok := c.m[key]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription)
node, ok := c.m[key]
if !ok {
return nil
}
return nil
c.moveToFront(node)
return node.sd
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
@ -51,37 +67,45 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
return
}
if c.l.Len() == c.cap {
if c.len == c.cap {
c.invalidateOldest()
}
el := c.l.PushFront(sd)
c.m[sd.SQL] = el
node := c.allocNode()
node.sd = sd
c.insertAfter(c.head, node)
c.m[sd.SQL] = node
c.len++
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
if el, ok := c.m[sql]; ok {
delete(c.m, sql)
sd := el.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
c.invalidSet[sql] = struct{}{}
c.l.Remove(el)
node, ok := c.m[sql]
if !ok {
return
}
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[sql] = struct{}{}
c.unlink(node)
c.len--
c.freeNode(node)
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
el := c.l.Front()
for el != nil {
sd := el.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
c.invalidSet[sd.SQL] = struct{}{}
el = el.Next()
for node := c.head.next; node != c.tail; {
next := node.next
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[node.sd.SQL] = struct{}{}
c.freeNode(node)
node = next
}
c.m = make(map[string]*list.Element)
c.l = list.New()
clear(c.m)
c.head.next = c.tail
c.tail.prev = c.head
c.len = 0
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
@ -93,13 +117,13 @@ func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
c.invalidSet = make(map[string]struct{})
c.invalidStmts = c.invalidStmts[:0]
clear(c.invalidSet)
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.l.Len()
return c.len
}
// Cap returns the maximum number of cached prepared statement descriptions.
@ -108,10 +132,56 @@ func (c *LRUCache) Cap() int {
}
func (c *LRUCache) invalidateOldest() {
oldest := c.l.Back()
sd := oldest.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
c.invalidSet[sd.SQL] = struct{}{}
delete(c.m, sd.SQL)
c.l.Remove(oldest)
node := c.tail.prev
if node == c.head {
return
}
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[node.sd.SQL] = struct{}{}
delete(c.m, node.sd.SQL)
c.unlink(node)
c.len--
c.freeNode(node)
}
// List operations - sentinel nodes eliminate nil checks
func (c *LRUCache) insertAfter(at, node *lruNode) {
node.prev = at
node.next = at.next
at.next.prev = node
at.next = node
}
func (c *LRUCache) unlink(node *lruNode) {
node.prev.next = node.next
node.next.prev = node.prev
}
func (c *LRUCache) moveToFront(node *lruNode) {
if node.prev == c.head {
return
}
c.unlink(node)
c.insertAfter(c.head, node)
}
// Node pool operations - reuse evicted nodes to avoid allocations
func (c *LRUCache) allocNode() *lruNode {
if c.freelist != nil {
node := c.freelist
c.freelist = node.next
node.next = nil
node.prev = nil
return node
}
return &lruNode{}
}
func (c *LRUCache) freeNode(node *lruNode) {
node.sd = nil
node.prev = nil
node.next = c.freelist
c.freelist = node
}

67
vendor/github.com/jackc/pgx/v5/pgconn/auth_oauth.go generated vendored Normal file
View File

@ -0,0 +1,67 @@
package pgconn
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/jackc/pgx/v5/pgproto3"
)
func (c *PgConn) oauthAuth(ctx context.Context) error {
if c.config.OAuthTokenProvider == nil {
return errors.New("OAuth authentication required but no token provider configured")
}
token, err := c.config.OAuthTokenProvider(ctx)
if err != nil {
return fmt.Errorf("failed to obtain OAuth token: %w", err)
}
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "OAUTHBEARER",
Data: initialResponse,
}
c.frontend.Send(saslInitialResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}
msg, err := c.receiveMessage()
if err != nil {
return err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationOk:
return nil
case *pgproto3.AuthenticationSASLContinue:
// Server sent error response in SASL continue
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
errResponse := struct {
Status string `json:"status"`
Scope string `json:"scope"`
OpenIDConfiguration string `json:"openid-configuration"`
}{}
err := json.Unmarshal(m.Data, &errResponse)
if err != nil {
return fmt.Errorf("invalid OAuth error response from server: %w", err)
}
// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
// However, since the connection will be closed anyway, we can skip this
return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status)
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(m)
default:
return fmt.Errorf("unexpected message type during OAuth auth: %T", msg)
}
}

View File

@ -1,7 +1,8 @@
// SCRAM-SHA-256 authentication
// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc5929
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
@ -18,9 +19,13 @@ import (
"crypto/pbkdf2"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"hash"
"slices"
"strconv"
@ -28,7 +33,11 @@ import (
"golang.org/x/text/secure/precis"
)
const clientNonceLen = 18
const (
clientNonceLen = 18
scramSHA256Name = "SCRAM-SHA-256"
scramSHA256PlusName = "SCRAM-SHA-256-PLUS"
)
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
@ -37,9 +46,35 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
return err
}
serverHasPlus := slices.Contains(sc.serverAuthMechanisms, scramSHA256PlusName)
if c.config.ChannelBinding == "require" && !serverHasPlus {
return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS")
}
// If we have a TLS connection and channel binding is not disabled, attempt to
// extract the server certificate hash for tls-server-end-point channel binding.
if tlsConn, ok := c.conn.(*tls.Conn); ok && c.config.ChannelBinding != "disable" {
certHash, err := getTLSCertificateHash(tlsConn)
if err != nil && c.config.ChannelBinding == "require" {
return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", err)
}
// Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it.
if certHash != nil && serverHasPlus {
sc.authMechanism = scramSHA256PlusName
}
sc.channelBindingData = certHash
sc.hasTLS = true
}
if c.config.ChannelBinding == "require" && sc.channelBindingData == nil {
return errors.New("channel binding required but channel binding data is not available")
}
// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
AuthMechanism: sc.authMechanism,
Data: sc.clientFirstMessage(),
}
c.frontend.Send(saslInitialResponse)
@ -111,7 +146,28 @@ type scramClient struct {
password string
clientNonce []byte
// authMechanism is the selected SASL mechanism for the client. Must be
// either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS.
//
// Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding
// is not disabled, channel binding data is available (TLS connection with
// an obtainable server certificate hash) and the server advertises
// SCRAM-SHA-256-PLUS.
authMechanism string
// hasTLS indicates whether the connection is using TLS. This is
// needed because the GS2 header must distinguish between a client that
// supports channel binding but the server does not ("y,,") versus one
// that does not support it at all ("n,,").
hasTLS bool
// channelBindingData is the hash of the server's TLS certificate, computed
// per the tls-server-end-point channel binding type (RFC 5929). Used as
// the binding input in SCRAM-SHA-256-PLUS. nil when not in use.
channelBindingData []byte
clientFirstMessageBare []byte
clientGS2Header []byte
serverFirstMessage []byte
clientAndServerNonce []byte
@ -125,11 +181,14 @@ type scramClient struct {
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
authMechanism: scramSHA256Name,
}
// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := slices.Contains(sc.serverAuthMechanisms, "SCRAM-SHA-256")
if !hasScramSHA256 {
// Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the
// channel binding variant and is only advertised when the server supports
// SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism
// regardless of SSL.
if !slices.Contains(sc.serverAuthMechanisms, scramSHA256Name) {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
@ -153,8 +212,32 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien
}
func (sc *scramClient) clientFirstMessage() []byte {
// The client-first-message is the GS2 header concatenated with the bare
// message (username + client nonce). The GS2 header communicates the
// client's channel binding capability to the server:
//
// "n,," - client is not using TLS (channel binding not possible)
// "y,," - client is using TLS but channel binding is not
// in use (e.g., server did not advertise SCRAM-SHA-256-PLUS
// or the server certificate hash was not obtainable)
// "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS
//
// See:
// https://www.rfc-editor.org/rfc/rfc5802#section-6
// https://www.rfc-editor.org/rfc/rfc5929#section-4
// https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce)
return fmt.Appendf(nil, "n,,%s", sc.clientFirstMessageBare)
if sc.authMechanism == scramSHA256PlusName {
sc.clientGS2Header = []byte("p=tls-server-end-point,,")
} else if sc.hasTLS {
sc.clientGS2Header = []byte("y,,")
} else {
sc.clientGS2Header = []byte("n,,")
}
return append(sc.clientGS2Header, sc.clientFirstMessageBare...)
}
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
@ -213,7 +296,19 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
}
func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=biws,r=%s", sc.clientAndServerNonce)
// The c= attribute carries the base64-encoded channel binding input.
//
// Without channel binding this is just the GS2 header alone ("biws" for
// "n,," or "eSws" for "y,,").
//
// With channel binding, this is the GS2 header with the channel binding data
// (certificate hash) appended.
channelBindInput := sc.clientGS2Header
if sc.authMechanism == scramSHA256PlusName {
channelBindInput = slices.Concat(sc.clientGS2Header, sc.channelBindingData)
}
channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput)
clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce)
var err error
sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32)
@ -269,3 +364,36 @@ func computeServerSignature(saltedPassword, authMessage []byte) []byte {
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}
// Get the server certificate hash for SCRAM channel binding type
// tls-server-end-point.
func getTLSCertificateHash(conn *tls.Conn) ([]byte, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, errors.New("no peer certificates for channel binding")
}
cert := state.PeerCertificates[0]
// Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses
// MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature
// algorithm.
//
// See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1
var h hash.Hash
switch cert.SignatureAlgorithm {
case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1:
h = sha256.New()
case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256:
h = sha256.New()
case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384:
h = sha512.New384()
case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512:
h = sha512.New()
default:
return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", cert.SignatureAlgorithm)
}
h.Write(cert.Raw)
return h.Sum(nil), nil
}

View File

@ -83,6 +83,23 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
// OAuthTokenProvider is a function that returns an OAuth token for authentication. If set, it will be used for
// OAUTHBEARER SASL authentication when the server requests it.
OAuthTokenProvider func(context.Context) (string, error)
// MinProtocolVersion is the minimum acceptable PostgreSQL protocol version.
// If the server does not support at least this version, the connection will fail.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0".
MinProtocolVersion string
// MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility.
MaxProtocolVersion string
// ChannelBinding is the channel_binding parameter for SCRAM-SHA-256-PLUS authentication.
// Valid values: "disable", "prefer", "require". Defaults to "prefer".
ChannelBinding string
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
@ -213,6 +230,8 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGTZ
// PGMINPROTOCOLVERSION
// PGMAXPROTOCOLVERSION
//
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
//
@ -338,6 +357,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"target_session_attrs": {},
"service": {},
"servicefile": {},
"min_protocol_version": {},
"max_protocol_version": {},
"channel_binding": {},
}
// Adding kerberos configuration
@ -430,6 +452,52 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
minProto, err := parseProtocolVersion(settings["min_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid min_protocol_version: %q", settings["min_protocol_version"]), err: err}
}
maxProto, err := parseProtocolVersion(settings["max_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid max_protocol_version: %q", settings["max_protocol_version"]), err: err}
}
config.MinProtocolVersion = settings["min_protocol_version"]
config.MaxProtocolVersion = settings["max_protocol_version"]
if config.MinProtocolVersion == "" {
config.MinProtocolVersion = "3.0"
}
// When max_protocol_version is not explicitly set, default based on
// min_protocol_version. This matches libpq behavior: if min > 3.0,
// default max to latest; otherwise default to 3.0 for compatibility
// with older servers/poolers that don't support NegotiateProtocolVersion.
if config.MaxProtocolVersion == "" {
if minProto > pgproto3.ProtocolVersion30 {
config.MaxProtocolVersion = "latest"
} else {
config.MaxProtocolVersion = "3.0"
}
}
// Only error when max_protocol_version was explicitly set and conflicts
// with min_protocol_version. When max_protocol_version is not explicitly
// set, the auto-raise logic above already ensures a valid default.
if minProto > maxProto && settings["max_protocol_version"] != "" {
return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"}
}
switch channelBinding := settings["channel_binding"]; channelBinding {
case "", "prefer":
config.ChannelBinding = "prefer"
case "disable":
config.ChannelBinding = "disable"
case "require":
config.ChannelBinding = "require"
default:
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown channel_binding value: %v", channelBinding)}
}
return config, nil
}
@ -467,6 +535,8 @@ func parseEnvSettings() map[string]string {
"PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
"PGMINPROTOCOLVERSION": "min_protocol_version",
"PGMAXPROTOCOLVERSION": "max_protocol_version",
}
for envname, realname := range nameMap {
@ -491,7 +561,9 @@ func parseURLSettings(connString string) (map[string]string, error) {
}
if parsedURL.User != nil {
settings["user"] = parsedURL.User.Username()
if u := parsedURL.User.Username(); u != "" {
settings["user"] = u
}
if password, present := parsedURL.User.Password(); present {
settings["password"] = password
}
@ -618,6 +690,9 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
return nil, errors.New("invalid keyword/value")
}
if key == "user" && val == "" {
continue
}
settings[key] = val
}
@ -788,7 +863,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) //nolint:ineffassign
}
// if sslpassword not provided or has decryption error when use it
// try to find sslpassword with callback function
@ -803,7 +878,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError)
}
pemBytes := pem.Block{
@ -955,3 +1030,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn
return nil
}
func parseProtocolVersion(s string) (uint32, error) {
switch s {
case "", "3.0":
return pgproto3.ProtocolVersion30, nil
case "3.2", "latest":
return pgproto3.ProtocolVersion32, nil
default:
return 0, fmt.Errorf("invalid protocol version: %q", s)
}
}

View File

@ -8,12 +8,13 @@ import (
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
handler Handler
unwatchChan chan struct{}
handler Handler
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
// Lock protects the members below.
lock sync.Mutex
// Stop is the handle for an "after func". See [context.AfterFunc].
stop func() bool
done chan struct{}
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
@ -21,8 +22,7 @@ type ContextWatcher struct {
// onCancel called.
func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{
handler: handler,
unwatchChan: make(chan struct{}),
handler: handler,
}
return cw
@ -33,25 +33,16 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
if cw.stop != nil {
panic("watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
cw.done = make(chan struct{})
cw.stop = context.AfterFunc(ctx, func() {
cw.handler.HandleCancel(ctx)
close(cw.done)
})
}
}
@ -61,12 +52,13 @@ func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
if cw.stop != nil {
if !cw.stop() {
<-cw.done
cw.handler.HandleUnwatchAfterCancel()
}
cw.watchInProgress = false
cw.stop = nil
cw.done = nil
}
}

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
return errors.New("bad auth type")
}
dst.AuthMechanisms = dst.AuthMechanisms[:0]
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)

View File

@ -123,7 +123,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
msgSize := int(int32(binary.BigEndian.Uint32(buf)) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
case ProtocolVersion30, ProtocolVersion32:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
@ -176,7 +176,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0]
msgLength := int(binary.BigEndian.Uint32(header[1:]))
msgLength := int(int32(binary.BigEndian.Uint32(header[1:])))
if msgLength < 4 {
return nil, fmt.Errorf("invalid message length: %d", msgLength)
}

View File

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
@ -9,7 +10,7 @@ import (
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
SecretKey []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
@ -18,12 +19,13 @@ func (*BackendKeyData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BackendKeyData) Decode(src []byte) error {
if len(src) != 8 {
if len(src) < 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = make([]byte, len(src)-4)
copy(dst.SecretKey, src[4:])
return nil
}
@ -32,7 +34,7 @@ func (dst *BackendKeyData) Decode(src []byte) error {
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
dst = append(dst, src.SecretKey...)
return finishMessage(dst, sp)
}
@ -41,10 +43,29 @@ func (src BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
SecretKey string
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
SecretKey: hex.EncodeToString(src.SecretKey),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *BackendKeyData) UnmarshalJSON(data []byte) error {
var msg struct {
ProcessID uint32
SecretKey string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.ProcessID = msg.ProcessID
secretKey, err := hex.DecodeString(msg.SecretKey)
if err != nil {
return err
}
dst.SecretKey = secretKey
return nil
}

View File

@ -82,7 +82,7 @@ func (dst *Bind) Decode(src []byte) error {
continue
}
if len(src[rp:]) < msgSize {
if msgSize < 0 || len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "Bind"}
}

View File

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
@ -12,35 +13,42 @@ const cancelRequestCode = 80877102
type CancelRequest struct {
ProcessID uint32
SecretKey uint32
SecretKey []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CancelRequest) Frontend() {}
func (dst *CancelRequest) Decode(src []byte) error {
if len(src) != 12 {
return errors.New("bad cancel request size")
if len(src) < 12 {
return errors.New("cancel request too short")
}
if len(src) > 264 {
return errors.New("cancel request too long")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != cancelRequestCode {
return errors.New("bad cancel request code")
}
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
dst.SecretKey = make([]byte, len(src)-8)
copy(dst.SecretKey, src[8:])
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
if len(src.SecretKey) > 256 {
return nil, errors.New("secret key too long")
}
msgLen := int32(12 + len(src.SecretKey))
dst = pgio.AppendInt32(dst, msgLen)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
dst = append(dst, src.SecretKey...)
return dst, nil
}
@ -49,10 +57,29 @@ func (src CancelRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
SecretKey string
}{
Type: "CancelRequest",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
SecretKey: hex.EncodeToString(src.SecretKey),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CancelRequest) UnmarshalJSON(data []byte) error {
var msg struct {
ProcessID uint32
SecretKey string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.ProcessID = msg.ProcessID
secretKey, err := hex.DecodeString(msg.SecretKey)
if err != nil {
return err
}
dst.SecretKey = secretKey
return nil
}

View File

@ -15,6 +15,10 @@ func (*CopyFail) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyFail) Decode(src []byte) error {
if len(src) == 0 {
return &invalidMessageFormatErr{messageType: "CopyFail"}
}
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CopyFail"}

View File

@ -52,6 +52,7 @@ type Frontend struct {
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
negotiateProtocolVersion NegotiateProtocolVersion
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
@ -230,7 +231,7 @@ func (f *Frontend) SendExecute(msg *Execute) {
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg)
}
}
@ -312,7 +313,7 @@ func (f *Frontend) Receive() (BackendMessage, error) {
f.msgType = header[0]
msgLength := int(binary.BigEndian.Uint32(header[1:]))
msgLength := int(int32(binary.BigEndian.Uint32(header[1:])))
if msgLength < 4 {
return nil, fmt.Errorf("invalid message length: %d", msgLength)
}
@ -383,6 +384,8 @@ func (f *Frontend) Receive() (BackendMessage, error) {
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
case 'v':
msg = &f.negotiateProtocolVersion
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}

View File

@ -23,6 +23,11 @@ func (*FunctionCall) Frontend() {}
func (dst *FunctionCall) Decode(src []byte) error {
*dst = FunctionCall{}
rp := 0
if len(src) < 8 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
@ -32,6 +37,11 @@ func (dst *FunctionCall) Decode(src []byte) error {
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if len(src[rp:]) < nArgumentCodes*2+2 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes := make([]uint16, nArgumentCodes)
for i := range nArgumentCodes {
// The argument format codes. Each must presently be zero (text) or one (binary).
@ -49,13 +59,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
rp += 2
arguments := make([][]byte, nArguments)
for i := range nArguments {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
argumentLength := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else if argumentLength < 0 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
} else {
if len(src[rp:]) < argumentLength {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp : rp+argumentLength]
rp += argumentLength
@ -64,6 +82,9 @@ func (dst *FunctionCall) Decode(src []byte) error {
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}

View File

@ -22,7 +22,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
rp := 0
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
resultSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
if resultSize == -1 {
@ -30,7 +30,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
return nil
}
if len(src[rp:]) != resultSize {
if resultSize < 0 || len(src[rp:]) != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}

View File

@ -0,0 +1,93 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type NegotiateProtocolVersion struct {
NewestMinorProtocol uint32
UnrecognizedOptions []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NegotiateProtocolVersion) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NegotiateProtocolVersion) Decode(src []byte) error {
if len(src) < 8 {
return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)}
}
dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4])
optionCount := int(binary.BigEndian.Uint32(src[4:8]))
rp := 8
// Use the remaining message size as an upper bound for capacity to prevent
// malicious optionCount values from causing excessive memory allocation.
capHint := optionCount
if remaining := len(src) - rp; capHint > remaining {
capHint = remaining
}
dst.UnrecognizedOptions = make([]string, 0, capHint)
for i := 0; i < optionCount; i++ {
if rp >= len(src) {
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
}
end := rp
for end < len(src) && src[end] != 0 {
end++
}
if end >= len(src) {
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
}
dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end]))
rp = end + 1
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'v')
dst = pgio.AppendUint32(dst, src.NewestMinorProtocol)
dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions)))
for _, option := range src.UnrecognizedOptions {
dst = append(dst, option...)
dst = append(dst, 0)
}
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
NewestMinorProtocol uint32
UnrecognizedOptions []string
}{
Type: "NegotiateProtocolVersion",
NewestMinorProtocol: src.NewestMinorProtocol,
UnrecognizedOptions: src.UnrecognizedOptions,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error {
var msg struct {
NewestMinorProtocol uint32
UnrecognizedOptions []string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.NewestMinorProtocol = msg.NewestMinorProtocol
dst.UnrecognizedOptions = msg.UnrecognizedOptions
return nil
}

View File

@ -15,6 +15,10 @@ func (*Query) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Query) Decode(src []byte) error {
if len(src) == 0 {
return &invalidMessageFormatErr{messageType: "Query"}
}
i := bytes.IndexByte(src, 0)
if i != len(src)-1 {
return &invalidMessageFormatErr{messageType: "Query"}

View File

@ -32,6 +32,9 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
dst.AuthMechanism = string(src[rp:idx])
rp = idx + 1
if len(src[rp:]) < 4 {
return errors.New("invalid SASLInitialResponse")
}
rp += 4 // The rest of the message is data so we can just skip the size
dst.Data = src[rp:]

View File

@ -10,7 +10,12 @@ import (
"github.com/jackc/pgx/v5/internal/pgio"
)
const ProtocolVersionNumber = 196608 // 3.0
const (
ProtocolVersion30 = 196608 // 3.0
ProtocolVersion32 = 196610 // 3.2
ProtocolVersionLatest = ProtocolVersion32 // Latest is 3.2
ProtocolVersionNumber = ProtocolVersion30 // Default is still 3.0
)
type StartupMessage struct {
ProtocolVersion uint32
@ -30,8 +35,8 @@ func (dst *StartupMessage) Decode(src []byte) error {
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
rp := 4
if dst.ProtocolVersion != ProtocolVersionNumber {
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
if dst.ProtocolVersion != ProtocolVersion30 && dst.ProtocolVersion != ProtocolVersion32 {
return fmt.Errorf("Bad startup message version number. Expected %d or %d, got %d", ProtocolVersion30, ProtocolVersion32, dst.ProtocolVersion)
}
dst.Parameters = make(map[string]string)

View File

@ -82,7 +82,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) {
case *ErrorResponse:
t.traceErrorResponse(sender, encodedLen, msg)
case *Execute:
t.TraceQueryute(sender, encodedLen, msg)
t.traceExecute(sender, encodedLen, msg)
case *Flush:
t.traceFlush(sender, encodedLen, msg)
case *FunctionCall:
@ -260,7 +260,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes
t.writeTrace(sender, encodedLen, "ErrorResponse", nil)
}
func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) {
func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) {
t.writeTrace(sender, encodedLen, "Execute", func() {
fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
})

View File

@ -38,6 +38,10 @@ func cardinality(dimensions []ArrayDimension) int {
elementCount *= int(d.Length)
}
if elementCount < 0 {
return 0
}
return elementCount
}
@ -51,16 +55,20 @@ func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) {
numDims := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numDims > 6 {
return 0, fmt.Errorf("array has too many dimensions: %d", numDims)
}
dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
rp += 4
dst.ElementOID = binary.BigEndian.Uint32(src[rp:])
rp += 4
dst.Dimensions = make([]ArrayDimension, numDims)
if len(src) < 12+numDims*8 {
return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
}
dst.Dimensions = make([]ArrayDimension, numDims)
for i := range dst.Dimensions {
dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
@ -299,7 +307,7 @@ func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) {
return "", false, err
}
case '"':
r, _, err = buf.ReadRune()
_, _, err = buf.ReadRune()
if err != nil {
return "", false, err
}

View File

@ -289,7 +289,7 @@ type CompositeBinaryScanner struct {
err error
}
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
// NewCompositeBinaryScanner a scanner over a binary encoded composite value.
func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
rp := 0
if len(src[rp:]) < 4 {

View File

@ -90,19 +90,19 @@ func GetAssignToDstType(dst any) (any, bool) {
func init() {
kindTypes = map[reflect.Kind]reflect.Type{
reflect.Bool: reflect.TypeOf(false),
reflect.Float32: reflect.TypeOf(float32(0)),
reflect.Float64: reflect.TypeOf(float64(0)),
reflect.Int: reflect.TypeOf(int(0)),
reflect.Int8: reflect.TypeOf(int8(0)),
reflect.Int16: reflect.TypeOf(int16(0)),
reflect.Int32: reflect.TypeOf(int32(0)),
reflect.Int64: reflect.TypeOf(int64(0)),
reflect.Uint: reflect.TypeOf(uint(0)),
reflect.Uint8: reflect.TypeOf(uint8(0)),
reflect.Uint16: reflect.TypeOf(uint16(0)),
reflect.Uint32: reflect.TypeOf(uint32(0)),
reflect.Uint64: reflect.TypeOf(uint64(0)),
reflect.String: reflect.TypeOf(""),
reflect.Bool: reflect.TypeFor[bool](),
reflect.Float32: reflect.TypeFor[float32](),
reflect.Float64: reflect.TypeFor[float64](),
reflect.Int: reflect.TypeFor[int](),
reflect.Int8: reflect.TypeFor[int8](),
reflect.Int16: reflect.TypeFor[int16](),
reflect.Int32: reflect.TypeFor[int32](),
reflect.Int64: reflect.TypeFor[int64](),
reflect.Uint: reflect.TypeFor[uint](),
reflect.Uint8: reflect.TypeFor[uint8](),
reflect.Uint16: reflect.TypeFor[uint16](),
reflect.Uint32: reflect.TypeFor[uint32](),
reflect.Uint64: reflect.TypeFor[uint64](),
reflect.String: reflect.TypeFor[string](),
}
}

View File

@ -5,7 +5,6 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"regexp"
"strconv"
"time"
@ -271,8 +270,6 @@ func (scanPlanBinaryDateToDateScanner) Scan(src []byte, dst any) error {
type scanPlanTextAnyToDateScanner struct{}
var dateRegexp = regexp.MustCompile(`^(\d{4,})-(\d\d)-(\d\d)( BC)?$`)
func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error {
scanner := (dst).(DateScanner)
@ -280,41 +277,104 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error {
return scanner.ScanDate(Date{})
}
sbuf := string(src)
match := dateRegexp.FindStringSubmatch(sbuf)
if match != nil {
year, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err)
}
month, err := strconv.ParseInt(match[2], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err)
}
day, err := strconv.ParseInt(match[3], 10, 32)
if err != nil {
return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err)
}
// BC matched
if len(match[4]) > 0 {
year = -year + 1
}
t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC)
return scanner.ScanDate(Date{Time: t, Valid: true})
// Check infinity cases first
if len(src) == 8 && string(src) == "infinity" {
return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true})
}
if len(src) == 9 && string(src) == "-infinity" {
return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true})
}
switch sbuf {
case "infinity":
return scanner.ScanDate(Date{InfinityModifier: Infinity, Valid: true})
case "-infinity":
return scanner.ScanDate(Date{InfinityModifier: -Infinity, Valid: true})
default:
// Format: YYYY-MM-DD or YYYY...-MM-DD BC
// Minimum: 10 chars (2000-01-01), with BC: 13 chars
if len(src) < 10 {
return fmt.Errorf("invalid date format")
}
// Check for BC suffix
bc := false
datePart := src
if len(src) >= 13 && string(src[len(src)-3:]) == " BC" {
bc = true
datePart = src[:len(src)-3]
}
// Find year-month separator (first dash after at least 4 digits)
yearEnd := -1
for i := 4; i < len(datePart); i++ {
if datePart[i] == '-' {
yearEnd = i
break
}
if datePart[i] < '0' || datePart[i] > '9' {
return fmt.Errorf("invalid date format")
}
}
if yearEnd == -1 || yearEnd+6 > len(datePart) {
return fmt.Errorf("invalid date format")
}
// Validate: -MM-DD structure after year
if datePart[yearEnd+3] != '-' {
return fmt.Errorf("invalid date format")
}
// Parse year
year, err := parseDigits(datePart[:yearEnd])
if err != nil {
return fmt.Errorf("invalid date format")
}
// Parse month (2 digits)
month, err := parse2Digits(datePart[yearEnd+1 : yearEnd+3])
if err != nil {
return fmt.Errorf("invalid date format")
}
// Parse day (2 digits)
day, err := parse2Digits(datePart[yearEnd+4 : yearEnd+6])
if err != nil {
return fmt.Errorf("invalid date format")
}
// Ensure nothing extra after day
if yearEnd+6 != len(datePart) {
return fmt.Errorf("invalid date format")
}
if bc {
year = -year + 1
}
t := time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC)
return scanner.ScanDate(Date{Time: t, Valid: true})
}
// parse2Digits parses exactly 2 ASCII digits.
func parse2Digits(b []byte) (int64, error) {
if len(b) != 2 {
return 0, fmt.Errorf("expected 2 digits")
}
d1, d2 := b[0], b[1]
if d1 < '0' || d1 > '9' || d2 < '0' || d2 > '9' {
return 0, fmt.Errorf("expected digits")
}
return int64(d1-'0')*10 + int64(d2-'0'), nil
}
// parseDigits parses a sequence of ASCII digits.
func parseDigits(b []byte) (int64, error) {
if len(b) == 0 {
return 0, fmt.Errorf("empty")
}
var n int64
for _, c := range b {
if c < '0' || c > '9' {
return 0, fmt.Errorf("non-digit")
}
n = n*10 + int64(c-'0')
}
return n, nil
}
func (c DateCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {

View File

@ -1,10 +1,10 @@
// Package pgtype converts between Go and PostgreSQL values.
/*
The primary type is the Map type. It is a map of PostgreSQL types identified by OID (object ID) to a Codec. A Codec is
responsible for converting between Go and PostgreSQL values. NewMap creates a Map with all supported standard PostgreSQL
types already registered. Additional types can be registered with Map.RegisterType.
The primary type is the [Map] type. It is a map of PostgreSQL types identified by OID (object ID) to a [Codec]. A [Codec] is
responsible for converting between Go and PostgreSQL values. [NewMap] creates a [Map] with all supported standard PostgreSQL
types already registered. Additional types can be registered with [Map.RegisterType].
Use Map.Scan and Map.Encode to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively.
Use [Map.Scan] and [Map.Encode] to decode PostgreSQL values to Go and encode Go values to PostgreSQL respectively.
Base Type Mapping
@ -63,8 +63,8 @@ pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL
Extending Existing PostgreSQL Type Support
Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example,
PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use
pgtype.Point and application can directly use its own point type with pgtype as long as it implements those interfaces.
[PointCodec] can use any Go type that implements the [PointScanner] and [PointValuer] interfaces. So rather than use
[Point] an application can directly use its own point type with pgtype as long as it implements those interfaces.
See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type.
@ -77,10 +77,10 @@ New PostgreSQL Type Support
pgtype uses the PostgreSQL OID to determine how to encode or decode a value. pgtype supports array, composite, domain,
and enum types. However, any type created in PostgreSQL with CREATE TYPE will receive a new OID. This means that the OID
of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct Codec.
of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct [Codec].
The pgx.Conn LoadType method can return a *Type for array, composite, domain, and enum types by inspecting the database
metadata. This *Type can then be registered with Map.RegisterType.
The [github.com/jackc/pgx/v5.Conn.LoadType] method can return a [*Type] for array, composite, domain, and enum types by
inspecting the database metadata. This [*Type] can then be registered with [Map.RegisterType].
For example, the following function could be called after a connection is established:
@ -106,30 +106,30 @@ For example, the following function could be called after a connection is establ
A type cannot be registered unless all types it depends on are already registered. e.g. An array type cannot be
registered until its element type is registered.
ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an
ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays.
[ArrayCodec] implements support for arrays. If pgtype supports type T then it can easily support []T by registering an
[ArrayCodec] for the appropriate PostgreSQL OID. In addition, [Array] type can support multi-dimensional arrays.
CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of
the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and
CompositeIndexGetter.
[CompositeCodec] implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of
the struct are in the exact order and type of the PostgreSQL type or by implementing [CompositeIndexScanner] and
[CompositeIndexGetter].
Domain types are treated as their underlying type if the underlying type and the domain type are registered.
PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can
PostgreSQL enums can usually be treated as text. However, [EnumCodec] implements support for interning strings which can
reduce memory usage.
While pgtype will often still work with unregistered types it is highly recommended that all types be registered due to
an improvement in performance and the elimination of certain edge cases.
If an entirely new PostgreSQL type (e.g. PostGIS types) is used then the application or a library can create a new
Codec. Then the OID / Codec mapping can be registered with Map.RegisterType. There is no difference between a Codec
defined and registered by the application and a Codec built in to pgtype. See any of the Codecs in pgtype for Codec
[Codec]. Then the OID / [Codec] mapping can be registered with [Map.RegisterType]. There is no difference between a [Codec]
defined and registered by the application and a [Codec] built in to pgtype. See any of the [Codec]s in pgtype for [Codec]
examples and for examples of type registration.
Encoding Unknown Types
pgtype works best when the OID of the PostgreSQL type is known. But in some cases such as using the simple protocol the
OID is unknown. In this case Map.RegisterDefaultPgType can be used to register an assumed OID for a particular Go type.
OID is unknown. In this case [Map.RegisterDefaultPgType] can be used to register an assumed OID for a particular Go type.
Renamed Types
@ -137,18 +137,18 @@ If pgtype does not recognize a type and that type is a renamed simple type simpl
as if it is the underlying type. It currently cannot automatically detect the underlying type of renamed structs (eg.g.
type MyTime time.Time).
Compatibility with database/sql
Compatibility with [database/sql]
pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
pgtype also includes support for custom types implementing the [database/sql.Scanner] and [database/sql/driver.Valuer]
interfaces.
Encoding Typed Nils
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec
system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil).
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the [Codec]
system. This means that [Codec]s and other encoding logic do not have to handle nil or *T(nil).
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See
However, [database/sql] compatibility requires Value to be called on T(nil) when T implements [database/sql/driver.Valuer]. Therefore,
[database/sql/driver.Valuer] values are only considered NULL when *T(nil) where [database/sql/driver.Valuer] is implemented on T not on *T. See
https://github.com/golang/go/issues/8415 and
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
@ -159,38 +159,38 @@ example_child_records_test.go for an example.
Overview of Scanning Implementation
The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a
plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types
are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner
and PointValuer interfaces.
The first step is to use the OID to lookup the correct [Codec]. The [Map] will call the [Codec.PlanScan] method to get a
plan for scanning into the Go value. A [Codec] will support scanning into one or more Go types. Oftentime these Go types
are interfaces rather than explicit types. For example, [PointCodec] can use any Go type that implements the [PointScanner]
and [PointValuer] interfaces.
If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that
interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and
If a Go value is not supported directly by a [Codec] then [Map] will try see if it is a [database/sql.Scanner]. If is then that
interface will be used to scan the value. Most [database/sql.Scanner]s require the input to be in the text format (e.g. UUIDs and
numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be
parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with
parsed, reencoded as text, and then passed to the [database/sql.Scanner]. This may incur additional overhead for query results with
a large number of affected values.
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.
For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that
If a Go value is not supported directly by a [Codec] then [Map] will try wrapping it with additional logic and try again.
For example, [Int8Codec] does not support scanning into a renamed type (e.g. type myInt64 int64). But [Map] will detect that
myInt64 is a renamed type and create a plan that converts the value to the underlying int64 type and then passes that to
the Codec (see TryFindUnderlyingTypeScanPlan).
the [Codec] (see [TryFindUnderlyingTypeScanPlan]).
These plan wrappers are contained in Map.TryWrapScanPlanFuncs. By default these contain shared logic to handle renamed
These plan wrappers are contained in [Map.TryWrapScanPlanFuncs]. By default these contain shared logic to handle renamed
types, pointers to pointers, slices, composite types, etc. Additional plan wrappers can be added to seamlessly integrate
types that do not support pgx directly. For example, the before mentioned
https://github.com/jackc/pgx-shopspring-decimal package detects decimal.Decimal values, wraps them in something
implementing NumericScanner and passes that to the Codec.
implementing [NumericScanner] and passes that to the [Codec].
Map.Scan and Map.Encode are convenience methods that wrap Map.PlanScan and Map.PlanEncode. Determining how to scan or
[Map.Scan] and [Map.Encode] are convenience methods that wrap [Map.PlanScan] and [Map.PlanEncode]. Determining how to scan or
encode a particular type may be a time consuming operation. Hence the planning and execution steps of a conversion are
internally separated.
Reducing Compiled Binary Size
pgx.QueryExecModeExec and pgx.QueryExecModeSimpleProtocol require the default PostgreSQL type to be registered for each
Go type used as a query parameter. By default pgx does this for all supported types and their array variants. If an
application does not use those query execution modes or manually registers the default PostgreSQL type for the types it
uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits the default type registration
and reduces the compiled binary size by ~2MB.
[github.com/jackc/pgx/v5.QueryExecModeExec] and [github.com/jackc/pgx/v5.QueryExecModeSimpleProtocol] require the default
PostgreSQL type to be registered for each Go type used as a query parameter. By default pgx does this for all supported
types and their array variants. If an application does not use those query execution modes or manually registers the default
PostgreSQL type for the types it uses as query parameters it can use the build tag nopgxregisterdefaulttypes. This omits
the default type registration and reduces the compiled binary size by ~2MB.
*/
package pgtype

View File

@ -198,6 +198,10 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += uint32Len
if pairCount < 0 {
return fmt.Errorf("hstore invalid pair count: %d", pairCount)
}
hstore := make(Hstore, pairCount)
// one allocation for all *string, rather than one per string, just like text parsing
valueStrings := make([]string, pairCount)
@ -209,6 +213,9 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += uint32Len
if keyLen < 0 {
return fmt.Errorf("hstore invalid key length: %d", keyLen)
}
if len(src[rp:]) < keyLen {
return fmt.Errorf("hstore incomplete %v", src)
}

View File

@ -78,7 +78,7 @@ func (dst *Int2) Scan(src any) error {
}
if n < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n)
return fmt.Errorf("%d is less than minimum value for Int2", n)
}
if n > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n)
@ -641,7 +641,7 @@ func (dst *Int4) Scan(src any) error {
}
if n < math.MinInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", n)
return fmt.Errorf("%d is less than minimum value for Int4", n)
}
if n > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", n)

View File

@ -210,6 +210,11 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte,
elementCount := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
// Each element requires at least 4 bytes for its length prefix.
if elementCount > len(src)/4 {
return fmt.Errorf("multirange element count %d exceeds available data", elementCount)
}
err := multirange.SetLen(elementCount)
if err != nil {
return err

View File

@ -128,7 +128,7 @@ func (n Numeric) Int64Value() (Int8, error) {
}
func (n *Numeric) ScanScientific(src string) error {
if !strings.ContainsAny("eE", src) {
if !strings.ContainsAny(src, "eE") {
return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n)
}
@ -264,6 +264,10 @@ func (n *Numeric) UnmarshalJSON(src []byte) error {
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
if n.Int == nil {
return []byte("0")
}
intStr := n.Int.String()
buf := &bytes.Buffer{}
@ -405,7 +409,7 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) {
}
var sign int16
if n.Int.Sign() < 0 {
if n.Int != nil && n.Int.Sign() < 0 {
sign = 16384
}
@ -413,7 +417,9 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) {
wholePart := &big.Int{}
fracPart := &big.Int{}
remainder := &big.Int{}
absInt.Abs(n.Int)
if n.Int != nil {
absInt.Abs(n.Int)
}
// Normalize absInt and exp to where exp is always a multiple of 4. This makes
// converting to 16-bit base 10,000 digits easier.

View File

@ -96,6 +96,8 @@ const (
RecordArrayOID = 2287
UUIDOID = 2950
UUIDArrayOID = 2951
TSVectorOID = 3614
TSVectorArrayOID = 3643
JSONBOID = 3802
JSONBArrayOID = 3807
DaterangeOID = 3912
@ -154,7 +156,7 @@ const (
BinaryFormatCode = 1
)
// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map.
// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a [Map].
type Codec interface {
// FormatSupported returns true if the format is supported.
FormatSupported(int16) bool
@ -185,7 +187,7 @@ func (e *nullAssignmentError) Error() string {
return fmt.Sprintf("cannot assign NULL to %T", e.dst)
}
// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map.
// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a [Map].
type Type struct {
Codec Codec
Name string
@ -241,6 +243,7 @@ func NewMap() *Map {
TryWrapDerefPointerEncodePlan,
TryWrapBuiltinTypeEncodePlan,
TryWrapFindUnderlyingTypeEncodePlan,
TryWrapStringerEncodePlan,
TryWrapStructEncodePlan,
TryWrapSliceEncodePlan,
TryWrapMultiDimSliceEncodePlan,
@ -266,7 +269,7 @@ func (m *Map) RegisterTypes(types []*Type) {
}
}
// RegisterType registers a data type with the Map. t must not be mutated after it is registered.
// RegisterType registers a data type with the [Map]. t must not be mutated after it is registered.
func (m *Map) RegisterType(t *Type) {
m.oidToType[t.OID] = t
m.nameToType[t.Name] = t
@ -292,7 +295,7 @@ func (m *Map) RegisterDefaultPgType(value any, name string) {
}
}
// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated.
// TypeForOID returns the [Type] registered for the given OID. The returned [Type] must not be mutated.
func (m *Map) TypeForOID(oid uint32) (*Type, bool) {
if dt, ok := m.oidToType[oid]; ok {
return dt, true
@ -302,7 +305,7 @@ func (m *Map) TypeForOID(oid uint32) (*Type, bool) {
return dt, ok
}
// TypeForName returns the Type registered for the given name. The returned Type must not be mutated.
// TypeForName returns the [Type] registered for the given name. The returned [Type] must not be mutated.
func (m *Map) TypeForName(name string) (*Type, bool) {
if dt, ok := m.nameToType[name]; ok {
return dt, true
@ -321,8 +324,8 @@ func (m *Map) buildReflectTypeToType() {
}
}
// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode
// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type
// TypeForValue finds a data type suitable for v. Use [Map.RegisterType] to register types that can encode and decode
// themselves. Use [Map.RegisterDefaultPgType] to register that can be handled by a registered data type. The returned [Type]
// must not be mutated.
func (m *Map) TypeForValue(v any) (*Type, bool) {
if m.reflectTypeToType == nil {
@ -523,20 +526,20 @@ type SkipUnderlyingTypePlanner interface {
}
var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{
reflect.Int: reflect.TypeOf(new(int)),
reflect.Int8: reflect.TypeOf(new(int8)),
reflect.Int16: reflect.TypeOf(new(int16)),
reflect.Int32: reflect.TypeOf(new(int32)),
reflect.Int64: reflect.TypeOf(new(int64)),
reflect.Uint: reflect.TypeOf(new(uint)),
reflect.Uint8: reflect.TypeOf(new(uint8)),
reflect.Uint16: reflect.TypeOf(new(uint16)),
reflect.Uint32: reflect.TypeOf(new(uint32)),
reflect.Uint64: reflect.TypeOf(new(uint64)),
reflect.Float32: reflect.TypeOf(new(float32)),
reflect.Float64: reflect.TypeOf(new(float64)),
reflect.String: reflect.TypeOf(new(string)),
reflect.Bool: reflect.TypeOf(new(bool)),
reflect.Int: reflect.TypeFor[*int](),
reflect.Int8: reflect.TypeFor[*int8](),
reflect.Int16: reflect.TypeFor[*int16](),
reflect.Int32: reflect.TypeFor[*int32](),
reflect.Int64: reflect.TypeFor[*int64](),
reflect.Uint: reflect.TypeFor[*uint](),
reflect.Uint8: reflect.TypeFor[*uint8](),
reflect.Uint16: reflect.TypeFor[*uint16](),
reflect.Uint32: reflect.TypeFor[*uint32](),
reflect.Uint64: reflect.TypeFor[*uint64](),
reflect.Float32: reflect.TypeFor[*float32](),
reflect.Float64: reflect.TypeFor[*float64](),
reflect.String: reflect.TypeFor[*string](),
reflect.Bool: reflect.TypeFor[*bool](),
}
type underlyingTypeScanPlan struct {
@ -901,7 +904,7 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst any) error {
return nil
}
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
// TryWrapStructScanPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
@ -1372,23 +1375,23 @@ func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter,
}
var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{
reflect.Int: reflect.TypeOf(int(0)),
reflect.Int8: reflect.TypeOf(int8(0)),
reflect.Int16: reflect.TypeOf(int16(0)),
reflect.Int32: reflect.TypeOf(int32(0)),
reflect.Int64: reflect.TypeOf(int64(0)),
reflect.Uint: reflect.TypeOf(uint(0)),
reflect.Uint8: reflect.TypeOf(uint8(0)),
reflect.Uint16: reflect.TypeOf(uint16(0)),
reflect.Uint32: reflect.TypeOf(uint32(0)),
reflect.Uint64: reflect.TypeOf(uint64(0)),
reflect.Float32: reflect.TypeOf(float32(0)),
reflect.Float64: reflect.TypeOf(float64(0)),
reflect.String: reflect.TypeOf(""),
reflect.Bool: reflect.TypeOf(false),
reflect.Int: reflect.TypeFor[int](),
reflect.Int8: reflect.TypeFor[int8](),
reflect.Int16: reflect.TypeFor[int16](),
reflect.Int32: reflect.TypeFor[int32](),
reflect.Int64: reflect.TypeFor[int64](),
reflect.Uint: reflect.TypeFor[uint](),
reflect.Uint8: reflect.TypeFor[uint8](),
reflect.Uint16: reflect.TypeFor[uint16](),
reflect.Uint32: reflect.TypeFor[uint32](),
reflect.Uint64: reflect.TypeFor[uint64](),
reflect.Float32: reflect.TypeFor[float32](),
reflect.Float64: reflect.TypeFor[float64](),
reflect.String: reflect.TypeFor[string](),
reflect.Bool: reflect.TypeFor[bool](),
}
var byteSliceType = reflect.TypeOf([]byte{})
var byteSliceType = reflect.TypeFor[[]byte]()
type underlyingTypeEncodePlan struct {
nextValueType reflect.Type
@ -1444,6 +1447,24 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS
return nil, nil, false
}
// TryWrapStringerEncodePlan tries to wrap a fmt.Stringer type with a wrapper that provides TextValuer. This is
// intentionally a separate function from TryWrapBuiltinTypeEncodePlan so it can be ordered after
// TryWrapFindUnderlyingTypeEncodePlan. This ensures that named types with an underlying builtin type (e.g. type MyEnum
// int32 with a String() method) prefer encoding via the underlying type's codec (e.g. as an integer) rather than via
// Stringer. Stringer is only used as a fallback when no type-specific encoding plan succeeds.
// (https://github.com/jackc/pgx/discussions/2527)
func TryWrapStringerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false
}
if s, ok := value.(fmt.Stringer); ok {
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{s}, true
}
return nil, nil, false
}
type WrappedEncodePlanNextSetter interface {
SetNext(EncodePlan)
EncodePlan
@ -1504,8 +1525,6 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter,
return &wrapByte16EncodePlan{}, byte16Wrapper(value), true
case []byte:
return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true
case fmt.Stringer:
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true
}
return nil, nil, false
@ -1751,7 +1770,7 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value any, buf []byte) (newBuf []b
return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf)
}
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
// TryWrapStructEncodePlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
func TryWrapStructEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false

View File

@ -81,6 +81,7 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
defaultMap.RegisterType(&Type{Name: "tsvector", OID: TSVectorOID, Codec: TSVectorCodec{}})
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
@ -164,6 +165,7 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}})
defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}})
defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}})
defaultMap.RegisterType(&Type{Name: "_tsvector", OID: TSVectorArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TSVectorOID]}})
defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}})
defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}})
defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}})
@ -242,6 +244,7 @@ func initDefaultMap() {
registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange")
registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange")
registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange")
registerDefaultPgTypeVariants[TSVector](defaultMap, "tsvector")
registerDefaultPgTypeVariants[UUID](defaultMap, "uuid")
defaultMap.buildReflectTypeToType()

View File

@ -111,9 +111,9 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
case "-infinity":
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
default:
// Parse time with or without timezonr
// Parse time with or without timezone
tss := *s
// PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt
// PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestamp
tim, err := time.Parse(time.RFC3339Nano, tss)
if err == nil {
*ts = Timestamp{Time: tim, Valid: true}

507
vendor/github.com/jackc/pgx/v5/pgtype/tsvector.go generated vendored Normal file
View File

@ -0,0 +1,507 @@
package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"strconv"
"strings"
"github.com/jackc/pgx/v5/internal/pgio"
)
type TSVectorScanner interface {
ScanTSVector(TSVector) error
}
type TSVectorValuer interface {
TSVectorValue() (TSVector, error)
}
// TSVector represents a PostgreSQL tsvector value.
type TSVector struct {
Lexemes []TSVectorLexeme
Valid bool
}
// TSVectorLexeme represents a lexeme within a tsvector, consisting of a word and its positions.
type TSVectorLexeme struct {
Word string
Positions []TSVectorPosition
}
// ScanTSVector implements the [TSVectorScanner] interface.
func (t *TSVector) ScanTSVector(v TSVector) error {
*t = v
return nil
}
// TSVectorValue implements the [TSVectorValuer] interface.
func (t TSVector) TSVectorValue() (TSVector, error) {
return t, nil
}
func (t TSVector) String() string {
buf, _ := encodePlanTSVectorCodecText{}.Encode(t, nil)
return string(buf)
}
// Scan implements the [database/sql.Scanner] interface.
func (t *TSVector) Scan(src any) error {
if src == nil {
*t = TSVector{}
return nil
}
switch src := src.(type) {
case string:
return scanPlanTextAnyToTSVectorScanner{}.scanString(src, t)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (t TSVector) Value() (driver.Value, error) {
if !t.Valid {
return nil, nil
}
buf, err := TSVectorCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil)
if err != nil {
return nil, err
}
return string(buf), nil
}
// TSVectorWeight represents the weight label of a lexeme position in a tsvector.
type TSVectorWeight byte
const (
TSVectorWeightA = TSVectorWeight('A')
TSVectorWeightB = TSVectorWeight('B')
TSVectorWeightC = TSVectorWeight('C')
TSVectorWeightD = TSVectorWeight('D')
)
// tsvectorWeightToBinary converts a TSVectorWeight to the 2-bit binary encoding used by PostgreSQL.
func tsvectorWeightToBinary(w TSVectorWeight) uint16 {
switch w {
case TSVectorWeightA:
return 3
case TSVectorWeightB:
return 2
case TSVectorWeightC:
return 1
default:
return 0 // D or unset
}
}
// tsvectorWeightFromBinary converts a 2-bit binary weight value to a TSVectorWeight.
func tsvectorWeightFromBinary(b uint16) TSVectorWeight {
switch b {
case 3:
return TSVectorWeightA
case 2:
return TSVectorWeightB
case 1:
return TSVectorWeightC
default:
return TSVectorWeightD
}
}
// TSVectorPosition represents a lexeme position and its optional weight within a tsvector.
type TSVectorPosition struct {
Position uint16
Weight TSVectorWeight
}
func (p TSVectorPosition) String() string {
s := strconv.FormatUint(uint64(p.Position), 10)
if p.Weight != 0 && p.Weight != TSVectorWeightD {
s += string(p.Weight)
}
return s
}
type TSVectorCodec struct{}
func (TSVectorCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (TSVectorCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (TSVectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TSVectorValuer); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return encodePlanTSVectorCodecBinary{}
case TextFormatCode:
return encodePlanTSVectorCodecText{}
}
return nil
}
type encodePlanTSVectorCodecBinary struct{}
func (encodePlanTSVectorCodecBinary) Encode(value any, buf []byte) ([]byte, error) {
tsv, err := value.(TSVectorValuer).TSVectorValue()
if err != nil {
return nil, err
}
if !tsv.Valid {
return nil, nil
}
buf = pgio.AppendInt32(buf, int32(len(tsv.Lexemes)))
for _, entry := range tsv.Lexemes {
buf = append(buf, entry.Word...)
buf = append(buf, 0x00)
buf = pgio.AppendUint16(buf, uint16(len(entry.Positions)))
// Each position is a uint16: weight (2 bits) | position (14 bits)
for _, pos := range entry.Positions {
packed := tsvectorWeightToBinary(pos.Weight)<<14 | uint16(pos.Position)&0x3FFF
buf = pgio.AppendUint16(buf, packed)
}
}
return buf, nil
}
type scanPlanBinaryTSVectorToTSVectorScanner struct{}
func (scanPlanBinaryTSVectorToTSVectorScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TSVectorScanner)
if src == nil {
return scanner.ScanTSVector(TSVector{})
}
rp := 0
const (
uint16Len = 2
uint32Len = 4
)
if len(src[rp:]) < uint32Len {
return fmt.Errorf("tsvector incomplete %v", src)
}
entryCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += uint32Len
var tsv TSVector
if entryCount > 0 {
tsv.Lexemes = make([]TSVectorLexeme, entryCount)
}
for i := range entryCount {
nullIndex := bytes.IndexByte(src[rp:], 0x00)
if nullIndex == -1 {
return fmt.Errorf("invalid tsvector binary format: missing null terminator")
}
lexeme := TSVectorLexeme{Word: string(src[rp : rp+nullIndex])}
rp += nullIndex + 1 // skip past null terminator
// Read position count.
if len(src[rp:]) < uint16Len {
return fmt.Errorf("invalid tsvector binary format: incomplete position count")
}
numPositions := int(binary.BigEndian.Uint16(src[rp:]))
rp += uint16Len
// Read each packed position: weight (2 bits) | position (14 bits)
if len(src[rp:]) < numPositions*uint16Len {
return fmt.Errorf("invalid tsvector binary format: incomplete positions")
}
if numPositions > 0 {
lexeme.Positions = make([]TSVectorPosition, numPositions)
for pos := range numPositions {
packed := binary.BigEndian.Uint16(src[rp:])
rp += uint16Len
lexeme.Positions[pos] = TSVectorPosition{
Position: packed & 0x3FFF,
Weight: tsvectorWeightFromBinary(packed >> 14),
}
}
}
tsv.Lexemes[i] = lexeme
}
tsv.Valid = true
return scanner.ScanTSVector(tsv)
}
var tsvectorLexemeReplacer = strings.NewReplacer(
`\`, `\\`,
`'`, `\'`,
)
type encodePlanTSVectorCodecText struct{}
func (encodePlanTSVectorCodecText) Encode(value any, buf []byte) ([]byte, error) {
tsv, err := value.(TSVectorValuer).TSVectorValue()
if err != nil {
return nil, err
}
if !tsv.Valid {
return nil, nil
}
if buf == nil {
buf = []byte{}
}
for i, lex := range tsv.Lexemes {
if i > 0 {
buf = append(buf, ' ')
}
buf = append(buf, '\'')
buf = append(buf, tsvectorLexemeReplacer.Replace(lex.Word)...)
buf = append(buf, '\'')
sep := byte(':')
for _, p := range lex.Positions {
buf = append(buf, sep)
buf = append(buf, p.String()...)
sep = ','
}
}
return buf, nil
}
func (TSVectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case TSVectorScanner:
return scanPlanBinaryTSVectorToTSVectorScanner{}
}
case TextFormatCode:
switch target.(type) {
case TSVectorScanner:
return scanPlanTextAnyToTSVectorScanner{}
}
}
return nil
}
type scanPlanTextAnyToTSVectorScanner struct{}
func (s scanPlanTextAnyToTSVectorScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TSVectorScanner)
if src == nil {
return scanner.ScanTSVector(TSVector{})
}
return s.scanString(string(src), scanner)
}
func (scanPlanTextAnyToTSVectorScanner) scanString(src string, scanner TSVectorScanner) error {
tsv, err := parseTSVector(src)
if err != nil {
return err
}
return scanner.ScanTSVector(tsv)
}
func (c TSVectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return codecDecodeToTextFormat(c, m, oid, format, src)
}
func (c TSVectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var tsv TSVector
err := codecScan(c, m, oid, format, src, &tsv)
if err != nil {
return nil, err
}
return tsv, nil
}
type tsvectorParser struct {
str string
pos int
}
func (p *tsvectorParser) atEnd() bool {
return p.pos >= len(p.str)
}
func (p *tsvectorParser) peek() byte {
return p.str[p.pos]
}
func (p *tsvectorParser) consume() (byte, bool) {
if p.pos >= len(p.str) {
return 0, true
}
b := p.str[p.pos]
p.pos++
return b, false
}
func (p *tsvectorParser) consumeSpaces() {
for !p.atEnd() && p.peek() == ' ' {
p.consume()
}
}
// consumeLexeme consumes a single-quoted lexeme, handling single quotes and backslash escapes.
func (p *tsvectorParser) consumeLexeme() (string, error) {
ch, end := p.consume()
if end || ch != '\'' {
return "", fmt.Errorf("invalid tsvector format: lexeme must start with a single quote")
}
var buf strings.Builder
for {
ch, end := p.consume()
if end {
return "", fmt.Errorf("invalid tsvector format: unterminated quoted lexeme")
}
switch ch {
case '\'':
// Escaped quote ('') — write a literal single quote
if !p.atEnd() && p.peek() == '\'' {
p.consume()
buf.WriteByte('\'')
} else {
// Closing quote — lexeme is complete
return buf.String(), nil
}
case '\\':
next, end := p.consume()
if end {
return "", fmt.Errorf("invalid tsvector format: unexpected end after backslash")
}
buf.WriteByte(next)
default:
buf.WriteByte(ch)
}
}
}
// consumePositions consumes a comma-separated list of position[weight] values.
func (p *tsvectorParser) consumePositions() ([]TSVectorPosition, error) {
var positions []TSVectorPosition
for {
pos, err := p.consumePosition()
if err != nil {
return nil, err
}
positions = append(positions, pos)
if p.atEnd() || p.peek() != ',' {
break
}
p.consume() // skip ','
}
return positions, nil
}
// consumePosition consumes a single position number with optional weight letter.
func (p *tsvectorParser) consumePosition() (TSVectorPosition, error) {
start := p.pos
for !p.atEnd() && p.peek() >= '0' && p.peek() <= '9' {
p.consume()
}
if p.pos == start {
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: expected position number")
}
num, err := strconv.ParseUint(p.str[start:p.pos], 10, 16)
if err != nil {
return TSVectorPosition{}, fmt.Errorf("invalid tsvector format: invalid position number %q", p.str[start:p.pos])
}
pos := TSVectorPosition{Position: uint16(num), Weight: TSVectorWeightD}
// Check for optional weight letter
if !p.atEnd() {
switch p.peek() {
case 'A', 'a':
pos.Weight = TSVectorWeightA
case 'B', 'b':
pos.Weight = TSVectorWeightB
case 'C', 'c':
pos.Weight = TSVectorWeightC
case 'D', 'd':
pos.Weight = TSVectorWeightD
default:
return pos, nil
}
p.consume()
}
return pos, nil
}
// parseTSVector parses a PostgreSQL tsvector text representation.
func parseTSVector(s string) (TSVector, error) {
result := TSVector{}
p := &tsvectorParser{str: strings.TrimSpace(s), pos: 0}
for !p.atEnd() {
p.consumeSpaces()
if p.atEnd() {
break
}
word, err := p.consumeLexeme()
if err != nil {
return TSVector{}, err
}
entry := TSVectorLexeme{Word: word}
// Check for optional positions after ':'
if !p.atEnd() && p.peek() == ':' {
p.consume() // skip ':'
positions, err := p.consumePositions()
if err != nil {
return TSVector{}, err
}
entry.Positions = positions
}
result.Lexemes = append(result.Lexemes, entry)
}
result.Valid = true
return result, nil
}

View File

@ -122,7 +122,7 @@ type ShouldPingParams struct {
type Config struct {
ConnConfig *pgx.ConnConfig
// BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and
// BeforeConnect is called before a new connection is made. It is passed a copy of the underlying [pgx.ConnConfig] and
// will not impact any existing open connections.
BeforeConnect func(context.Context, *pgx.ConnConfig) error
@ -218,7 +218,7 @@ func New(ctx context.Context, connString string) (*Pool, error) {
return NewWithConfig(ctx, config)
}
// NewWithConfig creates a new Pool. config must have been created by [ParseConfig].
// NewWithConfig creates a new [Pool]. config must have been created by [ParseConfig].
func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
@ -453,7 +453,7 @@ func ParseConfig(connString string) (*Config, error) {
return config, nil
}
// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned
// Close closes all connections in the pool and rejects future [Pool.Acquire] calls. Blocks until all connections are returned
// to pool and closed.
func (p *Pool) Close() {
p.closeOnce.Do(func() {
@ -558,7 +558,8 @@ func (p *Pool) checkMinConns() error {
// off this check
// Create the number of connections needed to get to both minConns and minIdleConns
toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns())
stat := p.Stat()
toCreate := max(p.minConns-stat.TotalConns(), p.minIdleConns-stat.IdleConns())
if toCreate > 0 {
return p.createIdleResources(context.Background(), int(toCreate))
}
@ -594,7 +595,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
return firstError
}
// Acquire returns a connection (*Conn) from the Pool
// Acquire returns a connection ([Conn]) from the [Pool].
func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{})
@ -620,14 +621,15 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()}
if p.shouldPing(ctx, shouldPingParams) {
pingCtx := ctx
if p.pingTimeout > 0 {
var cancel context.CancelFunc
pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout)
defer cancel()
}
err := cr.conn.Ping(pingCtx)
err := func() error {
pingCtx := ctx
if p.pingTimeout > 0 {
var cancel context.CancelFunc
pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout)
defer cancel()
}
return cr.conn.Ping(pingCtx)
}()
if err != nil {
res.Destroy()
continue
@ -652,11 +654,11 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
return cr.getConn(p, res), nil
}
return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook")
return nil, errors.New("pgxpool: too many failed attempts acquiring connection; likely bug in PrepareConn, BeforeAcquire, or ShouldPing hook")
}
// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the
// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is
// AcquireFunc acquires a [Conn] and calls f with that [Conn]. ctx will only affect the [Pool.Acquire]. It has no effect on the
// call of f. The return value is either an error acquiring the [Conn] or the return value of f. The [Conn] is
// automatically released after the call of f.
func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error {
conn, err := p.Acquire(ctx)
@ -697,7 +699,7 @@ func (p *Pool) Reset() {
p.p.Reset()
}
// Config returns a copy of config that was used to initialize this pool.
// Config returns a copy of config that was used to initialize this [Pool].
func (p *Pool) Config() *Config { return p.config.Copy() }
// Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics.
@ -710,10 +712,10 @@ func (p *Pool) Stat() *Stat {
}
}
// Exec acquires a connection from the Pool and executes the given SQL.
// Exec acquires a connection from the [Pool] and executes the given SQL.
// SQL can be either a prepared statement name or an SQL string.
// Arguments should be referenced positionally from the SQL string as $1, $2, etc.
// The acquired connection is returned to the pool when the Exec function returns.
// The acquired connection is returned to the pool when the [Pool.Exec] function returns.
func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
c, err := p.Acquire(ctx)
if err != nil {
@ -724,15 +726,15 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
return c.Exec(ctx, sql, arguments...)
}
// Query acquires a connection and executes a query that returns pgx.Rows.
// Query acquires a connection and executes a query that returns [pgx.Rows].
// Arguments should be referenced positionally from the SQL string as $1, $2, etc.
// See pgx.Rows documentation to close the returned Rows and return the acquired connection to the Pool.
// See [pgx.Rows] documentation to close the returned [pgx.Rows] and return the acquired connection to the [Pool].
//
// If there is an error, the returned pgx.Rows will be returned in an error state.
// If preferred, ignore the error returned from Query and handle errors using the returned pgx.Rows.
// If there is an error, the returned [pgx.Rows] will be returned in an error state.
// If preferred, ignore the error returned from [Pool.Query] and handle errors using the returned [pgx.Rows].
//
// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// For extra control over how the query is executed, the types [pgx.QueryExecMode], [pgx.QueryResultFormats], and
// [pgx.QueryResultFormatsByOID] may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
c, err := p.Acquire(ctx)
@ -750,16 +752,16 @@ func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, er
}
// QueryRow acquires a connection and executes a query that is expected
// to return at most one row (pgx.Row). Errors are deferred until pgx.Row's
// Scan method is called. If the query selects no rows, pgx.Row's Scan will
// return ErrNoRows. Otherwise, pgx.Row's Scan scans the first selected row
// and discards the rest. The acquired connection is returned to the Pool when
// pgx.Row's Scan method is called.
// to return at most one row ([pgx.Row]). Errors are deferred until [pgx.Row]'s
// Scan method is called. If the query selects no rows, [pgx.Row]'s Scan will
// return [pgx.ErrNoRows]. Otherwise, [pgx.Row]'s Scan scans the first selected row
// and discards the rest. The acquired connection is returned to the [Pool] when
// [pgx.Row]'s Scan method is called.
//
// Arguments should be referenced positionally from the SQL string as $1, $2, etc.
//
// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// For extra control over how the query is executed, the types [pgx.QueryExecMode], [pgx.QueryResultFormats], and
// [pgx.QueryResultFormatsByOID] may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
c, err := p.Acquire(ctx)
@ -781,18 +783,18 @@ func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return &poolBatchResults{br: br, c: c}
}
// Begin acquires a connection from the Pool and starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no
// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see BeginTx with TxOptions if transaction mode is required).
// *pgxpool.Tx is returned, which implements the pgx.Tx interface.
// Commit or Rollback must be called on the returned transaction to finalize the transaction block.
// Begin acquires a connection from the [Pool] and starts a transaction. Unlike [database/sql], the context only affects the begin command. i.e. there is no
// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see [Pool.BeginTx] with [pgx.TxOptions] if transaction mode is required).
// [*Tx] is returned, which implements the [pgx.Tx] interface.
// [Tx.Commit] or [Tx.Rollback] must be called on the returned transaction to finalize the transaction block.
func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) {
return p.BeginTx(ctx, pgx.TxOptions{})
}
// BeginTx acquires a connection from the Pool and starts a transaction with pgx.TxOptions determining the transaction mode.
// Unlike database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancellation.
// *pgxpool.Tx is returned, which implements the pgx.Tx interface.
// Commit or Rollback must be called on the returned transaction to finalize the transaction block.
// BeginTx acquires a connection from the [Pool] and starts a transaction with [pgx.TxOptions] determining the transaction mode.
// Unlike [database/sql], the context only affects the begin command. i.e. there is no auto-rollback on context cancellation.
// [*Tx] is returned, which implements the [pgx.Tx] interface.
// [Tx.Commit] or [Tx.Rollback] must be called on the returned transaction to finalize the transaction block.
func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
c, err := p.Acquire(ctx)
if err != nil {
@ -818,8 +820,8 @@ func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNam
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
}
// Ping acquires a connection from the Pool and executes an empty sql statement against it.
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
// Ping acquires a connection from the [Pool] and executes an empty sql statement against it.
// If the sql returns without error, the database [Pool.Ping] is considered successful, otherwise, the error is returned.
func (p *Pool) Ping(ctx context.Context) error {
c, err := p.Acquire(ctx)
if err != nil {

View File

@ -13,12 +13,12 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
// Rows is the result set returned from *Conn.Query. Rows must be closed before
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
// calling Next() until it returns false, or when a fatal error occurs.
// Rows is the result set returned from [Conn.Query]. Rows must be closed before
// the [Conn] can be used again. Rows are closed by explicitly calling [Rows.Close],
// calling [Rows.Next] until it returns false, or when a fatal error occurs.
//
// Once a Rows is closed the only methods that may be called are Close(), Err(),
// and CommandTag().
// Once a Rows is closed the only methods that may be called are [Rows.Close], [Rows.Err],
// and [Rows.CommandTag].
//
// Rows is an interface instead of a struct to allow tests to mock Query. However,
// adding a method to an interface is technically a breaking change. Because of this
@ -46,9 +46,9 @@ type Rows interface {
// having been read or due to an error).
//
// Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended
// prematurely due to an error. See Conn.Query for details.
// prematurely due to an error. See [Conn.Query] for details.
//
// For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead.
// For simpler error handling, consider using the higher-level pgx v5 [CollectRows()] and [ForEachRow()] helpers instead.
Next() bool
// Scan reads the values from the current row into dest values positionally. dest can include pointers to core types,
@ -70,7 +70,7 @@ type Rows interface {
Conn() *Conn
}
// Row is a convenience wrapper over Rows that is returned by QueryRow.
// Row is a convenience wrapper over [Rows] that is returned by [Conn.QueryRow].
//
// Row is an interface instead of a struct to allow tests to mock QueryRow. However,
// adding a method to an interface is technically a breaking change. Because of this
@ -358,7 +358,7 @@ func (e ScanArgError) Unwrap() error {
return e.Err
}
// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface.
// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level [pgconn] interface.
//
// typeMap - OID to Go type mapping.
// fieldDescriptions - OID and format of values
@ -386,8 +386,8 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v
return nil
}
// RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used
// to read from the lower level pgconn interface.
// RowsFromResultReader returns a [Rows] that will read from values resultReader and decode with typeMap. It can be used
// to read from the lower level [pgconn] interface.
func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows {
return &baseRows{
typeMap: typeMap,
@ -460,7 +460,7 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
// CollectOneRow is to [CollectRows] as [Conn.QueryRow] is to [Conn.Query].
//
// This function closes the rows automatically on return.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
@ -529,7 +529,7 @@ func RowTo[T any](row CollectableRow) (T, error) {
return value, err
}
// RowTo returns a the address of a T scanned from row.
// RowToAddrOf returns the address of a T scanned from row.
func RowToAddrOf[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&value)
@ -848,7 +848,7 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize
}
}
}
return
return i
}
// structRowField describes a field of a struct.

View File

@ -555,6 +555,12 @@ func (c *Conn) ResetSession(ctx context.Context) error {
return driver.ErrBadConn
}
// Discard connection if it has an open transaction. This can happen if the
// application did not properly commit or rollback a transaction.
if c.conn.PgConn().TxStatus() != 'I' {
return driver.ErrBadConn
}
now := time.Now()
idle := now.Sub(c.lastResetSessionTime)
@ -662,7 +668,7 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
switch fd.DataTypeOID {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
case pgtype.VarcharOID, pgtype.BPCharOID:
return int64(fd.TypeModifier - varHeaderSize), true
case pgtype.VarbitOID:
return int64(fd.TypeModifier), true
@ -693,25 +699,25 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
switch fd.DataTypeOID {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
return reflect.TypeFor[float64]()
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
return reflect.TypeFor[float32]()
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
return reflect.TypeFor[int64]()
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
return reflect.TypeFor[int32]()
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
return reflect.TypeFor[int16]()
case pgtype.BoolOID:
return reflect.TypeOf(false)
return reflect.TypeFor[bool]()
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
return reflect.TypeFor[float64]()
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
return reflect.TypeFor[time.Time]()
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
return reflect.TypeFor[[]byte]()
default:
return reflect.TypeOf("")
return reflect.TypeFor[string]()
}
}

170
vendor/github.com/jackc/pgx/v5/test.sh generated vendored Normal file
View File

@ -0,0 +1,170 @@
#!/usr/bin/env bash
set -euo pipefail
# test.sh - Run pgx tests against specific database targets
#
# Usage:
# ./test.sh [target] [go test flags...]
#
# Targets:
# pg14 - PostgreSQL 14 (port 5414)
# pg15 - PostgreSQL 15 (port 5415)
# pg16 - PostgreSQL 16 (port 5416)
# pg17 - PostgreSQL 17 (port 5417)
# pg18 - PostgreSQL 18 (port 5432) [default]
# crdb - CockroachDB (port 26257)
# all - Run against all targets sequentially
#
# Examples:
# ./test.sh # Test against PG18
# ./test.sh pg14 # Test against PG14
# ./test.sh crdb # Test against CockroachDB
# ./test.sh all # Test against all targets
# ./test.sh pg16 -run TestConnect # Test specific test against PG16
# ./test.sh pg18 -count=1 -v # Verbose, no cache, PG18
# Color output (disabled if not a terminal)
if [ -t 1 ]; then
GREEN='\033[0;32m'
RED='\033[0;31m'
BLUE='\033[0;34m'
NC='\033[0m'
else
GREEN=''
RED=''
BLUE=''
NC=''
fi
log_info() { echo -e "${BLUE}==> $*${NC}"; }
log_ok() { echo -e "${GREEN}==> $*${NC}"; }
log_err() { echo -e "${RED}==> $*${NC}" >&2; }
# Wait for a database to accept connections
wait_for_ready() {
local connstr="$1"
local label="$2"
local max_attempts=30
local attempt=0
log_info "Waiting for $label to be ready..."
while ! psql "$connstr" -c "SELECT 1" > /dev/null 2>&1; do
attempt=$((attempt + 1))
if [ "$attempt" -ge "$max_attempts" ]; then
log_err "$label did not become ready after $max_attempts attempts"
return 1
fi
sleep 1
done
log_ok "$label is ready"
}
# Directory containing this script (used to locate testsetup/)
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CERTS_DIR="$SCRIPT_DIR/testsetup/certs"
# Copy client certificates to /tmp for TLS tests
setup_client_certs() {
if [ -d "$CERTS_DIR" ]; then
base64 -d "$CERTS_DIR/ca.pem.b64" > /tmp/ca.pem
base64 -d "$CERTS_DIR/pgx_sslcert.crt.b64" > /tmp/pgx_sslcert.crt
base64 -d "$CERTS_DIR/pgx_sslcert.key.b64" > /tmp/pgx_sslcert.key
fi
}
# Initialize CockroachDB (create database if not exists)
init_crdb() {
local connstr="postgresql://root@localhost:26257/?sslmode=disable"
wait_for_ready "$connstr" "CockroachDB"
log_info "Ensuring pgx_test database exists on CockroachDB..."
psql "$connstr" -c "CREATE DATABASE IF NOT EXISTS pgx_test" 2>/dev/null || true
}
# Run tests against a single target
run_tests() {
local target="$1"
shift
local extra_args=("$@")
local label=""
local port=""
case "$target" in
pg14) label="PostgreSQL 14"; port=5414 ;;
pg15) label="PostgreSQL 15"; port=5415 ;;
pg16) label="PostgreSQL 16"; port=5416 ;;
pg17) label="PostgreSQL 17"; port=5417 ;;
pg18) label="PostgreSQL 18"; port=5432 ;;
crdb)
label="CockroachDB (port 26257)"
init_crdb
log_info "Testing against $label"
if ! PGX_TEST_DATABASE="postgresql://root@localhost:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" \
go test -count=1 "${extra_args[@]}" ./...; then
log_err "Tests FAILED against $label"
return 1
fi
log_ok "Tests passed against $label"
return 0
;;
*)
log_err "Unknown target: $target"
log_err "Valid targets: pg14, pg15, pg16, pg17, pg18, crdb, all"
return 1
;;
esac
setup_client_certs
log_info "Testing against $label (port $port)"
if ! PGX_TEST_DATABASE="host=localhost port=$port user=postgres password=postgres dbname=pgx_test" \
PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql port=$port user=postgres dbname=pgx_test" \
PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_md5 password=secret dbname=pgx_test" \
PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_scram password=secret dbname=pgx_test channel_binding=disable" \
PGX_TEST_SCRAM_PLUS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=require" \
PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 port=$port user=pgx_pw password=secret dbname=pgx_test" \
PGX_TEST_TLS_CONN_STRING="host=localhost port=$port user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test channel_binding=disable" \
PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost port=$port user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" \
PGX_SSL_PASSWORD=certpw \
go test -count=1 "${extra_args[@]}" ./...; then
log_err "Tests FAILED against $label"
return 1
fi
log_ok "Tests passed against $label"
}
# Main
main() {
local target="${1:-pg18}"
if [ "$target" = "all" ]; then
shift || true
local targets=(pg14 pg15 pg16 pg17 pg18 crdb)
local failed=()
for t in "${targets[@]}"; do
echo ""
log_info "=========================================="
log_info "Target: $t"
log_info "=========================================="
if ! run_tests "$t" "$@"; then
failed+=("$t")
log_err "FAILED: $t"
fi
done
echo ""
if [ ${#failed[@]} -gt 0 ]; then
log_err "Failed targets: ${failed[*]}"
return 1
else
log_ok "All targets passed"
fi
else
shift || true
run_tests "$target" "$@"
fi
}
main "$@"

12
vendor/github.com/jackc/pgx/v5/tx.go generated vendored
View File

@ -89,13 +89,13 @@ var ErrTxClosed = errors.New("tx is closed")
// it is treated as ROLLBACK.
var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
// Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no
// Begin starts a transaction. Unlike [database/sql], the context only affects the begin command. i.e. there is no
// auto-rollback on context cancellation.
func (c *Conn) Begin(ctx context.Context) (Tx, error) {
return c.BeginTx(ctx, TxOptions{})
}
// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only
// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike [database/sql], the context only
// affects the begin command. i.e. there is no auto-rollback on context cancellation.
func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
_, err := c.Exec(ctx, txOptions.beginSQL())
@ -385,8 +385,8 @@ func (sp *dbSimulatedNestedTx) Conn() *Conn {
return sp.tx.Conn()
}
// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls [Tx.Commit] on db. If fn
// returns an error it calls [Tx.Rollback] on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginFunc(
ctx context.Context,
@ -404,8 +404,8 @@ func BeginFunc(
return beginFuncExec(ctx, tx, fn)
}
// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls [Tx.Commit] on db. If fn
// returns an error it calls [Tx.Rollback] on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginTxFunc(
ctx context.Context,

4
vendor/modules.txt vendored
View File

@ -56,8 +56,8 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
## explicit; go 1.14
github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.8.0
## explicit; go 1.24.0
# github.com/jackc/pgx/v5 v5.9.2
## explicit; go 1.25.0
github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/iobufpool
github.com/jackc/pgx/v5/internal/pgio