mirror of
https://github.com/ergochat/ergo.git
synced 2025-01-22 02:04:10 +01:00
324 lines
9.4 KiB
Go
324 lines
9.4 KiB
Go
|
package webpush
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"context"
|
|||
|
"crypto/aes"
|
|||
|
"crypto/cipher"
|
|||
|
"crypto/ecdh"
|
|||
|
"crypto/rand"
|
|||
|
"crypto/sha256"
|
|||
|
"encoding/base64"
|
|||
|
"encoding/binary"
|
|||
|
"encoding/json"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"io"
|
|||
|
"net/http"
|
|||
|
"strconv"
|
|||
|
"strings"
|
|||
|
"time"
|
|||
|
|
|||
|
"golang.org/x/crypto/hkdf"
|
|||
|
)
|
|||
|
|
|||
|
const MaxRecordSize uint32 = 4096
|
|||
|
|
|||
|
var (
|
|||
|
ErrRecordSizeTooSmall = errors.New("record size too small for message")
|
|||
|
|
|||
|
invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)")
|
|||
|
|
|||
|
defaultHTTPClient = &http.Client{}
|
|||
|
)
|
|||
|
|
|||
|
// HTTPClient is an interface for sending the notification HTTP request / testing
|
|||
|
type HTTPClient interface {
|
|||
|
Do(*http.Request) (*http.Response, error)
|
|||
|
}
|
|||
|
|
|||
|
// Options are config and extra params needed to send a notification
|
|||
|
type Options struct {
|
|||
|
HTTPClient HTTPClient // Will replace with *http.Client by default if not included
|
|||
|
RecordSize uint32 // Limit the record size
|
|||
|
Subscriber string // Sub in VAPID JWT token
|
|||
|
Topic string // Set the Topic header to collapse a pending messages (Optional)
|
|||
|
TTL int // Set the TTL on the endpoint POST request, in seconds
|
|||
|
Urgency Urgency // Set the Urgency header to change a message priority (Optional)
|
|||
|
VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header
|
|||
|
VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours)
|
|||
|
}
|
|||
|
|
|||
|
// Keys represents a subscription's keys (its ECDH public key on the P-256 curve
|
|||
|
// and its 16-byte authentication secret).
|
|||
|
type Keys struct {
|
|||
|
Auth [16]byte
|
|||
|
P256dh *ecdh.PublicKey
|
|||
|
}
|
|||
|
|
|||
|
// Equal compares two Keys for equality.
|
|||
|
func (k *Keys) Equal(o Keys) bool {
|
|||
|
return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh)
|
|||
|
}
|
|||
|
|
|||
|
var _ json.Marshaler = (*Keys)(nil)
|
|||
|
var _ json.Unmarshaler = (*Keys)(nil)
|
|||
|
|
|||
|
type marshaledKeys struct {
|
|||
|
Auth string `json:"auth"`
|
|||
|
P256dh string `json:"p256dh"`
|
|||
|
}
|
|||
|
|
|||
|
// MarshalJSON implements json.Marshaler, allowing serialization to JSON.
|
|||
|
func (k *Keys) MarshalJSON() ([]byte, error) {
|
|||
|
m := marshaledKeys{
|
|||
|
Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]),
|
|||
|
P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()),
|
|||
|
}
|
|||
|
return json.Marshal(&m)
|
|||
|
}
|
|||
|
|
|||
|
// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON.
|
|||
|
func (k *Keys) UnmarshalJSON(b []byte) (err error) {
|
|||
|
var m marshaledKeys
|
|||
|
if err := json.Unmarshal(b, &m); err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
authBytes, err := decodeSubscriptionKey(m.Auth)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
if len(authBytes) != 16 {
|
|||
|
return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes))
|
|||
|
}
|
|||
|
copy(k.Auth[:], authBytes)
|
|||
|
rawDHKey, err := decodeSubscriptionKey(m.P256dh)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey)
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
// DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys
|
|||
|
// (the authentication secret and ECDH public key).
|
|||
|
func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) {
|
|||
|
authBytes, err := decodeSubscriptionKey(auth)
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if len(authBytes) != 16 {
|
|||
|
err = invalidAuthKeyLength
|
|||
|
return
|
|||
|
}
|
|||
|
copy(keys.Auth[:], authBytes)
|
|||
|
dhBytes, err := decodeSubscriptionKey(p256dh)
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes)
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
// Subscription represents a PushSubscription object from the Push API
|
|||
|
type Subscription struct {
|
|||
|
Endpoint string `json:"endpoint"`
|
|||
|
Keys Keys `json:"keys"`
|
|||
|
}
|
|||
|
|
|||
|
// SendNotification sends a push notification to a subscription's endpoint,
|
|||
|
// applying encryption (RFC 8291) and adding a VAPID header (RFC 8292).
|
|||
|
func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) {
|
|||
|
// Compose message body (RFC8291 encryption of the message)
|
|||
|
body, err := EncryptNotification(message, s.Keys, options.RecordSize)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// Get VAPID Authorization header
|
|||
|
vapidAuthHeader, err := getVAPIDAuthorizationHeader(
|
|||
|
s.Endpoint,
|
|||
|
options.Subscriber,
|
|||
|
options.VAPIDKeys,
|
|||
|
options.VapidExpiration,
|
|||
|
)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// Compose and send the HTTP request
|
|||
|
return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body)
|
|||
|
}
|
|||
|
|
|||
|
// EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push
|
|||
|
// (RFC 8188's aes128gcm content-encoding, with the key material derived from
|
|||
|
// elliptic curve Diffie-Hellman over the P-256 curve).
|
|||
|
func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) {
|
|||
|
// Get the record size
|
|||
|
if recordSize == 0 {
|
|||
|
recordSize = MaxRecordSize
|
|||
|
} else if recordSize < 128 {
|
|||
|
return nil, ErrRecordSizeTooSmall
|
|||
|
}
|
|||
|
|
|||
|
// Allocate buffer to hold the eventual message
|
|||
|
// [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes
|
|||
|
// the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ]
|
|||
|
recordBuf := make([]byte, recordSize)
|
|||
|
// remainingBuf tracks our current writing position in recordBuf:
|
|||
|
remainingBuf := recordBuf
|
|||
|
|
|||
|
// Application server key pairs (single use)
|
|||
|
localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
localPublicKey := localPrivateKey.PublicKey()
|
|||
|
|
|||
|
// Encryption Content-Coding Header
|
|||
|
// +-----------+--------+-----------+---------------+
|
|||
|
// | salt (16) | rs (4) | idlen (1) | keyid (idlen) |
|
|||
|
// +-----------+--------+-----------+---------------+
|
|||
|
// in our case the keyid is localPublicKey.Bytes(), so 65 bytes
|
|||
|
// First, generate the salt
|
|||
|
_, err = rand.Read(remainingBuf[:16])
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
salt := remainingBuf[:16]
|
|||
|
remainingBuf = remainingBuf[16:]
|
|||
|
binary.BigEndian.PutUint32(remainingBuf[:], recordSize)
|
|||
|
remainingBuf = remainingBuf[4:]
|
|||
|
localPublicKeyBytes := localPublicKey.Bytes()
|
|||
|
remainingBuf[0] = byte(len(localPublicKeyBytes))
|
|||
|
remainingBuf = remainingBuf[1:]
|
|||
|
copy(remainingBuf[:], localPublicKeyBytes)
|
|||
|
remainingBuf = remainingBuf[len(localPublicKeyBytes):]
|
|||
|
|
|||
|
// Combine application keys with receiver's EC public key to derive ECDH shared secret
|
|||
|
sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh)
|
|||
|
if err != nil {
|
|||
|
return nil, fmt.Errorf("deriving shared secret: %w", err)
|
|||
|
}
|
|||
|
|
|||
|
// ikm
|
|||
|
prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00"))
|
|||
|
prkInfoBuf.Write(keys.P256dh.Bytes())
|
|||
|
prkInfoBuf.Write(localPublicKey.Bytes())
|
|||
|
|
|||
|
prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes())
|
|||
|
ikm, err := getHKDFKey(prkHKDF, 32)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// Derive Content Encryption Key
|
|||
|
contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00")
|
|||
|
contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo)
|
|||
|
contentEncryptionKey, err := getHKDFKey(contentHKDF, 16)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// Derive the Nonce
|
|||
|
nonceInfo := []byte("Content-Encoding: nonce\x00")
|
|||
|
nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo)
|
|||
|
nonce, err := getHKDFKey(nonceHKDF, 12)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// Cipher
|
|||
|
c, err := aes.NewCipher(contentEncryptionKey)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
gcm, err := cipher.NewGCM(c)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag
|
|||
|
if len(remainingBuf) < len(message)+17 {
|
|||
|
return nil, ErrRecordSizeTooSmall
|
|||
|
}
|
|||
|
// Copy the message plaintext into the buffer
|
|||
|
copy(remainingBuf[:], message[:])
|
|||
|
// The plaintext to be encrypted will include the padding delimiter and the padding;
|
|||
|
// cut off the final 16 bytes that are reserved for the AEAD tag
|
|||
|
plaintext := remainingBuf[:len(remainingBuf)-16]
|
|||
|
remainingBuf = remainingBuf[len(message):]
|
|||
|
// Add padding delimiter
|
|||
|
remainingBuf[0] = '\x02'
|
|||
|
remainingBuf = remainingBuf[1:]
|
|||
|
// The rest of the buffer is already zero-padded
|
|||
|
|
|||
|
// Encipher the plaintext in place, then add the AEAD tag at the end.
|
|||
|
// "To reuse plaintext's storage for the encrypted output, use plaintext[:0]
|
|||
|
// as dst. Otherwise, the remaining capacity of dst must not overlap plaintext."
|
|||
|
gcm.Seal(plaintext[:0], nonce, plaintext, nil)
|
|||
|
|
|||
|
return recordBuf, nil
|
|||
|
}
|
|||
|
|
|||
|
func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) {
|
|||
|
// POST request
|
|||
|
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body))
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
if ctx != nil {
|
|||
|
req = req.WithContext(ctx)
|
|||
|
}
|
|||
|
|
|||
|
req.Header.Set("Content-Encoding", "aes128gcm")
|
|||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
|||
|
req.Header.Set("TTL", strconv.Itoa(options.TTL))
|
|||
|
|
|||
|
// Сheck the optional headers
|
|||
|
if len(options.Topic) > 0 {
|
|||
|
req.Header.Set("Topic", options.Topic)
|
|||
|
}
|
|||
|
|
|||
|
if isValidUrgency(options.Urgency) {
|
|||
|
req.Header.Set("Urgency", string(options.Urgency))
|
|||
|
}
|
|||
|
|
|||
|
req.Header.Set("Authorization", vapidAuthHeader)
|
|||
|
|
|||
|
// Send the request
|
|||
|
var client HTTPClient
|
|||
|
if options.HTTPClient != nil {
|
|||
|
client = options.HTTPClient
|
|||
|
} else {
|
|||
|
client = defaultHTTPClient
|
|||
|
}
|
|||
|
|
|||
|
return client.Do(req)
|
|||
|
}
|
|||
|
|
|||
|
// decodeSubscriptionKey decodes a base64 subscription key.
|
|||
|
func decodeSubscriptionKey(key string) ([]byte, error) {
|
|||
|
key = strings.TrimRight(key, "=")
|
|||
|
|
|||
|
if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 {
|
|||
|
return base64.RawStdEncoding.DecodeString(key)
|
|||
|
}
|
|||
|
return base64.RawURLEncoding.DecodeString(key)
|
|||
|
}
|
|||
|
|
|||
|
// Returns a key of length "length" given an hkdf function
|
|||
|
func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) {
|
|||
|
key := make([]byte, length)
|
|||
|
n, err := io.ReadFull(hkdf, key)
|
|||
|
if n != len(key) || err != nil {
|
|||
|
return key, err
|
|||
|
}
|
|||
|
|
|||
|
return key, nil
|
|||
|
}
|