/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER


#include "CustomKernels.hpp"

#include "kernels/axpby_kernel.hpp"

namespace custom {


    // beta = rtz_dev[0]/ oldrtz_dev[0]
    // p = beta * p + z
    sycl::event SpAXPBY_ker1(sycl::queue &q, double *rtz_dev, double *oldrtz_dev,
                        local_int_t nrow, Vector &p, Vector &z, const std::vector<sycl::event>&dependencies)
    {
        // AXPBY esimd kernel parameters
        constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
        const local_int_t nWG = 8;
        constexpr local_int_t uroll = 4;
        local_int_t nBlocks = ceil_div(nrow, uroll * block_size);

        // Assume nrow is a multiple of block_size, since no remainder handling in AXPBY kernels
        assert(nrow % block_size == 0);

        return q.submit([&](sycl::handler &cgh) {
             cgh.depends_on(dependencies);
             auto kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                 const double beta = rtz_dev[0] / oldrtz_dev[0]; // beta <- rtz / old_rtz
                 axpby_body<block_size, uroll>(item, z.values, p.values, 1.0, beta, nrow, nBlocks);
             };
             cgh.parallel_for<class CG_axpby_1>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
         });

    }

    // alpha = rtz_dev[0]/ pAp_dev[0]
    // r = r - alpha * Ap 
    sycl::event SpAXPBY_ker2(sycl::queue &q, double *rtz_dev, double *pAp_dev,
                        local_int_t nrow, Vector &r, Vector &Ap, const std::vector<sycl::event>&dependencies)
    {
        // AXPBY esimd kernel parameters
        constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
        const local_int_t nWG = 8;
        constexpr local_int_t uroll = 4;
        local_int_t nBlocks = ceil_div(nrow, uroll * block_size);

        // Assume nrow is a multiple of block_size, since no remainder handling in AXPBY kernels
        assert(nrow % block_size == 0);

        return q.submit([&](sycl::handler &cgh) {
            cgh.depends_on(dependencies);
            auto kernel = [=] (sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                const double alpha = rtz_dev[0] / pAp_dev[0]; // alpha <- rtz / pAp
                axpby_body<block_size, uroll>(item, Ap.values, r.values, -alpha, 1.0, nrow, nBlocks);
            };
            cgh.parallel_for<class CG_axpby_2>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
        });
    }

    // alpha = rtz_dev[0]/ pAp_dev[0]
    // x = x + alpha * p 
    sycl::event SpAXPBY_ker3(sycl::queue &q, double *rtz_dev, double *pAp_dev,
                        local_int_t nrow, Vector &x, Vector &p, const std::vector<sycl::event>&dependencies)
    {
        // AXPBY esimd kernel parameters
        constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
        const local_int_t nWG = 8;
        constexpr local_int_t uroll = 4;
        local_int_t nBlocks = ceil_div(nrow, uroll * block_size);

        // Assume nrow is a multiple of block_size, since no remainder handling in AXPBY kernels
        assert(nrow % block_size == 0);

        return q.submit([&](sycl::handler &cgh) {
            cgh.depends_on(dependencies);
            auto kernel = [=] (sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                const double alpha = rtz_dev[0] / pAp_dev[0]; // alpha <- rtz / pAp
                axpby_body<block_size, uroll>(item, p.values, x.values, alpha, 1.0, nrow, nBlocks);
            };
            cgh.parallel_for<class CG_axpby_3>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
        });
    }



} // namespace custom
