package webpush import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "net/url" "strings" "time" jwt "github.com/golang-jwt/jwt/v5" ) // VAPIDKeys is a public-private keypair for use in VAPID. // It marshals to a JSON string containing the PEM of the PKCS8 // of the private key. type VAPIDKeys struct { privateKey *ecdsa.PrivateKey publicKey string // raw bytes encoding in urlsafe base64, as per RFC } // PublicKeyString returns the base64url-encoded uncompressed public key of the keypair, // as defined in RFC8292. func (v *VAPIDKeys) PublicKeyString() string { return v.publicKey } // PrivateKey returns the private key of the keypair. func (v *VAPIDKeys) PrivateKey() *ecdsa.PrivateKey { return v.privateKey } // Equal compares two VAPIDKeys for equality. func (v *VAPIDKeys) Equal(o *VAPIDKeys) bool { return v.privateKey.Equal(o.privateKey) } var _ json.Marshaler = (*VAPIDKeys)(nil) var _ json.Unmarshaler = (*VAPIDKeys)(nil) // MarshalJSON implements json.Marshaler, allowing serialization to JSON. func (v *VAPIDKeys) MarshalJSON() ([]byte, error) { pkcs8bytes, err := x509.MarshalPKCS8PrivateKey(v.privateKey) if err != nil { return nil, err } pemBlock := pem.Block{ Type: "PRIVATE KEY", Bytes: pkcs8bytes, } pemBytes := pem.EncodeToMemory(&pemBlock) if pemBytes == nil { return nil, fmt.Errorf("could not encode VAPID keys as PEM") } return json.Marshal(string(pemBytes)) } // MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. func (v *VAPIDKeys) UnmarshalJSON(b []byte) error { var pemKey string if err := json.Unmarshal(b, &pemKey); err != nil { return err } pemBlock, _ := pem.Decode([]byte(pemKey)) if pemBlock == nil { return fmt.Errorf("could not decode PEM block with VAPID keys") } privKey, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) if err != nil { return err } privateKey, ok := privKey.(*ecdsa.PrivateKey) if !ok { return fmt.Errorf("Invalid type of private key %T", privateKey) } if privateKey.Curve != elliptic.P256() { return fmt.Errorf("Invalid curve for private key %v", privateKey.Curve) } publicKeyStr, err := makePublicKeyString(privateKey) if err != nil { return err // should not be possible since we confirmed P256 already } // success v.privateKey = privateKey v.publicKey = publicKeyStr return nil } // GenerateVAPIDKeys generates a VAPID keypair (an ECDSA keypair on // the P-256 curve). func GenerateVAPIDKeys() (result *VAPIDKeys, err error) { private, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return } pubKeyECDH, err := private.PublicKey.ECDH() if err != nil { return } publicKey := base64.RawURLEncoding.EncodeToString(pubKeyECDH.Bytes()) return &VAPIDKeys{ privateKey: private, publicKey: publicKey, }, nil } // ECDSAToVAPIDKeys wraps an existing ecdsa.PrivateKey in VAPIDKeys for use in // VAPID header signing. func ECDSAToVAPIDKeys(privKey *ecdsa.PrivateKey) (result *VAPIDKeys, err error) { if privKey.Curve != elliptic.P256() { return nil, fmt.Errorf("Invalid curve for private key %v", privKey.Curve) } publicKeyString, err := makePublicKeyString(privKey) if err != nil { return nil, err } return &VAPIDKeys{ privateKey: privKey, publicKey: publicKeyString, }, nil } func makePublicKeyString(privKey *ecdsa.PrivateKey) (result string, err error) { // to get the raw bytes we have to convert the public key to *ecdh.PublicKey // this type assertion (from the crypto.PublicKey returned by (*ecdsa.PrivateKey).Public() // to *ecdsa.PublicKey) cannot fail: publicKey, err := privKey.Public().(*ecdsa.PublicKey).ECDH() if err != nil { return // should not be possible if we confirmed P256 already } return base64.RawURLEncoding.EncodeToString(publicKey.Bytes()), nil } // getVAPIDAuthorizationHeader func getVAPIDAuthorizationHeader( endpoint string, subscriber string, vapidKeys *VAPIDKeys, expiration time.Time, ) (string, error) { if expiration.IsZero() { expiration = time.Now().Add(time.Hour * 12) } // Create the JWT token subURL, err := url.Parse(endpoint) if err != nil { return "", err } // Unless subscriber is an HTTPS URL, assume an e-mail address if !strings.HasPrefix(subscriber, "https:") && !strings.HasPrefix(subscriber, "mailto:") { subscriber = "mailto:" + subscriber } token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ "aud": subURL.Scheme + "://" + subURL.Host, "exp": expiration.Unix(), "sub": subscriber, }) // Sign token with private key jwtString, err := token.SignedString(vapidKeys.privateKey) if err != nil { return "", err } return "vapid t=" + jwtString + ", k=" + vapidKeys.publicKey, nil }