//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Tests/Unit/PyBinding/EmbeddedTest.cpp
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Base/Const/Units.h"
#include "PyCore/Embed/PyInterpreter.h"
#include "PyCore/Embed/PyObjectPtr.h"
#include "PyCore/Sample/ImportMultiLayer.h"
#include "Tests/GTestWrapper/google_test.h"
#include <string>
#include <vector>

//! Importing numpy and accessing its version string.

TEST(Embedded, PyInterpreterTest)
{
    // initialize Python interpreter
    PyInterpreter::initialize();
    EXPECT_TRUE(PyInterpreter::isInitialized());

    PyInterpreter::Numpy::initialize();
    EXPECT_TRUE(PyInterpreter::Numpy::isInitialized());

    // add Python path
    PyInterpreter::addPythonPath("/Some/Extra/Python/Path/");
    EXPECT_FALSE(PyInterpreter::checkError());

    // get runtime info
    std::string runtime_info = PyInterpreter::runtimeInfo();
    std::cout << "Python runtime info:\n" << runtime_info << std::endl;

    // set Python path
    PyInterpreter::setPythonPath("/Some/Extra/Python/Path/");
}


TEST(Embedded, PyInterpreterNumpyTest)
{
    // initialize Python interpreter
    PyInterpreter::initialize();
    EXPECT_TRUE(PyInterpreter::isInitialized());

    PyInterpreter::Numpy::initialize();
    EXPECT_TRUE(PyInterpreter::Numpy::isInitialized());

    // import Numpy
    PyObjectPtr numpy_module = PyInterpreter::import("numpy");
    EXPECT_TRUE(numpy_module.valid());

    // initialize Numpy
    PyInterpreter::Numpy::initialize();
    EXPECT_TRUE(PyInterpreter::Numpy::isInitialized());

    // create Numpy 1D and 2D arrays from a C-array
    const int n_rows = 3, n_cols = 4, a_size = n_rows * n_cols;
    double c_array[a_size];
    for (int ii = 0; ii < a_size; ++ii)
        c_array[ii] = ii + 1;

    PyObjectPtr np_array1d = PyInterpreter::Numpy::createArray1DfromC(
        c_array, static_cast<PyInterpreter::Numpy::np_size_t>(a_size));
    EXPECT_TRUE(np_array1d.valid());

    PyInterpreter::Numpy::np_size_t dims[2] = {n_rows, n_cols};
    PyObjectPtr np_array2d = PyInterpreter::Numpy::createArray2DfromC(c_array, dims);
    EXPECT_TRUE(np_array2d.valid());

    PyObjectPtr np_as_array2d = PyInterpreter::Numpy::CArrayAsNpy2D(c_array, dims);
    EXPECT_TRUE(np_as_array2d.valid());

    double* np_array2d_ptr = PyInterpreter::Numpy::getDataPtr(np_array2d.get());
    EXPECT_TRUE(bool(np_array2d_ptr));

    // create an empty N-dimensional Numpy array
    std::vector<std::size_t> dimensions{1, 2, 3};
    PyObjectPtr np_arrayNd = PyInterpreter::Numpy::arrayND(dimensions);
    EXPECT_TRUE(np_arrayNd.valid());
}


TEST(Embedded, BornAgainPyFunctionTest)
{
    // Test importing a Python script which uses BornAgain Python package

    // initialize Python interpreter
    PyInterpreter::initialize();
    EXPECT_TRUE(PyInterpreter::isInitialized());

    const std::string script{"import bornagain as ba; from bornagain import deg;"
                             "d0 = ba.deg; d1 = deg; "
                             "get_sample = lambda: (d0 == d1 and 'BornAgain.deg = %.3f' % d1)"};
    const std::string functionName{"get_sample"};

    // locate the `get_sample` function (it is an attribute of the module)
    PyObjectPtr ret{PyInterpreter::BornAgain::callScriptFunction(functionName, script, "")};

    if (!ret.valid())
        throw std::runtime_error("Failed executing Python function '" + functionName + "'");

    const std::string return_str = PyInterpreter::pyStrtoString(ret.get());
    std::stringstream _stream;
    _stream << std::fixed << std::setprecision(3) << Units::deg;
    const std::string expected_str = "BornAgain.deg = " + _stream.str();

    // verify that the returned string starts with the expected string
    if (!(return_str == expected_str))
        throw std::runtime_error("Result '" + return_str + "' does not match the expected '"
                                 + expected_str + "'");
}
