Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ build of PyTorch past 4/10/2025 due to incomplete support for
TorchBind in earlier versions.

## Running MACE
**NOTE**: If you're revisiting this page, the repo containing
our up-to-date MACE integration has changed! See the instructions
below; we use a branch off a fork of MACE to facilitate
PRs into the main codebase.

We have modified MACE to use our accelerated kernels instead
of the standard e3nn backend. Here are the steps to replicate
our MACE benchmark:
Expand All @@ -237,7 +242,7 @@ our MACE benchmark:
```bash
pip uninstall mace-torch
pip install git+https://github.com/PASSIONLab/OpenEquivariance
pip install git+https://github.com/vbharadwaj-bk/mace_oeq
pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental
```

2. Download the `carbon.xyz` data file, available at <https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/>.
Expand All @@ -260,11 +265,11 @@ python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq

| Operation | CUDA | HIP |
|--------------------------|----------|-----|
| UVU Batch | ✅ | ✅ |
| UVW Batch | ✅ | ✅ |
| UVU Convolution | ✅ | ✅ |
| UVW Convolution | ✅ | ✅ |
| Symmetric Tensor Product | ✅ (beta) | 🚧🔨 |
| UVU | ✅ | ✅ |
| UVW | ✅ | ✅ |
| UVU + Convolution | ✅ | ✅ |
| UVW + Convolution | ✅ | ✅ |
| Symmetric Tensor Product | ✅ (beta) | ✅ (beta) |

e3nn supports a variety of connection modes for CG tensor products. We support
two that are commonly used in equivariant graph neural networks:
Expand All @@ -290,6 +295,19 @@ We do not (yet) support:

If you have a use case for any of the unsupported features above, let us know.

We have recently added beta support for symmetric
contraction acceleration. Because this is a kernel
specific to MACE, we require e3nn as dependency
to run it, and there is currently no support for
compile / export (coming soon!), we
do not expose it in the package
toplevel. You can test out our implementation by
running

```python
from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
```

## Multidevice / Stream Support
To use OpenEquivariance on multiple GPUs of a single
compute node, we currently require that all GPUs
Expand Down
3 changes: 1 addition & 2 deletions openequivariance/extension/group_mm_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using namespace std;

template<typename T>
class GroupMMCUDA {
cudaError_t cudaStat;
cublasStatus_t stat;
cublasHandle_t handle;

Expand Down Expand Up @@ -117,7 +116,7 @@ class GroupMMCUDA {
batch_size);
}
else {
throw std::logic_error("Double precision support in progress.");
throw std::logic_error("Unsupported datatype for grouped GEMM!");
}
if (stat != CUBLAS_STATUS_SUCCESS) {
throw std::logic_error("Grouped GEMM failed!");
Expand Down
105 changes: 100 additions & 5 deletions openequivariance/extension/group_mm_hip.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
#pragma once

#include "rocblas/rocblas.h"
#include <hip/hip_runtime.h>
#include <stdexcept>
#include <iostream>


template<typename T>
class GroupMMHIP {
rocblas_status stat;
rocblas_handle handle;

int num_W;
int batch_size;

Expand All @@ -13,22 +22,108 @@ class GroupMMHIP {
num_W(num_W),
batch_size(batch_size),
alpha(1.0),
beta(0.0) {
// TODO: To implement.
beta(0.0) {
if(rocblas_create_handle(&handle) != rocblas_status_success) {
throw std::logic_error("rocBLAS initialization failed");
}
}

void group_gemm(void* A_raw, void* B_raw, void* C_raw,
int64_t* ragged_counts, int m, int k, int ragged_inner) {
// TODO: To implement.

T* A_base = reinterpret_cast<T*>(A_raw);
T* B_base = reinterpret_cast<T*>(B_raw);
T* C_base = reinterpret_cast<T*>(C_raw);

int64_t ragged_offset = 0;
for(int i = 0; i < num_W; i++) {
int M, K, N, lda, ldb, ldc;
T *A, *B, *C;

int strideA, strideB, strideC;
rocblas_operation transa, transb;

if(ragged_inner == 0) {
M = m;
K = k;
N = static_cast<int>(ragged_counts[i]);

A = A_base + (m * k * batch_size * i);
lda = k; strideA = M * K;

B = B_base + (k * batch_size * ragged_offset);
ldb = K * batch_size; strideB = K;

C = C_base + (m * batch_size * ragged_offset);
ldc = M * batch_size; strideC = M;

transa = rocblas_operation_transpose;
transb = rocblas_operation_none;
}
else {
M = k;
K = static_cast<int>(ragged_counts[i]);
N = m;

A = B_base + (k * batch_size * ragged_offset);
lda = k * batch_size; strideA = M;

B = A_base + (m * batch_size * ragged_offset);
ldb = m * batch_size; strideB = N;

C = C_base + (m * k * batch_size * i);
ldc = k; strideC = M * N;

transa = rocblas_operation_none;
transb = rocblas_operation_transpose;
}
ragged_offset += ragged_counts[i];

if(ragged_counts[i] > 0) {
if(std::is_same<T, float>::value) {
stat = rocblas_sgemm_strided_batched(handle,
transa, transb,
M, N, K,
reinterpret_cast<float*>(&alpha),
reinterpret_cast<float*>(A), lda, strideA,
reinterpret_cast<float*>(B), ldb, strideB,
reinterpret_cast<float*>(&beta),
reinterpret_cast<float*>(C), ldc, strideC,
batch_size);
}
else if(std::is_same<T, double>::value) {
stat = rocblas_dgemm_strided_batched(handle,
transa, transb,
M, N, K,
reinterpret_cast<double*>(&alpha),
reinterpret_cast<double*>(A), lda, strideA,
reinterpret_cast<double*>(B), ldb, strideB,
reinterpret_cast<double*>(&beta),
reinterpret_cast<double*>(C), ldc, strideC,
batch_size);
}
else {
throw std::logic_error("Unsupported datatype for grouped GEMM!");
}
if (stat != rocblas_status_success) {
throw std::logic_error("Grouped GEMM failed!");
}
}
}
}

void group_gemm_intptr(uint64_t weights,
uint64_t vectors, uint64_t output,
uint64_t ragged_counts, int m, int k, int ragged_inner) {
// TODO: To implement.
group_gemm(
reinterpret_cast<void*>(weights),
reinterpret_cast<void*>(vectors),
reinterpret_cast<void*>(output),
reinterpret_cast<int64_t*>(ragged_counts),
m, k, ragged_inner);
}

~GroupMMHIP() {
// TODO: To implement.
rocblas_destroy_handle(handle);
}
};
2 changes: 1 addition & 1 deletion openequivariance/extension/util/backend_hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class __attribute__((visibility("default"))) KernelLibrary {
public:
int device;
KernelLibrary(hiprtcProgram &prog, vector<char> &kernel_binary, vector<string> &kernel_names) {
hipGetDevice(&device);
HIP_ERRCHK(hipGetDevice(&device));
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));

for (size_t i = 0; i < kernel_names.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from openequivariance.implementations.symmetric_contraction.symmetric_contraction import (
SymmetricContraction,
)

__all__ = ["SymmetricContraction"]
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,124 @@ def forward(self, weights, vectors, bincounts):


# --------------------------------------------------------------------------
from typing import Dict, Optional, Union
# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106

import collections
from typing import Dict, Optional, Union, List

from e3nn import o3
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
from mace.tools.cg import U_matrix_real

_TP = collections.namedtuple("_TP", "op, args")
_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")


def _wigner_nj(
irrepss: List[o3.Irreps],
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irrepss = [o3.Irreps(irreps) for irreps in irrepss]
if filter_ir_mid is not None:
filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]

if len(irrepss) == 1:
(irreps,) = irrepss
ret = []
e = torch.eye(irreps.dim, dtype=dtype)
i = 0
for mul, ir in irreps:
for _ in range(mul):
sl = slice(i, i + ir.dim)
ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
i += ir.dim
return ret

*irrepss_left, irreps_right = irrepss
ret = []
for ir_left, path_left, C_left in _wigner_nj(
irrepss_left,
normalization=normalization,
filter_ir_mid=filter_ir_mid,
dtype=dtype,
):
i = 0
for mul, ir in irreps_right:
for ir_out in ir_left * ir:
if filter_ir_mid is not None and ir_out not in filter_ir_mid:
continue

C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
if normalization == "component":
C *= ir_out.dim**0.5
if normalization == "norm":
C *= ir_left.dim**0.5 * ir.dim**0.5

C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
C = C.reshape(
ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
)
for u in range(mul):
E = torch.zeros(
ir_out.dim,
*(irreps.dim for irreps in irrepss_left),
irreps_right.dim,
dtype=dtype,
)
sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
E[..., sl] = C
ret += [
(
ir_out,
_TP(
op=(ir_left, ir, ir_out),
args=(
path_left,
_INPUT(len(irrepss_left), sl.start, sl.stop),
),
),
E,
)
]
i += mul * ir.dim
return sorted(ret, key=lambda x: x[0])


def U_matrix_real(
irreps_in: Union[str, o3.Irreps],
irreps_out: Union[str, o3.Irreps],
correlation: int,
normalization: str = "component",
filter_ir_mid=None,
dtype=None,
):
irreps_out = o3.Irreps(irreps_out)
irrepss = [o3.Irreps(irreps_in)] * correlation

if correlation == 4:
filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)]

wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)

current_ir = wigners[0][0]
out = []
stack = torch.tensor([])

for ir, _, base_o3 in wigners:
if ir in irreps_out and ir == current_ir:
stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
last_ir = current_ir
elif ir in irreps_out and ir != current_ir:
if len(stack) != 0:
out += [last_ir, stack]
stack = base_o3.squeeze().unsqueeze(-1)
current_ir, last_ir = ir, ir
else:
current_ir = ir
out += [last_ir, stack]
return out


@compile_mode("script")
Expand Down