diff --git a/README.md b/README.md index 9506937d..6100f7bd 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,11 @@ build of PyTorch past 4/10/2025 due to incomplete support for TorchBind in earlier versions. ## Running MACE +**NOTE**: If you're revisiting this page, the repo containing +our up-to-date MACE integration has changed! See the instructions +below; we use a branch off a fork of MACE to facilitate +PRs into the main codebase. + We have modified MACE to use our accelerated kernels instead of the standard e3nn backend. Here are the steps to replicate our MACE benchmark: @@ -237,7 +242,7 @@ our MACE benchmark: ```bash pip uninstall mace-torch pip install git+https://github.com/PASSIONLab/OpenEquivariance -pip install git+https://github.com/vbharadwaj-bk/mace_oeq +pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental ``` 2. Download the `carbon.xyz` data file, available at . @@ -260,11 +265,11 @@ python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq | Operation | CUDA | HIP | |--------------------------|----------|-----| -| UVU Batch | ✅ | ✅ | -| UVW Batch | ✅ | ✅ | -| UVU Convolution | ✅ | ✅ | -| UVW Convolution | ✅ | ✅ | -| Symmetric Tensor Product | ✅ (beta) | 🚧🔨 | +| UVU | ✅ | ✅ | +| UVW | ✅ | ✅ | +| UVU + Convolution | ✅ | ✅ | +| UVW + Convolution | ✅ | ✅ | +| Symmetric Tensor Product | ✅ (beta) | ✅ (beta) | e3nn supports a variety of connection modes for CG tensor products. We support two that are commonly used in equivariant graph neural networks: @@ -290,6 +295,19 @@ We do not (yet) support: If you have a use case for any of the unsupported features above, let us know. +We have recently added beta support for symmetric +contraction acceleration. Because this is a kernel +specific to MACE, we require e3nn as dependency +to run it, and there is currently no support for +compile / export (coming soon!), we +do not expose it in the package +toplevel. You can test out our implementation by +running + +```python +from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction +``` + ## Multidevice / Stream Support To use OpenEquivariance on multiple GPUs of a single compute node, we currently require that all GPUs diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/extension/group_mm_cuda.hpp index 58f2a3e3..95f1412f 100644 --- a/openequivariance/extension/group_mm_cuda.hpp +++ b/openequivariance/extension/group_mm_cuda.hpp @@ -9,7 +9,6 @@ using namespace std; template class GroupMMCUDA { - cudaError_t cudaStat; cublasStatus_t stat; cublasHandle_t handle; @@ -117,7 +116,7 @@ class GroupMMCUDA { batch_size); } else { - throw std::logic_error("Double precision support in progress."); + throw std::logic_error("Unsupported datatype for grouped GEMM!"); } if (stat != CUBLAS_STATUS_SUCCESS) { throw std::logic_error("Grouped GEMM failed!"); diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/extension/group_mm_hip.hpp index 0076a95e..2e713cf4 100644 --- a/openequivariance/extension/group_mm_hip.hpp +++ b/openequivariance/extension/group_mm_hip.hpp @@ -1,7 +1,16 @@ #pragma once +#include "rocblas/rocblas.h" +#include +#include +#include + + template class GroupMMHIP { + rocblas_status stat; + rocblas_handle handle; + int num_W; int batch_size; @@ -13,22 +22,108 @@ class GroupMMHIP { num_W(num_W), batch_size(batch_size), alpha(1.0), - beta(0.0) { - // TODO: To implement. + beta(0.0) { + if(rocblas_create_handle(&handle) != rocblas_status_success) { + throw std::logic_error("rocBLAS initialization failed"); + } } void group_gemm(void* A_raw, void* B_raw, void* C_raw, int64_t* ragged_counts, int m, int k, int ragged_inner) { - // TODO: To implement. + + T* A_base = reinterpret_cast(A_raw); + T* B_base = reinterpret_cast(B_raw); + T* C_base = reinterpret_cast(C_raw); + + int64_t ragged_offset = 0; + for(int i = 0; i < num_W; i++) { + int M, K, N, lda, ldb, ldc; + T *A, *B, *C; + + int strideA, strideB, strideC; + rocblas_operation transa, transb; + + if(ragged_inner == 0) { + M = m; + K = k; + N = static_cast(ragged_counts[i]); + + A = A_base + (m * k * batch_size * i); + lda = k; strideA = M * K; + + B = B_base + (k * batch_size * ragged_offset); + ldb = K * batch_size; strideB = K; + + C = C_base + (m * batch_size * ragged_offset); + ldc = M * batch_size; strideC = M; + + transa = rocblas_operation_transpose; + transb = rocblas_operation_none; + } + else { + M = k; + K = static_cast(ragged_counts[i]); + N = m; + + A = B_base + (k * batch_size * ragged_offset); + lda = k * batch_size; strideA = M; + + B = A_base + (m * batch_size * ragged_offset); + ldb = m * batch_size; strideB = N; + + C = C_base + (m * k * batch_size * i); + ldc = k; strideC = M * N; + + transa = rocblas_operation_none; + transb = rocblas_operation_transpose; + } + ragged_offset += ragged_counts[i]; + + if(ragged_counts[i] > 0) { + if(std::is_same::value) { + stat = rocblas_sgemm_strided_batched(handle, + transa, transb, + M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } + else if(std::is_same::value) { + stat = rocblas_dgemm_strided_batched(handle, + transa, transb, + M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } + else { + throw std::logic_error("Unsupported datatype for grouped GEMM!"); + } + if (stat != rocblas_status_success) { + throw std::logic_error("Grouped GEMM failed!"); + } + } + } } void group_gemm_intptr(uint64_t weights, uint64_t vectors, uint64_t output, uint64_t ragged_counts, int m, int k, int ragged_inner) { - // TODO: To implement. + group_gemm( + reinterpret_cast(weights), + reinterpret_cast(vectors), + reinterpret_cast(output), + reinterpret_cast(ragged_counts), + m, k, ragged_inner); } ~GroupMMHIP() { - // TODO: To implement. + rocblas_destroy_handle(handle); } }; \ No newline at end of file diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/extension/util/backend_hip.hpp index c9360189..91d0f6ee 100644 --- a/openequivariance/extension/util/backend_hip.hpp +++ b/openequivariance/extension/util/backend_hip.hpp @@ -154,7 +154,7 @@ class __attribute__((visibility("default"))) KernelLibrary { public: int device; KernelLibrary(hiprtcProgram &prog, vector &kernel_binary, vector &kernel_names) { - hipGetDevice(&device); + HIP_ERRCHK(hipGetDevice(&device)); HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data())); for (size_t i = 0; i < kernel_names.size(); i++) { diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py new file mode 100644 index 00000000..75ac6cc8 --- /dev/null +++ b/openequivariance/implementations/symmetric_contraction/__init__.py @@ -0,0 +1,5 @@ +from openequivariance.implementations.symmetric_contraction.symmetric_contraction import ( + SymmetricContraction, +) + +__all__ = ["SymmetricContraction"] diff --git a/openequivariance/implementations/symmetric_contraction/STPOpt.py b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py similarity index 78% rename from openequivariance/implementations/symmetric_contraction/STPOpt.py rename to openequivariance/implementations/symmetric_contraction/symmetric_contraction.py index 890ca7c0..9790c2a2 100644 --- a/openequivariance/implementations/symmetric_contraction/STPOpt.py +++ b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py @@ -109,12 +109,124 @@ def forward(self, weights, vectors, bincounts): # -------------------------------------------------------------------------- -from typing import Dict, Optional, Union +# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106 + +import collections +from typing import Dict, Optional, Union, List from e3nn import o3 from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode -from mace.tools.cg import U_matrix_real + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + + if correlation == 4: + filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] + + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out @compile_mode("script")