/*
 * Copyright (C) 2020-2023 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 *
 */

#pragma once

#include <level_zero/ze_api.h>

#include <fstream>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>

#define QTR(a) #a
#define TOSTR(b) QTR(b)

extern bool verbose;

template <bool TerminateOnFailure, typename ResulT>
inline void validate(ResulT result, const char *message) {
    if (result == ZE_RESULT_SUCCESS) {
        if (verbose) {
            std::cerr << "SUCCESS : " << message << std::endl;
        }
        return;
    }

    if (verbose) {
        std::cerr << (TerminateOnFailure ? "ERROR : " : "WARNING : ") << message << " : " << result
                  << std::endl;
    }

    if (TerminateOnFailure) {
        std::terminate();
    }
}

#define SUCCESS_OR_TERMINATE(CALL) validate<true>(CALL, #CALL)
#define SUCCESS_OR_TERMINATE_BOOL(FLAG) validate<true>(!(FLAG), #FLAG)
#define SUCCESS_OR_WARNING(CALL) validate<false>(CALL, #CALL)
#define SUCCESS_OR_WARNING_BOOL(FLAG) validate<false>(!(FLAG), #FLAG)

bool isParamEnabled(int argc, char *argv[], const char *shortName, const char *longName);

int getParamValue(int argc, char *argv[], const char *shortName, const char *longName, int defaultValue);

const char *getParamValue(int argc, char *argv[], const char *shortName, const char *longName, const char *defaultString);

bool isCircularDepTest(int argc, char *argv[]);

bool isVerbose(int argc, char *argv[]);

bool isSyncQueueEnabled(int argc, char *argv[]);

bool isAsyncQueueEnabled(int argc, char *argv[]);

bool isAubMode(int argc, char *argv[]);

bool isCommandListShared(int argc, char *argv[]);

bool isImmediateFirst(int argc, char *argv[]);

bool getAllocationFlag(int argc, char *argv[], int defaultValue);

void selectQueueMode(ze_command_queue_desc_t &desc, bool useSync);

uint32_t getBufferLength(int argc, char *argv[], uint32_t defaultLength);

void printResult(bool aubMode, bool outputValidationSuccessful, const std::string &blackBoxName, const std::string &currentTest);

void printResult(bool aubMode, bool outputValidationSuccessful, const std::string &blackBoxName);

uint32_t getCommandQueueOrdinal(ze_device_handle_t &device);

uint32_t getCopyOnlyCommandQueueOrdinal(ze_device_handle_t &device);

ze_command_queue_handle_t createCommandQueue(ze_context_handle_t &context, ze_device_handle_t &device,
                                             uint32_t *ordinal, ze_command_queue_mode_t mode,
                                             ze_command_queue_priority_t priority);

ze_command_queue_handle_t createCommandQueue(ze_context_handle_t &context, ze_device_handle_t &device, uint32_t *ordinal);

ze_result_t createCommandList(ze_context_handle_t &context, ze_device_handle_t &device, ze_command_list_handle_t &cmdList);
ze_result_t createCommandList(ze_context_handle_t &context, ze_device_handle_t &device, ze_command_list_handle_t &cmdList, uint32_t ordinal);

void createEventPoolAndEvents(ze_context_handle_t &context,
                              ze_device_handle_t &device,
                              ze_event_pool_handle_t &eventPool,
                              ze_event_pool_flag_t poolFlag,
                              uint32_t poolSize,
                              ze_event_handle_t *events,
                              ze_event_scope_flag_t signalScope,
                              ze_event_scope_flag_t waitScope);

std::vector<ze_device_handle_t> zelloGetSubDevices(ze_device_handle_t &device, uint32_t &subDevCount);

std::vector<ze_device_handle_t> zelloInitContextAndGetDevices(ze_context_handle_t &context, ze_driver_handle_t &driverHandle);

std::vector<ze_device_handle_t> zelloInitContextAndGetDevices(ze_context_handle_t &context);

void initialize(ze_driver_handle_t &driver, ze_context_handle_t &context, ze_device_handle_t &device, ze_command_queue_handle_t &cmdQueue, uint32_t &ordinal);

void teardown(ze_context_handle_t context, ze_command_queue_handle_t cmdQueue);

void printDeviceProperties(const ze_device_properties_t &props);

void printCacheProperties(uint32_t index, const ze_device_cache_properties_t &props);

void printP2PProperties(const ze_device_p2p_properties_t &props, bool canAccessPeer, uint32_t device0Index, uint32_t device1Index);

void printKernelProperties(const ze_kernel_properties_t &props, const char *kernelName);

void printCommandQueueGroupsProperties(ze_device_handle_t &device);

const std::vector<const char *> &getResourcesSearchLocations();

// read binary file into a non-NULL-terminated string
template <typename SizeT>
inline std::unique_ptr<char[]> readBinaryFile(const std::string &name, SizeT &outSize) {
    for (const char *base : getResourcesSearchLocations()) {
        std::string s(base);
        std::ifstream file(s + name, std::ios_base::in | std::ios_base::binary);
        if (false == file.good()) {
            continue;
        }

        size_t length;
        file.seekg(0, file.end);
        length = static_cast<size_t>(file.tellg());
        file.seekg(0, file.beg);

        auto storage = std::make_unique<char[]>(length);
        file.read(storage.get(), length);

        outSize = static_cast<SizeT>(length);
        return storage;
    }
    outSize = 0;
    return nullptr;
}

// read text file into a NULL-terminated string
template <typename SizeT>
inline std::unique_ptr<char[]> readTextFile(const std::string &name, SizeT &outSize) {
    for (const char *base : getResourcesSearchLocations()) {
        std::string s(base);
        std::ifstream file(s + name, std::ios_base::in);
        if (false == file.good()) {
            continue;
        }

        size_t length;
        file.seekg(0, file.end);
        length = static_cast<size_t>(file.tellg());
        file.seekg(0, file.beg);

        auto storage = std::make_unique<char[]>(length + 1);
        file.read(storage.get(), length);
        storage[length] = '\0';

        outSize = static_cast<SizeT>(length);
        return storage;
    }
    outSize = 0;
    return nullptr;
}

template <typename T = uint8_t>
inline bool validate(const void *expected, const void *tested, size_t len) {
    bool resultsAreOk = true;
    size_t offset = 0;

    const T *expectedT = reinterpret_cast<const T *>(expected);
    const T *testedT = reinterpret_cast<const T *>(tested);
    uint32_t errorsCount = 0;
    constexpr uint32_t errorsMax = 20;
    while (offset < len) {
        if (expectedT[offset] != testedT[offset]) {
            resultsAreOk = false;
            if (verbose == false) {
                break;
            }

            std::cerr << "Data mismatch expectedU8[" << offset << "] != testedU8[" << offset
                      << "]   ->    " << +expectedT[offset] << " != " << +testedT[offset]
                      << std::endl;
            ++errorsCount;
            if (errorsCount >= errorsMax) {
                std::cerr << "Found " << errorsCount
                          << " data mismatches - skipping further comparison " << std::endl;
                break;
            }
        }
        ++offset;
    }

    return resultsAreOk;
}

struct CommandHandler {
    ze_command_queue_handle_t cmdQueue;
    ze_command_list_handle_t cmdList;

    bool isImmediate = false;

    ze_result_t create(ze_context_handle_t context, ze_device_handle_t device, bool immediate) {
        isImmediate = immediate;
        ze_result_t result;
        ze_command_queue_desc_t cmdQueueDesc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC};
        cmdQueueDesc.ordinal = getCommandQueueOrdinal(device);
        cmdQueueDesc.index = 0;

        if (isImmediate) {
            cmdQueueDesc.mode = ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS;
            result = zeCommandListCreateImmediate(context, device, &cmdQueueDesc, &cmdList);
        } else {
            cmdQueueDesc.mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS;
            result = zeCommandQueueCreate(context, device, &cmdQueueDesc, &cmdQueue);
            if (result != ZE_RESULT_SUCCESS) {
                return result;
            }
            result = createCommandList(context, device, cmdList);
        }

        return result;
    }

    ze_result_t appendKernel(ze_kernel_handle_t kernel, const ze_group_count_t &dispatchTraits, ze_event_handle_t event = nullptr) {
        return zeCommandListAppendLaunchKernel(cmdList, kernel, &dispatchTraits,
                                               event, 0, nullptr);
    }

    ze_result_t execute() {
        auto result = ZE_RESULT_SUCCESS;

        if (!isImmediate) {
            result = zeCommandListClose(cmdList);
            if (result == ZE_RESULT_SUCCESS) {
                result = zeCommandQueueExecuteCommandLists(cmdQueue, 1, &cmdList, nullptr);
            }
        }
        return result;
    }

    ze_result_t synchronize() {
        if (!isImmediate) {
            return zeCommandQueueSynchronize(cmdQueue, std::numeric_limits<uint64_t>::max());
        }

        return ZE_RESULT_SUCCESS;
    }

    ze_result_t destroy() {
        auto result = zeCommandListDestroy(cmdList);
        if (result == ZE_RESULT_SUCCESS && !isImmediate) {
            result = zeCommandQueueDestroy(cmdQueue);
        }
        return result;
    }
};
