3
0
mirror of https://github.com/ergochat/ergo.git synced 2026-02-19 15:58:00 +01:00

Compare commits

...

72 Commits

Author SHA1 Message Date
Shivaram Lingamneni
11a38b7119
bump irctest to test websockets (#2331)
Install python3-websockets for irctest in CI. This won't actually work
as expected until Ubuntu 26.04.
2026-02-16 18:19:04 -05:00
Shivaram Lingamneni
c79b65cf8a bump irc-go 2026-02-11 04:01:40 -05:00
Shivaram Lingamneni
8fe491156c
upgrade to go 1.26 (#2330)
Also upgrade actions for compatibility with 1.26
2026-02-10 19:19:08 -05:00
Shivaram Lingamneni
5fa9757b43
fix #2328 (#2329)
* Add API endpoint /v1/ns/passwd to change passwords
* New naming scheme for API endpoints under /v1/ns/
* /v1/ns/list no longer returns error responses
2026-02-08 01:27:13 -05:00
Shivaram Lingamneni
2b1dad982f
clarify when history.channel-length applies (#2323) 2026-01-27 08:13:21 -08:00
Shivaram Lingamneni
3ff0341944 bump irctest 2026-01-26 12:22:31 -05:00
Shivaram Lingamneni
2152000312
use correct name for RPL_INVEXLIST (#2326)
Rename the constants only; no functional change
2026-01-11 02:27:37 -05:00
Shivaram Lingamneni
2da19a0760
fix #2324 (#2325)
Validate user limit parameter
2026-01-10 18:57:49 -05:00
Shivaram Lingamneni
5c0af196da fix (noopSequence).Ephemeral() 2026-01-01 15:06:53 -05:00
Shivaram Lingamneni
63a04a5ff0 bump irctest 2026-01-01 12:33:12 -05:00
Shivaram Lingamneni
f28b10986e
redact: fix error line for missing channel (#2321) 2025-12-31 14:56:23 -05:00
Shivaram Lingamneni
d2c9c80cc1 fix mysql logline 2025-12-31 02:37:00 -05:00
Shivaram Lingamneni
6386b9ef70
fix some redact bugs (#2320)
* Consistently return UNKNOWN_MSGID for unknown or invalid msgids
* If both client's DMs are stored in persistent history, a single
  server.DeleteMessage will delete the single canonical copy of the message.
  So the second call will fail, which is fine.
2025-12-31 02:15:11 -05:00
Shivaram Lingamneni
6ba60c89c4 move mysql serialization tools to shared pkgs 2025-12-30 23:34:05 -05:00
Shivaram Lingamneni
0b7be24f80 propagate mysql close error 2025-12-30 23:13:55 -05:00
Shivaram Lingamneni
9f54ea07b7
prep for alternative history databases (#2316)
* abstract history DB interface

* make mysql error logging consistent

Consistently propagate database errors to the client, making the client
responsible for logging them.

* move ListCorrespondents from Sequence to Database/Buffer
2025-12-30 23:12:30 -05:00
Shivaram Lingamneni
3179fdffb0
REDACT: fix reason handling (#2319)
REDACT without a reason parameter was being relayed with an
empty reason parameter instead; fix this.
2025-12-30 18:41:24 -05:00
Shivaram Lingamneni
13653b5202
fix #2284 (#2317)
Send a configurable NOTICE on first connection that hopm can
detect via target_string
2025-12-30 02:27:09 -05:00
Shivaram Lingamneni
b26571e891 bump irctest 2025-12-29 12:41:20 -05:00
Shivaram Lingamneni
b7ede0730f
consistent casing for SQL queries (#2315)
Makes it easier to read diffs with other database backends
2025-12-29 00:03:58 -05:00
Shivaram Lingamneni
05b41e18af bump irctest 2025-12-28 21:47:27 -05:00
Shivaram Lingamneni
748700877e
dependency upgrades for v2.18 release cycle (#2314) 2025-12-23 00:07:39 -05:00
Shivaram Lingamneni
462e568f00
fix #2311 (#2312)
Validate bcrypt-cost config value to prevent silent errors
2025-12-22 03:26:09 -05:00
Shivaram Lingamneni
3c4c5dde4d
fix #2309 (#2310)
If bob is monitoring alice, bob should get METADATA lines for alice
even if bob doesn't have extended-monitor.

This implementation also removes the check for extended-monitor
when sending account-notify, away-notify, chghost, setname
(which are explicitly mentioned in the extended-monitor spec)
on the grounds that nothing bad will happen if clients who support
the cap receive notifications for users they're not explicitly tracking.
2025-12-22 03:25:27 -05:00
Shivaram Lingamneni
3e9fa38f6b set up new development version 2025-12-22 03:22:05 -05:00
Shivaram Lingamneni
16186d677d bump version and changelog for v2.17.0 2025-12-22 02:34:40 -05:00
Shivaram Lingamneni
d5fb189a55
changelog and version bump for v2.17.0-rc1 (#2308) 2025-12-14 04:43:30 -05:00
Shivaram Lingamneni
53664694c4
add RPL_WHOISKEYVALUE output (#2302) 2025-12-14 04:32:02 -05:00
Shivaram Lingamneni
d26aa37f2c
pin Alpine to 3.22 (#2306)
* Upgrade the production image from 3.19 to 3.22 (3.19 went EOL
  2025-11-01)
* Downgrade the build image to 3.22 (3.23 is buggy, see #2305)
2025-12-14 02:50:21 -05:00
Shivaram Lingamneni
9ca936a777 bump irctest 2025-12-14 02:15:44 -05:00
Shivaram Lingamneni
fdd261a1e6
fix #2303 (#2304)
Fix inconsistent behavior when history.enabled is set but
history.chathistory-maxmessages is not
2025-12-14 02:13:15 -05:00
Shivaram Lingamneni
aef5d77b3b
Merge pull request #2301 from slingamn/ratelimited
more metadata follow-ups
2025-12-08 01:48:39 -05:00
Shivaram Lingamneni
0ce9016098 add persistence for user metadata 2025-12-07 03:05:34 -05:00
Shivaram Lingamneni
f91d1d94f6 add METADATA response when MONITOR triggered 2025-12-01 07:03:26 +00:00
Shivaram Lingamneni
0119bbc36f implement FAIL METADATA RATE_LIMITED 2025-12-01 04:09:45 +00:00
Shivaram Lingamneni
96aa018352 bump irctest 2025-11-26 03:27:13 -05:00
Shivaram Lingamneni
68faf82787
Merge pull request #2299 from ergochat/shivaram_ping.1
configurable idle timeouts
2025-11-09 21:49:39 -05:00
Shivaram Lingamneni
5cda5bdac9
Merge pull request #2298 from slingamn/shivaram_operthrottle
changes to OPER command
2025-11-09 21:41:53 -05:00
Shivaram Lingamneni
ed841ee62a configurable idle timeouts
Fixes #2292
2025-11-09 21:11:04 -05:00
Shivaram Lingamneni
6fdac13ad4 changes to OPER command
* Impose a throttle on OPER attempts regardless of whether they caused a
  password check.
* Never disconnect the client on a failed attempt, even if there was a
  password check.
* Change error numeric to ERR_NOOPERHOST
* Explicit information about the failure in the server log (copying Insp)

Fixes #2296.
2025-11-09 19:34:31 -05:00
Shivaram Lingamneni
efc1627d23
Merge pull request #2295 from slingamn/shivaram_pushreject
fix validation of web push URLs
2025-10-26 00:58:52 -04:00
Shivaram Lingamneni
6b8265fb17 fix validation of web push URLs
They are validated by test message, but it would have been possible
to add an http url.

If an http url was added, it's still possible to remove it via
NS PUSH DELETE.
2025-10-26 00:55:30 -04:00
Shivaram Lingamneni
92f069846c
Merge pull request #2290 from slingamn/goupgrade
upgrade to go 1.25
2025-08-17 21:11:02 -07:00
Shivaram Lingamneni
8913bd7fa9 upgrade to go 1.25 2025-08-18 00:07:11 -04:00
Shivaram Lingamneni
064291e902 add an explicit note covering #2289 2025-08-12 16:23:36 -04:00
Shivaram Lingamneni
65295cbafa
Merge pull request #2283 from slingamn/makefile
refactor makefile to label individual targets phony
2025-06-26 22:17:06 -04:00
Shivaram Lingamneni
f0b1f34da7 refactor makefile to label individual targets phony 2025-06-26 22:14:19 -04:00
Shivaram Lingamneni
f918e28513 bump irctest 2025-06-26 01:31:51 -04:00
Shivaram Lingamneni
8798676ae9
update metadata corresponding to spec edits (#2282)
* spec update: metadata keys are lowercase

* add batch parameter to metadata batches

* fix: connecting clients receive METADATA, not RPL_KEYVALUE

* spec update: send RPL_METADATASUBS in a metadata-subs batch

* move some helpers

* bump irctest to forked hash

This is https://github.com/progval/irctest/pull/314 but I don't want to
couple the merges

* fix: empty value is valid

* fix: deleting a nonexistent key gets a FAIL
2025-06-22 18:59:42 -04:00
Shivaram Lingamneni
cca400de73 fix: actually broadcast prereg updates to subscribers
Missed in #2281, needs a test presumably :-)
2025-06-22 13:59:36 -04:00
Shivaram Lingamneni
73e51333ad
implement metadata before-connect (#2281)
* metadata spec update: disallow colon entirely

* refactor key validation

* implement metadata before-connect

* play the metadata in reg burst to all clients with the cap

* bump irctest

* remove all case normalization for keys

From spec discussion, we will most likely either require keys to be lowercase,
or else treat them as case-opaque, similar to message tag keys.
2025-06-22 13:57:46 -04:00
Shivaram Lingamneni
a5e435a26b bump irctest 2025-06-21 22:04:33 -04:00
Shivaram Lingamneni
17ed01c1ed
Merge pull request #2279 from slingamn/doc
fix some documentation
2025-06-19 13:29:08 -04:00
Shivaram Lingamneni
8f18454e8f fix help string for HISTORY 2025-06-19 13:25:34 -04:00
Shivaram Lingamneni
23844d4103 update documentation for globalUtf8EnforcementSetting 2025-06-19 13:22:07 -04:00
Shivaram Lingamneni
3b7db7fff7
round 1 of follow-up for metadata (#2277)
* refactoring
* send an empty batch if necessary, as per spec
* don't broadcast no-op updates
* don't trim spaces before validating the key
* bump irctest to cover metadata
* replay existing metadata to reattaching always-on clients
* use canonicalized name everywhere
* use utils.SafeErrorParam in FAIL lines
* validate key names for sub
* fix error for METADATA CLEAR
* max-keys is enforced for channels as well
* remove unlimited configurations
* maintain the limit exactly without off-by-one cases
* add final channel registration check
2025-06-18 00:22:49 -04:00
thatcher-gaming
4dcbc48159
metadata-2 (#2273)
Initial implementation of draft/metadata-2
2025-06-15 04:06:45 -04:00
Shivaram Lingamneni
0f5603eca2 bump irctest to upstream master 2025-06-09 02:20:49 -04:00
Shivaram Lingamneni
7d4f5e4adf
Merge pull request #2271 from slingamn/register
fix #2270
2025-06-09 02:19:55 -04:00
Shivaram Lingamneni
16568c5ab7 fix #2270
REGISTER should strip the guest format when applicable, same as NS REGISTER.
2025-06-08 16:50:34 -04:00
Shivaram Lingamneni
9a186f8e54
Fix invalid FAIL codes in REGISTER (#2269)
* nickserv.go: Update FAIL codes to match spec

* handlers.go: Fix FAIL code

* use ACCOUNT_EXISTS for errNameReserved

* bump irctest to development version

---------

Co-authored-by: Valerie Liu <79415174+ValwareIRC@users.noreply.github.com>
2025-06-08 01:43:43 -04:00
Shivaram Lingamneni
7828218bc7
Merge pull request #2264 from slingamn/cutprefix
fix #2147
2025-05-26 01:56:06 -04:00
Shivaram Lingamneni
7138e76151 fix #2147
use strings.CutPrefix when possible
2025-05-25 01:59:55 -04:00
Sarah Rose
e4aac56bda
API enhancements (#2261)
Fixes #2257 and #2260

* add `/v1/status` endpoint
* add `/v1/account_list` endpoint
* add fields to `/v1/account_details` response
2025-05-25 00:47:20 -04:00
Shivaram Lingamneni
4da6511674
Merge pull request #2262 from slingamn/constant
clean up constant redefinition
2025-05-23 00:36:46 -04:00
Shivaram Lingamneni
253972a9d2 clean up constant redefinition 2025-05-23 00:18:36 -04:00
Shivaram Lingamneni
a1c46a4be7 clarify channel registration instructions 2025-05-19 00:02:04 -04:00
Shivaram Lingamneni
7718081440
Merge pull request #2258 from slingamn/deps
upgrade dependencies for 2.17 release cycle
2025-05-18 02:07:40 -04:00
Shivaram Lingamneni
e7501ef847 upgrade go-msgauth 2025-05-18 01:28:48 -04:00
Shivaram Lingamneni
e404942d83 upgrade x dependencies 2025-05-18 01:27:52 -04:00
Shivaram Lingamneni
0a947115d6 set up new development version 2025-05-18 01:15:11 -04:00
Shivaram Lingamneni
9b9c39ddd4 changelog entry for API config 2025-05-18 01:09:13 -04:00
192 changed files with 10054 additions and 1851 deletions

View File

@ -15,13 +15,13 @@ jobs:
runs-on: "ubuntu-24.04" runs-on: "ubuntu-24.04"
steps: steps:
- name: "checkout repository" - name: "checkout repository"
uses: "actions/checkout@v3" uses: "actions/checkout@v6"
- name: "setup go" - name: "setup go"
uses: "actions/setup-go@v3" uses: "actions/setup-go@v6"
with: with:
go-version: "1.24" go-version: "1.26"
- name: "install python3-pytest" - name: "install python3-pytest"
run: "sudo apt install -y python3-pytest" run: "sudo apt install -y python3-pytest python3-websockets"
- name: "make install" - name: "make install"
run: "make install" run: "make install"
- name: "make test" - name: "make test"

View File

@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Git repository - name: Checkout Git repository
uses: actions/checkout@v3 uses: actions/checkout@v6
- name: Authenticate to container registry - name: Authenticate to container registry
uses: docker/login-action@v2 uses: docker/login-action@v2

View File

@ -1,6 +1,37 @@
# Changelog # Changelog
All notable changes to Ergo will be documented in this file. All notable changes to Ergo will be documented in this file.
## [2.17.0] - 2025-12-22
We're pleased to be publishing v2.17.0, a new stable release. This release adds support for the [IRCv3 metadata specification](https://ircv3.net/specs/extensions/metadata), thanks to [@thatcher-gaming](https://github.com/thatcher-gaming), as well as bug fixes and minor updates.
This release includes changes to the config file format, all of which are fully backwards-compatible and do not require updating the file before upgrading. It includes no changes to the database file format.
Many thanks to [@branchgrove](https://github.com/branchgrove), [@Brutus5000](https://github.com/Brutus5000), [@progval](https://github.com/progval), [@SarahRoseLives](https://github.com/SarahRoseLives), [@thatcher-gaming](https://github.com/thatcher-gaming), [@ValwareIRC](https://github.com/ValwareIRC), and Xogium for contributing patches, reporting issues, and helping test.
### Config changes
* Added `accounts.metadata` block to configure the new metadata feature. If this block is absent, metadata is disabled. See `default.yaml` for an example. (#2273)
* Added `server.idle-timeouts` for configurable idle timeouts; when unset, the previous hardcoded defaults are used (#2292, thanks [@Brutus5000](https://github.com/Brutus5000)!)
* Added `server.oper-throttle` to configure throttling for failed `OPER` attempts; when unset, this defaults to 1 attempt every 10 seconds (#2296)
### Added
* Implemented support for the [draft/metadata-2](https://ircv3.net/specs/extensions/metadata) specification, allowing clients to set and retrieve metadata on accounts and channels (#2273, #2277, #2281, #2282, #2301, thanks [@thatcher-gaming](https://github.com/thatcher-gaming)!)
* Added `/v1/status` and `/v1/account_list` HTTP API endpoints (#2261, thanks [@SarahRoseLives](https://github.com/SarahRoseLives)!)
* Enhanced `/v1/account_details` API response with additional fields (#2261, thanks [@SarahRoseLives](https://github.com/SarahRoseLives)!)
### Fixed
* Fixed `REGISTER` command to strip guest format when applicable, matching `NS REGISTER` behavior (#2270, #2271, thanks [@ValwareIRC](https://github.com/ValwareIRC) and [@thatcher-gaming](https://github.com/thatcher-gaming)!)
* Fixed invalid `FAIL` codes in `REGISTER` command (#2269, thanks [@ValwareIRC](https://github.com/ValwareIRC)!)
* Fixed validation of web push URLs to reject non-HTTPS URLs (#2295)
* Fixed inconsistent behavior when `history.enabled` is set but `history.chathistory-maxmessages` is not (#2303, #2304, thanks [@branchgrove](https://github.com/branchgrove)!)
### Changed
* The `OPER` command now imposes a throttle on all attempts, never disconnects the client on failure, and logs non-sensitive information about failed attempts (#2296, #2298, thanks Xogium!)
### Internal
* Official release builds use Go 1.25 (#2290)
* Upgraded the Docker base image from Alpine 3.19 to 3.22 (#2306)
## [2.16.0] - 2025-05-18 ## [2.16.0] - 2025-05-18
We're pleased to be publishing v2.16.0, a new stable release. This release contains bug fixes and some minor updates. We're pleased to be publishing v2.16.0, a new stable release. This release contains bug fixes and some minor updates.
@ -9,6 +40,7 @@ This release includes changes to the config file format, all of which are fully
Many thanks to [@csmith](https://github.com/csmith), [@delthas](https://github.com/delthas), donio, [@emersion](https://github.com/emersion), [@KlaasT](https://github.com/KlaasT), [@knolley](https://github.com/knolley), [@Mailaender](https://github.com/Mailaender), and [@prdes](https://github.com/prdes) for reporting issues and helping test. Many thanks to [@csmith](https://github.com/csmith), [@delthas](https://github.com/delthas), donio, [@emersion](https://github.com/emersion), [@KlaasT](https://github.com/KlaasT), [@knolley](https://github.com/knolley), [@Mailaender](https://github.com/Mailaender), and [@prdes](https://github.com/prdes) for reporting issues and helping test.
### Config changes ### Config changes
* Added `api` block for configuring the new HTTP API. If this block is absent, the API is disabled (#2231)
* Added `server.additional-isupport` for publishing arbitrary ISUPPORT tokens (#2220, #2240) * Added `server.additional-isupport` for publishing arbitrary ISUPPORT tokens (#2220, #2240)
* Added `server.command-aliases` to configure aliases for server commands (#2229, #2236) * Added `server.command-aliases` to configure aliases for server commands (#2229, #2236)
* Added options to `roleplay` to customize the NUH's sent for `NPC` and `SCENE`. Roleplay remains deprecated and disabled by default. (#2237) * Added options to `roleplay` to customize the NUH's sent for `NPC` and `SCENE`. Roleplay remains deprecated and disabled by default. (#2237)

View File

@ -1,5 +1,5 @@
## build ergo binary ## build ergo binary
FROM docker.io/golang:1.24-alpine AS build-env FROM docker.io/golang:1.26-alpine3.22 AS build-env
RUN apk upgrade -U --force-refresh --no-cache && apk add --no-cache --purge --clean-protected -l -u make git RUN apk upgrade -U --force-refresh --no-cache && apk add --no-cache --purge --clean-protected -l -u make git
@ -16,7 +16,7 @@ RUN sed -i 's/^\(\s*\)\"127.0.0.1:6667\":.*$/\1":6667":/' /go/src/github.com/erg
RUN make install RUN make install
## build ergo container ## build ergo container
FROM docker.io/alpine:3.19 FROM docker.io/alpine:3.22
# metadata # metadata
LABEL maintainer="Daniel Oaks <daniel@danieloaks.net>,Daniel Thamdrup <dallemon@protonmail.com>" \ LABEL maintainer="Daniel Oaks <daniel@danieloaks.net>,Daniel Thamdrup <dallemon@protonmail.com>" \

View File

@ -1,5 +1,3 @@
.PHONY: all install build release capdefs test smoke gofmt irctest
GIT_COMMIT := $(shell git rev-parse HEAD 2> /dev/null) GIT_COMMIT := $(shell git rev-parse HEAD 2> /dev/null)
GIT_TAG := $(shell git tag --points-at HEAD 2> /dev/null | head -n 1) GIT_TAG := $(shell git tag --points-at HEAD 2> /dev/null | head -n 1)
@ -9,33 +7,42 @@ export CGO_ENABLED ?= 0
capdef_file = ./irc/caps/defs.go capdef_file = ./irc/caps/defs.go
.PHONY: all
all: build all: build
.PHONY: install
install: install:
go install -v -ldflags "-X main.commit=$(GIT_COMMIT) -X main.version=$(GIT_TAG)" go install -v -ldflags "-X main.commit=$(GIT_COMMIT) -X main.version=$(GIT_TAG)"
.PHONY: build
build: build:
go build -v -ldflags "-X main.commit=$(GIT_COMMIT) -X main.version=$(GIT_TAG)" go build -v -ldflags "-X main.commit=$(GIT_COMMIT) -X main.version=$(GIT_TAG)"
.PHONY: release
release: release:
goreleaser --skip=publish --clean goreleaser --skip=publish --clean
.PHONY: capdefs
capdefs: capdefs:
python3 ./gencapdefs.py > ${capdef_file} python3 ./gencapdefs.py > ${capdef_file}
.PHONY: test
test: test:
python3 ./gencapdefs.py | diff - ${capdef_file} python3 ./gencapdefs.py | diff - ${capdef_file}
go test ./... go test ./...
go vet ./... go vet ./...
./.check-gofmt.sh ./.check-gofmt.sh
.PHONY: smoke
smoke: install smoke: install
ergo mkcerts --conf ./default.yaml || true ergo mkcerts --conf ./default.yaml || true
ergo run --conf ./default.yaml --smoke ergo run --conf ./default.yaml --smoke
.PHONY: gofmt
gofmt: gofmt:
./.check-gofmt.sh --fix ./.check-gofmt.sh --fix
.PHONY: irctest
irctest: install irctest: install
git submodule update --init git submodule update --init
cd irctest && make ergo cd irctest && make ergo

View File

@ -180,6 +180,21 @@ server:
# if this is true, the motd is escaped using formatting codes like $c, $b, and $i # if this is true, the motd is escaped using formatting codes like $c, $b, and $i
motd-formatting: true motd-formatting: true
# send a configurable notice to clients immediately after they connect,
# as a way of detecting open proxies
#initial-notice: "*** Welcome to the Ergo IRC server"
# idle timeouts for inactive clients
idle-timeouts:
# give the client this long to complete connection registration (i.e. the initial
# IRC handshake, including capability negotiation and SASL)
registration: 60s
# if the client hasn't sent anything for this long, send them a PING
ping: 1m30s
# if the client hasn't sent anything for this long (including the PONG to the
# above PING), disconnect them
disconnect: 2m30s
# relaying using the RELAYMSG command # relaying using the RELAYMSG command
relaymsg: relaymsg:
# is relaymsg enabled at all? # is relaymsg enabled at all?
@ -358,6 +373,10 @@ server:
secure-nets: secure-nets:
# - "10.0.0.0/8" # - "10.0.0.0/8"
# allow attempts to OPER with a password at most this often. defaults to
# 10 seconds when unset.
oper-throttle: 10s
# Ergo will write files to disk under certain circumstances, e.g., # Ergo will write files to disk under certain circumstances, e.g.,
# CPU profiling or data export. by default, these files will be written # CPU profiling or data export. by default, these files will be written
# to the working directory. set this to customize: # to the working directory. set this to customize:
@ -522,7 +541,7 @@ accounts:
# 1. these nicknames cannot be registered or reserved # 1. these nicknames cannot be registered or reserved
# 2. if a client is automatically renamed by the server, # 2. if a client is automatically renamed by the server,
# this is the template that will be used (e.g., Guest-nccj6rgmt97cg) # this is the template that will be used (e.g., Guest-nccj6rgmt97cg)
# 3. if enforce-guest-format (see below) is enabled, clients without # 3. if force-guest-format (see below) is enabled, clients without
# a registered account will have this template applied to their # a registered account will have this template applied to their
# nicknames (e.g., 'katie' will become 'Guest-katie') # nicknames (e.g., 'katie' will become 'Guest-katie')
guest-nickname-format: "Guest-*" guest-nickname-format: "Guest-*"
@ -724,6 +743,7 @@ oper-classes:
- "history" # modify or delete history messages - "history" # modify or delete history messages
- "defcon" # use the DEFCON command (restrict server capabilities) - "defcon" # use the DEFCON command (restrict server capabilities)
- "massmessage" # message all users on the server - "massmessage" # message all users on the server
- "metadata" # modify arbitrary metadata on channels and users
# ircd operators # ircd operators
opers: opers:
@ -987,10 +1007,12 @@ history:
# in your country and the countries of your users. # in your country and the countries of your users.
enabled: true enabled: true
# how many channel-specific events (messages, joins, parts) should be tracked per channel? # if the in-memory backend is enabled for a channel, how many channel-specific events
# (messages, joins, parts) should be retained?
channel-length: 2048 channel-length: 2048
# how many direct messages and notices should be tracked per user? # if the in-memory backend is enabled for a user, how many direct messages
# and notices should be retained?
client-length: 256 client-length: 256
# how long should we try to preserve messages? # how long should we try to preserve messages?
@ -1087,6 +1109,20 @@ history:
# e.g., ERGO__SERVER__MAX_SENDQ=128k. see the manual for more details. # e.g., ERGO__SERVER__MAX_SENDQ=128k. see the manual for more details.
allow-environment-overrides: true allow-environment-overrides: true
# metadata support for setting key/value data on channels and nicknames.
metadata:
# can clients store metadata?
enabled: true
# how many keys can a client subscribe to?
max-subs: 100
# how many keys can be stored per entity?
max-keys: 100
# rate limiting for client metadata updates, which are expensive to process
client-throttle:
enabled: true
duration: 2m
max-attempts: 10
# experimental support for mobile push notifications # experimental support for mobile push notifications
# see the manual for potential security, privacy, and performance implications. # see the manual for potential security, privacy, and performance implications.
# DO NOT enable if you are running a Tor or I2P hidden service (i.e. one # DO NOT enable if you are running a Tor or I2P hidden service (i.e. one

View File

@ -39,19 +39,6 @@ This returns:
Endpoints Endpoints
========= =========
`/v1/account_details`
----------------
This endpoint fetches account details and returns them as JSON. The request is a JSON object with fields:
* `accountName`: string, name of the account
The response is a JSON object with fields:
* `success`: whether the account exists or not
* `accountName`: canonical, case-unfolded version of the account name
* `email`: email address of the account provided
`/v1/check_auth` `/v1/check_auth`
---------------- ----------------
@ -65,16 +52,53 @@ The response is a JSON object with fields:
* `success`: whether the credentials provided were valid * `success`: whether the credentials provided were valid
* `accountName`: canonical, case-unfolded version of the account name * `accountName`: canonical, case-unfolded version of the account name
`/v1/rehash` `/v1/ns/info`
------------ -------------
This endpoint rehashes the server (i.e. reloads the configuration file, TLS certificates, and other associated data). The body is ignored. The response is a JSON object with fields: This endpoint fetches account details and returns them as JSON. The request is a JSON object with fields:
* `success`: boolean, indicates whether the rehash was successful * `accountName`: string, name of the account
* `error`: string, optional, human-readable description of the failure
`/v1/saregister` The response is a JSON object with fields:
----------------
* `success`: whether the account exists or not
* `accountName`: canonical, case-unfolded version of the account name
* `email`: email address of the account provided
* `registeredAt`: string, registration date/time of the account (in ISO8601 format)
* `channels`: array of strings, list of channels the account is registered on or associated with
Note: this endpoint was previously named `/v1/account_details`. The old name is still accepted for backwards compatibility.
`/v1/ns/list`
-------------
This endpoint fetches a list of all accounts. The request body is ignored and can be empty.
The response is a JSON object with fields:
* `success`: whether the request succeeded
* `accounts`: array of objects, each with fields:
* `success`: boolean, whether this individual account query succeeded
* `accountName`: string, canonical, case-unfolded version of the account name
* `totalCount`: integer, total number of accounts returned
Note: this endpoint was previously named `/v1/account_list`. The old name is still accepted for backwards compatibility.
`/v1/ns/passwd`
---------------
This endpoint changes the password of an existing NickServ account. The request is a JSON object with fields:
* `accountName`: string, name of the account
* `passphrase`: string, new passphrase for the account
The response is a JSON object with fields:
* `success`: whether the password change succeeded
* `errorCode`: string, optional, machine-readable description of the error. Possible values include: `ACCOUNT_DOES_NOT_EXIST`, `INVALID_PASSPHRASE`, `CREDENTIALS_EXTERNALLY_MANAGED`, `UNKNOWN_ERROR`.
`/v1/ns/saregister`
-------------------
This endpoint registers an account in NickServ, with the same semantics as `NS SAREGISTER`. The request is a JSON object with fields: This endpoint registers an account in NickServ, with the same semantics as `NS SAREGISTER`. The request is a JSON object with fields:
@ -86,3 +110,33 @@ The response is a JSON object with fields:
* `success`: whether the account creation succeeded * `success`: whether the account creation succeeded
* `errorCode`: string, optional, machine-readable description of the error. Possible values include: `ACCOUNT_EXISTS`, `INVALID_PASSPHRASE`, `UNKNOWN_ERROR`. * `errorCode`: string, optional, machine-readable description of the error. Possible values include: `ACCOUNT_EXISTS`, `INVALID_PASSPHRASE`, `UNKNOWN_ERROR`.
* `error`: string, optional, human-readable description of the failure. * `error`: string, optional, human-readable description of the failure.
Note: this endpoint was previously named `/v1/saregister`. The old name is still accepted for backwards compatibility.
`/v1/rehash`
------------
This endpoint rehashes the server (i.e. reloads the configuration file, TLS certificates, and other associated data). The body is ignored. The response is a JSON object with fields:
* `success`: boolean, indicates whether the rehash was successful
* `error`: string, optional, human-readable description of the failure
`/v1/status`
------------
This endpoint returns status information about the running Ergo server. The request body is ignored and can be empty.
The response is a JSON object with fields:
* `success`: whether the request succeeded
* `version`: string, Ergo server version string
* `go_version`: string, version of Go runtime used
* `start_time`: string, server start time in ISO8601 format
* `users`: object with fields:
* `total`: total number of users connected
* `invisible`: number of invisible users
* `operators`: number of operators connected
* `unknown`: number of users with unknown status
* `max`: maximum number of users seen connected at once
* `channels`: integer, number of channels currently active
* `servers`: integer, number of servers connected in the network

View File

@ -171,6 +171,7 @@ Rehashing also reloads TLS certificates and the MOTD. Some configuration setting
Ergo can also be configured using environment variables, using the following technique: Ergo can also be configured using environment variables, using the following technique:
1. Ensure that `allow-environment-variables` is set to `true` in the YAML config file itself (see `default.yaml` for an example)
1. Find the "path" of the config variable you want to override in the YAML file, e.g., `server.websockets.allowed-origins` 1. Find the "path" of the config variable you want to override in the YAML file, e.g., `server.websockets.allowed-origins`
1. Convert each path component from "kebab case" to "screaming snake case", e.g., `SERVER`, `WEBSOCKETS`, and `ALLOWED_ORIGINS`. 1. Convert each path component from "kebab case" to "screaming snake case", e.g., `SERVER`, `WEBSOCKETS`, and `ALLOWED_ORIGINS`.
1. Prepend `ERGO` to the components, then join them all together using `__` as the separator, e.g., `ERGO__SERVER__WEBSOCKETS__ALLOWED_ORIGINS`. 1. Prepend `ERGO` to the components, then join them all together using `__` as the separator, e.g., `ERGO__SERVER__WEBSOCKETS__ALLOWED_ORIGINS`.
@ -526,7 +527,7 @@ If your client or bot is failing to connect to Ergo, here are some things to che
## Why can't I oper? ## Why can't I oper?
If you try to oper unsuccessfully, Ergo will disconnect you from the network. If you're unable to oper, here are some things to double-check: If your `OPER` command fails, check your server logs for more information. Here are some general issues to double-check:
1. Did you correctly generate the hashed password with `ergo genpasswd`? 1. Did you correctly generate the hashed password with `ergo genpasswd`?
1. Did you add the password hash to the correct config file, then save the file? 1. Did you add the password hash to the correct config file, then save the file?
@ -1071,9 +1072,12 @@ You can import user and channel registrations from an Anope or Atheme database i
## Hybrid Open Proxy Monitor (HOPM) ## Hybrid Open Proxy Monitor (HOPM)
[hopm](https://github.com/ircd-hybrid/hopm) can be used to monitor your server for connections from open proxies, then automatically ban them. To configure hopm to work with Ergo, add operator blocks like this to your Ergo config file, which grant hopm the necessary privileges: [hopm](https://github.com/ircd-hybrid/hopm) can be used to monitor your server for connections from open proxies, then automatically ban them. To configure hopm to work with Ergo, configure `server.initial_notice` and add operator blocks like this to your Ergo config file, which grant hopm the necessary privileges:
````yaml ````yaml
server:
initial-notice: "Welcome to the Ergo IRC server"
# operator classes # operator classes
oper-classes: oper-classes:
# hopm # hopm
@ -1108,6 +1112,9 @@ opers:
Then configure hopm like this: Then configure hopm like this:
```` ````
/* replace with the exact notice your server sends on first connection */
target_string = ":ergo.test NOTICE * :*** Welcome to the Ergo IRC server"
/* ergo */ /* ergo */
connregex = ".+-.+CONNECT.+-.+ Client Connected \\[([^ ]+)\\] \\[u:([^ ]+)\\] \\[h:([^ ]+)\\] \\[ip:([^ ]+)\\] .+"; connregex = ".+-.+CONNECT.+-.+ Client Connected \\[([^ ]+)\\] \\[u:([^ ]+)\\] \\[h:([^ ]+)\\] \\[ip:([^ ]+)\\] .+";

View File

@ -86,7 +86,7 @@ Once you've registered your nickname, you can use it to register channels. By de
/msg ChanServ register #myChannel /msg ChanServ register #myChannel
``` ```
You must already be an operator (have the `+o` channel mode --- your client may display this as an `@` next to your nickname). If you're not a channel operator in the channel you want to register, ask your server administrator for help. The channel must exist (if it doesn't, you can create it with `/join #myChannel`) and you must already be an operator (have the `+o` channel mode --- your client may display this as an `@` next to your nickname). If you're not a channel operator in the channel you want to register, ask your server administrator for help.
# Always-on # Always-on

View File

@ -237,6 +237,13 @@ CAPDEFS = [
url="https://github.com/ircv3/ircv3-specifications/pull/471", url="https://github.com/ircv3/ircv3-specifications/pull/471",
standard="Soju/Goguma vendor", standard="Soju/Goguma vendor",
), ),
CapDef(
identifier="Metadata",
name="draft/metadata-2",
url="https://ircv3.net/specs/extensions/metadata",
standard="draft IRCv3",
),
] ]
def validate_defs(): def validate_defs():

19
go.mod
View File

@ -1,6 +1,6 @@
module github.com/ergochat/ergo module github.com/ergochat/ergo
go 1.24 go 1.26
require ( require (
code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48
@ -8,8 +8,8 @@ require (
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815
github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1 github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1
github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881 github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881
github.com/ergochat/irc-go v0.5.0-rc2 github.com/ergochat/irc-go v0.5.0
github.com/go-sql-driver/mysql v1.7.0 github.com/go-sql-driver/mysql v1.9.3
github.com/gofrs/flock v0.8.1 github.com/gofrs/flock v0.8.1
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
@ -18,19 +18,20 @@ require (
github.com/stretchr/testify v1.4.0 // indirect github.com/stretchr/testify v1.4.0 // indirect
github.com/tidwall/buntdb v1.3.2 github.com/tidwall/buntdb v1.3.2
github.com/xdg-go/scram v1.0.2 github.com/xdg-go/scram v1.0.2
golang.org/x/crypto v0.32.0 golang.org/x/crypto v0.46.0
golang.org/x/term v0.28.0 golang.org/x/term v0.38.0
golang.org/x/text v0.21.0 golang.org/x/text v0.32.0
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
) )
require ( require (
github.com/emersion/go-msgauth v0.6.8 github.com/emersion/go-msgauth v0.7.0
github.com/ergochat/webpush-go/v2 v2.0.0 github.com/ergochat/webpush-go/v2 v2.0.0
github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang-jwt/jwt/v5 v5.3.0
) )
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/tidwall/btree v1.4.2 // indirect github.com/tidwall/btree v1.4.2 // indirect
github.com/tidwall/gjson v1.14.3 // indirect github.com/tidwall/gjson v1.14.3 // indirect
github.com/tidwall/grect v0.1.4 // indirect github.com/tidwall/grect v0.1.4 // indirect
@ -39,7 +40,7 @@ require (
github.com/tidwall/rtred v0.1.2 // indirect github.com/tidwall/rtred v0.1.2 // indirect
github.com/tidwall/tinyqueue v0.1.1 // indirect github.com/tidwall/tinyqueue v0.1.1 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect
golang.org/x/sys v0.29.0 // indirect golang.org/x/sys v0.39.0 // indirect
) )
replace github.com/gorilla/websocket => github.com/ergochat/websocket v1.4.2-oragono1 replace github.com/gorilla/websocket => github.com/ergochat/websocket v1.4.2-oragono1

38
go.sum
View File

@ -1,19 +1,21 @@
code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 h1:/EMHruHCFXR9xClkGV/t0rmHrdhX4+trQUcBqjwc9xE= code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 h1:/EMHruHCFXR9xClkGV/t0rmHrdhX4+trQUcBqjwc9xE=
code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48/go.mod h1:wN/zk7mhREp/oviagqUXY3EwuHhWyOvAdsn5Y4CzOrc= code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48/go.mod h1:wN/zk7mhREp/oviagqUXY3EwuHhWyOvAdsn5Y4CzOrc=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 h1:KeNholpO2xKjgaaSyd+DyQRrsQjhbSeS7qe4nEw8aQw= github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 h1:KeNholpO2xKjgaaSyd+DyQRrsQjhbSeS7qe4nEw8aQw=
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962/go.mod h1:kC29dT1vFpj7py2OvG1khBdQpo3kInWP+6QipLbdngo= github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962/go.mod h1:kC29dT1vFpj7py2OvG1khBdQpo3kInWP+6QipLbdngo=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/emersion/go-msgauth v0.6.8 h1:kW/0E9E8Zx5CdKsERC/WnAvnXvX7q9wTHia1OA4944A= github.com/emersion/go-msgauth v0.7.0 h1:vj2hMn6KhFtW41kshIBTXvp6KgYSqpA/ZN9Pv4g1INc=
github.com/emersion/go-msgauth v0.6.8/go.mod h1:YDwuyTCUHu9xxmAeVj0eW4INnwB6NNZoPdLerpSxRrc= github.com/emersion/go-msgauth v0.7.0/go.mod h1:mmS9I6HkSovrNgq0HNXTeu8l3sRAAuQ9RMvbM4KU7Ck=
github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1 h1:WLHTOodthVyv5NvYLIvWl112kSFv5IInKKrRN2qpons= github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1 h1:WLHTOodthVyv5NvYLIvWl112kSFv5IInKKrRN2qpons=
github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1/go.mod h1:mov+uh1DPWsltdQnOdzn08UO9GsJ3MEvhtu0Ci37fdk= github.com/ergochat/confusables v0.0.0-20201108231250-4ab98ab61fb1/go.mod h1:mov+uh1DPWsltdQnOdzn08UO9GsJ3MEvhtu0Ci37fdk=
github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881 h1:+J5m88nvybxB5AnBVGzTXM/yHVytt48rXBGcJGzSbms= github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881 h1:+J5m88nvybxB5AnBVGzTXM/yHVytt48rXBGcJGzSbms=
github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881/go.mod h1:ASYJtQujNitna6cVHsNQTGrfWvMPJ5Sa2lZlmsH65uM= github.com/ergochat/go-ident v0.0.0-20230911071154-8c30606d6881/go.mod h1:ASYJtQujNitna6cVHsNQTGrfWvMPJ5Sa2lZlmsH65uM=
github.com/ergochat/irc-go v0.5.0-rc2 h1:VuSQJF5K4hWvYSzGa4b8vgL6kzw8HF6LSOejE+RWpAo= github.com/ergochat/irc-go v0.5.0 h1:woQ1RS9YbfgqPgSpPBBQeczXGIGzR0aC7dEgk469fTw=
github.com/ergochat/irc-go v0.5.0-rc2/go.mod h1:2vi7KNpIPWnReB5hmLpl92eMywQvuIeIIGdt/FQCph0= github.com/ergochat/irc-go v0.5.0/go.mod h1:2vi7KNpIPWnReB5hmLpl92eMywQvuIeIIGdt/FQCph0=
github.com/ergochat/scram v1.0.2-ergo1 h1:2bYXiRFQH636pT0msOG39fmEYl4Eq+OuutcyDsCix/g= github.com/ergochat/scram v1.0.2-ergo1 h1:2bYXiRFQH636pT0msOG39fmEYl4Eq+OuutcyDsCix/g=
github.com/ergochat/scram v1.0.2-ergo1/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= github.com/ergochat/scram v1.0.2-ergo1/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs=
github.com/ergochat/webpush-go/v2 v2.0.0 h1:n6eoJk8RpzJFeBJ6gxvqo/dngnVEmJbzJwzKtCZbByo= github.com/ergochat/webpush-go/v2 v2.0.0 h1:n6eoJk8RpzJFeBJ6gxvqo/dngnVEmJbzJwzKtCZbByo=
@ -21,12 +23,12 @@ github.com/ergochat/webpush-go/v2 v2.0.0/go.mod h1:OQlhnq8JeHDzRzAy6bdDObr19uqbH
github.com/ergochat/websocket v1.4.2-oragono1 h1:plMUunFBM6UoSCIYCKKclTdy/TkkHfUslhOfJQzfueM= github.com/ergochat/websocket v1.4.2-oragono1 h1:plMUunFBM6UoSCIYCKKclTdy/TkkHfUslhOfJQzfueM=
github.com/ergochat/websocket v1.4.2-oragono1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/ergochat/websocket v1.4.2-oragono1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
@ -68,22 +70,22 @@ github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc=
github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -52,6 +52,7 @@ const (
// (not to be confused with their amodes, which a non-always-on client can have): // (not to be confused with their amodes, which a non-always-on client can have):
keyAccountChannelToModes = "account.channeltomodes %s" keyAccountChannelToModes = "account.channeltomodes %s"
keyAccountPushSubscriptions = "account.pushsubscriptions %s" keyAccountPushSubscriptions = "account.pushsubscriptions %s"
keyAccountMetadata = "account.metadata %s"
maxCertfpsPerAccount = 5 maxCertfpsPerAccount = 5
) )
@ -137,6 +138,7 @@ func (am *AccountManager) createAlwaysOnClients(config *Config) {
am.loadModes(accountName), am.loadModes(accountName),
am.loadRealname(accountName), am.loadRealname(accountName),
am.loadPushSubscriptions(accountName), am.loadPushSubscriptions(accountName),
am.loadMetadata(accountName),
) )
} }
} }
@ -751,6 +753,40 @@ func (am *AccountManager) loadPushSubscriptions(account string) (result []stored
} }
} }
func (am *AccountManager) saveMetadata(account string, metadata map[string]string) {
j, err := json.Marshal(metadata)
if err != nil {
am.server.logger.Error("internal", "error storing metadata", err.Error())
return
}
val := string(j)
key := fmt.Sprintf(keyAccountMetadata, account)
am.server.store.Update(func(tx *buntdb.Tx) error {
tx.Set(key, val, nil)
return nil
})
return
}
func (am *AccountManager) loadMetadata(account string) (result map[string]string) {
key := fmt.Sprintf(keyAccountMetadata, account)
var val string
am.server.store.View(func(tx *buntdb.Tx) error {
val, _ = tx.Get(key)
return nil
})
if val == "" {
return nil
}
if err := json.Unmarshal([]byte(val), &result); err == nil {
return result
} else {
am.server.logger.Error("internal", "error loading metadata", err.Error())
return nil
}
}
func (am *AccountManager) addRemoveCertfp(account, certfp string, add bool, hasPrivs bool) (err error) { func (am *AccountManager) addRemoveCertfp(account, certfp string, add bool, hasPrivs bool) (err error) {
certfp, err = utils.NormalizeCertfp(certfp) certfp, err = utils.NormalizeCertfp(certfp)
if err != nil { if err != nil {
@ -1880,6 +1916,7 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
pwResetKey := fmt.Sprintf(keyAccountPwReset, casefoldedAccount) pwResetKey := fmt.Sprintf(keyAccountPwReset, casefoldedAccount)
emailChangeKey := fmt.Sprintf(keyAccountEmailChange, casefoldedAccount) emailChangeKey := fmt.Sprintf(keyAccountEmailChange, casefoldedAccount)
pushSubscriptionsKey := fmt.Sprintf(keyAccountPushSubscriptions, casefoldedAccount) pushSubscriptionsKey := fmt.Sprintf(keyAccountPushSubscriptions, casefoldedAccount)
metadataKey := fmt.Sprintf(keyAccountMetadata, casefoldedAccount)
var clients []*Client var clients []*Client
defer func() { defer func() {
@ -1939,6 +1976,7 @@ func (am *AccountManager) Unregister(account string, erase bool) error {
tx.Delete(pwResetKey) tx.Delete(pwResetKey)
tx.Delete(emailChangeKey) tx.Delete(emailChangeKey)
tx.Delete(pushSubscriptionsKey) tx.Delete(pushSubscriptionsKey)
tx.Delete(metadataKey)
return nil return nil
}) })
@ -2299,7 +2337,7 @@ func (ac *AccountCredentials) Serialize() (result string, err error) {
return string(credText), nil return string(credText), nil
} }
func (ac *AccountCredentials) SetPassphrase(passphrase string, bcryptCost uint) (err error) { func (ac *AccountCredentials) SetPassphrase(passphrase string, bcryptCost int) (err error) {
if passphrase == "" { if passphrase == "" {
ac.PassphraseHash = nil ac.PassphraseHash = nil
ac.SCRAMCreds = SCRAMCreds{} ac.SCRAMCreds = SCRAMCreds{}
@ -2310,7 +2348,7 @@ func (ac *AccountCredentials) SetPassphrase(passphrase string, bcryptCost uint)
return errAccountBadPassphrase return errAccountBadPassphrase
} }
ac.PassphraseHash, err = passwd.GenerateFromPassword([]byte(passphrase), int(bcryptCost)) ac.PassphraseHash, err = passwd.GenerateFromPassword([]byte(passphrase), bcryptCost)
if err != nil { if err != nil {
return errAccountBadPassphrase return errAccountBadPassphrase
} }

View File

@ -5,7 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"runtime"
"strings" "strings"
"github.com/ergochat/ergo/irc/utils"
) )
func newAPIHandler(server *Server) http.Handler { func newAPIHandler(server *Server) http.Handler {
@ -14,10 +17,23 @@ func newAPIHandler(server *Server) http.Handler {
mux: http.NewServeMux(), mux: http.NewServeMux(),
} }
// server-level functionality:
api.mux.HandleFunc("POST /v1/rehash", api.handleRehash) api.mux.HandleFunc("POST /v1/rehash", api.handleRehash)
api.mux.HandleFunc("POST /v1/status", api.handleStatus)
// use Ergo as a source of truth for authentication in other services:
api.mux.HandleFunc("POST /v1/check_auth", api.handleCheckAuth) api.mux.HandleFunc("POST /v1/check_auth", api.handleCheckAuth)
// legacy names for /v1/ns endpoints:
api.mux.HandleFunc("POST /v1/saregister", api.handleSaregister) api.mux.HandleFunc("POST /v1/saregister", api.handleSaregister)
api.mux.HandleFunc("POST /v1/account_details", api.handleAccountDetails) api.mux.HandleFunc("POST /v1/account_details", api.handleAccountDetails)
api.mux.HandleFunc("POST /v1/account_list", api.handleAccountList)
// /v1/ns: nickserv functionality
api.mux.HandleFunc("POST /v1/ns/info", api.handleAccountDetails)
api.mux.HandleFunc("POST /v1/ns/list", api.handleAccountList)
api.mux.HandleFunc("POST /v1/ns/passwd", api.handleNsPasswd)
api.mux.HandleFunc("POST /v1/ns/saregister", api.handleSaregister)
return api return api
} }
@ -29,7 +45,6 @@ type ergoAPI struct {
func (a *ergoAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (a *ergoAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer a.server.HandlePanic(nil) defer a.server.HandlePanic(nil)
defer a.server.logger.Debug("api", r.URL.Path) defer a.server.logger.Debug("api", r.URL.Path)
if a.checkBearerAuth(r.Header.Get("Authorization")) { if a.checkBearerAuth(r.Header.Get("Authorization")) {
@ -117,8 +132,6 @@ func (a *ergoAPI) handleCheckAuth(w http.ResponseWriter, r *http.Request) {
// try passphrase if present // try passphrase if present
if request.AccountName != "" && request.Passphrase != "" { if request.AccountName != "" && request.Passphrase != "" {
// TODO this only checks the internal database, not auth-script;
// it's a little weird to use both auth-script and the API but we should probably handle it
account, err := a.server.accounts.checkPassphrase(request.AccountName, request.Passphrase) account, err := a.server.accounts.checkPassphrase(request.AccountName, request.Passphrase)
switch err { switch err {
case nil: case nil:
@ -133,7 +146,6 @@ func (a *ergoAPI) handleCheckAuth(w http.ResponseWriter, r *http.Request) {
response.Error = err.Error() response.Error = err.Error()
} }
} }
// try certfp if present // try certfp if present
if !response.Success && request.Certfp != "" { if !response.Success && request.Certfp != "" {
// TODO support cerftp // TODO support cerftp
@ -173,10 +185,37 @@ func (a *ergoAPI) handleSaregister(w http.ResponseWriter, r *http.Request) {
a.writeJSONResponse(response, w, r) a.writeJSONResponse(response, w, r)
} }
func (a *ergoAPI) handleNsPasswd(w http.ResponseWriter, r *http.Request) {
var request apiSaregisterRequest
if err := a.decodeJSONRequest(&request, w, r); err != nil {
return
}
var response apiGenericResponse
err := a.server.accounts.setPassword(request.AccountName, request.Passphrase, true)
switch err {
case nil:
response.Success = true
case errAccountDoesNotExist:
response.ErrorCode = "ACCOUNT_DOES_NOT_EXIST"
case errAccountBadPassphrase, errEmptyCredentials:
response.ErrorCode = "INVALID_PASSPHRASE"
case errCredsExternallyManaged:
response.ErrorCode = "CREDENTIALS_EXTERNALLY_MANAGED"
default:
a.server.logger.Error("api", "could not change user password:", err.Error())
response.ErrorCode = "UNKNOWN_ERROR"
}
a.writeJSONResponse(response, w, r)
}
type apiAccountDetailsResponse struct { type apiAccountDetailsResponse struct {
apiGenericResponse apiGenericResponse
AccountName string `json:"accountName,omitempty"` AccountName string `json:"accountName,omitempty"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
RegisteredAt string `json:"registeredAt,omitempty"`
Channels []string `json:"channels,omitempty"`
} }
type apiAccountDetailsRequest struct { type apiAccountDetailsRequest struct {
@ -191,8 +230,6 @@ func (a *ergoAPI) handleAccountDetails(w http.ResponseWriter, r *http.Request) {
var response apiAccountDetailsResponse var response apiAccountDetailsResponse
// TODO could probably use better error handling and more details
if request.AccountName != "" { if request.AccountName != "" {
accountData, err := a.server.accounts.LoadAccount(request.AccountName) accountData, err := a.server.accounts.LoadAccount(request.AccountName)
if err == nil { if err == nil {
@ -207,6 +244,12 @@ func (a *ergoAPI) handleAccountDetails(w http.ResponseWriter, r *http.Request) {
case nil: case nil:
response.AccountName = accountData.Name response.AccountName = accountData.Name
response.Email = accountData.Settings.Email response.Email = accountData.Settings.Email
if !accountData.RegisteredAt.IsZero() {
response.RegisteredAt = accountData.RegisteredAt.Format(utils.IRCv3TimestampFormat)
}
// Get channels the account is in
response.Channels = a.server.channels.ChannelsForAccount(accountData.NameCasefolded)
response.Success = true response.Success = true
case errAccountDoesNotExist, errAccountUnverified, errAccountSuspended: case errAccountDoesNotExist, errAccountUnverified, errAccountSuspended:
response.Success = false response.Success = false
@ -222,3 +265,81 @@ func (a *ergoAPI) handleAccountDetails(w http.ResponseWriter, r *http.Request) {
a.writeJSONResponse(response, w, r) a.writeJSONResponse(response, w, r)
} }
type apiAccountListResponse struct {
apiGenericResponse
Accounts []apiAccountDetailsResponse `json:"accounts"`
TotalCount int `json:"totalCount"`
}
func (a *ergoAPI) handleAccountList(w http.ResponseWriter, r *http.Request) {
var response apiAccountListResponse
// Get all account names
accounts := a.server.accounts.AllNicks()
response.TotalCount = len(accounts)
// Load account details
response.Accounts = make([]apiAccountDetailsResponse, 0, len(accounts))
for _, account := range accounts {
accountData, err := a.server.accounts.LoadAccount(account)
if err != nil {
// shouldn't happen
continue
}
response.Accounts = append(
response.Accounts,
apiAccountDetailsResponse{
apiGenericResponse: apiGenericResponse{
Success: true,
},
AccountName: accountData.Name,
Email: accountData.Settings.Email,
},
)
}
response.Success = true
a.writeJSONResponse(response, w, r)
}
type apiStatusResponse struct {
apiGenericResponse
Version string `json:"version"`
GoVersion string `json:"go_version"`
Commit string `json:"commit,omitempty"`
StartTime string `json:"start_time"`
Users struct {
Total int `json:"total"`
Invisible int `json:"invisible"`
Operators int `json:"operators"`
Unknown int `json:"unknown"`
Max int `json:"max"`
} `json:"users"`
Channels int `json:"channels"`
Servers int `json:"servers"`
}
func (a *ergoAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
server := a.server
stats := server.stats.GetValues()
response := apiStatusResponse{
apiGenericResponse: apiGenericResponse{Success: true},
Version: SemVer,
GoVersion: runtime.Version(),
Commit: Commit,
StartTime: server.ctime.Format(utils.IRCv3TimestampFormat),
}
response.Users.Total = stats.Total
response.Users.Invisible = stats.Invisible
response.Users.Operators = stats.Operators
response.Users.Unknown = stats.Unknown
response.Users.Max = stats.Max
response.Channels = server.channels.Len()
response.Servers = 1
a.writeJSONResponse(response, w, r)
}

View File

@ -44,10 +44,11 @@ func (b *buntdbDatastore) GetAll(table datastore.Table) (result []datastore.KV,
tablePrefix := fmt.Sprintf("%x ", table) tablePrefix := fmt.Sprintf("%x ", table)
err = b.db.View(func(tx *buntdb.Tx) error { err = b.db.View(func(tx *buntdb.Tx) error {
err := tx.AscendGreaterOrEqual("", tablePrefix, func(key, value string) bool { err := tx.AscendGreaterOrEqual("", tablePrefix, func(key, value string) bool {
if !strings.HasPrefix(key, tablePrefix) { encUUID, ok := strings.CutPrefix(key, tablePrefix)
if !ok {
return false return false
} }
uuid, err := utils.DecodeUUID(strings.TrimPrefix(key, tablePrefix)) uuid, err := utils.DecodeUUID(encUUID)
if err == nil { if err == nil {
result = append(result, datastore.KV{UUID: uuid, Value: []byte(value)}) result = append(result, datastore.KV{UUID: uuid, Value: []byte(value)})
} else { } else {

View File

@ -7,7 +7,7 @@ package caps
const ( const (
// number of recognized capabilities: // number of recognized capabilities:
numCapabs = 37 numCapabs = 38
// length of the uint32 array that represents the bitset: // length of the uint32 array that represents the bitset:
bitsetLen = 2 bitsetLen = 2
) )
@ -65,6 +65,10 @@ const (
// https://github.com/progval/ircv3-specifications/blob/redaction/extensions/message-redaction.md // https://github.com/progval/ircv3-specifications/blob/redaction/extensions/message-redaction.md
MessageRedaction Capability = iota MessageRedaction Capability = iota
// Metadata is the draft IRCv3 capability named "draft/metadata-2":
// https://ircv3.net/specs/extensions/metadata
Metadata Capability = iota
// Multiline is the proposed IRCv3 capability named "draft/multiline": // Multiline is the proposed IRCv3 capability named "draft/multiline":
// https://github.com/ircv3/ircv3-specifications/pull/398 // https://github.com/ircv3/ircv3-specifications/pull/398
Multiline Capability = iota Multiline Capability = iota
@ -178,6 +182,7 @@ var (
"draft/extended-isupport", "draft/extended-isupport",
"draft/languages", "draft/languages",
"draft/message-redaction", "draft/message-redaction",
"draft/metadata-2",
"draft/multiline", "draft/multiline",
"draft/no-implicit-names", "draft/no-implicit-names",
"draft/persistence", "draft/persistence",

View File

@ -7,6 +7,7 @@ package irc
import ( import (
"fmt" "fmt"
"iter"
"maps" "maps"
"strconv" "strconv"
"strings" "strings"
@ -55,6 +56,7 @@ type Channel struct {
dirtyBits uint dirtyBits uint
settings ChannelSettings settings ChannelSettings
uuid utils.UUID uuid utils.UUID
metadata map[string]string
// these caches are paired to allow iteration over channel members without holding the lock // these caches are paired to allow iteration over channel members without holding the lock
membersCache []*Client membersCache []*Client
memberDataCache []*memberData memberDataCache []*memberData
@ -126,6 +128,7 @@ func (channel *Channel) applyRegInfo(chanReg RegisteredChannel) {
channel.userLimit = chanReg.UserLimit channel.userLimit = chanReg.UserLimit
channel.settings = chanReg.Settings channel.settings = chanReg.Settings
channel.forward = chanReg.Forward channel.forward = chanReg.Forward
channel.metadata = chanReg.Metadata
for _, mode := range chanReg.Modes { for _, mode := range chanReg.Modes {
channel.flags.SetMode(mode, true) channel.flags.SetMode(mode, true)
@ -163,6 +166,7 @@ func (channel *Channel) ExportRegistration() (info RegisteredChannel) {
info.AccountToUMode = maps.Clone(channel.accountToUMode) info.AccountToUMode = maps.Clone(channel.accountToUMode)
info.Settings = channel.settings info.Settings = channel.settings
info.Metadata = channel.metadata
return return
} }
@ -726,6 +730,9 @@ func (channel *Channel) AddHistoryItem(item history.Item, account string) (err e
status, target, _ := channel.historyStatus(channel.server.Config()) status, target, _ := channel.historyStatus(channel.server.Config())
if status == HistoryPersistent { if status == HistoryPersistent {
err = channel.server.historyDB.AddChannelItem(target, item, account) err = channel.server.historyDB.AddChannelItem(target, item, account)
if err != nil {
channel.server.logger.Error("history", "could not add channel message to history", err.Error())
}
} else if status == HistoryEphemeral { } else if status == HistoryEphemeral {
channel.history.Add(item) channel.history.Add(item)
} }
@ -892,6 +899,10 @@ func (channel *Channel) Join(client *Client, key string, isSajoin bool, rb *Resp
rb.Add(nil, client.server.name, "MARKREAD", chname, client.GetReadMarker(chcfname)) rb.Add(nil, client.server.name, "MARKREAD", chname, client.GetReadMarker(chcfname))
} }
if rb.session.capabilities.Has(caps.Metadata) {
syncChannelMetadata(client.server, rb, channel)
}
if rb.session.client == client { if rb.session.client == client {
// don't send topic and names for a SAJOIN of a different client // don't send topic and names for a SAJOIN of a different client
channel.SendTopic(client, rb, false) channel.SendTopic(client, rb, false)
@ -1462,8 +1473,8 @@ func (channel *Channel) ShowMaskList(client *Client, mode modes.Mode, rb *Respon
rpllist = RPL_EXCEPTLIST rpllist = RPL_EXCEPTLIST
rplendoflist = RPL_ENDOFEXCEPTLIST rplendoflist = RPL_ENDOFEXCEPTLIST
} else if mode == modes.InviteMask { } else if mode == modes.InviteMask {
rpllist = RPL_INVITELIST rpllist = RPL_INVEXLIST
rplendoflist = RPL_ENDOFINVITELIST rplendoflist = RPL_ENDOFINVEXLIST
} }
nick := client.Nick() nick := client.Nick()
@ -1669,6 +1680,20 @@ func (channel *Channel) auditoriumFriends(client *Client) (friends []*Client) {
return return
} }
func (channel *Channel) sessionsWithCaps(capabs ...caps.Capability) iter.Seq[*Session] {
return func(yield func(*Session) bool) {
for _, member := range channel.Members() {
for _, sess := range member.Sessions() {
if sess.capabilities.HasAll(capabs...) {
if !yield(sess) {
return
}
}
}
}
}
}
// returns whether the client is visible to unprivileged users in the channel // returns whether the client is visible to unprivileged users in the channel
// (i.e., respecting auditorium mode). note that this assumes that the client // (i.e., respecting auditorium mode). note that this assumes that the client
// is a member; if the client is not, it may return true anyway // is a member; if the client is not, it may return true anyway

View File

@ -206,6 +206,10 @@ func (cm *ChannelManager) Cleanup(channel *Channel) {
} }
func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) { func (cm *ChannelManager) SetRegistered(channelName string, account string) (err error) {
if account == "" {
return errAuthRequired // this is already enforced by ChanServ, but do a final check
}
if cm.server.Defcon() <= 4 { if cm.server.Defcon() <= 4 {
return errFeatureDisabled return errFeatureDisabled
} }

View File

@ -63,6 +63,8 @@ type RegisteredChannel struct {
Invites map[string]MaskInfo Invites map[string]MaskInfo
// Settings are the chanserv-modifiable settings // Settings are the chanserv-modifiable settings
Settings ChannelSettings Settings ChannelSettings
// Metadata set using the METADATA command
Metadata map[string]string
} }
func (r *RegisteredChannel) Serialize() ([]byte, error) { func (r *RegisteredChannel) Serialize() ([]byte, error) {

View File

@ -41,8 +41,7 @@ const (
DefaultMaxLineLen = 512 DefaultMaxLineLen = 512
// IdentTimeout is how long before our ident (username) check times out. // IdentTimeout is how long before our ident (username) check times out.
IdentTimeout = time.Second + 500*time.Millisecond IdentTimeout = time.Second + 500*time.Millisecond
IRCv3TimestampFormat = utils.IRCv3TimestampFormat
// limit the number of device IDs a client can use, as a DoS mitigation // limit the number of device IDs a client can use, as a DoS mitigation
maxDeviceIDsPerClient = 64 maxDeviceIDsPerClient = 64
// maximum total read markers that can be stored // maximum total read markers that can be stored
@ -54,18 +53,16 @@ const (
pushQueueLengthPerClient = 16 pushQueueLengthPerClient = 16
) )
var (
// idle timeouts for client connections, set from the config
RegisterTimeout, PingTimeout, DisconnectTimeout time.Duration
)
const ( const (
// RegisterTimeout is how long clients have to register before we disconnect them
RegisterTimeout = time.Minute
// DefaultIdleTimeout is how long without traffic before we send the client a PING
DefaultIdleTimeout = time.Minute + 30*time.Second
// For Tor clients, we send a PING at least every 30 seconds, as a workaround for this bug // For Tor clients, we send a PING at least every 30 seconds, as a workaround for this bug
// (single-onion circuits will close unless the client sends data once every 60 seconds): // (single-onion circuits will close unless the client sends data once every 60 seconds):
// https://bugs.torproject.org/29665 // https://bugs.torproject.org/29665
TorIdleTimeout = time.Second * 30 TorPingTimeout = time.Second * 30
// This is how long a client gets without sending any message, including the PONG to our
// PING, before we disconnect them:
DefaultTotalTimeout = 2*time.Minute + 30*time.Second
// round off the ping interval by this much, see below: // round off the ping interval by this much, see below:
PingCoalesceThreshold = time.Second PingCoalesceThreshold = time.Second
@ -132,6 +129,8 @@ type Client struct {
clearablePushMessages map[string]time.Time clearablePushMessages map[string]time.Time
pushSubscriptionsExist atomic.Uint32 // this is a cache on len(pushSubscriptions) != 0 pushSubscriptionsExist atomic.Uint32 // this is a cache on len(pushSubscriptions) != 0
pushQueue pushQueue pushQueue pushQueue
metadata map[string]string
metadataThrottle connection_limits.ThrottleDetails
} }
type saslStatus struct { type saslStatus struct {
@ -189,6 +188,8 @@ type Session struct {
fakelag Fakelag fakelag Fakelag
deferredFakelagCount int deferredFakelagCount int
lastOperAttempt time.Time
certfp string certfp string
peerCerts []*x509.Certificate peerCerts []*x509.Certificate
sasl saslStatus sasl saslStatus
@ -215,6 +216,9 @@ type Session struct {
batch MultilineBatch batch MultilineBatch
webPushEndpoint string // goroutine-local: web push endpoint registered by the current session webPushEndpoint string // goroutine-local: web push endpoint registered by the current session
metadataSubscriptions utils.HashSet[string]
metadataPreregVals map[string]string
} }
// MultilineBatch tracks the state of a client-to-server multiline batch. // MultilineBatch tracks the state of a client-to-server multiline batch.
@ -410,6 +414,11 @@ func (server *Server) RunClient(conn IRCConn) {
session.certfp, session.peerCerts, _ = utils.GetCertFP(wConn.Conn, RegisterTimeout) session.certfp, session.peerCerts, _ = utils.GetCertFP(wConn.Conn, RegisterTimeout)
} }
if config.Server.InitialNotice != "" {
// send initial notice for HOPM to recognize
client.Send(nil, client.server.name, "NOTICE", "*", config.Server.InitialNotice)
}
if session.isTor { if session.isTor {
session.rawHostname = config.Server.TorListeners.Vhost session.rawHostname = config.Server.TorListeners.Vhost
client.rawHostname = session.rawHostname client.rawHostname = session.rawHostname
@ -424,7 +433,7 @@ func (server *Server) RunClient(conn IRCConn) {
client.run(session) client.run(session)
} }
func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus map[string]alwaysOnChannelStatus, lastSeen, readMarkers map[string]time.Time, uModes modes.Modes, realname string, pushSubscriptions []storedPushSubscription) { func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus map[string]alwaysOnChannelStatus, lastSeen, readMarkers map[string]time.Time, uModes modes.Modes, realname string, pushSubscriptions []storedPushSubscription, metadata map[string]string) {
now := time.Now().UTC() now := time.Now().UTC()
config := server.Config() config := server.Config()
if lastSeen == nil && account.Settings.AutoreplayMissed { if lastSeen == nil && account.Settings.AutoreplayMissed {
@ -509,6 +518,10 @@ func (server *Server) AddAlwaysOnClient(account ClientAccount, channelToStatus m
} }
} }
client.rebuildPushSubscriptionCache() client.rebuildPushSubscriptionCache()
if len(metadata) != 0 {
client.metadata = metadata
}
} }
func (client *Client) resizeHistory(config *Config) { func (client *Client) resizeHistory(config *Config) {
@ -675,7 +688,7 @@ func (client *Client) run(session *Session) {
isReattach := client.Registered() isReattach := client.Registered()
if isReattach { if isReattach {
client.Touch(session) client.Touch(session)
client.playReattachMessages(session) client.performReattach(session)
} }
firstLine := !isReattach firstLine := !isReattach
@ -775,7 +788,9 @@ func (client *Client) run(session *Session) {
} }
} }
func (client *Client) playReattachMessages(session *Session) { func (client *Client) performReattach(session *Session) {
client.applyPreregMetadata(session)
client.server.playRegistrationBurst(session) client.server.playRegistrationBurst(session)
hasHistoryCaps := session.HasHistoryCaps() hasHistoryCaps := session.HasHistoryCaps()
for _, channel := range session.client.Channels() { for _, channel := range session.client.Channels() {
@ -799,6 +814,34 @@ func (client *Client) playReattachMessages(session *Session) {
session.autoreplayMissedSince = time.Time{} session.autoreplayMissedSince = time.Time{}
} }
func (client *Client) applyPreregMetadata(session *Session) {
if session.metadataPreregVals == nil {
return
}
defer func() {
session.metadataPreregVals = nil
}()
updates := client.UpdateMetadataFromPrereg(session.metadataPreregVals, client.server.Config().Metadata.MaxKeys)
if len(updates) == 0 {
return
}
// TODO this is expensive
friends := client.FriendsMonitors(caps.Metadata)
for _, s := range client.Sessions() {
if s != session {
friends.Add(s)
}
}
target := client.Nick()
for k, v := range updates {
broadcastMetadataUpdate(client.server, maps.Keys(friends), session, target, k, v, true)
}
}
// //
// idle, quit, timers and timeouts // idle, quit, timers and timeouts
// //
@ -830,19 +873,19 @@ func (client *Client) updateIdleTimer(session *Session, now time.Time) {
session.pingSent = false session.pingSent = false
if session.idleTimer == nil { if session.idleTimer == nil {
pingTimeout := DefaultIdleTimeout pingTimeout := PingTimeout
if session.isTor { if session.isTor && TorPingTimeout < pingTimeout {
pingTimeout = TorIdleTimeout pingTimeout = TorPingTimeout
} }
session.idleTimer = time.AfterFunc(pingTimeout, session.handleIdleTimeout) session.idleTimer = time.AfterFunc(pingTimeout, session.handleIdleTimeout)
} }
} }
func (session *Session) handleIdleTimeout() { func (session *Session) handleIdleTimeout() {
totalTimeout := DefaultTotalTimeout totalTimeout := DisconnectTimeout
pingTimeout := DefaultIdleTimeout pingTimeout := PingTimeout
if session.isTor { if session.isTor && TorPingTimeout < pingTimeout {
pingTimeout = TorIdleTimeout pingTimeout = TorPingTimeout
} }
session.client.stateMutex.Lock() session.client.stateMutex.Lock()
@ -1130,6 +1173,7 @@ func (client *Client) SetNick(nick, nickCasefolded, skeleton string) (success bo
client.nickCasefolded = nickCasefolded client.nickCasefolded = nickCasefolded
client.skeleton = skeleton client.skeleton = skeleton
client.updateNickMaskNoMutex() client.updateNickMaskNoMutex()
return true return true
} }
@ -1363,7 +1407,7 @@ func (client *Client) destroy(session *Session) {
// alert monitors // alert monitors
if registered { if registered {
client.server.monitorManager.AlertAbout(details.nick, details.nickCasefolded, false) client.server.monitorManager.AlertAbout(details.nick, details.nickCasefolded, false, nil)
} }
// clean up channels // clean up channels
@ -1463,7 +1507,7 @@ func (session *Session) sendFromClientInternal(blocking bool, serverTime time.Ti
func composeMultilineBatch(batchID, fromNickMask, fromAccount string, isBot bool, tags map[string]string, command, target string, message utils.SplitMessage) (result []ircmsg.Message) { func composeMultilineBatch(batchID, fromNickMask, fromAccount string, isBot bool, tags map[string]string, command, target string, message utils.SplitMessage) (result []ircmsg.Message) {
batchStart := ircmsg.MakeMessage(tags, fromNickMask, "BATCH", "+"+batchID, caps.MultilineBatchType, target) batchStart := ircmsg.MakeMessage(tags, fromNickMask, "BATCH", "+"+batchID, caps.MultilineBatchType, target)
batchStart.SetTag("time", message.Time.Format(IRCv3TimestampFormat)) batchStart.SetTag("time", message.Time.Format(utils.IRCv3TimestampFormat))
batchStart.SetTag("msgid", message.Msgid) batchStart.SetTag("msgid", message.Msgid)
if fromAccount != "*" { if fromAccount != "*" {
batchStart.SetTag("account", fromAccount) batchStart.SetTag("account", fromAccount)
@ -1571,7 +1615,7 @@ func (session *Session) setTimeTag(msg *ircmsg.Message, serverTime time.Time) {
if serverTime.IsZero() { if serverTime.IsZero() {
serverTime = time.Now() serverTime = time.Now()
} }
msg.SetTag("time", serverTime.UTC().Format(IRCv3TimestampFormat)) msg.SetTag("time", serverTime.UTC().Format(utils.IRCv3TimestampFormat))
} }
} }
@ -1730,12 +1774,15 @@ func (client *Client) addHistoryItem(target *Client, item history.Item, details,
} }
if cStatus == HistoryPersistent || tStatus == HistoryPersistent { if cStatus == HistoryPersistent || tStatus == HistoryPersistent {
targetedItem.CfCorrespondent = "" targetedItem.CfCorrespondent = ""
client.server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem) err = client.server.historyDB.AddDirectMessage(details.nickCasefolded, details.account, tDetails.nickCasefolded, tDetails.account, targetedItem)
if err != nil {
client.server.logger.Error("history", "could not add direct message to history", err.Error())
}
} }
return nil return nil
} }
func (client *Client) listTargets(start, end history.Selector, limit int) (results []history.TargetListing, err error) { func (client *Client) listTargets(start, end time.Time, limit int) (results []history.TargetListing, err error) {
var base, extras []history.TargetListing var base, extras []history.TargetListing
var chcfnames []string var chcfnames []string
for _, channel := range client.Channels() { for _, channel := range client.Channels() {
@ -1756,27 +1803,35 @@ func (client *Client) listTargets(start, end history.Selector, limit int) (resul
} }
} }
persistentExtras, err := client.server.historyDB.ListChannels(chcfnames) persistentExtras, err := client.server.historyDB.ListChannels(chcfnames)
if err == nil && len(persistentExtras) != 0 { if err != nil {
client.server.logger.Error("history", "could not list persistent channels", err.Error())
} else if len(persistentExtras) != 0 {
extras = append(extras, persistentExtras...) extras = append(extras, persistentExtras...)
} }
_, cSeq, err := client.server.GetHistorySequence(nil, client, "") // get DM correspondents from the in-memory buffer or the database, as applicable
if err == nil && cSeq != nil { var cErr error
correspondents, err := cSeq.ListCorrespondents(start, end, limit) status, target := client.historyStatus(client.server.Config())
if err == nil { switch status {
base = correspondents case HistoryEphemeral:
} base, cErr = client.history.ListCorrespondents(start, end, limit)
case HistoryPersistent:
base, cErr = client.server.historyDB.ListCorrespondents(target, start, end, limit)
default:
// nothing to do
}
if cErr != nil {
base = nil
client.server.logger.Error("history", "could not list correspondents", cErr.Error())
} }
results = history.MergeTargets(base, extras, start.Time, end.Time, limit) results = history.MergeTargets(base, extras, start, end, limit)
return results, nil return results, nil
} }
// latest PRIVMSG from all DM targets // latest PRIVMSG from all DM targets
func (client *Client) privmsgsBetween(startTime, endTime time.Time, targetLimit, messageLimit int) (results []history.Item, err error) { func (client *Client) privmsgsBetween(startTime, endTime time.Time, targetLimit, messageLimit int) (results []history.Item, err error) {
start := history.Selector{Time: startTime} targets, err := client.listTargets(startTime, endTime, targetLimit)
end := history.Selector{Time: endTime}
targets, err := client.listTargets(start, end, targetLimit)
if err != nil { if err != nil {
return return
} }
@ -1786,7 +1841,7 @@ func (client *Client) privmsgsBetween(startTime, endTime time.Time, targetLimit,
} }
_, seq, err := client.server.GetHistorySequence(nil, client, target.CfName) _, seq, err := client.server.GetHistorySequence(nil, client, target.CfName)
if err == nil && seq != nil { if err == nil && seq != nil {
items, err := seq.Between(start, end, messageLimit) items, err := seq.Between(history.Selector{Time: startTime}, history.Selector{Time: endTime}, messageLimit)
if err == nil { if err == nil {
results = append(results, items...) results = append(results, items...)
} else { } else {
@ -1815,6 +1870,7 @@ const (
IncludeUserModes IncludeUserModes
IncludeRealname IncludeRealname
IncludePushSubscriptions IncludePushSubscriptions
IncludeMetadata
) )
func (client *Client) markDirty(dirtyBits uint) { func (client *Client) markDirty(dirtyBits uint) {
@ -1896,6 +1952,9 @@ func (client *Client) performWrite(additionalDirtyBits uint) {
if (dirtyBits & IncludePushSubscriptions) != 0 { if (dirtyBits & IncludePushSubscriptions) != 0 {
client.server.accounts.savePushSubscriptions(account, client.getPushSubscriptions(true)) client.server.accounts.savePushSubscriptions(account, client.getPushSubscriptions(true))
} }
if (dirtyBits & IncludeMetadata) != 0 {
client.server.accounts.saveMetadata(account, client.ListMetadata())
}
} }
// Blocking store; see Channel.Store and Socket.BlockingWrite // Blocking store; see Channel.Store and Socket.BlockingWrite

View File

@ -209,6 +209,11 @@ func init() {
handler: markReadHandler, handler: markReadHandler,
minParams: 0, // send FAIL instead of ERR_NEEDMOREPARAMS minParams: 0, // send FAIL instead of ERR_NEEDMOREPARAMS
}, },
"METADATA": {
handler: metadataHandler,
minParams: 2,
usablePreReg: true,
},
"MODE": { "MODE": {
handler: modeHandler, handler: modeHandler,
minParams: 1, minParams: 1,

View File

@ -45,6 +45,10 @@ import (
"github.com/ergochat/ergo/irc/webpush" "github.com/ergochat/ergo/irc/webpush"
) )
const (
defaultProxyDeadline = time.Minute
)
// here's how this works: exported (capitalized) members of the config structs // here's how this works: exported (capitalized) members of the config structs
// are defined in the YAML file and deserialized directly from there. They may // are defined in the YAML file and deserialized directly from there. They may
// be postprocessed and overwritten by LoadConfig. Unexported (lowercase) members // be postprocessed and overwritten by LoadConfig. Unexported (lowercase) members
@ -371,7 +375,7 @@ type AccountRegistrationConfig struct {
Mailto email.MailtoConfig Mailto email.MailtoConfig
} `yaml:"callbacks"` } `yaml:"callbacks"`
VerifyTimeout custime.Duration `yaml:"verify-timeout"` VerifyTimeout custime.Duration `yaml:"verify-timeout"`
BcryptCost uint `yaml:"bcrypt-cost"` BcryptCost int `yaml:"bcrypt-cost"`
} }
type VHostConfig struct { type VHostConfig struct {
@ -576,8 +580,14 @@ type Config struct {
CoerceIdent string `yaml:"coerce-ident"` CoerceIdent string `yaml:"coerce-ident"`
MOTD string MOTD string
motdLines []string motdLines []string
MOTDFormatting bool `yaml:"motd-formatting"` MOTDFormatting bool `yaml:"motd-formatting"`
Relaymsg struct { InitialNotice string `yaml:"initial-notice"`
IdleTimeouts struct {
Registration time.Duration
Ping time.Duration
Disconnect time.Duration
} `yaml:"idle-timeouts"`
Relaymsg struct {
Enabled bool Enabled bool
Separators string Separators string
AvailableToChanops bool `yaml:"available-to-chanops"` AvailableToChanops bool `yaml:"available-to-chanops"`
@ -599,6 +609,7 @@ type Config struct {
Cloaks cloaks.CloakConfig `yaml:"ip-cloaking"` Cloaks cloaks.CloakConfig `yaml:"ip-cloaking"`
SecureNetDefs []string `yaml:"secure-nets"` SecureNetDefs []string `yaml:"secure-nets"`
secureNets []net.IPNet secureNets []net.IPNet
OperThrottle time.Duration `yaml:"oper-throttle"`
supportedCaps *caps.Set supportedCaps *caps.Set
supportedCapsWithoutSTS *caps.Set supportedCapsWithoutSTS *caps.Set
capValues caps.Values capValues caps.Values
@ -723,6 +734,14 @@ type Config struct {
} `yaml:"tagmsg-storage"` } `yaml:"tagmsg-storage"`
} }
Metadata struct {
Enabled bool
MaxSubs int `yaml:"max-subs"`
MaxKeys int `yaml:"max-keys"`
MaxValueBytes int `yaml:"max-value-length"`
ClientThrottle ThrottleConfig `yaml:"client-throttle"`
}
WebPush struct { WebPush struct {
Enabled bool Enabled bool
Timeout time.Duration Timeout time.Duration
@ -979,7 +998,7 @@ func (conf *Config) prepareListeners() (err error) {
conf.Server.trueListeners = make(map[string]utils.ListenerConfig) conf.Server.trueListeners = make(map[string]utils.ListenerConfig)
for addr, block := range conf.Server.Listeners { for addr, block := range conf.Server.Listeners {
var lconf utils.ListenerConfig var lconf utils.ListenerConfig
lconf.ProxyDeadline = RegisterTimeout lconf.ProxyDeadline = defaultProxyDeadline
lconf.Tor = block.Tor lconf.Tor = block.Tor
lconf.STSOnly = block.STSOnly lconf.STSOnly = block.STSOnly
if lconf.STSOnly && !conf.Server.STS.Enabled { if lconf.STSOnly && !conf.Server.STS.Enabled {
@ -1229,6 +1248,23 @@ func LoadConfig(filename string) (config *Config, err error) {
} }
} }
if config.Server.IdleTimeouts.Registration <= 0 {
config.Server.IdleTimeouts.Registration = time.Minute
}
if config.Server.IdleTimeouts.Ping <= 0 {
config.Server.IdleTimeouts.Ping = time.Minute + 30*time.Second
}
if config.Server.IdleTimeouts.Disconnect <= 0 {
config.Server.IdleTimeouts.Disconnect = 2*time.Minute + 30*time.Second
}
if !(config.Server.IdleTimeouts.Ping < config.Server.IdleTimeouts.Disconnect) {
return nil, fmt.Errorf(
"ping timeout %v must be strictly less than disconnect timeout %v, to give the client time to respond",
config.Server.IdleTimeouts.Ping, config.Server.IdleTimeouts.Disconnect,
)
}
if config.Server.CoerceIdent != "" { if config.Server.CoerceIdent != "" {
if config.Server.CheckIdent { if config.Server.CheckIdent {
return nil, errors.New("Can't configure both check-ident and coerce-ident") return nil, errors.New("Can't configure both check-ident and coerce-ident")
@ -1473,6 +1509,10 @@ func LoadConfig(filename string) (config *Config, err error) {
config.Server.supportedCaps.Disable(caps.SASL) config.Server.supportedCaps.Disable(caps.SASL)
} }
if config.Server.OperThrottle <= 0 {
config.Server.OperThrottle = 10 * time.Second
}
if err := config.Accounts.OAuth2.Postprocess(); err != nil { if err := config.Accounts.OAuth2.Postprocess(); err != nil {
return nil, err return nil, err
} }
@ -1556,6 +1596,12 @@ func LoadConfig(filename string) (config *Config, err error) {
if config.Accounts.Registration.BcryptCost == 0 { if config.Accounts.Registration.BcryptCost == 0 {
config.Accounts.Registration.BcryptCost = passwd.DefaultCost config.Accounts.Registration.BcryptCost = passwd.DefaultCost
} }
if config.Accounts.Registration.BcryptCost < passwd.MinCost || config.Accounts.Registration.BcryptCost > passwd.MaxCost {
return nil, fmt.Errorf(
"invalid bcrypt-cost %d (require %d <= cost <= %d)",
config.Accounts.Registration.BcryptCost, passwd.MinCost, passwd.MaxCost,
)
}
if config.Channels.MaxChannelsPerClient == 0 { if config.Channels.MaxChannelsPerClient == 0 {
config.Channels.MaxChannelsPerClient = 100 config.Channels.MaxChannelsPerClient = 100
@ -1572,7 +1618,7 @@ func LoadConfig(filename string) (config *Config, err error) {
// in the current implementation, we disable history by creating a history buffer // in the current implementation, we disable history by creating a history buffer
// with zero capacity. but the `enabled` config option MUST be respected regardless // with zero capacity. but the `enabled` config option MUST be respected regardless
// of this detail // of this detail
if !config.History.Enabled { if !config.History.Enabled || config.History.ChathistoryMax == 0 {
config.History.ChannelLength = 0 config.History.ChannelLength = 0
config.History.ClientLength = 0 config.History.ClientLength = 0
config.Server.supportedCaps.Disable(caps.Chathistory) config.Server.supportedCaps.Disable(caps.Chathistory)
@ -1637,6 +1683,27 @@ func LoadConfig(filename string) (config *Config, err error) {
} }
} }
if !config.Metadata.Enabled {
config.Server.supportedCaps.Disable(caps.Metadata)
} else {
metadataValues := make([]string, 0, 4)
metadataValues = append(metadataValues, "before-connect")
// these are required for normal operation, so set sane defaults:
if config.Metadata.MaxSubs == 0 {
config.Metadata.MaxSubs = 10
}
metadataValues = append(metadataValues, fmt.Sprintf("max-subs=%d", config.Metadata.MaxSubs))
if config.Metadata.MaxKeys == 0 {
config.Metadata.MaxKeys = 10
}
metadataValues = append(metadataValues, fmt.Sprintf("max-keys=%d", config.Metadata.MaxKeys))
// this is not required since we enforce a hardcoded upper bound on key+value
if config.Metadata.MaxValueBytes > 0 {
metadataValues = append(metadataValues, fmt.Sprintf("max-value-bytes=%d", config.Metadata.MaxValueBytes))
}
config.Server.capValues[caps.Metadata] = strings.Join(metadataValues, ",")
}
err = config.processExtjwt() err = config.processExtjwt()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -33,6 +33,7 @@ var (
errAccountVerificationInvalidCode = errors.New("Invalid account verification code") errAccountVerificationInvalidCode = errors.New("Invalid account verification code")
errAccountUpdateFailed = errors.New(`Error while updating your account information`) errAccountUpdateFailed = errors.New(`Error while updating your account information`)
errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`) errAccountMustHoldNick = errors.New(`You must hold that nickname in order to register it`)
errAuthRequired = errors.New("You must be logged into an account to do this")
errAuthzidAuthcidMismatch = errors.New(`authcid and authzid must be the same`) errAuthzidAuthcidMismatch = errors.New(`authcid and authzid must be the same`)
errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`) errCertfpAlreadyExists = errors.New(`An account already exists for your certificate fingerprint`)
errChannelNotOwnedByAccount = errors.New("Channel not owned by the specified account") errChannelNotOwnedByAccount = errors.New("Channel not owned by the specified account")

View File

@ -7,9 +7,11 @@ import (
"fmt" "fmt"
"maps" "maps"
"net" "net"
"slices"
"time" "time"
"github.com/ergochat/ergo/irc/caps" "github.com/ergochat/ergo/irc/caps"
"github.com/ergochat/ergo/irc/connection_limits"
"github.com/ergochat/ergo/irc/languages" "github.com/ergochat/ergo/irc/languages"
"github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils" "github.com/ergochat/ergo/irc/utils"
@ -517,7 +519,7 @@ func (client *Client) GetReadMarker(cfname string) (result string) {
t, ok := client.readMarkers[cfname] t, ok := client.readMarkers[cfname]
client.stateMutex.RUnlock() client.stateMutex.RUnlock()
if ok { if ok {
return fmt.Sprintf("timestamp=%s", t.Format(IRCv3TimestampFormat)) return fmt.Sprintf("timestamp=%s", t.Format(utils.IRCv3TimestampFormat))
} }
return "*" return "*"
} }
@ -797,10 +799,12 @@ func (channel *Channel) Settings() (result ChannelSettings) {
} }
func (channel *Channel) SetSettings(settings ChannelSettings) { func (channel *Channel) SetSettings(settings ChannelSettings) {
defer channel.MarkDirty(IncludeSettings)
channel.stateMutex.Lock() channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.settings = settings channel.settings = settings
channel.stateMutex.Unlock()
channel.MarkDirty(IncludeSettings)
} }
func (channel *Channel) setForward(forward string) { func (channel *Channel) setForward(forward string) {
@ -827,3 +831,262 @@ func (channel *Channel) UUID() utils.UUID {
defer channel.stateMutex.RUnlock() defer channel.stateMutex.RUnlock()
return channel.uuid return channel.uuid
} }
func (session *Session) isSubscribedTo(key string) bool {
session.client.stateMutex.RLock()
defer session.client.stateMutex.RUnlock()
return session.metadataSubscriptions.Has(key)
}
func (session *Session) SubscribeTo(keys ...string) ([]string, error) {
maxSubs := session.client.server.Config().Metadata.MaxSubs
session.client.stateMutex.Lock()
defer session.client.stateMutex.Unlock()
if session.metadataSubscriptions == nil {
session.metadataSubscriptions = make(utils.HashSet[string])
}
var added []string
for _, k := range keys {
if !session.metadataSubscriptions.Has(k) {
if len(session.metadataSubscriptions) > maxSubs {
return added, errMetadataTooManySubs
}
added = append(added, k)
session.metadataSubscriptions.Add(k)
}
}
return added, nil
}
func (session *Session) UnsubscribeFrom(keys ...string) []string {
session.client.stateMutex.Lock()
defer session.client.stateMutex.Unlock()
var removed []string
for k := range session.metadataSubscriptions {
if slices.Contains(keys, k) {
removed = append(removed, k)
session.metadataSubscriptions.Remove(k)
}
}
return removed
}
func (session *Session) MetadataSubscriptions() utils.HashSet[string] {
session.client.stateMutex.Lock()
defer session.client.stateMutex.Unlock()
return maps.Clone(session.metadataSubscriptions)
}
func (channel *Channel) GetMetadata(key string) (string, bool) {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
val, ok := channel.metadata[key]
return val, ok
}
func (channel *Channel) SetMetadata(key string, value string, limit int) (updated bool, err error) {
defer channel.MarkDirty(IncludeAllAttrs)
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
if channel.metadata == nil {
channel.metadata = make(map[string]string)
}
existing, ok := channel.metadata[key]
if !ok && len(channel.metadata) >= limit {
return false, errLimitExceeded
}
updated = !ok || value != existing
if updated {
channel.metadata[key] = value
}
return updated, nil
}
func (channel *Channel) ListMetadata() map[string]string {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
return maps.Clone(channel.metadata)
}
func (channel *Channel) DeleteMetadata(key string) (updated bool) {
defer channel.MarkDirty(IncludeAllAttrs)
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
_, updated = channel.metadata[key]
if updated {
delete(channel.metadata, key)
}
return updated
}
func (channel *Channel) ClearMetadata() map[string]string {
defer channel.MarkDirty(IncludeAllAttrs)
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
oldMap := channel.metadata
channel.metadata = nil
return oldMap
}
func (channel *Channel) CountMetadata() int {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
return len(channel.metadata)
}
func (client *Client) GetMetadata(key string) (string, bool) {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
val, ok := client.metadata[key]
return val, ok
}
func (client *Client) SetMetadata(key string, value string, limit int) (updated bool, err error) {
var alwaysOn bool
defer func() {
if alwaysOn && updated {
client.markDirty(IncludeMetadata)
}
}()
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
alwaysOn = client.registered && client.alwaysOn
if client.metadata == nil {
client.metadata = make(map[string]string)
}
existing, ok := client.metadata[key]
if !ok && len(client.metadata) >= limit {
return false, errLimitExceeded
}
updated = !ok || value != existing
if updated {
client.metadata[key] = value
}
return updated, nil
}
func (client *Client) UpdateMetadataFromPrereg(preregData map[string]string, limit int) (updates map[string]string) {
var alwaysOn bool
defer func() {
if alwaysOn && len(updates) > 0 {
client.markDirty(IncludeMetadata)
}
}()
updates = make(map[string]string, len(preregData))
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
alwaysOn = client.registered && client.alwaysOn
if client.metadata == nil {
client.metadata = make(map[string]string)
}
for k, v := range preregData {
// do not overwrite any existing keys
_, ok := client.metadata[k]
if ok {
continue
}
if len(client.metadata) >= limit {
return // we know this is a new key
}
client.metadata[k] = v
updates[k] = v
}
return
}
func (client *Client) ListMetadata() map[string]string {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
return maps.Clone(client.metadata)
}
func (client *Client) DeleteMetadata(key string) (updated bool) {
defer func() {
if updated {
client.markDirty(IncludeMetadata)
}
}()
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
_, updated = client.metadata[key]
if updated {
delete(client.metadata, key)
}
return updated
}
func (client *Client) ClearMetadata() (oldMap map[string]string) {
defer func() {
if len(oldMap) > 0 {
client.markDirty(IncludeMetadata)
}
}()
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
oldMap = client.metadata
client.metadata = nil
return oldMap
}
func (client *Client) CountMetadata() int {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
return len(client.metadata)
}
func (client *Client) checkMetadataThrottle() (throttled bool, remainingTime time.Duration) {
config := client.server.Config()
if !config.Metadata.ClientThrottle.Enabled {
return false, 0
}
client.stateMutex.Lock()
defer client.stateMutex.Unlock()
// copy client.metadataThrottle locally and then back for processing
var throttle connection_limits.GenericThrottle
throttle.ThrottleDetails = client.metadataThrottle
throttle.Duration = config.Metadata.ClientThrottle.Duration
throttle.Limit = config.Metadata.ClientThrottle.MaxAttempts
throttled, remainingTime = throttle.Touch()
client.metadataThrottle = throttle.ThrottleDetails
return
}

View File

@ -9,11 +9,13 @@ package irc
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"maps"
"net" "net"
"os" "os"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"runtime/pprof" "runtime/pprof"
"slices"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -149,11 +151,8 @@ func (server *Server) sendLoginSnomask(nickMask, accountName string) {
// to indicate that it should be removed from the list // to indicate that it should be removed from the list
func acceptHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool { func acceptHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool {
for _, tNick := range strings.Split(msg.Params[0], ",") { for _, tNick := range strings.Split(msg.Params[0], ",") {
add := true tNick, negPrefix := strings.CutPrefix(tNick, "-")
if strings.HasPrefix(tNick, "-") { add := !negPrefix
add = false
tNick = strings.TrimPrefix(tNick, "-")
}
target := server.clients.Get(tNick) target := server.clients.Get(tNick)
if target == nil { if target == nil {
@ -701,11 +700,13 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
var channel *Channel var channel *Channel
var sequence history.Sequence var sequence history.Sequence
var err error var err error
var listTargets bool var disabled, listTargets bool
var targets []history.TargetListing var targets []history.TargetListing
defer func() { defer func() {
// errors are sent either without a batch, or in a draft/labeled-response batch as usual // errors are sent either without a batch, or in a draft/labeled-response batch as usual
if err == utils.ErrInvalidParams { if disabled {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "MESSAGE_ERROR", msg.Params[0], client.t("That feature is disabled"))
} else if err == utils.ErrInvalidParams {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_PARAMS", msg.Params[0], client.t("Invalid parameters"))
} else if !listTargets && sequence == nil { } else if !listTargets && sequence == nil {
rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", msg.Params[0], utils.SafeErrorParam(target), client.t("Messages could not be retrieved")) rb.Add(nil, server.name, "FAIL", "CHATHISTORY", "INVALID_TARGET", msg.Params[0], utils.SafeErrorParam(target), client.t("Messages could not be retrieved"))
@ -719,7 +720,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
for _, target := range targets { for _, target := range targets {
name := server.UnfoldName(target.CfName) name := server.UnfoldName(target.CfName)
rb.Add(nil, server.name, "CHATHISTORY", "TARGETS", name, rb.Add(nil, server.name, "CHATHISTORY", "TARGETS", name,
target.Time.Format(IRCv3TimestampFormat)) target.Time.Format(utils.IRCv3TimestampFormat))
} }
} else if channel != nil { } else if channel != nil {
channel.replayHistoryItems(rb, items, true) channel.replayHistoryItems(rb, items, true)
@ -731,7 +732,8 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
config := server.Config() config := server.Config()
maxChathistoryLimit := config.History.ChathistoryMax maxChathistoryLimit := config.History.ChathistoryMax
if maxChathistoryLimit == 0 { if !config.History.Enabled || maxChathistoryLimit == 0 {
disabled = true
return return
} }
preposition := strings.ToLower(msg.Params[0]) preposition := strings.ToLower(msg.Params[0])
@ -754,7 +756,7 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
msgid, err = history.NormalizeMsgid(value), nil msgid, err = history.NormalizeMsgid(value), nil
return return
} else if identifier == "timestamp" { } else if identifier == "timestamp" {
timestamp, err = time.Parse(IRCv3TimestampFormat, value) timestamp, err = time.Parse(utils.IRCv3TimestampFormat, value)
if err == nil { if err == nil {
timestamp = timestamp.UTC() timestamp = timestamp.UTC()
if timestamp.Before(unixEpoch) { if timestamp.Before(unixEpoch) {
@ -842,7 +844,12 @@ func chathistoryHandler(server *Server, client *Client, msg ircmsg.Message, rb *
} }
if listTargets { if listTargets {
targets, err = client.listTargets(start, end, limit) // TARGETS must take time= selectors
if start.Time.IsZero() || end.Time.IsZero() {
err = utils.ErrInvalidParams
return
}
targets, err = client.listTargets(start.Time, end.Time, limit)
} else { } else {
channel, sequence, err = server.GetHistorySequence(nil, client, target) channel, sequence, err = server.GetHistorySequence(nil, client, target)
if err != nil || sequence == nil { if err != nil || sequence == nil {
@ -2546,8 +2553,19 @@ func operHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
return false return false
} }
config := server.Config()
now := time.Now()
nextAllowableAttempt := rb.session.lastOperAttempt.Add(config.Server.OperThrottle)
if now.Before(nextAllowableAttempt) {
timeLeft := nextAllowableAttempt.Sub(now).Round(time.Millisecond)
rb.Add(nil, server.name, ERR_NOOPERHOST, client.Nick(), fmt.Sprintf(client.t("You must wait %v before issuing OPER again"), timeLeft))
return false
}
rb.session.lastOperAttempt = now
// must pass at least one check, and all enabled checks // must pass at least one check, and all enabled checks
var checkPassed, checkFailed, passwordFailed bool var checkPassed, checkFailed, certFailed, passwordFailed bool
oper := server.GetOperator(msg.Params[0]) oper := server.GetOperator(msg.Params[0])
if oper != nil { if oper != nil {
if oper.Certfp != "" { if oper.Certfp != "" {
@ -2555,11 +2573,13 @@ func operHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
checkPassed = true checkPassed = true
} else { } else {
checkFailed = true checkFailed = true
certFailed = true
} }
} }
if !checkFailed && oper.Pass != nil { if !checkFailed && oper.Pass != nil {
if len(msg.Params) == 1 { if len(msg.Params) == 1 {
checkFailed = true checkFailed = true
passwordFailed = true
} else if bcrypt.CompareHashAndPassword(oper.Pass, []byte(msg.Params[1])) != nil { } else if bcrypt.CompareHashAndPassword(oper.Pass, []byte(msg.Params[1])) != nil {
checkFailed = true checkFailed = true
passwordFailed = true passwordFailed = true
@ -2570,14 +2590,21 @@ func operHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
} }
if !checkPassed || checkFailed { if !checkPassed || checkFailed {
rb.Add(nil, server.name, ERR_PASSWDMISMATCH, client.Nick(), client.t("Password incorrect")) rb.Add(nil, server.name, ERR_NOOPERHOST, client.Nick(), client.t("OPER failed; check the server logs for details."))
// #951: only disconnect them if we actually tried to check a password for them
if passwordFailed { // hopefully not too spammy given the throttling:
client.Quit(client.t("Password incorrect"), rb.session) if oper == nil {
return true server.logger.Info("opers", "OPER failed with invalid oper name", msg.Params[0])
} else if certFailed {
server.logger.Info("opers", "OPER attempt for", msg.Params[0], "failed with invalid certfp")
} else if passwordFailed {
server.logger.Info("opers", "OPER attempt for", msg.Params[0], "failed with invalid password")
} else { } else {
return false // should not be possible given config validation
server.logger.Info("opers", "OPER attempt for", msg.Params[0], "failed with invalid config")
} }
return false
} }
if oper != nil { if oper != nil {
@ -2770,8 +2797,10 @@ func redactHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
targetmsgid := msg.Params[1] targetmsgid := msg.Params[1]
//clientOnlyTags := msg.ClientOnlyTags() //clientOnlyTags := msg.ClientOnlyTags()
var reason string var reason string
var reasonPresent bool
if len(msg.Params) > 2 { if len(msg.Params) > 2 {
reason = msg.Params[2] reason = msg.Params[2]
reasonPresent = true
} }
var members []*Client // members of a channel, or both parties of a PM var members []*Client // members of a channel, or both parties of a PM
var canDelete CanDelete var canDelete CanDelete
@ -2784,7 +2813,7 @@ func redactHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
if target[0] == '#' { if target[0] == '#' {
channel := server.channels.Get(target) channel := server.channels.Get(target)
if channel == nil { if channel == nil {
rb.Add(nil, server.name, ERR_NOSUCHCHANNEL, client.Nick(), utils.SafeErrorParam(target), client.t("No such channel")) rb.Add(nil, server.name, "FAIL", "REDACT", "INVALID_TARGET", utils.SafeErrorParam(target), client.t("No such channel"))
return false return false
} }
members = channel.Members() members = channel.Members()
@ -2813,10 +2842,16 @@ func redactHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
} }
err := server.DeleteMessage(target, targetmsgid, accountName) err := server.DeleteMessage(target, targetmsgid, accountName)
if err == errNoop { switch err {
case history.ErrNotFound:
rb.Add(nil, server.name, "FAIL", "REDACT", "UNKNOWN_MSGID", utils.SafeErrorParam(target), utils.SafeErrorParam(targetmsgid), client.t("This message does not exist or is too old")) rb.Add(nil, server.name, "FAIL", "REDACT", "UNKNOWN_MSGID", utils.SafeErrorParam(target), utils.SafeErrorParam(targetmsgid), client.t("This message does not exist or is too old"))
return false return false
} else if err != nil { case history.ErrDisallowed:
rb.Add(nil, server.name, "FAIL", "REDACT", "REDACT_FORBIDDEN", utils.SafeErrorParam(target), utils.SafeErrorParam(targetmsgid), client.t("You are not authorized to delete this message"))
return false
case nil:
// OK
default:
isOper := client.HasRoleCapabs("history") isOper := client.HasRoleCapabs("history")
if isOper { if isOper {
rb.Add(nil, server.name, "FAIL", "REDACT", "REDACT_FORBIDDEN", utils.SafeErrorParam(target), utils.SafeErrorParam(targetmsgid), fmt.Sprintf(client.t("Error deleting message: %v"), err)) rb.Add(nil, server.name, "FAIL", "REDACT", "REDACT_FORBIDDEN", utils.SafeErrorParam(target), utils.SafeErrorParam(targetmsgid), fmt.Sprintf(client.t("Error deleting message: %v"), err))
@ -2831,7 +2866,8 @@ func redactHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
// now we have to remove it from the buffer of the client who sent the REDACT command // now we have to remove it from the buffer of the client who sent the REDACT command
err := server.DeleteMessage(client.Nick(), targetmsgid, accountName) err := server.DeleteMessage(client.Nick(), targetmsgid, accountName)
if err != nil { // ErrNotFound is expected if both clients are using persistent history
if err != nil && err != history.ErrNotFound {
client.server.logger.Error("internal", fmt.Sprintf("Private message %s is not deletable by %s from their own buffer's even though we just deleted it from %s's. This is a bug, please report it in details.", targetmsgid, client.Nick(), target), client.Nick()) client.server.logger.Error("internal", fmt.Sprintf("Private message %s is not deletable by %s from their own buffer's even though we just deleted it from %s's. This is a bug, please report it in details.", targetmsgid, client.Nick(), target), client.Nick())
isOper := client.HasRoleCapabs("history") isOper := client.HasRoleCapabs("history")
if isOper { if isOper {
@ -2845,7 +2881,11 @@ func redactHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respo
for _, member := range members { for _, member := range members {
for _, session := range member.Sessions() { for _, session := range member.Sessions() {
if session.capabilities.Has(caps.MessageRedaction) { if session.capabilities.Has(caps.MessageRedaction) {
session.sendFromClientInternal(false, time, msgid, details.nickMask, details.accountName, isBot, nil, "REDACT", target, targetmsgid, reason) if reasonPresent {
session.sendFromClientInternal(false, time, msgid, details.nickMask, details.accountName, isBot, nil, "REDACT", target, targetmsgid, reason)
} else {
session.sendFromClientInternal(false, time, msgid, details.nickMask, details.accountName, isBot, nil, "REDACT", target, targetmsgid)
}
} else { } else {
// If we wanted to send a fallback to clients which do not support // If we wanted to send a fallback to clients which do not support
// draft/message-redaction, we would do it from here. // draft/message-redaction, we would do it from here.
@ -2910,11 +2950,23 @@ func quitHandler(server *Server, client *Client, msg ircmsg.Message, rb *Respons
// REGISTER < account | * > < email | * > <password> // REGISTER < account | * > < email | * > <password>
func registerHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) { func registerHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) {
accountName := client.Nick() var accountName string
if accountName == "*" { if client.registered {
accountName = client.Nick()
} else {
accountName = client.preregNick accountName = client.preregNick
} }
config := server.Config()
if client.registered && config.Accounts.NickReservation.ForceGuestFormat {
matches := config.Accounts.NickReservation.guestRegexp.FindStringSubmatch(accountName)
if matches == nil || len(matches) < 2 {
rb.Add(nil, server.name, "FAIL", "REGISTER", "INVALID_USERNAME", utils.SafeErrorParam(accountName), client.t("Username invalid or not given"))
return
}
accountName = matches[1]
}
switch msg.Params[0] { switch msg.Params[0] {
case "*", accountName: case "*", accountName:
// ok // ok
@ -2931,7 +2983,6 @@ func registerHandler(server *Server, client *Client, msg ircmsg.Message, rb *Res
return return
} }
config := server.Config()
if !config.Accounts.Registration.Enabled { if !config.Accounts.Registration.Enabled {
rb.Add(nil, server.name, "FAIL", "REGISTER", "DISALLOWED", accountName, client.t("Account registration is disabled")) rb.Add(nil, server.name, "FAIL", "REGISTER", "DISALLOWED", accountName, client.t("Account registration is disabled"))
return return
@ -2977,7 +3028,7 @@ func registerHandler(server *Server, client *Client, msg ircmsg.Message, rb *Res
announcePendingReg(client, rb, accountName) announcePendingReg(client, rb, accountName)
} }
case errAccountAlreadyRegistered, errAccountAlreadyUnregistered, errAccountMustHoldNick: case errAccountAlreadyRegistered, errAccountAlreadyUnregistered, errAccountMustHoldNick:
rb.Add(nil, server.name, "FAIL", "REGISTER", "USERNAME_EXISTS", accountName, client.t("Username is already registered or otherwise unavailable")) rb.Add(nil, server.name, "FAIL", "REGISTER", "ACCOUNT_EXISTS", accountName, client.t("Username is already registered or otherwise unavailable"))
case errAccountBadPassphrase: case errAccountBadPassphrase:
rb.Add(nil, server.name, "FAIL", "REGISTER", "INVALID_PASSWORD", accountName, client.t("Password was invalid")) rb.Add(nil, server.name, "FAIL", "REGISTER", "INVALID_PASSWORD", accountName, client.t("Password was invalid"))
default: default:
@ -3055,14 +3106,14 @@ func markReadHandler(server *Server, client *Client, msg ircmsg.Message, rb *Res
// "MARKREAD client set command": MARKREAD <target> <timestamp> // "MARKREAD client set command": MARKREAD <target> <timestamp>
readTimestamp := strings.TrimPrefix(msg.Params[1], "timestamp=") readTimestamp := strings.TrimPrefix(msg.Params[1], "timestamp=")
readTime, err := time.Parse(IRCv3TimestampFormat, readTimestamp) readTime, err := time.Parse(utils.IRCv3TimestampFormat, readTimestamp)
if err != nil { if err != nil {
rb.Add(nil, server.name, "FAIL", "MARKREAD", "INVALID_PARAMS", utils.SafeErrorParam(readTimestamp), client.t("Invalid timestamp")) rb.Add(nil, server.name, "FAIL", "MARKREAD", "INVALID_PARAMS", utils.SafeErrorParam(readTimestamp), client.t("Invalid timestamp"))
return return
} }
readTime = readTime.UTC() readTime = readTime.UTC()
result := client.SetReadMarker(cftarget, readTime) result := client.SetReadMarker(cftarget, readTime)
readTimestamp = fmt.Sprintf("timestamp=%s", result.Format(IRCv3TimestampFormat)) readTimestamp = fmt.Sprintf("timestamp=%s", result.Format(utils.IRCv3TimestampFormat))
// inform the originating session whether it was a success or a no-op: // inform the originating session whether it was a success or a no-op:
rb.Add(nil, server.name, "MARKREAD", unfoldedTarget, readTimestamp) rb.Add(nil, server.name, "MARKREAD", unfoldedTarget, readTimestamp)
if result.Equal(readTime) { if result.Equal(readTime) {
@ -3089,6 +3140,287 @@ func markReadHandler(server *Server, client *Client, msg ircmsg.Message, rb *Res
return return
} }
// METADATA <target> <subcommand> [<and so on>...]
func metadataHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) (exiting bool) {
config := server.Config()
if !config.Metadata.Enabled {
rb.Add(nil, server.name, "FAIL", "METADATA", "FORBIDDEN", utils.SafeErrorParam(msg.Params[0]), "Metadata is disabled on this server")
return
}
subcommand := strings.ToLower(msg.Params[1])
needsKey := subcommand == "set" || subcommand == "get" || subcommand == "sub" || subcommand == "unsub"
if needsKey && len(msg.Params) < 3 {
rb.Add(nil, server.name, ERR_NEEDMOREPARAMS, client.Nick(), msg.Command, client.t("Not enough parameters"))
return
}
switch subcommand {
case "sub", "unsub", "subs":
// these are session-local and function the same whether or not the client is registered
return metadataSubsHandler(client, subcommand, msg.Params, rb)
case "get", "set", "list", "clear", "sync":
if client.registered {
return metadataRegisteredHandler(client, config, subcommand, msg.Params, rb)
} else {
return metadataUnregisteredHandler(client, config, subcommand, msg.Params, rb)
}
default:
rb.Add(nil, server.name, "FAIL", "METADATA", "SUBCOMMAND_INVALID", utils.SafeErrorParam(msg.Params[1]), client.t("Invalid subcommand"))
return
}
}
// metadataRegisteredHandler handles metadata-modifying commands from registered clients
func metadataRegisteredHandler(client *Client, config *Config, subcommand string, params []string, rb *ResponseBuffer) (exiting bool) {
server := client.server
target := params[0]
noKeyPerms := func(key string) {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_NO_PERMISSION", target, key, client.t("You do not have permission to perform this action"))
}
if target == "*" {
target = client.Nick()
}
var targetObj MetadataHaver
var targetClient *Client
var targetChannel *Channel
if strings.HasPrefix(target, "#") {
targetChannel = server.channels.Get(target)
if targetChannel != nil {
targetObj = targetChannel
target = targetChannel.Name() // canonicalize case
}
} else {
targetClient = server.clients.Get(target)
if targetClient != nil {
targetObj = targetClient
target = targetClient.Nick() // canonicalize case
}
}
if targetObj == nil {
rb.Add(nil, server.name, "FAIL", "METADATA", "INVALID_TARGET", target, client.t("Invalid metadata target"))
return
}
switch subcommand {
case "set":
key := params[2]
if metadataKeyIsEvil(key) {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_INVALID", utils.SafeErrorParam(key), client.t("Invalid key name"))
return
}
if !metadataCanIEditThisKey(client, targetObj, key) {
noKeyPerms(key)
return
}
// only rate limit clients changing their own metadata:
// channel metadata updates are not any more costly than a PRIVMSG
if client == targetClient {
if throttled, remainingTime := client.checkMetadataThrottle(); throttled {
retryAfter := strconv.Itoa(int(remainingTime.Seconds()) + 1)
rb.Add(nil, server.name, "FAIL", "METADATA", "RATE_LIMITED",
target, utils.SafeErrorParam(key), retryAfter,
fmt.Sprintf(client.t("Please wait at least %v and try again"), remainingTime.Round(time.Millisecond)))
return
}
}
if len(params) > 3 {
value := params[3]
config := client.server.Config()
if failMsg := metadataValueIsEvil(config, key, value); failMsg != "" {
rb.Add(nil, server.name, "FAIL", "METADATA", "VALUE_INVALID", client.t(failMsg))
return
}
updated, err := targetObj.SetMetadata(key, value, config.Metadata.MaxKeys)
if err != nil {
// errLimitExceeded is the only possible error
rb.Add(nil, server.name, "FAIL", "METADATA", "LIMIT_REACHED", client.t("Too many metadata keys"))
return
}
// echo the value to the client whether or not there was a real update
rb.Add(nil, server.name, RPL_KEYVALUE, client.Nick(), target, key, "*", value)
if updated {
notifySubscribers(server, rb.session, targetObj, target, key, value, true)
}
} else {
if updated := targetObj.DeleteMetadata(key); updated {
notifySubscribers(server, rb.session, targetObj, target, key, "", false)
rb.Add(nil, server.name, RPL_KEYNOTSET, client.Nick(), target, key, client.t("Key deleted"))
} else {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_NOT_SET", utils.SafeErrorParam(key), client.t("Metadata key not set"))
}
}
case "get":
if !metadataCanISeeThisTarget(client, targetObj) {
noKeyPerms("*")
return
}
batchId := rb.StartNestedBatch("metadata", target)
defer rb.EndNestedBatch(batchId)
for _, key := range params[2:] {
if metadataKeyIsEvil(key) {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_INVALID", utils.SafeErrorParam(key), client.t("Invalid key name"))
continue
}
val, ok := targetObj.GetMetadata(key)
if !ok {
rb.Add(nil, server.name, RPL_KEYNOTSET, client.Nick(), target, key, client.t("Key is not set"))
continue
}
visibility := "*"
rb.Add(nil, server.name, RPL_KEYVALUE, client.Nick(), target, key, visibility, val)
}
case "list":
playMetadataList(rb, client.Nick(), target, targetObj.ListMetadata())
case "clear":
if !metadataCanIEditThisTarget(client, targetObj) {
noKeyPerms("*")
return
}
values := targetObj.ClearMetadata()
playMetadataList(rb, client.Nick(), target, values)
case "sync":
if targetChannel != nil {
syncChannelMetadata(server, rb, targetChannel)
}
if targetClient != nil {
syncClientMetadata(server, rb, targetClient)
}
}
return
}
// metadataUnregisteredHandler handles metadata-modifying commands for pre-connection-registration
// clients. these operations act on a session-local buffer; if/when the client completes registration,
// they are applied to the final Client object (possibly a different client if there was a reattach)
// on a best-effort basis.
func metadataUnregisteredHandler(client *Client, config *Config, subcommand string, params []string, rb *ResponseBuffer) (exiting bool) {
server := client.server
if params[0] != "*" {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_NO_PERMISSION", utils.SafeErrorParam(params[0]), "*", client.t("You can only modify your own metadata before completing connection registration"))
return
}
switch subcommand {
case "set":
if rb.session.metadataPreregVals == nil {
rb.session.metadataPreregVals = make(map[string]string)
}
key := params[2]
if metadataKeyIsEvil(key) {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_INVALID", utils.SafeErrorParam(key), client.t("Invalid key name"))
return
}
if len(params) >= 4 {
value := params[3]
// enforce a sane limit on prereg keys. we don't need to enforce the exact limit,
// that will be done when applying the buffer after registration
if len(rb.session.metadataPreregVals) > config.Metadata.MaxKeys {
rb.Add(nil, server.name, "FAIL", "METADATA", "LIMIT_REACHED", client.t("Too many metadata keys"))
return
}
if failMsg := metadataValueIsEvil(config, key, value); failMsg != "" {
rb.Add(nil, server.name, "FAIL", "METADATA", "VALUE_INVALID", client.t(failMsg))
return
}
rb.session.metadataPreregVals[key] = value
rb.Add(nil, server.name, RPL_KEYVALUE, "*", "*", key, "*", value)
} else {
// unset
_, present := rb.session.metadataPreregVals[key]
if present {
delete(rb.session.metadataPreregVals, key)
rb.Add(nil, server.name, RPL_KEYNOTSET, "*", "*", key, client.t("Key deleted"))
} else {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_NOT_SET", utils.SafeErrorParam(key), client.t("Metadata key not set"))
}
}
case "list":
playMetadataList(rb, "*", "*", rb.session.metadataPreregVals)
case "clear":
oldMetadata := rb.session.metadataPreregVals
rb.session.metadataPreregVals = nil
playMetadataList(rb, "*", "*", oldMetadata)
case "sync":
rb.Add(nil, server.name, RPL_METADATASYNCLATER, "*", utils.SafeErrorParam(params[1]), "60") // lol
}
return false
}
// metadataSubsHandler handles subscription-related commands;
// these are handled the same whether the client is registered or not
func metadataSubsHandler(client *Client, subcommand string, params []string, rb *ResponseBuffer) (exiting bool) {
server := client.server
switch subcommand {
case "sub":
keys := params[2:]
for _, key := range keys {
if metadataKeyIsEvil(key) {
rb.Add(nil, server.name, "FAIL", "METADATA", "KEY_INVALID", utils.SafeErrorParam(key), client.t("Invalid key name"))
return
}
}
added, err := rb.session.SubscribeTo(keys...)
if err == errMetadataTooManySubs {
bad := keys[len(added)] // get the key that broke the camel's back
rb.Add(nil, server.name, "FAIL", "METADATA", "TOO_MANY_SUBS", utils.SafeErrorParam(bad), client.t("Too many subscriptions"))
}
lineLength := MaxLineLen - len(server.name) - len(RPL_METADATASUBOK) - len(client.Nick()) - 10
chunked := utils.ChunkifyParams(slices.Values(added), lineLength)
for _, line := range chunked {
params := append([]string{client.Nick()}, line...)
rb.Add(nil, server.name, RPL_METADATASUBOK, params...)
}
case "unsub":
keys := params[2:]
removed := rb.session.UnsubscribeFrom(keys...)
lineLength := MaxLineLen - len(server.name) - len(RPL_METADATAUNSUBOK) - len(client.Nick()) - 10
chunked := utils.ChunkifyParams(slices.Values(removed), lineLength)
for _, line := range chunked {
params := append([]string{client.Nick()}, line...)
rb.Add(nil, server.name, RPL_METADATAUNSUBOK, params...)
}
case "subs":
lineLength := MaxLineLen - len(server.name) - len(RPL_METADATASUBS) - len(client.Nick()) - 10
subs := rb.session.MetadataSubscriptions()
batchID := rb.StartNestedBatch("metadata-subs")
defer rb.EndNestedBatch(batchID)
chunked := utils.ChunkifyParams(maps.Keys(subs), lineLength)
for _, line := range chunked {
params := append([]string{client.Nick()}, line...)
rb.Add(nil, server.name, RPL_METADATASUBS, params...)
}
}
return false
}
// REHASH // REHASH
func rehashHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool { func rehashHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool {
nick := client.Nick() nick := client.Nick()
@ -3655,6 +3987,7 @@ func webpushHandler(server *Server, client *Client, msg ircmsg.Message, rb *Resp
if err := webpush.SanityCheckWebPushEndpoint(endpoint); err != nil { if err := webpush.SanityCheckWebPushEndpoint(endpoint); err != nil {
rb.Add(nil, server.name, "FAIL", "WEBPUSH", "INVALID_PARAMS", subcommand, client.t("Invalid web push URL")) rb.Add(nil, server.name, "FAIL", "WEBPUSH", "INVALID_PARAMS", subcommand, client.t("Invalid web push URL"))
return false
} }
switch subcommand { switch subcommand {
@ -4118,9 +4451,9 @@ func zncHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response
// fake handler for unknown commands // fake handler for unknown commands
func unknownCommandHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool { func unknownCommandHandler(server *Server, client *Client, msg ircmsg.Message, rb *ResponseBuffer) bool {
var message string var message string
if strings.HasPrefix(msg.Command, "/") { if trimmedCmd, initialSlash := strings.CutPrefix(msg.Command, "/"); initialSlash {
message = fmt.Sprintf(client.t("Unknown command; if you are using /QUOTE, the correct syntax is /QUOTE %[1]s, not /QUOTE %[2]s"), message = fmt.Sprintf(client.t("Unknown command; if you are using /QUOTE, the correct syntax is /QUOTE %[1]s, not /QUOTE %[2]s"),
strings.TrimPrefix(msg.Command, "/"), msg.Command) trimmedCmd, msg.Command)
} else { } else {
message = client.t("Unknown command") message = client.t("Unknown command")
} }

View File

@ -238,11 +238,10 @@ Get an explanation of <argument>, or "index" for a list of help topics.`,
"history": { "history": {
text: `HISTORY <target> [limit] text: `HISTORY <target> [limit]
Replay message history. <target> can be a channel name, "me" to replay direct Replay message history. <target> can be a channel name or a nickname you have
message history, or a nickname to replay another client's direct message direct message history with. [limit] can be either an integer (the maximum
history (they must be logged into the same account as you). [limit] can be number of messages to replay), or a time duration like 10m or 1h (the time
either an integer (the maximum number of messages to replay), or a time window within which to replay messages).`,
duration like 10m or 1h (the time window within which to replay messages).`,
}, },
"info": { "info": {
text: `INFO text: `INFO
@ -339,6 +338,12 @@ command is processed by that server.`,
MARKREAD updates an IRCv3 read message marker. It is not intended for use by MARKREAD updates an IRCv3 read message marker. It is not intended for use by
end users. For more details, see the latest draft of the read-marker end users. For more details, see the latest draft of the read-marker
specification.`, specification.`,
},
"metadata": {
text: `METADATA <target> <subcommand> [<everything else>...]
Retrieve and meddle with metadata for the given target.
Have a look at https://ircv3.net/specs/extensions/metadata for interesting technical information.`,
}, },
"mode": { "mode": {
text: `MODE <target> [<modestring> [<mode arguments>...]] text: `MODE <target> [<modestring> [<mode arguments>...]]

126
irc/history/database.go Normal file
View File

@ -0,0 +1,126 @@
// Copyright (c) 2025 Shivaram Lingamneni
// released under the MIT license
package history
import (
"errors"
"io"
"time"
)
var (
ErrDisallowed = errors.New("disallowed")
ErrNotFound = errors.New("not found")
)
// Database is an interface for persistent history storage backends.
type Database interface {
// Close closes the database connection and releases resources.
io.Closer
// AddChannelItem adds a history item for a channel.
// target is the casefolded channel name.
// account is the sender's casefolded account name ("" for no account).
AddChannelItem(target string, item Item, account string) error
// AddDirectMessage adds a history item for a direct message.
// All identifiers are casefolded; account identifiers are "" for no account.
AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item Item) error
// DeleteMsgid deletes a message by its msgid.
// accountName is the unfolded account name, or "*" to skip
// account validation
DeleteMsgid(msgid, accountName string) error
// MakeSequence creates a Sequence for querying history.
// target is the primary target (channel or account), casefolded.
// correspondent is the casefolded DM correspondent (empty for channels).
// cutoff is the earliest time to include in results.
MakeSequence(target, correspondent string, cutoff time.Time) Sequence
// ListChannels returns the timestamp of the latest message in each
// of the given channels (specified as casefolded names).
ListChannels(cfchannels []string) (results []TargetListing, err error)
// ListCorrespondents lists the DM correspondents associated with an account,
// in order to implement CHATHISTORY TARGETS.
ListCorrespondents(cftarget string, start, end time.Time, limit int) ([]TargetListing, error)
// these are for theoretical GDPR compliance, not actual chat functionality,
// and are not essential:
// Forget enqueues an account (casefolded) for message deletion.
// This is used for GDPR-style "right to be forgotten" requests.
// The actual deletion happens asynchronously.
Forget(account string)
// Export exports all messages for an account (casefolded) to the given writer.
Export(account string, writer io.Writer)
}
type noopDatabase struct{}
// NewNoopDatabase returns a Database implementation that does nothing.
func NewNoopDatabase() Database {
return noopDatabase{}
}
func (n noopDatabase) Close() error {
return nil
}
func (n noopDatabase) AddChannelItem(target string, item Item, account string) error {
return nil
}
func (n noopDatabase) AddDirectMessage(sender, senderAccount, recipient, recipientAccount string, item Item) error {
return nil
}
func (n noopDatabase) DeleteMsgid(msgid, accountName string) error {
return nil
}
func (n noopDatabase) Forget(account string) {
// no-op
}
func (n noopDatabase) Export(account string, writer io.Writer) {
// no-op
}
func (n noopDatabase) ListChannels(cfchannels []string) (results []TargetListing, err error) {
return nil, nil
}
func (n noopDatabase) ListCorrespondents(target string, start, end time.Time, limit int) (results []TargetListing, err error) {
return nil, nil
}
func (n noopDatabase) MakeSequence(target, correspondent string, cutoff time.Time) Sequence {
return noopSequence{}
}
// noopSequence is a no-op implementation of Sequence.
// XXX: this should never be accessed, because if persistent history is disabled,
// we should always be working with a bufferSequence instead. But we might as well
// be defensive in case there's an edge case where (noopDatabase).MakeSequence ends
// up getting called.
type noopSequence struct{}
func (n noopSequence) Between(start, end Selector, limit int) (results []Item, err error) {
return nil, nil
}
func (n noopSequence) Around(start Selector, limit int) (results []Item, err error) {
return nil, nil
}
func (n noopSequence) Cutoff() time.Time {
return time.Time{}
}
func (n noopSequence) Ephemeral() bool {
return false // we're pretending to be an empty database
}

View File

@ -230,10 +230,8 @@ func (list *Buffer) allCorrespondents() (results []TargetListing) {
} }
// list DM correspondents, as one input to CHATHISTORY TARGETS // list DM correspondents, as one input to CHATHISTORY TARGETS
func (list *Buffer) listCorrespondents(start, end Selector, cutoff time.Time, limit int) (results []TargetListing, err error) { func (list *Buffer) ListCorrespondents(start, end time.Time, limit int) (results []TargetListing, err error) {
after := start.Time after, before, ascending := MinMaxAsc(start, end, time.Time{})
before := end.Time
after, before, ascending := MinMaxAsc(after, before, cutoff)
correspondents := list.allCorrespondents() correspondents := list.allCorrespondents()
if len(correspondents) == 0 { if len(correspondents) == 0 {
@ -300,10 +298,6 @@ func (seq *bufferSequence) Around(start Selector, limit int) (results []Item, er
return GenericAround(seq, start, limit) return GenericAround(seq, start, limit)
} }
func (seq *bufferSequence) ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error) {
return seq.list.listCorrespondents(start, end, seq.cutoff, limit)
}
func (seq *bufferSequence) Cutoff() time.Time { func (seq *bufferSequence) Cutoff() time.Time {
return seq.cutoff return seq.cutoff
} }

View File

@ -21,8 +21,6 @@ type Sequence interface {
Between(start, end Selector, limit int) (results []Item, err error) Between(start, end Selector, limit int) (results []Item, err error)
Around(start Selector, limit int) (results []Item, err error) Around(start Selector, limit int) (results []Item, err error)
ListCorrespondents(start, end Selector, limit int) (results []TargetListing, err error)
// this are weird hacks that violate the encapsulation of Sequence to some extent; // this are weird hacks that violate the encapsulation of Sequence to some extent;
// Cutoff() returns the cutoff time for other code to use (it returns the zero time // Cutoff() returns the cutoff time for other code to use (it returns the zero time
// if none is set), and Ephemeral() returns whether the backing store is in-memory // if none is set), and Ephemeral() returns whether the backing store is in-memory

View File

@ -0,0 +1,19 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package history
import (
"encoding/json"
)
// 123 / '{' is the magic number that means JSON;
// if we want to do a binary encoding later, we just have to add different magic version numbers
func MarshalItem(item *Item) (result []byte, err error) {
return json.Marshal(item)
}
func UnmarshalItem(data []byte, result *Item) (err error) {
return json.Unmarshal(data, result)
}

View File

@ -164,7 +164,7 @@ func histservExportHandler(service *ircService, server *Server, client *Client,
config := server.Config() config := server.Config()
// don't include the account name in the filename because of escaping concerns // don't include the account name in the filename because of escaping concerns
filename := fmt.Sprintf("%s-%s.json", utils.GenerateSecretToken(), time.Now().UTC().Format(IRCv3TimestampFormat)) filename := fmt.Sprintf("%s-%s.json", utils.GenerateSecretToken(), time.Now().UTC().Format(utils.IRCv3TimestampFormat))
pathname := config.getOutputPath(filename) pathname := config.getOutputPath(filename)
outfile, err := os.Create(pathname) outfile, err := os.Create(pathname)
if err != nil { if err != nil {

View File

@ -45,7 +45,7 @@ type MessageCache struct {
func addAllTags(msg *ircmsg.Message, tags map[string]string, serverTime time.Time, msgid, accountName string, isBot bool) { func addAllTags(msg *ircmsg.Message, tags map[string]string, serverTime time.Time, msgid, accountName string, isBot bool) {
msg.UpdateTags(tags) msg.UpdateTags(tags)
msg.SetTag("time", serverTime.Format(IRCv3TimestampFormat)) msg.SetTag("time", serverTime.Format(utils.IRCv3TimestampFormat))
if accountName != "*" { if accountName != "*" {
msg.SetTag("account", accountName) msg.SetTag("account", accountName)
} }

174
irc/metadata.go Normal file
View File

@ -0,0 +1,174 @@
package irc
import (
"errors"
"iter"
"maps"
"regexp"
"unicode/utf8"
"github.com/ergochat/ergo/irc/caps"
"github.com/ergochat/ergo/irc/modes"
)
const (
// metadata key + value need to be relayable on a single IRC RPL_KEYVALUE line
maxCombinedMetadataLenBytes = 350
)
var (
errMetadataTooManySubs = errors.New("too many subscriptions")
errMetadataNotFound = errors.New("key not found")
)
type MetadataHaver interface {
SetMetadata(key string, value string, limit int) (updated bool, err error)
GetMetadata(key string) (string, bool)
DeleteMetadata(key string) (updated bool)
ListMetadata() map[string]string
ClearMetadata() map[string]string
CountMetadata() int
}
func notifySubscribers(server *Server, session *Session, targetObj MetadataHaver, targetName, key, value string, set bool) {
var recipientSessions iter.Seq[*Session]
switch target := targetObj.(type) {
case *Client:
// TODO this case is expensive and might warrant rate-limiting
friends := target.FriendsMonitors(caps.Metadata)
// broadcast metadata update to other connected sessions
for _, s := range target.Sessions() {
friends.Add(s)
}
recipientSessions = maps.Keys(friends)
case *Channel:
recipientSessions = target.sessionsWithCaps(caps.Metadata)
default:
return // impossible
}
broadcastMetadataUpdate(server, recipientSessions, session, targetName, key, value, set)
}
func broadcastMetadataUpdate(server *Server, sessions iter.Seq[*Session], originator *Session, target, key, value string, set bool) {
for s := range sessions {
// don't notify the session that made the change
if s == originator || !s.isSubscribedTo(key) {
continue
}
if set {
s.Send(nil, server.name, "METADATA", target, key, "*", value)
} else {
s.Send(nil, server.name, "METADATA", target, key, "*")
}
}
}
func syncClientMetadata(server *Server, rb *ResponseBuffer, target *Client) {
batchId := rb.StartNestedBatch("metadata", target.Nick())
defer rb.EndNestedBatch(batchId)
subs := rb.session.MetadataSubscriptions()
values := target.ListMetadata()
for k, v := range values {
if subs.Has(k) {
visibility := "*"
rb.Add(nil, server.name, "METADATA", target.Nick(), k, visibility, v)
}
}
}
func syncChannelMetadata(server *Server, rb *ResponseBuffer, channel *Channel) {
batchId := rb.StartNestedBatch("metadata", channel.Name())
defer rb.EndNestedBatch(batchId)
subs := rb.session.MetadataSubscriptions()
chname := channel.Name()
values := channel.ListMetadata()
for k, v := range values {
if subs.Has(k) {
visibility := "*"
rb.Add(nil, server.name, "METADATA", chname, k, visibility, v)
}
}
for _, client := range channel.Members() {
values := client.ListMetadata()
for k, v := range values {
if subs.Has(k) {
visibility := "*"
rb.Add(nil, server.name, "METADATA", client.Nick(), k, visibility, v)
}
}
}
}
func playMetadataList(rb *ResponseBuffer, nick, target string, values map[string]string) {
batchId := rb.StartNestedBatch("metadata", target)
defer rb.EndNestedBatch(batchId)
for key, val := range values {
visibility := "*"
rb.Add(nil, rb.session.client.server.name, RPL_KEYVALUE, nick, target, key, visibility, val)
}
}
func playMetadataVerbBatch(rb *ResponseBuffer, target string, values map[string]string) {
batchId := rb.StartNestedBatch("metadata", target)
defer rb.EndNestedBatch(batchId)
for key, val := range values {
visibility := "*"
rb.Add(nil, rb.session.client.server.name, "METADATA", target, key, visibility, val)
}
}
var validMetadataKeyRegexp = regexp.MustCompile("^[a-z0-9_./-]+$")
func metadataKeyIsEvil(key string) bool {
return !validMetadataKeyRegexp.MatchString(key)
}
func metadataValueIsEvil(config *Config, key, value string) (failMsg string) {
if !globalUtf8EnforcementSetting && !utf8.ValidString(value) {
return `METADATA values must be UTF-8`
}
if len(key)+len(value) > maxCombinedMetadataLenBytes ||
(config.Metadata.MaxValueBytes > 0 && len(value) > config.Metadata.MaxValueBytes) {
return `Value is too long`
}
return "" // success
}
func metadataCanIEditThisKey(client *Client, targetObj MetadataHaver, key string) bool {
// no key-specific logic as yet
return metadataCanIEditThisTarget(client, targetObj)
}
func metadataCanIEditThisTarget(client *Client, targetObj MetadataHaver) bool {
switch target := targetObj.(type) {
case *Client:
return client == target || client.HasRoleCapabs("metadata")
case *Channel:
return target.ClientIsAtLeast(client, modes.Operator) || client.HasRoleCapabs("metadata")
default:
return false // impossible
}
}
func metadataCanISeeThisTarget(client *Client, targetObj MetadataHaver) bool {
switch target := targetObj.(type) {
case *Client:
return true
case *Channel:
return target.hasClient(client) || client.HasRoleCapabs("metadata")
default:
return false // impossible
}
}

25
irc/metadata_test.go Normal file
View File

@ -0,0 +1,25 @@
package irc
import "testing"
func TestKeyCheck(t *testing.T) {
cases := []struct {
input string
isEvil bool
}{
{"ImNormalButIHaveCaps", true},
{"imnormalandidonthavecaps", false},
{"ergo.chat/vendor-extension", false},
{"", true},
{":imevil", true},
{"im:evil", true},
{"key£with$not%allowed^chars", true},
{"key.thats_completely/normal-and.fine", false},
}
for _, c := range cases {
if metadataKeyIsEvil(c.input) != c.isEvil {
t.Errorf("%s should have returned %v. but it didn't. so that's not great", c.input, c.isEvil)
}
}
}

View File

@ -251,9 +251,11 @@ func (channel *Channel) ApplyChannelModeChanges(client *Client, isSamode bool, c
switch change.Op { switch change.Op {
case modes.Add: case modes.Add:
val, err := strconv.Atoi(change.Arg) val, err := strconv.Atoi(change.Arg)
if err == nil { if err == nil && val > 0 {
channel.setUserLimit(val) channel.setUserLimit(val)
applied = append(applied, change) applied = append(applied, change)
} else {
rb.Add(nil, client.server.name, ERR_INVALIDMODEPARAM, details.nick, chname, string(change.Mode), utils.SafeErrorParam(change.Arg), client.t("+l user limit value must be an integer between 1 and 2147483647, expressed in base 10"))
} }
case modes.Remove: case modes.Remove:

View File

@ -28,17 +28,21 @@ func (mm *MonitorManager) Initialize() {
// AddMonitors adds clients using extended-monitor monitoring `client`'s nick to the passed user set. // AddMonitors adds clients using extended-monitor monitoring `client`'s nick to the passed user set.
func (manager *MonitorManager) AddMonitors(users utils.HashSet[*Session], cfnick string, capabs ...caps.Capability) { func (manager *MonitorManager) AddMonitors(users utils.HashSet[*Session], cfnick string, capabs ...caps.Capability) {
// technically, we should check extended-monitor here, but it's not really necessary
// since clients will ignore AWAY, ACCOUNT, CHGHOST, and SETNAME for users
// they're not tracking
manager.RLock() manager.RLock()
defer manager.RUnlock() defer manager.RUnlock()
for session := range manager.watchedby[cfnick] { for session := range manager.watchedby[cfnick] {
if session.capabilities.Has(caps.ExtendedMonitor) && session.capabilities.HasAll(capabs...) { if session.capabilities.HasAll(capabs...) {
users.Add(session) users.Add(session)
} }
} }
} }
// AlertAbout alerts everyone monitoring `client`'s nick that `client` is now {on,off}line. // AlertAbout alerts everyone monitoring `client`'s nick that `client` is now {on,off}line.
func (manager *MonitorManager) AlertAbout(nick, cfnick string, online bool) { func (manager *MonitorManager) AlertAbout(nick, cfnick string, online bool, client *Client) {
var watchers []*Session var watchers []*Session
// safely copy the list of clients watching our nick // safely copy the list of clients watching our nick
manager.RLock() manager.RLock()
@ -52,8 +56,21 @@ func (manager *MonitorManager) AlertAbout(nick, cfnick string, online bool) {
command = RPL_MONONLINE command = RPL_MONONLINE
} }
var metadata map[string]string
if online && client != nil {
metadata = client.ListMetadata()
}
for _, session := range watchers { for _, session := range watchers {
session.Send(nil, session.client.server.name, command, session.client.Nick(), nick) session.Send(nil, session.client.server.name, command, session.client.Nick(), nick)
if metadata != nil && session.capabilities.Has(caps.Metadata) {
for key := range session.MetadataSubscriptions() {
if val, ok := metadata[key]; ok {
session.Send(nil, client.server.name, "METADATA", nick, key, "*", val)
}
}
}
} }
} }

View File

@ -7,7 +7,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"runtime/debug" "runtime/debug"
@ -23,10 +22,6 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
) )
var (
ErrDisallowed = errors.New("disallowed")
)
const ( const (
// maximum length in bytes of any message target (nickname or channel name) in its // maximum length in bytes of any message target (nickname or channel name) in its
// canonicalized (i.e., casefolded) state: // canonicalized (i.e., casefolded) state:
@ -64,10 +59,16 @@ type MySQL struct {
trackAccountMessages atomic.Uint32 trackAccountMessages atomic.Uint32
} }
func (mysql *MySQL) Initialize(logger *logger.Manager, config Config) { var _ history.Database = (*MySQL)(nil)
func NewMySQLDatabase(logger *logger.Manager, config Config) (*MySQL, error) {
var mysql MySQL
mysql.logger = logger mysql.logger = logger
mysql.wakeForgetter = make(chan e, 1) mysql.wakeForgetter = make(chan e, 1)
mysql.SetConfig(config) mysql.SetConfig(config)
return &mysql, mysql.open()
} }
func (mysql *MySQL) SetConfig(config Config) { func (mysql *MySQL) SetConfig(config Config) {
@ -89,7 +90,7 @@ func (mysql *MySQL) getExpireTime() (expireTime time.Duration) {
return return
} }
func (m *MySQL) Open() (err error) { func (m *MySQL) open() (err error) {
var address string var address string
if m.config.SocketPath != "" { if m.config.SocketPath != "" {
address = fmt.Sprintf("unix(%s)", m.config.SocketPath) address = fmt.Sprintf("unix(%s)", m.config.SocketPath)
@ -128,7 +129,7 @@ func (m *MySQL) Open() (err error) {
func (mysql *MySQL) fixSchemas() (err error) { func (mysql *MySQL) fixSchemas() (err error) {
_, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata ( _, err = mysql.db.Exec(`CREATE TABLE IF NOT EXISTS metadata (
key_name VARCHAR(32) primary key, key_name VARCHAR(32) PRIMARY KEY,
value VARCHAR(32) NOT NULL value VARCHAR(32) NOT NULL
) CHARSET=ascii COLLATE=ascii_bin;`) ) CHARSET=ascii COLLATE=ascii_bin;`)
if err != nil { if err != nil {
@ -136,17 +137,17 @@ func (mysql *MySQL) fixSchemas() (err error) {
} }
var schema string var schema string
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaVersion).Scan(&schema) err = mysql.db.QueryRow(`SELECT value FROM metadata WHERE key_name = ?;`, keySchemaVersion).Scan(&schema)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
err = mysql.createTables() err = mysql.createTables()
if err != nil { if err != nil {
return return
} }
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaVersion, latestDbSchema) _, err = mysql.db.Exec(`INSERT INTO metadata (key_name, value) VALUES (?, ?);`, keySchemaVersion, latestDbSchema)
if err != nil { if err != nil {
return return
} }
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) _, err = mysql.db.Exec(`INSERT INTO metadata (key_name, value) VALUES (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil { if err != nil {
return return
} }
@ -159,7 +160,7 @@ func (mysql *MySQL) fixSchemas() (err error) {
} }
var minorVersion string var minorVersion string
err = mysql.db.QueryRow(`select value from metadata where key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion) err = mysql.db.QueryRow(`SELECT value FROM metadata WHERE key_name = ?;`, keySchemaMinorVersion).Scan(&minorVersion)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// XXX for now, the only minor version upgrade is the account tracking tables // XXX for now, the only minor version upgrade is the account tracking tables
err = mysql.createComplianceTables() err = mysql.createComplianceTables()
@ -170,7 +171,7 @@ func (mysql *MySQL) fixSchemas() (err error) {
if err != nil { if err != nil {
return return
} }
_, err = mysql.db.Exec(`insert into metadata (key_name, value) values (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion) _, err = mysql.db.Exec(`INSERT INTO metadata (key_name, value) VALUES (?, ?);`, keySchemaMinorVersion, latestDbMinorVersion)
if err != nil { if err != nil {
return return
} }
@ -180,7 +181,7 @@ func (mysql *MySQL) fixSchemas() (err error) {
if err != nil { if err != nil {
return return
} }
_, err = mysql.db.Exec(`update metadata set value = ? where key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion) _, err = mysql.db.Exec(`UPDATE metadata SET value = ? WHERE key_name = ?;`, latestDbMinorVersion, keySchemaMinorVersion)
if err != nil { if err != nil {
return return
} }
@ -416,7 +417,7 @@ func (mysql *MySQL) deleteCorrespondents(ctx context.Context, threshold int64) {
} else { } else {
count, err := result.RowsAffected() count, err := result.RowsAffected()
if !mysql.logError("error deleting correspondents", err) { if !mysql.logError("error deleting correspondents", err) {
mysql.logger.Debug(fmt.Sprintf("deleted %d correspondents entries", count)) mysql.logger.Debug("mysql", fmt.Sprintf("deleted %d correspondents entries", count))
} }
} }
} }
@ -623,40 +624,46 @@ func (mysql *MySQL) AddChannelItem(target string, item history.Item, account str
func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) { func (mysql *MySQL) insertSequenceEntry(ctx context.Context, target string, messageTime int64, id int64) (err error) {
_, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id) _, err = mysql.insertSequence.ExecContext(ctx, target, messageTime, id)
mysql.logError("could not insert sequence entry", err) if err != nil {
return fmt.Errorf("could not insert sequence entry: %w", err)
}
return return
} }
func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) { func (mysql *MySQL) insertConversationEntry(ctx context.Context, target, correspondent string, messageTime int64, id int64) (err error) {
_, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id) _, err = mysql.insertConversation.ExecContext(ctx, target, correspondent, messageTime, id)
mysql.logError("could not insert conversations entry", err) if err != nil {
return fmt.Errorf("could not insert conversations entry: %w", err)
}
return return
} }
func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) { func (mysql *MySQL) insertCorrespondentsEntry(ctx context.Context, target, correspondent string, messageTime int64, historyId int64) (err error) {
_, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime) _, err = mysql.insertCorrespondent.ExecContext(ctx, target, correspondent, messageTime, messageTime)
mysql.logError("could not insert conversations entry", err) if err != nil {
return fmt.Errorf("could not insert correspondents entry: %w", err)
}
return return
} }
func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) { func (mysql *MySQL) insertBase(ctx context.Context, item history.Item) (id int64, err error) {
value, err := marshalItem(&item) value, err := history.MarshalItem(&item)
if mysql.logError("could not marshal item", err) { if err != nil {
return return 0, fmt.Errorf("could not marshal item: %w", err)
} }
msgidBytes, err := decodeMsgid(item.Message.Msgid) msgidBytes, err := utils.DecodeSecretToken(item.Message.Msgid)
if mysql.logError("could not decode msgid", err) { if err != nil {
return return 0, fmt.Errorf("could not decode msgid: %w", err)
} }
result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes) result, err := mysql.insertHistory.ExecContext(ctx, value, msgidBytes)
if mysql.logError("could not insert item", err) { if err != nil {
return return 0, fmt.Errorf("could not insert item: %w", err)
} }
id, err = result.LastInsertId() id, err = result.LastInsertId()
if mysql.logError("could not insert item", err) { if err != nil {
return return 0, fmt.Errorf("could not insert item: %w", err)
} }
return return
@ -667,7 +674,9 @@ func (mysql *MySQL) insertAccountMessageEntry(ctx context.Context, id int64, acc
return return
} }
_, err = mysql.insertAccountMessage.ExecContext(ctx, id, account) _, err = mysql.insertAccountMessage.ExecContext(ctx, id, account)
mysql.logError("could not insert account-message entry", err) if err != nil {
return fmt.Errorf("could not insert account-message entry: %w", err)
}
return return
} }
@ -735,20 +744,25 @@ func (mysql *MySQL) DeleteMsgid(msgid, accountName string) (err error) {
_, id, data, err := mysql.lookupMsgid(ctx, msgid, true) _, id, data, err := mysql.lookupMsgid(ctx, msgid, true)
if err != nil { if err != nil {
if err == sql.ErrNoRows {
return history.ErrNotFound
}
return return
} }
if accountName != "*" { if accountName != "*" {
var item history.Item var item history.Item
err = unmarshalItem(data, &item) err = history.UnmarshalItem(data, &item)
// delete if the entry is corrupt // delete if the entry is corrupt
if err == nil && item.AccountName != accountName { if err == nil && item.AccountName != accountName {
return ErrDisallowed return history.ErrDisallowed
} }
} }
err = mysql.deleteHistoryIDs(ctx, []uint64{id}) err = mysql.deleteHistoryIDs(ctx, []uint64{id})
mysql.logError("couldn't delete msgid", err) if err != nil {
return fmt.Errorf("couldn't delete msgid: %w", err)
}
return return
} }
@ -784,7 +798,7 @@ func (mysql *MySQL) Export(account string, writer io.Writer) {
if err != nil { if err != nil {
return return
} }
err = unmarshalItem(blob, &item) err = history.UnmarshalItem(blob, &item)
if err != nil { if err != nil {
return return
} }
@ -812,8 +826,11 @@ func (mysql *MySQL) Export(account string, writer io.Writer) {
} }
func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) { func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData bool) (result time.Time, id uint64, data []byte, err error) {
decoded, err := decodeMsgid(msgid) decoded, err := utils.DecodeSecretToken(msgid)
if err != nil { if err != nil {
// use sql.ErrNoRows internally for consistency, translate to history.ErrNotFound
// at the package boundary if necessary
err = sql.ErrNoRows
return return
} }
cols := `sequence.nanotime, conversations.nanotime` cols := `sequence.nanotime, conversations.nanotime`
@ -831,10 +848,10 @@ func (mysql *MySQL) lookupMsgid(ctx context.Context, msgid string, includeData b
} else { } else {
err = row.Scan(&nanoSeq, &nanoConv, &id, &data) err = row.Scan(&nanoSeq, &nanoConv, &id, &data)
} }
if err != sql.ErrNoRows {
mysql.logError("could not resolve msgid to time", err)
}
if err != nil { if err != nil {
if err != sql.ErrNoRows {
err = fmt.Errorf("could not resolve msgid to time: %w", err)
}
return return
} }
nanotime := extractNanotime(nanoSeq, nanoConv) nanotime := extractNanotime(nanoSeq, nanoConv)
@ -857,8 +874,8 @@ func extractNanotime(seq, conv sql.NullInt64) (result int64) {
func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) { func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...interface{}) (results []history.Item, err error) {
rows, err := mysql.db.QueryContext(ctx, query, args...) rows, err := mysql.db.QueryContext(ctx, query, args...)
if mysql.logError("could not select history items", err) { if err != nil {
return return nil, fmt.Errorf("could not select history items: %w", err)
} }
defer rows.Close() defer rows.Close()
@ -867,12 +884,12 @@ func (mysql *MySQL) selectItems(ctx context.Context, query string, args ...inter
var blob []byte var blob []byte
var item history.Item var item history.Item
err = rows.Scan(&blob) err = rows.Scan(&blob)
if mysql.logError("could not scan history item", err) { if err != nil {
return return nil, fmt.Errorf("could not scan history item: %w", err)
} }
err = unmarshalItem(blob, &item) err = history.UnmarshalItem(blob, &item)
if mysql.logError("could not unmarshal history item", err) { if err != nil {
return return nil, fmt.Errorf("could not unmarshal history item: %w", err)
} }
results = append(results, item) results = append(results, item)
} }
@ -949,7 +966,7 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin
rows, err := mysql.db.QueryContext(ctx, query, args...) rows, err := mysql.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return return nil, fmt.Errorf("could not query correspondents: %w", err)
} }
defer rows.Close() defer rows.Close()
var correspondent string var correspondent string
@ -957,7 +974,7 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin
for rows.Next() { for rows.Next() {
err = rows.Scan(&correspondent, &nanotime) err = rows.Scan(&correspondent, &nanotime)
if err != nil { if err != nil {
return return nil, fmt.Errorf("could not scan correspondents: %w", err)
} }
results = append(results, history.TargetListing{ results = append(results, history.TargetListing{
CfName: correspondent, CfName: correspondent,
@ -972,6 +989,19 @@ func (mysql *MySQL) listCorrespondentsInternal(ctx context.Context, target strin
return return
} }
func (mysql *MySQL) ListCorrespondents(cftarget string, start, end time.Time, limit int) (results []history.TargetListing, err error) {
ctx, cancel := context.WithTimeout(context.Background(), mysql.getTimeout())
defer cancel()
// TODO accept msgids here?
results, err = mysql.listCorrespondentsInternal(ctx, cftarget, start, end, time.Time{}, limit)
if err != nil {
return nil, fmt.Errorf("could not read correspondents: %w", err)
}
return
}
func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) { func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetListing, err error) {
if mysql.db == nil { if mysql.db == nil {
return return
@ -1000,8 +1030,8 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL
queryBuf.WriteString(") GROUP BY sequence.target;") queryBuf.WriteString(") GROUP BY sequence.target;")
rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...) rows, err := mysql.db.QueryContext(ctx, queryBuf.String(), args...)
if mysql.logError("could not query channel listings", err) { if err != nil {
return return nil, fmt.Errorf("could not query channel listings: %w", err)
} }
defer rows.Close() defer rows.Close()
@ -1009,8 +1039,8 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL
var nanotime int64 var nanotime int64
for rows.Next() { for rows.Next() {
err = rows.Scan(&target, &nanotime) err = rows.Scan(&target, &nanotime)
if mysql.logError("could not scan channel listings", err) { if err != nil {
return return nil, fmt.Errorf("could not scan channel listings: %w", err)
} }
results = append(results, history.TargetListing{ results = append(results, history.TargetListing{
CfName: target, CfName: target,
@ -1020,12 +1050,13 @@ func (mysql *MySQL) ListChannels(cfchannels []string) (results []history.TargetL
return return
} }
func (mysql *MySQL) Close() { func (mysql *MySQL) Close() (err error) {
// closing the database will close our prepared statements as well // closing the database will close our prepared statements as well
if mysql.db != nil { if mysql.db != nil {
mysql.db.Close() err = mysql.db.Close()
} }
mysql.db = nil mysql.db = nil
return
} }
// implements history.Sequence, emulating a single history buffer (for a channel, // implements history.Sequence, emulating a single history buffer (for a channel,
@ -1072,19 +1103,6 @@ func (s *mySQLHistorySequence) Around(start history.Selector, limit int) (result
return history.GenericAround(s, start, limit) return history.GenericAround(s, start, limit)
} }
func (seq *mySQLHistorySequence) ListCorrespondents(start, end history.Selector, limit int) (results []history.TargetListing, err error) {
ctx, cancel := context.WithTimeout(context.Background(), seq.mysql.getTimeout())
defer cancel()
// TODO accept msgids here?
startTime := start.Time
endTime := end.Time
results, err = seq.mysql.listCorrespondentsInternal(ctx, seq.target, startTime, endTime, seq.cutoff, limit)
seq.mysql.logError("could not read correspondents", err)
return
}
func (seq *mySQLHistorySequence) Cutoff() time.Time { func (seq *mySQLHistorySequence) Cutoff() time.Time {
return seq.cutoff return seq.cutoff
} }

View File

@ -1,23 +0,0 @@
package mysql
import (
"encoding/json"
"github.com/ergochat/ergo/irc/history"
"github.com/ergochat/ergo/irc/utils"
)
// 123 / '{' is the magic number that means JSON;
// if we want to do a binary encoding later, we just have to add different magic version numbers
func marshalItem(item *history.Item) (result []byte, err error) {
return json.Marshal(item)
}
func unmarshalItem(data []byte, result *history.Item) (err error) {
return json.Unmarshal(data, result)
}
func decodeMsgid(msgid string) ([]byte, error) {
return utils.B32Encoder.DecodeString(msgid)
}

View File

@ -128,9 +128,11 @@ func performNickChange(server *Server, client *Client, target *Client, session *
} }
newCfnick := target.NickCasefolded() newCfnick := target.NickCasefolded()
if newCfnick != details.nickCasefolded { // send MONITOR updates only for nick changes, not for new connection registration;
client.server.monitorManager.AlertAbout(details.nick, details.nickCasefolded, false) // defer MONITOR for new connection registration until pre-registration metadata is applied
client.server.monitorManager.AlertAbout(assignedNickname, newCfnick, true) if hadNick && newCfnick != details.nickCasefolded {
client.server.monitorManager.AlertAbout(details.nick, details.nickCasefolded, false, nil)
client.server.monitorManager.AlertAbout(assignedNickname, newCfnick, true, target)
} }
return nil return nil
} }

View File

@ -1055,10 +1055,10 @@ func nsSaregisterHandler(service *ircService, server *Server, client *Client, co
var failCode string var failCode string
if err == errAccountAlreadyRegistered || err == errAccountAlreadyVerified { if err == errAccountAlreadyRegistered || err == errAccountAlreadyVerified {
errMsg = client.t("Account already exists") errMsg = client.t("Account already exists")
failCode = "USERNAME_EXISTS" failCode = "ACCOUNT_EXISTS"
} else if err == errNameReserved { } else if err == errNameReserved {
errMsg = client.t(err.Error()) errMsg = client.t(err.Error())
failCode = "USERNAME_EXISTS" failCode = "ACCOUNT_EXISTS"
} else if err == errAccountBadPassphrase { } else if err == errAccountBadPassphrase {
errMsg = client.t("Passphrase contains forbidden characters or is otherwise invalid") errMsg = client.t("Passphrase contains forbidden characters or is otherwise invalid")
failCode = "INVALID_PASSWORD" failCode = "INVALID_PASSWORD"

View File

@ -79,8 +79,8 @@ const (
RPL_WHOISACTUALLY = "338" RPL_WHOISACTUALLY = "338"
RPL_INVITING = "341" RPL_INVITING = "341"
RPL_SUMMONING = "342" RPL_SUMMONING = "342"
RPL_INVITELIST = "346" RPL_INVEXLIST = "346"
RPL_ENDOFINVITELIST = "347" RPL_ENDOFINVEXLIST = "347"
RPL_EXCEPTLIST = "348" RPL_EXCEPTLIST = "348"
RPL_ENDOFEXCEPTLIST = "349" RPL_ENDOFEXCEPTLIST = "349"
RPL_VERSION = "351" RPL_VERSION = "351"
@ -183,6 +183,13 @@ const (
RPL_MONLIST = "732" RPL_MONLIST = "732"
RPL_ENDOFMONLIST = "733" RPL_ENDOFMONLIST = "733"
ERR_MONLISTFULL = "734" ERR_MONLISTFULL = "734"
RPL_WHOISKEYVALUE = "760" // metadata numerics
RPL_KEYVALUE = "761"
RPL_KEYNOTSET = "766"
RPL_METADATASUBOK = "770"
RPL_METADATAUNSUBOK = "771"
RPL_METADATASUBS = "772"
RPL_METADATASYNCLATER = "774" // end metadata numerics
RPL_LOGGEDIN = "900" RPL_LOGGEDIN = "900"
RPL_LOGGEDOUT = "901" RPL_LOGGEDOUT = "901"
ERR_NICKLOCKED = "902" ERR_NICKLOCKED = "902"

View File

@ -11,6 +11,7 @@ import (
const ( const (
MinCost = bcrypt.MinCost MinCost = bcrypt.MinCost
MaxCost = bcrypt.MaxCost
DefaultCost = 12 // ballpark: 250 msec on a modern Intel CPU DefaultCost = 12 // ballpark: 250 msec on a modern Intel CPU
) )

View File

@ -90,7 +90,8 @@ type Server struct {
snomasks SnoManager snomasks SnoManager
store *buntdb.DB store *buntdb.DB
dstore datastore.Datastore dstore datastore.Datastore
historyDB mysql.MySQL mysqlHistoryDB *mysql.MySQL
historyDB history.Database
torLimiter connection_limits.TorLimiter torLimiter connection_limits.TorLimiter
whoWas WhoWasList whoWas WhoWasList
stats Stats stats Stats
@ -153,7 +154,6 @@ func (server *Server) Shutdown() {
sdnotify.Stopping() sdnotify.Stopping()
server.logger.Info("server", "Stopping server") server.logger.Info("server", "Stopping server")
//TODO(dan): Make sure we disallow new nicks
for _, client := range server.clients.AllClients() { for _, client := range server.clients.AllClients() {
client.Notice("Server is shutting down") client.Notice("Server is shutting down")
} }
@ -162,10 +162,12 @@ func (server *Server) Shutdown() {
server.performAlwaysOnMaintenance(false, true) server.performAlwaysOnMaintenance(false, true)
if err := server.store.Close(); err != nil { if err := server.store.Close(); err != nil {
server.logger.Error("shutdown", fmt.Sprintln("Could not close datastore:", err)) server.logger.Error("shutdown", "Could not close datastore", err.Error())
} }
server.historyDB.Close() if err := server.historyDB.Close(); err != nil {
server.logger.Error("shutdown", "Could not close history database", err.Error())
}
server.logger.Info("server", fmt.Sprintf("%s exiting", Ver)) server.logger.Info("server", fmt.Sprintf("%s exiting", Ver))
} }
@ -428,6 +430,10 @@ func (server *Server) tryRegister(c *Client, session *Session) (exiting bool) {
c.SetMode(defaultMode, true) c.SetMode(defaultMode, true)
} }
c.applyPreregMetadata(session)
c.server.monitorManager.AlertAbout(c.Nick(), c.NickCasefolded(), true, c)
// this is not a reattach, so if the client is always-on, this is the first time // this is not a reattach, so if the client is always-on, this is the first time
// the Client object was created during the current server uptime. mark dirty in // the Client object was created during the current server uptime. mark dirty in
// order to persist the realname and the user modes: // order to persist the realname and the user modes:
@ -496,6 +502,9 @@ func (server *Server) playRegistrationBurst(session *Session) {
if !(rb.session.capabilities.Has(caps.ExtendedISupport) && rb.session.isupportSentPrereg) { if !(rb.session.capabilities.Has(caps.ExtendedISupport) && rb.session.isupportSentPrereg) {
server.RplISupport(c, rb) server.RplISupport(c, rb)
} }
if session.capabilities.Has(caps.Metadata) {
playMetadataVerbBatch(rb, d.nick, c.ListMetadata())
}
if d.account != "" && session.capabilities.Has(caps.Persistence) { if d.account != "" && session.capabilities.Has(caps.Persistence) {
reportPersistenceStatus(c, rb, false) reportPersistenceStatus(c, rb, false)
} }
@ -630,7 +639,6 @@ func (client *Client) getWhoisOf(target *Client, hasPrivs bool, rb *ResponseBuff
if target.HasMode(modes.Bot) { if target.HasMode(modes.Bot) {
rb.Add(nil, client.server.name, RPL_WHOISBOT, cnick, tnick, fmt.Sprintf(ircfmt.Unescape(client.t("is a $bBot$b on %s")), client.server.Config().Network.Name)) rb.Add(nil, client.server.name, RPL_WHOISBOT, cnick, tnick, fmt.Sprintf(ircfmt.Unescape(client.t("is a $bBot$b on %s")), client.server.Config().Network.Name))
} }
if client == target || oper.HasRoleCapab("ban") { if client == target || oper.HasRoleCapab("ban") {
for _, session := range target.Sessions() { for _, session := range target.Sessions() {
if session.certfp != "" { if session.certfp != "" {
@ -642,6 +650,11 @@ func (client *Client) getWhoisOf(target *Client, hasPrivs bool, rb *ResponseBuff
if away, awayMessage := target.Away(); away { if away, awayMessage := target.Away(); away {
rb.Add(nil, client.server.name, RPL_AWAY, cnick, tnick, awayMessage) rb.Add(nil, client.server.name, RPL_AWAY, cnick, tnick, awayMessage)
} }
if rb.session.capabilities.Has(caps.Metadata) {
for key, value := range target.ListMetadata() {
rb.Add(nil, client.server.name, RPL_WHOISKEYVALUE, cnick, tnick, key, "*", value)
}
}
} }
// rehash reloads the config and applies the changes from the config file. // rehash reloads the config and applies the changes from the config file.
@ -685,6 +698,9 @@ func (server *Server) applyConfig(config *Config) (err error) {
globalCasemappingSetting = config.Server.Casemapping globalCasemappingSetting = config.Server.Casemapping
globalUtf8EnforcementSetting = config.Server.EnforceUtf8 globalUtf8EnforcementSetting = config.Server.EnforceUtf8
MaxLineLen = config.Server.MaxLineLen MaxLineLen = config.Server.MaxLineLen
RegisterTimeout = config.Server.IdleTimeouts.Registration
PingTimeout = config.Server.IdleTimeouts.Ping
DisconnectTimeout = config.Server.IdleTimeouts.Disconnect
} else { } else {
// enforce configs that can't be changed after launch: // enforce configs that can't be changed after launch:
if server.name != config.Server.Name { if server.name != config.Server.Name {
@ -710,6 +726,8 @@ func (server *Server) applyConfig(config *Config) (err error) {
return fmt.Errorf("Cannot enable MySQL after launching the server, rehash aborted") return fmt.Errorf("Cannot enable MySQL after launching the server, rehash aborted")
} else if oldConfig.Server.MaxLineLen != config.Server.MaxLineLen { } else if oldConfig.Server.MaxLineLen != config.Server.MaxLineLen {
return fmt.Errorf("Cannot change max-line-len after launching the server, rehash aborted") return fmt.Errorf("Cannot change max-line-len after launching the server, rehash aborted")
} else if oldConfig.Server.IdleTimeouts != config.Server.IdleTimeouts {
return fmt.Errorf("Cannot change idle-timeouts after launching the server, rehash aborted")
} }
} }
@ -788,8 +806,10 @@ func (server *Server) applyConfig(config *Config) (err error) {
return err return err
} }
} else { } else {
if config.Datastore.MySQL.Enabled && config.Datastore.MySQL != oldConfig.Datastore.MySQL { if config.Datastore.MySQL.Enabled && server.mysqlHistoryDB != nil {
server.historyDB.SetConfig(config.Datastore.MySQL) if config.Datastore.MySQL != oldConfig.Datastore.MySQL {
server.mysqlHistoryDB.SetConfig(config.Datastore.MySQL)
}
} }
} }
@ -896,6 +916,9 @@ func (server *Server) applyConfig(config *Config) (err error) {
if config.Accounts.RequireSasl.Enabled && config.Accounts.Registration.Enabled { if config.Accounts.RequireSasl.Enabled && config.Accounts.Registration.Enabled {
server.logger.Warning("server", "Warning: although require-sasl is enabled, users can still register accounts. If your server is not intended to be public, you must set accounts.registration.enabled to false.") server.logger.Warning("server", "Warning: although require-sasl is enabled, users can still register accounts. If your server is not intended to be public, you must set accounts.registration.enabled to false.")
} }
if config.History.Enabled && config.History.ChathistoryMax == 0 {
server.logger.Warning("server", "Warning: for history to work correctly, you must set history.chathistory-maxmessages (see default.yaml for a recommendation).")
}
return err return err
} }
@ -996,12 +1019,14 @@ func (server *Server) loadFromDatastore(config *Config) (err error) {
server.accounts.Initialize(server) server.accounts.Initialize(server)
if config.Datastore.MySQL.Enabled { if config.Datastore.MySQL.Enabled {
server.historyDB.Initialize(server.logger, config.Datastore.MySQL) server.mysqlHistoryDB, err = mysql.NewMySQLDatabase(server.logger, config.Datastore.MySQL)
err = server.historyDB.Open()
if err != nil { if err != nil {
server.logger.Error("internal", "could not connect to mysql", err.Error()) server.logger.Error("internal", "could not connect to mysql", err.Error())
return err return err
} }
server.historyDB = server.mysqlHistoryDB
} else {
server.historyDB = history.NewNoopDatabase()
} }
return nil return nil
@ -1066,10 +1091,6 @@ func (server *Server) setupListeners(config *Config) (err error) {
// we may already know the channel we're querying, or we may have // we may already know the channel we're querying, or we may have
// to look it up via a string query. This function is responsible for // to look it up via a string query. This function is responsible for
// privilege checking. // privilege checking.
// XXX: call this with providedChannel==nil and query=="" to get a sequence
// suitable for ListCorrespondents (i.e., this function is still used to
// decide whether the ringbuf or mysql is authoritative about the client's
// message history).
func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) { func (server *Server) GetHistorySequence(providedChannel *Channel, client *Client, query string) (channel *Channel, sequence history.Sequence, err error) {
config := server.Config() config := server.Config()
// 4 cases: {persistent, ephemeral} x {normal, conversation} // 4 cases: {persistent, ephemeral} x {normal, conversation}
@ -1221,7 +1242,7 @@ func (server *Server) DeleteMessage(target, msgid, accountName string) (err erro
return item.Message.Msgid == msgid && (accountName == "*" || item.AccountName == accountName) return item.Message.Msgid == msgid && (accountName == "*" || item.AccountName == accountName)
}) })
if count == 0 { if count == 0 {
err = errNoop err = history.ErrNotFound
} }
} }

View File

@ -73,7 +73,7 @@ var globalCasemappingSetting Casemapping = CasemappingPRECIS
// XXX analogous unsynchronized global variable controlling utf8 validation // XXX analogous unsynchronized global variable controlling utf8 validation
// if this is off, you get the traditional IRC behavior (relaying any valid RFC1459 // if this is off, you get the traditional IRC behavior (relaying any valid RFC1459
// octets) and invalid utf8 messages are silently dropped for websocket clients only. // octets), and websocket listeners are disabled.
// if this is on, invalid utf8 inputs get a FAIL reply. // if this is on, invalid utf8 inputs get a FAIL reply.
var globalUtf8EnforcementSetting bool var globalUtf8EnforcementSetting bool

28
irc/utils/chunks.go Normal file
View File

@ -0,0 +1,28 @@
package utils
import "iter"
func ChunkifyParams(params iter.Seq[string], maxChars int) [][]string {
var chunked [][]string
var acc []string
var length = 0
for p := range params {
length = length + len(p) + 1 // (accounting for the space)
if length > maxChars {
chunked = append(chunked, acc)
acc = []string{}
length = 0
}
acc = append(acc, p)
}
if len(acc) != 0 {
chunked = append(chunked, acc)
}
return chunked
}

View File

@ -42,6 +42,11 @@ func GenerateSecretToken() string {
return B32Encoder.EncodeToString(buf[:]) return B32Encoder.EncodeToString(buf[:])
} }
// return a compact representation of a token generated by GenerateSecretToken()
func DecodeSecretToken(t string) ([]byte, error) {
return B32Encoder.DecodeString(t)
}
// securely check if a supplied token matches a stored token // securely check if a supplied token matches a stored token
func SecretTokensMatch(storedToken string, suppliedToken string) bool { func SecretTokensMatch(storedToken string, suppliedToken string) bool {
// XXX fix a potential gotcha: if the stored token is uninitialized, // XXX fix a potential gotcha: if the stored token is uninitialized,

View File

@ -7,7 +7,7 @@ import "fmt"
const ( const (
// SemVer is the semantic version of Ergo. // SemVer is the semantic version of Ergo.
SemVer = "2.16.0" SemVer = "2.18.0-unreleased"
) )
var ( var (

View File

@ -218,7 +218,7 @@ func zncPlayPrivmsgsFromAll(client *Client, rb *ResponseBuffer, start, end time.
// PRIVMSG *playback :list // PRIVMSG *playback :list
func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) { func zncPlaybackListHandler(client *Client, command string, params []string, rb *ResponseBuffer) {
limit := client.server.Config().History.ChathistoryMax limit := client.server.Config().History.ChathistoryMax
correspondents, err := client.listTargets(history.Selector{}, history.Selector{}, limit) correspondents, err := client.listTargets(time.Time{}, time.Time{}, limit)
if err != nil { if err != nil {
client.server.logger.Error("internal", "couldn't get history for ZNC list", err.Error()) client.server.logger.Error("internal", "couldn't get history for ZNC list", err.Error())
return return

@ -1 +1 @@
Subproject commit e9e37f5438bd5f02656b89dab0cd40ef113edac6 Subproject commit 959114063189eac3aaad82a229acbcaeb78fb184

View File

@ -154,6 +154,21 @@ server:
# if this is true, the motd is escaped using formatting codes like $c, $b, and $i # if this is true, the motd is escaped using formatting codes like $c, $b, and $i
motd-formatting: true motd-formatting: true
# send a configurable notice to clients immediately after they connect,
# as a way of detecting open proxies
#initial-notice: "*** Welcome to the Ergo IRC server"
# idle timeouts for inactive clients
idle-timeouts:
# give the client this long to complete connection registration (i.e. the initial
# IRC handshake, including capability negotiation and SASL)
registration: 60s
# if the client hasn't sent anything for this long, send them a PING
ping: 1m30s
# if the client hasn't sent anything for this long (including the PONG to the
# above PING), disconnect them
disconnect: 2m30s
# relaying using the RELAYMSG command # relaying using the RELAYMSG command
relaymsg: relaymsg:
# is relaymsg enabled at all? # is relaymsg enabled at all?
@ -330,6 +345,10 @@ server:
secure-nets: secure-nets:
# - "10.0.0.0/8" # - "10.0.0.0/8"
# allow attempts to OPER with a password at most this often. defaults to
# 10 seconds when unset.
oper-throttle: 10s
# Ergo will write files to disk under certain circumstances, e.g., # Ergo will write files to disk under certain circumstances, e.g.,
# CPU profiling or data export. by default, these files will be written # CPU profiling or data export. by default, these files will be written
# to the working directory. set this to customize: # to the working directory. set this to customize:
@ -494,7 +513,7 @@ accounts:
# 1. these nicknames cannot be registered or reserved # 1. these nicknames cannot be registered or reserved
# 2. if a client is automatically renamed by the server, # 2. if a client is automatically renamed by the server,
# this is the template that will be used (e.g., Guest-nccj6rgmt97cg) # this is the template that will be used (e.g., Guest-nccj6rgmt97cg)
# 3. if enforce-guest-format (see below) is enabled, clients without # 3. if force-guest-format (see below) is enabled, clients without
# a registered account will have this template applied to their # a registered account will have this template applied to their
# nicknames (e.g., 'katie' will become 'Guest-katie') # nicknames (e.g., 'katie' will become 'Guest-katie')
guest-nickname-format: "Guest-*" guest-nickname-format: "Guest-*"
@ -695,6 +714,7 @@ oper-classes:
- "history" # modify or delete history messages - "history" # modify or delete history messages
- "defcon" # use the DEFCON command (restrict server capabilities) - "defcon" # use the DEFCON command (restrict server capabilities)
- "massmessage" # message all users on the server - "massmessage" # message all users on the server
- "metadata" # modify arbitrary metadata on channels and users
# ircd operators # ircd operators
opers: opers:
@ -958,10 +978,12 @@ history:
# in your country and the countries of your users. # in your country and the countries of your users.
enabled: true enabled: true
# how many channel-specific events (messages, joins, parts) should be tracked per channel? # if the in-memory backend is enabled for a channel, how many channel-specific events
# (messages, joins, parts) should be retained?
channel-length: 2048 channel-length: 2048
# how many direct messages and notices should be tracked per user? # if the in-memory backend is enabled for a user, how many direct messages
# and notices should be retained?
client-length: 256 client-length: 256
# how long should we try to preserve messages? # how long should we try to preserve messages?
@ -1058,6 +1080,20 @@ history:
# e.g., ERGO__SERVER__MAX_SENDQ=128k. see the manual for more details. # e.g., ERGO__SERVER__MAX_SENDQ=128k. see the manual for more details.
allow-environment-overrides: true allow-environment-overrides: true
# metadata support for setting key/value data on channels and nicknames.
metadata:
# can clients store metadata?
enabled: true
# how many keys can a client subscribe to?
max-subs: 100
# how many keys can be stored per entity?
max-keys: 100
# rate limiting for client metadata updates, which are expensive to process
client-throttle:
enabled: true
duration: 2m
max-attempts: 10
# experimental support for mobile push notifications # experimental support for mobile push notifications
# see the manual for potential security, privacy, and performance implications. # see the manual for potential security, privacy, and performance implications.
# DO NOT enable if you are running a Tor or I2P hidden service (i.e. one # DO NOT enable if you are running a Tor or I2P hidden service (i.e. one

27
vendor/filippo.io/edwards25519/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

14
vendor/filippo.io/edwards25519/README.md generated vendored Normal file
View File

@ -0,0 +1,14 @@
# filippo.io/edwards25519
```
import "filippo.io/edwards25519"
```
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.

20
vendor/filippo.io/edwards25519/doc.go generated vendored Normal file
View File

@ -0,0 +1,20 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package edwards25519 implements group logic for the twisted Edwards curve
//
// -x^2 + y^2 = 1 + -(121665/121666)*x^2*y^2
//
// This is better known as the Edwards curve equivalent to Curve25519, and is
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519

427
vendor/filippo.io/edwards25519/edwards25519.go generated vendored Normal file
View File

@ -0,0 +1,427 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"errors"
"filippo.io/edwards25519/field"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}

349
vendor/filippo.io/edwards25519/extra.go generated vendored Normal file
View File

@ -0,0 +1,349 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.
import (
"errors"
"filippo.io/edwards25519/field"
)
// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap. Don't change the style without making
// sure it doesn't increase the inliner cost.
var e [4]field.Element
X, Y, Z, T = v.extendedCoordinates(&e)
return
}
func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
checkInitialized(v)
X = e[0].Set(&v.x)
Y = e[1].Set(&v.y)
Z = e[2].Set(&v.z)
T = e[3].Set(&v.t)
return
}
// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
if !isOnCurve(X, Y, Z, T) {
return nil, errors.New("edwards25519: invalid point coordinates")
}
v.x.Set(X)
v.y.Set(Y)
v.z.Set(Z)
v.t.Set(T)
return v, nil
}
func isOnCurve(X, Y, Z, T *field.Element) bool {
var lhs, rhs field.Element
XX := new(field.Element).Square(X)
YY := new(field.Element).Square(Y)
ZZ := new(field.Element).Square(Z)
TT := new(field.Element).Square(T)
// -x² + y² = 1 + dx²y²
// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
// -X² + Y² = Z² + dT²
lhs.Subtract(YY, XX)
rhs.Multiply(d, TT).Add(&rhs, ZZ)
if lhs.Equal(&rhs) != 1 {
return false
}
// xy = T/Z
// XY/Z² = T/Z
// XY = TZ
lhs.Multiply(X, Y)
rhs.Multiply(T, Z)
return lhs.Equal(&rhs) == 1
}
// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func (v *Point) BytesMontgomery() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytesMontgomery(&buf)
}
func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
checkInitialized(v)
// RFC 7748, Section 4.1 provides the bilinear map to calculate the
// Montgomery u-coordinate
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
var y, recip, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
return copyFieldElement(buf, &u)
}
// MultByCofactor sets v = 8 * p, and returns v.
func (v *Point) MultByCofactor(p *Point) *Point {
checkInitialized(p)
result := projP1xP1{}
pp := (&projP2{}).FromP3(p)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
pp.FromP1xP1(&result)
result.Double(pp)
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
}
}
// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func (s *Scalar) Invert(t *Scalar) *Scalar {
// Uses a hardcoded sliding window of width 4.
var table [8]Scalar
var tt Scalar
tt.Multiply(t, t)
table[0] = *t
for i := 0; i < 7; i++ {
table[i+1].Multiply(&table[i], &tt)
}
// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
// so t**k = t[k/2] for odd k
// To compute the sliding window digits, use the following Sage script:
// sage: import itertools
// sage: def sliding_window(w,k):
// ....: digits = []
// ....: while k > 0:
// ....: if k % 2 == 1:
// ....: kmod = k % (2**w)
// ....: digits.append(kmod)
// ....: k = k - kmod
// ....: else:
// ....: digits.append(0)
// ....: k = k // 2
// ....: return digits
// Now we can compute s roughly as follows:
// sage: s = 1
// sage: for coeff in reversed(sliding_window(4,l-2)):
// ....: s = s*s
// ....: if coeff > 0 :
// ....: s = s*t**coeff
// This works on one bit at a time, with many runs of zeros.
// The digits can be collapsed into [(count, coeff)] as follows:
// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
// Entries of the form (k, 0) turn into pow2k(k)
// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
*s = table[1/2]
s.pow2k(127 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[5/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[1/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(5 + 1)
s.Multiply(s, &table[11/2])
s.pow2k(9 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[3/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[13/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[7/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[9/2])
s.pow2k(3 + 1)
s.Multiply(s, &table[15/2])
s.pow2k(4 + 1)
s.Multiply(s, &table[11/2])
return s
}
// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called MultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Proceed as in the single-base case, but share doublings
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][63])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
for i := 62; i >= 0; i-- {
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
// Lookup-and-add the appropriate multiple of each input point
for j := range tables {
tables[j].SelectInto(multiple, digits[j][i])
tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
v.fromP1xP1(tmp1) // update v
}
tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
}
return v
}
// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
if len(scalars) != len(points) {
panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
}
checkInitialized(points...)
// Generalize double-base NAF computation to arbitrary sizes.
// Here all the points are dynamic, so we only use the smaller
// tables.
// Build lookup tables for each point
tables := make([]nafLookupTable5, len(points))
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute a NAF for each scalar
nafs := make([][256]int8, len(scalars))
for i := range nafs {
nafs[i] = scalars[i].nonAdjacentForm(5)
}
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
//
// Skip trying to find the first nonzero coefficent, because
// searching might be more work than a few extra doublings.
for i := 255; i >= 0; i-- {
tmp1.Double(tmp2)
for j := range nafs {
if nafs[j][i] > 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, nafs[j][i])
tmp1.Add(v, multiple)
} else if nafs[j][i] < 0 {
v.fromP1xP1(tmp1)
tables[j].SelectInto(multiple, -nafs[j][i])
tmp1.Sub(v, multiple)
}
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

420
vendor/filippo.io/edwards25519/field/fe.go generated vendored Normal file
View File

@ -0,0 +1,420 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}

16
vendor/filippo.io/edwards25519/field/fe_amd64.go generated vendored Normal file
View File

@ -0,0 +1,16 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package field
// feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape
func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape
func feSquare(out *Element, a *Element)

379
vendor/filippo.io/edwards25519/field/fe_amd64.s generated vendored Normal file
View File

@ -0,0 +1,379 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
#include "textflag.h"
// func feMul(out *Element, a *Element, b *Element)
TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ a+8(FP), CX
MOVQ b+16(FP), BX
// r0 = a0×b0
MOVQ (CX), AX
MULQ (BX)
MOVQ AX, DI
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
MULQ 8(BX)
MOVQ AX, R9
MOVQ DX, R8
// r1 += a1×b0
MOVQ 8(CX), AX
MULQ (BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
MULQ 16(BX)
MOVQ AX, R11
MOVQ DX, R10
// r2 += a1×b1
MOVQ 8(CX), AX
MULQ 8(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += a2×b0
MOVQ 16(CX), AX
MULQ (BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
MULQ 24(BX)
MOVQ AX, R13
MOVQ DX, R12
// r3 += a1×b2
MOVQ 8(CX), AX
MULQ 16(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a2×b1
MOVQ 16(CX), AX
MULQ 8(BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += a3×b0
MOVQ 24(CX), AX
MULQ (BX)
ADDQ AX, R13
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
MULQ 32(BX)
MOVQ AX, R15
MOVQ DX, R14
// r4 += a1×b3
MOVQ 8(CX), AX
MULQ 24(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a2×b2
MOVQ 16(CX), AX
MULQ 16(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a3×b1
MOVQ 24(CX), AX
MULQ 8(BX)
ADDQ AX, R15
ADCQ DX, R14
// r4 += a4×b0
MOVQ 32(CX), AX
MULQ (BX)
ADDQ AX, R15
ADCQ DX, R14
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, DI, SI
SHLQ $0x0d, R9, R8
SHLQ $0x0d, R11, R10
SHLQ $0x0d, R13, R12
SHLQ $0x0d, R15, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Second reduction chain (carryPropagate)
MOVQ DI, SI
SHRQ $0x33, SI
MOVQ R9, R8
SHRQ $0x33, R8
MOVQ R11, R10
SHRQ $0x33, R10
MOVQ R13, R12
SHRQ $0x33, R12
MOVQ R15, R14
SHRQ $0x33, R14
ANDQ AX, DI
IMUL3Q $0x13, R14, R14
ADDQ R14, DI
ANDQ AX, R9
ADDQ SI, R9
ANDQ AX, R11
ADDQ R8, R11
ANDQ AX, R13
ADDQ R10, R13
ANDQ AX, R15
ADDQ R12, R15
// Store output
MOVQ out+0(FP), AX
MOVQ DI, (AX)
MOVQ R9, 8(AX)
MOVQ R11, 16(AX)
MOVQ R13, 24(AX)
MOVQ R15, 32(AX)
RET
// func feSquare(out *Element, a *Element)
TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ a+8(FP), CX
// r0 = l0×l0
MOVQ (CX), AX
MULQ (CX)
MOVQ AX, SI
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 8(CX)
MOVQ AX, R8
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
MOVQ AX, R10
MOVQ DX, R9
// r2 += l1×l1
MOVQ 8(CX), AX
MULQ 8(CX)
ADDQ AX, R10
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
MOVQ AX, R12
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
SHLQ $0x01, AX
MULQ 32(CX)
MOVQ AX, R14
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
MULQ 16(CX)
ADDQ AX, R14
ADCQ DX, R13
// First reduction chain
MOVQ $0x0007ffffffffffff, AX
SHLQ $0x0d, SI, BX
SHLQ $0x0d, R8, DI
SHLQ $0x0d, R10, R9
SHLQ $0x0d, R12, R11
SHLQ $0x0d, R14, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Second reduction chain (carryPropagate)
MOVQ SI, BX
SHRQ $0x33, BX
MOVQ R8, DI
SHRQ $0x33, DI
MOVQ R10, R9
SHRQ $0x33, R9
MOVQ R12, R11
SHRQ $0x33, R11
MOVQ R14, R13
SHRQ $0x33, R13
ANDQ AX, SI
IMUL3Q $0x13, R13, R13
ADDQ R13, SI
ANDQ AX, R8
ADDQ BX, R8
ANDQ AX, R10
ADDQ DI, R10
ANDQ AX, R12
ADDQ R9, R12
ANDQ AX, R14
ADDQ R11, R14
// Store output
MOVQ out+0(FP), AX
MOVQ SI, (AX)
MOVQ R8, 8(AX)
MOVQ R10, 16(AX)
MOVQ R12, 24(AX)
MOVQ R14, 32(AX)
RET

12
vendor/filippo.io/edwards25519/field/fe_amd64_noasm.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
package field
func feMul(v, x, y *Element) { feMulGeneric(v, x, y) }
func feSquare(v, x *Element) { feSquareGeneric(v, x) }

16
vendor/filippo.io/edwards25519/field/fe_arm64.go generated vendored Normal file
View File

@ -0,0 +1,16 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}

42
vendor/filippo.io/edwards25519/field/fe_arm64.s generated vendored Normal file
View File

@ -0,0 +1,42 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET

12
vendor/filippo.io/edwards25519/field/fe_arm64_noasm.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}

50
vendor/filippo.io/edwards25519/field/fe_extra.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "errors"
// This file contains additional functionality that is not included in the
// upstream crypto/ed25519/edwards25519/field package.
// SetWideBytes sets v to x, where x is a 64-byte little-endian encoding, which
// is reduced modulo the field order. If x is not of the right length,
// SetWideBytes returns nil and an error, and the receiver is unchanged.
//
// SetWideBytes is not necessary to select a uniformly distributed value, and is
// only provided for compatibility: SetBytes can be used instead as the chance
// of bias is less than 2⁻²⁵⁰.
func (v *Element) SetWideBytes(x []byte) (*Element, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetWideBytes input size")
}
// Split the 64 bytes into two elements, and extract the most significant
// bit of each, which is ignored by SetBytes.
lo, _ := new(Element).SetBytes(x[:32])
loMSB := uint64(x[31] >> 7)
hi, _ := new(Element).SetBytes(x[32:])
hiMSB := uint64(x[63] >> 7)
// The output we want is
//
// v = lo + loMSB * 2²⁵⁵ + hi * 2²⁵⁶ + hiMSB * 2⁵¹¹
//
// which applying the reduction identity comes out to
//
// v = lo + loMSB * 19 + hi * 2 * 19 + hiMSB * 2 * 19²
//
// l0 will be the sum of a 52 bits value (lo.l0), plus a 5 bits value
// (loMSB * 19), a 6 bits value (hi.l0 * 2 * 19), and a 10 bits value
// (hiMSB * 2 * 19²), so it fits in a uint64.
v.l0 = lo.l0 + loMSB*19 + hi.l0*2*19 + hiMSB*2*19*19
v.l1 = lo.l1 + hi.l1*2*19
v.l2 = lo.l2 + hi.l2*2*19
v.l3 = lo.l3 + hi.l3*2*19
v.l4 = lo.l4 + hi.l4*2*19
return v.carryPropagate(), nil
}

266
vendor/filippo.io/edwards25519/field/fe_generic.go generated vendored Normal file
View File

@ -0,0 +1,266 @@
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}

343
vendor/filippo.io/edwards25519/scalar.go generated vendored Normal file
View File

@ -0,0 +1,343 @@
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to a uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}

1147
vendor/filippo.io/edwards25519/scalar_fiat.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

214
vendor/filippo.io/edwards25519/scalarmult.go generated vendored Normal file
View File

@ -0,0 +1,214 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}

129
vendor/filippo.io/edwards25519/tables.go generated vendored Normal file
View File

@ -0,0 +1,129 @@
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}

View File

@ -83,7 +83,12 @@ func parseHeaderParams(s string) (map[string]string, error) {
return params, errors.New("dkim: malformed header params") return params, errors.New("dkim: malformed header params")
} }
params[strings.TrimSpace(key)] = strings.TrimSpace(value) trimmedKey := strings.TrimSpace(key)
_, present := params[trimmedKey]
if present {
return params, errors.New("dkim: duplicate tag name")
}
params[trimmedKey] = strings.TrimSpace(value)
} }
return params, nil return params, nil
} }

View File

@ -100,7 +100,7 @@ func queryDNSTXT(domain, selector string, txtLookup txtLookupFunc) (*queryResult
func parsePublicKey(s string) (*queryResult, error) { func parsePublicKey(s string) (*queryResult, error) {
params, err := parseHeaderParams(s) params, err := parseHeaderParams(s)
if err != nil { if err != nil {
return nil, permFailError("key syntax error: " + err.Error()) return nil, permFailError("key record error: " + err.Error())
} }
res := new(queryResult) res := new(queryResult)

View File

@ -13,29 +13,41 @@
Aaron Hopkins <go-sql-driver at die.net> Aaron Hopkins <go-sql-driver at die.net>
Achille Roussel <achille.roussel at gmail.com> Achille Roussel <achille.roussel at gmail.com>
Aidan <aidan.liu at pingcap.com>
Alex Snast <alexsn at fb.com> Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com> Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com> Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com> Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com> Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il> Ariel Mashraki <ariel at mashraki.co.il>
Artur Melanchyk <artur.melanchyk@gmail.com>
Asta Xie <xiemengjun at gmail.com> Asta Xie <xiemengjun at gmail.com>
B Lamarche <blam413 at gmail.com>
Bes Dollma <bdollma@thousandeyes.com>
Bogdan Constantinescu <bog.con.bc at gmail.com>
Brad Higgins <brad at defined.net>
Brian Hendriks <brian at dolthub.com>
Bulat Gaifullin <gaifullinbf at gmail.com> Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu> Caine Jette <jette at alum.mit.edu>
Carlos Nieto <jose.carlos at menteslibres.net> Carlos Nieto <jose.carlos at menteslibres.net>
Chris Kirkland <chriskirkland at github.com> Chris Kirkland <chriskirkland at github.com>
Chris Moos <chris at tech9computers.com> Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com> Craig Wilson <craiggwilson at gmail.com>
Daemonxiao <735462752 at qq.com>
Daniel Montoya <dsmontoyam at gmail.com> Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com> Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl> Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com> Dave Protasowski <dprotaso at gmail.com>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me> DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com> Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io> Erwan Martin <hello at erwan.io>
Evan Elias <evan at skeema.net>
Evan Shaw <evan at vendhq.com> Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com> Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com> Gustavo Kristic <gkristic at gmail.com>
Gusted <postmaster at gusted.xyz>
Hajime Nakagami <nakagami at gmail.com> Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com> Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com> Henri Yandell <flamefew at gmail.com>
@ -45,13 +57,18 @@ ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com> Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com> INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com> Jacek Szwec <szwec.jacek at gmail.com>
Jakub Adamus <kratky at zobak.cz>
James Harr <james.harr at gmail.com> James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net> Janek Vedock <janekvedock at comcast.net>
Jason Ng <oblitorum at gmail.com>
Jean-Yves Pellé <jy at pelle.link>
Jeff Hodges <jeff at somethingsimilar.com> Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com> Jeffrey Charles <jeffreycharles at gmail.com>
Jennifer Purevsuren <jennifer at dolthub.com>
Jerome Meyer <jxmeyer at gmail.com> Jerome Meyer <jxmeyer at gmail.com>
Jiajia Zhong <zhong2plus at gmail.com> Jiajia Zhong <zhong2plus at gmail.com>
Jian Zhen <zhenjl at gmail.com> Jian Zhen <zhenjl at gmail.com>
Joe Mann <contact at joemann.co.uk>
Joshua Prunier <joshua.prunier at gmail.com> Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com> Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com> Julien Schmidt <go-sql-driver at julienschmidt.com>
@ -72,17 +89,23 @@ Lunny Xiao <xiaolunwen at gmail.com>
Luke Scott <luke at webconnex.com> Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com> Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com> Michael Woolnough <michael.woolnough at gmail.com>
Nao Yokotsuka <yokotukanao at gmail.com>
Nathanial Murphy <nathanial.murphy at gmail.com> Nathanial Murphy <nathanial.murphy at gmail.com>
Nicola Peduzzi <thenikso at gmail.com> Nicola Peduzzi <thenikso at gmail.com>
Oliver Bone <owbone at github.com>
Olivier Mengué <dolmen at cpan.org> Olivier Mengué <dolmen at cpan.org>
oscarzhao <oscarzhaosl at gmail.com> oscarzhao <oscarzhaosl at gmail.com>
Paul Bonser <misterpib at gmail.com> Paul Bonser <misterpib at gmail.com>
Paulius Lozys <pauliuslozys at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com> Peter Schultz <peter.schultz at classmarkets.com>
Phil Porada <philporada at gmail.com>
Minh Quang <minhquang4334 at gmail.com>
Rebecca Chin <rchin at pivotal.io> Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com> Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com> Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com> Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com> Runrioter Wung <runrioter at gmail.com>
Samantha Frank <hello at entropy.cat>
Santhosh Kumar Tekuri <santhosh.tekuri at gmail.com> Santhosh Kumar Tekuri <santhosh.tekuri at gmail.com>
Sho Iizuka <sho.i518 at gmail.com> Sho Iizuka <sho.i518 at gmail.com>
Sho Ikeda <suicaicoca at gmail.com> Sho Ikeda <suicaicoca at gmail.com>
@ -93,6 +116,7 @@ Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com> Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk> Steven Hartland <steven.hartland at multiplay.co.uk>
Tan Jinhua <312841925 at qq.com> Tan Jinhua <312841925 at qq.com>
Tetsuro Aoki <t.aoki1130 at gmail.com>
Thomas Wodarek <wodarekwebpage at gmail.com> Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com> Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me> Tom Jenkinson <tom at tjenkinson.me>
@ -102,6 +126,7 @@ Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com> Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc> Xiuming Chen <cc at cxm.cc>
Xuehong Chan <chanxuehong at gmail.com> Xuehong Chan <chanxuehong at gmail.com>
Zhang Xiang <angwerzx at 126.com>
Zhenye Xie <xiezhenye at gmail.com> Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com> Zhixin Wen <john.wenzhixin at gmail.com>
Ziheng Lyu <zihenglv at gmail.com> Ziheng Lyu <zihenglv at gmail.com>
@ -110,15 +135,21 @@ Ziheng Lyu <zihenglv at gmail.com>
Barracuda Networks, Inc. Barracuda Networks, Inc.
Counting Ltd. Counting Ltd.
Defined Networking Inc.
DigitalOcean Inc. DigitalOcean Inc.
Dolthub Inc.
dyves labs AG dyves labs AG
Facebook Inc. Facebook Inc.
GitHub Inc. GitHub Inc.
Google Inc. Google Inc.
InfoSum Ltd. InfoSum Ltd.
Keybase Inc. Keybase Inc.
Microsoft Corp.
Multiplay Ltd. Multiplay Ltd.
Percona LLC Percona LLC
PingCAP Inc.
Pivotal Inc. Pivotal Inc.
Shattered Silicon Ltd.
Stripe Inc. Stripe Inc.
ThousandEyes
Zendesk Inc. Zendesk Inc.

View File

@ -1,3 +1,110 @@
# Changelog
## v1.9.3 (2025-06-13)
* `tx.Commit()` and `tx.Rollback()` returned `ErrInvalidConn` always.
Now they return cached real error if present. (#1690)
* Optimize reading small resultsets to fix performance regression
introduced by compression protocol support. (#1707)
* Fix `db.Ping()` on compressed connection. (#1723)
## v1.9.2 (2025-04-07)
v1.9.2 is a re-release of v1.9.1 due to a release process issue; no changes were made to the content.
## v1.9.1 (2025-03-21)
### Major Changes
* Add Charset() option. (#1679)
### Bugfixes
* go.mod: fix go version format (#1682)
* Fix FormatDSN missing ConnectionAttributes (#1619)
## v1.9.0 (2025-02-18)
### Major Changes
- Implement zlib compression. (#1487)
- Supported Go version is updated to Go 1.21+. (#1639)
- Add support for VECTOR type introduced in MySQL 9.0. (#1609)
- Config object can have custom dial function. (#1527)
### Bugfixes
- Fix auth errors when username/password are too long. (#1625)
- Check if MySQL supports CLIENT_CONNECT_ATTRS before sending client attributes. (#1640)
- Fix auth switch request handling. (#1666)
### Other changes
- Add "filename:line" prefix to log in go-mysql. Custom loggers now show it. (#1589)
- Improve error handling. It reduces the "busy buffer" errors. (#1595, #1601, #1641)
- Use `strconv.Atoi` to parse max_allowed_packet. (#1661)
- `rejectReadOnly` option now handles ER_READ_ONLY_MODE (1290) error too. (#1660)
## Version 1.8.1 (2024-03-26)
Bugfixes:
- fix race condition when context is canceled in [#1562](https://github.com/go-sql-driver/mysql/pull/1562) and [#1570](https://github.com/go-sql-driver/mysql/pull/1570)
## Version 1.8.0 (2024-03-09)
Major Changes:
- Use `SET NAMES charset COLLATE collation`. by @methane in [#1437](https://github.com/go-sql-driver/mysql/pull/1437)
- Older go-mysql-driver used `collation_id` in the handshake packet. But it caused collation mismatch in some situation.
- If you don't specify charset nor collation, go-mysql-driver sends `SET NAMES utf8mb4` for new connection. This uses server's default collation for utf8mb4.
- If you specify charset, go-mysql-driver sends `SET NAMES <charset>`. This uses the server's default collation for `<charset>`.
- If you specify collation and/or charset, go-mysql-driver sends `SET NAMES charset COLLATE collation`.
- PathEscape dbname in DSN. by @methane in [#1432](https://github.com/go-sql-driver/mysql/pull/1432)
- This is backward incompatible in rare case. Check your DSN.
- Drop Go 1.13-17 support by @methane in [#1420](https://github.com/go-sql-driver/mysql/pull/1420)
- Use Go 1.18+
- Parse numbers on text protocol too by @methane in [#1452](https://github.com/go-sql-driver/mysql/pull/1452)
- When text protocol is used, go-mysql-driver passed bare `[]byte` to database/sql for avoid unnecessary allocation and conversion.
- If user specified `*any` to `Scan()`, database/sql passed the `[]byte` into the target variable.
- This confused users because most user doesn't know when text/binary protocol used.
- go-mysql-driver 1.8 converts integer/float values into int64/double even in text protocol. This doesn't increase allocation compared to `[]byte` and conversion cost is negatable.
- New options start using the Functional Option Pattern to avoid increasing technical debt in the Config object. Future version may introduce Functional Option for existing options, but not for now.
- Make TimeTruncate functional option by @methane in [1552](https://github.com/go-sql-driver/mysql/pull/1552)
- Add BeforeConnect callback to configuration object by @ItalyPaleAle in [#1469](https://github.com/go-sql-driver/mysql/pull/1469)
Other changes:
- Adding DeregisterDialContext to prevent memory leaks with dialers we don't need anymore by @jypelle in https://github.com/go-sql-driver/mysql/pull/1422
- Make logger configurable per connection by @frozenbonito in https://github.com/go-sql-driver/mysql/pull/1408
- Fix ColumnType.DatabaseTypeName for mediumint unsigned by @evanelias in https://github.com/go-sql-driver/mysql/pull/1428
- Add connection attributes by @Daemonxiao in https://github.com/go-sql-driver/mysql/pull/1389
- Stop `ColumnTypeScanType()` from returning `sql.RawBytes` by @methane in https://github.com/go-sql-driver/mysql/pull/1424
- Exec() now provides access to status of multiple statements. by @mherr-google in https://github.com/go-sql-driver/mysql/pull/1309
- Allow to change (or disable) the default driver name for registration by @dolmen in https://github.com/go-sql-driver/mysql/pull/1499
- Add default connection attribute '_server_host' by @oblitorum in https://github.com/go-sql-driver/mysql/pull/1506
- QueryUnescape DSN ConnectionAttribute value by @zhangyangyu in https://github.com/go-sql-driver/mysql/pull/1470
- Add client_ed25519 authentication by @Gusted in https://github.com/go-sql-driver/mysql/pull/1518
## Version 1.7.1 (2023-04-25)
Changes:
- bump actions/checkout@v3 and actions/setup-go@v3 (#1375)
- Add go1.20 and mariadb10.11 to the testing matrix (#1403)
- Increase default maxAllowedPacket size. (#1411)
Bugfixes:
- Use SET syntax as specified in the MySQL documentation (#1402)
## Version 1.7 (2022-11-29) ## Version 1.7 (2022-11-29)
Changes: Changes:
@ -149,7 +256,7 @@ New Features:
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
- Support for returning table alias on Columns() (#289, #359, #382) - Support for returning table alias on Columns() (#289, #359, #382)
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) - Placeholder interpolation, can be activated with the DSN parameter `interpolateParams=true` (#309, #318, #490)
- Support for uint64 parameters with high bit set (#332, #345) - Support for uint64 parameters with high bit set (#332, #345)
- Cleartext authentication plugin support (#327) - Cleartext authentication plugin support (#327)
- Exported ParseDSN function and the Config struct (#403, #419, #429) - Exported ParseDSN function and the Config struct (#403, #419, #429)
@ -193,7 +300,7 @@ Changes:
- Also exported the MySQLWarning type - Also exported the MySQLWarning type
- mysqlConn.Close returns the first error encountered instead of ignoring all errors - mysqlConn.Close returns the first error encountered instead of ignoring all errors
- writePacket() automatically writes the packet size to the header - writePacket() automatically writes the packet size to the header
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets - readPacket() uses an iterative approach instead of the recursive approach to merge split packets
New Features: New Features:
@ -241,7 +348,7 @@ Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time` - Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly - Split packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read - Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections - Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values - Fixed empty string producing false nil values

View File

@ -38,17 +38,26 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
* Optional `time.Time` parsing * Optional `time.Time` parsing
* Optional placeholder interpolation * Optional placeholder interpolation
* Supports zlib compression.
## Requirements ## Requirements
* Go 1.13 or higher. We aim to support the 3 latest versions of Go.
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) * Go 1.21 or higher. We aim to support the 3 latest versions of Go.
* MySQL (5.7+) and MariaDB (10.5+) are supported.
* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP.
* Do not ask questions about TiDB in our issue tracker or forum.
* [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang)
* [Forum](https://ask.pingcap.com/)
* go-mysql would work with Percona Server, Google CloudSQL or Sphinx (2.2.3+).
* Maintainers won't support them. Do not expect issues are investigated and resolved by maintainers.
* Investigate issues yourself and please send a pull request to fix it.
--------------------------------------- ---------------------------------------
## Installation ## Installation
Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell:
```bash ```bash
$ go get -u github.com/go-sql-driver/mysql go get -u github.com/go-sql-driver/mysql
``` ```
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
@ -114,6 +123,12 @@ This has the same effect as an empty DSN string:
``` ```
`dbname` is escaped by [PathEscape()](https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes:
```
/dbname%2Fwithslash
```
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
#### Password #### Password
@ -121,7 +136,7 @@ Passwords can consist of any character. Escaping is **not** necessary.
#### Protocol #### Protocol
See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available.
In general you should use an Unix domain socket if available and TCP otherwise for best performance. In general you should use a Unix domain socket if available and TCP otherwise for best performance.
#### Address #### Address
For TCP and UDP networks, addresses have the form `host[:port]`. For TCP and UDP networks, addresses have the form `host[:port]`.
@ -145,7 +160,7 @@ Default: false
``` ```
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. `allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) [*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)
##### `allowCleartextPasswords` ##### `allowCleartextPasswords`
@ -194,10 +209,9 @@ Valid Values: <name>
Default: none Default: none
``` ```
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset fails. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
Usage of the `charset` parameter is discouraged because it issues additional queries to the server. See also [Unicode Support](#unicode-support).
Unless you need the fallback behavior, please use `collation` instead.
##### `checkConnLiveness` ##### `checkConnLiveness`
@ -226,6 +240,7 @@ The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You s
Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)).
See also [Unicode Support](#unicode-support).
##### `clientFoundRows` ##### `clientFoundRows`
@ -253,6 +268,16 @@ SELECT u.id FROM users as u
will return `u.id` instead of just `id` if `columnsWithAlias=true`. will return `u.id` instead of just `id` if `columnsWithAlias=true`.
##### `compress`
```
Type: bool
Valid Values: true, false
Default: false
```
Toggles zlib compression. false by default.
##### `interpolateParams` ##### `interpolateParams`
``` ```
@ -279,13 +304,22 @@ Note that this sets the location for time.Time values but does not change MySQL'
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
##### `timeTruncate`
```
Type: duration
Default: 0
```
[Truncate time values](https://pkg.go.dev/time#Duration.Truncate) to the specified duration. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `maxAllowedPacket` ##### `maxAllowedPacket`
``` ```
Type: decimal number Type: decimal number
Default: 4194304 Default: 64*1024*1024
``` ```
Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
##### `multiStatements` ##### `multiStatements`
@ -295,9 +329,25 @@ Valid Values: true, false
Default: false Default: false
``` ```
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. Allow multiple statements in one query. This can be used to bach multiple queries. Use [Rows.NextResultSet()](https://pkg.go.dev/database/sql#Rows.NextResultSet) to get result of the second and subsequent queries.
When `multiStatements` is used, `?` parameters must only be used in the first statement. When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly.
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
```go
conn, _ := db.Conn(ctx)
conn.Raw(func(conn any) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
UPDATE point SET x = 1 WHERE y = 2;
UPDATE point SET x = 2 WHERE y = 3;
`, nil)
// Both slices have 2 elements.
log.Print(res.(mysql.Result).AllRowsAffected())
log.Print(res.(mysql.Result).AllLastInsertIds())
})
```
##### `parseTime` ##### `parseTime`
@ -393,6 +443,15 @@ Default: 0
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `connectionAttributes`
```
Type: comma-delimited string of user-defined "key:value" pairs
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
Default: none
```
[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.
##### System Variables ##### System Variables
@ -465,12 +524,15 @@ user:password@/
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
## `ColumnType` Support ## `ColumnType` Support
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`. This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`.
## `context.Context` Support ## `context.Context` Support
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
> [!IMPORTANT]
> The `QueryContext`, `ExecContext`, etc. variants provided by `database/sql` will cause the connection to be closed if the provided context is cancelled or timed out before the result is received by the driver.
### `LOAD DATA LOCAL INFILE` support ### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`): For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
@ -478,7 +540,7 @@ For this feature you need direct access to the package. Therefore you must chang
import "github.com/go-sql-driver/mysql" import "github.com/go-sql-driver/mysql"
``` ```
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)).
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
@ -496,9 +558,11 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v
### Unicode support ### Unicode support
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter.
Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. - When only the `charset` is specified, the `SET NAMES <charset>` query is sent and the server's default collation is used.
- When both the `charset` and `collation` are specified, the `SET NAMES <charset> COLLATE <collation>` query is sent.
- When only the `collation` is specified, the collation is specified in the protocol handshake and the `SET NAMES` query is not sent. This can save one roundtrip, but note that the server may ignore the specified collation silently and use the server's default charset/collation instead.
See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support.

View File

@ -1,19 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.19
// +build go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
type atomicBool = atomic.Bool

View File

@ -1,47 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !go1.19
// +build !go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
// atomicBool is an implementation of atomic.Bool for older version of Go.
// it is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_ noCopy
value uint32
}
// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}

View File

@ -13,10 +13,13 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/sha512"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"sync" "sync"
"filippo.io/edwards25519"
) )
// server pub keys registry // server pub keys registry
@ -33,7 +36,7 @@ var (
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver // Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified. // after registering it and may not be modified.
// //
// data, err := ioutil.ReadFile("mykey.pem") // data, err := os.ReadFile("mykey.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
} }
// authEd25519 does ed25519 authentication used by MariaDB.
func authEd25519(scramble []byte, password string) ([]byte, error) {
// Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c
// Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207
h := sha512.Sum512([]byte(password))
s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
}
A := (&edwards25519.Point{}).ScalarBaseMult(s)
mh := sha512.New()
mh.Write(h[32:])
mh.Write(scramble)
messageDigest := mh.Sum(nil)
r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest)
if err != nil {
return nil, err
}
R := (&edwards25519.Point{}).ScalarBaseMult(r)
kh := sha512.New()
kh.Write(R.Bytes())
kh.Write(A.Bytes())
kh.Write(scramble)
hramDigest := kh.Sum(nil)
k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest)
if err != nil {
return nil, err
}
S := k.MultiplyAdd(k, s, r)
return append(R.Bytes(), S.Bytes()...), nil
}
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil { if err != nil {
@ -290,8 +331,14 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err return enc, err
case "client_ed25519":
if len(authData) != 32 {
return nil, ErrMalformPkt
}
return authEd25519(authData, mc.cfg.Passwd)
default: default:
errLog.Print("unknown auth plugin:", plugin) mc.log("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin return nil, ErrUnknownPlugin
} }
} }
@ -338,7 +385,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
switch plugin { switch plugin {
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ // https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/
case "caching_sha2_password": case "caching_sha2_password":
switch len(authData) { switch len(authData) {
case 0: case 0:
@ -346,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case 1: case 1:
switch authData[0] { switch authData[0] {
case cachingSha2PasswordFastAuthSuccess: case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil { if err = mc.resultUnchanged().readResultOK(); err == nil {
return nil // auth successful return nil // auth successful
} }
@ -376,13 +423,13 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
} }
if data[0] != iAuthMoreData { if data[0] != iAuthMoreData {
return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication") return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")
} }
// parse public key // parse public key
block, rest := pem.Decode(data[1:]) block, rest := pem.Decode(data[1:])
if block == nil { if block == nil {
return fmt.Errorf("No Pem data found, data: %s", rest) return fmt.Errorf("no pem data found, data: %s", rest)
} }
pkix, err := x509.ParsePKIXPublicKey(block.Bytes) pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
@ -397,7 +444,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err return err
} }
} }
return mc.readResultOK() return mc.resultUnchanged().readResultOK()
default: default:
return ErrMalformPkt return ErrMalformPkt
@ -426,7 +473,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
if err != nil { if err != nil {
return err return err
} }
return mc.readResultOK() return mc.resultUnchanged().readResultOK()
} }
default: default:

View File

@ -10,54 +10,47 @@ package mysql
import ( import (
"io" "io"
"net"
"time"
) )
const defaultBufSize = 4096 const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024 const maxCachedBufSize = 256 * 1024
// readerFunc is a function that compatible with io.Reader.
// We use this function type instead of io.Reader because we want to
// just pass mc.readWithTimeout.
type readerFunc func([]byte) (int, error)
// A buffer which is used for both reading and writing. // A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous. // This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection. // In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish // The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case. // Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct { type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal. buf []byte // read buffer.
nc net.Conn cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize.
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
} }
// newBuffer allocates and returns a new buffer. // newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer { func newBuffer() buffer {
fg := make([]byte, defaultBufSize)
return buffer{ return buffer{
buf: fg, cachedBuf: make([]byte, defaultBufSize),
nc: nc,
dbuf: [2][]byte{fg, nil},
} }
} }
// flip replaces the active buffer with the background buffer // busy returns true if the read buffer is not empty.
// this is a delayed flip that simply increases the buffer counter; func (b *buffer) busy() bool {
// the actual flip will be performed the next time we call `buffer.fill` return len(b.buf) > 0
func (b *buffer) flip() {
b.flipcnt += 1
} }
// fill reads into the buffer until at least _need_ bytes are in it // len returns how many bytes in the read buffer.
func (b *buffer) fill(need int) error { func (b *buffer) len() int {
n := b.length return len(b.buf)
// fill data into its double-buffering target: if we've called }
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move // fill reads into the read buffer until at least _need_ bytes are in it.
// the contents of the current buffer to the front before filling it func (b *buffer) fill(need int, r readerFunc) error {
dest := b.dbuf[b.flipcnt&1] // we'll move the contents of the current buffer to dest before filling it.
dest := b.cachedBuf
// grow buffer if necessary to fit the whole packet. // grow buffer if necessary to fit the whole packet.
if need > len(dest) { if need > len(dest) {
@ -67,64 +60,41 @@ func (b *buffer) fill(need int) error {
// if the allocated buffer is not too large, move it to backing storage // if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads // to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize { if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest b.cachedBuf = dest
} }
} }
// if we're filling the fg buffer, move the existing data to the start of it. // move the existing data to the start of the buffer.
// if we're filling the bg buffer, copy over the data n := len(b.buf)
if n > 0 { copy(dest[:n], b.buf)
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for { for {
if b.timeout > 0 { nn, err := r(dest[n:])
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn n += nn
switch err { if err == nil && n < need {
case nil: continue
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
} }
b.buf = dest[:n]
if err == io.EOF {
if n < need {
err = io.ErrUnexpectedEOF
} else {
err = nil
}
}
return err
} }
} }
// returns next N bytes from buffer. // returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read // The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) { func (b *buffer) readNext(need int) []byte {
if b.length < need { data := b.buf[:need:need]
// refill b.buf = b.buf[need:]
if err := b.fill(need); err != nil { return data
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
} }
// takeBuffer returns a buffer with the requested size. // takeBuffer returns a buffer with the requested size.
@ -132,18 +102,18 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// Otherwise a bigger buffer is made. // Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) { func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
// test (cheap) general case first // test (cheap) general case first
if length <= cap(b.buf) { if length <= len(b.cachedBuf) {
return b.buf[:length], nil return b.cachedBuf[:length], nil
} }
if length < maxPacketSize { if length < maxCachedBufSize {
b.buf = make([]byte, length) b.cachedBuf = make([]byte, length)
return b.buf, nil return b.cachedBuf, nil
} }
// buffer is larger than we want to store. // buffer is larger than we want to store.
@ -154,10 +124,10 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) {
// known to be smaller than defaultBufSize. // known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
return b.buf[:length], nil return b.cachedBuf[:length], nil
} }
// takeCompleteBuffer returns the complete existing buffer. // takeCompleteBuffer returns the complete existing buffer.
@ -165,18 +135,15 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
// cap and len of the returned buffer will be equal. // cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time. // Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) { func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 { if b.busy() {
return nil, ErrBusyBuffer return nil, ErrBusyBuffer
} }
return b.buf, nil return b.cachedBuf, nil
} }
// store stores buf, an updated buffer, if its suitable to do so. // store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error { func (b *buffer) store(buf []byte) {
if b.length > 0 { if cap(buf) <= maxCachedBufSize && cap(buf) > cap(b.cachedBuf) {
return ErrBusyBuffer b.cachedBuf = buf[:cap(buf)]
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
} }
return nil
} }

View File

@ -8,8 +8,8 @@
package mysql package mysql
const defaultCollation = "utf8mb4_general_ci" const defaultCollationID = 45 // utf8mb4_general_ci
const binaryCollation = "binary" const binaryCollationID = 63
// A list of available collations mapped to the internal ID. // A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query: // To update this map use the following MySQL query:

213
vendor/github.com/go-sql-driver/mysql/compress.go generated vendored Normal file
View File

@ -0,0 +1,213 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
)
var (
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
)
func init() {
zrPool = &sync.Pool{
New: func() any { return nil },
}
zwPool = &sync.Pool{
New: func() any {
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
if err != nil {
panic(err) // compress/zlib return non-nil error only if level is invalid
}
return zw
},
}
}
func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
br := bytes.NewReader(src)
var zr io.ReadCloser
var err error
if a := zrPool.Get(); a == nil {
if zr, err = zlib.NewReader(br); err != nil {
return 0, err
}
} else {
zr = a.(io.ReadCloser)
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
return 0, err
}
}
n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
err = zr.Close() // zr.Close() may return chuecksum error.
zrPool.Put(zr)
return int(n), err
}
func zCompress(src []byte, dst io.Writer) error {
zw := zwPool.Get().(*zlib.Writer)
zw.Reset(dst)
if _, err := zw.Write(src); err != nil {
return err
}
err := zw.Close()
zwPool.Put(zw)
return err
}
type compIO struct {
mc *mysqlConn
buff bytes.Buffer
}
func newCompIO(mc *mysqlConn) *compIO {
return &compIO{
mc: mc,
}
}
func (c *compIO) reset() {
c.buff.Reset()
}
func (c *compIO) readNext(need int) ([]byte, error) {
for c.buff.Len() < need {
if err := c.readCompressedPacket(); err != nil {
return nil, err
}
}
data := c.buff.Next(need)
return data[:need:need], nil // prevent caller writes into c.buff
}
func (c *compIO) readCompressedPacket() error {
header, err := c.mc.readNext(7)
if err != nil {
return err
}
_ = header[6] // bounds check hint to compiler; guaranteed by readNext
// compressed header structure
comprLength := getUint24(header[0:3])
compressionSequence := header[3]
uncompressedLength := getUint24(header[4:7])
if debug {
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
}
// Do not return ErrPktSync here.
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
// before receiving all packets from client. In this case, seqnr is younger than expected.
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
if debug && compressionSequence != c.mc.compressSequence {
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
c.mc.compressSequence, compressionSequence)
}
c.mc.compressSequence = compressionSequence + 1
comprData, err := c.mc.readNext(comprLength)
if err != nil {
return err
}
// if payload is uncompressed, its length will be specified as zero, and its
// true length is contained in comprLength
if uncompressedLength == 0 {
c.buff.Write(comprData)
return nil
}
// use existing capacity in bytesBuf if possible
c.buff.Grow(uncompressedLength)
nread, err := zDecompress(comprData, &c.buff)
if err != nil {
return err
}
if nread != uncompressedLength {
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
uncompressedLength, nread)
}
return nil
}
const minCompressLength = 150
const maxPayloadLen = maxPacketSize - 4
// writePackets sends one or some packets with compression.
// Use this instead of mc.netConn.Write() when mc.compress is true.
func (c *compIO) writePackets(packets []byte) (int, error) {
totalBytes := len(packets)
blankHeader := make([]byte, 7)
buf := &c.buff
for len(packets) > 0 {
payloadLen := min(maxPayloadLen, len(packets))
payload := packets[:payloadLen]
uncompressedLen := payloadLen
buf.Reset()
buf.Write(blankHeader) // Buffer.Write() never returns error
// If payload is less than minCompressLength, don't compress.
if uncompressedLen < minCompressLength {
buf.Write(payload)
uncompressedLen = 0
} else {
err := zCompress(payload, buf)
if debug && err != nil {
fmt.Printf("zCompress error: %v", err)
}
// do not compress if compressed data is larger than uncompressed data
// I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes.
if err != nil || buf.Len() >= uncompressedLen {
buf.Reset()
buf.Write(blankHeader)
buf.Write(payload)
uncompressedLen = 0
}
}
if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
// To allow returning ErrBadConn when sending really 0 bytes, we sum
// up compressed bytes that is returned by underlying Write().
return totalBytes - len(packets) + n, err
}
packets = packets[payloadLen:]
}
return totalBytes, nil
}
// writeCompressedPacket writes a compressed packet with header.
// data should start with 7 size space for header followed by payload.
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
mc := c.mc
comprLength := len(data) - 7
if debug {
fmt.Printf(
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v\n",
comprLength, uncompressedLen, mc.compressSequence)
}
// compression header
putUint24(data[0:3], comprLength)
data[3] = mc.compressSequence
putUint24(data[4:7], uncompressedLen)
mc.compressSequence++
return mc.writeWithTimeout(data)
}

View File

@ -13,28 +13,32 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net" "net"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
type mysqlConn struct { type mysqlConn struct {
buf buffer buf buffer
netConn net.Conn netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection. rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64 result mysqlResult // managed by clearResult() and handleOkPacket().
insertId uint64 compIO *compIO
cfg *Config cfg *Config
connector *connector
maxAllowedPacket int maxAllowedPacket int
maxWriteSize int maxWriteSize int
writeTimeout time.Duration
flags clientFlag flags clientFlag
status statusFlag status statusFlag
sequence uint8 sequence uint8
compressSequence uint8
parseTime bool parseTime bool
reset bool // set when the Go SQL package calls ResetSession compress bool
// for context support (Go 1.8+) // for context support (Go 1.8+)
watching bool watching bool
@ -42,61 +46,92 @@ type mysqlConn struct {
closech chan struct{} closech chan struct{}
finished chan<- struct{} finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed closed atomic.Bool // set when conn is closed, before closech is closed
}
// Helper function to call per-connection logger.
func (mc *mysqlConn) log(v ...any) {
_, filename, lineno, ok := runtime.Caller(1)
if ok {
pos := strings.LastIndexByte(filename, '/')
if pos != -1 {
filename = filename[pos+1:]
}
prefix := fmt.Sprintf("%s:%d ", filename, lineno)
v = append([]any{prefix}, v...)
}
mc.cfg.Logger.Print(v...)
}
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
to := mc.cfg.ReadTimeout
if to > 0 {
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Read(b)
}
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
to := mc.cfg.WriteTimeout
if to > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Write(b)
}
func (mc *mysqlConn) resetSequence() {
mc.sequence = 0
mc.compressSequence = 0
}
// syncSequence must be called when finished writing some packet and before start reading.
func (mc *mysqlConn) syncSequence() {
// Syncs compressionSequence to sequence.
// This is not documented but done in `net_flush()` in MySQL and MariaDB.
// https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171
// https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293
if mc.compress {
mc.sequence = mc.compressSequence
mc.compIO.reset()
}
} }
// Handles parameters set in DSN after the connection is established // Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) { func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
// Charset: character_set_connection, character_set_client, character_set_results
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// Other system vars accumulated in a single SET command for param, val := range mc.cfg.Params {
default: if cmdSet.Len() == 0 {
if cmdSet.Len() == 0 { // Heuristic: 29 chars for each other key=value to reduce reallocations
// Heuristic: 29 chars for each other key=value to reduce reallocations cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) cmdSet.WriteString("SET ")
cmdSet.WriteString("SET ") } else {
} else { cmdSet.WriteString(", ")
cmdSet.WriteByte(',')
}
cmdSet.WriteString(param)
cmdSet.WriteByte('=')
cmdSet.WriteString(val)
} }
cmdSet.WriteString(param)
cmdSet.WriteString(" = ")
cmdSet.WriteString(val)
} }
if cmdSet.Len() > 0 { if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String()) err = mc.exec(cmdSet.String())
if err != nil {
return
}
} }
return return
} }
// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn.
// This function is used to return driver.ErrBadConn only when safe to retry.
func (mc *mysqlConn) markBadConn(err error) error { func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil { if err == errBadConnNoWrite {
return err return driver.ErrBadConn
} }
if err != errBadConnNoWrite { return err
return err
}
return driver.ErrBadConn
} }
func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Begin() (driver.Tx, error) {
@ -105,7 +140,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
var q string var q string
@ -126,12 +160,16 @@ func (mc *mysqlConn) Close() (err error) {
if !mc.closed.Load() { if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit) err = mc.writeCommandPacket(comQuit)
} }
mc.close()
mc.cleanup()
return return
} }
// close closes the network connection and clear results without sending COM_QUIT.
func (mc *mysqlConn) close() {
mc.cleanup()
mc.clearResult()
}
// Closes the network connection and unsets internal variables. Do not call this // Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function // function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already // is called before auth or on auth failure because MySQL will have already
@ -143,12 +181,16 @@ func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent // Makes cleanup idempotent
close(mc.closech) close(mc.closech)
if mc.netConn == nil { conn := mc.rawConn
if conn == nil {
return return
} }
if err := mc.netConn.Close(); err != nil { if err := conn.Close(); err != nil {
errLog.Print(err) mc.log("closing connection:", err)
} }
// This function can be called from multiple goroutines.
// So we can not mc.clearResult() here.
// Caller should do it if they are in safe goroutine.
} }
func (mc *mysqlConn) error() error { func (mc *mysqlConn) error() error {
@ -163,14 +205,13 @@ func (mc *mysqlConn) error() error {
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query) err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil { if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here. // STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
errLog.Print(err) mc.log(err)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -204,8 +245,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer() buf, err := mc.buf.takeCompleteBuffer()
if err != nil { if err != nil {
// can not take the buffer. Something must be wrong with the connection // can not take the buffer. Something must be wrong with the connection
errLog.Print(err) mc.cleanup()
return "", ErrInvalidConn // interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
} }
buf = buf[:0] buf = buf[:0]
argPos := 0 argPos := 0
@ -246,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf = append(buf, "'0000-00-00'"...) buf = append(buf, "'0000-00-00'"...)
} else { } else {
buf = append(buf, '\'') buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -296,7 +339,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
if len(args) != 0 { if len(args) != 0 {
@ -310,28 +352,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
} }
query = prepared query = prepared
} }
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query) err := mc.exec(query)
if err == nil { if err == nil {
return &mysqlResult{ copied := mc.result
affectedRows: int64(mc.affectedRows), return &copied, err
insertId: int64(mc.insertId),
}, err
} }
return nil, mc.markBadConn(err) return nil, mc.markBadConn(err)
} }
// Internal function to execute commands // Internal function to execute commands
func (mc *mysqlConn) exec(query string) error { func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// Send command // Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil { if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err) return mc.markBadConn(err)
} }
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil { if err != nil {
return err return err
} }
@ -348,7 +387,7 @@ func (mc *mysqlConn) exec(query string) error {
} }
} }
return mc.discardResults() return handleOk.discardResults()
} }
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@ -356,8 +395,9 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()
if mc.closed.Load() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
if len(args) != 0 { if len(args) != 0 {
@ -373,43 +413,47 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
} }
// Send command // Send command
err := mc.writeCommandPacketStr(comQuery, query) err := mc.writeCommandPacketStr(comQuery, query)
if err == nil { if err != nil {
// Read Result return nil, mc.markBadConn(err)
var resLen int }
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 { // Read Result
rows.rs.done = true var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
switch err := rows.NextResultSet(); err { rows := new(textRows)
case nil, io.EOF: rows.mc = mc
return rows, nil
default:
return nil, err
}
}
// Columns if resLen == 0 {
rows.rs.columns, err = mc.readColumns(resLen) rows.rs.done = true
return rows, err
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
} }
} }
return nil, mc.markBadConn(err)
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
} }
// Gets the value of the given MySQL System Variable // Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read // The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command // Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err return nil, err
} }
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() resLen, err := handleOk.readResultSetHeaderPacket()
if err == nil { if err == nil {
rows := new(textRows) rows := new(textRows)
rows.mc = mc rows.mc = mc
@ -430,7 +474,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
return nil, err return nil, err
} }
// finish is called when the query has canceled. // cancel is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) { func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err) mc.canceled.Set(err)
mc.cleanup() mc.cleanup()
@ -451,7 +495,6 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface // Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) { func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn return driver.ErrBadConn
} }
@ -460,11 +503,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
} }
defer mc.finish() defer mc.finish()
handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil { if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err) return mc.markBadConn(err)
} }
return mc.readResultOK() return handleOk.readResultOK()
} }
// BeginTx implements driver.ConnBeginTx interface // BeginTx implements driver.ConnBeginTx interface
@ -636,15 +680,42 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter. // ResetSession implements driver.SessionResetter.
// (From Go 1.10) // (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error { func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() { if mc.closed.Load() || mc.buf.busy() {
return driver.ErrBadConn return driver.ErrBadConn
} }
mc.reset = true
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.cfg.CheckConnLiveness {
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
if err != nil {
mc.log("closing bad idle connection: ", err)
return driver.ErrBadConn
}
}
return nil return nil
} }
// IsValid implements driver.Validator interface // IsValid implements driver.Validator interface
// (From Go 1.15) // (From Go 1.15)
func (mc *mysqlConn) IsValid() bool { func (mc *mysqlConn) IsValid() bool {
return !mc.closed.Load() return !mc.closed.Load() && !mc.buf.busy()
} }
var _ driver.SessionResetter = &mysqlConn{}
var _ driver.Validator = &mysqlConn{}

View File

@ -11,11 +11,55 @@ package mysql
import ( import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"fmt"
"net" "net"
"os"
"strconv"
"strings"
) )
type connector struct { type connector struct {
cfg *Config // immutable private copy. cfg *Config // immutable private copy.
encodedAttributes string // Encoded connection attributes.
}
func encodeConnectionAttributes(cfg *Config) string {
connAttrsBuf := make([]byte, 0)
// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
serverHost, _, _ := net.SplitHostPort(cfg.Addr)
if serverHost != "" {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
}
// user-defined connection attributes
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
}
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
}
return string(connAttrsBuf)
}
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}
} }
// Connect implements driver.Connector interface. // Connect implements driver.Connector interface.
@ -23,43 +67,56 @@ type connector struct {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error var err error
// Invoke beforeConnect if present, with a copy of the configuration
cfg := c.cfg
if c.cfg.beforeConnect != nil {
cfg = c.cfg.Clone()
err = c.cfg.beforeConnect(ctx, cfg)
if err != nil {
return nil, err
}
}
// New mysqlConn // New mysqlConn
mc := &mysqlConn{ mc := &mysqlConn{
maxAllowedPacket: maxPacketSize, maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1, maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}), closech: make(chan struct{}),
cfg: c.cfg, cfg: cfg,
connector: c,
} }
mc.parseTime = mc.cfg.ParseTime mc.parseTime = mc.cfg.ParseTime
// Connect to Server // Connect to Server
dialsLock.RLock() dctx := ctx
dial, ok := dials[mc.cfg.Net] if mc.cfg.Timeout > 0 {
dialsLock.RUnlock() var cancel context.CancelFunc
if ok { dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
dctx := ctx defer cancel()
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
} }
if c.cfg.DialFunc != nil {
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
} else {
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{}
mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr)
}
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
mc.rawConn = mc.netConn
// Enable TCP Keepalives on TCP connections // Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok { if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil { if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake. c.cfg.Logger.Print(err)
mc.netConn.Close()
mc.netConn = nil
return nil, err
} }
} }
@ -71,11 +128,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
} }
defer mc.finish() defer mc.finish()
mc.buf = newBuffer(mc.netConn) mc.buf = newBuffer()
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet // Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket() authData, plugin, err := mc.readHandshakePacket()
@ -92,7 +145,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
authResp, err := mc.auth(authData, plugin) authResp, err := mc.auth(authData, plugin)
if err != nil { if err != nil {
// try the default auth plugin, if using the requested plugin failed // try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin) authResp, err = mc.auth(authData, plugin)
if err != nil { if err != nil {
@ -114,6 +167,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err return nil, err
} }
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
mc.compress = true
mc.compIO = newCompIO(mc)
}
if mc.cfg.MaxAllowedPacket > 0 { if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else { } else {
@ -123,12 +180,36 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.Close() mc.Close()
return nil, err return nil, err
} }
mc.maxAllowedPacket = stringToInt(maxap) - 1 n, err := strconv.Atoi(string(maxap))
if err != nil {
mc.Close()
return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err)
}
mc.maxAllowedPacket = n - 1
} }
if mc.maxAllowedPacket < maxPacketSize { if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket mc.maxWriteSize = mc.maxAllowedPacket
} }
// Charset: character_set_connection, character_set_client, character_set_results
if len(mc.cfg.charsets) > 0 {
for _, cs := range mc.cfg.charsets {
// ignore errors here - a charset may not exist
if mc.cfg.Collation != "" {
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
} else {
err = mc.exec("SET NAMES " + cs)
}
if err == nil {
break
}
}
if err != nil {
mc.Close()
return nil, err
}
}
// Handle DSN Params // Handle DSN Params
err = mc.handleParams() err = mc.handleParams()
if err != nil { if err != nil {

View File

@ -8,12 +8,27 @@
package mysql package mysql
import "runtime"
const ( const (
debug = false // for debugging. Set true only in development.
defaultAuthPlugin = "mysql_native_password" defaultAuthPlugin = "mysql_native_password"
defaultMaxAllowedPacket = 4 << 20 // 4 MiB defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
minProtocolVersion = 10 minProtocolVersion = 10
maxPacketSize = 1<<24 - 1 maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05.999999" timeFormat = "2006-01-02 15:04:05.999999"
// Connection attributes
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
connAttrClientName = "_client_name"
connAttrClientNameValue = "Go-MySQL-Driver"
connAttrOS = "_os"
connAttrOSValue = runtime.GOOS
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
connAttrServerHost = "_server_host"
) )
// MySQL constants documentation: // MySQL constants documentation:
@ -112,7 +127,10 @@ const (
fieldTypeBit fieldTypeBit
) )
const ( const (
fieldTypeJSON fieldType = iota + 0xf5 fieldTypeVector fieldType = iota + 0xf2
fieldTypeInvalid
fieldTypeBool
fieldTypeJSON
fieldTypeNewDecimal fieldTypeNewDecimal
fieldTypeEnum fieldTypeEnum
fieldTypeSet fieldTypeSet

View File

@ -55,6 +55,15 @@ func RegisterDialContext(net string, dial DialContextFunc) {
dials[net] = dial dials[net] = dial
} }
// DeregisterDialContext removes the custom dial function registered with the given net.
func DeregisterDialContext(net string) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials != nil {
delete(dials, net)
}
}
// RegisterDial registers a custom dial function. It can then be used by the // RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network. // network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function. // addr is passed as a parameter to the dial function.
@ -74,14 +83,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := &connector{ c := newConnector(cfg)
cfg: cfg,
}
return c.Connect(context.Background()) return c.Connect(context.Background())
} }
// This variable can be replaced with -ldflags like below:
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
var driverName = "mysql"
func init() { func init() {
sql.Register("mysql", &MySQLDriver{}) if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
} }
// NewConnector returns new driver.Connector. // NewConnector returns new driver.Connector.
@ -92,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil { if err := cfg.normalize(); err != nil {
return nil, err return nil, err
} }
return &connector{cfg: cfg}, nil return newConnector(cfg), nil
} }
// OpenConnector implements driver.DriverContext. // OpenConnector implements driver.DriverContext.
@ -101,7 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &connector{ return newConnector(cfg), nil
cfg: cfg,
}, nil
} }

View File

@ -10,6 +10,7 @@ package mysql
import ( import (
"bytes" "bytes"
"context"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -34,22 +35,29 @@ var (
// If a new Config is created instead of being parsed from a DSN string, // If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values. // the NewConfig function should be used, which sets default values.
type Config struct { type Config struct {
User string // Username // non boolean fields
Passwd string // Password (requires User)
Net string // Network type User string // Username
Addr string // Network address (requires Net) Passwd string // Password (requires User)
DBName string // Database name Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Params map[string]string // Connection parameters Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
Collation string // Connection collation DBName string // Database name
Loc *time.Location // Location for time.Time values Params map[string]string // Connection parameters
MaxAllowedPacket int // Max packet size allowed ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
ServerPubKey string // Server public key name Collation string // Connection collation. When set, this will be set in SET NAMES <charset> COLLATE <collation> query
pubKey *rsa.PublicKey // Server public key Loc *time.Location // Location for time.Time values
TLSConfig string // TLS configuration name MaxAllowedPacket int // Max packet size allowed
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig ServerPubKey string // Server public key name
Timeout time.Duration // Dial timeout TLSConfig string // TLS configuration name
ReadTimeout time.Duration // I/O read timeout TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
WriteTimeout time.Duration // I/O write timeout Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
// DialFunc specifies the dial function for creating connections
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// boolean fields
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowCleartextPasswords bool // Allows the cleartext client side plugin
@ -63,17 +71,83 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections RejectReadOnly bool // Reject read-only connections
// unexported fields. new options should be come here.
// boolean first. alphabetical order.
compress bool // Enable zlib compression
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
charsets []string // Connection charset. When set, this will be set in SET NAMES <charset> query
} }
// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type Option func(*Config) error
// NewConfig creates a new Config and sets default values. // NewConfig creates a new Config and sets default values.
func NewConfig() *Config { func NewConfig() *Config {
return &Config{ cfg := &Config{
Collation: defaultCollation,
Loc: time.UTC, Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket, MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true, AllowNativePasswords: true,
CheckConnLiveness: true, CheckConnLiveness: true,
} }
return cfg
}
// Apply applies the given options to the Config object.
func (c *Config) Apply(opts ...Option) error {
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}
// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) Option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}
// BeforeConnect sets the function to be invoked before a connection is established.
func BeforeConnect(fn func(context.Context, *Config) error) Option {
return func(cfg *Config) error {
cfg.beforeConnect = fn
return nil
}
}
// EnableCompress sets the compression mode.
func EnableCompression(yes bool) Option {
return func(cfg *Config) error {
cfg.compress = yes
return nil
}
}
// Charset sets the connection charset and collation.
//
// charset is the connection charset.
// collation is the connection collation. It can be null or empty string.
//
// When collation is not specified, `SET NAMES <charset>` command is sent when the connection is established.
// When collation is specified, `SET NAMES <charset> COLLATE <collation>` command is sent when the connection is established.
func Charset(charset, collation string) Option {
return func(cfg *Config) error {
cfg.charsets = []string{charset}
cfg.Collation = collation
return nil
}
} }
func (cfg *Config) Clone() *Config { func (cfg *Config) Clone() *Config {
@ -97,7 +171,7 @@ func (cfg *Config) Clone() *Config {
} }
func (cfg *Config) normalize() error { func (cfg *Config) normalize() error {
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] {
return errInvalidDSNUnsafeCollation return errInvalidDSNUnsafeCollation
} }
@ -153,6 +227,10 @@ func (cfg *Config) normalize() error {
} }
} }
if cfg.Logger == nil {
cfg.Logger = defaultLogger
}
return nil return nil
} }
@ -171,6 +249,8 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
// FormatDSN formats the given Config into a DSN string which can be passed to // FormatDSN formats the given Config into a DSN string which can be passed to
// the driver. // the driver.
//
// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config].
func (cfg *Config) FormatDSN() string { func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer var buf bytes.Buffer
@ -196,7 +276,7 @@ func (cfg *Config) FormatDSN() string {
// /dbname // /dbname
buf.WriteByte('/') buf.WriteByte('/')
buf.WriteString(cfg.DBName) buf.WriteString(url.PathEscape(cfg.DBName))
// [?param1=value1&...&paramN=valueN] // [?param1=value1&...&paramN=valueN]
hasParam := false hasParam := false
@ -230,7 +310,11 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
} }
if col := cfg.Collation; col != defaultCollation && len(col) > 0 { if charsets := cfg.charsets; len(charsets) > 0 {
writeDSNParam(&buf, &hasParam, "charset", strings.Join(charsets, ","))
}
if col := cfg.Collation; col != "" {
writeDSNParam(&buf, &hasParam, "collation", col) writeDSNParam(&buf, &hasParam, "collation", col)
} }
@ -238,6 +322,14 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
} }
if cfg.ConnectionAttributes != "" {
writeDSNParam(&buf, &hasParam, "connectionAttributes", url.QueryEscape(cfg.ConnectionAttributes))
}
if cfg.compress {
writeDSNParam(&buf, &hasParam, "compress", "true")
}
if cfg.InterpolateParams { if cfg.InterpolateParams {
writeDSNParam(&buf, &hasParam, "interpolateParams", "true") writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
} }
@ -254,6 +346,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "parseTime", "true") writeDSNParam(&buf, &hasParam, "parseTime", "true")
} }
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}
if cfg.ReadTimeout > 0 { if cfg.ReadTimeout > 0 {
writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
} }
@ -358,7 +454,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
break break
} }
} }
cfg.DBName = dsn[i+1 : j]
dbname := dsn[i+1 : j]
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
}
break break
} }
@ -378,13 +478,13 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
// Values must be url.QueryEscape'ed // Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) { func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") { for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2) key, value, found := strings.Cut(v, "=")
if len(param) != 2 { if !found {
continue continue
} }
// cfg params // cfg params
switch value := param[1]; param[0] { switch key {
// Disable INFILE allowlist / enable all files // Disable INFILE allowlist / enable all files
case "allowAllFiles": case "allowAllFiles":
var isBool bool var isBool bool
@ -441,6 +541,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value) return errors.New("invalid bool value: " + value)
} }
// charset
case "charset":
cfg.charsets = strings.Split(value, ",")
// Collation // Collation
case "collation": case "collation":
cfg.Collation = value cfg.Collation = value
@ -454,7 +558,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// Compression // Compression
case "compress": case "compress":
return errors.New("compression not implemented yet") var isBool bool
cfg.compress, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Enable client side placeholder substitution // Enable client side placeholder substitution
case "interpolateParams": case "interpolateParams":
@ -490,6 +598,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value) return errors.New("invalid bool value: " + value)
} }
// time.Time truncation
case "timeTruncate":
cfg.timeTruncate, err = time.ParseDuration(value)
if err != nil {
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}
// I/O read Timeout // I/O read Timeout
case "readTimeout": case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value) cfg.ReadTimeout, err = time.ParseDuration(value)
@ -554,13 +669,22 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil { if err != nil {
return return
} }
// Connection attributes
case "connectionAttributes":
connectionAttributes, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid connectionAttributes value: %v", err)
}
cfg.ConnectionAttributes = connectionAttributes
default: default:
// lazy init // lazy init
if cfg.Params == nil { if cfg.Params == nil {
cfg.Params = make(map[string]string) cfg.Params = make(map[string]string)
} }
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { if cfg.Params[key], err = url.QueryUnescape(value); err != nil {
return return
} }
} }

View File

@ -21,36 +21,42 @@ var (
ErrMalformPkt = errors.New("malformed packet") ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS") ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication.") ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
ErrBusyBuffer = errors.New("busy buffer") ErrBusyBuffer = errors.New("busy buffer")
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend. // to trigger a resend. Use mc.markBadConn(err) to do this.
// See https://github.com/go-sql-driver/mysql/pull/302 // See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection") errBadConnNoWrite = errors.New("bad connection")
) )
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))
// Logger is used to log critical error messages. // Logger is used to log critical error messages.
type Logger interface { type Logger interface {
Print(v ...interface{}) Print(v ...any)
} }
// SetLogger is used to set the logger for critical errors. // NopLogger is a nop implementation of the Logger interface.
type NopLogger struct{}
// Print implements Logger interface.
func (nl *NopLogger) Print(_ ...any) {}
// SetLogger is used to set the default logger for critical errors.
// The initial logger is os.Stderr. // The initial logger is os.Stderr.
func SetLogger(logger Logger) error { func SetLogger(logger Logger) error {
if logger == nil { if logger == nil {
return errors.New("logger is nil") return errors.New("logger is nil")
} }
errLog = logger defaultLogger = logger
return nil return nil
} }

View File

@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeBit: case fieldTypeBit:
return "BIT" return "BIT"
case fieldTypeBLOB: case fieldTypeBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TEXT" return "TEXT"
} }
return "BLOB" return "BLOB"
@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeGeometry: case fieldTypeGeometry:
return "GEOMETRY" return "GEOMETRY"
case fieldTypeInt24: case fieldTypeInt24:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED MEDIUMINT"
}
return "MEDIUMINT" return "MEDIUMINT"
case fieldTypeJSON: case fieldTypeJSON:
return "JSON" return "JSON"
@ -46,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "INT" return "INT"
case fieldTypeLongBLOB: case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "LONGTEXT" return "LONGTEXT"
} }
return "LONGBLOB" return "LONGBLOB"
@ -56,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "BIGINT" return "BIGINT"
case fieldTypeMediumBLOB: case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "MEDIUMTEXT" return "MEDIUMTEXT"
} }
return "MEDIUMBLOB" return "MEDIUMBLOB"
@ -74,7 +77,12 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "SMALLINT" return "SMALLINT"
case fieldTypeString: case fieldTypeString:
if mf.charSet == collations[binaryCollation] { if mf.flags&flagEnum != 0 {
return "ENUM"
} else if mf.flags&flagSet != 0 {
return "SET"
}
if mf.charSet == binaryCollationID {
return "BINARY" return "BINARY"
} }
return "CHAR" return "CHAR"
@ -88,43 +96,47 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "TINYINT" return "TINYINT"
case fieldTypeTinyBLOB: case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != binaryCollationID {
return "TINYTEXT" return "TINYTEXT"
} }
return "TINYBLOB" return "TINYBLOB"
case fieldTypeVarChar: case fieldTypeVarChar:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
case fieldTypeVarString: case fieldTypeVarString:
if mf.charSet == collations[binaryCollation] { if mf.charSet == binaryCollationID {
return "VARBINARY" return "VARBINARY"
} }
return "VARCHAR" return "VARCHAR"
case fieldTypeYear: case fieldTypeYear:
return "YEAR" return "YEAR"
case fieldTypeVector:
return "VECTOR"
default: default:
return "" return ""
} }
} }
var ( var (
scanTypeFloat32 = reflect.TypeOf(float32(0)) scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0)) scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0)) scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0)) scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0)) scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0)) scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0)) scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) scanTypeString = reflect.TypeOf("")
scanTypeUnknown = reflect.TypeOf(new(interface{})) scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(any))
) )
type mysqlField struct { type mysqlField struct {
@ -187,12 +199,18 @@ func (mf *mysqlField) scanType() reflect.Type {
} }
return scanTypeNullFloat return scanTypeNullFloat
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeVector:
if mf.charSet == binaryCollationID {
return scanTypeBytes
}
fallthrough
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, if mf.flags&flagNotNULL != 0 {
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, return scanTypeString
fieldTypeTime: }
return scanTypeRawBytes return scanTypeNullString
case fieldTypeDate, fieldTypeNewDate, case fieldTypeDate, fieldTypeNewDate,
fieldTypeTimestamp, fieldTypeDateTime: fieldTypeTimestamp, fieldTypeDateTime:

View File

@ -1,25 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build gofuzz
// +build gofuzz
package mysql
import (
"database/sql"
)
func Fuzz(data []byte) int {
db, err := sql.Open("mysql", string(data))
if err != nil {
return 0
}
db.Close()
return 1
}

View File

@ -17,7 +17,7 @@ import (
) )
var ( var (
fileRegister map[string]bool fileRegister map[string]struct{}
fileRegisterLock sync.RWMutex fileRegisterLock sync.RWMutex
readerRegister map[string]func() io.Reader readerRegister map[string]func() io.Reader
readerRegisterLock sync.RWMutex readerRegisterLock sync.RWMutex
@ -37,10 +37,10 @@ func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock() fileRegisterLock.Lock()
// lazy map init // lazy map init
if fileRegister == nil { if fileRegister == nil {
fileRegister = make(map[string]bool) fileRegister = make(map[string]struct{})
} }
fileRegister[strings.Trim(filePath, `"`)] = true fileRegister[strings.Trim(filePath, `"`)] = struct{}{}
fileRegisterLock.Unlock() fileRegisterLock.Unlock()
} }
@ -93,9 +93,8 @@ func deferredClose(err *error, closer io.Closer) {
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *mysqlConn) handleInFileRequest(name string) (err error) { func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader var rdr io.Reader
var data []byte
packetSize := defaultPacketSize packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize { if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize packetSize = mc.maxWriteSize
@ -116,17 +115,17 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
defer deferredClose(&err, cl) defer deferredClose(&err, cl)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is <nil>", name) err = fmt.Errorf("reader '%s' is <nil>", name)
} }
} else { } else {
err = fmt.Errorf("Reader '%s' is not registered", name) err = fmt.Errorf("reader '%s' is not registered", name)
} }
} else { // File } else { // File
name = strings.Trim(name, `"`) name = strings.Trim(name, `"`)
fileRegisterLock.RLock() fileRegisterLock.RLock()
fr := fileRegister[name] _, exists := fileRegister[name]
fileRegisterLock.RUnlock() fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr { if mc.cfg.AllowAllFiles || exists {
var file *os.File var file *os.File
var fi os.FileInfo var fi os.FileInfo
@ -147,14 +146,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
} }
// send content packets // send content packets
var data []byte
// if packetSize == 0, the Reader contains no data // if packetSize == 0, the Reader contains no data
if err == nil && packetSize > 0 { if err == nil && packetSize > 0 {
data := make([]byte, 4+packetSize) data = make([]byte, 4+packetSize)
var n int var n int
for err == nil { for err == nil {
n, err = rdr.Read(data[4:]) n, err = rdr.Read(data[4:])
if n > 0 { if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
return ioErr return ioErr
} }
} }
@ -168,15 +169,16 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if data == nil { if data == nil {
data = make([]byte, 4) data = make([]byte, 4)
} }
if ioErr := mc.writePacket(data[:4]); ioErr != nil { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
return ioErr return ioErr
} }
mc.conn().syncSequence()
// read OK packet // read OK packet
if err == nil { if err == nil {
return mc.readResultOK() return mc.readResultOK()
} }
mc.readPacket() mc.conn().readPacket()
return err return err
} }

View File

@ -38,7 +38,7 @@ type NullTime sql.NullTime
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string), // The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails. // otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) { func (nt *NullTime) Scan(value any) (err error) {
if value == nil { if value == nil {
nt.Time, nt.Valid = time.Time{}, false nt.Time, nt.Valid = time.Time{}, false
return return
@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
} }
nt.Valid = false nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value) return fmt.Errorf("can't convert %T to time.Time", value)
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.

View File

@ -14,75 +14,108 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"os"
"strconv"
"time" "time"
) )
// Packets documentation: // MySQL client/server protocol documentations.
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html // https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
// https://mariadb.com/kb/en/clientserver-protocol/
// read n bytes from mc.buf
func (mc *mysqlConn) readNext(n int) ([]byte, error) {
if mc.buf.len() < n {
err := mc.buf.fill(n, mc.readWithTimeout)
if err != nil {
return nil, err
}
}
return mc.buf.readNext(n), nil
}
// Read packet to buffer 'data' // Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) { func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte var prevData []byte
invalidSequence := false
readNext := mc.readNext
if mc.compress {
readNext = mc.compIO.readNext
}
for { for {
// read packet header // read packet header
data, err := mc.buf.readNext(4) data, err := readNext(4)
if err != nil { if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.log(err)
mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
// packet length [24 bit] // packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) pktLen := getUint24(data[:3])
seq := data[3]
// check packet sync [8 bit] // check packet sync [8 bit]
if data[3] != mc.sequence { if seq != mc.sequence {
if data[3] > mc.sequence { mc.log(fmt.Sprintf("[warn] unexpected sequence nr: expected %v, got %v", mc.sequence, seq))
return nil, ErrPktSyncMul // MySQL and MariaDB doesn't check packet nr in compressed packet.
if !mc.compress {
// For large packets, we stop reading as soon as sync error.
if len(prevData) > 0 {
mc.close()
return nil, ErrPktSyncMul
}
invalidSequence = true
} }
return nil, ErrPktSync
} }
mc.sequence++ mc.sequence = seq + 1
// packets with length 0 terminate a previous packet which is a // packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long // multiple of (2^24)-1 bytes long
if pktLen == 0 { if pktLen == 0 {
// there was no previous packet // there was no previous packet
if prevData == nil { if prevData == nil {
errLog.Print(ErrMalformPkt) mc.log(ErrMalformPkt)
mc.Close() mc.close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
return prevData, nil return prevData, nil
} }
// read packet body [pktLen bytes] // read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen) data, err = readNext(pktLen)
if err != nil { if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr return nil, cerr
} }
errLog.Print(err) mc.log(err)
mc.Close()
return nil, ErrInvalidConn return nil, ErrInvalidConn
} }
// return data if this was the last packet // return data if this was the last packet
if pktLen < maxPacketSize { if pktLen < maxPacketSize {
// zero allocations for non-split packets // zero allocations for non-split packets
if prevData == nil { if prevData != nil {
return data, nil data = append(prevData, data...)
} }
if invalidSequence {
return append(prevData, data...), nil mc.close()
// return sync error only for regular packet.
// error packets may have wrong sequence number.
if data[0] != iERR {
return nil, ErrPktSync
}
}
return data, nil
} }
prevData = append(prevData, data...) prevData = append(prevData, data...)
@ -92,88 +125,52 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// Write packet buffer 'data' // Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4 pktLen := len(data) - 4
if pktLen > mc.maxAllowedPacket { if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge return ErrPktTooLarge
} }
// Perform a stale connection check. We only perform this check for writeFunc := mc.writeWithTimeout
// the first query on a connection that has been checked out of the if mc.compress {
// connection pool: a fresh connection from the pool is more likely writeFunc = mc.compIO.writePackets
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
}
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
} }
for { for {
var size int size := min(maxPacketSize, pktLen)
if pktLen >= maxPacketSize { putUint24(data[:3], size)
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
data[3] = mc.sequence data[3] = mc.sequence
// Write packet // Write packet
if mc.writeTimeout > 0 { if debug {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { fmt.Fprintf(os.Stderr, "writePacket: size=%v seq=%v\n", size, mc.sequence)
return err
}
} }
n, err := mc.netConn.Write(data[:4+size]) n, err := writeFunc(data[:4+size])
if err == nil && n == 4+size { if err != nil {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}
// Handle error
if err == nil { // n != len(data)
mc.cleanup() mc.cleanup()
errLog.Print(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil { if cerr := mc.canceled.Value(); cerr != nil {
return cerr return cerr
} }
if n == 0 && pktLen == len(data)-4 { if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet // only for the first loop iteration when nothing was written yet
mc.log(err)
return errBadConnNoWrite return errBadConnNoWrite
} else {
return err
} }
mc.cleanup()
errLog.Print(err)
} }
return ErrInvalidConn if n != 4+size {
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
mc.cleanup()
return io.ErrShortWrite
}
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
} }
} }
@ -186,11 +183,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket() data, err = mc.readPacket()
if err != nil { if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return return
} }
@ -234,12 +226,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos { if len(data) > pos {
// character set [1 byte] // character set [1 byte]
// status flags [2 bytes] // status flags [2 bytes]
pos += 3
// capability flags (upper 2 bytes) [2 bytes] // capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2
// length of auth-plugin-data [1 byte] // length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes] // reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10 pos += 11
// second part of the password cipher [mininum 13 bytes], // second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8) // where len=MAX(13, length of auth-plugin-data - 8)
// //
// The web documentation is ambiguous about the length. However, // The web documentation is ambiguous about the length. However,
@ -285,12 +280,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles | clientLocalFiles |
clientPluginAuth | clientPluginAuth |
clientMultiResults | clientMultiResults |
mc.flags&clientConnectAttrs |
mc.flags&clientLongFlag mc.flags&clientLongFlag
sendConnectAttrs := mc.flags&clientConnectAttrs != 0
if mc.cfg.ClientFoundRows { if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows clientFlags |= clientFoundRows
} }
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
clientFlags |= clientCompress
}
// To enable TLS / SSL // To enable TLS / SSL
if mc.cfg.TLS != nil { if mc.cfg.TLS != nil {
clientFlags |= clientSSL clientFlags |= clientSSL
@ -318,34 +318,38 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1 pktLen += n + 1
} }
// encode length of the connection attributes
var connAttrsLEI []byte
if sendConnectAttrs {
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
}
// Calculate packet length and get buffer with that size // Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection mc.cleanup()
errLog.Print(err) return err
return errBadConnNoWrite
} }
// ClientFlags [32 bit] // ClientFlags [32 bit]
data[4] = byte(clientFlags) binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
// MaxPacketSize [32 bit] (none) // MaxPacketSize [32 bit] (none)
data[8] = 0x00 binary.LittleEndian.PutUint32(data[8:], 0)
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
// Charset [1 byte] // Collation ID [1 byte]
var found bool data[12] = defaultCollationID
data[12], found = collations[mc.cfg.Collation] if cname := mc.cfg.Collation; cname != "" {
if !found { colID, ok := collations[cname]
// Note possibility for false negatives: if ok {
// could be triggered although the collation is valid if the data[12] = colID
// collations map does not contain entries the server supports. } else if len(mc.cfg.charsets) > 0 {
return errors.New("unknown collation") // When cfg.charset is set, the collation is set by `SET NAMES <charset> COLLATE <collation>`.
return fmt.Errorf("unknown collation: %q", cname)
}
} }
// Filler [23 bytes] (all 0x00) // Filler [23 bytes] (all 0x00)
@ -365,11 +369,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Switch to TLS // Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
return err return err
} }
mc.rawConn = mc.netConn
mc.netConn = tlsConn mc.netConn = tlsConn
mc.buf.nc = tlsConn
} }
// User [null terminated string] // User [null terminated string]
@ -394,6 +399,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data[pos] = 0x00 data[pos] = 0x00
pos++ pos++
// Connection Attributes
if sendConnectAttrs {
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
}
// Send Auth packet // Send Auth packet
return mc.writePacket(data[:pos]) return mc.writePacket(data[:pos])
} }
@ -401,11 +412,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData) pktLen := 4 + len(authData)
data, err := mc.buf.takeSmallBuffer(pktLen) data, err := mc.buf.takeBuffer(pktLen)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection mc.cleanup()
errLog.Print(err) return err
return errBadConnNoWrite
} }
// Add the auth data [EOF] // Add the auth data [EOF]
@ -419,32 +429,30 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1) data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
errLog.Print(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
data[4] = command data[4] = command
// Send CMD packet // Send CMD packet
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
pktLen := 1 + len(arg) pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4) data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
errLog.Print(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
@ -454,31 +462,30 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
copy(data[5:], arg) copy(data[5:], arg)
// Send CMD packet // Send CMD packet
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence // Reset Packet Sequence
mc.sequence = 0 mc.resetSequence()
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
errLog.Print(err)
return errBadConnNoWrite
} }
// Add command byte // Add command byte
data[4] = command data[4] = command
// Add arg [32 bit] // Add arg [32 bit]
data[5] = byte(arg) binary.LittleEndian.PutUint32(data[5:], arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
// Send CMD packet // Send CMD packet
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
/****************************************************************************** /******************************************************************************
@ -495,7 +502,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
switch data[0] { switch data[0] {
case iOK: case iOK:
return nil, "", mc.handleOkPacket(data) // resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.resultUnchanged().handleOkPacket(data)
case iAuthMoreData: case iAuthMoreData:
return data[1:], "", err return data[1:], "", err
@ -511,6 +520,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
} }
plugin := string(data[1:pluginEndIndex]) plugin := string(data[1:pluginEndIndex])
authData := data[pluginEndIndex+1:] authData := data[pluginEndIndex+1:]
if len(authData) > 0 && authData[len(authData)-1] == 0 {
authData = authData[:len(authData)-1]
}
return authData, plugin, nil return authData, plugin, nil
default: // Error otherwise default: // Error otherwise
@ -518,9 +530,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
} }
} }
// Returns error if Packet is not an 'Result OK'-Packet // Returns error if Packet is not a 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error { func (mc *okHandler) readResultOK() error {
data, err := mc.readPacket() data, err := mc.conn().readPacket()
if err != nil { if err != nil {
return err return err
} }
@ -528,35 +540,37 @@ func (mc *mysqlConn) readResultOK() error {
if data[0] == iOK { if data[0] == iOK {
return mc.handleOkPacket(data) return mc.handleOkPacket(data)
} }
return mc.handleErrorPacket(data) return mc.conn().handleErrorPacket(data)
} }
// Result Set Header Packet // Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket() // handleOkPacket replaces both values; other cases leave the values unchanged.
if err == nil { mc.result.affectedRows = append(mc.result.affectedRows, 0)
switch data[0] { mc.result.insertIds = append(mc.result.insertIds, 0)
case iOK: data, err := mc.conn().readPacket()
return 0, mc.handleOkPacket(data) if err != nil {
return 0, err
case iERR:
return 0, mc.handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
}
return 0, ErrMalformPkt
} }
return 0, err
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
num, _, _ := readLengthEncodedInteger(data)
// ignore remaining data in the packet. see #1478.
return int(num), nil
} }
// Error Packet // Error Packet
@ -573,7 +587,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // 1836: ER_READ_ONLY_MODE
if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly {
// Oops; we are connected to a read-only connection, and won't be able // Oops; we are connected to a read-only connection, and won't be able
// to issue any write statements. Since RejectReadOnly is configured, // to issue any write statements. Since RejectReadOnly is configured,
// we throw away this connection hoping this one would have write // we throw away this connection hoping this one would have write
@ -607,18 +622,61 @@ func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8 return statusFlag(b[0]) | statusFlag(b[1])<<8
} }
// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't
// need to be cleared first (e.g. during authentication, or while additional
// resultsets are being fetched.)
func (mc *mysqlConn) resultUnchanged() *okHandler {
return (*okHandler)(mc)
}
// okHandler represents the state of the connection when mysqlConn.result has
// been prepared for processing of OK packets.
//
// To correctly populate mysqlConn.result (updated by handleOkPacket()), all
// callpaths must either:
//
// 1. first clear it using clearResult(), or
// 2. confirm that they don't need to (by calling resultUnchanged()).
//
// Both return an instance of type *okHandler.
type okHandler mysqlConn
// Exposes the underlying type's methods.
func (mc *okHandler) conn() *mysqlConn {
return (*mysqlConn)(mc)
}
// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// It returns a handler that can process OK responses.
func (mc *mysqlConn) clearResult() *okHandler {
mc.result = mysqlResult{}
return (*okHandler)(mc)
}
// Ok Packet // Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error { func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int var n, m int
var affectedRows, insertId uint64
// 0x00 [1 byte] // 0x00 [1 byte]
// Affected rows [Length Coded Binary] // Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary] // Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) insertId, _, m = readLengthEncodedInteger(data[1+n:])
// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes] // server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2]) mc.status = readStatus(data[1+n+m : 1+n+m+2])
@ -769,7 +827,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
for i := range dest { for i := range dest {
// Read bytes and convert to string // Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n pos += n
if err != nil { if err != nil {
@ -781,19 +840,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
continue continue
} }
if !mc.parseTime {
continue
}
// Parse time field
switch rows.rs.columns[i].fieldType { switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, case fieldTypeTimestamp,
fieldTypeDateTime, fieldTypeDateTime,
fieldTypeDate, fieldTypeDate,
fieldTypeNewDate: fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { if mc.parseTime {
return err dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
} }
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
} }
} }
@ -875,32 +955,26 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
pktLen = dataOffset + argLen pktLen = dataOffset + argLen
} }
stmt.mc.sequence = 0
// Add command byte [1 byte] // Add command byte [1 byte]
data[4] = comStmtSendLongData data[4] = comStmtSendLongData
// Add stmtID [32 bit] // Add stmtID [32 bit]
data[5] = byte(stmt.id) binary.LittleEndian.PutUint32(data[5:], stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit] // Add paramID [16 bit]
data[9] = byte(paramID) binary.LittleEndian.PutUint16(data[9:], uint16(paramID))
data[10] = byte(paramID >> 8)
// Send CMD packet // Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen]) err := stmt.mc.writePacket(data[:4+pktLen])
// Every COM_LONG_DATA packet reset Packet Sequence
stmt.mc.resetSequence()
if err == nil { if err == nil {
data = data[pktLen-dataOffset:] data = data[pktLen-dataOffset:]
continue continue
} }
return err return err
} }
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil return nil
} }
@ -925,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
} }
// Reset packet-sequence // Reset packet-sequence
mc.sequence = 0 mc.resetSequence()
var data []byte var data []byte
var err error var err error
@ -937,28 +1011,20 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In this case the len(data) == cap(data) which is used to optimise the flow below. // In this case the len(data) == cap(data) which is used to optimise the flow below.
} }
if err != nil { if err != nil {
// cannot take the buffer. Something must be wrong with the connection return err
errLog.Print(err)
return errBadConnNoWrite
} }
// command [1 byte] // command [1 byte]
data[4] = comStmtExecute data[4] = comStmtExecute
// statement_id [4 bytes] // statement_id [4 bytes]
data[5] = byte(stmt.id) binary.LittleEndian.PutUint32(data[5:], stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00 data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes] // iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01 binary.LittleEndian.PutUint32(data[10:], 1)
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 { if len(args) > 0 {
pos := minPktLen pos := minPktLen
@ -1012,50 +1078,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
case int64: case int64:
paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x00 paramTypes[i+i+1] = 0x00
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case uint64: case uint64:
paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x80 // type is unsigned paramTypes[i+i+1] = 0x80 // type is unsigned
paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64: case float64:
paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00 paramTypes[i+i+1] = 0x00
paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v))
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool: case bool:
paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i] = byte(fieldTypeTiny)
@ -1116,7 +1149,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() { if v.IsZero() {
b = append(b, "0000-00-00"...) b = append(b, "0000-00-00"...)
} else { } else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc)) b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil { if err != nil {
return err return err
} }
@ -1136,20 +1169,21 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer // In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) { if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...) data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil { mc.buf.store(data) // allow this buffer to be reused
errLog.Print(err)
return errBadConnNoWrite
}
} }
pos += len(paramValues) pos += len(paramValues)
data = data[:pos] data = data[:pos]
} }
return mc.writePacket(data) err = mc.writePacket(data)
mc.syncSequence()
return err
} }
func (mc *mysqlConn) discardResults() error { // For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 { for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket() resLen, err := mc.readResultSetHeaderPacket()
if err != nil { if err != nil {
@ -1157,11 +1191,11 @@ func (mc *mysqlConn) discardResults() error {
} }
if resLen > 0 { if resLen > 0 {
// columns // columns
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
// rows // rows
if err := mc.readUntilEOF(); err != nil { if err := mc.conn().readUntilEOF(); err != nil {
return err return err
} }
} }
@ -1268,7 +1302,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
fieldTypeVector:
var isNull bool var isNull bool
var n int var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) dest[i], isNull, n, err = readLengthEncodedString(data[pos:])

View File

@ -8,15 +8,43 @@
package mysql package mysql
import "database/sql/driver"
// Result exposes data not available through *connection.Result.
//
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
// executed statement.
AllRowsAffected() []int64
// AllLastInsertIds returns a slice containing the last inserted ID for each
// executed statement.
AllLastInsertIds() []int64
}
type mysqlResult struct { type mysqlResult struct {
affectedRows int64 // One entry in both slices is created for every executed statement result.
insertId int64 affectedRows []int64
insertIds []int64
} }
func (res *mysqlResult) LastInsertId() (int64, error) { func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil return res.insertIds[len(res.insertIds)-1], nil
} }
func (res *mysqlResult) RowsAffected() (int64, error) { func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil return res.affectedRows[len(res.affectedRows)-1], nil
}
func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
}
func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
} }

View File

@ -111,19 +111,13 @@ func (rows *mysqlRows) Close() (err error) {
return err return err
} }
// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()
// Remove unread packets from stream // Remove unread packets from stream
if !rows.rs.done { if !rows.rs.done {
err = mc.readUntilEOF() err = mc.readUntilEOF()
} }
if err == nil { if err == nil {
if err = mc.discardResults(); err != nil { handleOk := mc.clearResult()
if err = handleOk.discardResults(); err != nil {
return err return err
} }
} }
@ -160,7 +154,15 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
return 0, io.EOF return 0, io.EOF
} }
rows.rs = resultSet{} rows.rs = resultSet{}
return rows.mc.readResultSetHeaderPacket() // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
rows.mc.status = rows.mc.status & (^statusMoreResultsExists)
}
return resLen, err
} }
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {

View File

@ -24,11 +24,12 @@ type mysqlStmt struct {
func (stmt *mysqlStmt) Close() error { func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.Load() { if stmt.mc == nil || stmt.mc.closed.Load() {
// driver.Stmt.Close can be called more than once, thus this function // driver.Stmt.Close could be called more than once, thus this function
// has to be idempotent. // had to be idempotent. See also Issue #450 and golang/go#16019.
// See also Issue #450 and golang/go#16019. // This bug has been fixed in Go 1.8.
//errLog.Print(ErrInvalidConn) // https://github.com/golang/go/commit/90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a
return driver.ErrBadConn // But we keep this function idempotent because it is safer.
return nil
} }
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
@ -51,7 +52,6 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
} }
mc := stmt.mc mc := stmt.mc
handleOk := stmt.mc.clearResult()
mc.affectedRows = 0
mc.insertId = 0
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
} }
} }
if err := mc.discardResults(); err != nil { if err := handleOk.discardResults(); err != nil {
return nil, err return nil, err
} }
return &mysqlResult{ copied := mc.result
affectedRows: int64(mc.affectedRows), return &copied, nil
insertId: int64(mc.insertId),
}, nil
} }
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
@ -99,7 +95,6 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
// Send command // Send command
@ -111,7 +106,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
mc := stmt.mc mc := stmt.mc
// Read Result // Read Result
resLen, err := mc.readResultSetHeaderPacket() handleOk := stmt.mc.clearResult()
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -144,7 +140,7 @@ type converter struct{}
// implementation does not. This function should be kept in sync with // implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that // database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference. // deliberate difference.
func (c converter) ConvertValue(v interface{}) (driver.Value, error) { func (c converter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) { if driver.IsValue(v) {
return v, nil return v, nil
} }

View File

@ -13,18 +13,32 @@ type mysqlTx struct {
} }
func (tx *mysqlTx) Commit() (err error) { func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.Load() { if tx.mc == nil {
return ErrInvalidConn return ErrInvalidConn
} }
if tx.mc.closed.Load() {
err = tx.mc.error()
if err == nil {
err = ErrInvalidConn
}
return
}
err = tx.mc.exec("COMMIT") err = tx.mc.exec("COMMIT")
tx.mc = nil tx.mc = nil
return return
} }
func (tx *mysqlTx) Rollback() (err error) { func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.Load() { if tx.mc == nil {
return ErrInvalidConn return ErrInvalidConn
} }
if tx.mc.closed.Load() {
err = tx.mc.error()
if err == nil {
err = ErrInvalidConn
}
return
}
err = tx.mc.exec("ROLLBACK") err = tx.mc.exec("ROLLBACK")
tx.mc = nil tx.mc = nil
return return

View File

@ -36,7 +36,7 @@ var (
// registering it. // registering it.
// //
// rootCertPool := x509.NewCertPool() // rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem") // pem, err := os.ReadFile("/path/ca-cert.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
return nil, fmt.Errorf("invalid DATETIME packet length %d", num) return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
} }
func appendDateTime(buf []byte, t time.Time) ([]byte, error) { func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) {
if timeTruncate > 0 {
t = t.Truncate(timeTruncate)
}
year, month, day := t.Date() year, month, day := t.Date()
hour, min, sec := t.Clock() hour, min, sec := t.Clock()
nsec := t.Nanosecond() nsec := t.Nanosecond()
@ -486,17 +490,16 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
* Convert from and to bytes * * Convert from and to bytes *
******************************************************************************/ ******************************************************************************/
func uint64ToBytes(n uint64) []byte { // 24bit integer: used for packet headers.
return []byte{
byte(n), func putUint24(data []byte, n int) {
byte(n >> 8), data[2] = byte(n >> 16)
byte(n >> 16), data[1] = byte(n >> 8)
byte(n >> 24), data[0] = byte(n)
byte(n >> 32), }
byte(n >> 40),
byte(n >> 48), func getUint24(data []byte) int {
byte(n >> 56), return int(data[2])<<16 | int(data[1])<<8 | int(data[0])
}
} }
func uint64ToString(n uint64) []byte { func uint64ToString(n uint64) []byte {
@ -521,16 +524,6 @@ func uint64ToString(n uint64) []byte {
return a[i:] return a[i:]
} }
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, whether the value is NULL, // returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than // the number of bytes read and an error, in case the string is longer than
// the input slice // the input slice
@ -582,18 +575,15 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// 252: value of following 2 // 252: value of following 2
case 0xfc: case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3 return uint64(binary.LittleEndian.Uint16(b[1:])), false, 3
// 253: value of following 3 // 253: value of following 3
case 0xfd: case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 return uint64(getUint24(b[1:])), false, 4
// 254: value of following 8 // 254: value of following 8
case 0xfe: case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
} }
// 0-250: value of first byte // 0-250: value of first byte
@ -607,13 +597,19 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
return append(b, byte(n)) return append(b, byte(n))
case n <= 0xffff: case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8)) b = append(b, 0xfc)
return binary.LittleEndian.AppendUint16(b, uint16(n))
case n <= 0xffffff: case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
} }
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), b = append(b, 0xfe)
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) return binary.LittleEndian.AppendUint64(b, n)
}
func appendLengthEncodedString(b []byte, s string) []byte {
b = appendLengthEncodedInteger(b, uint64(len(s)))
return append(b, s...)
} }
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.

View File

@ -155,7 +155,7 @@ stored in base64 encoded form, which was redundant with the information in the
type Token struct { type Token struct {
Raw string // Raw contains the raw token Raw string // Raw contains the raw token
Method SigningMethod // Method is the signing method used or to be used Method SigningMethod // Method is the signing method used or to be used
Header map[string]interface{} // Header is the first segment of the token in decoded form Header map[string]any // Header is the first segment of the token in decoded form
Claims Claims // Claims is the second segment of the token in decoded form Claims Claims // Claims is the second segment of the token in decoded form
Signature []byte // Signature is the third segment of the token in decoded form Signature []byte // Signature is the third segment of the token in decoded form
Valid bool // Valid specifies if the token is valid Valid bool // Valid specifies if the token is valid

View File

@ -55,7 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {
// Verify implements token verification for the SigningMethod. // Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ecdsa.PublicKey struct // For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interface{}) error { func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key any) error {
// Get the key // Get the key
var ecdsaKey *ecdsa.PublicKey var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) { switch k := key.(type) {
@ -89,7 +89,7 @@ func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interf
// Sign implements token signing for the SigningMethod. // Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ecdsa.PrivateKey struct // For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) { func (m *SigningMethodECDSA) Sign(signingString string, key any) ([]byte, error) {
// Get the key // Get the key
var ecdsaKey *ecdsa.PrivateKey var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) { switch k := key.(type) {

View File

@ -23,7 +23,7 @@ func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
} }
// Parse the key // Parse the key
var parsedKey interface{} var parsedKey any
if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil { if parsedKey, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err return nil, err
@ -50,7 +50,7 @@ func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
} }
// Parse the key // Parse the key
var parsedKey interface{} var parsedKey any
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil { if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
parsedKey = cert.PublicKey parsedKey = cert.PublicKey

View File

@ -33,7 +33,7 @@ func (m *SigningMethodEd25519) Alg() string {
// Verify implements token verification for the SigningMethod. // Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ed25519.PublicKey // For this verify method, key must be an ed25519.PublicKey
func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key interface{}) error { func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key any) error {
var ed25519Key ed25519.PublicKey var ed25519Key ed25519.PublicKey
var ok bool var ok bool
@ -55,7 +55,7 @@ func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key inte
// Sign implements token signing for the SigningMethod. // Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ed25519.PrivateKey // For this signing method, key must be an ed25519.PrivateKey
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) { func (m *SigningMethodEd25519) Sign(signingString string, key any) ([]byte, error) {
var ed25519Key crypto.Signer var ed25519Key crypto.Signer
var ok bool var ok bool

View File

@ -24,7 +24,7 @@ func ParseEdPrivateKeyFromPEM(key []byte) (crypto.PrivateKey, error) {
} }
// Parse the key // Parse the key
var parsedKey interface{} var parsedKey any
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return nil, err return nil, err
} }
@ -49,7 +49,7 @@ func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
} }
// Parse the key // Parse the key
var parsedKey interface{} var parsedKey any
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
return nil, err return nil, err
} }

Some files were not shown because too many files have changed in this diff Show More