diff --git a/src/cudecomp.cc b/src/cudecomp.cc index fb43822..e67446f 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -86,6 +86,13 @@ static void initNvshmemFromMPIComm(MPI_Comm mpi_comm) { } #endif +static bool checkEnvVar(const char* env_var_str) { + const char* env_var_val_str = std::getenv(env_var_str); + bool result = false; + if (env_var_val_str) { result = std::strtol(env_var_val_str, nullptr, 10) == 1; } + return result; +} + static void checkTransposeCommBackend(cudecompTransposeCommBackend_t comm_backend) { switch (comm_backend) { case CUDECOMP_TRANSPOSE_COMM_NCCL: @@ -199,7 +206,11 @@ static void gatherGlobalMPIInfo(cudecompHandle_t& handle) { CHECK_NVML(nvmlDeviceGetHandleByPciBusId(pciBusId, &nvml_dev)); #if NVML_API_VERSION >= 12 && CUDART_VERSION >= 12040 nvmlGpuFabricInfoV_t fabricInfo = {.version = nvmlGpuFabricInfo_v2}; - if (nvmlHasFabricSupport()) { + + // Check CUDECOMP_DISABLE_MNNVL (debug setting to disable MNNVL topology detection) + bool disable_mnnvl = checkEnvVar("CUDECOMP_DISABLE_MNNVL"); + + if (nvmlHasFabricSupport() && !disable_mnnvl) { handle->rank_to_mnnvl_info.resize(handle->nranks); // Gather MNNVL information (clusterUuid, cliqueId) by rank @@ -263,12 +274,10 @@ static void gatherGlobalMPIInfo(cudecompHandle_t& handle) { static void getCudecompEnvVars(cudecompHandle_t& handle) { // Check CUDECOMP_ENABLE_NCCL_UBR (NCCL user buffer registration) - const char* nccl_enable_ubr_str = std::getenv("CUDECOMP_ENABLE_NCCL_UBR"); - if (nccl_enable_ubr_str) { handle->nccl_enable_ubr = std::strtol(nccl_enable_ubr_str, nullptr, 10) == 1; } + handle->nccl_enable_ubr = checkEnvVar("CUDECOMP_ENABLE_NCCL_UBR"); // Check CUDECOMP_ENABLE_CUMEM (CUDA VMM allocations for work buffers) - const char* cumem_enable_str = std::getenv("CUDECOMP_ENABLE_CUMEM"); - if (cumem_enable_str) { handle->cuda_cumem_enable = std::strtol(cumem_enable_str, nullptr, 10) == 1; } + handle->cuda_cumem_enable = checkEnvVar("CUDECOMP_ENABLE_CUMEM"); if (handle->cuda_cumem_enable) { #if CUDART_VERSION < 12030 if (handle->rank == 0) { @@ -305,8 +314,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) { } // Check CUDECOMP_ENABLE_CUDA_GRAPHS (CUDA Graphs usage in pipelined backends) - const char* graphs_enable_str = std::getenv("CUDECOMP_ENABLE_CUDA_GRAPHS"); - if (graphs_enable_str) { handle->cuda_graphs_enable = std::strtol(graphs_enable_str, nullptr, 10) == 1; } + handle->cuda_graphs_enable = checkEnvVar("CUDECOMP_ENABLE_CUDA_GRAPHS"); if (handle->cuda_graphs_enable) { #if CUDART_VERSION < 11010 if (handle->rank == 0) { @@ -318,10 +326,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) { } // Check CUDECOMP_ENABLE_PERFORMANCE_REPORT (Performance reporting) - const char* performance_report_str = std::getenv("CUDECOMP_ENABLE_PERFORMANCE_REPORT"); - if (performance_report_str) { - handle->performance_report_enable = std::strtol(performance_report_str, nullptr, 10) == 1; - } + handle->performance_report_enable = checkEnvVar("CUDECOMP_ENABLE_PERFORMANCE_REPORT"); // Check CUDECOMP_PERFORMANCE_REPORT_DETAIL (Performance report detail level) const char* performance_detail_str = std::getenv("CUDECOMP_PERFORMANCE_REPORT_DETAIL"); @@ -363,8 +368,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) { if (performance_write_dir_str) { handle->performance_report_write_dir = std::string(performance_write_dir_str); } // Check CUDECOMP_USE_COL_MAJOR_RANK_ORDER (Column-major rank assignment) - const char* col_major_rank_str = std::getenv("CUDECOMP_USE_COL_MAJOR_RANK_ORDER"); - if (col_major_rank_str) { handle->use_col_major_rank_order = std::strtol(col_major_rank_str, nullptr, 10) == 1; } + handle->use_col_major_rank_order = checkEnvVar("CUDECOMP_USE_COL_MAJOR_RANK_ORDER"); } #ifdef ENABLE_NVSHMEM @@ -439,12 +443,12 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { CHECK_CUTENSOR(cutensorInit(&handle->cutensor_handle)); #endif - // Gather extra MPI info from all communicator ranks - gatherGlobalMPIInfo(handle); - // Gather cuDecomp environment variable settings getCudecompEnvVars(handle); + // Gather extra MPI info from all communicator ranks + gatherGlobalMPIInfo(handle); + // Determine P2P CE count int dev; CUdevice cu_dev;