/*
 * Copyright (C) by
 *   MetraLabs GmbH (MLAB), GERMANY
 * and
 *   Neuroinformatics and Cognitive Robotics Labs (NICR) at TU Ilmenau, GERMANY
 * All rights reserved.
 *
 * Contact: info@mira-project.org
 *
 * Commercial Usage:
 *   Licensees holding valid commercial licenses may use this file in
 *   accordance with the commercial license agreement provided with the
 *   software or, alternatively, in accordance with the terms contained in
 *   a written agreement between you and MLAB or NICR.
 *
 * GNU General Public License Usage:
 *   Alternatively, this file may be used under the terms of the GNU
 *   General Public License version 3.0 as published by the Free Software
 *   Foundation and appearing in the file LICENSE.GPL3 included in the
 *   packaging of this file. Please review the following information to
 *   ensure the GNU General Public License version 3.0 requirements will be
 *   met: http://www.gnu.org/copyleft/gpl.html.
 *   Alternatively you may (at your option) use any later version of the GNU
 *   General Public License if such license has been publicly approved by
 *   MLAB and NICR (or its successors, if any).
 *
 * IN NO EVENT SHALL "MLAB" OR "NICR" BE LIABLE TO ANY PARTY FOR DIRECT,
 * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
 * THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF "MLAB" OR
 * "NICR" HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * "MLAB" AND "NICR" SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
 * BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
 * ON AN "AS IS" BASIS, AND "MLAB" AND "NICR" HAVE NO OBLIGATION TO
 * PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS OR MODIFICATIONS.
 */

/**
 * @file RSAKey.C
 *    Implementation of RSAKey.h
 *
 * @author Christian Martin
 * @date   2023/12/xx
 */

#include <security/RSAKey.h>

#include <boost/format.hpp>
#include <boost/random.hpp>

#include <utils/StringAlgorithms.h>
#include <utils/Time.h>

#include <error/Exceptions.h>

#include "../OpenSSLHelper.h"

#include <openssl/core_names.h>
#include <openssl/encoder.h>
#include <openssl/decoder.h>
#include <openssl/param_build.h>
#include <openssl/rand.h>

using namespace std;

namespace mira {

///////////////////////////////////////////////////////////////////////////////
// some internal helper functions

namespace impl {

///////////////////////////////////////////////////////////////////////////////

static const char* OSSL_PRMS_PUBLIC[] = {
	OSSL_PKEY_PARAM_RSA_N,
	OSSL_PKEY_PARAM_RSA_E,
	NULL
};

static const char* OSSL_PRMS_PRIVATE[] = {
	OSSL_PKEY_PARAM_RSA_N,
	OSSL_PKEY_PARAM_RSA_E,
	OSSL_PKEY_PARAM_RSA_D,
	OSSL_PKEY_PARAM_RSA_FACTOR1,
	OSSL_PKEY_PARAM_RSA_FACTOR2,
	OSSL_PKEY_PARAM_RSA_EXPONENT1,
	OSSL_PKEY_PARAM_RSA_EXPONENT2,
	OSSL_PKEY_PARAM_RSA_COEFFICIENT1,
	NULL
};

EVP_PKEY* duplicateKey(EVP_PKEY* pkey, int keySelection, const char* prmList[])
{
	ERR_clear_error();

	////////////////////////////////////////////////////////////////
	// extract the required key components

	OSSL_PARAM* params = NULL;
	if (!EVP_PKEY_todata(pkey, keySelection, &params))
		MIRA_THROW(XSystemCall, "EVP_PKEY_todata failed.");

	OSSL_PARAM_BLD* bld = OSSL_PARAM_BLD_new();

	for(int i = 0; prmList[i] != NULL; i++) {
		BIGNUM* bn = NULL;
		if (!EVP_PKEY_get_bn_param(pkey, prmList[i], &bn)) {
			auto errNo = ERR_get_error();
			OSSL_PARAM_free(params);
			OSSL_PARAM_BLD_free(bld);
			MIRA_THROW(XSystemCall, "EVP_PKEY_get_bn_param failed for '" <<
			           prmList[i] << "': " <<
			           OpenSSLErrorString::instance().err2str(errNo));
		}
		OSSL_PARAM_BLD_push_BN(bld, prmList[i], bn);
	}

	// build new parameter list
	OSSL_PARAM* newPrms = OSSL_PARAM_BLD_to_param(bld);

	// some first cleanup
	OSSL_PARAM_BLD_free(bld);
	OSSL_PARAM_free(params);

	////////////////////////////////////////////////////////////////

	EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL);
	if (!ctx) {
		OSSL_PARAM_free(newPrms);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new_id failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (EVP_PKEY_fromdata_init(ctx) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		OSSL_PARAM_free(newPrms);
		MIRA_THROW(XSystemCall, "EVP_PKEY_fromdata_init failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	EVP_PKEY* newKey = NULL;
	if (EVP_PKEY_fromdata(ctx, &newKey, keySelection, newPrms) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		OSSL_PARAM_free(newPrms);
		MIRA_THROW(XSystemCall, "EVP_PKEY_fromdata failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	// final cleanup
	EVP_PKEY_CTX_free(ctx);
	OSSL_PARAM_free(newPrms);

	return newKey;
}

string getKeyParamAsString(EVP_PKEY* key, const char* param)
{
	BIGNUM* bn = NULL;
	EVP_PKEY_get_bn_param(key, param, &bn);

	if (bn == NULL)
		return "";

	char* strPtr = BN_bn2hex(bn);
	if (strPtr == NULL) {
		BN_clear_free(bn);
		MIRA_THROW(XSystemCall, "Unable to convert BIGNUM for '" <<
		           param << "' to string.");
	}

	string res(strPtr);
	OPENSSL_clear_free(strPtr, strlen(strPtr));
	BN_clear_free(bn);
	return res;
}

///////////////////////////////////////////////////////////////////////////////

}

///////////////////////////////////////////////////////////////////////////////
// Implementation of RSAKey

RSAKey::RSAKey()
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();
	mKey->key = EVP_PKEY_new();
}

RSAKey::RSAKey(const RSAKey& key)
{
	OpenSSLCleanup::instance();

	ERR_clear_error();

	mKey = new OpenSSLRSAWrapper();
	mKey->key = EVP_PKEY_dup(key.mKey->key);

	if (mKey->key == NULL) {
		auto errNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
}

RSAKey::RSAKey(const string& n, const string& e, const string& d)
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();
	mKey->key = EVP_PKEY_new();

	////////////////////////////////////////////////////////////////
	// convert strings to BIGNUMs

	BIGNUM* key_n = NULL;
	BIGNUM* key_e = NULL;
	BIGNUM* key_d = NULL;

	if (n.size() > 0) {
		key_n = BN_secure_new();
		if (BN_hex2bn(&key_n, n.c_str()) == 0)
			BN_zero(key_n);
	}
	if (e.size() > 0) {
		key_e = BN_secure_new();
		if (BN_hex2bn(&key_e, e.c_str()) == 0)
			BN_zero(key_e);
	}
	if (d.size() > 0) {
		key_d = BN_secure_new();
		if (BN_hex2bn(&key_d, d.c_str()) == 0)
			BN_zero(key_d);
	}

	int keyType = key_d ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY;

	////////////////////////////////////////////////////////////////
	// build OSSL_PARAM array

	OSSL_PARAM_BLD *bld = OSSL_PARAM_BLD_new();
	OSSL_PARAM_BLD_push_BN(bld, "n", key_n);
	OSSL_PARAM_BLD_push_BN(bld, "e", key_e);
	OSSL_PARAM_BLD_push_BN(bld, "d", key_d);

	OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(bld);
	if (params == NULL) {
		auto errNo = ERR_get_error();
		OSSL_PARAM_BLD_free(bld);
		BN_clear_free(key_n);
		BN_clear_free(key_e);
		BN_clear_free(key_d);
		MIRA_THROW(XSystemCall, "OSSL_PARAM_BLD_to_param failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	// some first cleanup
	OSSL_PARAM_BLD_free(bld);
	BN_clear_free(key_n);
	BN_clear_free(key_e);
	BN_clear_free(key_d);

	////////////////////////////////////////////////////////////////
	// create key from OSSL_PARAM array

	EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL);
	if (!ctx) {
		OSSL_PARAM_free(params);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new_id failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (EVP_PKEY_fromdata_init(ctx) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		OSSL_PARAM_free(params);
		MIRA_THROW(XSystemCall, "EVP_PKEY_fromdata_init failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_fromdata(ctx, &mKey->key, keyType, params) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		OSSL_PARAM_free(params);
		MIRA_THROW(XSystemCall, "EVP_PKEY_fromdata failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	////////////////////////////////////////////////////////////////
	// final cleanup

	EVP_PKEY_CTX_free(ctx);
	OSSL_PARAM_free(params);
}

RSAKey::~RSAKey()
{
	EVP_PKEY_free(mKey->key);
	delete mKey;
	mKey = NULL;
}

///////////////////////////////////////////////////////////////////////////////

RSAKey::RSAKey(const OpenSSLRSAWrapper* key)
{
	OpenSSLCleanup::instance();

	mKey = new OpenSSLRSAWrapper();
	if (key == NULL)
		MIRA_THROW(XInvalidParameter, "Key must not be NULL.");

	ERR_clear_error();

	mKey->key = EVP_PKEY_dup(key->key);
	if (mKey->key == NULL) {
		unsigned long tErrNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(tErrNo));
	}
}

///////////////////////////////////////////////////////////////////////////////

bool RSAKey::isValid() const
{
	BIGNUM* key_n = NULL;
	BIGNUM* key_e = NULL;
	BIGNUM* key_d = NULL;

	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_N, &key_n);
	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_E, &key_e);
	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_D, &key_d);

	bool res =
		(key_n != NULL) && (!BN_is_zero(key_n)) &&
		(((key_e != NULL) && (!BN_is_zero(key_e))) ||
		 ((key_d != NULL) && (!BN_is_zero(key_d))));

	BN_clear_free(key_n);
	BN_clear_free(key_e);
	BN_clear_free(key_d);

	return res;
}

bool RSAKey::isPublicKey() const
{
	BIGNUM* key_n = NULL;
	BIGNUM* key_e = NULL;

	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_N, &key_n);
	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_E, &key_e);

	bool res =
			(key_n != NULL) && (!BN_is_zero(key_n)) &&
			(key_e != NULL) && (!BN_is_zero(key_e));

	BN_clear_free(key_n);
	BN_clear_free(key_e);

	return res;
}

bool RSAKey::isPrivateKey() const
{
	BIGNUM* key_n = NULL;
	BIGNUM* key_d = NULL;

	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_N, &key_n);
	EVP_PKEY_get_bn_param(mKey->key, OSSL_PKEY_PARAM_RSA_D, &key_d);

	bool res =
		(key_n != NULL) && (!BN_is_zero(key_n)) &&
		(key_d != NULL) && (!BN_is_zero(key_d));

	BN_clear_free(key_n);
	BN_clear_free(key_d);

	return res;
}

bool RSAKey::clear()
{
	EVP_PKEY_free(mKey->key);
	mKey->key = EVP_PKEY_new();

	return(true);
}

///////////////////////////////////////////////////////////////////////////////

RSAKey& RSAKey::operator=(const RSAKey& key)
{
	EVP_PKEY_free(mKey->key);

	ERR_clear_error();
	mKey->key = EVP_PKEY_dup(key.mKey->key);
	if (mKey->key == NULL) {
		auto errNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "Failed to duplicate RSA key: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	return(*this);
}

bool RSAKey::operator==(const RSAKey& key)
{
	if (!(isValid() && key.isValid()))
		return(false);

	return (EVP_PKEY_eq(mKey->key, key.mKey->key) == 1);
}

bool RSAKey::operator!=(const RSAKey& key)
{
	return(!(*this == key));
}

///////////////////////////////////////////////////////////////////////////////

string RSAKey::getNStr() const
{
	return impl::getKeyParamAsString(mKey->key, OSSL_PKEY_PARAM_RSA_N);
}

string RSAKey::getEStr() const
{
	return impl::getKeyParamAsString(mKey->key, OSSL_PKEY_PARAM_RSA_E);
}

string RSAKey::getDStr() const
{
	return impl::getKeyParamAsString(mKey->key, OSSL_PKEY_PARAM_RSA_D);
}

///////////////////////////////////////////////////////////////////////////////

void RSAKey::generateKey(unsigned int iKeyBitLength,
                         RSAKey &oPublicKey, RSAKey &oPrivateKey)
{
	// We should use at least 1024 bits!
	if (iKeyBitLength < 1024)
		MIRA_THROW(XInvalidParameter,
		          "Key bit length should be at least 1024 bits.");

	feedRandomNumberGenerator(iKeyBitLength);
	ERR_clear_error();

	// prepare context
	EVP_PKEY_CTX* ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL);
	if (!ctx) {
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_new_id failed:" <<
		           OpenSSLErrorString::instance().err2str(ERR_get_error()));
	}
	if (EVP_PKEY_keygen_init(ctx) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_keygen_init failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx, iKeyBitLength) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_CTX_set_rsa_keygen_bits failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	// generate a new key
	EVP_PKEY* pkey = NULL;
	if (EVP_PKEY_keygen(ctx, &pkey) <= 0) {
		auto errNo = ERR_get_error();
		EVP_PKEY_CTX_free(ctx);
		MIRA_THROW(XSystemCall, "EVP_PKEY_keygen failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	// extract the public key
	oPublicKey.clear();
	EVP_PKEY_free(oPublicKey.mKey->key);
	oPublicKey.mKey->key = impl::duplicateKey(pkey, EVP_PKEY_PUBLIC_KEY, impl::OSSL_PRMS_PUBLIC);

	// extract the private key
	oPrivateKey.clear();
	EVP_PKEY_free(oPrivateKey.mKey->key);
	oPrivateKey.mKey->key = impl::duplicateKey(pkey, EVP_PKEY_KEYPAIR, impl::OSSL_PRMS_PRIVATE);

	// cleanup
	EVP_PKEY_free(pkey);
	EVP_PKEY_CTX_free(ctx);
}

void RSAKey::feedRandomNumberGenerator(size_t count)
{
	boost::mt19937 generator(Time::now().toUnixTimestamp());
	boost::uniform_int<> t256(0, 255);
	boost::variate_generator<boost::mt19937&, boost::uniform_int<> >
		tRand(generator, t256);

	// Feed OpenSSL's pseudo random number generator with interesting data :-)
	uint8* randData = new uint8[count];
	for(unsigned int i = 0; i < count; i++)
		randData[i] = (uint8)tRand();
	RAND_seed(randData, count);
	delete [] randData;
}

///////////////////////////////////////////////////////////////////////////////

ostream& operator<<(ostream& stream, const RSAKey& key)
{
	int selection = 0;
	if (key.isPrivateKey()) {
		selection = EVP_PKEY_KEYPAIR;
		stream << "PRIVATE:";
	} else {
		selection = EVP_PKEY_PUBLIC_KEY;
		stream << "PUBLIC:";
	}

	////////////////////////////////////////////////////////////////
	// convert PKEY_KEY into binary DER format

	const char *structure = "type-specific";

	OSSL_ENCODER_CTX *octx =
		OSSL_ENCODER_CTX_new_for_pkey(key.mKey->key, selection, "DER", structure, NULL);
	if (octx == NULL) {
		auto errNo = ERR_get_error();
		MIRA_THROW(XSystemCall, "OSSL_ENCODER_CTX_new_for_pkey failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}
	if (OSSL_ENCODER_CTX_get_num_encoders(octx) == 0) {
		auto errNo = ERR_get_error();
		OSSL_ENCODER_CTX_free(octx);
		MIRA_THROW(XSystemCall, "OSSL_ENCODER_CTX_get_num_encoders failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	unsigned char* outP = NULL;
	size_t len = 0;
	if (!OSSL_ENCODER_to_data(octx, &outP, &len)) {
		auto errNo = ERR_get_error();
		OSSL_ENCODER_CTX_free(octx);
		MIRA_THROW(XSystemCall, "OSSL_ENCODER_to_data failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	////////////////////////////////////////////////////////////////
	// convert binary data into a hex-string

	stream << len << ":";
	for(size_t j = 0; j < len; j++)
		stream << boost::format("%02x") % int(outP[j]);
	stream << ";";

	// cleanup
	OPENSSL_clear_free(outP, len);
	OSSL_ENCODER_CTX_free(octx);

	return stream;
}

istream& operator>>(istream& stream, RSAKey& key)
{
	string strData;
	stream >> strData;

	// Expected format: {PUBLIC|PRIVATE}:Len:HexData;
	vector<string> parts;
	boost::algorithm::split(parts, strData, boost::is_from_range(':',':'));

	///////////////////////////////////////////////////////////////////////////
	// check content of the stream

	if (parts.size() != 3) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. "
		           "Format should be: {PUBLIC|PRIVATE}:Len:HexData;");
	}
	if ((parts[0] != "PUBLIC") && (parts[0] != "PRIVATE")) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. "
		           "Key type must be PUBLIC or PRIVATE.");
	}
	bool isPrivate = (parts[0] == "PRIVATE");

	size_t len = boost::lexical_cast<int>(parts[1]);
	if (parts[2].size() != (2*len+1)) {
		MIRA_THROW(XInvalidParameter,
		           "Unexpected stream data. Invalid number of data bytes.");
	}

	////////////////////////////////////////////////////////////////
	// convert string data into binary buffer

	unsigned char* buffer = (unsigned char*)OPENSSL_malloc(len);
	const char* srcPtr = parts[2].data();
	for(size_t i = 0; i < len; i++, srcPtr += 2) {
		int v = 0;
		sscanf(srcPtr, "%02x", &v);
		buffer[i] = v;
	}

	////////////////////////////////////////////////////////////////
	// convert DER binary format into a EVP_PKEY

	EVP_PKEY *pkey = NULL;
	const char *structure = "type-specific";

	int selection = isPrivate ? EVP_PKEY_KEYPAIR : EVP_PKEY_PUBLIC_KEY;

	OSSL_DECODER_CTX* ctx =
		OSSL_DECODER_CTX_new_for_pkey(&pkey, "DER", structure, "RSA",
		                              selection, NULL, NULL);
	if (ctx == NULL) {
		auto errNo = ERR_get_error();
		OPENSSL_clear_free(buffer, len);
		MIRA_THROW(XSystemCall, "OSSL_DECODER_CTX_new_for_pkey failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	const unsigned char* dataPtr = buffer;
	if (!OSSL_DECODER_from_data(ctx, &dataPtr, &len)) {
		auto errNo = ERR_get_error();
		OSSL_DECODER_CTX_free(ctx);
		OPENSSL_clear_free(buffer, len);
		MIRA_THROW(XSystemCall, "OSSL_DECODER_CTX_new_for_pkey failed: " <<
		           OpenSSLErrorString::instance().err2str(errNo));
	}

	// move key to object
	EVP_PKEY_free(key.mKey->key);
	key.mKey->key = pkey;

	// cleanup
	OSSL_DECODER_CTX_free(ctx);
	OPENSSL_clear_free(buffer, len);

	////////////////////////////////////////////////////////////////

	return(stream);
}

///////////////////////////////////////////////////////////////////////////////

} // namespace
