import os
import subprocess
from shutil import copy, which

import libtbx.load_env

# =============================================================================
#
# Setting the environment to build kokkos and kokkos-kernels
#
# =============================================================================

# libkokkos.a
# call kokkos build system directly
# set environment variable defaults if necessary
if os.getenv('KOKKOS_DEVICES') is None:
  os.environ['KOKKOS_DEVICES'] = "OpenMP"
use_openmp = "OpenMP" in os.getenv('KOKKOS_DEVICES')  
use_cuda = "Cuda" in os.getenv('KOKKOS_DEVICES')
use_hip = "HIP" in os.getenv('KOKKOS_DEVICES')
use_sycl = "SYCL" in os.getenv('KOKKOS_DEVICES')
if os.getenv('KOKKOS_PATH') is None:
  os.environ['KOKKOS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos')
if os.getenv('KOKKOSKERNELS_PATH') is None:
  os.environ['KOKKOSKERNELS_PATH'] = libtbx.env.under_dist('simtbx', '../../kokkos-kernels')
if os.getenv('KOKKOS_ARCH') is None:
  os.environ['KOKKOS_ARCH'] = "HSW"
if use_cuda and os.getenv('KOKKOS_CUDA_OPTIONS') is None:
  os.environ['KOKKOS_CUDA_OPTIONS'] = "enable_lambda,force_uvm"
os.environ['CXXFLAGS'] = '-O3 -fPIC -DCUDAREAL=double'

library_flags = "-Llib"
if use_cuda:
  library_flags += " -L$(CUDA_HOME)/lib64"
  if os.getenv('CUDA_COMPATIBILITY'):
    library_flags += " -L$(CUDA_HOME)/compat"
elif use_sycl:
  library_flags = ""
os.environ['LDFLAGS'] = library_flags

linked_libraries = "-lkokkoscontainers -lkokkoscore -ldl"
if use_cuda:
  linked_libraries += " -lcudart -lcuda"
elif use_sycl:
  linked_libraries = ""
os.environ['LDLIBS'] = linked_libraries

cxx_standard = '14'
if use_sycl:
  cxx_standard = '17'


print("-"*40)
print("         Kokkos configuration\n")
print("  Devices: " + os.getenv('KOKKOS_DEVICES'))
print("     Arch: " + os.getenv('KOKKOS_ARCH'))
print("-"*40)


original_cxx = None
kokkos_lib = 'libkokkos.a'
kokkos_cxxflags = None

if os.getenv('CXX') is not None:
  original_cxx = os.environ['CXX']
if use_cuda:
  os.environ['CXX'] = os.path.join(os.environ['KOKKOS_PATH'], 'bin', 'nvcc_wrapper')
elif use_hip:
  os.environ['CXX'] = 'hipcc'
elif use_sycl:
  os.environ['CXX'] = 'icpx'
else:
  os.environ['CXX'] = os.environ.get('CXX', 'g++')
print('='*79)
print('Building Kokkos')
print('-'*79)
returncode = subprocess.call(['make', '-f', 'Makefile.kokkos', kokkos_lib],
                              cwd=os.environ['KOKKOS_PATH'])
print()

print('Copying Kokkos library')
print('-'*79)
src = os.path.join(os.environ['KOKKOS_PATH'], kokkos_lib)
dst = os.path.join(libtbx.env.under_build('lib'), kokkos_lib)
if os.path.isfile(src):
  copy(src, dst)
  print('Copied')
  print('  source:     ', src)
  print('  destination:', dst)
else:
  print('Error: {src} does not exist'.format(src=src))
print()

# =============================================================================
#
# Building the Kokkos library
#
# =============================================================================



# =============================================================================
# Build kokkos with CMake
# The build needs to be in a directory not in the build directory, otherwise
# kokkos-kernels will find that build directory instead of build/lib/cmake/Kokkos
# The error will look like
#
#   -- The project name is: KokkosKernels
#   CMake Error at /dev/shm/bkpoon/software/xfel/build/kokkos/KokkosConfig.cmake:48 (INCLUDE):
#     INCLUDE could not find requested file:

#       /dev/shm/bkpoon/software/xfel/build/kokkos/KokkosTargets.cmake
#   Call Stack (most recent call first):
#     CMakeLists.txt:107 (FIND_PACKAGE)
#
# kokkos will be installed in build with the libraries in lib, not lib64
#
# TODO:
# - Change over from Makefile version to cmake version
# - Change system configuration classes to select backends with cmake
#   flags instead of environment variables
# - Verify that libraries from cmake version works the same as libkokkos.a

OnOff = {True:'ON', False:'OFF'}
supported_architectures = ['Native', 'HSW', 'Zen', 'Zen2', 'Zen3', 'Volta70', 'Ampere80', 'Vega908', 'Vega90A', 'XeHP', 'PVC']
architectures = {}
for arch in supported_architectures:
  architectures[arch] = OnOff[ arch in os.getenv('KOKKOS_ARCH') ]

cmake_is_available = which('cmake')
if cmake_is_available and os.getenv('KOKKOS_HOME') is None:
  print('='*79)
  print('Building kokkos with cmake')
  print('-'*79)
  kokkos_build_dir = libtbx.env.under_dist('simtbx', '../../kokkos_build')
  if not os.path.isdir(kokkos_build_dir):
    os.mkdir(kokkos_build_dir)
  # ..............................................................................
  # Build kokkos with CMake.
  # Turn off all ETI builds for now, until needed, for maximum machine compatibility    
  returncode = subprocess.call([
      'cmake',
      os.environ['KOKKOS_PATH'],
      '-DCMAKE_CXX_STANDARD={}'.format(cxx_standard),
      '-DCMAKE_INSTALL_PREFIX={}'.format(libtbx.env.under_build('.')),
      '-DCMAKE_INSTALL_LIBDIR=lib',
      '-DBUILD_SHARED_LIBS={}'.format(OnOff[use_sycl]),
      '-DKokkos_ARCH_NATIVE={}'.format(architectures['Native']),
      '-DKokkos_ARCH_HSW={}'.format(architectures['HSW']),
      '-DKokkos_ARCH_ZEN={}'.format(architectures['Zen']),
      '-DKokkos_ARCH_ZEN2={}'.format(architectures['Zen2']),
      '-DKokkos_ARCH_ZEN3={}'.format(architectures['Zen3']),
      '-DKokkos_ARCH_VOLTA70={}'.format(architectures['Volta70']),
      '-DKokkos_ARCH_AMPERE80={}'.format(architectures['Ampere80']),
      '-DKokkos_ARCH_VEGA908={}'.format(architectures['Vega908']),
      '-DKokkos_ARCH_VEGA90A={}'.format(architectures['Vega90A']),
      '-DKokkos_ARCH_INTEL_XEHP={}'.format(architectures['XeHP']),
      '-DKokkos_ARCH_INTEL_PVC={}'.format(architectures['PVC']),
      '-DKokkos_ENABLE_SERIAL=ON',
      '-DKokkos_ENABLE_OPENMP={}'.format(OnOff[use_openmp]),
      '-DKokkos_ENABLE_CUDA={}'.format(OnOff[use_cuda]),
      '-DKokkos_ENABLE_CUDA_UVM={}'.format(OnOff[use_cuda]),
      '-DKokkos_ENABLE_HIP={}'.format(OnOff[use_hip]),
      '-DKokkos_ENABLE_SYCL={}'.format(OnOff[use_sycl])
    ],
    cwd=kokkos_build_dir)
  returncode = subprocess.call(['make', '-j', '4', 'install'], cwd=kokkos_build_dir)

# -----------------------------------------------------------------------------
# Build kokkos-kernels with CMake.
# Turn off all ETI builds for now, until needed, for maximum machine compatibility
  print('='*79)
  print('Building kokkos_kernels')
  print('-'*79)
  kokkos_kernels_build_dir = libtbx.env.under_dist('simtbx', '../../kokkos-kernels/build')
  if not os.path.isdir(kokkos_kernels_build_dir):
    os.mkdir(kokkos_kernels_build_dir)
  returncode = subprocess.call([
      'cmake',
      os.environ['KOKKOSKERNELS_PATH'],
      '-DCMAKE_INSTALL_PREFIX={}'.format(libtbx.env.under_build('.')),
      '-DCMAKE_INSTALL_LIBDIR=lib',
      '-DBUILD_SHARED_LIBS={}'.format(OnOff[use_sycl]),
      '-DKokkos_ROOT={}'.format(libtbx.env.under_build('.')),
      '-DKokkosKernels_ADD_DEFAULT_ETI={}'.format(OnOff[False]),
      '-DKokkosKernels_INST_LAYOUTLEFT:BOOL={}'.format(OnOff[False]),
      '-DKokkosKernels_INST_LAYOUTRIGHT:BOOL={}'.format(OnOff[False]),
      '-DKokkosKernels_ENABLE_TPL_CUBLAS={}'.format(OnOff[False]),
      '-DKokkosKernels_ENABLE_TPL_CUSPARSE={}'.format(OnOff[False])
    ],
    cwd=kokkos_kernels_build_dir)
  returncode = subprocess.call(['make', '-j', '4'], cwd=kokkos_kernels_build_dir)
  returncode = subprocess.call(['make', '-j', '4', 'install'], cwd=kokkos_kernels_build_dir)
else:
  print('*'*79)
  print('cmake was not found')
  print('Skipping builds of kokkos and kokkos-kernels')
  print('*'*79)
  
# =============================================================================

print('Getting environment variables')
print('-'*79)
kokkos_cxxflags = subprocess.check_output(
  ['make', '-f', 'Makefile.kokkos', 'print-cxx-flags'],
  cwd=os.environ['KOKKOS_PATH'])
kokkos_cxxflags = kokkos_cxxflags.split(b'\n')
kokkos_cxxflags = kokkos_cxxflags[1].decode('utf8').split()
if kokkos_cxxflags and kokkos_cxxflags[0] == 'echo':
  kokkos_cxxflags = [f.strip('"') for f in kokkos_cxxflags[1:]]
print('KOKKOS_CXXFLAGS:', kokkos_cxxflags)
print('='*79)

# =============================================================================
#
# Building libsimtbx_kokkos.so
#
# =============================================================================
Import("env", "env_etc")

kokkos_env = env.Clone()
kokkos_env.Replace(CXX=os.environ['CXX'])
kokkos_env.Replace(SHCXX=os.environ['CXX'])
kokkos_env.Prepend(CXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
kokkos_env.Prepend(CPPFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
kokkos_env.Prepend(CPPPATH=[os.environ['KOKKOS_PATH']])
kokkos_env.Append(LIBS=['kokkoscontainers','kokkoscore'])

if use_sycl:
  kokkos_env.Replace(SHLINK=os.environ['CXX'])
  #Need -fsycl and additional compiler flags at link stage for Intel compiler to generate dev code
  kokkos_env.Append(SHLINKFLAGS=kokkos_cxxflags)
  kokkos_env.Replace(SHLINKFLAGS=[val.replace('c++11','c++17') for val in kokkos_env['SHLINKFLAGS']])
  #Prefer c++17 for compiler as well otherwise we pass both c++11 and c++17 to compiler
  kokkos_env.Replace(CXXFLAGS=[val.replace('c++11','c++17') for val in kokkos_env['CXXFLAGS']])  
  kokkos_env.Prepend(SHCXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  #Remove duplicate flags
  kokkos_env.Replace(CXXFLAGS=[list(OrderedDict.fromkeys(kokkos_env['CXXFLAGS']))])
  kokkos_env.Replace(SHLINKFLAGS=[list(OrderedDict.fromkeys(kokkos_env['SHLINKFLAGS']))])

simtbx_kokkos_lib = kokkos_env.SharedLibrary(
  target="#lib/libsimtbx_kokkos.so",
  source=[
    'detector.cpp',
    'kokkos_instance.cpp',
    'kokkos_utils.cpp',
    'simulation.cpp',
    'structure_factors.cpp'
  ]
)

print("kokkos_env CXXFLAGS:", kokkos_env['CXXFLAGS'])
print("kokkos_env SHLINKFLAGS:", kokkos_env['SHLINKFLAGS'])

# =============================================================================
#
# Building simtbx_kokkos_ext.so
#
# =============================================================================
if not env_etc.no_boost_python:
  Import("env_no_includes_boost_python_ext")
  kokkos_ext_env = env_no_includes_boost_python_ext.Clone()

  env_etc.include_registry.append(
    env=kokkos_ext_env,
    paths=env_etc.simtbx_common_includes + [env_etc.python_include])
  kokkos_ext_env.Replace(CXX=os.environ['CXX'])
  kokkos_ext_env.Replace(SHCXX=os.environ['CXX'])
  if use_sycl:
    kokkos_ext_env.Replace(SHLINK=os.environ['CXX'])
    kokkos_ext_env.Prepend(SHCXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CXXFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CPPFLAGS=['-DCUDAREAL=double'] + kokkos_cxxflags)
  kokkos_ext_env.Prepend(CPPPATH=[os.environ['KOKKOS_PATH']])
  default_libs = [
    "simtbx_kokkos",
    "scitbx_boost_python",
    env_etc.boost_python_lib,
    "cctbx",
    "kokkoscontainers",
    "kokkoscore"]
  if 'Cuda' in os.getenv('KOKKOS_DEVICES'):
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['CUDA_HOME'], 'lib64')])
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['CUDA_HOME'], 'compat')])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs + ["cudart", "cuda"])
  elif 'HIP' in os.getenv('KOKKOS_DEVICES'):
    kokkos_ext_env.Append(LIBPATH=[os.path.join(os.environ['ROCM_PATH'], 'lib')])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs + ["amdhip64", "hsa-runtime64"])
  elif 'SYCL' in os.getenv('KOKKOS_DEVICES'):
    #Prefer c++17 for compiler as well otherwise we pass both c++11 and c++17 to compiler
    kokkos_ext_env.Replace(CXXFLAGS=[val.replace('c++11','c++17') for val in kokkos_ext_env['CXXFLAGS']])
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs)
    #Need -fsycl and additional compiler flags at link stage for Intel compiler to generate dev code
    kokkos_ext_env.Append(SHLINKFLAGS=kokkos_cxxflags)
    kokkos_ext_env.Replace(SHLINKFLAGS=[val.replace('c++11','c++17') for val in kokkos_ext_env['SHLINKFLAGS']])
    #Remove duplicate flags
    kokkos_ext_env.Replace(CXXFLAGS=[list(OrderedDict.fromkeys(kokkos_ext_env['CXXFLAGS']))])
    kokkos_ext_env.Replace(SHLINKFLAGS=[list(OrderedDict.fromkeys(kokkos_ext_env['SHLINKFLAGS']))])
  else:
    kokkos_ext_env.Append(LIBS=env_etc.libm + default_libs)

  simtbx_kokkos_ext = kokkos_ext_env.SharedLibrary(
    target="#lib/simtbx_kokkos_ext.so",
    source=['kokkos_ext.cpp']
  )

# reset CXX
if original_cxx is not None:
  os.environ['CXX'] = original_cxx
