// SPDX-License-Identifier: BSD-2-Clause
/*
 * Copyright (c) 2018-2020, Linaro Limited
 */

#include <assert.h>
#include <compiler.h>
#include <mbedtls/nist_kw.h>
#include <tee_api_defines.h>
#include <tee_internal_api.h>
#include <tee_internal_api_extensions.h>

#include "attributes.h"
#include "pkcs11_helpers.h"
#include "pkcs11_token.h"
#include "processing.h"
#include "serializer.h"

bool processing_is_tee_asymm(uint32_t proc_id)
{
	switch (proc_id) {
	/* RSA flavors */
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
	case PKCS11_CKM_RSA_PKCS:
	case PKCS11_CKM_RSA_PKCS_OAEP:
	case PKCS11_CKM_RSA_PKCS_PSS:
	case PKCS11_CKM_MD5_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS:
	case PKCS11_CKM_SHA224_RSA_PKCS:
	case PKCS11_CKM_SHA256_RSA_PKCS:
	case PKCS11_CKM_SHA384_RSA_PKCS:
	case PKCS11_CKM_SHA512_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
	/* EC flavors */
	case PKCS11_CKM_EDDSA:
	case PKCS11_CKM_ECDSA:
	case PKCS11_CKM_ECDH1_DERIVE:
	case PKCS11_CKM_ECDSA_SHA1:
	case PKCS11_CKM_ECDSA_SHA224:
	case PKCS11_CKM_ECDSA_SHA256:
	case PKCS11_CKM_ECDSA_SHA384:
	case PKCS11_CKM_ECDSA_SHA512:
		return true;
	default:
		return false;
	}
}

static enum pkcs11_rc
pkcs2tee_algorithm(uint32_t *tee_id, uint32_t *tee_hash_id,
		   enum processing_func function __unused,
		   struct pkcs11_attribute_head *proc_params,
		   struct pkcs11_object *obj)
{
	static const struct {
		enum pkcs11_mechanism_id mech_id;
		uint32_t tee_id;
		uint32_t tee_hash_id;
	} pkcs2tee_algo[] = {
		/* RSA flavors */
		{ PKCS11_CKM_RSA_AES_KEY_WRAP, 1, 0 },
		{ PKCS11_CKM_RSA_PKCS, TEE_ALG_RSAES_PKCS1_V1_5, 0 },
		{ PKCS11_CKM_RSA_PKCS_OAEP, 1, 0 },
		{ PKCS11_CKM_RSA_PKCS_PSS, 1, 0 },
		{ PKCS11_CKM_MD5_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_MD5,
		  TEE_ALG_MD5 },
		{ PKCS11_CKM_SHA1_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_SHA1,
		  TEE_ALG_SHA1 },
		{ PKCS11_CKM_SHA224_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_SHA224,
		  TEE_ALG_SHA224 },
		{ PKCS11_CKM_SHA256_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_SHA256,
		  TEE_ALG_SHA256 },
		{ PKCS11_CKM_SHA384_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_SHA384,
		  TEE_ALG_SHA384 },
		{ PKCS11_CKM_SHA512_RSA_PKCS, TEE_ALG_RSASSA_PKCS1_V1_5_SHA512,
		  TEE_ALG_SHA512 },
		{ PKCS11_CKM_SHA1_RSA_PKCS_PSS,
		  TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1, TEE_ALG_SHA1 },
		{ PKCS11_CKM_SHA224_RSA_PKCS_PSS,
		  TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224, TEE_ALG_SHA224 },
		{ PKCS11_CKM_SHA256_RSA_PKCS_PSS,
		  TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256, TEE_ALG_SHA256 },
		{ PKCS11_CKM_SHA384_RSA_PKCS_PSS,
		  TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384, TEE_ALG_SHA384 },
		{ PKCS11_CKM_SHA512_RSA_PKCS_PSS,
		  TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512, TEE_ALG_SHA512 },
		/* EC flavors (Must find key size from the object) */
		{ PKCS11_CKM_ECDSA, 1, 0 },
		{ PKCS11_CKM_ECDSA_SHA1, 1, TEE_ALG_SHA1 },
		{ PKCS11_CKM_ECDSA_SHA224, 1, TEE_ALG_SHA224 },
		{ PKCS11_CKM_ECDSA_SHA256, 1, TEE_ALG_SHA256 },
		{ PKCS11_CKM_ECDSA_SHA384, 1, TEE_ALG_SHA384 },
		{ PKCS11_CKM_ECDSA_SHA512, 1, TEE_ALG_SHA512 },
		{ PKCS11_CKM_ECDH1_DERIVE, 1, 0 },
		{ PKCS11_CKM_EDDSA, TEE_ALG_ED25519, 0 },
	};
	size_t n = 0;
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;

	for (n = 0; n < ARRAY_SIZE(pkcs2tee_algo); n++) {
		if (pkcs2tee_algo[n].mech_id == proc_params->id) {
			*tee_id = pkcs2tee_algo[n].tee_id;
			*tee_hash_id = pkcs2tee_algo[n].tee_hash_id;
			break;
		}
	}

	if (n == ARRAY_SIZE(pkcs2tee_algo))
		return PKCS11_RV_NOT_IMPLEMENTED;

	switch (proc_params->id) {
	case PKCS11_CKM_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		rc = pkcs2tee_algo_rsa_pss(tee_id, proc_params);
		break;
	case PKCS11_CKM_RSA_PKCS_OAEP:
		rc = pkcs2tee_algo_rsa_oaep(tee_id, tee_hash_id, proc_params);
		break;
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
		rc = pkcs2tee_algo_rsa_aes_wrap(tee_id, tee_hash_id,
						proc_params);
		break;
	case PKCS11_CKM_ECDSA:
	case PKCS11_CKM_ECDSA_SHA1:
	case PKCS11_CKM_ECDSA_SHA224:
	case PKCS11_CKM_ECDSA_SHA256:
	case PKCS11_CKM_ECDSA_SHA384:
	case PKCS11_CKM_ECDSA_SHA512:
		rc = pkcs2tee_algo_ecdsa(tee_id, proc_params, obj);
		break;
	case PKCS11_CKM_ECDH1_DERIVE:
		rc = pkcs2tee_algo_ecdh(tee_id, proc_params, obj);
		break;
	default:
		rc = PKCS11_CKR_OK;
		break;
	}

	/*
	 * PKCS#11 uses single mechanism CKM_RSA_PKCS for both ciphering and
	 * authentication whereas GPD TEE expects TEE_ALG_RSAES_PKCS1_V1_5 for
	 * ciphering and TEE_ALG_RSASSA_PKCS1_V1_5 for authentication.
	 */
	if (*tee_id == TEE_ALG_RSAES_PKCS1_V1_5 &&
	    (function == PKCS11_FUNCTION_SIGN ||
	     function == PKCS11_FUNCTION_VERIFY))
		*tee_id = TEE_ALG_RSASSA_PKCS1_V1_5;

	return rc;
}

static enum pkcs11_rc pkcs2tee_key_type(uint32_t *tee_type,
					struct pkcs11_object *obj,
					enum processing_func function)
{
	enum pkcs11_class_id class = get_class(obj->attributes);
	enum pkcs11_key_type type = get_key_type(obj->attributes);

	switch (class) {
	case PKCS11_CKO_PUBLIC_KEY:
	case PKCS11_CKO_PRIVATE_KEY:
		break;
	default:
		TEE_Panic(class);
		break;
	}

	switch (type) {
	case PKCS11_CKK_EC:
		if (class == PKCS11_CKO_PRIVATE_KEY) {
			if (function == PKCS11_FUNCTION_DERIVE)
				*tee_type = TEE_TYPE_ECDH_KEYPAIR;
			else
				*tee_type = TEE_TYPE_ECDSA_KEYPAIR;
		} else {
			if (function == PKCS11_FUNCTION_DERIVE)
				*tee_type = TEE_TYPE_ECDH_PUBLIC_KEY;
			else
				*tee_type = TEE_TYPE_ECDSA_PUBLIC_KEY;
		}
		break;
	case PKCS11_CKK_RSA:
		if (class == PKCS11_CKO_PRIVATE_KEY)
			*tee_type = TEE_TYPE_RSA_KEYPAIR;
		else
			*tee_type = TEE_TYPE_RSA_PUBLIC_KEY;
		break;
	case PKCS11_CKK_EC_EDWARDS:
		if (class == PKCS11_CKO_PRIVATE_KEY)
			*tee_type = TEE_TYPE_ED25519_KEYPAIR;
		else
			*tee_type = TEE_TYPE_ED25519_PUBLIC_KEY;
		break;
	default:
		TEE_Panic(type);
		break;
	}

	return PKCS11_CKR_OK;
}

static enum pkcs11_rc
allocate_tee_operation(struct pkcs11_session *session,
		       enum processing_func function,
		       struct pkcs11_attribute_head *params,
		       struct pkcs11_object *obj)
{
	uint32_t size = (uint32_t)get_object_key_bit_size(obj);
	uint32_t algo = 0;
	uint32_t hash_algo = 0;
	uint32_t mode = 0;
	uint32_t hash_mode = 0;
	TEE_Result res = TEE_ERROR_GENERIC;
	struct active_processing *processing = session->processing;

	assert(processing->tee_op_handle == TEE_HANDLE_NULL);
	assert(processing->tee_op_handle2 == TEE_HANDLE_NULL);

	if (pkcs2tee_algorithm(&algo, &hash_algo, function, params, obj))
		return PKCS11_CKR_FUNCTION_FAILED;

	pkcs2tee_mode(&mode, function);

	if (hash_algo) {
		pkcs2tee_mode(&hash_mode, PKCS11_FUNCTION_DIGEST);

		res = TEE_AllocateOperation(&processing->tee_op_handle2,
					    hash_algo, hash_mode, 0);
		if (res) {
			EMSG("TEE_AllocateOp. failed %#"PRIx32" %#"PRIx32,
			     hash_algo, hash_mode);

			if (res == TEE_ERROR_NOT_SUPPORTED)
				return PKCS11_CKR_MECHANISM_INVALID;
			return tee2pkcs_error(res);
		}
		processing->tee_hash_algo = hash_algo;
	}

	res = TEE_AllocateOperation(&processing->tee_op_handle,
				    algo, mode, size);
	if (res)
		EMSG("TEE_AllocateOp. failed %#"PRIx32" %#"PRIx32" %#"PRIx32,
		     algo, mode, size);

	if (res == TEE_ERROR_NOT_SUPPORTED)
		return PKCS11_CKR_MECHANISM_INVALID;

	if (res != TEE_SUCCESS &&
	    processing->tee_op_handle2 != TEE_HANDLE_NULL) {
		TEE_FreeOperation(session->processing->tee_op_handle2);
		processing->tee_op_handle2 = TEE_HANDLE_NULL;
		processing->tee_hash_algo = 0;
	}

	return tee2pkcs_error(res);
}

static enum pkcs11_rc load_tee_key(struct pkcs11_session *session,
				   struct pkcs11_object *obj,
				   enum processing_func function)
{
	TEE_Attribute *tee_attrs = NULL;
	size_t tee_attrs_count = 0;
	size_t object_size = 0;
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;
	TEE_Result res = TEE_ERROR_GENERIC;
	enum pkcs11_class_id __maybe_unused class = get_class(obj->attributes);
	enum pkcs11_key_type type = get_key_type(obj->attributes);

	assert(class == PKCS11_CKO_PUBLIC_KEY ||
	       class == PKCS11_CKO_PRIVATE_KEY);

	if (obj->key_handle != TEE_HANDLE_NULL) {
		switch (type) {
		case PKCS11_CKK_RSA:
			/* RSA loaded keys can be reused */
			assert((obj->key_type == TEE_TYPE_RSA_PUBLIC_KEY &&
				class == PKCS11_CKO_PUBLIC_KEY) ||
			       (obj->key_type == TEE_TYPE_RSA_KEYPAIR &&
				class == PKCS11_CKO_PRIVATE_KEY));
			goto key_ready;
		case PKCS11_CKK_EC:
			/* Reuse EC TEE key only if already DSA or DH */
			switch (obj->key_type) {
			case TEE_TYPE_ECDSA_PUBLIC_KEY:
			case TEE_TYPE_ECDSA_KEYPAIR:
				if (function != PKCS11_FUNCTION_DERIVE)
					goto key_ready;
				break;
			case TEE_TYPE_ECDH_PUBLIC_KEY:
			case TEE_TYPE_ECDH_KEYPAIR:
				if (function == PKCS11_FUNCTION_DERIVE)
					goto key_ready;
				break;
			default:
				assert(0);
				break;
			}
			break;
		default:
			assert(0);
			break;
		}

		TEE_CloseObject(obj->key_handle);
		obj->key_handle = TEE_HANDLE_NULL;
	}

	rc = pkcs2tee_key_type(&obj->key_type, obj, function);
	if (rc)
		return rc;

	object_size = get_object_key_bit_size(obj);
	if (!object_size)
		return PKCS11_CKR_GENERAL_ERROR;

	switch (type) {
	case PKCS11_CKK_RSA:
		rc = load_tee_rsa_key_attrs(&tee_attrs, &tee_attrs_count, obj);
		break;
	case PKCS11_CKK_EC:
		rc = load_tee_ec_key_attrs(&tee_attrs, &tee_attrs_count, obj);
		break;
	case PKCS11_CKK_EC_EDWARDS:
		rc = load_tee_eddsa_key_attrs(&tee_attrs, &tee_attrs_count,
					      obj);
		break;
	default:
		break;
	}
	if (rc)
		return rc;

	res = TEE_AllocateTransientObject(obj->key_type, object_size,
					  &obj->key_handle);
	if (res) {
		DMSG("TEE_AllocateTransientObject failed, %#"PRIx32, res);

		return tee2pkcs_error(res);
	}

	res = TEE_PopulateTransientObject(obj->key_handle,
					  tee_attrs, tee_attrs_count);

	TEE_Free(tee_attrs);

	if (res) {
		DMSG("TEE_PopulateTransientObject failed, %#"PRIx32, res);

		goto error;
	}

key_ready:
	res = TEE_SetOperationKey(session->processing->tee_op_handle,
				  obj->key_handle);
	if (res) {
		DMSG("TEE_SetOperationKey failed, %#"PRIx32, res);

		goto error;
	}

	return PKCS11_CKR_OK;

error:
	TEE_FreeTransientObject(obj->key_handle);
	obj->key_handle = TEE_HANDLE_NULL;
	return tee2pkcs_error(res);
}

static enum pkcs11_rc
init_tee_operation(struct pkcs11_session *session,
		   struct pkcs11_attribute_head *proc_params,
		   struct pkcs11_object *obj)
{
	enum pkcs11_rc rc = PKCS11_CKR_OK;
	struct active_processing *proc = session->processing;

	switch (proc_params->id) {
	case PKCS11_CKM_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		rc = pkcs2tee_proc_params_rsa_pss(proc, proc_params);
		if (rc)
			break;

		rc = pkcs2tee_validate_rsa_pss(proc, obj);
		break;
	case PKCS11_CKM_RSA_PKCS_OAEP:
		rc = pkcs2tee_proc_params_rsa_oaep(proc, proc_params);
		break;
	case PKCS11_CKM_EDDSA:
		rc = pkcs2tee_proc_params_eddsa(proc, proc_params);
		break;
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
		rc = pkcs2tee_proc_params_rsa_aes_wrap(proc, proc_params);
		break;
	default:
		break;
	}

	return rc;
}

enum pkcs11_rc init_asymm_operation(struct pkcs11_session *session,
				    enum processing_func function,
				    struct pkcs11_attribute_head *proc_params,
				    struct pkcs11_object *obj)
{
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;

	assert(processing_is_tee_asymm(proc_params->id));

	rc = allocate_tee_operation(session, function, proc_params, obj);
	if (rc)
		return rc;

	rc = load_tee_key(session, obj, function);
	if (rc)
		return rc;

	rc = init_tee_operation(session, proc_params, obj);
	if (!rc)
		session->processing->mecha_type = proc_params->id;

	return rc;
}

/*
 * step_sym_step - step (update/oneshot/final) on a symmetric crypto operation
 *
 * @session - current session
 * @function - processing function (encrypt, decrypt, sign, ...)
 * @step - step ID in the processing (oneshot, update, final)
 * @ptypes - invocation parameter types
 * @params - invocation parameter references
 */
enum pkcs11_rc step_asymm_operation(struct pkcs11_session *session,
				    enum processing_func function,
				    enum processing_step step,
				    uint32_t ptypes, TEE_Param *params)
{
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;
	TEE_Result res = TEE_ERROR_GENERIC;
	void *in_buf = NULL;
	void *in2_buf = NULL;
	void *out_buf = NULL;
	void *hash_buf = NULL;
	uint32_t in_size = 0;
	uint32_t in2_size = 0;
	size_t out_size = 0;
	size_t hash_size = 0;
	TEE_Attribute *tee_attrs = NULL;
	size_t tee_attrs_count = 0;
	bool output_data = false;
	struct active_processing *proc = session->processing;
	struct rsa_aes_key_wrap_processing_ctx *rsa_aes_ctx = NULL;
	struct rsa_oaep_processing_ctx *rsa_oaep_ctx = NULL;
	struct rsa_pss_processing_ctx *rsa_pss_ctx = NULL;
	struct eddsa_processing_ctx *eddsa_ctx = NULL;
	size_t sz = 0;

	if (TEE_PARAM_TYPE_GET(ptypes, 1) == TEE_PARAM_TYPE_MEMREF_INPUT) {
		in_buf = params[1].memref.buffer;
		in_size = params[1].memref.size;
		if (in_size && !in_buf)
			return PKCS11_CKR_ARGUMENTS_BAD;
	}
	if (TEE_PARAM_TYPE_GET(ptypes, 2) == TEE_PARAM_TYPE_MEMREF_INPUT) {
		in2_buf = params[2].memref.buffer;
		in2_size = params[2].memref.size;
		if (in2_size && !in2_buf)
			return PKCS11_CKR_ARGUMENTS_BAD;
	}
	if (TEE_PARAM_TYPE_GET(ptypes, 2) == TEE_PARAM_TYPE_MEMREF_OUTPUT) {
		out_buf = params[2].memref.buffer;
		out_size = params[2].memref.size;
		if (out_size && !out_buf)
			return PKCS11_CKR_ARGUMENTS_BAD;
	}
	if (TEE_PARAM_TYPE_GET(ptypes, 3) != TEE_PARAM_TYPE_NONE)
		return PKCS11_CKR_ARGUMENTS_BAD;

	switch (step) {
	case PKCS11_FUNC_STEP_ONESHOT:
	case PKCS11_FUNC_STEP_UPDATE:
	case PKCS11_FUNC_STEP_FINAL:
		break;
	default:
		return PKCS11_CKR_GENERAL_ERROR;
	}

	/* TEE attribute(s) required by the operation */
	switch (proc->mecha_type) {
	case PKCS11_CKM_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		tee_attrs = TEE_Malloc(sizeof(TEE_Attribute),
				       TEE_USER_MEM_HINT_NO_FILL_ZERO);
		if (!tee_attrs) {
			rc = PKCS11_CKR_DEVICE_MEMORY;
			goto out;
		}

		rsa_pss_ctx = proc->extra_ctx;

		TEE_InitValueAttribute(&tee_attrs[tee_attrs_count],
				       TEE_ATTR_RSA_PSS_SALT_LENGTH,
				       rsa_pss_ctx->salt_len, 0);
		tee_attrs_count++;
		break;
	case PKCS11_CKM_EDDSA:
		eddsa_ctx = proc->extra_ctx;

		tee_attrs = TEE_Malloc(2 * sizeof(TEE_Attribute),
				       TEE_USER_MEM_HINT_NO_FILL_ZERO);
		if (!tee_attrs) {
			rc = PKCS11_CKR_DEVICE_MEMORY;
			goto out;
		}

		if (eddsa_ctx->flag) {
			TEE_InitValueAttribute(&tee_attrs[tee_attrs_count],
					       TEE_ATTR_EDDSA_PREHASH, 0, 0);
			tee_attrs_count++;
		}

		if (eddsa_ctx->ctx_len > 0) {
			TEE_InitRefAttribute(&tee_attrs[tee_attrs_count],
					     TEE_ATTR_EDDSA_CTX, eddsa_ctx->ctx,
					     eddsa_ctx->ctx_len);
			tee_attrs_count++;
		}
		break;
	case PKCS11_CKM_RSA_PKCS_OAEP:
		rsa_oaep_ctx = proc->extra_ctx;

		if (!rsa_oaep_ctx->source_data_len)
			break;

		tee_attrs = TEE_Malloc(sizeof(TEE_Attribute),
				       TEE_USER_MEM_HINT_NO_FILL_ZERO);
		if (!tee_attrs) {
			rc = PKCS11_CKR_DEVICE_MEMORY;
			goto out;
		}

		TEE_InitRefAttribute(&tee_attrs[tee_attrs_count],
				     TEE_ATTR_RSA_OAEP_LABEL,
				     rsa_oaep_ctx->source_data,
				     rsa_oaep_ctx->source_data_len);
		tee_attrs_count++;
		break;
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
		rsa_aes_ctx = proc->extra_ctx;

		if (!rsa_aes_ctx->source_data_len)
			break;

		tee_attrs = TEE_Malloc(sizeof(TEE_Attribute),
				       TEE_USER_MEM_HINT_NO_FILL_ZERO);
		if (!tee_attrs) {
			rc = PKCS11_CKR_DEVICE_MEMORY;
			goto out;
		}

		TEE_InitRefAttribute(&tee_attrs[tee_attrs_count],
				     TEE_ATTR_RSA_OAEP_LABEL,
				     rsa_aes_ctx->source_data,
				     rsa_aes_ctx->source_data_len);
		tee_attrs_count++;
		break;
	default:
		break;
	}

	/*
	 * Handle multi stage update step for mechas needing hash
	 * calculation
	 */
	if (step == PKCS11_FUNC_STEP_UPDATE) {
		switch (proc->mecha_type) {
		case PKCS11_CKM_ECDSA_SHA1:
		case PKCS11_CKM_ECDSA_SHA224:
		case PKCS11_CKM_ECDSA_SHA256:
		case PKCS11_CKM_ECDSA_SHA384:
		case PKCS11_CKM_ECDSA_SHA512:
		case PKCS11_CKM_MD5_RSA_PKCS:
		case PKCS11_CKM_SHA1_RSA_PKCS:
		case PKCS11_CKM_SHA224_RSA_PKCS:
		case PKCS11_CKM_SHA256_RSA_PKCS:
		case PKCS11_CKM_SHA384_RSA_PKCS:
		case PKCS11_CKM_SHA512_RSA_PKCS:
		case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
		case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
		case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
		case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
		case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
			assert(proc->tee_op_handle2 != TEE_HANDLE_NULL);

			TEE_DigestUpdate(proc->tee_op_handle2, in_buf, in_size);
			rc = PKCS11_CKR_OK;
			break;
		default:
			/*
			 * Other mechanism do not expect multi stage
			 * operation
			 */
			rc = PKCS11_CKR_GENERAL_ERROR;
			break;
		}

		goto out;
	}

	/*
	 * Handle multi stage one shot and final steps for mechas needing hash
	 * calculation
	 */
	switch (proc->mecha_type) {
	case PKCS11_CKM_ECDSA_SHA1:
	case PKCS11_CKM_ECDSA_SHA224:
	case PKCS11_CKM_ECDSA_SHA256:
	case PKCS11_CKM_ECDSA_SHA384:
	case PKCS11_CKM_ECDSA_SHA512:
	case PKCS11_CKM_MD5_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS:
	case PKCS11_CKM_SHA224_RSA_PKCS:
	case PKCS11_CKM_SHA256_RSA_PKCS:
	case PKCS11_CKM_SHA384_RSA_PKCS:
	case PKCS11_CKM_SHA512_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		assert(proc->tee_op_handle2 != TEE_HANDLE_NULL);

		hash_size = TEE_ALG_GET_DIGEST_SIZE(proc->tee_hash_algo);
		hash_buf = TEE_Malloc(hash_size, 0);
		if (!hash_buf)
			return PKCS11_CKR_DEVICE_MEMORY;

		res = TEE_DigestDoFinal(proc->tee_op_handle2, in_buf, in_size,
					hash_buf, &hash_size);

		rc = tee2pkcs_error(res);
		if (rc != PKCS11_CKR_OK)
			goto out;

		break;
	default:
		break;
	}

	/*
	 * Finalize either provided hash or calculated hash with signing
	 * operation
	 */

	/* First determine amount of bytes for signing operation */
	switch (proc->mecha_type) {
	case PKCS11_CKM_ECDSA:
		sz = ecdsa_get_input_max_byte_size(proc->tee_op_handle);
		if (!in_size || !sz) {
			rc = PKCS11_CKR_FUNCTION_FAILED;
			goto out;
		}

		/*
		 * Note 3) Input the entire raw digest. Internally, this will
		 * be truncated to the appropriate number of bits.
		 */
		if (in_size > sz)
			in_size = sz;

		if (function == PKCS11_FUNCTION_VERIFY && in2_size != 2 * sz) {
			rc = PKCS11_CKR_SIGNATURE_LEN_RANGE;
			goto out;
		}
		break;
	case PKCS11_CKM_ECDSA_SHA1:
	case PKCS11_CKM_ECDSA_SHA224:
	case PKCS11_CKM_ECDSA_SHA256:
	case PKCS11_CKM_ECDSA_SHA384:
	case PKCS11_CKM_ECDSA_SHA512:
		/* Get key size in bytes */
		sz = ecdsa_get_input_max_byte_size(proc->tee_op_handle);
		if (!sz) {
			rc = PKCS11_CKR_FUNCTION_FAILED;
			goto out;
		}

		if (function == PKCS11_FUNCTION_VERIFY &&
		    in2_size != 2 * sz) {
			rc = PKCS11_CKR_SIGNATURE_LEN_RANGE;
			goto out;
		}
		break;
	case PKCS11_CKM_RSA_PKCS:
	case PKCS11_CKM_MD5_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS:
	case PKCS11_CKM_SHA224_RSA_PKCS:
	case PKCS11_CKM_SHA256_RSA_PKCS:
	case PKCS11_CKM_SHA384_RSA_PKCS:
	case PKCS11_CKM_SHA512_RSA_PKCS:
	case PKCS11_CKM_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		/* Get key size in bytes */
		sz = rsa_get_input_max_byte_size(proc->tee_op_handle);
		if (!sz) {
			rc = PKCS11_CKR_FUNCTION_FAILED;
			goto out;
		}

		if (function == PKCS11_FUNCTION_VERIFY && in2_size != sz) {
			rc = PKCS11_CKR_SIGNATURE_LEN_RANGE;
			goto out;
		}
		break;
	default:
		break;
	}

	/* Next perform actual signing operation */
	switch (proc->mecha_type) {
	case PKCS11_CKM_ECDSA:
	case PKCS11_CKM_EDDSA:
	case PKCS11_CKM_RSA_PKCS:
	case PKCS11_CKM_RSA_PKCS_OAEP:
	case PKCS11_CKM_RSA_PKCS_PSS:
		/* For operations using provided input data */
		switch (function) {
		case PKCS11_FUNCTION_ENCRYPT:
			res = TEE_AsymmetricEncrypt(proc->tee_op_handle,
						    tee_attrs, tee_attrs_count,
						    in_buf, in_size,
						    out_buf, &out_size);
			output_data = true;
			rc = tee2pkcs_error(res);
			if (rc == PKCS11_CKR_ARGUMENTS_BAD)
				rc = PKCS11_CKR_DATA_LEN_RANGE;
			break;

		case PKCS11_FUNCTION_DECRYPT:
			res = TEE_AsymmetricDecrypt(proc->tee_op_handle,
						    tee_attrs, tee_attrs_count,
						    in_buf, in_size,
						    out_buf, &out_size);
			output_data = true;
			rc = tee2pkcs_error(res);
			if (rc == PKCS11_CKR_ARGUMENTS_BAD)
				rc = PKCS11_CKR_ENCRYPTED_DATA_LEN_RANGE;
			break;

		case PKCS11_FUNCTION_SIGN:
			res = TEE_AsymmetricSignDigest(proc->tee_op_handle,
						       tee_attrs,
						       tee_attrs_count,
						       in_buf, in_size,
						       out_buf, &out_size);
			output_data = true;
			rc = tee2pkcs_error(res);
			break;

		case PKCS11_FUNCTION_VERIFY:
			res = TEE_AsymmetricVerifyDigest(proc->tee_op_handle,
							 tee_attrs,
							 tee_attrs_count,
							 in_buf, in_size,
							 in2_buf, in2_size);
			rc = tee2pkcs_error(res);
			break;

		default:
			TEE_Panic(function);
			break;
		}
		break;
	case PKCS11_CKM_ECDSA_SHA1:
	case PKCS11_CKM_ECDSA_SHA224:
	case PKCS11_CKM_ECDSA_SHA256:
	case PKCS11_CKM_ECDSA_SHA384:
	case PKCS11_CKM_ECDSA_SHA512:
	case PKCS11_CKM_MD5_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS:
	case PKCS11_CKM_SHA224_RSA_PKCS:
	case PKCS11_CKM_SHA256_RSA_PKCS:
	case PKCS11_CKM_SHA384_RSA_PKCS:
	case PKCS11_CKM_SHA512_RSA_PKCS:
	case PKCS11_CKM_SHA1_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA224_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA256_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA384_RSA_PKCS_PSS:
	case PKCS11_CKM_SHA512_RSA_PKCS_PSS:
		/* For operations having hash operation use calculated hash */
		switch (function) {
		case PKCS11_FUNCTION_SIGN:
			res = TEE_AsymmetricSignDigest(proc->tee_op_handle,
						       tee_attrs,
						       tee_attrs_count,
						       hash_buf, hash_size,
						       out_buf, &out_size);
			output_data = true;
			rc = tee2pkcs_error(res);
			break;

		case PKCS11_FUNCTION_VERIFY:
			res = TEE_AsymmetricVerifyDigest(proc->tee_op_handle,
							 tee_attrs,
							 tee_attrs_count,
							 hash_buf, hash_size,
							 in2_buf, in2_size);
			rc = tee2pkcs_error(res);
			break;

		default:
			TEE_Panic(function);
			break;
		}
		break;
	default:
		TEE_Panic(proc->mecha_type);
		break;
	}

out:
	if (output_data &&
	    (rc == PKCS11_CKR_OK || rc == PKCS11_CKR_BUFFER_TOO_SMALL)) {
		switch (TEE_PARAM_TYPE_GET(ptypes, 2)) {
		case TEE_PARAM_TYPE_MEMREF_OUTPUT:
		case TEE_PARAM_TYPE_MEMREF_INOUT:
			params[2].memref.size = out_size;
			break;
		default:
			rc = PKCS11_CKR_GENERAL_ERROR;
			break;
		}
	}

	TEE_Free(hash_buf);
	TEE_Free(tee_attrs);

	return rc;
}

enum pkcs11_rc do_asymm_derivation(struct pkcs11_session *session,
				   struct pkcs11_attribute_head *proc_params,
				   struct obj_attrs **head)
{
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;
	TEE_ObjectHandle out_handle = TEE_HANDLE_NULL;
	TEE_Result res = TEE_ERROR_GENERIC;
	TEE_Attribute tee_attrs[2] = { };
	size_t tee_attrs_count = 0;
	uint32_t key_byte_size = 0;
	uint32_t key_bit_size = 0;
	void *a_ptr = NULL;
	size_t a_size = 0;

	/* Remove default attribute set at template sanitization */
	if (remove_empty_attribute(head, PKCS11_CKA_VALUE))
		return PKCS11_CKR_FUNCTION_FAILED;

	rc = get_u32_attribute(*head, PKCS11_CKA_VALUE_LEN, &key_bit_size);
	if (rc)
		return rc;

	key_bit_size *= 8;
	key_byte_size = (key_bit_size + 7) / 8;

	res = TEE_AllocateTransientObject(TEE_TYPE_GENERIC_SECRET,
					  key_byte_size * 8, &out_handle);
	if (res) {
		DMSG("TEE_AllocateTransientObject failed, %#"PRIx32, res);
		return tee2pkcs_error(res);
	}

	switch (proc_params->id) {
	case PKCS11_CKM_ECDH1_DERIVE:
		rc = pkcs2tee_param_ecdh(proc_params, &a_ptr, &a_size);
		if (rc)
			goto out;

		TEE_InitRefAttribute(&tee_attrs[tee_attrs_count],
				     TEE_ATTR_ECC_PUBLIC_VALUE_X,
				     a_ptr, a_size / 2);
		tee_attrs_count++;
		TEE_InitRefAttribute(&tee_attrs[tee_attrs_count],
				     TEE_ATTR_ECC_PUBLIC_VALUE_Y,
				     (char *)a_ptr + a_size / 2,
				     a_size / 2);
		tee_attrs_count++;
		break;
	default:
		TEE_Panic(proc_params->id);
		break;
	}

	TEE_DeriveKey(session->processing->tee_op_handle, &tee_attrs[0],
		      tee_attrs_count, out_handle);

	rc = alloc_get_tee_attribute_data(out_handle, TEE_ATTR_SECRET_VALUE,
					  &a_ptr, &a_size);
	if (rc)
		goto out;

	if (a_size * 8 < key_bit_size)
		rc = PKCS11_CKR_KEY_SIZE_RANGE;
	else
		rc = add_attribute(head, PKCS11_CKA_VALUE, a_ptr,
				   key_byte_size);
	TEE_Free(a_ptr);
out:
	release_active_processing(session);
	TEE_FreeTransientObject(out_handle);

	return rc;
}

static enum pkcs11_rc wrap_rsa_aes_key(struct active_processing *proc,
				       void *data, uint32_t data_sz,
				       void *out_buf, uint32_t *out_sz)
{
	enum pkcs11_rc rc = PKCS11_CKR_OK;
	TEE_Result res = TEE_ERROR_GENERIC;
	int mbedtls_rc = 0;
	struct rsa_aes_key_wrap_processing_ctx *ctx = proc->extra_ctx;
	mbedtls_nist_kw_context kw_ctx = { };
	uint8_t aes_key_value[32] = { };
	uint32_t aes_key_size = ctx->aes_key_bits / 8;
	size_t aes_wrapped_size = *out_sz;
	uint32_t expected_size = 0;
	size_t target_key_size = 0;
	const size_t kw_semiblock_len = 8;

	if (ctx->aes_key_bits != 128 &&
	    ctx->aes_key_bits != 192 &&
	    ctx->aes_key_bits != 256)
		return PKCS11_CKR_ARGUMENTS_BAD;

	mbedtls_nist_kw_init(&kw_ctx);
	TEE_GenerateRandom(aes_key_value, aes_key_size);
	res = TEE_AsymmetricEncrypt(proc->tee_op_handle,
				    NULL, 0,
				    aes_key_value, aes_key_size,
				    out_buf, &aes_wrapped_size);
	expected_size = aes_wrapped_size + data_sz + kw_semiblock_len;
	if (res) {
		if (res == TEE_ERROR_SHORT_BUFFER)
			*out_sz = expected_size;

		rc = tee2pkcs_error(res);
		goto out;
	}

	if (*out_sz < expected_size) {
		rc = PKCS11_CKR_BUFFER_TOO_SMALL;
		*out_sz = expected_size;
		goto out;
	}

	mbedtls_rc = mbedtls_nist_kw_setkey(&kw_ctx, MBEDTLS_CIPHER_ID_AES,
					    aes_key_value, ctx->aes_key_bits,
					    true);
	if (mbedtls_rc) {
		if (mbedtls_rc == MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA)
			rc = PKCS11_CKR_KEY_SIZE_RANGE;
		else
			rc = PKCS11_CKR_FUNCTION_FAILED;

		goto out;
	}

	mbedtls_rc = mbedtls_nist_kw_wrap(&kw_ctx, MBEDTLS_KW_MODE_KWP,
					  data, data_sz,
					  (uint8_t *)out_buf + aes_wrapped_size,
					  &target_key_size,
					  *out_sz - aes_wrapped_size);
	if (mbedtls_rc) {
		rc = PKCS11_CKR_ARGUMENTS_BAD;
		goto out;
	}

	assert(*out_sz >= target_key_size + aes_wrapped_size);
	*out_sz = target_key_size + aes_wrapped_size;

out:
	mbedtls_nist_kw_free(&kw_ctx);
	TEE_MemFill(aes_key_value, 0, aes_key_size);
	return rc;
}

static enum pkcs11_rc unwrap_rsa_aes_key(struct active_processing *proc,
					 void *data, uint32_t data_sz,
					 void **out_buf, uint32_t *out_sz)
{
	enum pkcs11_rc rc = PKCS11_CKR_OK;
	int mbedtls_rc = 0;
	TEE_Result res = TEE_ERROR_GENERIC;
	TEE_OperationInfo info = { };
	struct rsa_aes_key_wrap_processing_ctx *ctx = proc->extra_ctx;
	mbedtls_nist_kw_context kw_ctx = { };
	uint8_t aes_key_value[32] = { };
	size_t aes_key_size = ctx->aes_key_bits / 8;
	uint32_t wrapped_key_size = 0;
	uint32_t rsa_key_size = 0;
	size_t target_key_size = 0;

	if (ctx->aes_key_bits != 128 &&
	    ctx->aes_key_bits != 192 &&
	    ctx->aes_key_bits != 256)
		return PKCS11_CKR_ARGUMENTS_BAD;

	TEE_GetOperationInfo(proc->tee_op_handle, &info);
	rsa_key_size = info.keySize / 8;
	wrapped_key_size = data_sz - rsa_key_size;
	target_key_size = wrapped_key_size - 8;

	*out_buf = TEE_Malloc(target_key_size, TEE_MALLOC_FILL_ZERO);
	if (!*out_buf)
		return PKCS11_CKR_DEVICE_MEMORY;

	mbedtls_nist_kw_init(&kw_ctx);
	res = TEE_AsymmetricDecrypt(proc->tee_op_handle,
				    NULL, 0,
				    data, rsa_key_size,
				    aes_key_value, &aes_key_size);
	if (res) {
		rc = tee2pkcs_error(res);
		goto out;
	}

	mbedtls_rc = mbedtls_nist_kw_setkey(&kw_ctx, MBEDTLS_CIPHER_ID_AES,
					    aes_key_value, ctx->aes_key_bits,
					    false);
	if (mbedtls_rc) {
		rc = PKCS11_CKR_WRAPPED_KEY_INVALID;
		goto out;
	}

	mbedtls_rc = mbedtls_nist_kw_unwrap(&kw_ctx, MBEDTLS_KW_MODE_KWP,
					    (uint8_t *)data + rsa_key_size,
					    wrapped_key_size, *out_buf,
					    &target_key_size, target_key_size);
	if (mbedtls_rc) {
		rc = PKCS11_CKR_WRAPPED_KEY_INVALID;
		goto out;
	}

	*out_sz = target_key_size;
out:
	TEE_MemFill(aes_key_value, 0, aes_key_size);
	mbedtls_nist_kw_free(&kw_ctx);
	return rc;
}

enum pkcs11_rc wrap_data_by_asymm_enc(struct pkcs11_session *session,
				      void *data, uint32_t data_sz,
				      void *out_buf, uint32_t *out_sz)
{
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;
	struct active_processing *proc = session->processing;

	switch (proc->mecha_type) {
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
		rc = wrap_rsa_aes_key(proc, data, data_sz, out_buf, out_sz);
		break;
	default:
		return PKCS11_CKR_MECHANISM_INVALID;
	}

	return rc;
}

enum pkcs11_rc unwrap_key_by_asymm(struct pkcs11_session *session,
				   void *data, uint32_t data_sz,
				   void **out_buf, uint32_t *out_sz)
{
	enum pkcs11_rc rc = PKCS11_CKR_GENERAL_ERROR;
	struct active_processing *proc = session->processing;

	switch (proc->mecha_type) {
	case PKCS11_CKM_RSA_AES_KEY_WRAP:
		rc = unwrap_rsa_aes_key(proc, data, data_sz, out_buf, out_sz);
		break;
	default:
		return PKCS11_CKR_MECHANISM_INVALID;
	}

	return rc;
}
