From 41fc23c0ddbd1051807477341b03b55e4da9a3a2 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 May 2025 21:04:19 -0700 Subject: [PATCH 1/8] Updated README. --- README.md | 10 +- openequivariance/extension/group_mm_cuda.hpp | 3 +- openequivariance/extension/group_mm_hip.hpp | 105 ++++++++++++++++++- 3 files changed, 106 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9506937d..1481f684 100644 --- a/README.md +++ b/README.md @@ -260,11 +260,11 @@ python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq | Operation | CUDA | HIP | |--------------------------|----------|-----| -| UVU Batch | ✅ | ✅ | -| UVW Batch | ✅ | ✅ | -| UVU Convolution | ✅ | ✅ | -| UVW Convolution | ✅ | ✅ | -| Symmetric Tensor Product | ✅ (beta) | 🚧🔨 | +| UVU | ✅ | ✅ | +| UVW | ✅ | ✅ | +| UVU + Convolution | ✅ | ✅ | +| UVW + Convolution | ✅ | ✅ | +| Symmetric Tensor Product | ✅ (beta) | ✅ (beta) | e3nn supports a variety of connection modes for CG tensor products. We support two that are commonly used in equivariant graph neural networks: diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/extension/group_mm_cuda.hpp index 58f2a3e3..95f1412f 100644 --- a/openequivariance/extension/group_mm_cuda.hpp +++ b/openequivariance/extension/group_mm_cuda.hpp @@ -9,7 +9,6 @@ using namespace std; template class GroupMMCUDA { - cudaError_t cudaStat; cublasStatus_t stat; cublasHandle_t handle; @@ -117,7 +116,7 @@ class GroupMMCUDA { batch_size); } else { - throw std::logic_error("Double precision support in progress."); + throw std::logic_error("Unsupported datatype for grouped GEMM!"); } if (stat != CUBLAS_STATUS_SUCCESS) { throw std::logic_error("Grouped GEMM failed!"); diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/extension/group_mm_hip.hpp index 0076a95e..1f2d52ed 100644 --- a/openequivariance/extension/group_mm_hip.hpp +++ b/openequivariance/extension/group_mm_hip.hpp @@ -1,7 +1,16 @@ #pragma once +#include "rocblas.h" +#include +#include +#include + + template class GroupMMHIP { + rocblas_status stat; + rocblas_handle handle; + int num_W; int batch_size; @@ -13,22 +22,108 @@ class GroupMMHIP { num_W(num_W), batch_size(batch_size), alpha(1.0), - beta(0.0) { - // TODO: To implement. + beta(0.0) { + if(rocblas_create_handle(&handle) != rocblas_status_success) { + throw std::logic_error("rocBLAS initialization failed"); + } } void group_gemm(void* A_raw, void* B_raw, void* C_raw, int64_t* ragged_counts, int m, int k, int ragged_inner) { - // TODO: To implement. + + T* A_base = reinterpret_cast(A_raw); + T* B_base = reinterpret_cast(B_raw); + T* C_base = reinterpret_cast(C_raw); + + int64_t ragged_offset = 0; + for(int i = 0; i < num_W; i++) { + int M, K, N, lda, ldb, ldc; + T *A, *B, *C; + + int strideA, strideB, strideC; + rocblas_operation transa, transb; + + if(ragged_inner == 0) { + M = m; + K = k; + N = static_cast(ragged_counts[i]); + + A = A_base + (m * k * batch_size * i); + lda = k; strideA = M * K; + + B = B_base + (k * batch_size * ragged_offset); + ldb = K * batch_size; strideB = K; + + C = C_base + (m * batch_size * ragged_offset); + ldc = M * batch_size; strideC = M; + + transa = CUBLAS_OP_T; + transb = CUBLAS_OP_N; + } + else { + M = k; + K = static_cast(ragged_counts[i]); + N = m; + + A = B_base + (k * batch_size * ragged_offset); + lda = k * batch_size; strideA = M; + + B = A_base + (m * batch_size * ragged_offset); + ldb = m * batch_size; strideB = N; + + C = C_base + (m * k * batch_size * i); + ldc = k; strideC = M * N; + + transa = CUBLAS_OP_N; + transb = CUBLAS_OP_T; + } + ragged_offset += ragged_counts[i]; + + if(ragged_counts[i] > 0) { + if(std::is_same::value) { + stat = rocblas_sgemm_strided_batched(handle, + transa, transb, + M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } + else if(std::is_same::value) { + stat = rocblas_dgemm_strided_batched(handle, + transa, transb, + M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } + else { + throw std::logic_error("Double precision support in progress."); + } + if (stat != rocblas_status_success) { + throw std::logic_error("Grouped GEMM failed!"); + } + } + } } void group_gemm_intptr(uint64_t weights, uint64_t vectors, uint64_t output, uint64_t ragged_counts, int m, int k, int ragged_inner) { - // TODO: To implement. + group_gemm( + reinterpret_cast(weights), + reinterpret_cast(vectors), + reinterpret_cast(output), + reinterpret_cast(ragged_counts), + m, k, ragged_inner); } ~GroupMMHIP() { - // TODO: To implement. + rocblas_destroy_handle(handle) } }; \ No newline at end of file From 32722b42f1ec174602b5727dffbc03fd6067d5d3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 May 2025 01:04:40 -0400 Subject: [PATCH 2/8] Cleaned up RocM call. --- openequivariance/extension/group_mm_hip.hpp | 12 ++++++------ openequivariance/extension/util/backend_hip.hpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/extension/group_mm_hip.hpp index 1f2d52ed..abd974bf 100644 --- a/openequivariance/extension/group_mm_hip.hpp +++ b/openequivariance/extension/group_mm_hip.hpp @@ -1,6 +1,6 @@ #pragma once -#include "rocblas.h" +#include "rocblas/rocblas.h" #include #include #include @@ -57,8 +57,8 @@ class GroupMMHIP { C = C_base + (m * batch_size * ragged_offset); ldc = M * batch_size; strideC = M; - transa = CUBLAS_OP_T; - transb = CUBLAS_OP_N; + transa = rocblas_operation_transpose; + transb = rocblas_operation_none; } else { M = k; @@ -74,8 +74,8 @@ class GroupMMHIP { C = C_base + (m * k * batch_size * i); ldc = k; strideC = M * N; - transa = CUBLAS_OP_N; - transb = CUBLAS_OP_T; + transa = rocblas_operation_none; + transb = rocblas_operation_transpose; } ragged_offset += ragged_counts[i]; @@ -124,6 +124,6 @@ class GroupMMHIP { } ~GroupMMHIP() { - rocblas_destroy_handle(handle) + rocblas_destroy_handle(handle); } }; \ No newline at end of file diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/extension/util/backend_hip.hpp index c9360189..91d0f6ee 100644 --- a/openequivariance/extension/util/backend_hip.hpp +++ b/openequivariance/extension/util/backend_hip.hpp @@ -154,7 +154,7 @@ class __attribute__((visibility("default"))) KernelLibrary { public: int device; KernelLibrary(hiprtcProgram &prog, vector &kernel_binary, vector &kernel_names) { - hipGetDevice(&device); + HIP_ERRCHK(hipGetDevice(&device)); HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data())); for (size_t i = 0; i < kernel_names.size(); i++) { From ecd9b87575d683abd9408821b1570e3b646112b3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 May 2025 22:40:56 -0700 Subject: [PATCH 3/8] Exposed Symmetric Tensor Product in toplevel. --- openequivariance/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 21488786..20dc0461 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -9,6 +9,7 @@ TensorProductConv, ) from openequivariance.implementations.utils import torch_to_oeq_dtype +from openequivariance.implementations.symmetric_contraction.STPOpt import SymmetricContraction __all__ = [ "TPProblem", @@ -16,6 +17,7 @@ "TensorProduct", "TensorProductConv", "torch_to_oeq_dtype", + "SymmetricContraction" ] __version__ = version("openequivariance") From 2421007222bc4be2222e60123fc1b13e648f5543 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 30 May 2025 23:46:11 -0700 Subject: [PATCH 4/8] Updated README. --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1481f684..2646a2af 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,11 @@ build of PyTorch past 4/10/2025 due to incomplete support for TorchBind in earlier versions. ## Running MACE +**NOTE**: If you're revisiting this page, the repo containing +our up-to-date MACE integration has changed! See the instructions +below; our new repo is a fork of MACE to facilitate a PR into the +main codebase. + We have modified MACE to use our accelerated kernels instead of the standard e3nn backend. Here are the steps to replicate our MACE benchmark: @@ -237,7 +242,7 @@ our MACE benchmark: ```bash pip uninstall mace-torch pip install git+https://github.com/PASSIONLab/OpenEquivariance -pip install git+https://github.com/vbharadwaj-bk/mace_oeq +pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration/tree/oeq_experimental ``` 2. Download the `carbon.xyz` data file, available at . From 22fe0e9bbe5fb0389e1ad6c1fd2374f766544492 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 May 2025 03:32:24 -0400 Subject: [PATCH 5/8] Updated README with correct installation instructions. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2646a2af..a2ddfd00 100644 --- a/README.md +++ b/README.md @@ -231,8 +231,8 @@ TorchBind in earlier versions. ## Running MACE **NOTE**: If you're revisiting this page, the repo containing our up-to-date MACE integration has changed! See the instructions -below; our new repo is a fork of MACE to facilitate a PR into the -main codebase. +below; we use a branch off a fork of MACE to facilitate +PRs into the main codebase. We have modified MACE to use our accelerated kernels instead of the standard e3nn backend. Here are the steps to replicate @@ -242,7 +242,7 @@ our MACE benchmark: ```bash pip uninstall mace-torch pip install git+https://github.com/PASSIONLab/OpenEquivariance -pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration/tree/oeq_experimental +pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental ``` 2. Download the `carbon.xyz` data file, available at . From 313d8e1c4836b74878ecad265422d81ed9ffd678 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 May 2025 04:20:38 -0400 Subject: [PATCH 6/8] Avoid exposing STP to toplevel and remove MACE dependencies from STPOpt. --- openequivariance/__init__.py | 2 - openequivariance/extension/group_mm_hip.hpp | 2 +- .../symmetric_contraction/__init__.py | 1 + .../{STPOpt.py => symmetric_contraction.py} | 124 +++++++++++++++++- 4 files changed, 123 insertions(+), 6 deletions(-) create mode 100644 openequivariance/implementations/symmetric_contraction/__init__.py rename openequivariance/implementations/symmetric_contraction/{STPOpt.py => symmetric_contraction.py} (77%) diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 20dc0461..21488786 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -9,7 +9,6 @@ TensorProductConv, ) from openequivariance.implementations.utils import torch_to_oeq_dtype -from openequivariance.implementations.symmetric_contraction.STPOpt import SymmetricContraction __all__ = [ "TPProblem", @@ -17,7 +16,6 @@ "TensorProduct", "TensorProductConv", "torch_to_oeq_dtype", - "SymmetricContraction" ] __version__ = version("openequivariance") diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/extension/group_mm_hip.hpp index abd974bf..2e713cf4 100644 --- a/openequivariance/extension/group_mm_hip.hpp +++ b/openequivariance/extension/group_mm_hip.hpp @@ -103,7 +103,7 @@ class GroupMMHIP { batch_size); } else { - throw std::logic_error("Double precision support in progress."); + throw std::logic_error("Unsupported datatype for grouped GEMM!"); } if (stat != rocblas_status_success) { throw std::logic_error("Grouped GEMM failed!"); diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py new file mode 100644 index 00000000..36d90ade --- /dev/null +++ b/openequivariance/implementations/symmetric_contraction/__init__.py @@ -0,0 +1 @@ +from openequivariance.implementations.symmetric_contraction.symmetric_contraction import SymmetricContraction \ No newline at end of file diff --git a/openequivariance/implementations/symmetric_contraction/STPOpt.py b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py similarity index 77% rename from openequivariance/implementations/symmetric_contraction/STPOpt.py rename to openequivariance/implementations/symmetric_contraction/symmetric_contraction.py index 890ca7c0..2fea7573 100644 --- a/openequivariance/implementations/symmetric_contraction/STPOpt.py +++ b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py @@ -3,7 +3,6 @@ from openequivariance.extlib import GroupMM_F32, GroupMM_F64 - class GroupMM: next_id = 0 @@ -109,13 +108,132 @@ def forward(self, weights, vectors, bincounts): # -------------------------------------------------------------------------- -from typing import Dict, Optional, Union +# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106 + +import collections +from typing import Dict, Optional, Union, List from e3nn import o3 from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode -from mace.tools.cg import U_matrix_real +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + + if correlation == 4: + filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] + + try: + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + except NotImplementedError as e: + if CUET_AVAILABLE: + return compute_U_cueq( + irreps_in, irreps_out=irreps_out, correlation=correlation + ) + raise NotImplementedError( + "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" + ) from e + + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out @compile_mode("script") class Contraction(torch.nn.Module): From 9e92ea77e892ff6ee79f7f3393db5d2da81448a9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 May 2025 04:24:09 -0400 Subject: [PATCH 7/8] Updated README. --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index a2ddfd00..6100f7bd 100644 --- a/README.md +++ b/README.md @@ -295,6 +295,19 @@ We do not (yet) support: If you have a use case for any of the unsupported features above, let us know. +We have recently added beta support for symmetric +contraction acceleration. Because this is a kernel +specific to MACE, we require e3nn as dependency +to run it, and there is currently no support for +compile / export (coming soon!), we +do not expose it in the package +toplevel. You can test out our implementation by +running + +```python +from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction +``` + ## Multidevice / Stream Support To use OpenEquivariance on multiple GPUs of a single compute node, we currently require that all GPUs From ccc457fba2d540c6d9c97d9c3990636538b3642d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 31 May 2025 04:26:48 -0400 Subject: [PATCH 8/8] Linting. --- .../symmetric_contraction/__init__.py | 6 +++++- .../symmetric_contraction.py | 16 +++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py index 36d90ade..75ac6cc8 100644 --- a/openequivariance/implementations/symmetric_contraction/__init__.py +++ b/openequivariance/implementations/symmetric_contraction/__init__.py @@ -1 +1,5 @@ -from openequivariance.implementations.symmetric_contraction.symmetric_contraction import SymmetricContraction \ No newline at end of file +from openequivariance.implementations.symmetric_contraction.symmetric_contraction import ( + SymmetricContraction, +) + +__all__ = ["SymmetricContraction"] diff --git a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py index 2fea7573..9790c2a2 100644 --- a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py @@ -3,6 +3,7 @@ from openequivariance.extlib import GroupMM_F32, GroupMM_F64 + class GroupMM: next_id = 0 @@ -120,6 +121,7 @@ def forward(self, weights, vectors, bincounts): _TP = collections.namedtuple("_TP", "op, args") _INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + def _wigner_nj( irrepss: List[o3.Irreps], normalization: str = "component", @@ -198,7 +200,7 @@ def U_matrix_real( correlation: int, normalization: str = "component", filter_ir_mid=None, - dtype=None + dtype=None, ): irreps_out = o3.Irreps(irreps_out) irrepss = [o3.Irreps(irreps_in)] * correlation @@ -206,16 +208,7 @@ def U_matrix_real( if correlation == 4: filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] - try: - wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) - except NotImplementedError as e: - if CUET_AVAILABLE: - return compute_U_cueq( - irreps_in, irreps_out=irreps_out, correlation=correlation - ) - raise NotImplementedError( - "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" - ) from e + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) current_ir = wigners[0][0] out = [] @@ -235,6 +228,7 @@ def U_matrix_real( out += [last_ir, stack] return out + @compile_mode("script") class Contraction(torch.nn.Module): def __init__(