diff --git a/README.md b/README.md
index 9506937d..6100f7bd 100644
--- a/README.md
+++ b/README.md
@@ -229,6 +229,11 @@ build of PyTorch past 4/10/2025 due to incomplete support for
TorchBind in earlier versions.
## Running MACE
+**NOTE**: If you're revisiting this page, the repo containing
+our up-to-date MACE integration has changed! See the instructions
+below; we use a branch off a fork of MACE to facilitate
+PRs into the main codebase.
+
We have modified MACE to use our accelerated kernels instead
of the standard e3nn backend. Here are the steps to replicate
our MACE benchmark:
@@ -237,7 +242,7 @@ our MACE benchmark:
```bash
pip uninstall mace-torch
pip install git+https://github.com/PASSIONLab/OpenEquivariance
-pip install git+https://github.com/vbharadwaj-bk/mace_oeq
+pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental
```
2. Download the `carbon.xyz` data file, available at .
@@ -260,11 +265,11 @@ python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq
| Operation | CUDA | HIP |
|--------------------------|----------|-----|
-| UVU Batch | ✅ | ✅ |
-| UVW Batch | ✅ | ✅ |
-| UVU Convolution | ✅ | ✅ |
-| UVW Convolution | ✅ | ✅ |
-| Symmetric Tensor Product | ✅ (beta) | 🚧🔨 |
+| UVU | ✅ | ✅ |
+| UVW | ✅ | ✅ |
+| UVU + Convolution | ✅ | ✅ |
+| UVW + Convolution | ✅ | ✅ |
+| Symmetric Tensor Product | ✅ (beta) | ✅ (beta) |
e3nn supports a variety of connection modes for CG tensor products. We support
two that are commonly used in equivariant graph neural networks:
@@ -290,6 +295,19 @@ We do not (yet) support:
If you have a use case for any of the unsupported features above, let us know.
+We have recently added beta support for symmetric
+contraction acceleration. Because this is a kernel
+specific to MACE, we require e3nn as dependency
+to run it, and there is currently no support for
+compile / export (coming soon!), we
+do not expose it in the package
+toplevel. You can test out our implementation by
+running
+
+```python
+from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
+```
+
## Multidevice / Stream Support
To use OpenEquivariance on multiple GPUs of a single
compute node, we currently require that all GPUs
diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/extension/group_mm_cuda.hpp
index 58f2a3e3..95f1412f 100644
--- a/openequivariance/extension/group_mm_cuda.hpp
+++ b/openequivariance/extension/group_mm_cuda.hpp
@@ -9,7 +9,6 @@ using namespace std;
template
class GroupMMCUDA {
- cudaError_t cudaStat;
cublasStatus_t stat;
cublasHandle_t handle;
@@ -117,7 +116,7 @@ class GroupMMCUDA {
batch_size);
}
else {
- throw std::logic_error("Double precision support in progress.");
+ throw std::logic_error("Unsupported datatype for grouped GEMM!");
}
if (stat != CUBLAS_STATUS_SUCCESS) {
throw std::logic_error("Grouped GEMM failed!");
diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/extension/group_mm_hip.hpp
index 0076a95e..2e713cf4 100644
--- a/openequivariance/extension/group_mm_hip.hpp
+++ b/openequivariance/extension/group_mm_hip.hpp
@@ -1,7 +1,16 @@
#pragma once
+#include "rocblas/rocblas.h"
+#include
+#include
+#include
+
+
template
class GroupMMHIP {
+ rocblas_status stat;
+ rocblas_handle handle;
+
int num_W;
int batch_size;
@@ -13,22 +22,108 @@ class GroupMMHIP {
num_W(num_W),
batch_size(batch_size),
alpha(1.0),
- beta(0.0) {
- // TODO: To implement.
+ beta(0.0) {
+ if(rocblas_create_handle(&handle) != rocblas_status_success) {
+ throw std::logic_error("rocBLAS initialization failed");
+ }
}
void group_gemm(void* A_raw, void* B_raw, void* C_raw,
int64_t* ragged_counts, int m, int k, int ragged_inner) {
- // TODO: To implement.
+
+ T* A_base = reinterpret_cast(A_raw);
+ T* B_base = reinterpret_cast(B_raw);
+ T* C_base = reinterpret_cast(C_raw);
+
+ int64_t ragged_offset = 0;
+ for(int i = 0; i < num_W; i++) {
+ int M, K, N, lda, ldb, ldc;
+ T *A, *B, *C;
+
+ int strideA, strideB, strideC;
+ rocblas_operation transa, transb;
+
+ if(ragged_inner == 0) {
+ M = m;
+ K = k;
+ N = static_cast(ragged_counts[i]);
+
+ A = A_base + (m * k * batch_size * i);
+ lda = k; strideA = M * K;
+
+ B = B_base + (k * batch_size * ragged_offset);
+ ldb = K * batch_size; strideB = K;
+
+ C = C_base + (m * batch_size * ragged_offset);
+ ldc = M * batch_size; strideC = M;
+
+ transa = rocblas_operation_transpose;
+ transb = rocblas_operation_none;
+ }
+ else {
+ M = k;
+ K = static_cast(ragged_counts[i]);
+ N = m;
+
+ A = B_base + (k * batch_size * ragged_offset);
+ lda = k * batch_size; strideA = M;
+
+ B = A_base + (m * batch_size * ragged_offset);
+ ldb = m * batch_size; strideB = N;
+
+ C = C_base + (m * k * batch_size * i);
+ ldc = k; strideC = M * N;
+
+ transa = rocblas_operation_none;
+ transb = rocblas_operation_transpose;
+ }
+ ragged_offset += ragged_counts[i];
+
+ if(ragged_counts[i] > 0) {
+ if(std::is_same::value) {
+ stat = rocblas_sgemm_strided_batched(handle,
+ transa, transb,
+ M, N, K,
+ reinterpret_cast(&alpha),
+ reinterpret_cast(A), lda, strideA,
+ reinterpret_cast(B), ldb, strideB,
+ reinterpret_cast(&beta),
+ reinterpret_cast(C), ldc, strideC,
+ batch_size);
+ }
+ else if(std::is_same::value) {
+ stat = rocblas_dgemm_strided_batched(handle,
+ transa, transb,
+ M, N, K,
+ reinterpret_cast(&alpha),
+ reinterpret_cast(A), lda, strideA,
+ reinterpret_cast(B), ldb, strideB,
+ reinterpret_cast(&beta),
+ reinterpret_cast(C), ldc, strideC,
+ batch_size);
+ }
+ else {
+ throw std::logic_error("Unsupported datatype for grouped GEMM!");
+ }
+ if (stat != rocblas_status_success) {
+ throw std::logic_error("Grouped GEMM failed!");
+ }
+ }
+ }
}
void group_gemm_intptr(uint64_t weights,
uint64_t vectors, uint64_t output,
uint64_t ragged_counts, int m, int k, int ragged_inner) {
- // TODO: To implement.
+ group_gemm(
+ reinterpret_cast(weights),
+ reinterpret_cast(vectors),
+ reinterpret_cast(output),
+ reinterpret_cast(ragged_counts),
+ m, k, ragged_inner);
}
~GroupMMHIP() {
- // TODO: To implement.
+ rocblas_destroy_handle(handle);
}
};
\ No newline at end of file
diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/extension/util/backend_hip.hpp
index c9360189..91d0f6ee 100644
--- a/openequivariance/extension/util/backend_hip.hpp
+++ b/openequivariance/extension/util/backend_hip.hpp
@@ -154,7 +154,7 @@ class __attribute__((visibility("default"))) KernelLibrary {
public:
int device;
KernelLibrary(hiprtcProgram &prog, vector &kernel_binary, vector &kernel_names) {
- hipGetDevice(&device);
+ HIP_ERRCHK(hipGetDevice(&device));
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));
for (size_t i = 0; i < kernel_names.size(); i++) {
diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/implementations/symmetric_contraction/__init__.py
new file mode 100644
index 00000000..75ac6cc8
--- /dev/null
+++ b/openequivariance/implementations/symmetric_contraction/__init__.py
@@ -0,0 +1,5 @@
+from openequivariance.implementations.symmetric_contraction.symmetric_contraction import (
+ SymmetricContraction,
+)
+
+__all__ = ["SymmetricContraction"]
diff --git a/openequivariance/implementations/symmetric_contraction/STPOpt.py b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py
similarity index 78%
rename from openequivariance/implementations/symmetric_contraction/STPOpt.py
rename to openequivariance/implementations/symmetric_contraction/symmetric_contraction.py
index 890ca7c0..9790c2a2 100644
--- a/openequivariance/implementations/symmetric_contraction/STPOpt.py
+++ b/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py
@@ -109,12 +109,124 @@ def forward(self, weights, vectors, bincounts):
# --------------------------------------------------------------------------
-from typing import Dict, Optional, Union
+# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106
+
+import collections
+from typing import Dict, Optional, Union, List
from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
-from mace.tools.cg import U_matrix_real
+
+_TP = collections.namedtuple("_TP", "op, args")
+_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")
+
+
+def _wigner_nj(
+ irrepss: List[o3.Irreps],
+ normalization: str = "component",
+ filter_ir_mid=None,
+ dtype=None,
+):
+ irrepss = [o3.Irreps(irreps) for irreps in irrepss]
+ if filter_ir_mid is not None:
+ filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]
+
+ if len(irrepss) == 1:
+ (irreps,) = irrepss
+ ret = []
+ e = torch.eye(irreps.dim, dtype=dtype)
+ i = 0
+ for mul, ir in irreps:
+ for _ in range(mul):
+ sl = slice(i, i + ir.dim)
+ ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
+ i += ir.dim
+ return ret
+
+ *irrepss_left, irreps_right = irrepss
+ ret = []
+ for ir_left, path_left, C_left in _wigner_nj(
+ irrepss_left,
+ normalization=normalization,
+ filter_ir_mid=filter_ir_mid,
+ dtype=dtype,
+ ):
+ i = 0
+ for mul, ir in irreps_right:
+ for ir_out in ir_left * ir:
+ if filter_ir_mid is not None and ir_out not in filter_ir_mid:
+ continue
+
+ C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
+ if normalization == "component":
+ C *= ir_out.dim**0.5
+ if normalization == "norm":
+ C *= ir_left.dim**0.5 * ir.dim**0.5
+
+ C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
+ C = C.reshape(
+ ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
+ )
+ for u in range(mul):
+ E = torch.zeros(
+ ir_out.dim,
+ *(irreps.dim for irreps in irrepss_left),
+ irreps_right.dim,
+ dtype=dtype,
+ )
+ sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
+ E[..., sl] = C
+ ret += [
+ (
+ ir_out,
+ _TP(
+ op=(ir_left, ir, ir_out),
+ args=(
+ path_left,
+ _INPUT(len(irrepss_left), sl.start, sl.stop),
+ ),
+ ),
+ E,
+ )
+ ]
+ i += mul * ir.dim
+ return sorted(ret, key=lambda x: x[0])
+
+
+def U_matrix_real(
+ irreps_in: Union[str, o3.Irreps],
+ irreps_out: Union[str, o3.Irreps],
+ correlation: int,
+ normalization: str = "component",
+ filter_ir_mid=None,
+ dtype=None,
+):
+ irreps_out = o3.Irreps(irreps_out)
+ irrepss = [o3.Irreps(irreps_in)] * correlation
+
+ if correlation == 4:
+ filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)]
+
+ wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
+
+ current_ir = wigners[0][0]
+ out = []
+ stack = torch.tensor([])
+
+ for ir, _, base_o3 in wigners:
+ if ir in irreps_out and ir == current_ir:
+ stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
+ last_ir = current_ir
+ elif ir in irreps_out and ir != current_ir:
+ if len(stack) != 0:
+ out += [last_ir, stack]
+ stack = base_o3.squeeze().unsqueeze(-1)
+ current_ir, last_ir = ir, ir
+ else:
+ current_ir = ir
+ out += [last_ir, stack]
+ return out
@compile_mode("script")