//==---------------- <complex> wrapper around STL --------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// STL's <complex> includes <cmath> which, in turn, pollutes global namespace.
// As such, we cannot include <complex> from SYCL headers unconditionally and
// have to provide support for std::complex only when the customer included
// <complex> explicitly. Do that by providing our own <complex> that is
// implemented as a wrapper around the STL header using "#include_next"
// functionality.

#pragma once

// Include real STL <complex> header - the next one from the include search
// directories.
#if defined(__has_include_next)
// GCC/clang support go through this path.
#include_next <complex>
#else
// MSVC doesn't support "#include_next", so we have to be creative.
// Our header is located in "stl_wrappers/complex" so it won't be picked by the
// following include. MSVC's installation, on the other hand, has the layout
// where the following would result in the <complex> we want. This is obviously
// hacky, but the best we can do...
#include <../include/complex>
#endif

// Now that we have std::complex available, implement SYCL functionality related
// to it.

#include <type_traits>

#include <CL/__spirv/spirv_ops.hpp> // for __SYCL_CONVERGENT__
#include <sycl/half_type.hpp>       // for half

// We provide std::complex specializations here for the following:
// select_cl_scalar_complex_or_T:
#include <sycl/detail/generic_type_traits.hpp>
// sycl::detail::GroupOpTag:
#include <sycl/ext/oneapi/functional.hpp>
// sycl::detail::is_complex:
#include <sycl/group_algorithm.hpp>
// sycl::detail::isComplex
#include <sycl/known_identity.hpp>

namespace __spv {
struct complex_float {
  complex_float() = default;
  complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
  operator std::complex<float>() { return {real, imag}; }
  float real, imag;
};

struct complex_double {
  complex_double() = default;
  complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
  operator std::complex<double>() { return {real, imag}; }
  double real, imag;
};

struct complex_half {
  complex_half() = default;
  complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
  operator std::complex<sycl::half>() { return {real, imag}; }
  sycl::half real, imag;
};
} // namespace __spv

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
    __SYCL_EXPORT __spv::complex_half
    __spirv_GroupCMulINTEL(unsigned int, unsigned int,
                           __spv::complex_half) noexcept;
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
    __SYCL_EXPORT __spv::complex_float
    __spirv_GroupCMulINTEL(unsigned int, unsigned int,
                           __spv::complex_float) noexcept;
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
    __SYCL_EXPORT __spv::complex_double
    __spirv_GroupCMulINTEL(unsigned int, unsigned int,
                           __spv::complex_double) noexcept;

namespace sycl {
inline namespace _V1 {
namespace detail {
template <typename T>
struct isComplex<T, std::enable_if_t<std::is_same_v<T, std::complex<float>> ||
                                     std::is_same_v<T, std::complex<double>>>>
    : public std::true_type {};

// NOTE: std::complex<long double> not yet supported by group algorithms.
template <typename T>
struct is_complex<T, std::enable_if_t<std::is_same_v<T, std::complex<half>> ||
                                      std::is_same_v<T, std::complex<float>> ||
                                      std::is_same_v<T, std::complex<double>>>>
    : public std::true_type {};

#ifdef __SYCL_DEVICE_ONLY__
template <typename T>
struct GroupOpTag<
    T, std::enable_if_t<std::is_same<T, std::complex<half>>::value ||
                        std::is_same<T, std::complex<float>>::value ||
                        std::is_same<T, std::complex<double>>::value>> {
  using type = GroupOpC;
};
#endif

template <typename T>
struct select_cl_scalar_complex_or_T<T,
                                     std::enable_if_t<is_complex<T>::value>> {
  using type = std::conditional_t<
      std::is_same_v<T, std::complex<float>>, __spv::complex_float,
      std::conditional_t<std::is_same_v<T, std::complex<double>>,
                         __spv::complex_double, __spv::complex_half>>;
};
} // namespace detail
} // namespace _V1
} // namespace sycl
