mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-22 02:04:10 +01:00
324 lines
9.4 KiB
Go
324 lines
9.4 KiB
Go
package webpush
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/aes"
|
||
"crypto/cipher"
|
||
"crypto/ecdh"
|
||
"crypto/rand"
|
||
"crypto/sha256"
|
||
"encoding/base64"
|
||
"encoding/binary"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"golang.org/x/crypto/hkdf"
|
||
)
|
||
|
||
const MaxRecordSize uint32 = 4096
|
||
|
||
var (
|
||
ErrRecordSizeTooSmall = errors.New("record size too small for message")
|
||
|
||
invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)")
|
||
|
||
defaultHTTPClient = &http.Client{}
|
||
)
|
||
|
||
// HTTPClient is an interface for sending the notification HTTP request / testing
|
||
type HTTPClient interface {
|
||
Do(*http.Request) (*http.Response, error)
|
||
}
|
||
|
||
// Options are config and extra params needed to send a notification
|
||
type Options struct {
|
||
HTTPClient HTTPClient // Will replace with *http.Client by default if not included
|
||
RecordSize uint32 // Limit the record size
|
||
Subscriber string // Sub in VAPID JWT token
|
||
Topic string // Set the Topic header to collapse a pending messages (Optional)
|
||
TTL int // Set the TTL on the endpoint POST request, in seconds
|
||
Urgency Urgency // Set the Urgency header to change a message priority (Optional)
|
||
VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header
|
||
VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours)
|
||
}
|
||
|
||
// Keys represents a subscription's keys (its ECDH public key on the P-256 curve
|
||
// and its 16-byte authentication secret).
|
||
type Keys struct {
|
||
Auth [16]byte
|
||
P256dh *ecdh.PublicKey
|
||
}
|
||
|
||
// Equal compares two Keys for equality.
|
||
func (k *Keys) Equal(o Keys) bool {
|
||
return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh)
|
||
}
|
||
|
||
var _ json.Marshaler = (*Keys)(nil)
|
||
var _ json.Unmarshaler = (*Keys)(nil)
|
||
|
||
type marshaledKeys struct {
|
||
Auth string `json:"auth"`
|
||
P256dh string `json:"p256dh"`
|
||
}
|
||
|
||
// MarshalJSON implements json.Marshaler, allowing serialization to JSON.
|
||
func (k *Keys) MarshalJSON() ([]byte, error) {
|
||
m := marshaledKeys{
|
||
Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]),
|
||
P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()),
|
||
}
|
||
return json.Marshal(&m)
|
||
}
|
||
|
||
// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON.
|
||
func (k *Keys) UnmarshalJSON(b []byte) (err error) {
|
||
var m marshaledKeys
|
||
if err := json.Unmarshal(b, &m); err != nil {
|
||
return err
|
||
}
|
||
authBytes, err := decodeSubscriptionKey(m.Auth)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(authBytes) != 16 {
|
||
return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes))
|
||
}
|
||
copy(k.Auth[:], authBytes)
|
||
rawDHKey, err := decodeSubscriptionKey(m.P256dh)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey)
|
||
return err
|
||
}
|
||
|
||
// DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys
|
||
// (the authentication secret and ECDH public key).
|
||
func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) {
|
||
authBytes, err := decodeSubscriptionKey(auth)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if len(authBytes) != 16 {
|
||
err = invalidAuthKeyLength
|
||
return
|
||
}
|
||
copy(keys.Auth[:], authBytes)
|
||
dhBytes, err := decodeSubscriptionKey(p256dh)
|
||
if err != nil {
|
||
return
|
||
}
|
||
keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes)
|
||
if err != nil {
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
// Subscription represents a PushSubscription object from the Push API
|
||
type Subscription struct {
|
||
Endpoint string `json:"endpoint"`
|
||
Keys Keys `json:"keys"`
|
||
}
|
||
|
||
// SendNotification sends a push notification to a subscription's endpoint,
|
||
// applying encryption (RFC 8291) and adding a VAPID header (RFC 8292).
|
||
func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) {
|
||
// Compose message body (RFC8291 encryption of the message)
|
||
body, err := EncryptNotification(message, s.Keys, options.RecordSize)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Get VAPID Authorization header
|
||
vapidAuthHeader, err := getVAPIDAuthorizationHeader(
|
||
s.Endpoint,
|
||
options.Subscriber,
|
||
options.VAPIDKeys,
|
||
options.VapidExpiration,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Compose and send the HTTP request
|
||
return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body)
|
||
}
|
||
|
||
// EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push
|
||
// (RFC 8188's aes128gcm content-encoding, with the key material derived from
|
||
// elliptic curve Diffie-Hellman over the P-256 curve).
|
||
func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) {
|
||
// Get the record size
|
||
if recordSize == 0 {
|
||
recordSize = MaxRecordSize
|
||
} else if recordSize < 128 {
|
||
return nil, ErrRecordSizeTooSmall
|
||
}
|
||
|
||
// Allocate buffer to hold the eventual message
|
||
// [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes
|
||
// the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ]
|
||
recordBuf := make([]byte, recordSize)
|
||
// remainingBuf tracks our current writing position in recordBuf:
|
||
remainingBuf := recordBuf
|
||
|
||
// Application server key pairs (single use)
|
||
localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
localPublicKey := localPrivateKey.PublicKey()
|
||
|
||
// Encryption Content-Coding Header
|
||
// +-----------+--------+-----------+---------------+
|
||
// | salt (16) | rs (4) | idlen (1) | keyid (idlen) |
|
||
// +-----------+--------+-----------+---------------+
|
||
// in our case the keyid is localPublicKey.Bytes(), so 65 bytes
|
||
// First, generate the salt
|
||
_, err = rand.Read(remainingBuf[:16])
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
salt := remainingBuf[:16]
|
||
remainingBuf = remainingBuf[16:]
|
||
binary.BigEndian.PutUint32(remainingBuf[:], recordSize)
|
||
remainingBuf = remainingBuf[4:]
|
||
localPublicKeyBytes := localPublicKey.Bytes()
|
||
remainingBuf[0] = byte(len(localPublicKeyBytes))
|
||
remainingBuf = remainingBuf[1:]
|
||
copy(remainingBuf[:], localPublicKeyBytes)
|
||
remainingBuf = remainingBuf[len(localPublicKeyBytes):]
|
||
|
||
// Combine application keys with receiver's EC public key to derive ECDH shared secret
|
||
sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("deriving shared secret: %w", err)
|
||
}
|
||
|
||
// ikm
|
||
prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00"))
|
||
prkInfoBuf.Write(keys.P256dh.Bytes())
|
||
prkInfoBuf.Write(localPublicKey.Bytes())
|
||
|
||
prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes())
|
||
ikm, err := getHKDFKey(prkHKDF, 32)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Derive Content Encryption Key
|
||
contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00")
|
||
contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo)
|
||
contentEncryptionKey, err := getHKDFKey(contentHKDF, 16)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Derive the Nonce
|
||
nonceInfo := []byte("Content-Encoding: nonce\x00")
|
||
nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo)
|
||
nonce, err := getHKDFKey(nonceHKDF, 12)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Cipher
|
||
c, err := aes.NewCipher(contentEncryptionKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
gcm, err := cipher.NewGCM(c)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag
|
||
if len(remainingBuf) < len(message)+17 {
|
||
return nil, ErrRecordSizeTooSmall
|
||
}
|
||
// Copy the message plaintext into the buffer
|
||
copy(remainingBuf[:], message[:])
|
||
// The plaintext to be encrypted will include the padding delimiter and the padding;
|
||
// cut off the final 16 bytes that are reserved for the AEAD tag
|
||
plaintext := remainingBuf[:len(remainingBuf)-16]
|
||
remainingBuf = remainingBuf[len(message):]
|
||
// Add padding delimiter
|
||
remainingBuf[0] = '\x02'
|
||
remainingBuf = remainingBuf[1:]
|
||
// The rest of the buffer is already zero-padded
|
||
|
||
// Encipher the plaintext in place, then add the AEAD tag at the end.
|
||
// "To reuse plaintext's storage for the encrypted output, use plaintext[:0]
|
||
// as dst. Otherwise, the remaining capacity of dst must not overlap plaintext."
|
||
gcm.Seal(plaintext[:0], nonce, plaintext, nil)
|
||
|
||
return recordBuf, nil
|
||
}
|
||
|
||
func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) {
|
||
// POST request
|
||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if ctx != nil {
|
||
req = req.WithContext(ctx)
|
||
}
|
||
|
||
req.Header.Set("Content-Encoding", "aes128gcm")
|
||
req.Header.Set("Content-Type", "application/octet-stream")
|
||
req.Header.Set("TTL", strconv.Itoa(options.TTL))
|
||
|
||
// Сheck the optional headers
|
||
if len(options.Topic) > 0 {
|
||
req.Header.Set("Topic", options.Topic)
|
||
}
|
||
|
||
if isValidUrgency(options.Urgency) {
|
||
req.Header.Set("Urgency", string(options.Urgency))
|
||
}
|
||
|
||
req.Header.Set("Authorization", vapidAuthHeader)
|
||
|
||
// Send the request
|
||
var client HTTPClient
|
||
if options.HTTPClient != nil {
|
||
client = options.HTTPClient
|
||
} else {
|
||
client = defaultHTTPClient
|
||
}
|
||
|
||
return client.Do(req)
|
||
}
|
||
|
||
// decodeSubscriptionKey decodes a base64 subscription key.
|
||
func decodeSubscriptionKey(key string) ([]byte, error) {
|
||
key = strings.TrimRight(key, "=")
|
||
|
||
if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 {
|
||
return base64.RawStdEncoding.DecodeString(key)
|
||
}
|
||
return base64.RawURLEncoding.DecodeString(key)
|
||
}
|
||
|
||
// Returns a key of length "length" given an hkdf function
|
||
func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) {
|
||
key := make([]byte, length)
|
||
n, err := io.ReadFull(hkdf, key)
|
||
if n != len(key) || err != nil {
|
||
return key, err
|
||
}
|
||
|
||
return key, nil
|
||
}
|