mirror of
https://github.com/ergochat/ergo.git
synced 2025-05-03 13:17:36 +02:00
463 lines
12 KiB
Go
463 lines
12 KiB
Go
package dkim
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
)
|
|
|
|
type permFailError string
|
|
|
|
func (err permFailError) Error() string {
|
|
return "dkim: " + string(err)
|
|
}
|
|
|
|
// IsPermFail returns true if the error returned by Verify is a permanent
|
|
// failure. A permanent failure is for instance a missing required field or a
|
|
// malformed header.
|
|
func IsPermFail(err error) bool {
|
|
_, ok := err.(permFailError)
|
|
return ok
|
|
}
|
|
|
|
type tempFailError string
|
|
|
|
func (err tempFailError) Error() string {
|
|
return "dkim: " + string(err)
|
|
}
|
|
|
|
// IsTempFail returns true if the error returned by Verify is a temporary
|
|
// failure.
|
|
func IsTempFail(err error) bool {
|
|
_, ok := err.(tempFailError)
|
|
return ok
|
|
}
|
|
|
|
type failError string
|
|
|
|
func (err failError) Error() string {
|
|
return "dkim: " + string(err)
|
|
}
|
|
|
|
// isFail returns true if the error returned by Verify is a signature error.
|
|
func isFail(err error) bool {
|
|
_, ok := err.(failError)
|
|
return ok
|
|
}
|
|
|
|
// ErrTooManySignatures is returned by Verify when the message exceeds the
|
|
// maximum number of signatures.
|
|
var ErrTooManySignatures = errors.New("dkim: too many signatures")
|
|
|
|
var requiredTags = []string{"v", "a", "b", "bh", "d", "h", "s"}
|
|
|
|
// A Verification is produced by Verify when it checks if one signature is
|
|
// valid. If the signature is valid, Err is nil.
|
|
type Verification struct {
|
|
// The SDID claiming responsibility for an introduction of a message into the
|
|
// mail stream.
|
|
Domain string
|
|
// The Agent or User Identifier (AUID) on behalf of which the SDID is taking
|
|
// responsibility.
|
|
Identifier string
|
|
|
|
// The list of signed header fields.
|
|
HeaderKeys []string
|
|
|
|
// The time that this signature was created. If unknown, it's set to zero.
|
|
Time time.Time
|
|
// The expiration time. If the signature doesn't expire, it's set to zero.
|
|
Expiration time.Time
|
|
|
|
// Err is nil if the signature is valid.
|
|
Err error
|
|
}
|
|
|
|
type signature struct {
|
|
i int
|
|
v string
|
|
}
|
|
|
|
// VerifyOptions allows to customize the default signature verification
|
|
// behavior.
|
|
type VerifyOptions struct {
|
|
// LookupTXT returns the DNS TXT records for the given domain name. If nil,
|
|
// net.LookupTXT is used.
|
|
LookupTXT func(domain string) ([]string, error)
|
|
// MaxVerifications controls the maximum number of signature verifications
|
|
// to perform. If more signatures are present, the first MaxVerifications
|
|
// signatures are verified, the rest are ignored and ErrTooManySignatures
|
|
// is returned. If zero, there is no maximum.
|
|
MaxVerifications int
|
|
}
|
|
|
|
// Verify checks if a message's signatures are valid. It returns one
|
|
// verification per signature.
|
|
//
|
|
// There is no guarantee that the reader will be completely consumed.
|
|
func Verify(r io.Reader) ([]*Verification, error) {
|
|
return VerifyWithOptions(r, nil)
|
|
}
|
|
|
|
// VerifyWithOptions performs the same task as Verify, but allows specifying
|
|
// verification options.
|
|
func VerifyWithOptions(r io.Reader, options *VerifyOptions) ([]*Verification, error) {
|
|
// Read header
|
|
bufr := bufio.NewReader(r)
|
|
h, err := readHeader(bufr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Scan header fields for signatures
|
|
var signatures []*signature
|
|
for i, kv := range h {
|
|
k, v := parseHeaderField(kv)
|
|
if strings.EqualFold(k, headerFieldName) {
|
|
signatures = append(signatures, &signature{i, v})
|
|
}
|
|
}
|
|
|
|
tooManySignatures := false
|
|
if options != nil && options.MaxVerifications > 0 && len(signatures) > options.MaxVerifications {
|
|
tooManySignatures = true
|
|
signatures = signatures[:options.MaxVerifications]
|
|
}
|
|
|
|
var verifs []*Verification
|
|
if len(signatures) == 1 {
|
|
// If there is only one signature - just verify it.
|
|
v, err := verify(h, bufr, h[signatures[0].i], signatures[0].v, options)
|
|
if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
|
|
return nil, err
|
|
}
|
|
v.Err = err
|
|
verifs = []*Verification{v}
|
|
} else {
|
|
verifs, err = parallelVerify(bufr, h, signatures, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if tooManySignatures {
|
|
return verifs, ErrTooManySignatures
|
|
}
|
|
return verifs, nil
|
|
}
|
|
|
|
func parallelVerify(r io.Reader, h header, signatures []*signature, options *VerifyOptions) ([]*Verification, error) {
|
|
pipeWriters := make([]*io.PipeWriter, len(signatures))
|
|
// We can't pass pipeWriter to io.MultiWriter directly,
|
|
// we need a slice of io.Writer, but we also need *io.PipeWriter
|
|
// to call Close on it.
|
|
writers := make([]io.Writer, len(signatures))
|
|
chans := make([]chan *Verification, len(signatures))
|
|
|
|
for i, sig := range signatures {
|
|
// Be careful with loop variables and goroutines.
|
|
i, sig := i, sig
|
|
|
|
chans[i] = make(chan *Verification, 1)
|
|
|
|
pr, pw := io.Pipe()
|
|
writers[i] = pw
|
|
pipeWriters[i] = pw
|
|
|
|
go func() {
|
|
v, err := verify(h, pr, h[sig.i], sig.v, options)
|
|
|
|
// Make sure we consume the whole reader, otherwise io.Copy on
|
|
// other side can block forever.
|
|
io.Copy(ioutil.Discard, pr)
|
|
|
|
v.Err = err
|
|
chans[i] <- v
|
|
}()
|
|
}
|
|
|
|
if _, err := io.Copy(io.MultiWriter(writers...), r); err != nil {
|
|
return nil, err
|
|
}
|
|
for _, wr := range pipeWriters {
|
|
wr.Close()
|
|
}
|
|
|
|
verifications := make([]*Verification, len(signatures))
|
|
for i, ch := range chans {
|
|
verifications[i] = <-ch
|
|
}
|
|
|
|
// Return unexpected failures as a separate error.
|
|
for _, v := range verifications {
|
|
err := v.Err
|
|
if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
|
|
v.Err = nil
|
|
return verifications, err
|
|
}
|
|
}
|
|
return verifications, nil
|
|
}
|
|
|
|
func verify(h header, r io.Reader, sigField, sigValue string, options *VerifyOptions) (*Verification, error) {
|
|
verif := new(Verification)
|
|
|
|
params, err := parseHeaderParams(sigValue)
|
|
if err != nil {
|
|
return verif, permFailError("malformed signature tags: " + err.Error())
|
|
}
|
|
|
|
if params["v"] != "1" {
|
|
return verif, permFailError("incompatible signature version")
|
|
}
|
|
|
|
verif.Domain = stripWhitespace(params["d"])
|
|
|
|
for _, tag := range requiredTags {
|
|
if _, ok := params[tag]; !ok {
|
|
return verif, permFailError("signature missing required tag")
|
|
}
|
|
}
|
|
|
|
if i, ok := params["i"]; ok {
|
|
verif.Identifier = stripWhitespace(i)
|
|
if !strings.HasSuffix(verif.Identifier, "@"+verif.Domain) && !strings.HasSuffix(verif.Identifier, "."+verif.Domain) {
|
|
return verif, permFailError("domain mismatch")
|
|
}
|
|
} else {
|
|
verif.Identifier = "@" + verif.Domain
|
|
}
|
|
|
|
headerKeys := parseTagList(params["h"])
|
|
ok := false
|
|
for _, k := range headerKeys {
|
|
if strings.EqualFold(k, "from") {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
if !ok {
|
|
return verif, permFailError("From field not signed")
|
|
}
|
|
verif.HeaderKeys = headerKeys
|
|
|
|
if timeStr, ok := params["t"]; ok {
|
|
t, err := parseTime(timeStr)
|
|
if err != nil {
|
|
return verif, permFailError("malformed time: " + err.Error())
|
|
}
|
|
verif.Time = t
|
|
}
|
|
if expiresStr, ok := params["x"]; ok {
|
|
t, err := parseTime(expiresStr)
|
|
if err != nil {
|
|
return verif, permFailError("malformed expiration time: " + err.Error())
|
|
}
|
|
verif.Expiration = t
|
|
if now().After(t) {
|
|
return verif, permFailError("signature has expired")
|
|
}
|
|
}
|
|
|
|
// Query public key
|
|
// TODO: compute hash in parallel
|
|
methods := []string{string(QueryMethodDNSTXT)}
|
|
if methodsStr, ok := params["q"]; ok {
|
|
methods = parseTagList(methodsStr)
|
|
}
|
|
var res *queryResult
|
|
for _, method := range methods {
|
|
if query, ok := queryMethods[QueryMethod(method)]; ok {
|
|
if options != nil {
|
|
res, err = query(verif.Domain, stripWhitespace(params["s"]), options.LookupTXT)
|
|
} else {
|
|
res, err = query(verif.Domain, stripWhitespace(params["s"]), nil)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
return verif, err
|
|
} else if res == nil {
|
|
return verif, permFailError("unsupported public key query method")
|
|
}
|
|
|
|
// Parse algos
|
|
keyAlgo, hashAlgo, ok := strings.Cut(stripWhitespace(params["a"]), "-")
|
|
if !ok {
|
|
return verif, permFailError("malformed algorithm name")
|
|
}
|
|
|
|
// Check hash algo
|
|
if res.HashAlgos != nil {
|
|
ok := false
|
|
for _, algo := range res.HashAlgos {
|
|
if algo == hashAlgo {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
if !ok {
|
|
return verif, permFailError("inappropriate hash algorithm")
|
|
}
|
|
}
|
|
var hash crypto.Hash
|
|
switch hashAlgo {
|
|
case "sha1":
|
|
// RFC 8301 section 3.1: rsa-sha1 MUST NOT be used for signing or
|
|
// verifying.
|
|
return verif, permFailError(fmt.Sprintf("hash algorithm too weak: %v", hashAlgo))
|
|
case "sha256":
|
|
hash = crypto.SHA256
|
|
default:
|
|
return verif, permFailError("unsupported hash algorithm")
|
|
}
|
|
|
|
// Check key algo
|
|
if res.KeyAlgo != keyAlgo {
|
|
return verif, permFailError("inappropriate key algorithm")
|
|
}
|
|
|
|
if res.Services != nil {
|
|
ok := false
|
|
for _, s := range res.Services {
|
|
if s == "email" {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
if !ok {
|
|
return verif, permFailError("inappropriate service")
|
|
}
|
|
}
|
|
|
|
headerCan, bodyCan := parseCanonicalization(params["c"])
|
|
if _, ok := canonicalizers[headerCan]; !ok {
|
|
return verif, permFailError("unsupported header canonicalization algorithm")
|
|
}
|
|
if _, ok := canonicalizers[bodyCan]; !ok {
|
|
return verif, permFailError("unsupported body canonicalization algorithm")
|
|
}
|
|
|
|
// The body length "l" parameter is insecure, because it allows parts of
|
|
// the message body to not be signed. Reject messages which have it set.
|
|
if _, ok := params["l"]; ok {
|
|
// TODO: technically should be policyError
|
|
return verif, failError("message contains an insecure body length tag")
|
|
}
|
|
|
|
// Parse body hash and signature
|
|
bodyHashed, err := decodeBase64String(params["bh"])
|
|
if err != nil {
|
|
return verif, permFailError("malformed body hash: " + err.Error())
|
|
}
|
|
sig, err := decodeBase64String(params["b"])
|
|
if err != nil {
|
|
return verif, permFailError("malformed signature: " + err.Error())
|
|
}
|
|
|
|
// Check body hash
|
|
hasher := hash.New()
|
|
wc := canonicalizers[bodyCan].CanonicalizeBody(hasher)
|
|
if _, err := io.Copy(wc, r); err != nil {
|
|
return verif, err
|
|
}
|
|
if err := wc.Close(); err != nil {
|
|
return verif, err
|
|
}
|
|
if subtle.ConstantTimeCompare(hasher.Sum(nil), bodyHashed) != 1 {
|
|
return verif, failError("body hash did not verify")
|
|
}
|
|
|
|
// Compute data hash
|
|
hasher.Reset()
|
|
picker := newHeaderPicker(h)
|
|
for _, key := range headerKeys {
|
|
kv := picker.Pick(key)
|
|
if kv == "" {
|
|
// The field MAY contain names of header fields that do not exist
|
|
// when signed; nonexistent header fields do not contribute to the
|
|
// signature computation
|
|
continue
|
|
}
|
|
|
|
kv = canonicalizers[headerCan].CanonicalizeHeader(kv)
|
|
if _, err := hasher.Write([]byte(kv)); err != nil {
|
|
return verif, err
|
|
}
|
|
}
|
|
canSigField := removeSignature(sigField)
|
|
canSigField = canonicalizers[headerCan].CanonicalizeHeader(canSigField)
|
|
canSigField = strings.TrimRight(canSigField, "\r\n")
|
|
if _, err := hasher.Write([]byte(canSigField)); err != nil {
|
|
return verif, err
|
|
}
|
|
hashed := hasher.Sum(nil)
|
|
|
|
// Check signature
|
|
if err := res.Verifier.Verify(hash, hashed, sig); err != nil {
|
|
return verif, failError("signature did not verify: " + err.Error())
|
|
}
|
|
|
|
return verif, nil
|
|
}
|
|
|
|
func parseTagList(s string) []string {
|
|
tags := strings.Split(s, ":")
|
|
for i, t := range tags {
|
|
tags[i] = stripWhitespace(t)
|
|
}
|
|
return tags
|
|
}
|
|
|
|
func parseCanonicalization(s string) (headerCan, bodyCan Canonicalization) {
|
|
headerCan = CanonicalizationSimple
|
|
bodyCan = CanonicalizationSimple
|
|
|
|
cans := strings.SplitN(stripWhitespace(s), "/", 2)
|
|
if cans[0] != "" {
|
|
headerCan = Canonicalization(cans[0])
|
|
}
|
|
if len(cans) > 1 {
|
|
bodyCan = Canonicalization(cans[1])
|
|
}
|
|
return
|
|
}
|
|
|
|
func parseTime(s string) (time.Time, error) {
|
|
sec, err := strconv.ParseInt(stripWhitespace(s), 10, 64)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
return time.Unix(sec, 0), nil
|
|
}
|
|
|
|
func decodeBase64String(s string) ([]byte, error) {
|
|
return base64.StdEncoding.DecodeString(stripWhitespace(s))
|
|
}
|
|
|
|
func stripWhitespace(s string) string {
|
|
return strings.Map(func(r rune) rune {
|
|
if unicode.IsSpace(r) {
|
|
return -1
|
|
}
|
|
return r
|
|
}, s)
|
|
}
|
|
|
|
var sigRegex = regexp.MustCompile(`(b\s*=)[^;]+`)
|
|
|
|
func removeSignature(s string) string {
|
|
return sigRegex.ReplaceAllString(s, "$1")
|
|
}
|