From 37cc75eb4211cec2c9839bc7017af9634b173142 Mon Sep 17 00:00:00 2001 From: David Bayer Date: Mon, 4 May 2026 09:36:06 +0200 Subject: [PATCH] [cudax] Add support for generic thread groups within warp and cluster --- .../cuda/experimental/__group/group.cuh | 18 ++-- .../synchronizer/lane_synchronizer.cuh | 5 ++ cudax/test/group/cooperative_algorithm.cu | 82 +++++++++++++------ 3 files changed, 72 insertions(+), 33 deletions(-) diff --git a/cudax/include/cuda/experimental/__group/group.cuh b/cudax/include/cuda/experimental/__group/group.cuh index 2e82c242e1b..7f6f11c2f5f 100644 --- a/cudax/include/cuda/experimental/__group/group.cuh +++ b/cudax/include/cuda/experimental/__group/group.cuh @@ -120,12 +120,13 @@ class group } public: - using unit_type = _Unit; - using level_type = typename _ParentGroup::level_type; - using hierarchy_type = typename _ParentGroup::hierarchy_type; - using mapping_type = _Mapping; - using __mapping_result_type = _MappingResult; - using synchronizer_type = _Synchronizer; + using unit_type = _Unit; + using level_type = typename _ParentGroup::level_type; + using hierarchy_type = typename _ParentGroup::hierarchy_type; + using mapping_type = _Mapping; + using __mapping_result_type = _MappingResult; + using synchronizer_type = _Synchronizer; + using __synchronizer_instance_type = _SynchronizerInstance; _CCCL_DEVICE_API explicit group( const _Unit& __unit, @@ -162,6 +163,11 @@ public: return __synchronizer_; } + [[nodiscard]] _CCCL_DEVICE_API const _SynchronizerInstance& __synchronizer_instance() const noexcept + { + return __synchronizer_instance_; + } + // todo(dabayer): Do we want to expose .arrive() and .wait()? Do we want to implement .sync() using them? Do we want // aligned/unaligned variants? _CCCL_DEVICE_API void sync() noexcept diff --git a/cudax/include/cuda/experimental/__group/synchronizer/lane_synchronizer.cuh b/cudax/include/cuda/experimental/__group/synchronizer/lane_synchronizer.cuh index 736432d5254..973efcd29c0 100644 --- a/cudax/include/cuda/experimental/__group/synchronizer/lane_synchronizer.cuh +++ b/cudax/include/cuda/experimental/__group/synchronizer/lane_synchronizer.cuh @@ -55,6 +55,11 @@ public: return {0u}; } + [[nodiscard]] _CCCL_DEVICE_API unsigned __lane_mask() const noexcept + { + return __lane_mask_; + } + template _CCCL_DEVICE_API void do_sync(const _MappingResult&, const lane_synchronizer&) const noexcept { diff --git a/cudax/test/group/cooperative_algorithm.cu b/cudax/test/group/cooperative_algorithm.cu index c256b4eeded..3dae756aad8 100644 --- a/cudax/test/group/cooperative_algorithm.cu +++ b/cudax/test/group/cooperative_algorithm.cu @@ -116,46 +116,60 @@ __device__ cuda::std::optional sum(cudax::this_cluster group, T (& return (cuda::gpu_thread.is_root_rank(group)) ? cuda::std::optional{result} : cuda::std::nullopt; } -// todo(dabayer): Add support for warp and cluster levels. template __device__ cuda::std::optional sum(Group group, T (&array)[N]) { - using Unit = typename Group::unit_type; - using MappingResult = typename Group::__mapping_result_type; - - constexpr auto ngroups = MappingResult::static_group_count(); - static_assert(ngroups != cuda::std::dynamic_extent, "group count must be statically known"); - - __shared__ T group_sums[ngroups]; + using Level = typename Group::level_type; - if (!Unit{}.is_part_of(group)) + if constexpr (cuda::std::is_same_v) { - return cuda::std::nullopt; + const auto result_unit = sum(cudax::this_thread{group.hierarchy()}, array); + + // todo(dabayer): Implement fallback for cc < 80. + T result; + NV_IF_TARGET(NV_PROVIDES_SM_80, + ({ result = __reduce_add_sync(group.__synchronizer_instance().__lane_mask(), result_unit.value()); })) + return (cuda::gpu_thread.is_root_rank(group)) ? cuda::std::optional{result} : cuda::std::nullopt; } + else + { + using Unit = typename Group::unit_type; + using MappingResult = typename Group::__mapping_result_type; - // todo(dabayer): Replace by group.rank(level) once this query is available. - const auto group_rank = group.__mapping_result().group_rank(); + constexpr auto ngroups = MappingResult::static_group_count(); + static_assert(ngroups != cuda::std::dynamic_extent, "group count must be statically known"); - if (cuda::gpu_thread.is_root_rank(group)) - { - group_sums[group_rank] = 0; - } + __shared__ T group_sums[ngroups]; - const auto unit_group = cudax::make_this_group(Unit{}, group.hierarchy()); - const auto result_unit = sum(unit_group, array); + if (!Unit{}.is_part_of(group)) + { + return cuda::std::nullopt; + } - // Wait until group_sums are are filled with 0. - group.sync_aligned(); + const auto group_rank = group.rank(Level{}); - if (cuda::gpu_thread.is_root_rank(unit_group)) - { - cuda::atomic_ref{group_sums[group_rank]} += result_unit.value(); - } + if (cuda::gpu_thread.is_root_rank(group)) + { + group_sums[group_rank] = 0; + } + + const auto unit_group = cudax::make_this_group(Unit{}, group.hierarchy()); + const auto result_unit = sum(unit_group, array); + + // Wait until group_sums are are filled with 0. + group.sync_aligned(); - // Wait until all unit_group roots add the intermediate sum to the shared memory. - group.sync_aligned(); + if (cuda::gpu_thread.is_root_rank(unit_group)) + { + constexpr auto min_thread_scope = cudax::__minimum_required_scope_for(); + cuda::atomic_ref{group_sums[group_rank]} += result_unit.value(); + } - return (cuda::gpu_thread.is_root_rank(group)) ? cuda::std::optional{group_sums[group_rank]} : cuda::std::nullopt; + // Wait until all unit_group roots add the intermediate sum to the shared memory. + group.sync_aligned(); + + return (cuda::gpu_thread.is_root_rank(group)) ? cuda::std::optional{group_sums[group_rank]} : cuda::std::nullopt; + } } template @@ -190,10 +204,24 @@ struct TestKernel test_cooperative_algorithm(cudax::this_block{config}); test_cooperative_algorithm(cudax::this_cluster{config}); + // todo(dabayer): Enable these once cc < 75 fallback for __reduce_add is implemented. + NV_IF_TARGET( + NV_PROVIDES_SM_80, ({ + test_cooperative_algorithm( + cudax::group{cuda::gpu_thread, cudax::this_warp{config}, cudax::group_by<2>{}, cudax::lane_synchronizer{}}); + test_cooperative_algorithm( + cudax::group{cuda::gpu_thread, cudax::this_warp{config}, cudax::group_by<16>{}, cudax::lane_synchronizer{}}); + })) + test_cooperative_algorithm( cudax::group{cuda::gpu_thread, cudax::this_block{config}, cudax::group_by<2>{}, cudax::lane_synchronizer{}}); test_cooperative_algorithm( cudax::group{cuda::gpu_thread, cudax::this_block{config}, cudax::group_by<16>{}, cudax::lane_synchronizer{}}); + + test_cooperative_algorithm( + cudax::group{cuda::gpu_thread, cudax::this_cluster{config}, cudax::group_by<2>{}, cudax::lane_synchronizer{}}); + test_cooperative_algorithm( + cudax::group{cuda::gpu_thread, cudax::this_cluster{config}, cudax::group_by<16>{}, cudax::lane_synchronizer{}}); } }; } // namespace