Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ static void initNvshmemFromMPIComm(MPI_Comm mpi_comm) {
}
#endif

static bool checkEnvVar(const char* env_var_str) {
const char* env_var_val_str = std::getenv(env_var_str);
bool result = false;
if (env_var_val_str) { result = std::strtol(env_var_val_str, nullptr, 10) == 1; }
return result;
}

static void checkTransposeCommBackend(cudecompTransposeCommBackend_t comm_backend) {
switch (comm_backend) {
case CUDECOMP_TRANSPOSE_COMM_NCCL:
Expand Down Expand Up @@ -199,7 +206,11 @@ static void gatherGlobalMPIInfo(cudecompHandle_t& handle) {
CHECK_NVML(nvmlDeviceGetHandleByPciBusId(pciBusId, &nvml_dev));
#if NVML_API_VERSION >= 12 && CUDART_VERSION >= 12040
nvmlGpuFabricInfoV_t fabricInfo = {.version = nvmlGpuFabricInfo_v2};
if (nvmlHasFabricSupport()) {

// Check CUDECOMP_DISABLE_MNNVL (debug setting to disable MNNVL topology detection)
bool disable_mnnvl = checkEnvVar("CUDECOMP_DISABLE_MNNVL");

if (nvmlHasFabricSupport() && !disable_mnnvl) {
handle->rank_to_mnnvl_info.resize(handle->nranks);

// Gather MNNVL information (clusterUuid, cliqueId) by rank
Expand Down Expand Up @@ -263,12 +274,10 @@ static void gatherGlobalMPIInfo(cudecompHandle_t& handle) {

static void getCudecompEnvVars(cudecompHandle_t& handle) {
// Check CUDECOMP_ENABLE_NCCL_UBR (NCCL user buffer registration)
const char* nccl_enable_ubr_str = std::getenv("CUDECOMP_ENABLE_NCCL_UBR");
if (nccl_enable_ubr_str) { handle->nccl_enable_ubr = std::strtol(nccl_enable_ubr_str, nullptr, 10) == 1; }
handle->nccl_enable_ubr = checkEnvVar("CUDECOMP_ENABLE_NCCL_UBR");

// Check CUDECOMP_ENABLE_CUMEM (CUDA VMM allocations for work buffers)
const char* cumem_enable_str = std::getenv("CUDECOMP_ENABLE_CUMEM");
if (cumem_enable_str) { handle->cuda_cumem_enable = std::strtol(cumem_enable_str, nullptr, 10) == 1; }
handle->cuda_cumem_enable = checkEnvVar("CUDECOMP_ENABLE_CUMEM");
if (handle->cuda_cumem_enable) {
#if CUDART_VERSION < 12030
if (handle->rank == 0) {
Expand Down Expand Up @@ -305,8 +314,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) {
}

// Check CUDECOMP_ENABLE_CUDA_GRAPHS (CUDA Graphs usage in pipelined backends)
const char* graphs_enable_str = std::getenv("CUDECOMP_ENABLE_CUDA_GRAPHS");
if (graphs_enable_str) { handle->cuda_graphs_enable = std::strtol(graphs_enable_str, nullptr, 10) == 1; }
handle->cuda_graphs_enable = checkEnvVar("CUDECOMP_ENABLE_CUDA_GRAPHS");
if (handle->cuda_graphs_enable) {
#if CUDART_VERSION < 11010
if (handle->rank == 0) {
Expand All @@ -318,10 +326,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) {
}

// Check CUDECOMP_ENABLE_PERFORMANCE_REPORT (Performance reporting)
const char* performance_report_str = std::getenv("CUDECOMP_ENABLE_PERFORMANCE_REPORT");
if (performance_report_str) {
handle->performance_report_enable = std::strtol(performance_report_str, nullptr, 10) == 1;
}
handle->performance_report_enable = checkEnvVar("CUDECOMP_ENABLE_PERFORMANCE_REPORT");

// Check CUDECOMP_PERFORMANCE_REPORT_DETAIL (Performance report detail level)
const char* performance_detail_str = std::getenv("CUDECOMP_PERFORMANCE_REPORT_DETAIL");
Expand Down Expand Up @@ -363,8 +368,7 @@ static void getCudecompEnvVars(cudecompHandle_t& handle) {
if (performance_write_dir_str) { handle->performance_report_write_dir = std::string(performance_write_dir_str); }

// Check CUDECOMP_USE_COL_MAJOR_RANK_ORDER (Column-major rank assignment)
const char* col_major_rank_str = std::getenv("CUDECOMP_USE_COL_MAJOR_RANK_ORDER");
if (col_major_rank_str) { handle->use_col_major_rank_order = std::strtol(col_major_rank_str, nullptr, 10) == 1; }
handle->use_col_major_rank_order = checkEnvVar("CUDECOMP_USE_COL_MAJOR_RANK_ORDER");
}

#ifdef ENABLE_NVSHMEM
Expand Down Expand Up @@ -439,12 +443,12 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) {
CHECK_CUTENSOR(cutensorInit(&handle->cutensor_handle));
#endif

// Gather extra MPI info from all communicator ranks
gatherGlobalMPIInfo(handle);

// Gather cuDecomp environment variable settings
getCudecompEnvVars(handle);

// Gather extra MPI info from all communicator ranks
gatherGlobalMPIInfo(handle);

// Determine P2P CE count
int dev;
CUdevice cu_dev;
Expand Down