package claims

import (
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/base64"
	"errors"
	"log"
)

func SignDefault(plaintext, privateKey []byte) (signature string, err error) {
	client, err := New(privateKey, nil)
	if err != nil {
		log.Println(err)
		return
	}
	signatureByte, err := client.Sign(plaintext)
	if err != nil {
		log.Println(err)
		return
	}
	signature = base64.StdEncoding.EncodeToString(signatureByte)
	return
}

func VerifyDefault(plaintext, publicKey []byte, signature string) (err error) {
	publicKeys := make(map[string][]byte)
	publicKeys["default"] = publicKey
	client, err := New(nil, publicKeys)
	if err != nil {
		log.Println(err)
		return
	}

	signatureByte, err := base64.StdEncoding.DecodeString(signature)
	if err != nil {
		log.Println(err)
		return
	}

	err = client.Verify(plaintext, signatureByte, "default")
	if err != nil {
		log.Println(err)
		return
	}
	return
}

func (c *Client) Sign(plaintext []byte) (signature []byte, err error) {
	var opts rsa.PSSOptions
	opts.SaltLength = rsa.PSSSaltLengthAuto

	newhash := crypto.SHA256
	pssh := newhash.New()
	pssh.Write(plaintext)
	hashed := pssh.Sum(nil)
	signature, err = rsa.SignPSS(
		rand.Reader,
		c.PrivateKey,
		newhash,
		hashed,
		&opts,
	)
	return
}

func (c *Client) Verify(plaintext, signature []byte, target string) (err error) {
	var opts rsa.PSSOptions
	opts.SaltLength = rsa.PSSSaltLengthAuto

	newhash := crypto.SHA256
	pssh := newhash.New()
	pssh.Write(plaintext)
	hashed := pssh.Sum(nil)
	err = rsa.VerifyPSS(
		c.PublicKeys[target],
		newhash,
		hashed,
		signature,
		&opts,
	)
	return
}

func init() {
	log.SetFlags(log.LstdFlags | log.Lshortfile)
}

type Client struct {
	PrivateKey *rsa.PrivateKey
	PublicKeys map[string]*rsa.PublicKey
}

func New(privateKey []byte, publicKeys map[string][]byte) (client *Client, err error) {
	client = &Client{}

	if privateKey != nil {
		validPrivateKey, errPrivate := x509.ParsePKCS1PrivateKey(privateKey)
		if errPrivate != nil {
			err = errPrivate
			log.Println(err)
			return
		}
		client.PrivateKey = validPrivateKey
	}

	if publicKeys != nil {
		validPublicKeysMap := make(map[string]*rsa.PublicKey)
		for k, v := range publicKeys {
			validPublicKey, errPublic := x509.ParsePKCS1PublicKey(v)
			if errPublic != nil {
				err = errPublic
				log.Println(err)
				return
			}
			if validPublicKey == nil {
				err = errors.New("Invalid Public Key Type")
				log.Println(err)
				return
			}
			validPublicKeysMap[k] = validPublicKey
		}
		client.PublicKeys = validPublicKeysMap
	}

	return
}