3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-27 05:02:35 +01:00
ergo/vendor/github.com/xdg-go/scram/client_conv.go

150 lines
4.0 KiB
Go
Raw Normal View History

2021-07-30 18:20:13 +02:00
// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package scram
import (
"crypto/hmac"
"encoding/base64"
"errors"
"fmt"
"strings"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
// ClientConversation implements the client-side of an authentication
// conversation with a server. A new conversation must be created for
// each authentication attempt.
type ClientConversation struct {
client *Client
nonceGen NonceGeneratorFcn
hashGen HashGeneratorFcn
minIters int
state clientState
valid bool
gs2 string
nonce string
c1b string
serveSig []byte
}
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (cc *ClientConversation) Step(challenge string) (response string, err error) {
switch cc.state {
case clientStarting:
cc.state = clientFirst
response, err = cc.firstMsg()
case clientFirst:
cc.state = clientFinal
response, err = cc.finalMsg(challenge)
case clientFinal:
cc.state = clientDone
response, err = cc.validateServer(challenge)
default:
response, err = "", errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (cc *ClientConversation) Done() bool {
return cc.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (cc *ClientConversation) Valid() bool {
return cc.valid
}
func (cc *ClientConversation) firstMsg() (string, error) {
// Values are cached for use in final message parameters
cc.gs2 = cc.gs2Header()
cc.nonce = cc.client.nonceGen()
cc.c1b = fmt.Sprintf("n=%s,r=%s", encodeName(cc.client.username), cc.nonce)
return cc.gs2 + cc.c1b, nil
}
func (cc *ClientConversation) finalMsg(s1 string) (string, error) {
msg, err := parseServerFirst(s1)
if err != nil {
return "", err
}
// Check nonce prefix and update
if !strings.HasPrefix(msg.nonce, cc.nonce) {
return "", errors.New("server nonce did not extend client nonce")
}
cc.nonce = msg.nonce
// Check iteration count vs minimum
if msg.iters < cc.minIters {
return "", fmt.Errorf("server requested too few iterations (%d)", msg.iters)
}
// Create client-final-message-without-proof
c2wop := fmt.Sprintf(
"c=%s,r=%s",
base64.StdEncoding.EncodeToString([]byte(cc.gs2)),
cc.nonce,
)
// Create auth message
authMsg := cc.c1b + "," + s1 + "," + c2wop
// Get derived keys from client cache
dk := cc.client.getDerivedKeys(KeyFactors{Salt: string(msg.salt), Iters: msg.iters})
// Create proof as clientkey XOR clientsignature
clientSignature := computeHMAC(cc.hashGen, dk.StoredKey, []byte(authMsg))
clientProof := xorBytes(dk.ClientKey, clientSignature)
proof := base64.StdEncoding.EncodeToString(clientProof)
// Cache ServerSignature for later validation
cc.serveSig = computeHMAC(cc.hashGen, dk.ServerKey, []byte(authMsg))
return fmt.Sprintf("%s,p=%s", c2wop, proof), nil
}
func (cc *ClientConversation) validateServer(s2 string) (string, error) {
msg, err := parseServerFinal(s2)
if err != nil {
return "", err
}
if len(msg.err) > 0 {
return "", fmt.Errorf("server error: %s", msg.err)
}
if !hmac.Equal(msg.verifier, cc.serveSig) {
return "", errors.New("server validation failed")
}
cc.valid = true
return "", nil
}
func (cc *ClientConversation) gs2Header() string {
if cc.client.authzID == "" {
return "n,,"
}
return fmt.Sprintf("n,%s,", encodeName(cc.client.authzID))
}