cmake_minimum_required (VERSION 3.16)

# Setting the compiler hard coded is not recommended in this script.
# It avoids accidentally to use the wrong standard compiler.
if(WITH_ZOS)
    set(CMAKE_C_COMPILER "ibm-clang64")
    set(CMAKE_CXX_COMPILER "ibm-clang++64")
endif()

project (src)

enable_testing()

option(WITH_GPU "WITH_GPU" ON)
option(WITH_MPI "WITH_MPI" OFF)
option(WITH_ZFLAGS "WITH_ZFLAGS" OFF)
option(WITH_NATIVE "WITH_NATIVE" OFF)
option(WITH_NUMA "WITH_NUMA" OFF)
option(WITH_WIN "WITH_WIN" OFF)
option(WITH_MAC "WITH_MAC" OFF)
option(WITH_ZOS "WITH_ZOS" OFF)
option(WITH_S390X_ONLY "WITH_S390X_ONLY" OFF)

message(">> -------------------------------------------------------")
message(">> Snap ML Compilation Options")
message(">> GPU support (WITH_GPU):                  ${WITH_GPU}")
message(">> NUMA support (WITH_NUMA):                ${WITH_NUMA}")
message(">> Z-specific compiler flags (WITH_ZFLAGS): ${WITH_ZFLAGS}")
message(">> Native compiler flags (WITH_NATIVE):     ${WITH_NATIVE}")
message(">> Compile with Visual Studio (WITH_WIN):   ${WITH_WIN}")
message(">> Mac OS compilation (WITH_MAC):           ${WITH_MAC}")
message(">> z/OS compilation (WITH_ZOS):             ${WITH_ZOS}")
message(">> -------------------------------------------------------")

set(CMAKE_MODULE_PATH ${CMAKE_HOME_DIRECTORY}/../../cmake)

if(WITH_ZOS)
   set(Python_ROOT_DIR /usr/lpp/IBM/cyp/v3r11/pyz)
   set(Python_LIBRARIES /usr/lpp/IBM/cyp/v3r11/pyz/lib/python3.11/config-3.11/libpython3.11.x)
   set(Python_INCLUDE_DIRS /usr/lpp/IBM/cyp/v3r11/pyz/include/python3.11)
else()
   find_package(Python COMPONENTS Development)
endif()

#if(NOT(WITH_MAC))
#    set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-undefined" )
#else ()
#    set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-undefined,error" )
#endif()

#find_package(BLAS REQUIRED)

if(WITH_ZOS)
    message("build for USS")

    add_definitions(--config /etc/clang.cfg)
    add_definitions(-D_ALL_SOURCE)
    add_definitions(-D_OPEN_SYS_MUTEX_EXT)
    add_definitions(-DWITH_ZOS)
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -Wall -Werror -Wno-sign-compare -Wno-unused-lambda-capture -march=arch12 -fzos-le-char-mode=ascii -std=c++11 -fzvector -fvisibility=default -fPIC")

    add_definitions(-DGLM_INLINE=inline)
    add_definitions(-DUNUSED=__attribute__\(\(unused\)\))
    add_definitions(-DZ14_SIMD)
elseif(NOT(WITH_WIN))
    add_definitions(-DGLM_INLINE=inline)
    add_definitions(-DUNUSED=__attribute__\(\(unused\)\))

    if (WITH_MAC)
        set (CMAKE_CXX_FLAGS "-O3 -Wall -Werror -Wno-sign-compare -fPIC")
    else ()
        set (CMAKE_CXX_FLAGS "-O3 -lpthread -ldl -lutil -Wall -Werror -Wno-sign-compare -fPIC")
    endif ()

    if (WITH_MAC)
        set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-private-field")
        add_definitions(-DWITH_MAC)
        find_package(OpenMP REQUIRED)
        link_libraries(OpenMP::OpenMP_CXX)
    else ()
        # Linux
        set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
    endif ()

    # normal behaviour (x86)
    if(NOT(WITH_NATIVE) AND NOT(WITH_ZFLAGS) AND NOT (WITH_POWER))
        if(WITH_AVX512)
            set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfma -mavx512f -mavx512dq -mtune='corei7'")
            add_compile_definitions(X86_AVX512)
        else()
            if(WITH_AVX2)
                set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mtune='corei7'")
                add_compile_definitions(X86_AVX2)
            endif()
        endif()
    endif()

    if(WITH_NATIVE)
        set(CMAKE_CXX_FLAGS "-march=native -mtune=native ${CMAKE_CXX_FLAGS}")
    endif()

    if(WITH_ZFLAGS)
        # Let compiler make Z-specific optimizations
        set(CMAKE_CXX_FLAGS "-std=gnu++11 -mvx -march=z14 -mzvector -fstack-protector-all -funroll-loops -Wno-attributes -Wno-unused-variable ${CMAKE_CXX_FLAGS}")
        add_compile_definitions(Z14_SIMD)
    else()
        set(CMAKE_CXX_FLAGS "-std=c++11 -ffast-math -fno-finite-math-only -flto -fno-operator-names ${CMAKE_CXX_FLAGS}")
    endif()
else()
    add_definitions(-DWITH_VS)
    add_definitions(-D_USE_MATH_DEFINES)
    add_definitions(-DGLM_INLINE=__forceinline)
    add_definitions(-DUNUSED=)
    set(CMAKE_CXX_STANDARD 11)
    set(CMAKE_CXX_STANDARD_REQUIRED ON)
    set(CMAKE_CXX_FLAGS "/WX /MD /Zi /O2 /Oi /Oy /GT /Gy /Qpar /arch:AVX /fp:fast /openmp /EHsc")
    set(CMAKE_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS})
    set(CMAKE_CXX_FLAGS_RELWITHDEBINFO ${CMAKE_CXX_FLAGS})
endif()

if(WITH_GPU)
    # Compile with GPU support
    find_package(CUDA REQUIRED)
    if(CUDA_VERSION LESS 10)
        set(CUDA_ARCH_TARGETS "-gencode arch=compute_60,code=sm_60;-gencode arch=compute_70,code=sm_70")
    elseif(CUDA_VERSION LESS 11)
        set(CUDA_ARCH_TARGETS "-gencode arch=compute_60,code=sm_60;-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75")
    else()
        set(CUDA_ARCH_TARGETS "-gencode arch=compute_60,code=sm_60;-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75;-gencode arch=compute_80,code=sm_80")
    endif()
    set(CUDA_HOST_COMPILATION_CPP "ON")
    set(CUDA_PROPAGATE_HOST_FLAGS "FALSE")
    set(CUDA_HOST_FLAGS "-O3,-Wall,-Werror,-Wno-sign-compare,-fPIC,-fopenmp,-mtune='corei7',-ffast-math,-fno-finite-math-only,-fno-operator-names")

    set(CUDA_NVCC_FLAGS "-std=c++14;-use_fast_math;-Werror cross-execution-space-call,deprecated-declarations;-Xcompiler ${CUDA_HOST_FLAGS};${CUDA_ARCH_TARGETS}")
    add_definitions(-DWITH_CUDA)
    #add_definitions(-DTWO_GPUS)
    #message("${CUDA_NVCC_FLAGS}")
else()
    add_definitions(-D__host__=)
    add_definitions(-D__device__=)
    #message("${CMAKE_CXX_FLAGS}")
endif()

if(WITH_NUMA)
    # compile with numa
    add_definitions(-DWITH_NUMA)
    find_package(Numa REQUIRED)
    link_libraries(${NUMA_LIBRARY})
endif()

message("${CMAKE_CURRENT_SOURCE_DIR}")

include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/../include")
include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/include")
include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/../src/preprocessing/src")
include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/third-party/cub")
include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/third-party/eigen")
include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/third-party/rapidjson/include")
include_directories(snap PUBLIC "${Python_INCLUDE_DIRS}")

if(WITH_MAC OR WITH_ZOS)
    include_directories(snap PUBLIC "${PROJECT_SOURCE_DIR}/third-party/pthread-barrier-macos/inc")
endif()

if(NOT(WITH_ZOS))
    add_subdirectory(graph-preprocessor)
endif()
add_subdirectory(test)

if(WITH_ZOS)
    target_link_libraries(TreeForestTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    target_link_libraries(TreeBoosterTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    target_link_libraries(MixBoosterTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    target_link_libraries(CompressedTreesTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    target_link_libraries(MBITBoosterTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    target_link_libraries(MBITBoosterThrdTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
    # target_link_libraries(MBITBoosterOptTest /usr/lpp/aie/zdnn/lib/libzdnn.x)
endif()

if(WITH_ZFLAGS)
    target_link_libraries(TreeForestTest "/usr/lib64/libzdnn.so")
    target_link_libraries(TreeBoosterTest "/usr/lib64/libzdnn.so")
    target_link_libraries(MixBoosterTest "/usr/lib64/libzdnn.so")
    target_link_libraries(CompressedTreesTest "/usr/lib64/libzdnn.so")
    target_link_libraries(MBITBoosterTest "/usr/lib64/libzdnn.so")
    target_link_libraries(MBITBoosterThrdTest "/usr/lib64/libzdnn.so")
    # target_link_libraries(MBITBoosterOptTest "${PROJECT_SOURCE_DIR}/../libzdnninternal.so")
endif()

# In ManyLinux containers there is no shared Python library available.
# As a work around the library is linked statically. The following
# change was required to make that working with an embedded Python:
# https://stackoverflow.com/questions/49784583/numpy-import-fails-on-multiarray-extension-library-when-called-from-embedded-pyt/66163572#66163572
get_filename_component(extension "${Python_LIBRARIES}" LAST_EXT)
if (extension STREQUAL ".a")
    target_link_options(PreProcessingTest PUBLIC "LINKER:-export-dynamic")
endif()
target_link_libraries(PreProcessingTest ${Python_LIBRARIES})
target_link_libraries(GfpEncoderTest ${Python_LIBRARIES})

add_test(RidgeRegressionTest test/RidgeRegressionTest)
add_test(LogisticRegressionTest test/LogisticRegressionTest)
add_test(SupportVectorMachineTest test/SupportVectorMachineTest)
add_test(SparseLogisticRegressionTest test/SparseLogisticRegressionTest)
add_test(LassoRegressionTest test/LassoRegressionTest)
add_test(ChunkingTest test/ChunkingTest)

if(DEFINED ENV{WITH_S390X_ONLY})
    set(WITH_S390X_ONLY "$ENV{WITH_S390X_ONLY}")
endif()

if(NOT(WITH_ZOS) AND NOT(WITH_ZFLAGS) AND NOT(WITH_S390X_ONLY))
    add_test(LoadersTest test/LoadersTest)
endif()
add_test(CocoaTest test/CocoaTest)
add_test(DatasetTest test/DatasetTest)
add_test(MultiThreadingTest test/MultiThreadingTest)
add_test(LoadBalancingTest test/LoadBalancingTest)
add_test(SGDTest test/SGDTest)
add_test(PrivacyTest test/PrivacyTest)
add_test(TreeLearnerTest test/TreeLearnerTest)
add_test(TreeForestTest test/TreeForestTest)
add_test(TreeBoosterTest test/TreeBoosterTest)
add_test(MixBoosterTest test/MixBoosterTest)
add_test(RBFSamplerTest test/RBFSamplerTest)
add_test(CompressedTreesTest test/CompressedTreesTest)
add_test(MulticlassBoosterTest test/MulticlassBoosterTest)
add_test(PreProcessingTest test/PreProcessingTest)
add_test(GfpEncoderTest test/GfpEncoderTest)

if(WITH_ZFLAGS OR WITH_ZOS)
    add_test(MBITBoosterTest test/MBITBoosterTest)
    add_test(MBITBoosterThrdTest test/MBITBoosterThrdTest)
    # add_test(MBITBoosterOptTest test/MBITBoosterOptTest)
endif()

if(WITH_ZOS)
    set_property(TEST TreeLearnerTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST TreeForestTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST TreeBoosterTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST MixBoosterTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST CompressedTreesTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST MulticlassBoosterTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST MBITBoosterTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST MBITBoosterThrdTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test")
    set_property(TEST PreProcessingTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test:/usr/lpp/IBM/cyp/v3r11/pyz/lib")
    set_property(TEST GfpEncoderTest PROPERTY ENVIRONMENT "LIBPATH=$ENV{LIBPATH}:./test:/usr/lpp/IBM/cyp/v3r11/pyz/lib")    
endif()
