Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ We do not (yet) support:

If you have a use case for any of the unsupported features above, let us know.

## Multidevice / Stream Support
To use OpenEquivariance on multiple GPUs of a single
compute node, we currently require that all GPUs
share the same compute capability. This is because
our kernels are compiled based on the shared memory
capacity of the numerically first visible GPU card.
On heterogeneous systems, you can still
use OpenEquivariance on all GPUs that match the
compute capability of the first visible device.

We are working on support for CUDA streams!

## Citation and Acknowledgements
If you find this code useful, please cite our paper:

Expand Down
38 changes: 27 additions & 11 deletions openequivariance/extension/util/backend_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ class __attribute__((visibility("default"))) KernelLaunchConfig {
smem(smem)
{ }


KernelLaunchConfig(int64_t num_blocks_i, int64_t num_threads_i, int64_t smem_i) :
KernelLaunchConfig( static_cast<uint32_t>(num_blocks_i),
static_cast<uint32_t>(num_threads_i),
Expand All @@ -156,8 +155,8 @@ class __attribute__((visibility("default"))) CUJITKernel {

bool compiled = false;
char* code = nullptr;
int cu_major, cu_minor;

CUdevice dev;
CUlibrary library;

vector<string> kernel_names;
Expand Down Expand Up @@ -185,6 +184,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
}

void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list, int opt_level=3) {
DeviceProp dp(0); // We only query the first device on the system at the moment
cu_major = dp.major;
cu_minor = dp.minor;

if(compiled) {
throw std::logic_error("JIT object has already been compiled!");
}
Expand Down Expand Up @@ -215,8 +218,7 @@ class __attribute__((visibility("default"))) CUJITKernel {

}

DeviceProp dp(0); // TODO: We only query the first device at the moment
std::string sm = "-arch=sm_" + std::to_string(dp.major) + std::to_string(dp.minor);
std::string sm = "-arch=sm_" + std::to_string(cu_major) + std::to_string(cu_minor);

std::vector<const char*> opts = {
"--std=c++17",
Expand Down Expand Up @@ -268,19 +270,29 @@ class __attribute__((visibility("default"))) CUJITKernel {
kernels.emplace_back();
CUDA_SAFE_CALL(cuLibraryGetKernel(&(kernels[i]), library, name));
}

CUDA_SAFE_CALL(cuDeviceGet(&dev, 0));
}

void set_max_smem(int kernel_id, uint32_t max_smem_bytes) {
if(!compiled)
throw std::logic_error("JIT object has not been compiled!");
if(kernel_id >= kernels.size())
throw std::logic_error("Kernel index out of range!");

CUDA_SAFE_CALL(cuKernelSetAttribute(
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
max_smem_bytes,
kernels[kernel_id],
dev));
int device_count;
CUDA_SAFE_CALL(cuDeviceGetCount(&device_count));

for(int i = 0; i < device_count; i++) {
DeviceProp dp(i);
if(dp.major == cu_major && dp.minor == cu_minor) {
CUdevice dev;
CUDA_SAFE_CALL(cuDeviceGet(&dev, i));
CUDA_SAFE_CALL(cuKernelSetAttribute(
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
max_smem_bytes,
kernels[kernel_id],
dev));
}
}
}

void execute(int kernel_id, void* args[], KernelLaunchConfig config) {
Expand All @@ -291,6 +303,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
CUDA_SAFE_CALL(cuCtxGetCurrent(&pctx));

if(pctx == NULL) {
int device_id;
CUdevice dev;
CUDA_ERRCHK(cudaGetDevice(&device_id));
CUDA_SAFE_CALL(cuDeviceGet(&dev, device_id));
CUDA_SAFE_CALL(cuDevicePrimaryCtxRetain(&pctx, dev));
CUDA_SAFE_CALL(cuCtxSetCurrent(pctx));
}
Expand Down
87 changes: 50 additions & 37 deletions openequivariance/extension/util/backend_hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>
#include <string>
#include <iostream>
#include <memory>

using namespace std;

Expand Down Expand Up @@ -146,20 +147,54 @@ class __attribute__((visibility("default"))) KernelLaunchConfig {
{ }
};

class __attribute__((visibility("default"))) KernelLibrary {
hipModule_t library;
vector<hipFunction_t> kernels;

public:
int device;
KernelLibrary(hiprtcProgram &prog, vector<char> &kernel_binary, vector<string> &kernel_names) {
hipGetDevice(&device);
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));

for (size_t i = 0; i < kernel_names.size(); i++) {
const char *name;

HIPRTC_SAFE_CALL(hiprtcGetLoweredName(
prog,
kernel_names[i].c_str(), // name expression
&name // lowered name
));

kernels.emplace_back();
HIP_ERRCHK(hipModuleGetFunction(&(kernels[i]), library, name));
}
}

hipFunction_t operator[](int kernel_id) {
if(kernel_id >= kernels.size())
throw std::logic_error("Kernel index out of range!");

return kernels[kernel_id];
}

~KernelLibrary() {
HIP_ERRCHK(hipModuleUnload(library));
}
};

class __attribute__((visibility("default"))) HIPJITKernel {
private:
hiprtcProgram prog;

bool compiled = false;
char* code = nullptr;

hipModule_t library;

vector<string> kernel_names;
vector<hipFunction_t> kernels;
unique_ptr<KernelLibrary> kernels;

public:
string kernel_plaintext;
vector<char> kernel_binary;

HIPJITKernel(string plaintext) :
kernel_plaintext(plaintext) {

Expand Down Expand Up @@ -247,42 +282,24 @@ class __attribute__((visibility("default"))) HIPJITKernel {

size_t codeSize;
HIPRTC_SAFE_CALL(hiprtcGetCodeSize(prog, &codeSize));
code = new char[codeSize];
HIPRTC_SAFE_CALL(hiprtcGetCode(prog, code));

vector<char> kernel_binary(codeSize);
hiprtcGetCode(prog, kernel_binary.data());

//HIP_SAFE_CALL(cuInit(0));
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));

for (size_t i = 0; i < kernel_names.size(); i++) {
const char *name;

HIPRTC_SAFE_CALL(hiprtcGetLoweredName(
prog,
kernel_names[i].c_str(), // name expression
&name // lowered name
));

kernels.emplace_back();
HIP_ERRCHK(hipModuleGetFunction(&(kernels[i]), library, name));
}
kernel_binary.resize(codeSize);
hiprtcGetCode(prog, kernel_binary.data());
kernels.reset(new KernelLibrary(prog, kernel_binary, kernel_names));
}

void set_max_smem(int kernel_id, uint32_t max_smem_bytes) {
if(kernel_id >= kernels.size())
throw std::logic_error("Kernel index out of range!");

// Ignore for AMD GPUs
// Ignore for AMD GPUs
}

void execute(int kernel_id, void* args[], KernelLaunchConfig config) {
if(kernel_id >= kernels.size())
throw std::logic_error("Kernel index out of range!");
int device_id; HIP_ERRCHK(hipGetDevice(&device_id));
if(device_id != kernels->device) {
kernels.reset();
kernels.reset(new KernelLibrary(prog, kernel_binary, kernel_names));
}

HIP_ERRCHK(
hipModuleLaunchKernel( (kernels[kernel_id]),
hipModuleLaunchKernel( ((*kernels)[kernel_id]),
config.num_blocks, 1, 1, // grid dim
config.num_threads, 1, 1, // block dim
config.smem, config.hStream, // shared mem and stream
Expand All @@ -291,10 +308,6 @@ class __attribute__((visibility("default"))) HIPJITKernel {
}

~HIPJITKernel() {
if(compiled) {
HIP_ERRCHK(hipModuleUnload(library));
delete[] code;
}
HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog));
}
};
5 changes: 0 additions & 5 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import numpy as np
import openequivariance as oeq
from torch_geometric import EdgeIndex
from openequivariance.implementations.TensorProduct import TensorProduct
from openequivariance.benchmark.correctness_utils import correctness_forward, correctness_backward, correctness_double_backward

@pytest.fixture(scope='session')
def problem_and_irreps():
Expand All @@ -15,9 +13,6 @@ def problem_and_irreps():
shared_weights=False, internal_weights=False,
irrep_dtype=np.float32, weight_dtype=np.float32)

gen = torch.Generator(device='cuda')
gen.manual_seed(0)

return problem, X_ir, Y_ir, Z_ir,

@pytest.fixture(params=['batch', 'conv_det', 'conv_atomic'], scope='session')
Expand Down
45 changes: 45 additions & 0 deletions tests/multidevice_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import textwrap, torch, subprocess, os
import numpy as np

def test_multidevice():
result = subprocess.run([
"python", "-m", "torch.distributed.run",
"--standalone", "--nnodes=1", "--nproc-per-node=gpu",
__file__],
capture_output=True,
check=False)

if result.returncode != 0:
error_string = f'''
Invocation: {' '.join(result.args)}
Test failed with return code {result.returncode}.
\nOutput:\n\n{result.stdout.decode()}
\nError:\n\n{result.stderr.decode()}
'''
assert False, textwrap.dedent(error_string)

assert True

if __name__ == "__main__":
import openequivariance as oeq

# Use MACE-large to test >64KB shared memory allocation
from openequivariance.benchmark.benchmark_configs import mace_problems
problem = mace_problems[0]

local_rank = int(os.environ["LOCAL_RANK"])
device = f'cuda:{local_rank}'
torch.set_default_device(device)

X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out
tp = oeq.TensorProduct(problem)

batch_size = 1000
gen = torch.Generator(device=device)
gen.manual_seed(0)
X = torch.rand(batch_size, X_ir.dim, device=device, generator=gen)
Y = torch.rand(batch_size, Y_ir.dim, device=device, generator=gen)
W = torch.rand(batch_size, problem.weight_numel, device=device, generator=gen)

with torch.cuda.device(device):
result = tp.forward(X, Y, W)