import os
from shutil import copy, which

import libtbx.load_env

# =============================================================================
#
# Setting the environment to build kokkos and kokkos-kernels
#
# =============================================================================

# libkokkos.a
# call kokkos build system directly
# set environment variable defaults if necessary
if os.getenv('KOKKOS_DEVICES') is None:
  os.environ['KOKKOS_DEVICES'] = "OpenMP"
use_openmp = "OpenMP" in os.getenv('KOKKOS_DEVICES')  
use_cuda = "Cuda" in os.getenv('KOKKOS_DEVICES')
use_hip = "HIP" in os.getenv('KOKKOS_DEVICES')
use_sycl = "SYCL" in os.getenv('KOKKOS_DEVICES')
if os.getenv('KOKKOS_PATH') is None:
  os.environ['KOKKOS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos')
if os.getenv('KOKKOSKERNELS_PATH') is None:
  os.environ['KOKKOSKERNELS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos-kernels')
if os.getenv('KOKKOS_ARCH') is None:
  os.environ['KOKKOS_ARCH'] = "HSW"
os.environ['CXXFLAGS'] = '-O3 -fPIC -DCUDAREAL=double'

library_flags = "-Llib"
if use_cuda:
  library_flags += " -L$(CUDA_HOME)/lib64 -L$(CUDA_HOME)/lib64/stubs"
  if os.getenv('CUDA_COMPATIBILITY'):
    library_flags += " -L$(CUDA_HOME)/compat"
elif use_sycl:
  library_flags = ""
os.environ['LDFLAGS'] = library_flags

linked_libraries = "-lkokkoscontainers -lkokkoscore -ldl"
if use_cuda:
  linked_libraries += " -lcudart -lcuda"
elif use_sycl:
  linked_libraries = ""
os.environ['LDLIBS'] = linked_libraries

cxx_standard = '14'
if use_sycl:
  cxx_standard = '17'


print("-"*40)
print("         Kokkos configuration\n")
print("  Devices: " + os.getenv('KOKKOS_DEVICES'))
print("     Arch: " + os.getenv('KOKKOS_ARCH'))
print("-"*40)


original_cxx = None
kokkos_lib = 'libkokkos.a'
kokkos_cxxflags = None

# remove conda/lib path from ld_path
original_ld_path = os.getenv('LD_LIBRARY_PATH')
if original_ld_path is not None:
  ld_path = [s for s in original_ld_path.split(':') if not 'opt/mamba/envs/psana_env/lib' in s]
  os.environ['LD_LIBRARY_PATH'] = ":".join(ld_path)

if os.getenv('CXX') is not None:
  original_cxx = os.environ['CXX']
if use_cuda:
  os.environ['CXX'] = os.path.join(os.environ['KOKKOS_PATH'], 'bin', 'nvcc_wrapper')
elif use_hip:
  os.environ['CXX'] = 'hipcc'
elif use_sycl:
  os.environ['CXX'] = 'icpx'
else:
  os.environ['CXX'] = os.environ.get('CXX', 'g++')

# =============================================================================
#
# Building the Kokkos library
#
# =============================================================================

print('Getting environment variables')
print('-'*79)
kokkos_cxxflags = os.environ['KOKKOS_CXXFLAGS'].split()
if kokkos_cxxflags and kokkos_cxxflags[0] == 'echo':
  kokkos_cxxflags = [f.strip('"') for f in kokkos_cxxflags[1:]]
print('KOKKOS_CXXFLAGS:', kokkos_cxxflags)
print('='*79)

# =============================================================================
#
# Building libsimtbx_kokkos.so
#
# =============================================================================
Import("env", "env_etc")

# remove -ffast-math compile option
kokkos_env = env.Clone()
ccflags = kokkos_env['SHCCFLAGS']
o = '-ffast-math'
if o in ccflags:
  ccflags.remove(o)
kokkos_env.Replace(SHCCFLAGS=ccflags)

kokkos_env.Replace(CXX=os.environ['CXX'])
kokkos_env.Replace(SHCXX=os.environ['CXX'])
kokkos_env.Replace(SHLINK=os.environ['CXX'])
kokkos_env.Prepend(CXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
kokkos_env.Prepend(CPPFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
kokkos_env.Prepend(CPPPATH=[os.environ['KOKKOS_PATH']])
kokkos_env.Append(LIBS=['kokkoscontainers','kokkoscore'])
if 'HIP' in os.getenv('KOKKOS_DEVICES'):
  kokkos_env.Replace(LINK=os.environ['CXX'])
  kokkos_env.Replace(SHLINK=os.environ['CXX'])

if use_sycl:
  kokkos_env.Replace(SHLINK=os.environ['CXX'])
  #Need -fsycl and additional compiler flags at link stage for Intel compiler to generate dev code
  kokkos_env.Append(SHLINKFLAGS=kokkos_cxxflags)
  kokkos_env.Replace(SHLINKFLAGS=[val.replace('c++11','c++17') for val in kokkos_env['SHLINKFLAGS']])
  #Prefer c++17 for compiler as well otherwise we pass both c++11 and c++17 to compiler
  kokkos_env.Replace(CXXFLAGS=[val.replace('c++11','c++17') for val in kokkos_env['CXXFLAGS']])  
  kokkos_env.Prepend(SHCXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  #Remove duplicate flags
  kokkos_env.Replace(CXXFLAGS=[list(OrderedDict.fromkeys(kokkos_env['CXXFLAGS']))])
  kokkos_env.Replace(SHLINKFLAGS=[list(OrderedDict.fromkeys(kokkos_env['SHLINKFLAGS']))])

if True: # diffuse code in exascale API
  # "if True" block is intended to demark the linkage against the eigen library
  # which is needed to reuse the diffBragg function calc_diffuse_at_hkl() within kokkos/simulation.
  # This block can be removed in the future if the eigen dependency is replaced by kokkoslib.
  env_etc.eigen_dist = os.path.abspath(os.path.join(libtbx.env.dist_path("simtbx"),"../../eigen"))
  if not os.path.isdir(env_etc.eigen_dist) and hasattr(env_etc, "conda_cpppath"):
    for candidate in env_etc.conda_cpppath:
      if os.path.isdir(os.path.join(candidate, "eigen3")):
        env_etc.eigen_dist = os.path.abspath(os.path.join(os.path.join(candidate, "eigen3")))
  if os.path.isdir(env_etc.eigen_dist):
    env_etc.eigen_include = env_etc.eigen_dist
  env_etc.include_registry.append(
    env=kokkos_env,
    paths=[env_etc.eigen_include])

simtbx_kokkos_lib = kokkos_env.SharedLibrary(
  target="#lib/libsimtbx_kokkos.so",
  source=[
    'detector.cpp',
    'kokkos_instance.cpp',
    '../../kokkostbx/kokkos_utils.cpp',
    'simulation.cpp',
    'structure_factors.cpp'
  ]
)

print("kokkos_env CXXFLAGS:", kokkos_env['CXXFLAGS'])
print("kokkos_env SHLINKFLAGS:", kokkos_env['SHLINKFLAGS'])

# =============================================================================
#
# Building simtbx_kokkos_ext.so
#
# =============================================================================
if not env_etc.no_boost_python:
  Import("env_no_includes_boost_python_ext")
  kokkos_ext_env = env_no_includes_boost_python_ext.Clone()

  env_etc.include_registry.append(
    env=kokkos_ext_env,
    paths=env_etc.simtbx_common_includes + [env_etc.python_include])
  kokkos_ext_env.Replace(CXX=os.environ['CXX'])
  kokkos_ext_env.Replace(SHCXX=os.environ['CXX'])
  kokkos_ext_env.Replace(SHLINK=os.environ['CXX'])
  if use_sycl:
    kokkos_ext_env.Replace(SHLINK=os.environ['CXX'])
    kokkos_ext_env.Prepend(SHCXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CPPFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CPPPATH=[os.environ['KOKKOS_PATH']])
  default_libs = [
    "simtbx_kokkos",
    "scitbx_boost_python",
    env_etc.boost_python_lib,
    "cctbx",
    "kokkoscontainers",
    "kokkoscore"]
  if 'Cuda' in os.getenv('KOKKOS_DEVICES'):
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['CUDA_HOME'], 'lib64')])
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['CUDA_HOME'], 'compat')])
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['CUDA_HOME'], 'lib64/stubs')])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs + ["cudart", "cuda"])
  elif 'HIP' in os.getenv('KOKKOS_DEVICES'):
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['ROCM_PATH'], 'lib')])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs + ["amdhip64", "hsa-runtime64"])
  elif 'SYCL' in os.getenv('KOKKOS_DEVICES'):
    #Prefer c++17 for compiler as well otherwise we pass both c++11 and c++17 to compiler
    kokkos_ext_env.Replace(CXXFLAGS=[val.replace('c++11','c++17') for val in kokkos_ext_env['CXXFLAGS']])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs)
    #Need -fsycl and additional compiler flags at link stage for Intel compiler to generate dev code
    kokkos_ext_env.Append(SHLINKFLAGS=kokkos_cxxflags)
    kokkos_ext_env.Replace(SHLINKFLAGS=[val.replace('c++11','c++17') for val in kokkos_ext_env['SHLINKFLAGS']])
    #Remove duplicate flags
    kokkos_ext_env.Replace(CXXFLAGS=[list(OrderedDict.fromkeys(kokkos_ext_env['CXXFLAGS']))])
    kokkos_ext_env.Replace(SHLINKFLAGS=[list(OrderedDict.fromkeys(kokkos_ext_env['SHLINKFLAGS']))])
  else:
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs)

  simtbx_kokkos_ext = kokkos_ext_env.SharedLibrary(
    target="#lib/simtbx_kokkos_ext.so",
    source=['kokkos_ext.cpp']
  )

# reset CXX
if original_cxx is not None:
  os.environ['CXX'] = original_cxx
