| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package |
| // |
| // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. |
| // |
| // This Source Code Form is subject to the terms of the Mozilla Public |
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| // You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| package mysql |
| |
| import ( |
| "crypto/sha1" |
| "crypto/tls" |
| "database/sql/driver" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "net/url" |
| "strings" |
| "time" |
| ) |
| |
| var ( |
| tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs |
| |
| errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") |
| errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") |
| errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") |
| ) |
| |
| func init() { |
| tlsConfigRegister = make(map[string]*tls.Config) |
| } |
| |
| // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. |
| // Use the key as a value in the DSN where tls=value. |
| // |
| // rootCertPool := x509.NewCertPool() |
| // pem, err := ioutil.ReadFile("/path/ca-cert.pem") |
| // if err != nil { |
| // log.Fatal(err) |
| // } |
| // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { |
| // log.Fatal("Failed to append PEM.") |
| // } |
| // clientCert := make([]tls.Certificate, 0, 1) |
| // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") |
| // if err != nil { |
| // log.Fatal(err) |
| // } |
| // clientCert = append(clientCert, certs) |
| // mysql.RegisterTLSConfig("custom", &tls.Config{ |
| // RootCAs: rootCertPool, |
| // Certificates: clientCert, |
| // }) |
| // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") |
| // |
| func RegisterTLSConfig(key string, config *tls.Config) error { |
| if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { |
| return fmt.Errorf("Key '%s' is reserved", key) |
| } |
| |
| tlsConfigRegister[key] = config |
| return nil |
| } |
| |
| // DeregisterTLSConfig removes the tls.Config associated with key. |
| func DeregisterTLSConfig(key string) { |
| delete(tlsConfigRegister, key) |
| } |
| |
| // parseDSN parses the DSN string to a config |
| func parseDSN(dsn string) (cfg *config, err error) { |
| // New config with some default values |
| cfg = &config{ |
| loc: time.UTC, |
| collation: defaultCollation, |
| } |
| |
| // TODO: use strings.IndexByte when we can depend on Go 1.2 |
| |
| // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] |
| // Find the last '/' (since the password or the net addr might contain a '/') |
| foundSlash := false |
| for i := len(dsn) - 1; i >= 0; i-- { |
| if dsn[i] == '/' { |
| foundSlash = true |
| var j, k int |
| |
| // left part is empty if i <= 0 |
| if i > 0 { |
| // [username[:password]@][protocol[(address)]] |
| // Find the last '@' in dsn[:i] |
| for j = i; j >= 0; j-- { |
| if dsn[j] == '@' { |
| // username[:password] |
| // Find the first ':' in dsn[:j] |
| for k = 0; k < j; k++ { |
| if dsn[k] == ':' { |
| cfg.passwd = dsn[k+1 : j] |
| break |
| } |
| } |
| cfg.user = dsn[:k] |
| |
| break |
| } |
| } |
| |
| // [protocol[(address)]] |
| // Find the first '(' in dsn[j+1:i] |
| for k = j + 1; k < i; k++ { |
| if dsn[k] == '(' { |
| // dsn[i-1] must be == ')' if an address is specified |
| if dsn[i-1] != ')' { |
| if strings.ContainsRune(dsn[k+1:i], ')') { |
| return nil, errInvalidDSNUnescaped |
| } |
| return nil, errInvalidDSNAddr |
| } |
| cfg.addr = dsn[k+1 : i-1] |
| break |
| } |
| } |
| cfg.net = dsn[j+1 : k] |
| } |
| |
| // dbname[?param1=value1&...¶mN=valueN] |
| // Find the first '?' in dsn[i+1:] |
| for j = i + 1; j < len(dsn); j++ { |
| if dsn[j] == '?' { |
| if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { |
| return |
| } |
| break |
| } |
| } |
| cfg.dbname = dsn[i+1 : j] |
| |
| break |
| } |
| } |
| |
| if !foundSlash && len(dsn) > 0 { |
| return nil, errInvalidDSNNoSlash |
| } |
| |
| // Set default network if empty |
| if cfg.net == "" { |
| cfg.net = "tcp" |
| } |
| |
| // Set default address if empty |
| if cfg.addr == "" { |
| switch cfg.net { |
| case "tcp": |
| cfg.addr = "127.0.0.1:3306" |
| case "unix": |
| cfg.addr = "/tmp/mysql.sock" |
| default: |
| return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") |
| } |
| |
| } |
| |
| return |
| } |
| |
| // parseDSNParams parses the DSN "query string" |
| // Values must be url.QueryEscape'ed |
| func parseDSNParams(cfg *config, params string) (err error) { |
| for _, v := range strings.Split(params, "&") { |
| param := strings.SplitN(v, "=", 2) |
| if len(param) != 2 { |
| continue |
| } |
| |
| // cfg params |
| switch value := param[1]; param[0] { |
| |
| // Disable INFILE whitelist / enable all files |
| case "allowAllFiles": |
| var isBool bool |
| cfg.allowAllFiles, isBool = readBool(value) |
| if !isBool { |
| return fmt.Errorf("Invalid Bool value: %s", value) |
| } |
| |
| // Use old authentication mode (pre MySQL 4.1) |
| case "allowOldPasswords": |
| var isBool bool |
| cfg.allowOldPasswords, isBool = readBool(value) |
| if !isBool { |
| return fmt.Errorf("Invalid Bool value: %s", value) |
| } |
| |
| // Switch "rowsAffected" mode |
| case "clientFoundRows": |
| var isBool bool |
| cfg.clientFoundRows, isBool = readBool(value) |
| if !isBool { |
| return fmt.Errorf("Invalid Bool value: %s", value) |
| } |
| |
| // Collation |
| case "collation": |
| collation, ok := collations[value] |
| if !ok { |
| // Note possibility for false negatives: |
| // could be triggered although the collation is valid if the |
| // collations map does not contain entries the server supports. |
| err = errors.New("unknown collation") |
| return |
| } |
| cfg.collation = collation |
| break |
| |
| // Time Location |
| case "loc": |
| if value, err = url.QueryUnescape(value); err != nil { |
| return |
| } |
| cfg.loc, err = time.LoadLocation(value) |
| if err != nil { |
| return |
| } |
| |
| // Dial Timeout |
| case "timeout": |
| cfg.timeout, err = time.ParseDuration(value) |
| if err != nil { |
| return |
| } |
| |
| // TLS-Encryption |
| case "tls": |
| boolValue, isBool := readBool(value) |
| if isBool { |
| if boolValue { |
| cfg.tls = &tls.Config{} |
| } |
| } else { |
| if strings.ToLower(value) == "skip-verify" { |
| cfg.tls = &tls.Config{InsecureSkipVerify: true} |
| } else if tlsConfig, ok := tlsConfigRegister[value]; ok { |
| cfg.tls = tlsConfig |
| } else { |
| return fmt.Errorf("Invalid value / unknown config name: %s", value) |
| } |
| } |
| |
| default: |
| // lazy init |
| if cfg.params == nil { |
| cfg.params = make(map[string]string) |
| } |
| |
| if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { |
| return |
| } |
| } |
| } |
| |
| return |
| } |
| |
| // Returns the bool value of the input. |
| // The 2nd return value indicates if the input was a valid bool value |
| func readBool(input string) (value bool, valid bool) { |
| switch input { |
| case "1", "true", "TRUE", "True": |
| return true, true |
| case "0", "false", "FALSE", "False": |
| return false, true |
| } |
| |
| // Not a valid bool value |
| return |
| } |
| |
| /****************************************************************************** |
| * Authentication * |
| ******************************************************************************/ |
| |
| // Encrypt password using 4.1+ method |
| func scramblePassword(scramble, password []byte) []byte { |
| if len(password) == 0 { |
| return nil |
| } |
| |
| // stage1Hash = SHA1(password) |
| crypt := sha1.New() |
| crypt.Write(password) |
| stage1 := crypt.Sum(nil) |
| |
| // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) |
| // inner Hash |
| crypt.Reset() |
| crypt.Write(stage1) |
| hash := crypt.Sum(nil) |
| |
| // outer Hash |
| crypt.Reset() |
| crypt.Write(scramble) |
| crypt.Write(hash) |
| scramble = crypt.Sum(nil) |
| |
| // token = scrambleHash XOR stage1Hash |
| for i := range scramble { |
| scramble[i] ^= stage1[i] |
| } |
| return scramble |
| } |
| |
| // Encrypt password using pre 4.1 (old password) method |
| // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c |
| type myRnd struct { |
| seed1, seed2 uint32 |
| } |
| |
| const myRndMaxVal = 0x3FFFFFFF |
| |
| // Pseudo random number generator |
| func newMyRnd(seed1, seed2 uint32) *myRnd { |
| return &myRnd{ |
| seed1: seed1 % myRndMaxVal, |
| seed2: seed2 % myRndMaxVal, |
| } |
| } |
| |
| // Tested to be equivalent to MariaDB's floating point variant |
| // http://play.golang.org/p/QHvhd4qved |
| // http://play.golang.org/p/RG0q4ElWDx |
| func (r *myRnd) NextByte() byte { |
| r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal |
| r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal |
| |
| return byte(uint64(r.seed1) * 31 / myRndMaxVal) |
| } |
| |
| // Generate binary hash from byte string using insecure pre 4.1 method |
| func pwHash(password []byte) (result [2]uint32) { |
| var add uint32 = 7 |
| var tmp uint32 |
| |
| result[0] = 1345345333 |
| result[1] = 0x12345671 |
| |
| for _, c := range password { |
| // skip spaces and tabs in password |
| if c == ' ' || c == '\t' { |
| continue |
| } |
| |
| tmp = uint32(c) |
| result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) |
| result[1] += (result[1] << 8) ^ result[0] |
| add += tmp |
| } |
| |
| // Remove sign bit (1<<31)-1) |
| result[0] &= 0x7FFFFFFF |
| result[1] &= 0x7FFFFFFF |
| |
| return |
| } |
| |
| // Encrypt password using insecure pre 4.1 method |
| func scrambleOldPassword(scramble, password []byte) []byte { |
| if len(password) == 0 { |
| return nil |
| } |
| |
| scramble = scramble[:8] |
| |
| hashPw := pwHash(password) |
| hashSc := pwHash(scramble) |
| |
| r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) |
| |
| var out [8]byte |
| for i := range out { |
| out[i] = r.NextByte() + 64 |
| } |
| |
| mask := r.NextByte() |
| for i := range out { |
| out[i] ^= mask |
| } |
| |
| return out[:] |
| } |
| |
| /****************************************************************************** |
| * Time related utils * |
| ******************************************************************************/ |
| |
| // NullTime represents a time.Time that may be NULL. |
| // NullTime implements the Scanner interface so |
| // it can be used as a scan destination: |
| // |
| // var nt NullTime |
| // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) |
| // ... |
| // if nt.Valid { |
| // // use nt.Time |
| // } else { |
| // // NULL value |
| // } |
| // |
| // This NullTime implementation is not driver-specific |
| type NullTime struct { |
| Time time.Time |
| Valid bool // Valid is true if Time is not NULL |
| } |
| |
| // Scan implements the Scanner interface. |
| // The value type must be time.Time or string / []byte (formatted time-string), |
| // otherwise Scan fails. |
| func (nt *NullTime) Scan(value interface{}) (err error) { |
| if value == nil { |
| nt.Time, nt.Valid = time.Time{}, false |
| return |
| } |
| |
| switch v := value.(type) { |
| case time.Time: |
| nt.Time, nt.Valid = v, true |
| return |
| case []byte: |
| nt.Time, err = parseDateTime(string(v), time.UTC) |
| nt.Valid = (err == nil) |
| return |
| case string: |
| nt.Time, err = parseDateTime(v, time.UTC) |
| nt.Valid = (err == nil) |
| return |
| } |
| |
| nt.Valid = false |
| return fmt.Errorf("Can't convert %T to time.Time", value) |
| } |
| |
| // Value implements the driver Valuer interface. |
| func (nt NullTime) Value() (driver.Value, error) { |
| if !nt.Valid { |
| return nil, nil |
| } |
| return nt.Time, nil |
| } |
| |
| func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { |
| base := "0000-00-00 00:00:00.0000000" |
| switch len(str) { |
| case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" |
| if str == base[:len(str)] { |
| return |
| } |
| t, err = time.Parse(timeFormat[:len(str)], str) |
| default: |
| err = fmt.Errorf("Invalid Time-String: %s", str) |
| return |
| } |
| |
| // Adjust location |
| if err == nil && loc != time.UTC { |
| y, mo, d := t.Date() |
| h, mi, s := t.Clock() |
| t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil |
| } |
| |
| return |
| } |
| |
| func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { |
| switch num { |
| case 0: |
| return time.Time{}, nil |
| case 4: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| 0, 0, 0, 0, |
| loc, |
| ), nil |
| case 7: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| int(data[4]), // hour |
| int(data[5]), // minutes |
| int(data[6]), // seconds |
| 0, |
| loc, |
| ), nil |
| case 11: |
| return time.Date( |
| int(binary.LittleEndian.Uint16(data[:2])), // year |
| time.Month(data[2]), // month |
| int(data[3]), // day |
| int(data[4]), // hour |
| int(data[5]), // minutes |
| int(data[6]), // seconds |
| int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds |
| loc, |
| ), nil |
| } |
| return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) |
| } |
| |
| // zeroDateTime is used in formatBinaryDateTime to avoid an allocation |
| // if the DATE or DATETIME has the zero value. |
| // It must never be changed. |
| // The current behavior depends on database/sql copying the result. |
| var zeroDateTime = []byte("0000-00-00 00:00:00.000000") |
| |
| func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { |
| // length expects the deterministic length of the zero value, |
| // negative time and 100+ hours are automatically added if needed |
| const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" |
| const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" |
| if len(src) == 0 { |
| if justTime { |
| return zeroDateTime[11 : 11+length], nil |
| } |
| return zeroDateTime[:length], nil |
| } |
| var dst []byte // return value |
| var pt, p1, p2, p3 byte // current digit pair |
| var zOffs byte // offset of value in zeroDateTime |
| if justTime { |
| switch length { |
| case |
| 8, // time (can be up to 10 when negative and 100+ hours) |
| 10, 11, 12, 13, 14, 15: // time with fractional seconds |
| default: |
| return nil, fmt.Errorf("illegal TIME length %d", length) |
| } |
| switch len(src) { |
| case 8, 12: |
| default: |
| return nil, fmt.Errorf("Invalid TIME-packet length %d", len(src)) |
| } |
| // +2 to enable negative time and 100+ hours |
| dst = make([]byte, 0, length+2) |
| if src[0] == 1 { |
| dst = append(dst, '-') |
| } |
| if src[1] != 0 { |
| hour := uint16(src[1])*24 + uint16(src[5]) |
| pt = byte(hour / 100) |
| p1 = byte(hour - 100*uint16(pt)) |
| dst = append(dst, digits01[pt]) |
| } else { |
| p1 = src[5] |
| } |
| zOffs = 11 |
| src = src[6:] |
| } else { |
| switch length { |
| case 10, 19, 21, 22, 23, 24, 25, 26: |
| default: |
| t := "DATE" |
| if length > 10 { |
| t += "TIME" |
| } |
| return nil, fmt.Errorf("illegal %s length %d", t, length) |
| } |
| switch len(src) { |
| case 4, 7, 11: |
| default: |
| t := "DATE" |
| if length > 10 { |
| t += "TIME" |
| } |
| return nil, fmt.Errorf("illegal %s-packet length %d", t, len(src)) |
| } |
| dst = make([]byte, 0, length) |
| // start with the date |
| year := binary.LittleEndian.Uint16(src[:2]) |
| pt = byte(year / 100) |
| p1 = byte(year - 100*uint16(pt)) |
| p2, p3 = src[2], src[3] |
| dst = append(dst, |
| digits10[pt], digits01[pt], |
| digits10[p1], digits01[p1], '-', |
| digits10[p2], digits01[p2], '-', |
| digits10[p3], digits01[p3], |
| ) |
| if length == 10 { |
| return dst, nil |
| } |
| if len(src) == 4 { |
| return append(dst, zeroDateTime[10:length]...), nil |
| } |
| dst = append(dst, ' ') |
| p1 = src[4] // hour |
| src = src[5:] |
| } |
| // p1 is 2-digit hour, src is after hour |
| p2, p3 = src[0], src[1] |
| dst = append(dst, |
| digits10[p1], digits01[p1], ':', |
| digits10[p2], digits01[p2], ':', |
| digits10[p3], digits01[p3], |
| ) |
| if length <= byte(len(dst)) { |
| return dst, nil |
| } |
| src = src[2:] |
| if len(src) == 0 { |
| return append(dst, zeroDateTime[19:zOffs+length]...), nil |
| } |
| microsecs := binary.LittleEndian.Uint32(src[:4]) |
| p1 = byte(microsecs / 10000) |
| microsecs -= 10000 * uint32(p1) |
| p2 = byte(microsecs / 100) |
| microsecs -= 100 * uint32(p2) |
| p3 = byte(microsecs) |
| switch decimals := zOffs + length - 20; decimals { |
| default: |
| return append(dst, '.', |
| digits10[p1], digits01[p1], |
| digits10[p2], digits01[p2], |
| digits10[p3], digits01[p3], |
| ), nil |
| case 1: |
| return append(dst, '.', |
| digits10[p1], |
| ), nil |
| case 2: |
| return append(dst, '.', |
| digits10[p1], digits01[p1], |
| ), nil |
| case 3: |
| return append(dst, '.', |
| digits10[p1], digits01[p1], |
| digits10[p2], |
| ), nil |
| case 4: |
| return append(dst, '.', |
| digits10[p1], digits01[p1], |
| digits10[p2], digits01[p2], |
| ), nil |
| case 5: |
| return append(dst, '.', |
| digits10[p1], digits01[p1], |
| digits10[p2], digits01[p2], |
| digits10[p3], |
| ), nil |
| } |
| } |
| |
| /****************************************************************************** |
| * Convert from and to bytes * |
| ******************************************************************************/ |
| |
| func uint64ToBytes(n uint64) []byte { |
| return []byte{ |
| byte(n), |
| byte(n >> 8), |
| byte(n >> 16), |
| byte(n >> 24), |
| byte(n >> 32), |
| byte(n >> 40), |
| byte(n >> 48), |
| byte(n >> 56), |
| } |
| } |
| |
| func uint64ToString(n uint64) []byte { |
| var a [20]byte |
| i := 20 |
| |
| // U+0030 = 0 |
| // ... |
| // U+0039 = 9 |
| |
| var q uint64 |
| for n >= 10 { |
| i-- |
| q = n / 10 |
| a[i] = uint8(n-q*10) + 0x30 |
| n = q |
| } |
| |
| i-- |
| a[i] = uint8(n) + 0x30 |
| |
| return a[i:] |
| } |
| |
| // treats string value as unsigned integer representation |
| func stringToInt(b []byte) int { |
| val := 0 |
| for i := range b { |
| val *= 10 |
| val += int(b[i] - 0x30) |
| } |
| return val |
| } |
| |
| // returns the string read as a bytes slice, wheter the value is NULL, |
| // the number of bytes read and an error, in case the string is longer than |
| // the input slice |
| func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { |
| // Get length |
| num, isNull, n := readLengthEncodedInteger(b) |
| if num < 1 { |
| return b[n:n], isNull, n, nil |
| } |
| |
| n += int(num) |
| |
| // Check data length |
| if len(b) >= n { |
| return b[n-int(num) : n], false, n, nil |
| } |
| return nil, false, n, io.EOF |
| } |
| |
| // returns the number of bytes skipped and an error, in case the string is |
| // longer than the input slice |
| func skipLengthEncodedString(b []byte) (int, error) { |
| // Get length |
| num, _, n := readLengthEncodedInteger(b) |
| if num < 1 { |
| return n, nil |
| } |
| |
| n += int(num) |
| |
| // Check data length |
| if len(b) >= n { |
| return n, nil |
| } |
| return n, io.EOF |
| } |
| |
| // returns the number read, whether the value is NULL and the number of bytes read |
| func readLengthEncodedInteger(b []byte) (uint64, bool, int) { |
| switch b[0] { |
| |
| // 251: NULL |
| case 0xfb: |
| return 0, true, 1 |
| |
| // 252: value of following 2 |
| case 0xfc: |
| return uint64(b[1]) | uint64(b[2])<<8, false, 3 |
| |
| // 253: value of following 3 |
| case 0xfd: |
| return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 |
| |
| // 254: value of following 8 |
| case 0xfe: |
| return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | |
| uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | |
| uint64(b[7])<<48 | uint64(b[8])<<56, |
| false, 9 |
| } |
| |
| // 0-250: value of first byte |
| return uint64(b[0]), false, 1 |
| } |
| |
| // encodes a uint64 value and appends it to the given bytes slice |
| func appendLengthEncodedInteger(b []byte, n uint64) []byte { |
| switch { |
| case n <= 250: |
| return append(b, byte(n)) |
| |
| case n <= 0xffff: |
| return append(b, 0xfc, byte(n), byte(n>>8)) |
| |
| case n <= 0xffffff: |
| return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) |
| } |
| return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), |
| byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) |
| } |