diff --git a/src/cudecomp.cc b/src/cudecomp.cc index dd82575..fb43822 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -655,8 +655,16 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); // Create local NCCL communicator if row or column communicator uses it - if ((grid_desc->row_comm_info.ngroups == 1 && grid_desc->row_comm_info.nranks > 1) || - (grid_desc->col_comm_info.ngroups == 1 && grid_desc->col_comm_info.nranks > 1)) { + int need_local_nccl_comm = + static_cast((grid_desc->row_comm_info.ngroups == 1 && grid_desc->row_comm_info.nranks > 1) || + (grid_desc->col_comm_info.ngroups == 1 && grid_desc->col_comm_info.nranks > 1)); + + // Local comm can include ranks in other rows/columns, need additional check for those cases. + CHECK_MPI(MPI_Allreduce(MPI_IN_PLACE, &need_local_nccl_comm, 1, MPI_INT, MPI_LOR, + handle->mpi_clique_comm != MPI_COMM_NULL ? handle->mpi_clique_comm + : handle->mpi_local_comm)); + + if (need_local_nccl_comm) { handle->nccl_local_comm = ncclCommFromMPIComm( handle->mpi_clique_comm != MPI_COMM_NULL ? handle->mpi_clique_comm : handle->mpi_local_comm); }