// Copyright 2018 by David A. Golden. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package scram

import (
	"encoding/base64"
	"errors"
	"fmt"
	"strconv"
	"strings"
)

type c1Msg struct {
	gs2Header string
	authzID   string
	username  string
	nonce     string
	c1b       string
}

type c2Msg struct {
	cbind []byte
	nonce string
	proof []byte
	c2wop string
}

type s1Msg struct {
	nonce string
	salt  []byte
	iters int
}

type s2Msg struct {
	verifier []byte
	err      string
}

func parseField(s, k string) (string, error) {
	t := strings.TrimPrefix(s, k+"=")
	if t == s {
		return "", fmt.Errorf("error parsing '%s' for field '%s'", s, k)
	}
	return t, nil
}

func parseGS2Flag(s string) (string, error) {
	if s[0] == 'p' {
		return "", fmt.Errorf("channel binding requested but not supported")
	}

	if s == "n" || s == "y" {
		return s, nil
	}

	return "", fmt.Errorf("error parsing '%s' for gs2 flag", s)
}

func parseFieldBase64(s, k string) ([]byte, error) {
	raw, err := parseField(s, k)
	if err != nil {
		return nil, err
	}

	dec, err := base64.StdEncoding.DecodeString(raw)
	if err != nil {
		return nil, err
	}

	return dec, nil
}

func parseFieldInt(s, k string) (int, error) {
	raw, err := parseField(s, k)
	if err != nil {
		return 0, err
	}

	num, err := strconv.Atoi(raw)
	if err != nil {
		return 0, fmt.Errorf("error parsing field '%s': %v", k, err)
	}

	return num, nil
}

func parseClientFirst(c1 string) (msg c1Msg, err error) {

	fields := strings.Split(c1, ",")
	if len(fields) < 4 {
		err = errors.New("not enough fields in first server message")
		return
	}

	gs2flag, err := parseGS2Flag(fields[0])
	if err != nil {
		return
	}

	// 'a' field is optional
	if len(fields[1]) > 0 {
		msg.authzID, err = parseField(fields[1], "a")
		if err != nil {
			return
		}
	}

	// Recombine and save the gs2 header
	msg.gs2Header = gs2flag + "," + msg.authzID + ","

	// Check for unsupported extensions field "m".
	if strings.HasPrefix(fields[2], "m=") {
		err = errors.New("SCRAM message extensions are not supported")
		return
	}

	msg.username, err = parseField(fields[2], "n")
	if err != nil {
		return
	}

	msg.nonce, err = parseField(fields[3], "r")
	if err != nil {
		return
	}

	msg.c1b = strings.Join(fields[2:], ",")

	return
}

func parseClientFinal(c2 string) (msg c2Msg, err error) {
	fields := strings.Split(c2, ",")
	if len(fields) < 3 {
		err = errors.New("not enough fields in first server message")
		return
	}

	msg.cbind, err = parseFieldBase64(fields[0], "c")
	if err != nil {
		return
	}

	msg.nonce, err = parseField(fields[1], "r")
	if err != nil {
		return
	}

	// Extension fields may come between nonce and proof, so we
	// grab the *last* fields as proof.
	msg.proof, err = parseFieldBase64(fields[len(fields)-1], "p")
	if err != nil {
		return
	}

	msg.c2wop = c2[:strings.LastIndex(c2, ",")]

	return
}

func parseServerFirst(s1 string) (msg s1Msg, err error) {

	// Check for unsupported extensions field "m".
	if strings.HasPrefix(s1, "m=") {
		err = errors.New("SCRAM message extensions are not supported")
		return
	}

	fields := strings.Split(s1, ",")
	if len(fields) < 3 {
		err = errors.New("not enough fields in first server message")
		return
	}

	msg.nonce, err = parseField(fields[0], "r")
	if err != nil {
		return
	}

	msg.salt, err = parseFieldBase64(fields[1], "s")
	if err != nil {
		return
	}

	msg.iters, err = parseFieldInt(fields[2], "i")

	return
}

func parseServerFinal(s2 string) (msg s2Msg, err error) {
	fields := strings.Split(s2, ",")

	msg.verifier, err = parseFieldBase64(fields[0], "v")
	if err == nil {
		return
	}

	msg.err, err = parseField(fields[0], "e")

	return
}