mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-22 20:09:41 +01:00
869 lines
21 KiB
Go
869 lines
21 KiB
Go
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
|
//
|
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Registry for custom tls.Configs
|
|
var (
|
|
tlsConfigLock sync.RWMutex
|
|
tlsConfigRegistry map[string]*tls.Config
|
|
)
|
|
|
|
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
|
// Use the key as a value in the DSN where tls=value.
|
|
//
|
|
// Note: The provided tls.Config is exclusively owned by the driver after
|
|
// registering it.
|
|
//
|
|
// rootCertPool := x509.NewCertPool()
|
|
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
// log.Fatal("Failed to append PEM.")
|
|
// }
|
|
// clientCert := make([]tls.Certificate, 0, 1)
|
|
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// clientCert = append(clientCert, certs)
|
|
// mysql.RegisterTLSConfig("custom", &tls.Config{
|
|
// RootCAs: rootCertPool,
|
|
// Certificates: clientCert,
|
|
// })
|
|
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
|
|
//
|
|
func RegisterTLSConfig(key string, config *tls.Config) error {
|
|
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
|
|
return fmt.Errorf("key '%s' is reserved", key)
|
|
}
|
|
|
|
tlsConfigLock.Lock()
|
|
if tlsConfigRegistry == nil {
|
|
tlsConfigRegistry = make(map[string]*tls.Config)
|
|
}
|
|
|
|
tlsConfigRegistry[key] = config
|
|
tlsConfigLock.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// DeregisterTLSConfig removes the tls.Config associated with key.
|
|
func DeregisterTLSConfig(key string) {
|
|
tlsConfigLock.Lock()
|
|
if tlsConfigRegistry != nil {
|
|
delete(tlsConfigRegistry, key)
|
|
}
|
|
tlsConfigLock.Unlock()
|
|
}
|
|
|
|
func getTLSConfigClone(key string) (config *tls.Config) {
|
|
tlsConfigLock.RLock()
|
|
if v, ok := tlsConfigRegistry[key]; ok {
|
|
config = v.Clone()
|
|
}
|
|
tlsConfigLock.RUnlock()
|
|
return
|
|
}
|
|
|
|
// Returns the bool value of the input.
|
|
// The 2nd return value indicates if the input was a valid bool value
|
|
func readBool(input string) (value bool, valid bool) {
|
|
switch input {
|
|
case "1", "true", "TRUE", "True":
|
|
return true, true
|
|
case "0", "false", "FALSE", "False":
|
|
return false, true
|
|
}
|
|
|
|
// Not a valid bool value
|
|
return
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Time related utils *
|
|
******************************************************************************/
|
|
|
|
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
|
|
const base = "0000-00-00 00:00:00.000000"
|
|
switch len(b) {
|
|
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
|
|
if string(b) == base[:len(b)] {
|
|
return time.Time{}, nil
|
|
}
|
|
|
|
year, err := parseByteYear(b)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if year <= 0 {
|
|
year = 1
|
|
}
|
|
|
|
if b[4] != '-' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
|
|
}
|
|
|
|
m, err := parseByte2Digits(b[5], b[6])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if m <= 0 {
|
|
m = 1
|
|
}
|
|
month := time.Month(m)
|
|
|
|
if b[7] != '-' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
|
|
}
|
|
|
|
day, err := parseByte2Digits(b[8], b[9])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if day <= 0 {
|
|
day = 1
|
|
}
|
|
if len(b) == 10 {
|
|
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
|
|
}
|
|
|
|
if b[10] != ' ' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
|
|
}
|
|
|
|
hour, err := parseByte2Digits(b[11], b[12])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if b[13] != ':' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
|
|
}
|
|
|
|
min, err := parseByte2Digits(b[14], b[15])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if b[16] != ':' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
|
|
}
|
|
|
|
sec, err := parseByte2Digits(b[17], b[18])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if len(b) == 19 {
|
|
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
|
|
}
|
|
|
|
if b[19] != '.' {
|
|
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
|
|
}
|
|
nsec, err := parseByteNanoSec(b[20:])
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
|
|
default:
|
|
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
|
|
}
|
|
}
|
|
|
|
func parseByteYear(b []byte) (int, error) {
|
|
year, n := 0, 1000
|
|
for i := 0; i < 4; i++ {
|
|
v, err := bToi(b[i])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
year += v * n
|
|
n = n / 10
|
|
}
|
|
return year, nil
|
|
}
|
|
|
|
func parseByte2Digits(b1, b2 byte) (int, error) {
|
|
d1, err := bToi(b1)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
d2, err := bToi(b2)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return d1*10 + d2, nil
|
|
}
|
|
|
|
func parseByteNanoSec(b []byte) (int, error) {
|
|
ns, digit := 0, 100000 // max is 6-digits
|
|
for i := 0; i < len(b); i++ {
|
|
v, err := bToi(b[i])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
ns += v * digit
|
|
digit /= 10
|
|
}
|
|
// nanoseconds has 10-digits. (needs to scale digits)
|
|
// 10 - 6 = 4, so we have to multiple 1000.
|
|
return ns * 1000, nil
|
|
}
|
|
|
|
func bToi(b byte) (int, error) {
|
|
if b < '0' || b > '9' {
|
|
return 0, errors.New("not [0-9]")
|
|
}
|
|
return int(b - '0'), nil
|
|
}
|
|
|
|
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
|
|
switch num {
|
|
case 0:
|
|
return time.Time{}, nil
|
|
case 4:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
0, 0, 0, 0,
|
|
loc,
|
|
), nil
|
|
case 7:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
int(data[4]), // hour
|
|
int(data[5]), // minutes
|
|
int(data[6]), // seconds
|
|
0,
|
|
loc,
|
|
), nil
|
|
case 11:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
int(data[4]), // hour
|
|
int(data[5]), // minutes
|
|
int(data[6]), // seconds
|
|
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
|
|
loc,
|
|
), nil
|
|
}
|
|
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
|
|
}
|
|
|
|
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
|
|
year, month, day := t.Date()
|
|
hour, min, sec := t.Clock()
|
|
nsec := t.Nanosecond()
|
|
|
|
if year < 1 || year > 9999 {
|
|
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
|
|
}
|
|
year100 := year / 100
|
|
year1 := year % 100
|
|
|
|
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
|
|
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
|
|
localBuf[4] = '-'
|
|
localBuf[5], localBuf[6] = digits10[month], digits01[month]
|
|
localBuf[7] = '-'
|
|
localBuf[8], localBuf[9] = digits10[day], digits01[day]
|
|
|
|
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
|
|
return append(buf, localBuf[:10]...), nil
|
|
}
|
|
|
|
localBuf[10] = ' '
|
|
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
|
|
localBuf[13] = ':'
|
|
localBuf[14], localBuf[15] = digits10[min], digits01[min]
|
|
localBuf[16] = ':'
|
|
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
|
|
|
|
if nsec == 0 {
|
|
return append(buf, localBuf[:19]...), nil
|
|
}
|
|
nsec100000000 := nsec / 100000000
|
|
nsec1000000 := (nsec / 1000000) % 100
|
|
nsec10000 := (nsec / 10000) % 100
|
|
nsec100 := (nsec / 100) % 100
|
|
nsec1 := nsec % 100
|
|
localBuf[19] = '.'
|
|
|
|
// milli second
|
|
localBuf[20], localBuf[21], localBuf[22] =
|
|
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
|
|
// micro second
|
|
localBuf[23], localBuf[24], localBuf[25] =
|
|
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
|
|
// nano second
|
|
localBuf[26], localBuf[27], localBuf[28] =
|
|
digits01[nsec100], digits10[nsec1], digits01[nsec1]
|
|
|
|
// trim trailing zeros
|
|
n := len(localBuf)
|
|
for n > 0 && localBuf[n-1] == '0' {
|
|
n--
|
|
}
|
|
|
|
return append(buf, localBuf[:n]...), nil
|
|
}
|
|
|
|
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
|
|
// if the DATE or DATETIME has the zero value.
|
|
// It must never be changed.
|
|
// The current behavior depends on database/sql copying the result.
|
|
var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
|
|
|
|
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
|
|
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
|
|
|
|
func appendMicrosecs(dst, src []byte, decimals int) []byte {
|
|
if decimals <= 0 {
|
|
return dst
|
|
}
|
|
if len(src) == 0 {
|
|
return append(dst, ".000000"[:decimals+1]...)
|
|
}
|
|
|
|
microsecs := binary.LittleEndian.Uint32(src[:4])
|
|
p1 := byte(microsecs / 10000)
|
|
microsecs -= 10000 * uint32(p1)
|
|
p2 := byte(microsecs / 100)
|
|
microsecs -= 100 * uint32(p2)
|
|
p3 := byte(microsecs)
|
|
|
|
switch decimals {
|
|
default:
|
|
return append(dst, '.',
|
|
digits10[p1], digits01[p1],
|
|
digits10[p2], digits01[p2],
|
|
digits10[p3], digits01[p3],
|
|
)
|
|
case 1:
|
|
return append(dst, '.',
|
|
digits10[p1],
|
|
)
|
|
case 2:
|
|
return append(dst, '.',
|
|
digits10[p1], digits01[p1],
|
|
)
|
|
case 3:
|
|
return append(dst, '.',
|
|
digits10[p1], digits01[p1],
|
|
digits10[p2],
|
|
)
|
|
case 4:
|
|
return append(dst, '.',
|
|
digits10[p1], digits01[p1],
|
|
digits10[p2], digits01[p2],
|
|
)
|
|
case 5:
|
|
return append(dst, '.',
|
|
digits10[p1], digits01[p1],
|
|
digits10[p2], digits01[p2],
|
|
digits10[p3],
|
|
)
|
|
}
|
|
}
|
|
|
|
func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
|
|
// length expects the deterministic length of the zero value,
|
|
// negative time and 100+ hours are automatically added if needed
|
|
if len(src) == 0 {
|
|
return zeroDateTime[:length], nil
|
|
}
|
|
var dst []byte // return value
|
|
var p1, p2, p3 byte // current digit pair
|
|
|
|
switch length {
|
|
case 10, 19, 21, 22, 23, 24, 25, 26:
|
|
default:
|
|
t := "DATE"
|
|
if length > 10 {
|
|
t += "TIME"
|
|
}
|
|
return nil, fmt.Errorf("illegal %s length %d", t, length)
|
|
}
|
|
switch len(src) {
|
|
case 4, 7, 11:
|
|
default:
|
|
t := "DATE"
|
|
if length > 10 {
|
|
t += "TIME"
|
|
}
|
|
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
|
|
}
|
|
dst = make([]byte, 0, length)
|
|
// start with the date
|
|
year := binary.LittleEndian.Uint16(src[:2])
|
|
pt := year / 100
|
|
p1 = byte(year - 100*uint16(pt))
|
|
p2, p3 = src[2], src[3]
|
|
dst = append(dst,
|
|
digits10[pt], digits01[pt],
|
|
digits10[p1], digits01[p1], '-',
|
|
digits10[p2], digits01[p2], '-',
|
|
digits10[p3], digits01[p3],
|
|
)
|
|
if length == 10 {
|
|
return dst, nil
|
|
}
|
|
if len(src) == 4 {
|
|
return append(dst, zeroDateTime[10:length]...), nil
|
|
}
|
|
dst = append(dst, ' ')
|
|
p1 = src[4] // hour
|
|
src = src[5:]
|
|
|
|
// p1 is 2-digit hour, src is after hour
|
|
p2, p3 = src[0], src[1]
|
|
dst = append(dst,
|
|
digits10[p1], digits01[p1], ':',
|
|
digits10[p2], digits01[p2], ':',
|
|
digits10[p3], digits01[p3],
|
|
)
|
|
return appendMicrosecs(dst, src[2:], int(length)-20), nil
|
|
}
|
|
|
|
func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
|
|
// length expects the deterministic length of the zero value,
|
|
// negative time and 100+ hours are automatically added if needed
|
|
if len(src) == 0 {
|
|
return zeroDateTime[11 : 11+length], nil
|
|
}
|
|
var dst []byte // return value
|
|
|
|
switch length {
|
|
case
|
|
8, // time (can be up to 10 when negative and 100+ hours)
|
|
10, 11, 12, 13, 14, 15: // time with fractional seconds
|
|
default:
|
|
return nil, fmt.Errorf("illegal TIME length %d", length)
|
|
}
|
|
switch len(src) {
|
|
case 8, 12:
|
|
default:
|
|
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
|
|
}
|
|
// +2 to enable negative time and 100+ hours
|
|
dst = make([]byte, 0, length+2)
|
|
if src[0] == 1 {
|
|
dst = append(dst, '-')
|
|
}
|
|
days := binary.LittleEndian.Uint32(src[1:5])
|
|
hours := int64(days)*24 + int64(src[5])
|
|
|
|
if hours >= 100 {
|
|
dst = strconv.AppendInt(dst, hours, 10)
|
|
} else {
|
|
dst = append(dst, digits10[hours], digits01[hours])
|
|
}
|
|
|
|
min, sec := src[6], src[7]
|
|
dst = append(dst, ':',
|
|
digits10[min], digits01[min], ':',
|
|
digits10[sec], digits01[sec],
|
|
)
|
|
return appendMicrosecs(dst, src[8:], int(length)-9), nil
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Convert from and to bytes *
|
|
******************************************************************************/
|
|
|
|
func uint64ToBytes(n uint64) []byte {
|
|
return []byte{
|
|
byte(n),
|
|
byte(n >> 8),
|
|
byte(n >> 16),
|
|
byte(n >> 24),
|
|
byte(n >> 32),
|
|
byte(n >> 40),
|
|
byte(n >> 48),
|
|
byte(n >> 56),
|
|
}
|
|
}
|
|
|
|
func uint64ToString(n uint64) []byte {
|
|
var a [20]byte
|
|
i := 20
|
|
|
|
// U+0030 = 0
|
|
// ...
|
|
// U+0039 = 9
|
|
|
|
var q uint64
|
|
for n >= 10 {
|
|
i--
|
|
q = n / 10
|
|
a[i] = uint8(n-q*10) + 0x30
|
|
n = q
|
|
}
|
|
|
|
i--
|
|
a[i] = uint8(n) + 0x30
|
|
|
|
return a[i:]
|
|
}
|
|
|
|
// treats string value as unsigned integer representation
|
|
func stringToInt(b []byte) int {
|
|
val := 0
|
|
for i := range b {
|
|
val *= 10
|
|
val += int(b[i] - 0x30)
|
|
}
|
|
return val
|
|
}
|
|
|
|
// returns the string read as a bytes slice, wheter the value is NULL,
|
|
// the number of bytes read and an error, in case the string is longer than
|
|
// the input slice
|
|
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
|
|
// Get length
|
|
num, isNull, n := readLengthEncodedInteger(b)
|
|
if num < 1 {
|
|
return b[n:n], isNull, n, nil
|
|
}
|
|
|
|
n += int(num)
|
|
|
|
// Check data length
|
|
if len(b) >= n {
|
|
return b[n-int(num) : n : n], false, n, nil
|
|
}
|
|
return nil, false, n, io.EOF
|
|
}
|
|
|
|
// returns the number of bytes skipped and an error, in case the string is
|
|
// longer than the input slice
|
|
func skipLengthEncodedString(b []byte) (int, error) {
|
|
// Get length
|
|
num, _, n := readLengthEncodedInteger(b)
|
|
if num < 1 {
|
|
return n, nil
|
|
}
|
|
|
|
n += int(num)
|
|
|
|
// Check data length
|
|
if len(b) >= n {
|
|
return n, nil
|
|
}
|
|
return n, io.EOF
|
|
}
|
|
|
|
// returns the number read, whether the value is NULL and the number of bytes read
|
|
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
|
|
// See issue #349
|
|
if len(b) == 0 {
|
|
return 0, true, 1
|
|
}
|
|
|
|
switch b[0] {
|
|
// 251: NULL
|
|
case 0xfb:
|
|
return 0, true, 1
|
|
|
|
// 252: value of following 2
|
|
case 0xfc:
|
|
return uint64(b[1]) | uint64(b[2])<<8, false, 3
|
|
|
|
// 253: value of following 3
|
|
case 0xfd:
|
|
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
|
|
|
|
// 254: value of following 8
|
|
case 0xfe:
|
|
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
|
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
|
|
uint64(b[7])<<48 | uint64(b[8])<<56,
|
|
false, 9
|
|
}
|
|
|
|
// 0-250: value of first byte
|
|
return uint64(b[0]), false, 1
|
|
}
|
|
|
|
// encodes a uint64 value and appends it to the given bytes slice
|
|
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
|
|
switch {
|
|
case n <= 250:
|
|
return append(b, byte(n))
|
|
|
|
case n <= 0xffff:
|
|
return append(b, 0xfc, byte(n), byte(n>>8))
|
|
|
|
case n <= 0xffffff:
|
|
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
|
|
}
|
|
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
|
|
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
|
|
}
|
|
|
|
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
|
|
// If cap(buf) is not enough, reallocate new buffer.
|
|
func reserveBuffer(buf []byte, appendSize int) []byte {
|
|
newSize := len(buf) + appendSize
|
|
if cap(buf) < newSize {
|
|
// Grow buffer exponentially
|
|
newBuf := make([]byte, len(buf)*2+appendSize)
|
|
copy(newBuf, buf)
|
|
buf = newBuf
|
|
}
|
|
return buf[:newSize]
|
|
}
|
|
|
|
// escapeBytesBackslash escapes []byte with backslashes (\)
|
|
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
|
|
// characters, and turning others into specific escape sequences, such as
|
|
// turning newlines into \n and null bytes into \0.
|
|
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
|
|
func escapeBytesBackslash(buf, v []byte) []byte {
|
|
pos := len(buf)
|
|
buf = reserveBuffer(buf, len(v)*2)
|
|
|
|
for _, c := range v {
|
|
switch c {
|
|
case '\x00':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '0'
|
|
pos += 2
|
|
case '\n':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'n'
|
|
pos += 2
|
|
case '\r':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'r'
|
|
pos += 2
|
|
case '\x1a':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'Z'
|
|
pos += 2
|
|
case '\'':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '\''
|
|
pos += 2
|
|
case '"':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '"'
|
|
pos += 2
|
|
case '\\':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '\\'
|
|
pos += 2
|
|
default:
|
|
buf[pos] = c
|
|
pos++
|
|
}
|
|
}
|
|
|
|
return buf[:pos]
|
|
}
|
|
|
|
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
|
|
func escapeStringBackslash(buf []byte, v string) []byte {
|
|
pos := len(buf)
|
|
buf = reserveBuffer(buf, len(v)*2)
|
|
|
|
for i := 0; i < len(v); i++ {
|
|
c := v[i]
|
|
switch c {
|
|
case '\x00':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '0'
|
|
pos += 2
|
|
case '\n':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'n'
|
|
pos += 2
|
|
case '\r':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'r'
|
|
pos += 2
|
|
case '\x1a':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = 'Z'
|
|
pos += 2
|
|
case '\'':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '\''
|
|
pos += 2
|
|
case '"':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '"'
|
|
pos += 2
|
|
case '\\':
|
|
buf[pos] = '\\'
|
|
buf[pos+1] = '\\'
|
|
pos += 2
|
|
default:
|
|
buf[pos] = c
|
|
pos++
|
|
}
|
|
}
|
|
|
|
return buf[:pos]
|
|
}
|
|
|
|
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
|
|
// This escapes the contents of a string by doubling up any apostrophes that
|
|
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
|
|
// effect on the server.
|
|
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
|
|
func escapeBytesQuotes(buf, v []byte) []byte {
|
|
pos := len(buf)
|
|
buf = reserveBuffer(buf, len(v)*2)
|
|
|
|
for _, c := range v {
|
|
if c == '\'' {
|
|
buf[pos] = '\''
|
|
buf[pos+1] = '\''
|
|
pos += 2
|
|
} else {
|
|
buf[pos] = c
|
|
pos++
|
|
}
|
|
}
|
|
|
|
return buf[:pos]
|
|
}
|
|
|
|
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
|
|
func escapeStringQuotes(buf []byte, v string) []byte {
|
|
pos := len(buf)
|
|
buf = reserveBuffer(buf, len(v)*2)
|
|
|
|
for i := 0; i < len(v); i++ {
|
|
c := v[i]
|
|
if c == '\'' {
|
|
buf[pos] = '\''
|
|
buf[pos+1] = '\''
|
|
pos += 2
|
|
} else {
|
|
buf[pos] = c
|
|
pos++
|
|
}
|
|
}
|
|
|
|
return buf[:pos]
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Sync utils *
|
|
******************************************************************************/
|
|
|
|
// noCopy may be embedded into structs which must not be copied
|
|
// after the first use.
|
|
//
|
|
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
|
|
// for details.
|
|
type noCopy struct{}
|
|
|
|
// Lock is a no-op used by -copylocks checker from `go vet`.
|
|
func (*noCopy) Lock() {}
|
|
|
|
// atomicBool is a wrapper around uint32 for usage as a boolean value with
|
|
// atomic access.
|
|
type atomicBool struct {
|
|
_noCopy noCopy
|
|
value uint32
|
|
}
|
|
|
|
// IsSet returns whether the current boolean value is true
|
|
func (ab *atomicBool) IsSet() bool {
|
|
return atomic.LoadUint32(&ab.value) > 0
|
|
}
|
|
|
|
// Set sets the value of the bool regardless of the previous value
|
|
func (ab *atomicBool) Set(value bool) {
|
|
if value {
|
|
atomic.StoreUint32(&ab.value, 1)
|
|
} else {
|
|
atomic.StoreUint32(&ab.value, 0)
|
|
}
|
|
}
|
|
|
|
// TrySet sets the value of the bool and returns whether the value changed
|
|
func (ab *atomicBool) TrySet(value bool) bool {
|
|
if value {
|
|
return atomic.SwapUint32(&ab.value, 1) == 0
|
|
}
|
|
return atomic.SwapUint32(&ab.value, 0) > 0
|
|
}
|
|
|
|
// atomicError is a wrapper for atomically accessed error values
|
|
type atomicError struct {
|
|
_noCopy noCopy
|
|
value atomic.Value
|
|
}
|
|
|
|
// Set sets the error value regardless of the previous value.
|
|
// The value must not be nil
|
|
func (ae *atomicError) Set(value error) {
|
|
ae.value.Store(value)
|
|
}
|
|
|
|
// Value returns the current error value
|
|
func (ae *atomicError) Value() error {
|
|
if v := ae.value.Load(); v != nil {
|
|
// this will panic if the value doesn't implement the error interface
|
|
return v.(error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
|
|
dargs := make([]driver.Value, len(named))
|
|
for n, param := range named {
|
|
if len(param.Name) > 0 {
|
|
// TODO: support the use of Named Parameters #561
|
|
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
|
|
}
|
|
dargs[n] = param.Value
|
|
}
|
|
return dargs, nil
|
|
}
|
|
|
|
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
|
|
switch sql.IsolationLevel(level) {
|
|
case sql.LevelRepeatableRead:
|
|
return "REPEATABLE READ", nil
|
|
case sql.LevelReadCommitted:
|
|
return "READ COMMITTED", nil
|
|
case sql.LevelReadUncommitted:
|
|
return "READ UNCOMMITTED", nil
|
|
case sql.LevelSerializable:
|
|
return "SERIALIZABLE", nil
|
|
default:
|
|
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
|
|
}
|
|
}
|