package webpush import ( "bytes" "context" "crypto/aes" "crypto/cipher" "crypto/ecdh" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "strings" "time" "golang.org/x/crypto/hkdf" ) const MaxRecordSize uint32 = 4096 var ( ErrRecordSizeTooSmall = errors.New("record size too small for message") invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)") defaultHTTPClient = &http.Client{} ) // HTTPClient is an interface for sending the notification HTTP request / testing type HTTPClient interface { Do(*http.Request) (*http.Response, error) } // Options are config and extra params needed to send a notification type Options struct { HTTPClient HTTPClient // Will replace with *http.Client by default if not included RecordSize uint32 // Limit the record size Subscriber string // Sub in VAPID JWT token Topic string // Set the Topic header to collapse a pending messages (Optional) TTL int // Set the TTL on the endpoint POST request, in seconds Urgency Urgency // Set the Urgency header to change a message priority (Optional) VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours) } // Keys represents a subscription's keys (its ECDH public key on the P-256 curve // and its 16-byte authentication secret). type Keys struct { Auth [16]byte P256dh *ecdh.PublicKey } // Equal compares two Keys for equality. func (k *Keys) Equal(o Keys) bool { return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh) } var _ json.Marshaler = (*Keys)(nil) var _ json.Unmarshaler = (*Keys)(nil) type marshaledKeys struct { Auth string `json:"auth"` P256dh string `json:"p256dh"` } // MarshalJSON implements json.Marshaler, allowing serialization to JSON. func (k *Keys) MarshalJSON() ([]byte, error) { m := marshaledKeys{ Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]), P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()), } return json.Marshal(&m) } // MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. func (k *Keys) UnmarshalJSON(b []byte) (err error) { var m marshaledKeys if err := json.Unmarshal(b, &m); err != nil { return err } authBytes, err := decodeSubscriptionKey(m.Auth) if err != nil { return err } if len(authBytes) != 16 { return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes)) } copy(k.Auth[:], authBytes) rawDHKey, err := decodeSubscriptionKey(m.P256dh) if err != nil { return err } k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey) return err } // DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys // (the authentication secret and ECDH public key). func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) { authBytes, err := decodeSubscriptionKey(auth) if err != nil { return } if len(authBytes) != 16 { err = invalidAuthKeyLength return } copy(keys.Auth[:], authBytes) dhBytes, err := decodeSubscriptionKey(p256dh) if err != nil { return } keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes) if err != nil { return } return } // Subscription represents a PushSubscription object from the Push API type Subscription struct { Endpoint string `json:"endpoint"` Keys Keys `json:"keys"` } // SendNotification sends a push notification to a subscription's endpoint, // applying encryption (RFC 8291) and adding a VAPID header (RFC 8292). func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { // Compose message body (RFC8291 encryption of the message) body, err := EncryptNotification(message, s.Keys, options.RecordSize) if err != nil { return nil, err } // Get VAPID Authorization header vapidAuthHeader, err := getVAPIDAuthorizationHeader( s.Endpoint, options.Subscriber, options.VAPIDKeys, options.VapidExpiration, ) if err != nil { return nil, err } // Compose and send the HTTP request return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body) } // EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push // (RFC 8188's aes128gcm content-encoding, with the key material derived from // elliptic curve Diffie-Hellman over the P-256 curve). func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) { // Get the record size if recordSize == 0 { recordSize = MaxRecordSize } else if recordSize < 128 { return nil, ErrRecordSizeTooSmall } // Allocate buffer to hold the eventual message // [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes // the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ] recordBuf := make([]byte, recordSize) // remainingBuf tracks our current writing position in recordBuf: remainingBuf := recordBuf // Application server key pairs (single use) localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } localPublicKey := localPrivateKey.PublicKey() // Encryption Content-Coding Header // +-----------+--------+-----------+---------------+ // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | // +-----------+--------+-----------+---------------+ // in our case the keyid is localPublicKey.Bytes(), so 65 bytes // First, generate the salt _, err = rand.Read(remainingBuf[:16]) if err != nil { return nil, err } salt := remainingBuf[:16] remainingBuf = remainingBuf[16:] binary.BigEndian.PutUint32(remainingBuf[:], recordSize) remainingBuf = remainingBuf[4:] localPublicKeyBytes := localPublicKey.Bytes() remainingBuf[0] = byte(len(localPublicKeyBytes)) remainingBuf = remainingBuf[1:] copy(remainingBuf[:], localPublicKeyBytes) remainingBuf = remainingBuf[len(localPublicKeyBytes):] // Combine application keys with receiver's EC public key to derive ECDH shared secret sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh) if err != nil { return nil, fmt.Errorf("deriving shared secret: %w", err) } // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) prkInfoBuf.Write(keys.P256dh.Bytes()) prkInfoBuf.Write(localPublicKey.Bytes()) prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) if err != nil { return nil, err } // Derive Content Encryption Key contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo) contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) if err != nil { return nil, err } // Derive the Nonce nonceInfo := []byte("Content-Encoding: nonce\x00") nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo) nonce, err := getHKDFKey(nonceHKDF, 12) if err != nil { return nil, err } // Cipher c, err := aes.NewCipher(contentEncryptionKey) if err != nil { return nil, err } gcm, err := cipher.NewGCM(c) if err != nil { return nil, err } // need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag if len(remainingBuf) < len(message)+17 { return nil, ErrRecordSizeTooSmall } // Copy the message plaintext into the buffer copy(remainingBuf[:], message[:]) // The plaintext to be encrypted will include the padding delimiter and the padding; // cut off the final 16 bytes that are reserved for the AEAD tag plaintext := remainingBuf[:len(remainingBuf)-16] remainingBuf = remainingBuf[len(message):] // Add padding delimiter remainingBuf[0] = '\x02' remainingBuf = remainingBuf[1:] // The rest of the buffer is already zero-padded // Encipher the plaintext in place, then add the AEAD tag at the end. // "To reuse plaintext's storage for the encrypted output, use plaintext[:0] // as dst. Otherwise, the remaining capacity of dst must not overlap plaintext." gcm.Seal(plaintext[:0], nonce, plaintext, nil) return recordBuf, nil } func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) { // POST request req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body)) if err != nil { return nil, err } if ctx != nil { req = req.WithContext(ctx) } req.Header.Set("Content-Encoding", "aes128gcm") req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("TTL", strconv.Itoa(options.TTL)) // Сheck the optional headers if len(options.Topic) > 0 { req.Header.Set("Topic", options.Topic) } if isValidUrgency(options.Urgency) { req.Header.Set("Urgency", string(options.Urgency)) } req.Header.Set("Authorization", vapidAuthHeader) // Send the request var client HTTPClient if options.HTTPClient != nil { client = options.HTTPClient } else { client = defaultHTTPClient } return client.Do(req) } // decodeSubscriptionKey decodes a base64 subscription key. func decodeSubscriptionKey(key string) ([]byte, error) { key = strings.TrimRight(key, "=") if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 { return base64.RawStdEncoding.DecodeString(key) } return base64.RawURLEncoding.DecodeString(key) } // Returns a key of length "length" given an hkdf function func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) { key := make([]byte, length) n, err := io.ReadFull(hkdf, key) if n != len(key) || err != nil { return key, err } return key, nil }