From d649f9f53617ab80ae34b36c9bdd9f62a0d5643e Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 24 May 2025 21:38:39 -0700 Subject: [PATCH 01/15] adding pre-commit and linting / formatting via ruff --- .pre-commit-config.yaml | 9 +++++++++ pyproject.toml | 2 ++ 2 files changed, 11 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f65c7627 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.11.11 + hooks: + # Run the linter. + - id: ruff-check + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 99cc53ff..af40bc1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ bench = [ ] dev = [ + "pre-commit", + "ruff", "pytest", "pytest-check", "torch_geometric" From b8b1ae2d31d8b51f4a05b732bb0ed48849832bd8 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 24 May 2025 21:42:42 -0700 Subject: [PATCH 02/15] linting + remove transpose permutation in forward --- .../convolution/ConvolutionBase.py | 862 ++++++++++++------ 1 file changed, 589 insertions(+), 273 deletions(-) diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index 612427bf..568832a8 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -1,57 +1,73 @@ import copy +import typing import numpy as np -import numpy.linalg as la -from openequivariance.extlib import * -from openequivariance.benchmark.random_buffer_utils import * -from openequivariance.implementations.TensorProductBase import * - -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance import extlib +from openequivariance.extlib import DeviceBuffer, GPUTimer +from openequivariance.benchmark.random_buffer_utils import ( + get_random_buffers_forward_conv, + get_random_buffers_backward_conv, +) +from openequivariance.implementations.TensorProductBase import TensorProductBase + +from openequivariance.benchmark.logging_utils import getLogger, bcolors from openequivariance.benchmark.correctness_utils import check_similiarity + logger = getLogger() + def flops_data_per_tp(config, direction): - ''' + """ Assumes all interactions are "uvu" for now Returns (flops_per_tp, data_per_tp, nnz) - ''' - bytes_per_word = np.dtype(config.irrep_dtype).itemsize + """ + bytes_per_word = np.dtype(config.irrep_dtype).itemsize - assert(not config.shared_weights) + assert not config.shared_weights L1, L2, L3 = config.irreps_in1, config.irreps_in2, config.irreps_out ops_per_nz, words_per_tp = None, None if direction == "forward": ops_per_nz = 3 - words_per_tp = L1.dim + L2.dim + L3.dim + config.weight_numel + words_per_tp = L1.dim + L2.dim + L3.dim + config.weight_numel elif direction == "backward": ops_per_nz = 9 - words_per_tp = L1.dim + L2.dim + L3.dim + config.weight_numel \ - + L1.dim + L2.dim + config.weight_numel # Output gradients + words_per_tp = ( + L1.dim + + L2.dim + + L3.dim + + config.weight_numel + + L1.dim + + L2.dim + + config.weight_numel + ) # Output gradients ops_per_tp = 0 nnz = 0 - for (u, v, w, connection_mode, *others) in config.instructions: + for u, v, w, connection_mode, *others in config.instructions: tensor = TensorProductBase.load_cg_tensor(L1[u].ir.l, L2[v].ir.l, L3[w].ir.l) local_nnz = np.count_nonzero(tensor) nnz += local_nnz - ops_per_tp += ops_per_nz * local_nnz * L1[u].mul * L2[v].mul # Assumes L3.mult(w) = L1.mult(u) * L2.mult(v) + ops_per_tp += ( + ops_per_nz * local_nnz * L1[u].mul * L2[v].mul + ) # Assumes L3.mult(w) = L1.mult(u) * L2.mult(v) if connection_mode == "uvu": - ops_per_tp += L3[w].mul * (2 * L3[w].ir.l + 1) + ops_per_tp += L3[w].mul * (2 * L3[w].ir.l + 1) elif connection_mode == "uvw": ops_per_tp += L1[u].mul * L2[v].mul * L3[w].ir.dim * L3[w].mul return ops_per_tp, words_per_tp * bytes_per_word, nnz + class CoordGraph: def __init__(self, coords, rows, cols, name): - ''' + """ Because graphs may change constantly, this class is designed to be as light as possible. A directed edge from node u to v is indicated by the presence of an index i such that rows[i] = u, rows[i] = v. - ''' - assert(len(rows) == len(cols)) - self.nnz = len(rows) # Counts every nonzero in the adjacency matrix + """ + assert len(rows) == len(cols) + self.nnz = len(rows) # Counts every nonzero in the adjacency matrix self.node_count = coords.shape[0] self.coords = coords self.name = name @@ -69,12 +85,17 @@ def __init__(self, coords, rows, cols, name): triples.sort(key=lambda x: (x[0], x[1])) self.transpose_perm = np.array([x[2] for x in triples], dtype=self.rows.dtype) + class ConvolutionBase: - next_conv_id = 0 # Used to assign unique IDs to each conv instance + next_conv_id = 0 # Used to assign unique IDs to each conv instance def __init__(self, config, idx_dtype, torch_op=False, deterministic=False): - self.config = config - self.L1, self.L2, self.L3 = config.irreps_in1, config.irreps_in2, config.irreps_out + self.config = config + self.L1, self.L2, self.L3 = ( + config.irreps_in1, + config.irreps_in2, + config.irreps_out, + ) self.internal = None self.torch_op = torch_op self.idx_dtype = idx_dtype @@ -93,7 +114,9 @@ def __init__(self, config, idx_dtype, torch_op=False, deterministic=False): def allocate_workspace(self, size_bytes): self.workspace_size = size_bytes if self.torch_op: - self.workspace_buffer = torch.zeros(size_bytes, dtype=torch.uint8, device='cuda') + self.workspace_buffer = torch.zeros( + size_bytes, dtype=torch.uint8, device="cuda" + ) else: self.workspace_buffer = DeviceBuffer(size_bytes) self.workspace_ptr = self.workspace_buffer.data_ptr() @@ -103,20 +126,23 @@ def allocate_workspace(self, size_bytes): def name(): raise NotImplementedError() - def forward_cpu(self, - L1_in, L2_in, weights, L3_out, - graph): + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype - assert(graph.rows.dtype == self.idx_dtype) - assert(graph.cols.dtype == self.idx_dtype) - - weights_chunked = np.zeros_like(weights) + weights_chunked = np.zeros_like(weights) if self.reorder_weights_e3nn_to_oeq is not None: - self.reorder_weights_e3nn_to_oeq(weights, weights_chunked, not self.config.shared_weights) + self.reorder_weights_e3nn_to_oeq( + weights, weights_chunked, not self.config.shared_weights + ) else: weights_chunked = weights - L1_d, L2_d, weights_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(weights_chunked) + L1_d, L2_d, weights_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(weights_chunked), + ) L3_d = DeviceBuffer(L3_out) rows_d = DeviceBuffer(graph.rows) @@ -131,20 +157,22 @@ def forward_cpu(self, cols_d.data_ptr(), graph.nnz, graph.node_count, - self.workspace_ptr) + self.workspace_ptr, + ) L3_d.copy_to_host() - def backward_cpu(self, - L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, - L3_grad, graph): - - assert(graph.rows.dtype == self.idx_dtype) - assert(graph.cols.dtype == self.idx_dtype) + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph + ): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype - weights_chunked = np.zeros_like(weights) + weights_chunked = np.zeros_like(weights) if self.reorder_weights_e3nn_to_oeq is not None: - self.reorder_weights_e3nn_to_oeq(weights, weights_chunked, not self.config.shared_weights) + self.reorder_weights_e3nn_to_oeq( + weights, weights_chunked, not self.config.shared_weights + ) else: weights_chunked = weights @@ -154,26 +182,32 @@ def backward_cpu(self, L3_d = DeviceBuffer(L3_grad) rows_d = DeviceBuffer(graph.rows) cols_d = DeviceBuffer(graph.cols) - + L1_grad_d = DeviceBuffer(L1_grad) L2_grad_d = DeviceBuffer(L2_grad) weights_grad_d = DeviceBuffer(weights_grad) transpose_perm_d = None - transpose_perm_ptr = 0 + transpose_perm_ptr = 0 if self.deterministic: transpose_perm_d = DeviceBuffer(graph.transpose_perm) transpose_perm_ptr = transpose_perm_d.data_ptr() self.internal.backward_rawptrs( - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), L3_d.data_ptr(), - rows_d.data_ptr(), cols_d.data_ptr(), - graph.nnz, graph.node_count, + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, self.workspace_ptr, - transpose_perm_ptr) + transpose_perm_ptr, + ) L1_grad_d.copy_to_host() L2_grad_d.copy_to_host() @@ -181,54 +215,72 @@ def backward_cpu(self, if self.reorder_weights_oeq_to_e3nn is not None: weights_grad_copy = weights_grad.copy() - self.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not self.config.shared_weights) + self.reorder_weights_oeq_to_e3nn( + weights_grad_copy, weights_grad, not self.config.shared_weights + ) return L1_grad, L2_grad, weights_grad - def test_correctness_forward(self, - graph, thresh, prng_seed, reference_implementation=None, - check_reproducible=True, high_precision_ref=False): - + def test_correctness_forward( + self, + graph, + thresh, + prng_seed, + reference_implementation=None, + check_reproducible=True, + high_precision_ref=False, + ): if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv + reference_implementation = E3NNConv result = {"thresh": thresh} - in1, in2, weights, out = get_random_buffers_forward_conv(self.config, - graph.node_count, graph.nnz, prng_seed) - ref_in1, ref_in2, ref_weights, ref_out = [buf.copy() for buf in [in1, in2, weights, out]] + in1, in2, weights, out = get_random_buffers_forward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) + ref_in1, ref_in2, ref_weights, ref_out = [ + buf.copy() for buf in [in1, in2, weights, out] + ] - reference_config = self.config + reference_config = self.config if high_precision_ref: reference_config = copy.deepcopy(self.config) reference_config.irrep_dtype = np.float64 - reference_config.weight_dtype = np.float64 - ref_in1, ref_in2, ref_weights, ref_out = [np.array(el, dtype=np.float64) - for el in [ref_in1, ref_in2, ref_weights, ref_out]] - - args = {"L1_in": ref_in1, "L2_in": ref_in2, "weights": ref_weights, - "rows": graph.rows, "cols": graph.cols} + reference_config.weight_dtype = np.float64 + ref_in1, ref_in2, ref_weights, ref_out = [ + np.array(el, dtype=np.float64) + for el in [ref_in1, ref_in2, ref_weights, ref_out] + ] + + args = { + "L1_in": ref_in1, + "L2_in": ref_in2, + "weights": ref_weights, + "rows": graph.rows, + "cols": graph.cols, + } ref_tp = reference_implementation(reference_config) if ref_tp.deterministic: args["transpose_perm"] = graph.transpose_perm for key in args: - args[key] = torch.tensor(args[key], device='cuda') + args[key] = torch.tensor(args[key], device="cuda") ref_out[:] = ref_tp.forward(**args).cpu().numpy() test_out = out.copy() self.forward_cpu( - L1_in=in1.copy(), + L1_in=in1.copy(), L2_in=in2.copy(), weights=weights.copy(), L3_out=test_out, - graph=graph) + graph=graph, + ) - for name, to_check, ground_truth in [ - ("output", ref_out, test_out)]: + for name, to_check, ground_truth in [("output", ref_out, test_out)]: result[name] = check_similiarity(name, to_check, ground_truth, thresh) if check_reproducible: @@ -240,121 +292,170 @@ def test_correctness_forward(self, for i in range(num_trials): repeated_run = out.copy() self.forward_cpu( - L1_in=in1.copy(), + L1_in=in1.copy(), L2_in=in2.copy(), weights=weights.copy(), L3_out=repeated_run, - graph=graph) + graph=graph, + ) for name, to_check, ground_truth in [ - ("output", repeated_run, test_out)]: - result[name]["bitwise_reproducible"] = bool(result[name]["bitwise_reproducible"] - and np.all(repeated_run == test_out)) + ("output", repeated_run, test_out) + ]: + result[name]["bitwise_reproducible"] = bool( + result[name]["bitwise_reproducible"] + and np.all(repeated_run == test_out) + ) return result def benchmark_forward(self, num_warmup, num_iter, graph, prng_seed=12345): direction = "forward" - L1_in, L2_in, weights, L3_buffer = get_random_buffers_forward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + L1_in, L2_in, weights, L3_buffer = get_random_buffers_forward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) - assert(graph.rows.dtype == self.idx_dtype) - assert(graph.cols.dtype == self.idx_dtype) + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() if self.torch_op: - torch_L1_in = torch.tensor(L1_in, device='cuda') - torch_L2_in = torch.tensor(L2_in, device='cuda') - torch_weights = torch.tensor(weights, device='cuda') + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights, device="cuda") - torch_rows = torch.tensor(graph.rows, device='cuda') - torch_cols = torch.tensor(graph.cols, device='cuda') - torch_transpose_perm = torch.tensor(graph.transpose_perm, device='cuda') + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") if not self.deterministic: - for i in range(num_warmup): - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights, torch_rows, torch_cols) + for i in range(num_warmup): + self.forward( + torch_L1_in, torch_L2_in, torch_weights, torch_rows, torch_cols + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights, torch_rows, torch_cols) + self.forward( + torch_L1_in, torch_L2_in, torch_weights, torch_rows, torch_cols + ) time_millis[i] = timer.stop_clock_get_elapsed() else: - for i in range(num_warmup): - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights, torch_rows, - torch_cols, torch_transpose_perm) - + for i in range(num_warmup): + self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_transpose_perm, + ) + for i in range(num_iter): timer.clear_L2_cache() timer.start() - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights, torch_rows, - torch_cols, torch_transpose_perm) + self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_transpose_perm, + ) time_millis[i] = timer.stop_clock_get_elapsed() elif not self.torch_op: - L1_d, L2_d, weights_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(weights) + L1_d, L2_d, weights_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(weights), + ) L3_d = DeviceBuffer(L3_buffer) rows_d = DeviceBuffer(graph.rows) cols_d = DeviceBuffer(graph.cols) - transpose_perm_d = None - transpose_perm_ptr = 0 - if self.deterministic: - transpose_perm_d = DeviceBuffer(graph.transpose_perm) - transpose_perm_ptr = transpose_perm_d.data_ptr() - for i in range(num_warmup): self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), L2_d.data_ptr(), weights_d.data_ptr(), L3_d.data_ptr(), - rows_d.data_ptr(), cols_d.data_ptr(), graph.nnz, graph.node_count, - self.workspace_ptr) + L1_d.data_ptr(), + L2_d.data_ptr(), + weights_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), L2_d.data_ptr(), weights_d.data_ptr(), L3_d.data_ptr(), - rows_d.data_ptr(), cols_d.data_ptr(), graph.nnz, graph.node_count, - self.workspace_ptr) - time_millis[i] = timer.stop_clock_get_elapsed() + L1_d.data_ptr(), + L2_d.data_ptr(), + weights_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + ) + time_millis[i] = timer.stop_clock_get_elapsed() ops_per_tp, data_per_tp, _ = flops_data_per_tp(self.config, direction) - ops_per_tp += self.config.irreps_out.dim - - return self.calculate_bench_stats(direction, ops_per_tp, data_per_tp, - time_millis, graph, num_warmup, num_iter, prng_seed) + ops_per_tp += self.config.irreps_out.dim + return self.calculate_bench_stats( + direction, + ops_per_tp, + data_per_tp, + time_millis, + graph, + num_warmup, + num_iter, + prng_seed, + ) def benchmark_backward(self, num_warmup, num_iter, graph, prng_seed=12345): direction = "backward" - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( + get_random_buffers_backward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) + ) - assert(graph.rows.dtype == self.idx_dtype) - assert(graph.cols.dtype == self.idx_dtype) + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() if self.torch_op: - torch_L1_in = torch.tensor(in1, device='cuda', requires_grad=True) - torch_L2_in = torch.tensor(in2, device='cuda', requires_grad=True) - torch_weights = torch.tensor(weights, device='cuda', requires_grad=True) + torch_L1_in = torch.tensor(in1, device="cuda", requires_grad=True) + torch_L2_in = torch.tensor(in2, device="cuda", requires_grad=True) + torch_weights = torch.tensor(weights, device="cuda", requires_grad=True) - torch_rows = torch.tensor(graph.rows, device='cuda').detach() - torch_cols = torch.tensor(graph.cols, device='cuda').detach() - torch_transpose_perm = torch.tensor(graph.transpose_perm, device='cuda') + torch_rows = torch.tensor(graph.rows, device="cuda").detach() + torch_cols = torch.tensor(graph.cols, device="cuda").detach() + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") fwd_args = [torch_L1_in, torch_L2_in, torch_weights, torch_rows, torch_cols] if self.deterministic: fwd_args.append(torch_transpose_perm) torch_out = self.forward(*fwd_args) - torch_L3_grad = torch.tensor(out_grad, device='cuda') + torch_L3_grad = torch.tensor(out_grad, device="cuda") - for i in range(num_warmup): - torch_out.backward(torch_L3_grad, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + for i in range(num_warmup): + torch_out.backward( + torch_L3_grad, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) for i in range(num_iter): torch_L1_in.grad.zero_() @@ -363,7 +464,11 @@ def benchmark_backward(self, num_warmup, num_iter, graph, prng_seed=12345): timer.clear_L2_cache() timer.start() - torch_out.backward(torch_L3_grad, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + torch_out.backward( + torch_L3_grad, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) time_millis[i] = timer.stop_clock_get_elapsed() elif not self.torch_op: @@ -373,112 +478,163 @@ def benchmark_backward(self, num_warmup, num_iter, graph, prng_seed=12345): L3_d = DeviceBuffer(out_grad) rows_d = DeviceBuffer(graph.rows) cols_d = DeviceBuffer(graph.cols) - + L1_grad_d = DeviceBuffer(in1_grad) L2_grad_d = DeviceBuffer(in2_grad) weights_grad_d = DeviceBuffer(weights_grad) transpose_perm_d = None - transpose_perm_ptr = 0 + transpose_perm_ptr = 0 if self.deterministic: transpose_perm_d = DeviceBuffer(graph.transpose_perm) transpose_perm_ptr = transpose_perm_d.data_ptr() for i in range(num_warmup): self.internal.backward_rawptrs( - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), L3_d.data_ptr(), - rows_d.data_ptr(), cols_d.data_ptr(), - graph.nnz, graph.node_count, + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, self.workspace_ptr, - transpose_perm_ptr) + transpose_perm_ptr, + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() self.internal.backward_rawptrs( - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), L3_d.data_ptr(), - rows_d.data_ptr(), cols_d.data_ptr(), - graph.nnz, graph.node_count, + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, self.workspace_ptr, - transpose_perm_ptr) - time_millis[i] = timer.stop_clock_get_elapsed() + transpose_perm_ptr, + ) + time_millis[i] = timer.stop_clock_get_elapsed() ops_per_tp, data_per_tp, _ = flops_data_per_tp(self.config, direction) ops_per_tp += self.config.irreps_out.dim - return self.calculate_bench_stats(direction, ops_per_tp, data_per_tp, - time_millis, graph, num_warmup, num_iter, prng_seed) - - def calculate_bench_stats(self, direction, ops_per_tp, data_per_tp, time_millis, - graph, num_warmup, num_iter, prng_seed): - throughputs_gflops = [float(el) for el in graph.nnz * ops_per_tp / (time_millis * 1e6)] - bandwidth_gbps = [float(el) for el in graph.nnz * data_per_tp / (time_millis * 1e6)] - time_millis = [float(el) for el in time_millis] + return self.calculate_bench_stats( + direction, + ops_per_tp, + data_per_tp, + time_millis, + graph, + num_warmup, + num_iter, + prng_seed, + ) + + def calculate_bench_stats( + self, + direction, + ops_per_tp, + data_per_tp, + time_millis, + graph, + num_warmup, + num_iter, + prng_seed, + ): + throughputs_gflops = [ + float(el) for el in graph.nnz * ops_per_tp / (time_millis * 1e6) + ] + bandwidth_gbps = [ + float(el) for el in graph.nnz * data_per_tp / (time_millis * 1e6) + ] + time_millis = [float(el) for el in time_millis] result = { "direction": direction, "flops_per_tp": ops_per_tp, "data_per_tp": data_per_tp, - "time_millis": list(time_millis), "throughputs_gflops": list(throughputs_gflops), "bandwidth_gbps": list(bandwidth_gbps), - "L1": str(self.config.irreps_in1), - "L2": str(self.config.irreps_in2), + "L2": str(self.config.irreps_in2), "L3": str(self.config.irreps_out), "graph_node_count": graph.node_count, "graph_adj_nnz": graph.nnz, "num_warmup": num_warmup, "num_iter": num_iter, "prng_seed": prng_seed, - "time_millis": time_millis, - "throughputs_gflops": throughputs_gflops, - "bandwidth_gbps": bandwidth_gbps } - logger.info(f"{bcolors.OKCYAN}Avg. Throughput: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(throughputs_gflops):.2f} ± {np.std(throughputs_gflops):.2f} GFLOPs{bcolors.ENDC}") - logger.info(f"{bcolors.OKCYAN}Avg. Bandwidth: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(bandwidth_gbps):.2f} ± {np.std(bandwidth_gbps):.2f} GBPs{bcolors.ENDC}") + logger.info( + f"{bcolors.OKCYAN}Avg. Throughput: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(throughputs_gflops):.2f} ± {np.std(throughputs_gflops):.2f} GFLOPs{bcolors.ENDC}" + ) + logger.info( + f"{bcolors.OKCYAN}Avg. Bandwidth: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(bandwidth_gbps):.2f} ± {np.std(bandwidth_gbps):.2f} GBPs{bcolors.ENDC}" + ) return result - def test_correctness_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): - L1, L2, L3 = self.L1, self.L2, self.L3 - + def test_correctness_backward( + self, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, + ): if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv + reference_implementation = E3NNConv result = {"thresh": thresh} - buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + buffers = get_random_buffers_backward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) reference_buffers = [buf.copy() for buf in buffers] reference_problem = self.config if high_precision_ref: reference_problem = copy.deepcopy(self.config) reference_problem.irrep_dtype = np.float64 - reference_problem.weight_dtype = np.float64 - reference_buffers = [np.array(el, dtype=np.float64) for el in reference_buffers] + reference_problem.weight_dtype = np.float64 + reference_buffers = [ + np.array(el, dtype=np.float64) for el in reference_buffers + ] ref_tp = reference_implementation(reference_problem) in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = buffers - ref_in1, ref_in2, ref_out_grad, ref_weights, ref_weights_grad, ref_in1_grad, ref_in2_grad = reference_buffers + ( + ref_in1, + ref_in2, + ref_out_grad, + ref_weights, + ref_weights_grad, + ref_in1_grad, + ref_in2_grad, + ) = reference_buffers ref_tp.backward_cpu( L1_in=ref_in1, L1_grad=ref_in1_grad, L2_in=ref_in2, - L2_grad=ref_in2_grad, + L2_grad=ref_in2_grad, L3_grad=ref_out_grad, weights=ref_weights, weights_grad=ref_weights_grad, - graph=graph) + graph=graph, + ) # run test version test_weights_grad = weights_grad.copy() @@ -488,34 +644,46 @@ def test_correctness_backward(self, graph, thresh, prng_seed, reference_implemen self.backward_cpu( L1_in=in1.copy(), L1_grad=test_in1_grad, - L2_in=in2.copy(), - L2_grad=test_in2_grad, + L2_in=in2.copy(), + L2_grad=test_in2_grad, L3_grad=out_grad.copy(), weights=weights.copy(), weights_grad=test_weights_grad, - graph=graph) + graph=graph, + ) for name, to_check, ground_truth, threshold in [ - ("weight_grad", test_weights_grad, ref_weights_grad, thresh), - ("in1_grad", test_in1_grad, ref_in1_grad, thresh), - ("in2_grad", test_in2_grad, ref_in2_grad, thresh)]: + ("weight_grad", test_weights_grad, ref_weights_grad, thresh), + ("in1_grad", test_in1_grad, ref_in1_grad, thresh), + ("in2_grad", test_in2_grad, ref_in2_grad, thresh), + ]: result[name] = check_similiarity(name, to_check, ground_truth, threshold) return result - def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): + def test_correctness_double_backward( + self, + graph, + thresh, + prng_seed, + reference_implementation=None, + high_precision_ref=False, + ): global torch import torch - assert(self.torch_op) - buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + assert self.torch_op + buffers = get_random_buffers_backward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) rng = np.random.default_rng(seed=prng_seed * 2) dummy_grad_value = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv - reference_implementation = E3NNConv + from openequivariance.implementations.convolution.E3NNConv import E3NNConv + + reference_implementation = E3NNConv reference_problem = self.config if high_precision_ref: @@ -531,126 +699,186 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers] if i == 1 and high_precision_ref: - in1, in2, out_grad, weights, _, _, _ = [np.array(el, dtype=np.float64) for el in buffers] - - in1_torch = torch.tensor(in1, device='cuda', requires_grad=True) - in2_torch = torch.tensor(in2, device='cuda', requires_grad=True) + in1, in2, out_grad, weights, _, _, _ = [ + np.array(el, dtype=np.float64) for el in buffers + ] + + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) - weights_reordered = np.zeros_like(weights) + weights_reordered = np.zeros_like(weights) if i == 0 and self.reorder_weights_e3nn_to_oeq is not None: - self.reorder_weights_e3nn_to_oeq(weights, weights_reordered, not self.config.shared_weights) + self.reorder_weights_e3nn_to_oeq( + weights, weights_reordered, not self.config.shared_weights + ) else: weights_reordered[:] = weights - weights_torch = torch.tensor(weights_reordered, device='cuda', requires_grad=True) + weights_torch = torch.tensor( + weights_reordered, device="cuda", requires_grad=True + ) - torch_rows = torch.tensor(graph.rows, device='cuda') - torch_cols = torch.tensor(graph.cols, device='cuda') - torch_transpose_perm = torch.tensor(graph.transpose_perm, device='cuda') + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") fwd_args = [in1_torch, in2_torch, weights_torch, torch_rows, torch_cols] if tp.deterministic: fwd_args.append(torch_transpose_perm) out_torch = tp.forward(*fwd_args) - out_grad_torch = torch.tensor(out_grad, device='cuda', requires_grad=True) + out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) in1_grad, in2_grad, w_grad = torch.autograd.grad( outputs=[out_torch], inputs=[in1_torch, in2_torch, weights_torch], grad_outputs=[out_grad_torch], - create_graph=True) + create_graph=True, + ) dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor(float(dummy_grad_value), device='cuda', requires_grad=True) - dummy.backward(dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch]) - + dummy_grad = torch.tensor( + float(dummy_grad_value), device="cuda", requires_grad=True + ) + dummy.backward( + dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch] + ) + weights_grad = weights_torch.grad.detach().cpu().numpy() if i == 0 and self.reorder_weights_oeq_to_e3nn is not None: weights_grad_copy = weights_grad.copy() - self.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not self.config.shared_weights) - - tensors.append(( - out_grad_torch.grad.detach().cpu().numpy().copy(), - in1_torch.grad.detach().cpu().numpy().copy(), - in2_torch.grad.detach().cpu().numpy().copy(), - weights_grad.copy() - )) + self.reorder_weights_oeq_to_e3nn( + weights_grad_copy, weights_grad, not self.config.shared_weights + ) + + tensors.append( + ( + out_grad_torch.grad.detach().cpu().numpy().copy(), + in1_torch.grad.detach().cpu().numpy().copy(), + in2_torch.grad.detach().cpu().numpy().copy(), + weights_grad.copy(), + ) + ) for name, to_check, ground_truth in [ ("output_grad", tensors[0][0], tensors[1][0]), ("in1_grad", tensors[0][1], tensors[1][1]), ("in2_grad", tensors[0][2], tensors[1][2]), - ("weights_grad", tensors[0][3], tensors[1][3]) - ]: + ("weights_grad", tensors[0][3], tensors[1][3]), + ]: result[name] = check_similiarity(name, to_check, ground_truth, thresh) return result - def setup_torch_module(self): if not extlib.TORCH_COMPILE: self.setup_nocompile_ops() def setup_nocompile_ops(self): - ''' + """ Need two different functions depending on whether the convolution is deterministic. - ''' + """ if not self.deterministic: - @torch.library.custom_op(f"openequivariance::conv_forward{self.conv_id}", mutates_args=(), device_types="cuda") - def forward(L1_in : torch.Tensor, L2_in : torch.Tensor, - weights : torch.Tensor, rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous() - L3_out = torch.zeros((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda') - self.internal.exec_conv_rawptrs(L1_in_c.data_ptr(), L2_in_c.data_ptr(), - weights_c.data_ptr(), L3_out.data_ptr(), - rows.data_ptr(), cols.data_ptr(), - cols.shape[0], L1_in.shape[0], self.workspace_ptr) + @torch.library.custom_op( + f"openequivariance::conv_forward{self.conv_id}", + mutates_args=(), + device_types="cuda", + ) + def forward( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + rows: torch.Tensor, + cols: torch.Tensor, + ) -> torch.Tensor: + L1_in_c, L2_in_c, weights_c = ( + L1_in.contiguous(), + L2_in.contiguous(), + weights.contiguous(), + ) + L3_out = torch.zeros( + (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device="cuda" + ) + + self.internal.exec_conv_rawptrs( + L1_in_c.data_ptr(), + L2_in_c.data_ptr(), + weights_c.data_ptr(), + L3_out.data_ptr(), + rows.data_ptr(), + cols.data_ptr(), + cols.shape[0], + L1_in.shape[0], + self.workspace_ptr, + ) return L3_out - + @forward.register_fake def _(L1_in, L2_in, weights, rows, cols): return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - + self.forward = forward - - @torch.library.custom_op(f"openequivariance::conv_backward{self.conv_id}", mutates_args=(), device_types="cuda") - def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor, - weights : torch.Tensor, L3_grad : torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor) -> typing.List[torch.Tensor]: + + @torch.library.custom_op( + f"openequivariance::conv_backward{self.conv_id}", + mutates_args=(), + device_types="cuda", + ) + def backward_helper( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + L3_grad: torch.Tensor, + rows: torch.Tensor, + cols: torch.Tensor, + ) -> typing.List[torch.Tensor]: L1_grad = torch.zeros_like(L1_in) L2_grad = torch.empty_like(L2_in) weights_grad = torch.empty_like(weights) self.internal.backward_rawptrs( - L1_in.contiguous().data_ptr(), L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), L2_grad.data_ptr(), - weights.contiguous().data_ptr(), weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - rows.data_ptr(), cols.data_ptr(), - rows.shape[0], L1_in.shape[0], - self.workspace_ptr, - 0) - + L1_in.contiguous().data_ptr(), + L1_grad.data_ptr(), + L2_in.contiguous().data_ptr(), + L2_grad.data_ptr(), + weights.contiguous().data_ptr(), + weights_grad.data_ptr(), + L3_grad.contiguous().data_ptr(), + rows.data_ptr(), + cols.data_ptr(), + rows.shape[0], + L1_in.shape[0], + self.workspace_ptr, + 0, + ) + return [L1_grad, L2_grad, weights_grad] - + @backward_helper.register_fake def _(L1_in, L2_in, weights, L3_grad, rows, cols): - return [L1_in.new_empty(*L1_in.shape), L2_in.new_empty(*L2_in.shape), weights.new_empty(*weights.shape)] + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + weights.new_empty(*weights.shape), + ] def setup_context(ctx, inputs, output): ctx.L1_in, ctx.L2_in, ctx.weights, ctx.rows, ctx.cols = inputs - + def backward(ctx, grad_output): - result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output, ctx.rows, ctx.cols) + result = backward_helper( + ctx.L1_in, ctx.L2_in, ctx.weights, grad_output, ctx.rows, ctx.cols + ) return result[0], result[1], result[2], None, None self.forward.register_autograd(backward, setup_context=setup_context) def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, ctx.rows, ctx.cols = inputs + ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, ctx.rows, ctx.cols = ( + inputs + ) def double_backward(ctx, grad_output): A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights @@ -660,73 +888,151 @@ def double_backward(ctx, grad_output): op1 = backward_helper(E, F, D, C, rows, cols) op2 = backward_helper(A, B, G, C, rows, cols) op3 = forward(E, B, D, rows, cols) - op4 = backward_helper(E, B, D, C, rows, cols) # op4 and op5 could be combined with op3 and op6 - op5 = backward_helper(A, F, D, C, rows, cols) + op4 = backward_helper( + E, B, D, C, rows, cols + ) # op4 and op5 could be combined with op3 and op6 + op5 = backward_helper(A, F, D, C, rows, cols) op6 = forward(A, F, D, rows, cols) op7 = forward(A, B, G, rows, cols) - return op1[0] + op2[0], op1[1] + op2[1], (op4[2] + op5[2]), (op3 + op6 + op7), None, None + return ( + op1[0] + op2[0], + op1[1] + op2[1], + (op4[2] + op5[2]), + (op3 + op6 + op7), + None, + None, + ) - backward_helper.register_autograd(double_backward, setup_context=setup_context_double_backward) + backward_helper.register_autograd( + double_backward, setup_context=setup_context_double_backward + ) else: - @torch.library.custom_op(f"openequivariance::conv_forward{self.conv_id}", mutates_args=(), device_types="cuda") - def forward(L1_in : torch.Tensor, L2_in : torch.Tensor, - weights : torch.Tensor, rows: torch.Tensor, cols: torch.Tensor, transpose_perm: torch.Tensor) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous() - L3_out = torch.zeros((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda') - self.internal.exec_conv_rawptrs(L1_in_c.data_ptr(), L2_in_c.data_ptr(), - weights_c.data_ptr(), L3_out.data_ptr(), - rows.data_ptr(), cols.data_ptr(), - rows.shape[0], L1_in.shape[0], self.workspace_ptr) + @torch.library.custom_op( + f"openequivariance::conv_forward{self.conv_id}", + mutates_args=(), + device_types="cuda", + ) + def forward( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + rows: torch.Tensor, + cols: torch.Tensor, + transpose_perm: torch.Tensor, + ) -> torch.Tensor: + L1_in_c, L2_in_c, weights_c = ( + L1_in.contiguous(), + L2_in.contiguous(), + weights.contiguous(), + ) + L3_out = torch.zeros( + (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device="cuda" + ) + + self.internal.exec_conv_rawptrs( + L1_in_c.data_ptr(), + L2_in_c.data_ptr(), + weights_c.data_ptr(), + L3_out.data_ptr(), + rows.data_ptr(), + cols.data_ptr(), + rows.shape[0], + L1_in.shape[0], + self.workspace_ptr, + ) return L3_out - + @forward.register_fake def _(L1_in, L2_in, weights, rows, cols, transpose_perm): return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - + self.forward = forward - - @torch.library.custom_op(f"openequivariance::conv_backward{self.conv_id}", mutates_args=(), device_types="cuda") - def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor, - weights : torch.Tensor, L3_grad : torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor, transpose_perm: torch.Tensor) -> typing.List[torch.Tensor]: + + @torch.library.custom_op( + f"openequivariance::conv_backward{self.conv_id}", + mutates_args=(), + device_types="cuda", + ) + def backward_helper( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + L3_grad: torch.Tensor, + rows: torch.Tensor, + cols: torch.Tensor, + transpose_perm: torch.Tensor, + ) -> typing.List[torch.Tensor]: L1_grad = torch.zeros_like(L1_in) L2_grad = torch.empty_like(L2_in) weights_grad = torch.empty_like(weights) if self.config.shared_weights: - weights_grad[:] = 0.0 + weights_grad[:] = 0.0 self.internal.backward_rawptrs( - L1_in.contiguous().data_ptr(), L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), L2_grad.data_ptr(), - weights.contiguous().data_ptr(), weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - rows.data_ptr(), cols.data_ptr(), - rows.shape[0], L1_in.shape[0], - self.workspace_ptr, - transpose_perm.data_ptr()) - + L1_in.contiguous().data_ptr(), + L1_grad.data_ptr(), + L2_in.contiguous().data_ptr(), + L2_grad.data_ptr(), + weights.contiguous().data_ptr(), + weights_grad.data_ptr(), + L3_grad.contiguous().data_ptr(), + rows.data_ptr(), + cols.data_ptr(), + rows.shape[0], + L1_in.shape[0], + self.workspace_ptr, + transpose_perm.data_ptr(), + ) + return [L1_grad, L2_grad, weights_grad] - + @backward_helper.register_fake def _(L1_in, L2_in, weights, L3_grad, rows, cols, transpose_perm): - return [L1_in.new_empty(*L1_in.shape), L2_in.new_empty(*L2_in.shape), weights.new_empty(*weights.shape)] + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + weights.new_empty(*weights.shape), + ] def setup_context(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.rows, ctx.cols, ctx.transpose_perm = inputs - + ( + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.rows, + ctx.cols, + ctx.transpose_perm, + ) = inputs + def backward(ctx, grad_output): - result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output, ctx.rows, ctx.cols, ctx.transpose_perm) + result = backward_helper( + ctx.L1_in, + ctx.L2_in, + ctx.weights, + grad_output, + ctx.rows, + ctx.cols, + ctx.transpose_perm, + ) return result[0], result[1], result[2], None, None, None self.forward.register_autograd(backward, setup_context=setup_context) def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, ctx.rows, ctx.cols, ctx.transpose_perm = inputs + ( + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + ctx.rows, + ctx.cols, + ctx.transpose_perm, + ) = inputs def double_backward(ctx, grad_output): A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights @@ -737,10 +1043,20 @@ def double_backward(ctx, grad_output): op2 = backward_helper(A, B, G, C, rows, cols, transpose_perm) op3 = forward(E, B, D, rows, cols, transpose_perm) op4 = backward_helper(E, B, D, C, rows, cols, transpose_perm) - op5 = backward_helper(A, F, D, C, rows, cols, transpose_perm) + op5 = backward_helper(A, F, D, C, rows, cols, transpose_perm) op6 = forward(A, F, D, rows, cols, transpose_perm) op7 = forward(A, B, G, rows, cols, transpose_perm) - return op1[0] + op2[0], op1[1] + op2[1], (op4[2] + op5[2]), (op3 + op6 + op7), None, None, None - - backward_helper.register_autograd(double_backward, setup_context=setup_context_double_backward) \ No newline at end of file + return ( + op1[0] + op2[0], + op1[1] + op2[1], + (op4[2] + op5[2]), + (op3 + op6 + op7), + None, + None, + None, + ) + + backward_helper.register_autograd( + double_backward, setup_context=setup_context_double_backward + ) From bec7c161e17a0da409c5cb6c09917b47d5232fde Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Sat, 24 May 2025 21:45:45 -0700 Subject: [PATCH 03/15] readme tutorial formatted. Can exclude if preferred --- examples/readme_tutorial.py | 70 ++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/examples/readme_tutorial.py b/examples/readme_tutorial.py index 16d3a626..57a88dcf 100644 --- a/examples/readme_tutorial.py +++ b/examples/readme_tutorial.py @@ -1,26 +1,29 @@ -# Examples from the README +# ruff: noqa: E402 +# Examples from the README import logging from openequivariance.benchmark.logging_utils import getLogger + logger = getLogger() logger.setLevel(logging.ERROR) -# UVU Tensor Product +# UVU Tensor Product # =============================== import torch import e3nn.o3 as o3 -gen = torch.Generator(device='cuda') +gen = torch.Generator(device="cuda") batch_size = 1000 -X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") -X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) -Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) +X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") +X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) +Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) -instructions=[(0, 0, 0, "uvu", True)] +instructions = [(0, 0, 0, "uvu", True)] -tp_e3nn = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, - shared_weights=False, internal_weights=False).to('cuda') -W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen) +tp_e3nn = o3.TensorProduct( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False +).to("cuda") +W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) Z = tp_e3nn(X, Y, W) print(torch.norm(Z)) @@ -29,10 +32,12 @@ # =============================== import openequivariance as oeq -problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) +problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False +) tp_fast = oeq.TensorProduct(problem, torch_op=True) -Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier +Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier print(torch.norm(Z)) # =============================== @@ -44,26 +49,35 @@ # Receiver, sender indices for message passing GNN edge_index = EdgeIndex( - [[0, 1, 1, 2], # Receiver - [1, 0, 2, 1]], # Sender - device='cuda', - dtype=torch.long) - -X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen) -Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen) -W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen) - -tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier -Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) # Z has shape [node_ct, z_ir.dim] + [ + [0, 1, 1, 2], # Receiver + [1, 0, 2, 1], # Sender + ], + device="cuda", + dtype=torch.long, +) + +X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) +Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) +W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) + +tp_conv = oeq.TensorProductConv( + problem, torch_op=True, deterministic=False +) # Reuse problem from earlier +Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] +) # Z has shape [node_ct, z_ir.dim] print(torch.norm(Z)) # =============================== # =============================== -_, sender_perm = edge_index.sort_by("col") # Sort by sender index -edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index +_, sender_perm = edge_index.sort_by("col") # Sort by sender index +edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index # Now we can use the faster deterministic algorithm -tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) -Z = tp_conv.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm) +tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) +Z = tp_conv.forward( + X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm +) print(torch.norm(Z)) -# =============================== \ No newline at end of file +# =============================== From 7b8da85fee12c1c5ba17593a9862e037677c03fb Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Mon, 26 May 2025 14:20:06 -0700 Subject: [PATCH 04/15] ruff format and lint --- io/cif_to_graph.py | 38 +- io/load_nequip_configs.py | 24 +- openequivariance/__init__.py | 20 +- .../benchmark/ConvBenchmarkSuite.py | 149 +++-- .../benchmark/TestBenchmarkSuite.py | 206 +++---- .../benchmark/benchmark_configs.py | 159 +++-- .../benchmark_routines/paper_benchmark_uvw.py | 79 ++- openequivariance/benchmark/benchmark_utils.py | 424 +++++++------ .../benchmark/correctness_utils.py | 255 ++++---- openequivariance/benchmark/logging_utils.py | 23 +- .../benchmark/perf_metrics_utils.py | 72 ++- .../benchmark/plotting/__init__.py | 13 +- .../benchmark/plotting/plot_convolution.py | 93 ++- .../plotting/plot_double_backward.py | 107 +++- .../benchmark/plotting/plot_roofline.py | 103 +++- .../benchmark/plotting/plot_uvu.py | 128 +++- .../benchmark/plotting/plot_uvw.py | 142 +++-- .../benchmark/plotting/plotting_utils.py | 351 +++++++---- .../benchmark/random_buffer_utils.py | 149 +++-- .../benchmark/tpp_creation_utils.py | 49 +- openequivariance/extlib/__init__.py | 100 +++- .../implementations/CUETensorProduct.py | 281 +++++---- .../implementations/ComputationSchedule.py | 494 +++++++++++----- .../implementations/E3NNTensorProduct.py | 159 ++--- .../implementations/LoopUnrollTP.py | 295 +++++---- .../MultiplicityOuterProductTP.py | 149 +++-- .../implementations/TensorProduct.py | 15 +- .../implementations/TensorProductBase.py | 558 +++++++++++------- .../implementations/convolution/CUEConv.py | 44 +- .../implementations/convolution/E3NNConv.py | 71 ++- .../convolution/LoopUnrollConv.py | 342 ++++++++--- .../convolution/TensorProductConv.py | 134 +++-- .../implementations/convolution/scatter.py | 11 +- openequivariance/implementations/e3nn_lite.py | 186 ++++-- .../symmetric_contraction/STPOpt.py | 208 ++++--- openequivariance/implementations/utils.py | 90 +-- openequivariance/templates/jinja_utils.py | 22 +- tests/batch_test.py | 227 ++++--- tests/benchmark.py | 500 +++++++++++----- tests/conv_test.py | 202 ++++--- tests/export_test.py | 118 ++-- tests/import_test.py | 67 ++- tests/mace_driver.py | 195 +++--- 43 files changed, 4546 insertions(+), 2506 deletions(-) diff --git a/io/cif_to_graph.py b/io/cif_to_graph.py index 69cc2980..cb90b5dd 100644 --- a/io/cif_to_graph.py +++ b/io/cif_to_graph.py @@ -1,42 +1,44 @@ import pickle import numpy as np from sklearn.neighbors import radius_neighbors_graph -from scipy.io import mmwrite + def cif_to_molecular_graph(cif_file, cp, radii): - with open(f'../data/cif_files/{cif_file}', 'r') as f: + with open(f"../data/cif_files/{cif_file}", "r") as f: print("Started reading file...") lines = f.readlines() print("Finished reading file!") coords = [] for line in lines: - if line.startswith('ATOM'): + if line.startswith("ATOM"): parts = line.split() - coords.append([float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])]) + coords.append( + [float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])] + ) coords = np.array(coords) for radius in radii: print(f"Starting radius neighbors calculation, r={radius}") - A = radius_neighbors_graph(coords, radius, mode='connectivity', - include_self=False) - print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.") - + A = radius_neighbors_graph( + coords, radius, mode="connectivity", include_self=False + ) + print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.") + # mmwrite(f'../data/molecular_structures/{cif_file.split(".")[0]}.mtx', A) coo_mat = A.tocoo() - result = { - 'row': coo_mat.row, - 'col': coo_mat.col, - 'coords': coords - } + result = {"row": coo_mat.row, "col": coo_mat.col, "coords": coords} - with open(f'../data/molecular_structures/{cif_file.split(".")[0]}_radius{radius}.pickle', 'wb') as handle: + with open( + f"../data/molecular_structures/{cif_file.split('.')[0]}_radius{radius}.pickle", + "wb", + ) as handle: pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL) -if __name__=='__main__': - #cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) - cif_to_molecular_graph('covid_spike.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) - cif_to_molecular_graph('1drf.cif', (10, 11, 12), radii=[6.0]) \ No newline at end of file +if __name__ == "__main__": + # cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) + cif_to_molecular_graph("covid_spike.cif", (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) + cif_to_molecular_graph("1drf.cif", (10, 11, 12), radii=[6.0]) diff --git a/io/load_nequip_configs.py b/io/load_nequip_configs.py index 5e2ef50b..cd2fb4ad 100644 --- a/io/load_nequip_configs.py +++ b/io/load_nequip_configs.py @@ -1,27 +1,29 @@ -''' +""" This script parse the repository of Nequip input files at https://github.com/mir-group/nequip-input-files. We extract the node / edge hidden features representations. -''' +""" + +import os +import yaml -import os, yaml def process_nequip_configs(): nequip_files = [] - for root, dirs, files in os.walk('../data/nequip-input-files'): + for root, dirs, files in os.walk("../data/nequip-input-files"): for file in files: - if file.endswith('.yaml'): + if file.endswith(".yaml"): nequip_files.append(os.path.join(root, file)) - + irrep_pairs = [] configs = [] for file in nequip_files: - with open(file, 'r') as f: + with open(file, "r") as f: data = yaml.unsafe_load(f) filename = os.path.splitext(os.path.basename(file))[0] - feature_irreps_hidden = data['feature_irreps_hidden'] - irreps_edge_sh = data['irreps_edge_sh'] + feature_irreps_hidden = data["feature_irreps_hidden"] + irreps_edge_sh = data["irreps_edge_sh"] if (feature_irreps_hidden, irreps_edge_sh) not in irrep_pairs: irrep_pairs.append((feature_irreps_hidden, irreps_edge_sh)) configs.append((feature_irreps_hidden, irreps_edge_sh, filename)) @@ -30,5 +32,5 @@ def process_nequip_configs(): print(config) -if __name__ == '__main__': - process_nequip_configs() \ No newline at end of file +if __name__ == "__main__": + process_nequip_configs() diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 20120352..21488786 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -1,18 +1,32 @@ +# ruff: noqa: F401 import openequivariance.extlib from pathlib import Path from importlib.metadata import version from openequivariance.implementations.e3nn_lite import TPProblem, Irreps -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.convolution.TensorProductConv import TensorProductConv +from openequivariance.implementations.TensorProduct import TensorProduct +from openequivariance.implementations.convolution.TensorProductConv import ( + TensorProductConv, +) from openequivariance.implementations.utils import torch_to_oeq_dtype +__all__ = [ + "TPProblem", + "Irreps", + "TensorProduct", + "TensorProductConv", + "torch_to_oeq_dtype", +] + __version__ = version("openequivariance") + def _check_package_editable(): import json from importlib.metadata import Distribution + direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json") return json.loads(direct_url).get("dir_info", {}).get("editable", False) -_editable_install_output_path = Path(__file__).parent.parent / "outputs" \ No newline at end of file + +_editable_install_output_path = Path(__file__).parent.parent / "outputs" diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/benchmark/ConvBenchmarkSuite.py index 35b17600..a4b7c982 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -1,105 +1,132 @@ -import json, os, time, pickle, pathlib -import numpy as np -import numpy.linalg as la +import json import os +import time +import pickle +import pathlib +import numpy as np import openequivariance as oeq -from openequivariance.benchmark.logging_utils import * -from openequivariance.implementations.convolution.ConvolutionBase import * +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph + logger = getLogger() + def load_graph(filename): coords, rows, cols = [None] * 3 name = pathlib.Path(filename).stem - with open(filename, 'rb') as f: + with open(filename, "rb") as f: logger.info(f"Loading {name} from pickle...") result = pickle.load(f) coords, rows, cols, name = result["coords"], result["row"], result["col"], name - logger.info(f"Graph {name} loaded with {len(coords)} nodes and {len(rows)} edges.") + logger.info( + f"Graph {name} loaded with {len(coords)} nodes and {len(rows)} edges." + ) return CoordGraph(coords, rows.astype(np.int64), cols.astype(np.int64), name) + class ConvBenchmarkSuite: - def __init__(self, configs, - num_warmup = 10, - num_iter = 30, - reference_impl=None, - test_name=None, - prng_seed = 12345, - correctness_threshold = 1e-5): + def __init__( + self, + configs, + num_warmup=10, + num_iter=30, + reference_impl=None, + test_name=None, + prng_seed=12345, + correctness_threshold=1e-5, + ): self.configs = configs self.num_warmup = num_warmup self.num_iter = num_iter self.reference_impl = reference_impl self.prng_seed = 12345 - self.correctness_threshold = correctness_threshold + self.correctness_threshold = correctness_threshold self.exp_count = 0 - self.test_name = test_name + self.test_name = test_name self.millis_since_epoch = round(time.time() * 1000) - def run(self, graph, implementations, direction, output_folder=None, - correctness=True, benchmark=True, high_precision_ref=False): + def run( + self, + graph, + implementations, + direction, + output_folder=None, + correctness=True, + benchmark=True, + high_precision_ref=False, + ): if output_folder is None: if oeq._check_package_editable(): - output_folder = oeq._editable_install_output_path / f"{self.millis_since_epoch}" + output_folder = ( + oeq._editable_install_output_path / f"{self.millis_since_epoch}" + ) else: - raise ValueError("output folder must be specified for non-editable installs.") + raise ValueError( + "output folder must be specified for non-editable installs." + ) else: - output_folder = pathlib.Path(output_folder) + output_folder = pathlib.Path(output_folder) output_folder.mkdir(parents=True, exist_ok=True) metadata = { - "test_name": self.test_name, - "configs": [str(config) for config in self.configs], + "test_name": self.test_name, + "configs": [str(config) for config in self.configs], "implementations": [impl.name() for impl in implementations], - "graph": graph.name + "graph": graph.name, } if self.exp_count == 0: - with open(os.path.join(output_folder,'metadata.json'), 'w') as f: - json.dump(metadata, f, indent=2) - - for config in self.configs: - L1_in, L2_in, weights, L3_out = get_random_buffers_forward_conv(config, graph.node_count, graph.nnz, self.prng_seed) + with open(os.path.join(output_folder, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) + for config in self.configs: for impl in implementations: tc_name = f"{config}, {impl.name()}" - logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}') - conv = impl(config) + logger.info(f"Starting {tc_name}, graph {graph.name}, {direction}") + conv = impl(config) if direction == "forward": if correctness: - correctness = conv.test_correctness_forward(graph, - thresh=self.correctness_threshold, - prng_seed=self.prng_seed, - reference_implementation=self.reference_impl, - high_precision_ref=high_precision_ref) + correctness = conv.test_correctness_forward( + graph, + thresh=self.correctness_threshold, + prng_seed=self.prng_seed, + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref, + ) if benchmark: - benchmark = conv.benchmark_forward(self.num_warmup, - self.num_iter, graph, prng_seed=12345) - + benchmark = conv.benchmark_forward( + self.num_warmup, self.num_iter, graph, prng_seed=12345 + ) if direction == "backward": if correctness: - correctness = conv.test_correctness_backward(graph, - thresh=self.correctness_threshold, - prng_seed=self.prng_seed, - reference_implementation=self.reference_impl, - high_precision_ref=high_precision_ref) + correctness = conv.test_correctness_backward( + graph, + thresh=self.correctness_threshold, + prng_seed=self.prng_seed, + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref, + ) if benchmark: - benchmark = conv.benchmark_backward(self.num_warmup, - self.num_iter, graph, prng_seed=12345) - - if direction == "double_backward": + benchmark = conv.benchmark_backward( + self.num_warmup, self.num_iter, graph, prng_seed=12345 + ) + + if direction == "double_backward": if correctness: - correctness = conv.test_correctness_double_backward(self.graph, - thresh=self.correctness_threshold, - prng_seed=self.prng_seed, - reference_implementation=self.reference_impl, - high_precision_ref=high_precision_ref) - + correctness = conv.test_correctness_double_backward( + self.graph, + thresh=self.correctness_threshold, + prng_seed=self.prng_seed, + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref, + ) + assert not benchmark result = { @@ -111,14 +138,16 @@ def run(self, graph, implementations, direction, output_folder=None, "graph": graph.name, "name": impl.name(), "correctness": correctness, - "benchmark": benchmark + "benchmark": benchmark, } - - fname = pathlib.Path(f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json") - with open(fname, 'w') as f: + + fname = pathlib.Path( + f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json" + ) + with open(fname, "w") as f: json.dump(result, f, indent=2) self.exp_count += 1 - logger.info(f'Finished {tc_name}, graph {graph.name}') - + logger.info(f"Finished {tc_name}, graph {graph.name}") + return output_folder diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/benchmark/TestBenchmarkSuite.py index 7b01079e..f28153f0 100644 --- a/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/benchmark/TestBenchmarkSuite.py @@ -11,52 +11,53 @@ from openequivariance.implementations.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger, bcolors -from openequivariance.extlib import * from openequivariance.implementations.e3nn_lite import TPProblem from openequivariance.benchmark.correctness_utils import ( - correctness_forward, + correctness_forward, correctness_backward, - correctness_double_backward - ) + correctness_double_backward, +) from openequivariance.benchmark.benchmark_utils import ( - benchmark_forward, + benchmark_forward, benchmark_backward, - benchmark_double_backward - ) + benchmark_double_backward, +) logger = getLogger() -Direction = Literal['forward', 'backward', 'double_backward'] +Direction = Literal["forward", "backward", "double_backward"] + class TestDefinition(NamedTuple): - implementation : type[TensorProductBase] - problem : TPProblem - direction : Direction - correctness : bool = True - benchmark : bool = True + implementation: type[TensorProductBase] + problem: TPProblem + direction: Direction + correctness: bool = True + benchmark: bool = True + @dataclass(init=True, repr=False, eq=False) class TestBenchmarkSuite: - num_warmup : int = 10 - num_iter : int = 30 - correctness_batch_size : int = 10_000 - bench_batch_size : int = 10_000_000 - prng_seed : int = 12345 - reference_implementation : Optional[type[TensorProductBase]] = None - correctness_threshold_forward : float = 5e-7 - correctness_threshold_backward : float = 1e-4 - correctness_threshold_double_backward : float = 1e-4 - torch_op : bool = True + num_warmup: int = 10 + num_iter: int = 30 + correctness_batch_size: int = 10_000 + bench_batch_size: int = 10_000_000 + prng_seed: int = 12345 + reference_implementation: Optional[type[TensorProductBase]] = None + correctness_threshold_forward: float = 5e-7 + correctness_threshold_backward: float = 1e-4 + correctness_threshold_double_backward: float = 1e-4 + torch_op: bool = True test_name: Optional[str] = None metadata: Optional[dict] = None results: Optional[list] = None @staticmethod - def validate_inputs(test_list : list[TestDefinition]) -> None: + def validate_inputs(test_list: list[TestDefinition]) -> None: """ - Just does empty list and type checking to catch bad input + Just does empty list and type checking to catch bad input """ - assert isinstance(test_list, list) + assert isinstance(test_list, list) assert len(test_list) != 0 for test in test_list: assert isinstance(test, TestDefinition) @@ -67,51 +68,57 @@ def validate_inputs(test_list : list[TestDefinition]) -> None: assert isinstance(test.benchmark, bool) @staticmethod - def generate_metadata(test_list : list[TestDefinition]) -> dict[str, Any]: + def generate_metadata(test_list: list[TestDefinition]) -> dict[str, Any]: impls, tpps, directions, corectnesses, benchmarks = zip(*test_list) config_strs = list(dict.fromkeys([str(tpp) for tpp in tpps])) config_reprs = list(dict.fromkeys([repr(tpp) for tpp in tpps])) config_labels = list(dict.fromkeys([tpp.label for tpp in tpps])) - implementation_names = list(dict.fromkeys([impl.name() for impl in impls])) + implementation_names = list(dict.fromkeys([impl.name() for impl in impls])) directions = list(dict.fromkeys(directions)) did_correctness = any(corectnesses) did_benchmark = any(benchmarks) - + dp = DeviceProp(0) metadata = { - 'config_strs' : config_strs, - 'config_reprs': config_reprs, - 'config_labels' : config_labels, - 'implementations' : implementation_names, - 'directions' : directions, - 'did_correctness' : did_correctness, - 'did_benchmark' : did_benchmark, - 'gpu_name' : dp.name, - } - + "config_strs": config_strs, + "config_reprs": config_reprs, + "config_labels": config_labels, + "implementations": implementation_names, + "directions": directions, + "did_correctness": did_correctness, + "did_benchmark": did_benchmark, + "gpu_name": dp.name, + } + test_details = {} for test_ID, test in enumerate(test_list): test_details[test_ID] = { - 'implementation' : test.implementation.name(), - 'problem' : repr(test.problem), - 'direction' : test.direction, - 'correctness' : test.correctness, - 'benchmark' : test.benchmark, - } - - metadata['test details'] = test_details + "implementation": test.implementation.name(), + "problem": repr(test.problem), + "direction": test.direction, + "correctness": test.correctness, + "benchmark": test.benchmark, + } + + metadata["test details"] = test_details return metadata - def run(self, test_list : list[TestDefinition], output_folder=None, progressbar=False) -> pathlib.Path: + def run( + self, test_list: list[TestDefinition], output_folder=None, progressbar=False + ) -> pathlib.Path: self.results = [] millis_since_epoch = round(time.time() * 1000) if output_folder is None: if oeq._check_package_editable(): - output_folder = oeq._editable_install_output_path / f"{millis_since_epoch}" + output_folder = ( + oeq._editable_install_output_path / f"{millis_since_epoch}" + ) else: - raise ValueError("output folder must be specified for non-editable installs.") + raise ValueError( + "output folder must be specified for non-editable installs." + ) else: output_folder = pathlib.Path(output_folder) @@ -119,46 +126,48 @@ def run(self, test_list : list[TestDefinition], output_folder=None, progressbar= output_folder.mkdir(parents=True) metadata = TestBenchmarkSuite.generate_metadata(test_list) - metadata["test_name"] = self.test_name + metadata["test_name"] = self.test_name - with open(os.path.join(output_folder,'metadata.json'), 'w') as f: - json.dump(metadata, f, indent=2) + with open(os.path.join(output_folder, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) target_iterable = enumerate(test_list) if progressbar: import tqdm - target_iterable = tqdm.tqdm(target_iterable, desc=self.test_name, - total=len(test_list)) - for test_ID, test in target_iterable: + target_iterable = tqdm.tqdm( + target_iterable, desc=self.test_name, total=len(test_list) + ) + + for test_ID, test in target_iterable: impl = test.implementation tpp = test.problem - logger.info(f'Starting Test ID: {test_ID}') - logger.info(f'Config: {str(tpp)}') - logger.info(f'Irrep dtype: {tpp.irrep_dtype.__name__}') - logger.info(f'Weight dtype: {tpp.weight_dtype.__name__}') - if(tpp.label): - logger.info(f'{bcolors.OKCYAN}{tpp.label}{bcolors.ENDC}') - logger.info(f'Implementation Name: {impl.name()}') - logger.info(f'Test Direction: {test.direction}') + logger.info(f"Starting Test ID: {test_ID}") + logger.info(f"Config: {str(tpp)}") + logger.info(f"Irrep dtype: {tpp.irrep_dtype.__name__}") + logger.info(f"Weight dtype: {tpp.weight_dtype.__name__}") + if tpp.label: + logger.info(f"{bcolors.OKCYAN}{tpp.label}{bcolors.ENDC}") + logger.info(f"Implementation Name: {impl.name()}") + logger.info(f"Test Direction: {test.direction}") logger.info(f"Torch Overhead Included: {self.torch_op}") result = { - "config_str" : str(tpp), - "config_repr" : repr(tpp), - "config_label" : tpp.label, - "direction" : test.direction, + "config_str": str(tpp), + "config_repr": repr(tpp), + "config_label": tpp.label, + "direction": test.direction, "implementation_name": impl.name(), "correctness": str(test.correctness), "benchmark": str(test.benchmark), - "torch_overhead_included": self.torch_op + "torch_overhead_included": self.torch_op, } - if test.direction == 'forward': + if test.direction == "forward": if test.correctness: logger.info("Starting correctness check...") - result['correctness results'] = correctness_forward( + result["correctness results"] = correctness_forward( problem=tpp, test_implementation=impl, reference_implementation=self.reference_implementation, @@ -168,73 +177,72 @@ def run(self, test_list : list[TestDefinition], output_folder=None, progressbar= ) logger.info("Finished correctness check...") if test.benchmark: - result['benchmark results'] = benchmark_forward( + result["benchmark results"] = benchmark_forward( problem=tpp, implementation=impl, batch_size=self.bench_batch_size, num_warmup=self.num_warmup, num_iter=self.num_iter, prng_seed=self.prng_seed, - torch_op=self.torch_op + torch_op=self.torch_op, ) - - if test.direction == 'backward': - if test.correctness: + if test.direction == "backward": + if test.correctness: logger.info("Starting correctness check...") - result ['correctness results'] = correctness_backward( + result["correctness results"] = correctness_backward( problem=tpp, test_implementation=impl, reference_implementation=self.reference_implementation, batch_size=self.correctness_batch_size, correctness_threshold=self.correctness_threshold_backward, - prng_seed=self.prng_seed + prng_seed=self.prng_seed, ) logger.info("Finished correctness check...") - if test.benchmark: - result ['benchmark results'] = benchmark_backward( + if test.benchmark: + result["benchmark results"] = benchmark_backward( problem=tpp, implementation=impl, batch_size=self.bench_batch_size, num_warmup=self.num_warmup, num_iter=self.num_iter, prng_seed=self.prng_seed, - torch_op=self.torch_op + torch_op=self.torch_op, ) - - if test.direction == 'double_backward': + + if test.direction == "double_backward": if test.correctness: logger.info("Starting correctness check...") - result ['correctness results'] = correctness_double_backward( + result["correctness results"] = correctness_double_backward( problem=tpp, test_implementation=impl, reference_implementation=self.reference_implementation, batch_size=self.correctness_batch_size, - correctness_threshold= self.correctness_threshold_double_backward, - prng_seed=self.prng_seed + correctness_threshold=self.correctness_threshold_double_backward, + prng_seed=self.prng_seed, ) logger.info("Finished correctness check...") - if test.benchmark: - result ['benchmark results'] = benchmark_double_backward( - problem=tpp, - implementation=impl, - batch_size=self.bench_batch_size, - num_warmup=self.num_warmup, + if test.benchmark: + result["benchmark results"] = benchmark_double_backward( + problem=tpp, + implementation=impl, + batch_size=self.bench_batch_size, + num_warmup=self.num_warmup, num_iter=self.num_iter, - prng_seed=self.prng_seed, - torch_op=self.torch_op, + prng_seed=self.prng_seed, + torch_op=self.torch_op, ) - + fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json") - pretty_result = json.dumps(obj=result, indent=2).replace('\\n', '\n') + pretty_result = json.dumps(obj=result, indent=2).replace("\\n", "\n") logger.debug(pretty_result) - with open(fname, 'w') as f: + with open(fname, "w") as f: json.dump(result, f, indent=2) self.results.append(result) - logger.info(f'Finished Test ID: {test_ID}') + logger.info(f"Finished Test ID: {test_ID}") self.metadata = metadata - return output_folder \ No newline at end of file + return output_folder diff --git a/openequivariance/benchmark/benchmark_configs.py b/openequivariance/benchmark/benchmark_configs.py index 009bc1e2..9fec8f63 100644 --- a/openequivariance/benchmark/benchmark_configs.py +++ b/openequivariance/benchmark/benchmark_configs.py @@ -1,72 +1,141 @@ -from openequivariance.benchmark.tpp_creation_utils import FullyConnectedTPProblem as FCTPP +from openequivariance.benchmark.tpp_creation_utils import ( + FullyConnectedTPProblem as FCTPP, +) from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP as CTPP -import numpy as np # source: https://github.com/e3nn/e3nn/blob/main/examples/tetris.py -# running tetris will output the layers. I've only extracted the fully connected layers here. +# running tetris will output the layers. I've only extracted the fully connected layers here. e3nn_torch_tetris = [ - # 0th Layer - FCTPP("1x0e", "1x0e", "150x0e + 50x1o + 50x2e"), #sc - FCTPP("1x0e", "1x0e", "1x0e"), #lin1 - FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "150x0e + 50x1o + 50x2e"), #lin2 - FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "1x0e"), #alpha - + # 0th Layer + FCTPP("1x0e", "1x0e", "150x0e + 50x1o + 50x2e"), # sc + FCTPP("1x0e", "1x0e", "1x0e"), # lin1 + FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "150x0e + 50x1o + 50x2e"), # lin2 + FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "1x0e"), # alpha # 1st Layer - FCTPP("50x0e + 50x1o + 50x2e", "1x0e", "250x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #sc - FCTPP("50x0e + 50x1o + 50x2e", "1x0e", "50x0e + 50x1o + 50x2e"), #lin1 + FCTPP( + "50x0e + 50x1o + 50x2e", "1x0e", "250x0e + 50x1o + 50x1e + 50x2o + 50x2e" + ), # sc + FCTPP("50x0e + 50x1o + 50x2e", "1x0e", "50x0e + 50x1o + 50x2e"), # lin1 # FCTPP("50x0e + 50x1o + 50x2e", "1x0e + 1x1o + 1x2e", "150x0e + 200x1o + 100x1e + 100x2o + 200x2e"), #tp - FCTPP("150x0e + 200x1o + 100x1e + 100x2o + 200x2e", "1x0e", "250x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #lin2 - FCTPP("150x0e + 200x1o + 100x1e + 100x2o + 200x2e", "1x0e", "1x0e"), #alpha - - # 2nd Layer - FCTPP("50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #sc - FCTPP("50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "50x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #lin1 - FCTPP("100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", "1x0e", "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #lin2 - FCTPP("100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", "1x0e", "1x0e"), #alpha - - # 3rd Layer - FCTPP("50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "1x0o + 6x0e"), #sc - FCTPP("50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e"), #lin1 - FCTPP("150x0o + 150x0e", "1x0e", "1x0o + 6x0e"), #lin2 - FCTPP("150x0o + 150x0e", "1x0e", "1x0e"), #alpha -] + FCTPP( + "150x0e + 200x1o + 100x1e + 100x2o + 200x2e", + "1x0e", + "250x0e + 50x1o + 50x1e + 50x2o + 50x2e", + ), # lin2 + FCTPP("150x0e + 200x1o + 100x1e + 100x2o + 200x2e", "1x0e", "1x0e"), # alpha + # 2nd Layer + FCTPP( + "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", + "1x0e", + "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e", + ), # sc + FCTPP( + "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", + "1x0e", + "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", + ), # lin1 + FCTPP( + "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", + "1x0e", + "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e", + ), # lin2 + FCTPP( + "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", "1x0e", "1x0e" + ), # alpha + # 3rd Layer + FCTPP("50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "1x0o + 6x0e"), # sc + FCTPP( + "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", + "1x0e", + "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", + ), # lin1 + FCTPP("150x0o + 150x0e", "1x0e", "1x0o + 6x0e"), # lin2 + FCTPP("150x0o + 150x0e", "1x0e", "1x0e"), # alpha +] -# jax version can be found here, but doesn't directly translate +# jax version can be found here, but doesn't directly translate # https://github.com/e3nn/e3nn-jax/blob/main/examples/tetris_point.py # source: https://github.com/e3nn/e3nn/blob/f95297952303347a8a3cfe971efe449c710c43b2/examples/tetris_polynomial.py#L66-L68 e3nn_torch_tetris_polynomial = [ - FCTPP("1x0e + 1x1o + 1x2e + 1x3o", "1x0e + 1x1o + 1x2e + 1x3o", "64x0e + 24x1e + 24x1o + 16x2e + 16x2o", label="tetris-poly-1"), #tp1 - FCTPP("64x0e + 24x1e + 24x1o + 16x2e + 16x2o", "1x0e + 1x1o + 1x2e", "0o + 6x0e", label="tetris-poly-2"), #tp2 + FCTPP( + "1x0e + 1x1o + 1x2e + 1x3o", + "1x0e + 1x1o + 1x2e + 1x3o", + "64x0e + 24x1e + 24x1o + 16x2e + 16x2o", + label="tetris-poly-1", + ), # tp1 + FCTPP( + "64x0e + 24x1e + 24x1o + 16x2e + 16x2o", + "1x0e + 1x1o + 1x2e", + "0o + 6x0e", + label="tetris-poly-2", + ), # tp2 ] # https://github.com/gcorso/DiffDock/blob/b4704d94de74d8cb2acbe7ec84ad234c09e78009/models/tensor_layers.py#L299 -# specific irreps come from vivek's communication with DiffDock team +# specific irreps come from Vivek's communication with DiffDock team diffdock_configs = [ - FCTPP("10x1o + 10x1e + 48x0e + 48x0o", "1x0e + 1x1o", "10x1o + 10x1e + 48x0e + 48x0o", shared_weights=False, label='DiffDock-L=1'), - FCTPP("10x1o + 10x1e + 48x0e + 48x0o", "1x0e + 1x1o + 1x2e", "10x1o + 10x1e + 48x0e + 48x0o", shared_weights=False, label='DiffDock-L=2'), + FCTPP( + "10x1o + 10x1e + 48x0e + 48x0o", + "1x0e + 1x1o", + "10x1o + 10x1e + 48x0e + 48x0o", + shared_weights=False, + label="DiffDock-L=1", + ), + FCTPP( + "10x1o + 10x1e + 48x0e + 48x0o", + "1x0e + 1x1o + 1x2e", + "10x1o + 10x1e + 48x0e + 48x0o", + shared_weights=False, + label="DiffDock-L=2", + ), ] mace_conv = [ - ("128x0e+128x1o+128x2e", "1x0e+1x1o+1x2e+1x3o", "128x0e+128x1o+128x2e+128x3o", "mace-large"), - ("128x0e+128x1o", "1x0e+1x1o+1x2e+1x3o", "128x0e+128x1o+128x2e", "mace-medium") + ( + "128x0e+128x1o+128x2e", + "1x0e+1x1o+1x2e+1x3o", + "128x0e+128x1o+128x2e+128x3o", + "mace-large", + ), + ("128x0e+128x1o", "1x0e+1x1o+1x2e+1x3o", "128x0e+128x1o+128x2e", "mace-medium"), ] nequip_conv = [ - ('32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e', '0e + 1o + 2e', '32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e', - 'nequip-lips'), - ('64x0o + 64x0e + 64x1o + 64x1e', '0e + 1o', '64x0o + 64x0e + 64x1o + 64x1e', - 'nequip-revmd17-aspirin'), - ('64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e', '0e + 1o + 2e', '64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e', - 'nequip-revmd17-toluene'), - ('64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e', '0e + 1o + 2e + 3o', '64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e', - 'nequip-revmd17-benzene'), - ('32x0o + 32x0e + 32x1o + 32x1e', '0e + 1o', '32x0o + 32x0e + 32x1o + 32x1e', - 'nequip-water'), + ( + "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e", + "0e + 1o + 2e", + "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e", + "nequip-lips", + ), + ( + "64x0o + 64x0e + 64x1o + 64x1e", + "0e + 1o", + "64x0o + 64x0e + 64x1o + 64x1e", + "nequip-revmd17-aspirin", + ), + ( + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e", + "0e + 1o + 2e", + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e", + "nequip-revmd17-toluene", + ), + ( + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", + "0e + 1o + 2e + 3o", + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", + "nequip-revmd17-benzene", + ), + ( + "32x0o + 32x0e + 32x1o + 32x1e", + "0e + 1o", + "32x0o + 32x0e + 32x1o + 32x1e", + "nequip-water", + ), ] mace_problems = [CTPP(*config) for config in mace_conv] mace_nequip_problems = [] for config in mace_conv + nequip_conv: - mace_nequip_problems.append(CTPP(*config)) \ No newline at end of file + mace_nequip_problems.append(CTPP(*config)) diff --git a/openequivariance/benchmark/benchmark_routines/paper_benchmark_uvw.py b/openequivariance/benchmark/benchmark_routines/paper_benchmark_uvw.py index 55397dd9..18d88ca3 100644 --- a/openequivariance/benchmark/benchmark_routines/paper_benchmark_uvw.py +++ b/openequivariance/benchmark/benchmark_routines/paper_benchmark_uvw.py @@ -1,58 +1,75 @@ -import itertools, sys, os, logging, copy, pathlib +import itertools +import logging +import copy +import pathlib +from typing import List import numpy as np +from torch._functorch import config + from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs +from openequivariance.implementations.E3NNTensorProduct import ( + E3NNTensorProductCompiledCUDAGraphs, +) from openequivariance.implementations.CUETensorProduct import CUETensorProduct from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.benchmark.TestBenchmarkSuite import TestBenchmarkSuite, TestDefinition, Direction -from openequivariance.benchmark.tpp_creation_utils import FullyConnectedTPProblem -from openequivariance.benchmark.benchmark_configs import e3nn_torch_tetris_polynomial, diffdock_configs +from openequivariance.benchmark.TestBenchmarkSuite import ( + TestBenchmarkSuite, + TestDefinition, +) +from openequivariance.benchmark.benchmark_configs import ( + e3nn_torch_tetris_polynomial, + diffdock_configs, +) logger = getLogger() -import torch -from torch._functorch import config + @config.patch("donated_buffer", False) def run_paper_uvw_benchmark(params) -> pathlib.Path: - FCTPP = FullyConnectedTPProblem - - problems = list(itertools.chain( - e3nn_torch_tetris_polynomial, - diffdock_configs - )) + problems = list(itertools.chain(e3nn_torch_tetris_polynomial, diffdock_configs)) float64_problems = copy.deepcopy(problems) - for problem in float64_problems: + for problem in float64_problems: problem.irrep_dtype = np.float64 problem.weight_dtype = np.float64 - + problems += float64_problems - implementations = [ + implementations: List[TensorProduct] = [ E3NNTensorProductCompiledCUDAGraphs, CUETensorProduct, - TensorProduct] + TensorProduct, + ] - tests = [TestDefinition(implementation, problem, direction, correctness=False, benchmark=True) - for problem, direction, implementation - in itertools.product(problems, params.directions, implementations)] + tests = [ + TestDefinition( + implementation, problem, direction, correctness=False, benchmark=True + ) + for problem, direction, implementation in itertools.product( + problems, params.directions, implementations + ) + ] bench_suite = TestBenchmarkSuite( - num_warmup=100, - num_iter=100, - bench_batch_size=params.batch_size, - prng_seed=11111, - torch_op=True, - test_name="uvw" - ) - + num_warmup=100, + num_iter=100, + bench_batch_size=params.batch_size, + prng_seed=11111, + torch_op=True, + test_name="uvw", + ) + logger.setLevel(logging.INFO) data_folder = bench_suite.run(tests, output_folder=params.output_folder) if params.plot: import openequivariance.benchmark.plotting as plotting - plotting.plot_uvw(data_folder) -if __name__ == '__main__': - run_paper_uvw_benchmark() \ No newline at end of file + plotting.plot_uvw(data_folder) + + return data_folder + + +if __name__ == "__main__": + run_paper_uvw_benchmark() diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/benchmark/benchmark_utils.py index 4cba4c73..a243762a 100644 --- a/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/benchmark/benchmark_utils.py @@ -1,75 +1,80 @@ import numpy as np from openequivariance.benchmark.random_buffer_utils import ( - get_random_buffers_forward, + get_random_buffers_forward, get_random_buffers_backward, get_random_buffers_double_backward, - ) +) from openequivariance.benchmark.perf_metrics_utils import ( - calculate_minimum_flops_forward, - calculate_minimum_memory_streamed_forward, + calculate_minimum_flops_forward, + calculate_minimum_memory_streamed_forward, calculate_minimum_memory_streamed_backward, - ) +) from openequivariance.implementations.utils import calculate_total_nnz from openequivariance.implementations.TensorProductBase import TensorProductBase from openequivariance.implementations.e3nn_lite import TPProblem from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() -def calculate_performance_statistics( - problem : TPProblem, - batch_size : int, - total_flops : int, - total_memory_streamed : int, - time_millis : np.ndarray, - ) -> dict: - result = {} - - throughputs_gflops = [float(x) for x in total_flops / (time_millis * 1e6)] - bandwidth_gbps = [float(x) for x in total_memory_streamed / (time_millis * 1e6)] - nnz = calculate_total_nnz(problem) - time_millis = [float(x) for x in time_millis] - - result |= { - "total_cg_nnz": nnz, - "flops_per_tp": total_flops / batch_size, - "L1": str(problem.irreps_in1), - "L2": str(problem.irreps_in2), - "L3": str(problem.irreps_out), - - "L1_rep_len": problem.irreps_in1.dim, - "L2_rep_len": problem.irreps_in2.dim, - "L3_rep_len": problem.irreps_out.dim, - - "rep_dtype": str(problem.irrep_dtype), - "weight_dtype": str(problem.weight_dtype), - "arithmetic_intensity (FLOPs / byte)": total_flops / total_memory_streamed, - - "batch_size":batch_size, - "time_millis": time_millis, - "throughputs_gflops": throughputs_gflops, - "bandwidth_gbps": bandwidth_gbps, - } - - logger.info(f"{bcolors.OKCYAN}Avg. Throughput: {bcolors.ENDC} {bcolors.WARNING}{np.mean(throughputs_gflops):.2f} ± {np.std(throughputs_gflops):.2f} GFLOPs{bcolors.ENDC}") - logger.info(f"{bcolors.OKCYAN}Avg. Bandwidth : {bcolors.ENDC} {bcolors.WARNING}{np.mean(bandwidth_gbps) :.2f} ± {np.std(bandwidth_gbps) :.2f} GBPs{bcolors.ENDC}") - logger.info(f"{bcolors.OKCYAN}Avg. Walltime : {bcolors.ENDC} {bcolors.WARNING}{np.mean(time_millis) :.2f} ± {np.std(time_millis) :.2f} ms{bcolors.ENDC}") - return result + +def calculate_performance_statistics( + problem: TPProblem, + batch_size: int, + total_flops: int, + total_memory_streamed: int, + time_millis: np.ndarray, +) -> dict: + result = {} + + throughputs_gflops = [float(x) for x in total_flops / (time_millis * 1e6)] + bandwidth_gbps = [float(x) for x in total_memory_streamed / (time_millis * 1e6)] + nnz = calculate_total_nnz(problem) + time_millis = [float(x) for x in time_millis] + + result |= { + "total_cg_nnz": nnz, + "flops_per_tp": total_flops / batch_size, + "L1": str(problem.irreps_in1), + "L2": str(problem.irreps_in2), + "L3": str(problem.irreps_out), + "L1_rep_len": problem.irreps_in1.dim, + "L2_rep_len": problem.irreps_in2.dim, + "L3_rep_len": problem.irreps_out.dim, + "rep_dtype": str(problem.irrep_dtype), + "weight_dtype": str(problem.weight_dtype), + "arithmetic_intensity (FLOPs / byte)": total_flops / total_memory_streamed, + "batch_size": batch_size, + "time_millis": time_millis, + "throughputs_gflops": throughputs_gflops, + "bandwidth_gbps": bandwidth_gbps, + } + + logger.info( + f"{bcolors.OKCYAN}Avg. Throughput: {bcolors.ENDC} {bcolors.WARNING}{np.mean(throughputs_gflops):.2f} ± {np.std(throughputs_gflops):.2f} GFLOPs{bcolors.ENDC}" + ) + logger.info( + f"{bcolors.OKCYAN}Avg. Bandwidth : {bcolors.ENDC} {bcolors.WARNING}{np.mean(bandwidth_gbps):.2f} ± {np.std(bandwidth_gbps):.2f} GBPs{bcolors.ENDC}" + ) + logger.info( + f"{bcolors.OKCYAN}Avg. Walltime : {bcolors.ENDC} {bcolors.WARNING}{np.mean(time_millis):.2f} ± {np.std(time_millis):.2f} ms{bcolors.ENDC}" + ) + return result + def benchmark_forward( - problem : TPProblem, - implementation : type[TensorProductBase], - batch_size : int, - num_warmup : int, - num_iter : int, - prng_seed : int, - torch_op: bool, - ) -> dict: - ''' + problem: TPProblem, + implementation: type[TensorProductBase], + batch_size: int, + num_warmup: int, + num_iter: int, + prng_seed: int, + torch_op: bool, +) -> dict: + """ This function sets up the necessary materials and calls the internal benchmarker - ''' + """ result = { "tp_direction": "forward", "num_warmup": num_warmup, @@ -77,186 +82,215 @@ def benchmark_forward( "prng_seed": prng_seed, } - L1_in, L2_in, weights, L3_buffer = get_random_buffers_forward(problem, batch_size, prng_seed) + L1_in, L2_in, weights, L3_buffer = get_random_buffers_forward( + problem, batch_size, prng_seed + ) if problem.shared_weights and implementation == CUETensorProduct: weights = weights[np.newaxis, :] logger.info("Initialized input / output data.") tp = implementation(problem, torch_op=torch_op) - # BENCHMARK - try: + # BENCHMARK + try: time_millis = tp.benchmark_forward( - num_warmup=num_warmup, + num_warmup=num_warmup, num_iter=num_iter, L1_in=L1_in, L2_in=L2_in, weights=weights, - L3_buffer=L3_buffer - ) + L3_buffer=L3_buffer, + ) except NotImplementedError: - logger.warning("Benchmarking is not implemented, time millis replaced with -1's") - time_millis = np.full(shape=num_iter, fill_value=-1) - - # FLOPS + logger.warning( + "Benchmarking is not implemented, time millis replaced with -1's" + ) + time_millis = np.full(shape=num_iter, fill_value=-1) + + # FLOPS try: flops = tp.calculate_flops_forward(batch_size=batch_size) except NotImplementedError: - logger.warning("Actual flop count not calculated, so minimum values are being used") + logger.warning( + "Actual flop count not calculated, so minimum values are being used" + ) flops = calculate_minimum_flops_forward(problem, batch_size=batch_size) - + # DATA - try: + try: memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning("Actual memory streamed not calculated, so minimum values are being used") - memory_streamed = calculate_minimum_memory_streamed_forward(problem, batch_size=batch_size) - + except NotImplementedError: + logger.warning( + "Actual memory streamed not calculated, so minimum values are being used" + ) + memory_streamed = calculate_minimum_memory_streamed_forward( + problem, batch_size=batch_size + ) result |= calculate_performance_statistics( problem=problem, batch_size=batch_size, total_flops=flops["total"], total_memory_streamed=memory_streamed["total"], - time_millis=time_millis - ) + time_millis=time_millis, + ) + + return result - return result def benchmark_backward( - problem : TPProblem, - implementation : type[TensorProductBase], - batch_size : int, - num_warmup : int, - num_iter : int, - prng_seed : int, - torch_op: bool, - ) -> dict: - - result = { - "tp_direction": "backward", - "num_warmup": num_warmup, - "num_iter": num_iter, - "prng_seed": prng_seed, - } - - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = get_random_buffers_backward(problem, batch_size, prng_seed) - if problem.shared_weights and implementation == CUETensorProduct: - weights = weights[np.newaxis, :] - - logger.info("Initialized input / output data.") - tp = implementation(problem, torch_op=torch_op) + problem: TPProblem, + implementation: type[TensorProductBase], + batch_size: int, + num_warmup: int, + num_iter: int, + prng_seed: int, + torch_op: bool, +) -> dict: + result = { + "tp_direction": "backward", + "num_warmup": num_warmup, + "num_iter": num_iter, + "prng_seed": prng_seed, + } - try: - time_millis = tp.benchmark_backward( - num_warmup=num_warmup, - num_iter=num_iter, - L1_in=in1, - L2_in=in2, - L3_buffer=out_grad, - weights=weights, - L1_grad=in1_grad, - L2_grad=in2_grad, - weights_grad=weights_grad - ) - except NotImplementedError: - logger.warning("Benchmarking is not implemented, time millis replaced with -1's") - time_millis = np.full(shape=num_iter, fill_value=-1) + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( + get_random_buffers_backward(problem, batch_size, prng_seed) + ) + if problem.shared_weights and implementation == CUETensorProduct: + weights = weights[np.newaxis, :] + logger.info("Initialized input / output data.") + tp = implementation(problem, torch_op=torch_op) + + try: + time_millis = tp.benchmark_backward( + num_warmup=num_warmup, + num_iter=num_iter, + L1_in=in1, + L2_in=in2, + L3_buffer=out_grad, + weights=weights, + L1_grad=in1_grad, + L2_grad=in2_grad, + weights_grad=weights_grad, + ) + except NotImplementedError: + logger.warning( + "Benchmarking is not implemented, time millis replaced with -1's" + ) + time_millis = np.full(shape=num_iter, fill_value=-1) + + try: + flops = tp.calculate_flops_backward(batch_size=batch_size) + except NotImplementedError: try: - flops = tp.calculate_flops_backward(batch_size=batch_size) + flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) + logger.warning( + "Actual flops was not calculated, so minimum values are being used" + ) except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning("Actual flops was not calculated, so minimum values are being used") - except NotImplementedError: - logger.warning("Minimum Backwards flops calculations are not implemented, -1 is a placeholder") - flops = {"total" : -1} - - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning("Actual memory streamed was not calculated, so minimum values are being") - memory_streamed = calculate_minimum_memory_streamed_backward(tpp=problem, batch_size=batch_size) - - result |= calculate_performance_statistics( - problem=problem, - batch_size=batch_size, - total_flops=flops["total"], - total_memory_streamed=memory_streamed["total"], - time_millis=time_millis + logger.warning( + "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" ) + flops = {"total": -1} + + try: + memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) + except NotImplementedError: + logger.warning( + "Actual memory streamed was not calculated, so minimum values are being" + ) + memory_streamed = calculate_minimum_memory_streamed_backward( + tpp=problem, batch_size=batch_size + ) + + result |= calculate_performance_statistics( + problem=problem, + batch_size=batch_size, + total_flops=flops["total"], + total_memory_streamed=memory_streamed["total"], + time_millis=time_millis, + ) + + return result - return result def benchmark_double_backward( - problem : TPProblem, - implementation : type[TensorProductBase], - batch_size : int, - num_warmup : int, - num_iter : int, - prng_seed : int, - torch_op: bool, - ) -> dict: - - result = { - "tp_direction": "double_backward", - "num_warmup": num_warmup, - "num_iter": num_iter, - "prng_seed": prng_seed, - } - - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad, out_double_grad = get_random_buffers_double_backward( - problem, - batch_size, - prng_seed - ) + problem: TPProblem, + implementation: type[TensorProductBase], + batch_size: int, + num_warmup: int, + num_iter: int, + prng_seed: int, + torch_op: bool, +) -> dict: + result = { + "tp_direction": "double_backward", + "num_warmup": num_warmup, + "num_iter": num_iter, + "prng_seed": prng_seed, + } - if problem.shared_weights and implementation == CUETensorProduct: - weights = weights[np.newaxis, :] + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad, out_double_grad = ( + get_random_buffers_double_backward(problem, batch_size, prng_seed) + ) - logger.info("Initialized input / output data.") - tp = implementation(problem, torch_op=torch_op) + if problem.shared_weights and implementation == CUETensorProduct: + weights = weights[np.newaxis, :] - try: - time_millis = tp.benchmark_double_backward( - num_warmup=num_warmup, - num_iter=num_iter, - L1_in=in1, - L2_in=in2, - L3_buffer=out_grad, - weights=weights, - L1_grad=in1_grad, - L2_grad=in2_grad, - weights_grad=weights_grad, - L3_double_grad=out_double_grad, - ) - except NotImplementedError: - logger.warning("Benchmarking is not implemented, time millis replaced with -1's") - time_millis = np.full(shape=num_iter, fill_value=-1) + logger.info("Initialized input / output data.") + tp = implementation(problem, torch_op=torch_op) + try: + time_millis = tp.benchmark_double_backward( + num_warmup=num_warmup, + num_iter=num_iter, + L1_in=in1, + L2_in=in2, + L3_buffer=out_grad, + weights=weights, + L1_grad=in1_grad, + L2_grad=in2_grad, + weights_grad=weights_grad, + L3_double_grad=out_double_grad, + ) + except NotImplementedError: + logger.warning( + "Benchmarking is not implemented, time millis replaced with -1's" + ) + time_millis = np.full(shape=num_iter, fill_value=-1) + + try: + flops = tp.calculate_flops_backward(batch_size=batch_size) + except NotImplementedError: try: - flops = tp.calculate_flops_backward(batch_size=batch_size) + flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) + logger.warning( + "Actual flops was not calculated, so minimum values are being used" + ) except NotImplementedError: - try: - flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size) - logger.warning("Actual flops was not calculated, so minimum values are being used") - except NotImplementedError: - logger.warning("Minimum Backwards flops calculations are not implemented, -1 is a placeholder") - flops = {"total" : -1} - - try: - memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) - except NotImplementedError: - logger.warning("Actual memory streamed was not calculated, so minimum values are being") - memory_streamed = calculate_minimum_memory_streamed_backward(tpp=problem, batch_size=batch_size) - - result |= calculate_performance_statistics( - problem=problem, - batch_size=batch_size, - total_flops=flops["total"], - total_memory_streamed=memory_streamed["total"], - time_millis=time_millis + logger.warning( + "Minimum Backwards flops calculations are not implemented, -1 is a placeholder" ) + flops = {"total": -1} + + try: + memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size) + except NotImplementedError: + logger.warning( + "Actual memory streamed was not calculated, so minimum values are being" + ) + memory_streamed = calculate_minimum_memory_streamed_backward( + tpp=problem, batch_size=batch_size + ) + + result |= calculate_performance_statistics( + problem=problem, + batch_size=batch_size, + total_flops=flops["total"], + total_memory_streamed=memory_streamed["total"], + time_millis=time_millis, + ) - return result \ No newline at end of file + return result diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index c0a70f63..6284dbb2 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -1,64 +1,83 @@ from typing import Optional, Union from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.CUETensorProduct import CUETensorProduct +from openequivariance.implementations.CUETensorProduct import CUETensorProduct from openequivariance.implementations.e3nn_lite import TPProblem -from openequivariance.benchmark.random_buffer_utils import get_random_buffers_forward, get_random_buffers_backward -from openequivariance.benchmark.logging_utils import getLogger, bcolors -import numpy as np +from openequivariance.benchmark.random_buffer_utils import ( + get_random_buffers_forward, + get_random_buffers_backward, +) +from openequivariance.benchmark.logging_utils import getLogger, bcolors +import numpy as np import numpy.linalg as la logger = getLogger() -def check_similiarity(name : str, to_check : np.ndarray, ground_truth : np.ndarray, correctness_threshold : float): + +def check_similiarity( + name: str, + to_check: np.ndarray, + ground_truth: np.ndarray, + correctness_threshold: float, +): result = {} if to_check.shape != ground_truth.shape: result["shape_match"] = False result["diff_Linf_norm"] = np.inf result["pass"] = False - logger.error(f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}") + logger.error( + f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}" + ) else: - result["shape_match"] = True + result["shape_match"] = True diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) - result["diff_Linf_norm"] = diff_Linf_norm + result["diff_Linf_norm"] = diff_Linf_norm result["pass"] = bool(diff_Linf_norm < correctness_threshold) if result["pass"]: - logger.info(f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}") + logger.info( + f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" + ) else: - logger.error(f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}") + logger.error( + f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" + ) return result -def instantiate_implementation(implementation : Union[type[TensorProductBase], TensorProductBase], problem : TPProblem): + +def instantiate_implementation( + implementation: Union[type[TensorProductBase], TensorProductBase], + problem: TPProblem, +): if isinstance(implementation, type): test_tp = implementation(problem) else: test_tp = implementation if not isinstance(test_tp, TensorProductBase): - raise TypeError(f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}") + raise TypeError( + f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}" + ) return test_tp + def correctness_forward( - problem : TPProblem, - test_implementation : Union[type[TensorProductBase], TensorProductBase], - reference_implementation : Optional[type[TensorProductBase]], - batch_size : int, - correctness_threshold : float, - prng_seed : int, - ) -> dict: - + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +) -> dict: if reference_implementation is None: from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + reference_implementation = E3NNTensorProduct - result = { - "thresh": correctness_threshold, - "batch_size": batch_size - } - + result = {"thresh": correctness_threshold, "batch_size": batch_size} + in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) # run reference @@ -66,54 +85,46 @@ def correctness_forward( ref_out = out.copy() ref_tp.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - L3_out=ref_out, - weights=weights.copy()) + L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy() + ) weights_copy = weights.copy() if problem.shared_weights and test_implementation == CUETensorProduct: weights_copy = weights[np.newaxis, :] # run test - test_tp = instantiate_implementation(test_implementation, problem) + test_tp = instantiate_implementation(test_implementation, problem) test_out = out.copy() test_tp.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - L3_out=test_out, - weights=weights_copy) + L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy + ) + + for name, to_check, ground_truth in [("output", ref_out, test_out)]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) - for name, to_check, ground_truth in [ - ("output", ref_out, test_out) - ]: - result[name] = check_similiarity(name, to_check, ground_truth, correctness_threshold) - return result + def correctness_backward( - problem : TPProblem, - test_implementation : Union[type[TensorProductBase], TensorProductBase], - reference_implementation : Optional[type[TensorProductBase]], - batch_size : int, - correctness_threshold : float, - prng_seed : int, - ) -> dict: - + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +) -> dict: if reference_implementation is None: from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + reference_implementation = E3NNTensorProduct - result = { - "thresh": correctness_threshold, - "batch_size": batch_size - } - + result = {"thresh": correctness_threshold, "batch_size": batch_size} + # run reference - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = get_random_buffers_backward( - problem, - batch_size, - prng_seed + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( + get_random_buffers_backward(problem, batch_size, prng_seed) ) ref_tp = reference_implementation(problem) @@ -125,12 +136,12 @@ def correctness_backward( ref_tp.backward_cpu( L1_in=in1.copy(), L1_grad=ref_in1_grad, - L2_in=in2.copy(), - L2_grad=ref_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), - weights_grad=ref_weights_grad - ) + L2_in=in2.copy(), + L2_grad=ref_in2_grad, + L3_grad=out_grad.copy(), + weights=weights.copy(), + weights_grad=ref_weights_grad, + ) # run test version test_weights_grad = weights_grad.copy() @@ -138,23 +149,27 @@ def correctness_backward( test_in2_grad = in2_grad.copy() weights_copy = weights.copy() - weights_grad_copy = weights_grad.copy() + if problem.shared_weights and test_implementation == CUETensorProduct: weights_copy = weights[np.newaxis, :] - test_weights_grad = test_weights_grad[np.newaxis, :] + test_weights_grad = test_weights_grad[np.newaxis, :] - test_tp = instantiate_implementation(test_implementation, problem) + test_tp = instantiate_implementation(test_implementation, problem) test_tp.backward_cpu( L1_in=in1.copy(), L1_grad=test_in1_grad, - L2_in=in2.copy(), - L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), + L2_in=in2.copy(), + L2_grad=test_in2_grad, + L3_grad=out_grad.copy(), weights=weights_copy, - weights_grad=test_weights_grad - ) + weights_grad=test_weights_grad, + ) - weight_threshold = correctness_threshold * batch_size if problem.shared_weights else correctness_threshold + weight_threshold = ( + correctness_threshold * batch_size + if problem.shared_weights + else correctness_threshold + ) if problem.shared_weights: test_weights_grad = test_weights_grad.squeeze() @@ -163,90 +178,100 @@ def correctness_backward( ("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold), ("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold), ("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold), - ]: + ]: result[name] = check_similiarity(name, to_check, ground_truth, threshold) return result -def correctness_double_backward( - problem : TPProblem, - test_implementation : Union[type[TensorProductBase], TensorProductBase], - reference_implementation : Optional[type[TensorProductBase]], - batch_size : int, - correctness_threshold : float, - prng_seed : int): +def correctness_double_backward( + problem: TPProblem, + test_implementation: Union[type[TensorProductBase], TensorProductBase], + reference_implementation: Optional[type[TensorProductBase]], + batch_size: int, + correctness_threshold: float, + prng_seed: int, +): global torch import torch - + in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward( - problem, - batch_size, - prng_seed + problem, batch_size, prng_seed ) rng = np.random.default_rng(seed=prng_seed * 2) - dummy_grad = rng.standard_normal(1)[0] - + dummy_grad = rng.standard_normal(1)[0] + if reference_implementation is None: from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + reference_implementation = E3NNTensorProduct - result = { - "thresh": correctness_threshold, - "batch_size": batch_size - } + result = {"thresh": correctness_threshold, "batch_size": batch_size} tensors = [] for i, impl in enumerate([test_implementation, reference_implementation]): - tp = instantiate_implementation(impl, problem) + tp = instantiate_implementation(impl, problem) - if impl == CUETensorProduct and problem.shared_weights : + if impl == CUETensorProduct and problem.shared_weights: weights = weights[np.newaxis, :] - weights_reordered = np.zeros_like(weights) + weights_reordered = np.zeros_like(weights) if tp.reorder_weights_e3nn_to_oeq is not None: - tp.reorder_weights_e3nn_to_oeq(weights, weights_reordered, not tp.config.shared_weights) + tp.reorder_weights_e3nn_to_oeq( + weights, weights_reordered, not tp.config.shared_weights + ) else: weights_reordered = weights - - in1_torch = torch.tensor(in1, device='cuda', requires_grad=True) - in2_torch = torch.tensor(in2, device='cuda', requires_grad=True) - weights_torch = torch.tensor(weights_reordered, device='cuda', requires_grad=True) + + in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) + in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + weights_torch = torch.tensor( + weights_reordered, device="cuda", requires_grad=True + ) out_torch = tp.forward(in1_torch, in2_torch, weights_torch) - out_grad = out_torch.clone().detach().to(device='cuda').requires_grad_(True) + out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True) in1_grad, in2_grad, w_grad = torch.autograd.grad( outputs=[out_torch], inputs=[in1_torch, in2_torch, weights_torch], grad_outputs=[out_grad], - create_graph=True) + create_graph=True, + ) dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor(float(dummy_grad), device='cuda', requires_grad=True) + dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True) - dummy.backward(dummy_grad, - retain_graph=True, - inputs=[out_grad, in1_torch, in2_torch, weights_torch]) + dummy.backward( + dummy_grad, + retain_graph=True, + inputs=[out_grad, in1_torch, in2_torch, weights_torch], + ) weights_grad = weights_torch.grad.detach().cpu().numpy() if tp.reorder_weights_oeq_to_e3nn is not None: weights_grad_copy = weights_grad.copy() - tp.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not tp.config.shared_weights) - - tensors.append(( - out_grad.grad.detach().cpu().numpy(), - in1_torch.grad.detach().cpu().numpy(), - in2_torch.grad.detach().cpu().numpy(), - weights_grad - )) + tp.reorder_weights_oeq_to_e3nn( + weights_grad_copy, weights_grad, not tp.config.shared_weights + ) + + tensors.append( + ( + out_grad.grad.detach().cpu().numpy(), + in1_torch.grad.detach().cpu().numpy(), + in2_torch.grad.detach().cpu().numpy(), + weights_grad, + ) + ) for name, to_check, ground_truth in [ ("output_double_grad", tensors[0][0], tensors[1][0]), ("in1_grad", tensors[0][1], tensors[1][1]), ("in2_grad", tensors[0][2], tensors[1][2]), - ("weights_grad", tensors[0][3], tensors[1][3]) - ]: - result[name] = check_similiarity(name, to_check, ground_truth, correctness_threshold) + ("weights_grad", tensors[0][3], tensors[1][3]), + ]: + result[name] = check_similiarity( + name, to_check, ground_truth, correctness_threshold + ) - return result \ No newline at end of file + return result diff --git a/openequivariance/benchmark/logging_utils.py b/openequivariance/benchmark/logging_utils.py index 08ce5fcf..6bfe811d 100644 --- a/openequivariance/benchmark/logging_utils.py +++ b/openequivariance/benchmark/logging_utils.py @@ -1,21 +1,24 @@ import logging + logger = logging.getLogger("ETP") logger.setLevel(logging.CRITICAL) ch = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) logger.addHandler(ch) + def getLogger(): return logger + class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" diff --git a/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/benchmark/perf_metrics_utils.py index 1209b0d4..edd314af 100644 --- a/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/benchmark/perf_metrics_utils.py @@ -1,6 +1,9 @@ import math -from openequivariance.implementations.utils import count_cg_non_zero, sparse_outer_product_work +from openequivariance.implementations.utils import ( + count_cg_non_zero, + sparse_outer_product_work, +) from openequivariance.implementations.TensorProductBase import TensorProductBase from openequivariance.implementations.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger @@ -8,7 +11,10 @@ logger = getLogger() -def calculate_minimum_memory_streamed_forward(tpp : TPProblem, batch_size : int) -> dict[str, int]: + +def calculate_minimum_memory_streamed_forward( + tpp: TPProblem, batch_size: int +) -> dict[str, int]: """ This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source @@ -17,37 +23,38 @@ def calculate_minimum_memory_streamed_forward(tpp : TPProblem, batch_size : int) irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize weight_word_size = np.dtype(tpp.weight_dtype).itemsize - data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size - data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size - data_size["output"] = tpp.irreps_out.dim * batch_size * irrep_word_size - data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size + data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size + data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size + data_size["output"] = tpp.irreps_out.dim * batch_size * irrep_word_size + data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size data_size["total"] = sum(data_size.values()) return data_size -def calculate_minimum_memory_streamed_backward(tpp : TPProblem, batch_size : int) -> dict: + +def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: """ - This represents an absolute minimum amount of memory that could be streamed on an ideal machine + This represents an absolute minimum amount of memory that could be streamed on an ideal machine It returns the number of bytes streamed total and from each source """ data_size = {} irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize weight_word_size = np.dtype(tpp.weight_dtype).itemsize - data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size - data_size["input 1 grad"] = tpp.irreps_in1.dim * batch_size * irrep_word_size - data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size - data_size["input 2 grad"] = tpp.irreps_in2.dim * batch_size * irrep_word_size - data_size["output grad"] = tpp.irreps_out.dim * batch_size * irrep_word_size - data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size - data_size["weights grad"] = tpp.weight_numel * batch_size * weight_word_size + data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size + data_size["input 1 grad"] = tpp.irreps_in1.dim * batch_size * irrep_word_size + data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size + data_size["input 2 grad"] = tpp.irreps_in2.dim * batch_size * irrep_word_size + data_size["output grad"] = tpp.irreps_out.dim * batch_size * irrep_word_size + data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size + data_size["weights grad"] = tpp.weight_numel * batch_size * weight_word_size data_size["total"] = sum(data_size.values()) return data_size -def calculate_minimum_flops_forward(tpp : TPProblem, batch_size : int) -> dict: +def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minimum value. - Ideally you might share the outer product values between two inputs across multiple inputs. + This is not actually calcuating the minimum value. + Ideally you might share the outer product values between two inputs across multiple inputs. This is assuming that you form those values and reuse them once per CG decomp. """ logger.warning("Minimum flops Calculation is not the true minimum") @@ -55,12 +62,22 @@ def calculate_minimum_flops_forward(tpp : TPProblem, batch_size : int) -> dict: flops_count["outer_products"] = 0 flops_count["CG_decomposition"] = 0 flops_count["linear_combination"] = 0 - for ins in tpp.instructions: # type : Instruction - l1, l2, l3 = tpp.irreps_in1[ins.i_in1].ir.l, tpp.irreps_in2[ins.i_in2].ir.l, tpp.irreps_out[ins.i_out].ir.l + for ins in tpp.instructions: # type : Instruction + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) - flops_count["outer_products"] += sparse_outer_product_work(TensorProductBase.load_cg_tensor(l1,l2,l3)) - flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (ins.path_shape[0] * ins.path_shape[1]) - flops_count["linear_combination"] += (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0 + flops_count["outer_products"] += sparse_outer_product_work( + TensorProductBase.load_cg_tensor(l1, l2, l3) + ) + flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + flops_count["linear_combination"] += ( + (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0 + ) flops_count["outer_products"] *= batch_size flops_count["CG_decomposition"] *= 2 * batch_size @@ -69,10 +86,11 @@ def calculate_minimum_flops_forward(tpp : TPProblem, batch_size : int) -> dict: flops_count["total"] = sum(flops_count.values()) return flops_count -def calculate_minimum_flops_backward(tpp : TPProblem, batch_size : int) -> dict: + +def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict: """ - This is not actually calcuating the minumum value. - Ideally you might share the outer product values between two inputs across multiple inputs. + This is not actually calcuating the minumum value. + Ideally you might share the outer product values between two inputs across multiple inputs. This is assuming that you form those values and reuse them once per CG decomp. """ - raise NotImplementedError("this needs to be implemented properly") \ No newline at end of file + raise NotImplementedError("this needs to be implemented properly") diff --git a/openequivariance/benchmark/plotting/__init__.py b/openequivariance/benchmark/plotting/__init__.py index f911bd04..3c0bf032 100644 --- a/openequivariance/benchmark/plotting/__init__.py +++ b/openequivariance/benchmark/plotting/__init__.py @@ -1,6 +1,15 @@ -from openequivariance.benchmark.plotting.plotting_utils import * from openequivariance.benchmark.plotting.plot_uvu import plot_uvu from openequivariance.benchmark.plotting.plot_uvw import plot_uvw from openequivariance.benchmark.plotting.plot_roofline import plot_roofline from openequivariance.benchmark.plotting.plot_convolution import plot_convolution -from openequivariance.benchmark.plotting.plot_double_backward import plot_double_backward \ No newline at end of file +from openequivariance.benchmark.plotting.plot_double_backward import ( + plot_double_backward, +) + +__all__ = [ + "plot_uvu", + "plot_uvw", + "plot_roofline", + "plot_convolution", + "plot_double_backward", +] diff --git a/openequivariance/benchmark/plotting/plot_convolution.py b/openequivariance/benchmark/plotting/plot_convolution.py index e2416b0f..f2f62070 100644 --- a/openequivariance/benchmark/plotting/plot_convolution.py +++ b/openequivariance/benchmark/plotting/plot_convolution.py @@ -1,24 +1,38 @@ +# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt -import os, json, pathlib, sys -from openequivariance.benchmark.plotting import * +import pathlib +from openequivariance.benchmark.plotting.plotting_utils import ( + set_grid, + colormap, + labelmap, + hatchmap, + dtypes, + directions, + dtype_labelmap, + grouped_barchart, + load_benchmarks, +) + def plot_convolution(data_folder): data_folder = pathlib.Path(data_folder) benchmarks, metadata = load_benchmarks(data_folder) - implementations = ["CUEConvolution", - "CUEConvolutionFused", - "LoopUnrollConvScatterSum", - "LoopUnrollConvAtomic", - "LoopUnrollConvDeterministic" - ] + implementations = [ + "CUEConvolution", + "CUEConvolutionFused", + "LoopUnrollConvScatterSum", + "LoopUnrollConvAtomic", + "LoopUnrollConvDeterministic", + ] graphs = ["1drf_radius6.0", "covid_spike_radius3.0", "carbon_lattice_radius6.0"] graph_lmap = { "covid_spike_radius3.0": "COVID spike", "1drf_radius6.0": "DHFR", - "carbon_lattice_radius6.0": "carbon-lattice"} + "carbon_lattice_radius6.0": "carbon-lattice", + } data = {} @@ -29,37 +43,56 @@ def plot_convolution(data_folder): for graph in graphs: data[direction][dtype][graph_lmap[graph]] = {} for impl in implementations: - exp = filter(benchmarks, { "graph": graph, - "direction": direction, - "name": impl, - "irrep_dtype": dtype - }, match_one=True) - - data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = np.mean(exp["benchmark"]["time_millis"]) - - + exp = filter( + benchmarks, + { + "graph": graph, + "direction": direction, + "name": impl, + "irrep_dtype": dtype, + }, + match_one=True, + ) + + data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = np.mean( + exp["benchmark"]["time_millis"] + ) + fig = plt.figure(figsize=(5, 5)) gs = fig.add_gridspec(2, 2, hspace=0, wspace=0) - axes = gs.subplots(sharex='col', sharey='row') - + axes = gs.subplots(sharex="col", sharey="row") + for i, direction in enumerate(directions): for j, dtype in enumerate(dtypes): for k, graph in enumerate(graphs): - normalizing_value = data[direction][dtype][graph_lmap[graph]]["cuE-scattersum"] + normalizing_value = data[direction][dtype][graph_lmap[graph]][ + "cuE-scattersum" + ] for impl in implementations: - data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = normalizing_value / data[direction][dtype][graph_lmap[graph]][labelmap[impl]] + data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = ( + normalizing_value + / data[direction][dtype][graph_lmap[graph]][labelmap[impl]] + ) - grouped_barchart(data[direction][dtype], axes[i][j], bar_height_fontsize=0, rotate_xlabels=True, colormap=colormap, hatchmap=hatchmap, group_spacing=6.0) + grouped_barchart( + data[direction][dtype], + axes[i][j], + bar_height_fontsize=0, + rotate_xlabels=True, + colormap=colormap, + hatchmap=hatchmap, + group_spacing=6.0, + ) axes[i][j].set_xlabel(dtype_labelmap[dtype]) axes[i][j].set_ylabel(direction) - axes[i][j].axhline(1.0, ls='--', c=colormap["cuE"]) + axes[i][j].axhline(1.0, ls="--", c=colormap["cuE"]) set_grid(axes[i][j]) - + axes[1][0].set_ylim(0, 3.8) for ax in fig.get_axes(): ax.label_outer() - + fig.supylabel("Speedup over cuE-scattersum", x=0.025, y=0.6) handles, labels = axes[0][0].get_legend_handles_labels() @@ -67,9 +100,11 @@ def plot_convolution(data_folder): if "fast" in l: labels[i] += " (ours)" - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] - fig.legend(*zip(*unique), loc='upper center', bbox_to_anchor=(0.55, 0.01)) + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] + fig.legend(*zip(*unique), loc="upper center", bbox_to_anchor=(0.55, 0.01)) fig.show() fig.tight_layout() - fig.savefig(str(data_folder / "kernel_fusion_speedup.pdf"), bbox_inches='tight') \ No newline at end of file + fig.savefig(str(data_folder / "kernel_fusion_speedup.pdf"), bbox_inches="tight") diff --git a/openequivariance/benchmark/plotting/plot_double_backward.py b/openequivariance/benchmark/plotting/plot_double_backward.py index 89096d90..a10193c5 100644 --- a/openequivariance/benchmark/plotting/plot_double_backward.py +++ b/openequivariance/benchmark/plotting/plot_double_backward.py @@ -1,7 +1,15 @@ +# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt -import os, json, pathlib, sys -from openequivariance.benchmark.plotting import * +import pathlib +from openequivariance.benchmark.plotting.plotting_utils import ( + set_grid, + colormap, + labelmap, + grouped_barchart, + load_benchmarks, +) + def plot_double_backward(data_folder): data_folder = pathlib.Path(data_folder) @@ -11,18 +19,29 @@ def plot_double_backward(data_folder): implementations = ["E3NNTensorProduct", "CUETensorProduct", "LoopUnrollTP"] def calculate_tp_per_sec(exp): - return exp["benchmark results"]["batch_size"] / (np.mean(exp["benchmark results"]["time_millis"]) * 0.001) + return exp["benchmark results"]["batch_size"] / ( + np.mean(exp["benchmark results"]["time_millis"]) * 0.001 + ) dataf32 = {"double_backward": {}} for i, desc in enumerate(configs): for direction in ["double_backward"]: dataf32[direction][desc] = {} for impl in implementations: - f32_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == ""] - exp = filter(f32_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) + f32_benches = [ + b + for b in benchmarks + if b["benchmark results"]["rep_dtype"] == "" + ] + exp = filter( + f32_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) dataf64 = {"double_backward": {}} @@ -30,12 +49,21 @@ def calculate_tp_per_sec(exp): for direction in ["double_backward"]: dataf64[direction][desc] = {} for impl in implementations: - f64_benches = [b for b in benchmarks if 'float64' in b["benchmark results"]["rep_dtype"]] - - exp = filter(f64_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) + f64_benches = [ + b + for b in benchmarks + if "float64" in b["benchmark results"]["rep_dtype"] + ] + + exp = filter( + f64_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) if exp is None: print(desc) @@ -46,10 +74,22 @@ def calculate_tp_per_sec(exp): fig = plt.figure(figsize=(7, 3)) gs = fig.add_gridspec(1, 2, hspace=0, wspace=0.1) - axs = gs.subplots(sharex='col', sharey='row') - - grouped_barchart(dataf32["double_backward"], axs[0], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0) - grouped_barchart(dataf64["double_backward"], axs[1], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0) + axs = gs.subplots(sharex="col", sharey="row") + + grouped_barchart( + dataf32["double_backward"], + axs[0], + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + ) + grouped_barchart( + dataf64["double_backward"], + axs[1], + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + ) for i in range(2): set_grid(axs[i]) @@ -59,7 +99,9 @@ def calculate_tp_per_sec(exp): axs[1].set_xlabel("float64") handles, labels = axs[0].get_legend_handles_labels() - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] axs[0].legend(*zip(*unique)) for ax in fig.get_axes(): @@ -68,21 +110,32 @@ def calculate_tp_per_sec(exp): fig.supylabel("2nd Deriv. Throughput\n(# tensor products / s)", y=0.5) speedup_table = [] - for direction in ['double_backward']: - for impl in ['e3nn', 'cuE']: - for dtype_label, dtype_set in [('f32', dataf32), ('f64', dataf64)]: - speedups = [measurement['ours'] / measurement[impl] for _, measurement in dtype_set[direction].items() if impl in measurement] - stats = np.min(speedups), np.mean(speedups), np.median(speedups), np.max(speedups) + for direction in ["double_backward"]: + for impl in ["e3nn", "cuE"]: + for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: + speedups = [ + measurement["ours"] / measurement[impl] + for _, measurement in dtype_set[direction].items() + if impl in measurement + ] + stats = ( + np.min(speedups), + np.mean(speedups), + np.median(speedups), + np.max(speedups), + ) stats = [f"{stat:.2f}" for stat in stats] dir_print = direction result = [dir_print, impl, dtype_label] + stats speedup_table.append(result) - print('\t\t'.join(['Direction', 'Base', 'dtype', 'min', 'mean', 'med', 'max'])) + print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) for row in speedup_table: - print('\t\t'.join(row)) + print("\t\t".join(row)) fig.show() fig.tight_layout() - fig.savefig(str(data_folder / "double_backward_throughput.pdf"), bbox_inches='tight') \ No newline at end of file + fig.savefig( + str(data_folder / "double_backward_throughput.pdf"), bbox_inches="tight" + ) diff --git a/openequivariance/benchmark/plotting/plot_roofline.py b/openequivariance/benchmark/plotting/plot_roofline.py index 3d70e143..42e08020 100644 --- a/openequivariance/benchmark/plotting/plot_roofline.py +++ b/openequivariance/benchmark/plotting/plot_roofline.py @@ -1,47 +1,94 @@ +# ruff: noqa: E741 import numpy as np -import matplotlib.pyplot as plt -import os, json, pathlib, sys -from openequivariance.benchmark.plotting import * +import pathlib +from openequivariance.benchmark.plotting.plotting_utils import ( + colormap, + labelmap, + load_benchmarks, + roofline_plot, +) + def plot_roofline(data_folder): data_folder = pathlib.Path(data_folder) benchmarks, metadata = load_benchmarks(data_folder) configs = metadata["config_labels"] - implementations = ["LoopUnrollTP", "CUETensorProduct"] + implementations = ["LoopUnrollTP", "CUETensorProduct"] data = {"forward": {}, "backward": {}} for i, desc in enumerate(configs): for direction in ["forward", "backward"]: data[direction][desc] = {} for impl in implementations: - exp = filter(benchmarks, {"config_label": desc, - "direction": direction, - "implementation_name": impl}, match_one=True) - data[direction][desc][labelmap[impl]] = (exp["benchmark results"]["arithmetic_intensity (FLOPs / byte)"], np.mean(exp["benchmark results"]["throughputs_gflops"])) + exp = filter( + benchmarks, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) + data[direction][desc][labelmap[impl]] = ( + exp["benchmark results"]["arithmetic_intensity (FLOPs / byte)"], + np.mean(exp["benchmark results"]["throughputs_gflops"]), + ) roofline_data = [] - marker_map = {"forward-cuE": "+", "backward-cuE": "X", "forward-ours": "P", "backward-ours": "X"} + marker_map = { + "forward-cuE": "+", + "backward-cuE": "X", + "forward-ours": "P", + "backward-ours": "X", + } for i, desc in enumerate(configs): for direction in ["forward", "backward"]: ai, throughput = data[direction][desc][labelmap["LoopUnrollTP"]] label = f"{direction}-ours" - roofline_data.append({"AI": float(ai), "throughput": throughput / 1000, "label": label, "marker": marker_map[label], "color": colormap['ours'], "markersize": 80}) + roofline_data.append( + { + "AI": float(ai), + "throughput": throughput / 1000, + "label": label, + "marker": marker_map[label], + "color": colormap["ours"], + "markersize": 80, + } + ) label = f"{direction}-cuE" ai, throughput = data[direction][desc][labelmap["CUETensorProduct"]] - roofline_data.append({"AI": float(ai), "throughput": throughput / 1000, "label": label, "marker": marker_map[label], "color": colormap['cuE'], "markersize": 80}) + roofline_data.append( + { + "AI": float(ai), + "throughput": throughput / 1000, + "label": label, + "marker": marker_map[label], + "color": colormap["cuE"], + "markersize": 80, + } + ) - cpu_roofs = {"A100-SXM-80GB FP32 Peak": 19.5} mem_bottlenecks = {"HBM2": 2.039} AI_v = {"": 9.56} - + draw_bounds = {"xmin": 0.4, "xmax": 15, "ymin": 0.15, "ymax": 25} - fig, ax = roofline_plot(draw_bounds, cpu_roofs, mem_bottlenecks, AI_v, roofline_data, fig_ratio=1.8, fig_dimension=4) + fig, ax = roofline_plot( + draw_bounds, + cpu_roofs, + mem_bottlenecks, + AI_v, + roofline_data, + fig_ratio=1.8, + fig_dimension=4, + ) handles, labels = ax.get_legend_handles_labels() - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] ax.legend(*zip(*unique)) fig.show() @@ -65,25 +112,35 @@ def plot_roofline(data_folder): for direction in ["forward", "backward"]: for impl in implementations: short_id, long_desc = desc.split("#") - long_desc = long_desc.replace("->", "$\\rightarrow$").replace(" x ", "$\ \\times\ $") - ai_ours, throughput_ours = data[direction][desc][labelmap["LoopUnrollTP"]] + long_desc = long_desc.replace("->", "$\\rightarrow$").replace( + " x ", "$\ \\times\ $" + ) + ai_ours, throughput_ours = data[direction][desc][ + labelmap["LoopUnrollTP"] + ] throughput_ours = f"{float(throughput_ours / 1000):.2f}" _, throughput_cue = data[direction][desc][labelmap["CUETensorProduct"]] throughput_cue = f"{float(throughput_cue / 1000):.2f}" - - result = ["\multirow{2}{*}{" + short_id + "}", "\multirow{2}{*}{" + long_desc + "}", dir_map[direction], f"{ai_ours:.1f}", throughput_cue, throughput_ours] + + result = [ + "\multirow{2}{*}{" + short_id + "}", + "\multirow{2}{*}{" + long_desc + "}", + dir_map[direction], + f"{ai_ours:.1f}", + throughput_cue, + throughput_ours, + ] if direction == "backward": result[0] = "" result[1] = "" rows.append(result) - + print(header) result = "" for i, row in enumerate(rows): - result += ' & '.join(row) + r"\\" + "\n" + result += " & ".join(row) + r"\\" + "\n" if row[2] == "B" and i < len(rows) - 1: result += "\cmidrule(r){3-6}" + "\n" - print(result.replace("[", "").replace("]", "").replace("uvu", "B")) - print("\\bottomrule\n\\end{tabular}") \ No newline at end of file + print("\\bottomrule\n\\end{tabular}") diff --git a/openequivariance/benchmark/plotting/plot_uvu.py b/openequivariance/benchmark/plotting/plot_uvu.py index 02b2ef9d..aa926534 100644 --- a/openequivariance/benchmark/plotting/plot_uvu.py +++ b/openequivariance/benchmark/plotting/plot_uvu.py @@ -1,16 +1,27 @@ +# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt -import os, json, pathlib, sys -from openequivariance.benchmark.plotting import * +import pathlib +from openequivariance.benchmark.plotting.plotting_utils import ( + set_grid, + colormap, + labelmap, + grouped_barchart, + load_benchmarks, +) + def plot_uvu(data_folder): data_folder = pathlib.Path(data_folder) benchmarks, metadata = load_benchmarks(data_folder) configs = metadata["config_labels"] - implementations = metadata["implementations"] + implementations = metadata["implementations"] for benchmark in benchmarks: - if benchmark["implementation_name"] == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs": + if ( + benchmark["implementation_name"] + == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs" + ): benchmark["implementation_name"] = "E3NNTensorProduct" for i, implementation in enumerate(implementations): @@ -18,33 +29,53 @@ def plot_uvu(data_folder): implementations[i] = "E3NNTensorProduct" def calculate_tp_per_sec(exp): - return exp["benchmark results"]["batch_size"] / (np.mean(exp["benchmark results"]["time_millis"]) * 0.001) + return exp["benchmark results"]["batch_size"] / ( + np.mean(exp["benchmark results"]["time_millis"]) * 0.001 + ) dataf32 = {"forward": {}, "backward": {}} for i, desc in enumerate(configs): for direction in ["forward", "backward"]: dataf32[direction][desc] = {} for impl in implementations: - f32_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == ""] - exp = filter(f32_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) + f32_benches = [ + b + for b in benchmarks + if b["benchmark results"]["rep_dtype"] == "" + ] + exp = filter( + f32_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) if exp is not None: dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) else: - dataf32[direction][desc][labelmap[impl]] = 0.0 - + dataf32[direction][desc][labelmap[impl]] = 0.0 + dataf64 = {"forward": {}, "backward": {}} for i, desc in enumerate(configs): for direction in ["forward", "backward"]: dataf64[direction][desc] = {} for impl in implementations: - f64_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == ""] - exp = filter(f64_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) + f64_benches = [ + b + for b in benchmarks + if b["benchmark results"]["rep_dtype"] == "" + ] + exp = filter( + f64_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) if exp is not None: dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) @@ -55,11 +86,37 @@ def calculate_tp_per_sec(exp): gs = fig.add_gridspec(2, 2) axs = gs.subplots(sharex=True) - grouped_barchart(dataf32["forward"], axs[0][0], bar_height_fontsize=0, xticklabel=False, colormap=colormap, group_spacing=6.0) - grouped_barchart(dataf32["backward"], axs[1][0], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0) - - grouped_barchart(dataf64["forward"], axs[0][1], bar_height_fontsize=0, xticklabel=False, colormap=colormap, group_spacing=6.0) - grouped_barchart(dataf64["backward"], axs[1][1], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0) + grouped_barchart( + dataf32["forward"], + axs[0][0], + bar_height_fontsize=0, + xticklabel=False, + colormap=colormap, + group_spacing=6.0, + ) + grouped_barchart( + dataf32["backward"], + axs[1][0], + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + ) + + grouped_barchart( + dataf64["forward"], + axs[0][1], + bar_height_fontsize=0, + xticklabel=False, + colormap=colormap, + group_spacing=6.0, + ) + grouped_barchart( + dataf64["backward"], + axs[1][1], + bar_height_fontsize=0, + colormap=colormap, + group_spacing=6.0, + ) for i in range(2): for j in range(2): @@ -70,7 +127,9 @@ def calculate_tp_per_sec(exp): axs[1][0].set_ylabel("Backward") handles, labels = axs[0][0].get_legend_handles_labels() - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] axs[0][0].legend(*zip(*unique)) fig.supylabel("Throughput (# tensor products / s)", x=0.036, y=0.605) @@ -83,11 +142,20 @@ def calculate_tp_per_sec(exp): fig.savefig(str(data_folder / "throughput_comparison.pdf")) speedup_table = [] - for direction in ['forward', 'backward']: - for impl in ['e3nn', 'cuE']: - for dtype_label, dtype_set in [('f32', dataf32), ('f64', dataf64)]: - speedups = [measurement['ours'] / measurement[impl] for _, measurement in dtype_set[direction].items() if impl in measurement] - stats = np.min(speedups), np.mean(speedups), np.median(speedups), np.max(speedups) + for direction in ["forward", "backward"]: + for impl in ["e3nn", "cuE"]: + for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: + speedups = [ + measurement["ours"] / measurement[impl] + for _, measurement in dtype_set[direction].items() + if impl in measurement + ] + stats = ( + np.min(speedups), + np.mean(speedups), + np.median(speedups), + np.max(speedups), + ) stats = [f"{stat:.2f}" for stat in stats] dir_print = direction @@ -96,6 +164,6 @@ def calculate_tp_per_sec(exp): result = [dir_print, impl, dtype_label] + stats speedup_table.append(result) - print('\t\t'.join(['Direction', 'Base', 'dtype', 'min', 'mean', 'med', 'max'])) + print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) for row in speedup_table: - print('\t\t'.join(row)) \ No newline at end of file + print("\t\t".join(row)) diff --git a/openequivariance/benchmark/plotting/plot_uvw.py b/openequivariance/benchmark/plotting/plot_uvw.py index 1244a527..01eb6fa3 100644 --- a/openequivariance/benchmark/plotting/plot_uvw.py +++ b/openequivariance/benchmark/plotting/plot_uvw.py @@ -1,27 +1,46 @@ +# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt -import os, json, pathlib, sys -from openequivariance.benchmark.plotting import * +import pathlib +from openequivariance.benchmark.plotting.plotting_utils import ( + set_grid, + colormap, + labelmap, + grouped_barchart, + calculate_tp_per_sec, + load_benchmarks, +) + def plot_uvw(data_folder): data_folder = pathlib.Path(data_folder) benchmarks, metadata = load_benchmarks(data_folder) - configs = metadata['config_labels'] - implementations = metadata['implementations'] - directions = metadata['directions'] + configs = metadata["config_labels"] + implementations = metadata["implementations"] + metadata["directions"] dataf32 = {"forward": {}, "backward": {}} for i, desc in enumerate(configs): for direction in ["forward", "backward"]: dataf32[direction][desc] = {} for impl in implementations: - if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: - f32_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == ""] - exp = filter(f32_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) + if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: + f32_benches = [ + b + for b in benchmarks + if b["benchmark results"]["rep_dtype"] + == "" + ] + exp = filter( + f32_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) dataf64 = {"forward": {}, "backward": {}} @@ -29,26 +48,64 @@ def plot_uvw(data_folder): for direction in ["forward", "backward"]: dataf64[direction][desc] = {} for impl in implementations: - if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: - f64_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == ""] - exp = filter(f64_benches, {"config_label": desc, - "direction": direction, - "implementation_name": impl - }, match_one=True) - dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) - - plt.rcParams['font.family'] = 'serif' - plt.rcParams.update({'font.size': 11}) - + if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: + f64_benches = [ + b + for b in benchmarks + if b["benchmark results"]["rep_dtype"] + == "" + ] + exp = filter( + f64_benches, + { + "config_label": desc, + "direction": direction, + "implementation_name": impl, + }, + match_one=True, + ) + dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) + + plt.rcParams["font.family"] = "serif" + plt.rcParams.update({"font.size": 11}) + fig = plt.figure(figsize=(7, 7)) gs = fig.add_gridspec(2, 2) - axs = gs.subplots(sharex=True, sharey='row') - - grouped_barchart(dataf32["forward"], axs[0][0], bar_height_fontsize=0, xticklabel=False, colormap=colormap, group_spacing=6.0) - grouped_barchart(dataf32["backward"], axs[1][0], bar_height_fontsize=0,xticklabel=True, colormap=colormap, group_spacing=6.0) - - grouped_barchart(dataf64["forward"], axs[0][1], bar_height_fontsize=0, xticklabel=False, colormap=colormap, group_spacing=6.0) - grouped_barchart(dataf64["backward"], axs[1][1], bar_height_fontsize=0,xticklabel=True, colormap=colormap, group_spacing=6.0) + axs = gs.subplots(sharex=True, sharey="row") + + grouped_barchart( + dataf32["forward"], + axs[0][0], + bar_height_fontsize=0, + xticklabel=False, + colormap=colormap, + group_spacing=6.0, + ) + grouped_barchart( + dataf32["backward"], + axs[1][0], + bar_height_fontsize=0, + xticklabel=True, + colormap=colormap, + group_spacing=6.0, + ) + + grouped_barchart( + dataf64["forward"], + axs[0][1], + bar_height_fontsize=0, + xticklabel=False, + colormap=colormap, + group_spacing=6.0, + ) + grouped_barchart( + dataf64["backward"], + axs[1][1], + bar_height_fontsize=0, + xticklabel=True, + colormap=colormap, + group_spacing=6.0, + ) for i in range(2): for j in range(2): @@ -63,7 +120,9 @@ def plot_uvw(data_folder): axs[1][1].set_xlabel("float64") handles, labels = axs[0][1].get_legend_handles_labels() - unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + unique = [ + (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + ] axs[0][1].legend(*zip(*unique)) fig.show() @@ -71,11 +130,20 @@ def plot_uvw(data_folder): fig.savefig(str(data_folder / "uvw_throughput_comparison.pdf")) speedup_table = [] - for direction in ['forward', 'backward']: - for impl in ['e3nn', 'cuE']: - for dtype_label, dtype_set in [('f32', dataf32), ('f64', dataf64)]: - speedups = [measurement['ours'] / measurement[impl] for label, measurement in dtype_set[direction].items() if impl in measurement and "DiffDock" in label] - stats = np.min(speedups), np.mean(speedups), np.median(speedups), np.max(speedups) + for direction in ["forward", "backward"]: + for impl in ["e3nn", "cuE"]: + for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: + speedups = [ + measurement["ours"] / measurement[impl] + for label, measurement in dtype_set[direction].items() + if impl in measurement and "DiffDock" in label + ] + stats = ( + np.min(speedups), + np.mean(speedups), + np.median(speedups), + np.max(speedups), + ) stats = [f"{stat:.2f}" for stat in stats] dir_print = direction @@ -85,6 +153,6 @@ def plot_uvw(data_folder): speedup_table.append(result) print("DiffDock") - print('\t\t'.join(['Direction', 'Base', 'dtype', 'min', 'mean', 'med', 'max'])) + print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) for row in speedup_table: - print('\t\t'.join(row)) \ No newline at end of file + print("\t\t".join(row)) diff --git a/openequivariance/benchmark/plotting/plotting_utils.py b/openequivariance/benchmark/plotting/plotting_utils.py index acc16b27..fae6898c 100644 --- a/openequivariance/benchmark/plotting/plotting_utils.py +++ b/openequivariance/benchmark/plotting/plotting_utils.py @@ -1,63 +1,70 @@ -import json, os, pathlib +# ruff: noqa: E741 +import json +import os +import pathlib from typing import Literal import numpy as np import matplotlib.pyplot as plt -Project = Literal['e3nn', 'cuE', 'ours'] +Project = Literal["e3nn", "cuE", "ours"] -def impl_to_project_func(s : str) -> Project: - if 'E3NN' in s: - return 'e3nn' - elif 'CUE' in s: - return 'cuE' + +def impl_to_project_func(s: str) -> Project: + if "E3NN" in s: + return "e3nn" + elif "CUE" in s: + return "cuE" else: - return 'ours' + return "ours" + -project_to_color_map : dict[Project, str] = { - 'e3nn' : 'lightblue', - 'cuE' : 'orange', - 'ours' : 'green' +project_to_color_map: dict[Project, str] = { + "e3nn": "lightblue", + "cuE": "orange", + "ours": "green", } -project_to_display_order_map : dict[Project, int] = { - 'e3nn' : 0, - 'cuE' : 1, - 'ours' : 2, +project_to_display_order_map: dict[Project, int] = { + "e3nn": 0, + "cuE": 1, + "ours": 2, } + def calculate_tp_per_sec(exp): - return exp["benchmark results"]["batch_size"] / (np.mean(exp["benchmark results"]["time_millis"]) * 0.001) + return exp["benchmark results"]["batch_size"] / ( + np.mean(exp["benchmark results"]["time_millis"]) * 0.001 + ) -def sort_impls_by_display_order(implementations : list[str]) -> None : - implementations.sort(key=lambda x : project_to_display_order_map[impl_to_project_func(x)]) -def get_latest_experiment_path() -> pathlib.Path: - latest_experiment = max( - (folder for folder in BENCHMARK_FOLDER.iterdir() if folder.is_dir() and folder.name.isdigit()), - key=lambda x: int(x.name) +def sort_impls_by_display_order(implementations: list[str]) -> None: + implementations.sort( + key=lambda x: project_to_display_order_map[impl_to_project_func(x)] ) - return latest_experiment + # ============================================================= -def load_benchmarks(path : pathlib.Path): + +def load_benchmarks(path: pathlib.Path): benchmarks = [] metadata = None - + files = os.listdir(path) for file in files: - if not os.path.isdir(path / file) and str(file).endswith('.json'): - with open( path / file , "r") as f: + if not os.path.isdir(path / file) and str(file).endswith(".json"): + with open(path / file, "r") as f: if file != "metadata.json": benchmarks.append(json.load(f)) benchmarks[-1]["filename"] = file else: metadata = json.load(f) - metadata["path"] = path - + metadata["path"] = path + return benchmarks, metadata + def filter(benchmarks, base, match_one=True): filtered_results = [] for benchmark in benchmarks: @@ -65,40 +72,46 @@ def filter(benchmarks, base, match_one=True): for key in base: if benchmark[key] != base[key]: matched = False - + if matched: filtered_results.append(benchmark) - + if len(filtered_results) == 0: - #print("WARNING: Filter matched no experiments") + # print("WARNING: Filter matched no experiments") return None - + if len(filtered_results) > 1 and match_one: print("Error, matched more than one experiment:") for experiment in filtered_results: print(experiment["filename"]) - assert(False) - + assert False + if match_one: filtered_results = filtered_results[0] - + return filtered_results -def grouped_barchart(data: dict, ax, bar_width=1.0, group_spacing=3.0, - rotate_xlabels=True, - colormap=None, - hatchmap=None, - label=True, - edgecolor='k', - edgewidth=1.0, - bar_height_fontsize=7, - xticklabel=True): - ''' + +def grouped_barchart( + data: dict, + ax, + bar_width=1.0, + group_spacing=3.0, + rotate_xlabels=True, + colormap=None, + hatchmap=None, + label=True, + edgecolor="k", + edgewidth=1.0, + bar_height_fontsize=7, + xticklabel=True, +): + """ data is a dictionary with the following structure: xtick_label -> dict(bar_label->value) - + Example Use: - + data = { "Adelie": { 'Bill Depth': 18.35, 'Bill Length': 38.79, 'Flipper Length': 89.95 @@ -110,25 +123,25 @@ def grouped_barchart(data: dict, ax, bar_width=1.0, group_spacing=3.0, 'Bill Depth': 14.98, 'Bill Length': 47.50, 'Flipper Length': 217.19 } } - + fig, ax = plt.subplots() grouped_barchart(data, ax) ax.legend() - ''' + """ xtick_labels = list(data.keys()) - color_keys = {} # Maps bars to colors - hatch_keys = {} # Maps bars to hatch patterns - + color_keys = {} # Maps bars to colors + hatch_keys = {} # Maps bars to hatch patterns + coord = 0.0 xticks = [] - + if colormap is None: - colormap = plt.get_cmap('tab10') - + colormap = plt.get_cmap("tab10") + for bar_group in data: bars = data[bar_group] xticks.append(coord) - + for i, bar_label in enumerate(bars): is_first_label = False if bar_label not in color_keys: @@ -143,54 +156,95 @@ def grouped_barchart(data: dict, ax, bar_width=1.0, group_spacing=3.0, else: hatch_keys[bar_label] = None - rects = None offset = bar_width * len(bars) / 2.0 - if is_first_label: - rects = ax.bar(coord - offset + bar_width * (i + 0.5), bars[bar_label], label=bar_label, width=bar_width, - color=color_keys[bar_label], edgecolor=edgecolor, linewidth=edgewidth, hatch=hatch_keys[bar_label]) + if is_first_label: + rects = ax.bar( + coord - offset + bar_width * (i + 0.5), + bars[bar_label], + label=bar_label, + width=bar_width, + color=color_keys[bar_label], + edgecolor=edgecolor, + linewidth=edgewidth, + hatch=hatch_keys[bar_label], + ) else: - rects = ax.bar(coord - offset + bar_width * (i + 0.5), bars[bar_label], width=bar_width, - color=color_keys[bar_label], - edgecolor=edgecolor, linewidth=edgewidth, hatch=hatch_keys[bar_label]) + rects = ax.bar( + coord - offset + bar_width * (i + 0.5), + bars[bar_label], + width=bar_width, + color=color_keys[bar_label], + edgecolor=edgecolor, + linewidth=edgewidth, + hatch=hatch_keys[bar_label], + ) if bar_height_fontsize > 0: - ax.bar_label(rects, padding=3, fontsize=bar_height_fontsize) + ax.bar_label(rects, padding=3, fontsize=bar_height_fontsize) coord += group_spacing - if xticklabel: + if xticklabel: if rotate_xlabels: - ax.set_xticks(xticks, labels=xtick_labels, rotation=45, ha='right') + ax.set_xticks(xticks, labels=xtick_labels, rotation=45, ha="right") else: ax.set_xticks(xticks, labels=xtick_labels) - else: - ax.set_xticks(xticks) - - - -def barchart(xlabels, heights, ax, bar_width=1.0, spacing=3.0, rotate_xlabels=True, colormap=None, data_label="_", - edgecolor='k', edgewidth=1.0, bar_height_fontsize=10): - ''' + else: + ax.set_xticks(xticks) + + +def barchart( + xlabels, + heights, + ax, + bar_width=1.0, + spacing=3.0, + rotate_xlabels=True, + colormap=None, + data_label="_", + edgecolor="k", + edgewidth=1.0, + bar_height_fontsize=10, +): + """ Usage: fig, ax = plt.subplots() barchart(["alpha", "beta", "gamma"], [5, 7, 3], ax, data_label="Test") - ''' - assert(len(xlabels) == len(heights)) - + """ + assert len(xlabels) == len(heights) + data = {} for i, xlabel in enumerate(xlabels): data[xlabel] = {data_label: heights[i]} - - grouped_barchart(data, ax, bar_width, spacing, rotate_xlabels, colormap, edgecolor, edgewidth, bar_height_fontsize, label=True) -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.backends.backend_pdf import PdfPages + grouped_barchart( + data, + ax, + bar_width, + spacing, + rotate_xlabels, + colormap, + edgecolor, + edgewidth, + bar_height_fontsize, + label=True, + ) -def roofline_plot(draw_bounds, cpu_roofs, mem_bottlenecks, AI_v, datapoints, compute_unit="TFLOP/s", mem_unit="TB/s", fig_ratio=2, fig_dimension=7): - ''' + +def roofline_plot( + draw_bounds, + cpu_roofs, + mem_bottlenecks, + AI_v, + datapoints, + compute_unit="TFLOP/s", + mem_unit="TB/s", + fig_ratio=2, + fig_dimension=7, +): + """ Example Usage: # Architecture-specific roofs @@ -203,9 +257,14 @@ def roofline_plot(draw_bounds, cpu_roofs, mem_bottlenecks, AI_v, datapoints, com draw_bounds = {"xmin": 1.0, "xmax": 25, "ymin": 0.4, "ymax": 25} fig, ax = roofline_plot(draw_bounds, cpu_roofs, mem_bottlenecks, AI_v, datapoints, fig_ratio=1.8, fig_dimension=5) - ''' - xmin, xmax, ymin, ymax = draw_bounds["xmin"], draw_bounds["xmax"], draw_bounds["ymin"], draw_bounds["ymax"] - + """ + xmin, xmax, ymin, ymax = ( + draw_bounds["xmin"], + draw_bounds["xmax"], + draw_bounds["ymin"], + draw_bounds["ymax"], + ) + fig = plt.figure() ax = plt.subplot(1, 1, 1) @@ -216,8 +275,9 @@ def roofline_plot(draw_bounds, cpu_roofs, mem_bottlenecks, AI_v, datapoints, com ########################################################## # Set size for explicitly setting axes widths/heights def set_size(w, h, ax=None): - """ w, h: width, height in inches """ - if not ax: ax = plt.gca() + """w, h: width, height in inches""" + if not ax: + ax = plt.gca() l = ax.figure.subplotpars.left r = ax.figure.subplotpars.right t = ax.figure.subplotpars.top @@ -233,7 +293,9 @@ def set_size(w, h, ax=None): ylogsize = float(np.log10(ymax / ymin)) m = xlogsize / ylogsize - print(f"Axis limits: 10^[({np.log10(xmin)} -> {np.log10(xmax)}) x ({np.log10(ymin)} -> {np.log10(ymax)})]") + print( + f"Axis limits: 10^[({np.log10(xmin)} -> {np.log10(xmax)}) x ({np.log10(ymin)} -> {np.log10(ymax)})]" + ) print(f"Plot logarithmic ratio: {m}\n") max_roof = max([throughput for throughput in cpu_roofs.values()]) @@ -241,71 +303,102 @@ def set_size(w, h, ax=None): # Draw slopes for mem_roof, slope in mem_bottlenecks.items(): - print(f"slope\t\"{mem_roof}\"\t\t{slope} {mem_unit}") + print(f'slope\t"{mem_roof}"\t\t{slope} {mem_unit}') y = [0, max_roof] x = [float(yy) / slope for yy in y] - ax.loglog(x, y, linewidth=1.0, linestyle='-.', color="grey", zorder=10) + ax.loglog(x, y, linewidth=1.0, linestyle="-.", color="grey", zorder=10) # Label - xpos = xmin * (10 ** (xlogsize * 0.04)) + xpos = xmin * (10 ** (xlogsize * 0.04)) ypos = 1.05 * xpos * slope if ypos < ymin: ypos = ymin * (10 ** (ylogsize * 0.02)) xpos = ypos / slope ax.annotate( - f"{mem_roof}: {slope} {mem_unit}", (xpos, ypos), + f"{mem_roof}: {slope} {mem_unit}", + (xpos, ypos), rotation=np.arctan(m / fig_ratio) * 180 / np.pi, - fontsize=11, ha="left", va='bottom', color="grey" + fontsize=11, + ha="left", + va="bottom", + color="grey", ) # Draw roofs for roof, value in cpu_roofs.items(): - print(f"roof\t\"{roof}\"\t\t{value} {compute_unit}") + print(f'roof\t"{roof}"\t\t{value} {compute_unit}') x = [value / max_slope, xmax * 10] - ax.loglog(x, [value] * len(x), linewidth=1.0, linestyle='-.', color="grey", zorder=10) + ax.loglog( + x, [value] * len(x), linewidth=1.0, linestyle="-.", color="grey", zorder=10 + ) ax.text( - xmax / (10 ** (xlogsize * 0.01)), value * (10 ** (ylogsize * 0.01)), - f"{roof}: {value} {compute_unit}", ha="right", fontsize=11, color="grey" + xmax / (10 ** (xlogsize * 0.01)), + value * (10 ** (ylogsize * 0.01)), + f"{roof}: {value} {compute_unit}", + ha="right", + fontsize=11, + color="grey", ) # Benchmarks for benchmark, AI in AI_v.items(): - print(f"benchmark\t\"{benchmark}\"\t\t{AI} FLOPs/Byte") + print(f'benchmark\t"{benchmark}"\t\t{AI} FLOPs/Byte') plt.axvline(x=AI, dashes=[10, 10, 3, 10], linewidth=0.4, color="#aaaaaa") - ax.text(AI / 1.15, ymin * 1.24, benchmark, fontsize=12, rotation=90, va="bottom", color="#888888") + ax.text( + AI / 1.15, + ymin * 1.24, + benchmark, + fontsize=12, + rotation=90, + va="bottom", + color="#888888", + ) # Datapoints for point in datapoints: AI = point["AI"] if isinstance(AI, str): AI = AI_v[AI] - ax.scatter(AI, point["throughput"], label=point["label"], marker=point["marker"], zorder=100, c=point["color"], s=point["markersize"]) + ax.scatter( + AI, + point["throughput"], + label=point["label"], + marker=point["marker"], + zorder=100, + c=point["color"], + s=point["markersize"], + ) # Set axes limits and layout ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) fig.tight_layout() set_size(fig_dimension * fig_ratio, fig_dimension) - + return fig, ax -# ============================================================= -plt.rcParams['font.family'] = 'serif' -plt.rcParams.update({'font.size': 11}) +# ============================================================= -labelmap = {"E3NNTensorProduct": "e3nn", "CUETensorProduct": "cuE", "LoopUnrollTP": "ours", - "E3NNTensorProductCompiledCUDAGraphs": "e3nn", - "LoopUnrollConvScatterSum": "fast-scattersum", - "CUEConvolution": "cuE-scattersum", - "CUEConvolutionFused": "cuE-fused", - "LoopUnrollConvDeterministic": "fast-fused-det", "LoopUnrollConvAtomic": "fast-fused-atomic" - } +plt.rcParams["font.family"] = "serif" +plt.rcParams.update({"font.size": 11}) + +labelmap = { + "E3NNTensorProduct": "e3nn", + "CUETensorProduct": "cuE", + "LoopUnrollTP": "ours", + "E3NNTensorProductCompiledCUDAGraphs": "e3nn", + "LoopUnrollConvScatterSum": "fast-scattersum", + "CUEConvolution": "cuE-scattersum", + "CUEConvolutionFused": "cuE-fused", + "LoopUnrollConvDeterministic": "fast-fused-det", + "LoopUnrollConvAtomic": "fast-fused-atomic", +} colormap = {"e3nn": "lightblue", "cuE": "orange", "ours": "g"} for key in ["fast-scattersum", "fast-fused-det", "fast-fused-atomic"]: colormap[key] = colormap["ours"] - + colormap["cuE-scattersum"] = colormap["cuE"] colormap["cuE-fused"] = colormap["cuE"] hatchmap = {"fast-fused-det": "oo", "fast-fused-atomic": "//", "cuE-fused": "//"} @@ -314,8 +407,30 @@ def set_size(w, h, ax=None): dtypes = ["", ""] dtype_labelmap = { "": "float32", - "": "float64"} + "": "float64", +} + def set_grid(ax): ax.set_axisbelow(True) - ax.grid(True) \ No newline at end of file + ax.grid(True) + + +__all__ = [ + "Project", + "impl_to_project_func", + "project_to_color_map", + "project_to_display_order_map", + "calculate_tp_per_sec", + "sort_impls_by_display_order", + "load_benchmarks", + "filter", + "grouped_barchart", + "barchart", + "roofline_plot", + "labelmap", + "colormap", + "directions", + "dtypes", + "dtype_labelmap", +] diff --git a/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/benchmark/random_buffer_utils.py index 286d58fb..41fb7cb6 100644 --- a/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/benchmark/random_buffer_utils.py @@ -2,7 +2,10 @@ from openequivariance.implementations.e3nn_lite import TPProblem -def get_random_buffers_forward(tpp : TPProblem, batch_size : int, prng_seed : int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + +def get_random_buffers_forward( + tpp: TPProblem, batch_size: int, prng_seed: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Return properly sized numpy arrays needed to execute a tensor product in the forward direction Supports shared vs non-shared weights @@ -10,94 +13,166 @@ def get_random_buffers_forward(tpp : TPProblem, batch_size : int, prng_seed : i assert isinstance(tpp, TPProblem) rng = np.random.default_rng(prng_seed) - in1 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) - in2 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) - - weights_size = tuple([tpp.weight_numel]) if tpp.shared_weights else tuple([batch_size, tpp.weight_numel]) + in1 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([batch_size, tpp.weight_numel]) + ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype) out = np.zeros(shape=(batch_size, tpp.irreps_out.dim), dtype=tpp.weight_dtype) - return in1, in2, weights, out + return in1, in2, weights, out + -def get_random_buffers_backward(tpp : TPProblem, batch_size : int, prng_seed : int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +def get_random_buffers_backward( + tpp: TPProblem, batch_size: int, prng_seed: int +) -> tuple[ + np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray +]: """ Return properly sized numpy arrays needed to execute a tensor product in the backward direction Supports shared vs non-shared weights """ assert isinstance(tpp, TPProblem) rng = np.random.default_rng(prng_seed) - - in1 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) - in2 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) - out_grad = np.array(rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype) - weights_size = tuple([tpp.weight_numel]) if tpp.shared_weights else tuple([batch_size, tpp.weight_numel]) + in1 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([batch_size, tpp.weight_numel]) + ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) weights_grad = np.zeros_like(weights) in1_grad = np.zeros_like(in1) - in2_grad = np.zeros_like(in2) + in2_grad = np.zeros_like(in2) return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad -def get_random_buffers_double_backward(tpp : TPProblem, batch_size : int, prng_seed : int) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + +def get_random_buffers_double_backward( + tpp: TPProblem, batch_size: int, prng_seed: int +) -> tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +]: """ Return properly sized numpy arrays needed to execute a tensor product in the double backward direction Supports shared vs non-shared weights """ assert isinstance(tpp, TPProblem) rng = np.random.default_rng(prng_seed) - - in1 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) - in2 = np.array(rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) - out_grad = np.array(rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype) - weights_size = tuple([tpp.weight_numel]) if tpp.shared_weights else tuple([batch_size, tpp.weight_numel]) + in1 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([batch_size, tpp.weight_numel]) + ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) weights_grad = np.zeros_like(weights) in1_grad = np.zeros_like(in1) - in2_grad = np.zeros_like(in2) + in2_grad = np.zeros_like(in2) out_double_grad = np.zeros_like(out_grad) - return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad, out_double_grad + return ( + in1, + in2, + out_grad, + weights, + weights_grad, + in1_grad, + in2_grad, + out_double_grad, + ) def get_random_buffers_forward_conv( - tpp : TPProblem, - node_count: int, - edge_count: int, - prng_seed : int): + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): rng = np.random.default_rng(prng_seed) - in1 = np.array(rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) - in2 = np.array(rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) - - weights_size = tuple([tpp.weight_numel]) if tpp.shared_weights else tuple([edge_count, tpp.weight_numel]) + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype) out = np.zeros(shape=(node_count, tpp.irreps_out.dim), dtype=tpp.weight_dtype) - return in1, in2, weights, out + return in1, in2, weights, out -def get_random_buffers_backward_conv(tpp : TPProblem, node_count: int, edge_count: int, prng_seed : int): +def get_random_buffers_backward_conv( + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): """ Return properly sized numpy arrays needed to execute a tensor product in the backward direction Supports shared vs non-shared weights """ rng = np.random.default_rng(prng_seed) - - in1 = np.array(rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) - in2 = np.array(rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) - out_grad = np.array(rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype) - weights_size = tuple([tpp.weight_numel]) if tpp.shared_weights else tuple([edge_count, tpp.weight_numel]) + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) weights_grad = np.zeros_like(weights) in1_grad = np.zeros_like(in1) - in2_grad = np.zeros_like(in2) + in2_grad = np.zeros_like(in2) return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/benchmark/tpp_creation_utils.py index d3c1caf0..18f3a84c 100644 --- a/openequivariance/benchmark/tpp_creation_utils.py +++ b/openequivariance/benchmark/tpp_creation_utils.py @@ -13,9 +13,7 @@ class FullyConnectedTPProblem(TPProblem): - def __init__( - self, irreps_in1, irreps_in2, irreps_out, **kwargs - ) -> None: + def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: irreps_in1 = Irreps(irreps_in1) irreps_in2 = Irreps(irreps_in2) irreps_out = Irreps(irreps_out) @@ -35,6 +33,7 @@ def __init__( **kwargs, ) + class ElementwiseTPProblem(TPProblem): def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: irreps_in1 = Irreps(irreps_in1).simplify() @@ -43,7 +42,9 @@ def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError: - raise ValueError(f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep") + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) assert irreps_in1.num_irreps == irreps_in2.num_irreps @@ -93,7 +94,9 @@ def __init__( try: filter_ir_out = [Irrep(ir) for ir in filter_ir_out] except ValueError: - raise ValueError(f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep") + raise ValueError( + f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" + ) out = [] instr = [] @@ -110,15 +113,18 @@ def __init__( out = Irreps(out) out, p, _ = out.sort() - instr = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr] + instr = [ + (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr + ] super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) class ChannelwiseTPP(TPProblem): - ''' + """ Modified from mace/mace/modules/irreps_tools.py. - ''' + """ + def __init__( self, irreps_in1: Irreps, @@ -126,8 +132,8 @@ def __init__( irreps_out: Irreps, label: Optional[str] = None, irrep_dtype=np.float32, - weight_dtype=np.float32): - + weight_dtype=np.float32, + ): trainable = True irreps1 = Irreps(irreps_in1) irreps2 = Irreps(irreps_in2) @@ -153,12 +159,18 @@ def __init__( ] instructions = sorted(instructions, key=lambda x: x[2]) - super().__init__(irreps1, irreps2, irreps_out, instructions, + super().__init__( + irreps1, + irreps2, + irreps_out, + instructions, internal_weights=False, shared_weights=False, label=label, irrep_dtype=irrep_dtype, - weight_dtype=weight_dtype) + weight_dtype=weight_dtype, + ) + class SingleInstruction(TPProblem): def __init__( @@ -167,15 +179,20 @@ def __init__( irreps_in2: Irreps, irreps_in3: Irreps, mode: str, - label: Optional[str] = None): - + label: Optional[str] = None, + ): trainable = True irreps1 = Irreps(irreps_in1) irreps2 = Irreps(irreps_in2) irreps3 = Irreps(irreps_in3) instructions = [(0, 0, 0, mode, trainable)] - super().__init__(irreps1, irreps2, irreps3, instructions, + super().__init__( + irreps1, + irreps2, + irreps3, + instructions, internal_weights=False, shared_weights=False, - label=label) \ No newline at end of file + label=label, + ) diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index 022c17ea..c9816142 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -1,76 +1,112 @@ -import os, warnings, tempfile, warnings +# ruff: noqa : F401, E402 +import os +import warnings from pathlib import Path from openequivariance.benchmark.logging_utils import getLogger oeq_root = str(Path(__file__).parent.parent) -build_ext = True -TORCH_COMPILE=True +build_ext = True +TORCH_COMPILE = True torch_module, generic_module = None, None -postprocess_kernel = lambda kernel: kernel - -if not build_ext: - from openequivariance.extlib.generic_module import * +postprocess_kernel = lambda kernel: kernel # noqa : E731 + +if not build_ext: + from openequivariance.extlib.generic_module import ( + GenericTensorProductImpl, + JITTPImpl, + ConvolutionImpl, + JITConvImpl, + GroupMM_F32, + GroupMM_F64, + DeviceProp, + DeviceBuffer, + GPUTimer, + ) else: - from setuptools import setup from torch.utils.cpp_extension import library_paths, include_paths global torch import torch - extra_cflags=["-O3"] - generic_sources = ['generic_module.cpp'] - torch_sources = ['torch_tp_jit.cpp'] + extra_cflags = ["-O3"] + generic_sources = ["generic_module.cpp"] + torch_sources = ["torch_tp_jit.cpp"] - include_dirs, extra_link_args = ['util'], None + include_dirs, extra_link_args = ["util"], None if torch.version.cuda: - extra_link_args = ['-Wl,--no-as-needed', '-lcuda', '-lcudart', '-lnvrtc'] + extra_link_args = ["-Wl,--no-as-needed", "-lcuda", "-lcudart", "-lnvrtc"] try: - cuda_libs = library_paths('cuda')[1] - extra_link_args.append('-L' + cuda_libs) - if os.path.exists(cuda_libs + '/stubs'): - extra_link_args.append('-L' + cuda_libs + '/stubs') + cuda_libs = library_paths("cuda")[1] + extra_link_args.append("-L" + cuda_libs) + if os.path.exists(cuda_libs + "/stubs"): + extra_link_args.append("-L" + cuda_libs + "/stubs") except Exception as e: getLogger().info(str(e)) extra_cflags.append("-DCUDA_BACKEND") elif torch.version.hip: - extra_link_args = [ '-Wl,--no-as-needed', '-lhiprtc'] + extra_link_args = ["-Wl,--no-as-needed", "-lhiprtc"] def postprocess(kernel): kernel = kernel.replace("__syncwarp();", "__threadfence_block();") kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel + return kernel + postprocess_kernel = postprocess extra_cflags.append("-DHIP_BACKEND") - generic_sources = [oeq_root + '/extension/' + src for src in generic_sources] - torch_sources = [oeq_root + '/extension/' + src for src in torch_sources] - include_dirs = [oeq_root + '/extension/' + d for d in include_dirs] + include_paths('cuda') + generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] + torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] + include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths( + "cuda" + ) torch_compile_exception = None with warnings.catch_warnings(): warnings.simplefilter("ignore") try: - torch_module = torch.utils.cpp_extension.load("torch_tp_jit", - torch_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args) + torch_module = torch.utils.cpp_extension.load( + "torch_tp_jit", + torch_sources, + extra_cflags=extra_cflags, + extra_include_paths=include_dirs, + extra_ldflags=extra_link_args, + ) torch.ops.load_library(torch_module.__file__) except Exception as e: # If compiling torch fails (e.g. low gcc version), we should fall back to the - # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE=False + # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). + TORCH_COMPILE = False torch_compile_exception = e - generic_module = torch.utils.cpp_extension.load("generic_module", - generic_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args) + generic_module = torch.utils.cpp_extension.load( + "generic_module", + generic_sources, + extra_cflags=extra_cflags, + extra_include_paths=include_dirs, + extra_ldflags=extra_link_args, + ) if not TORCH_COMPILE: - warnings.warn("Could not compile integrated PyTorch wrapper. Falling back to Pybind11" + - f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}") - -from generic_module import * \ No newline at end of file + warnings.warn( + "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" + + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" + ) + + from generic_module import ( + GenericTensorProductImpl, + JITTPImpl, + ConvolutionImpl, + JITConvImpl, + GroupMM_F32, + GroupMM_F64, + DeviceProp, + DeviceBuffer, + GPUTimer, + ) diff --git a/openequivariance/implementations/CUETensorProduct.py b/openequivariance/implementations/CUETensorProduct.py index 9030220a..4011eec4 100644 --- a/openequivariance/implementations/CUETensorProduct.py +++ b/openequivariance/implementations/CUETensorProduct.py @@ -1,19 +1,28 @@ import numpy as np -import tempfile, json, os +import tempfile +import json +import os +import itertools +from typing import Iterator from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import * +from openequivariance.implementations.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.benchmark.tpp_creation_utils import * -from openequivariance.extlib import * +from openequivariance.benchmark.tpp_creation_utils import ( + ChannelwiseTPP, + FullyConnectedTPProblem, + SingleInstruction, +) +from openequivariance.extlib import GPUTimer from openequivariance.implementations.utils import count_cg_non_zero os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1" logger = getLogger() + class CUETensorProduct(TensorProductBase): - def __init__(self, config : TPProblem, torch_op=True): + def __init__(self, config: TPProblem, torch_op=True): super().__init__(config, torch_op=torch_op) global torch @@ -23,20 +32,17 @@ def __init__(self, config : TPProblem, torch_op=True): import e3nn.o3 as o3 # To-do: abstract and place into the TensorProduct class - self.is_uvw = (config.instructions[0].connection_mode == "uvw") + self.is_uvw = config.instructions[0].connection_mode == "uvw" supported_tpp_types = [ ChannelwiseTPP, FullyConnectedTPProblem, - SingleInstruction + SingleInstruction, ] - assert(config.irrep_dtype == config.weight_dtype) + assert config.irrep_dtype == config.weight_dtype - np_to_torch_dtype = { - np.float32: torch.float32, - np.float64: torch.float64 - } + np_to_torch_dtype = {np.float32: torch.float32, np.float64: torch.float64} class O3_e3nn(cue.O3): def __mul__( # pylint: disable=no-self-argument @@ -64,14 +70,17 @@ def __lt__( # pylint: disable=no-self-argument @classmethod def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): + for l in itertools.count(0): # noqa : E741 yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l) self.cue_tp = None torch_dtype = np_to_torch_dtype[config.irrep_dtype] - assert(any([isinstance(config, supported_ttp_type)] for supported_ttp_type in supported_tpp_types)) + assert any( + [isinstance(config, supported_ttp_type)] + for supported_ttp_type in supported_tpp_types + ) if isinstance(config, ChannelwiseTPP) or isinstance(config, SingleInstruction): self.cue_tp = cuet.ChannelWiseTensorProduct( cue.Irreps(O3_e3nn, str(config.irreps_in1)), @@ -81,17 +90,17 @@ def iterator(cls) -> Iterator["O3_e3nn"]: shared_weights=config.shared_weights, internal_weights=config.internal_weights, dtype=torch_dtype, - math_dtype=torch_dtype + math_dtype=torch_dtype, ) torch._dynamo.config.cache_size_limit = 64 - self.cue_tp.to('cuda') + self.cue_tp.to("cuda") # self.cue_tp = torch.compile(self.cue_tp, fullgraph=True, mode="default") self.tp_correctness = self.cue_tp self.forward = self.cue_tp.__call__ self.forward_correctness = lambda x, y, W: self.tp_correctness(x, y, W) - + if isinstance(config, FullyConnectedTPProblem): e = cue.descriptors.fully_connected_tensor_product( cue.Irreps("O3", str(config.irreps_in1)), @@ -99,65 +108,65 @@ def iterator(cls) -> Iterator["O3_e3nn"]: cue.Irreps("O3", str(config.irreps_out)), ) - assert(config.weight_numel == e.inputs[0].irreps.dim) - self.cue_tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, - math_dtype=np_to_torch_dtype[config.irrep_dtype]) + assert config.weight_numel == e.inputs[0].irreps.dim + self.cue_tp = cuet.EquivariantTensorProduct( + e, layout=cue.ir_mul, math_dtype=np_to_torch_dtype[config.irrep_dtype] + ) - self.cue_tp.to('cuda') + self.cue_tp.to("cuda") self.forward = lambda x, y, W: self.cue_tp(W, x, y) - # Only used for correctness - self.tp_correctness = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, - math_dtype=np_to_torch_dtype[config.irrep_dtype]) - self.tp_correctness.to('cuda') + # Only used for correctness + self.tp_correctness = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, math_dtype=np_to_torch_dtype[config.irrep_dtype] + ) + self.tp_correctness.to("cuda") self.forward_correctness = lambda x, y, W: self.tp_correctness(W, x, y) def forward_cpu( - self, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_out : np.ndarray, - weights : np.ndarray, - ) -> None: - torch_L1_in = torch.tensor(L1_in, device='cuda') - torch_L2_in = torch.tensor(L2_in, device='cuda') - torch_weights = torch.tensor(weights, device='cuda') - + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights, device="cuda") + torch_L3_out = self.forward_correctness(torch_L1_in, torch_L2_in, torch_weights) L3_out[:] = torch_L3_out.numpy(force=True) def backward_cpu( - self, - L1_in : np.ndarray, - L1_grad : np.ndarray, - L2_in : np.ndarray, - L2_grad : np.ndarray, - L3_grad : np.ndarray, - weights : np.ndarray, - weights_grad : np.ndarray, - ) -> None: - - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights = torch.tensor(weights, requires_grad=True, device='cuda') + self, + L1_in: np.ndarray, + L1_grad: np.ndarray, + L2_in: np.ndarray, + L2_grad: np.ndarray, + L3_grad: np.ndarray, + weights: np.ndarray, + weights_grad: np.ndarray, + ) -> None: + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") torch_out = self.forward_correctness(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_grad, device='cuda') + torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") torch_out.backward(gradient=torch_L3_grad_in) - + L1_grad[:] = torch_L1_in.grad.numpy(force=True) L2_grad[:] = torch_L2_in.grad.numpy(force=True) weights_grad[:] = torch_weights.grad.numpy(force=True) - def analyze_trace(self, trace_file): - ''' + """ Need to update this function for the uvw case. - ''' - assert not self.is_uvw + """ + assert not self.is_uvw trace = None with open(trace_file, "r") as f: @@ -169,127 +178,171 @@ def analyze_trace(self, trace_file): for event in trace["traceEvents"]: if "args" in event and "stream" in event["args"]: event_time_ms = event["dur"] / 1000 - total += event_time_ms + total += event_time_ms - if "TensorProductUniform1dKernel" in event["name"] \ - or "channelwise_kernel_fwd" in event["name"] \ - or "channelwise_kernel_bwd" in event["name"]: - tp_time += event_time_ms + if ( + "TensorProductUniform1dKernel" in event["name"] + or "channelwise_kernel_fwd" in event["name"] + or "channelwise_kernel_bwd" in event["name"] + ): + tp_time += event_time_ms - return tp_time + return tp_time def benchmark_forward( - self, - num_warmup : int, - num_iter : int, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_buffer : np.ndarray, - weights : np.ndarray) -> np.ndarray: - ''' + self, + num_warmup: int, + num_iter: int, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_buffer: np.ndarray, + weights: np.ndarray, + ) -> np.ndarray: + """ When we don't want to include torch overhead, we use the Pytorch profiler to extract the device time that the kernel takes. - ''' + """ if self.torch_op: - return super().benchmark_forward(num_warmup, num_iter, L1_in, L2_in, L3_buffer, weights) + return super().benchmark_forward( + num_warmup, num_iter, L1_in, L2_in, L3_buffer, weights + ) else: from torch.profiler import profile, record_function, ProfilerActivity + time_millis = np.zeros(num_iter, dtype=np.float32) - torch_L1_in = torch.tensor(L1_in).to(device='cuda').detach() - torch_L2_in = torch.tensor(L2_in).to(device='cuda').detach() - torch_weights = torch.tensor(weights).to(device='cuda').detach() + torch_L1_in = torch.tensor(L1_in).to(device="cuda").detach() + torch_L2_in = torch.tensor(L2_in).to(device="cuda").detach() + torch_weights = torch.tensor(weights).to(device="cuda").detach() timer = GPUTimer() for i in range(num_warmup): - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) + self.forward(torch_L1_in, torch_L2_in, torch_weights) trace_file = tempfile.NamedTemporaryFile().name for i in range(num_iter): timer.clear_L2_cache() - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: with record_function("cue_forward"): - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) + self.forward(torch_L1_in, torch_L2_in, torch_weights) - prof.export_chrome_trace(trace_file) + prof.export_chrome_trace(trace_file) time_millis[i] = self.analyze_trace(trace_file) return time_millis def benchmark_backward( - self, - num_warmup : int, - num_iter : int, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_buffer : np.ndarray, - weights : np.ndarray, - L1_grad : np.ndarray, - L2_grad : np.ndarray, - weights_grad : np.ndarray - ) -> np.ndarray: + self, + num_warmup: int, + num_iter: int, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_buffer: np.ndarray, + weights: np.ndarray, + L1_grad: np.ndarray, + L2_grad: np.ndarray, + weights_grad: np.ndarray, + ) -> np.ndarray: if self.torch_op: return super().benchmark_backward( - num_warmup, num_iter, - L1_in, L2_in, - L3_buffer, weights, - L1_grad, L2_grad, - weights_grad) + num_warmup, + num_iter, + L1_in, + L2_in, + L3_buffer, + weights, + L1_grad, + L2_grad, + weights_grad, + ) else: from torch.profiler import profile, record_function, ProfilerActivity + time_millis = np.zeros(num_iter, dtype=np.float32) - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights = torch.tensor(weights, requires_grad=True, device='cuda') + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_buffer, device='cuda') + torch_L3_grad_in = torch.tensor(L3_buffer, device="cuda") timer = GPUTimer() for i in range(num_warmup): - torch_out.backward(gradient=torch_L3_grad_in, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + torch_out.backward( + gradient=torch_L3_grad_in, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) trace_file = tempfile.NamedTemporaryFile().name for i in range(num_iter): timer.clear_L2_cache() - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: with record_function("cue_backward"): - torch_out.backward(gradient=torch_L3_grad_in, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + torch_out.backward( + gradient=torch_L3_grad_in, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) prof.export_chrome_trace(trace_file) time_millis[i] = self.analyze_trace(trace_file) return time_millis - # Copied over from loop unroller to match arithmetic intensity on roofline plots - def calculate_flops_forward(self, batch_size : int) -> dict: + # Copied over from loop unroller to match arithmetic intensity on roofline plots + def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) else: tpp = self.config - flop_count = {'CG_decomposition': 0, 'linear_combination': 0, 'outer_products': 0} - for ins in tpp.instructions: - l1, l2, l3 = tpp.irreps_in1[ins.i_in1].ir.l, tpp.irreps_in2[ins.i_in2].ir.l, tpp.irreps_out[ins.i_out].ir.l - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (ins.path_shape[0] * ins.path_shape[1]) - flop_count["linear_combination"] += (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 + flop_count = { + "CG_decomposition": 0, + "linear_combination": 0, + "outer_products": 0, + } + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + flop_count["linear_combination"] += ( + (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 + ) flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= batch_size # Weights do not require FMA here + flop_count["linear_combination"] *= ( + batch_size # Weights do not require FMA here + ) flop_count["total"] = sum(flop_count.values()) return flop_count - def calculate_flops_backward(self, batch_size : int) -> dict: + def calculate_flops_backward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_backward(batch_size) else: tpp = self.config - flop_count = {'backward': 0} - for ins in tpp.instructions: - l1, l2, l3 = tpp.irreps_in1[ins.i_in1].ir.l, tpp.irreps_in2[ins.i_in2].ir.l, tpp.irreps_out[ins.i_out].ir.l - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (ins.path_shape[0] * ins.path_shape[1]) + flop_count = {"backward": 0} + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) flop_count["backward"] *= 9 * batch_size flop_count["total"] = sum(flop_count.values()) @@ -297,4 +350,4 @@ def calculate_flops_backward(self, batch_size : int) -> dict: @staticmethod def name(): - return "CUETensorProduct" \ No newline at end of file + return "CUETensorProduct" diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index 54ae5d9c..4f4fe120 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -1,19 +1,23 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import * +from openequivariance.implementations.e3nn_lite import Irreps, TPProblem from itertools import accumulate -from openequivariance.benchmark.logging_utils import * -from openequivariance.implementations.TensorProductBase import * +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.implementations.TensorProductBase import TensorProductBase + logger = getLogger() + class SMEMCapacityException(Exception): def __init__(self, message): self.message = message super().__init__(self.message) + class IrrepMapping: - ''' + """ Maps irreps from a source to a destination set. - ''' + """ + def __init__(self, src_irreps, idxs): self.src_irreps = src_irreps self.idxs = sorted(list(idxs)) @@ -51,22 +55,29 @@ def __init__(self, src_irreps, idxs): self.storeback_procedure = {idx: "write" for idx in self.idxs} + class CGTensor: def __init__(self, l1, l2, l3, normalization_factor, dtype): - suffix_map = { - np.float32: "f", - np.float64: "L" - } + suffix_map = {np.float32: "f", np.float64: "L"} tensor = TensorProductBase.load_cg_tensor(l1, l2, l3) - coord1, coord2, coord3 = [arr.astype(np.int32).copy() for arr in np.nonzero(tensor)] - float_values = tensor[np.nonzero(tensor)].astype(dtype).copy() * normalization_factor - values = [str(float.hex(float(val))) + suffix_map[dtype] for val in float_values] - - self.tuples = [(coord1[i], coord2[i], coord3[i], values[i]) for i in range(len(values))] + coord1, coord2, coord3 = [ + arr.astype(np.int32).copy() for arr in np.nonzero(tensor) + ] + float_values = ( + tensor[np.nonzero(tensor)].astype(dtype).copy() * normalization_factor + ) + values = [ + str(float.hex(float(val))) + suffix_map[dtype] for val in float_values + ] + + self.tuples = [ + (coord1[i], coord2[i], coord3[i], values[i]) for i in range(len(values)) + ] self.tuples.sort(key=lambda tup: (tup[1], tup[0], tup[2])) self.nnz = len(values) + class ComputationSegment: def __init__(self, L1Map, L2Map, L3Map, problem, smem, weight_offset, irrep_dtype): self.L1Map = L1Map @@ -76,23 +87,37 @@ def __init__(self, L1Map, L2Map, L3Map, problem, smem, weight_offset, irrep_dtyp self.problem = problem self.smem = smem - self.weight_offset = weight_offset # Starting point for weights in overall problem. + self.weight_offset = ( + weight_offset # Starting point for weights in overall problem. + ) self.L1 = problem.irreps_in1 self.L2 = problem.irreps_in2 self.L3 = problem.irreps_out - self.interactions = [(u, v, w, - CGTensor(self.L1[u].ir.l, self.L2[v].ir.l, self.L3[w].ir.l, path_weight, irrep_dtype)) - for (u, v, w, _, _, path_weight, _) in problem.instructions] - - #self.interactions.sort(key=lambda x: (x[2], x[0], x[1])) + self.interactions = [ + ( + u, + v, + w, + CGTensor( + self.L1[u].ir.l, + self.L2[v].ir.l, + self.L3[w].ir.l, + path_weight, + irrep_dtype, + ), + ) + for (u, v, w, _, _, path_weight, _) in problem.instructions + ] + + # self.interactions.sort(key=lambda x: (x[2], x[0], x[1])) def create_schedule_case2(instructions, memory_per_warp, calculate_smem, direction): segments = [] - cL1 = set([inst[0] for inst in instructions]) - cL2 = set([inst[1] for inst in instructions]) + cL1 = set([inst[0] for inst in instructions]) + cL2 = set([inst[1] for inst in instructions]) cL3, cinst = set(), [] inst_idx = 0 @@ -100,7 +125,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi smem_required = None if inst_idx < len(instructions): u, v, w, *others = instructions[inst_idx] - smem_required = calculate_smem(cL1, cL2, cL3 | {w}, cinst + [inst_idx]) + smem_required = calculate_smem(cL1, cL2, cL3 | {w}, cinst + [inst_idx]) else: inst_idx += 1 @@ -109,7 +134,9 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL3, cinst = set(), [] else: - raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException( + f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!" + ) else: cL3.add(w) cinst.append(inst_idx) @@ -117,6 +144,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi return segments + def create_schedule_case3(instructions, memory_per_warp, calculate_smem, direction): segments = [] cL1, cL2, cL3, cinst = set(), set(), set(), [] @@ -126,7 +154,9 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi smem_required = None if inst_idx < len(instructions): u, v, w, *others = instructions[inst_idx] - smem_required = calculate_smem(cL1 | {u}, cL2 | {v}, cL3 | {w}, cinst + [inst_idx]) + smem_required = calculate_smem( + cL1 | {u}, cL2 | {v}, cL3 | {w}, cinst + [inst_idx] + ) else: inst_idx += 1 @@ -135,7 +165,9 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL1, cL2, cL3, cinst = set(), set(), set(), [] else: - raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException( + f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!" + ) else: cL1.add(u) cL2.add(v) @@ -145,20 +177,22 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi return segments + class ProblemSplitter: - ''' - Chunks an input problem to produce an output where all + """ + Chunks an input problem to produce an output where all multiplicities lie below a provided threshold. The irreps of the output are "ChildIrreps", and the new - instructions are "ChildInstructions". - ''' + instructions are "ChildInstructions". + """ + class ChildIrrep: def __init__(self, mul_ir, parent_idx, mul_start): - self.mul_ir, self.parent_idx, self.mul_start = mul_ir, parent_idx, mul_start + self.mul_ir, self.parent_idx, self.mul_start = mul_ir, parent_idx, mul_start class ChildInstruction: def __init__(self, instruction_tup, parent_idx): - self.instruction_tup, self.parent_idx = instruction_tup, parent_idx + self.instruction_tup, self.parent_idx = instruction_tup, parent_idx def __init__(self, input, mult_threshold): self.input = input @@ -167,46 +201,86 @@ def __init__(self, input, mult_threshold): input_reps = [input.irreps_in1, input.irreps_in2, input.irreps_out] child_reps = [[], [], []] - self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx] + self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx] - for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3 - for mul_ir_idx, mul_ir in enumerate(input_rep): # Loop over mul_ir's in each representation + for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3 + for mul_ir_idx, mul_ir in enumerate( + input_rep + ): # Loop over mul_ir's in each representation self.irrep_maps[input_rep_idx, mul_ir_idx] = [] - for mul_start in range(0, mul_ir.mul, mult_threshold): - mul = min(mult_threshold, mul_ir.mul - mul_start) - child_reps[input_rep_idx] += [self.ChildIrrep((mul, mul_ir.ir), input_rep_idx, mul_start)] - self.irrep_maps[input_rep_idx, mul_ir_idx].append(len(child_reps[input_rep_idx]) - 1) + for mul_start in range(0, mul_ir.mul, mult_threshold): + mul = min(mult_threshold, mul_ir.mul - mul_start) + child_reps[input_rep_idx] += [ + self.ChildIrrep((mul, mul_ir.ir), input_rep_idx, mul_start) + ] + self.irrep_maps[input_rep_idx, mul_ir_idx].append( + len(child_reps[input_rep_idx]) - 1 + ) new_instructions = [] - for inst_idx, (u, v, w, connection_mode, has_weight, path_weight, path_shape) in enumerate(input.instructions): + for inst_idx, ( + u, + v, + w, + connection_mode, + has_weight, + path_weight, + path_shape, + ) in enumerate(input.instructions): if connection_mode == "uvu": for i, idx1 in enumerate(self.irrep_maps[0, u]): for idx2 in self.irrep_maps[1, v]: new_instructions.append( - self.ChildInstruction((idx1, idx2, self.irrep_maps[2, w][i], connection_mode, has_weight, path_weight ** 2), - inst_idx)) + self.ChildInstruction( + ( + idx1, + idx2, + self.irrep_maps[2, w][i], + connection_mode, + has_weight, + path_weight**2, + ), + inst_idx, + ) + ) elif connection_mode == "uvw": for idx1 in self.irrep_maps[0, u]: for idx2 in self.irrep_maps[1, v]: for idx3 in self.irrep_maps[2, w]: new_instructions.append( - self.ChildInstruction( (idx1, idx2, idx3, - connection_mode, has_weight, - path_weight ** 2), inst_idx)) - - self.L1, self.L2, self.L3 = [ Irreps([child.mul_ir for child in reps]) - for reps in child_reps] - self.output = TPProblem(self.L1, self.L2, self.L3, - [child.instruction_tup for child in new_instructions], - irrep_normalization="none", path_normalization="none", - internal_weights=False, shared_weights=input.shared_weights) - - assert(self.output.weight_numel == input.weight_numel) + self.ChildInstruction( + ( + idx1, + idx2, + idx3, + connection_mode, + has_weight, + path_weight**2, + ), + inst_idx, + ) + ) + + self.L1, self.L2, self.L3 = [ + Irreps([child.mul_ir for child in reps]) for reps in child_reps + ] + self.output = TPProblem( + self.L1, + self.L2, + self.L3, + [child.instruction_tup for child in new_instructions], + irrep_normalization="none", + path_normalization="none", + internal_weights=False, + shared_weights=input.shared_weights, + ) + + assert self.output.weight_numel == input.weight_numel # For each new instruction, calculate the subrange of original weights - # that it maps to + # that it maps to for child_inst in new_instructions: u, v, w, connection_mode, _, _ = child_inst.instruction_tup @@ -220,68 +294,83 @@ def __init__(self, input, mult_threshold): w_end = w_start + child_reps[2][w].mul_ir[0] if connection_mode == "uvw": - child_inst.weights_subrange = [slice(u_start, u_end), slice(v_start, v_end), slice(w_start, w_end)] + child_inst.weights_subrange = [ + slice(u_start, u_end), + slice(v_start, v_end), + slice(w_start, w_end), + ] elif connection_mode == "uvu": - child_inst.weights_subrange = [slice(u_start, u_end), slice(v_start, v_end)] + child_inst.weights_subrange = [ + slice(u_start, u_end), + slice(v_start, v_end), + ] elif connection_mode == "uuu": child_inst.weights_subrange = [slice(u_start, u_end)] - child_inst.parent_weights_start, child_inst.parent_weights_end, child_inst.parent_weights_shape = \ - input.weight_range_and_shape_for_instruction(child_inst.parent_idx) + ( + child_inst.parent_weights_start, + child_inst.parent_weights_end, + child_inst.parent_weights_shape, + ) = input.weight_range_and_shape_for_instruction(child_inst.parent_idx) self.new_instructions = new_instructions + class LaunchConfig: def __init__(self, num_blocks, num_threads, warp_size, smem): - self.num_blocks = num_blocks - self.num_threads = num_threads - self.warp_size = warp_size + self.num_blocks = num_blocks + self.num_threads = num_threads + self.warp_size = warp_size self.smem = smem + class ComputationSchedule: - def __init__(self, - config, - smem_limit, - warps_per_block, - warp_size, - block_count, - direction, - irrep_dtype, - weight_dtype, - include_scratch=False, - stream_weights=False, - schedule_type=2, - kahan=False): - ''' - smem_limit: size of available shared memory in bytes - ''' + def __init__( + self, + config, + smem_limit, + warps_per_block, + warp_size, + block_count, + direction, + irrep_dtype, + weight_dtype, + include_scratch=False, + stream_weights=False, + schedule_type=2, + kahan=False, + ): + """ + smem_limit: size of available shared memory in bytes + """ self.kahan = kahan if kahan: assert irrep_dtype == weight_dtype == np.float32 - # Note: does not work with variances for irreps; easy to add that in + # Note: does not work with variances for irreps; easy to add that in self.total_warps = warps_per_block * block_count - dtype_to_str_map = { - np.float32: "float", - np.double: "double" - } + dtype_to_str_map = {np.float32: "float", np.double: "double"} self.irrep_dtype_cstr = dtype_to_str_map[irrep_dtype] self.weight_dtype_cstr = dtype_to_str_map[weight_dtype] - # Stream weights on the fly before pre-loading - self.stream_weights = stream_weights + # Stream weights on the fly before pre-loading + self.stream_weights = stream_weights - # Step 1: Break the irreps and the instructions into chunks + # Step 1: Break the irreps and the instructions into chunks chunk_size = warp_size - if include_scratch: # There is at least one UVW computation if this flag is set. Cap the chunk size to 32. + if include_scratch: # There is at least one UVW computation if this flag is set. Cap the chunk size to 32. chunk_size = 32 self.problem_splitter = ProblemSplitter(config, chunk_size) self.updated_config = self.problem_splitter.output - self.L1, self.L2, self.L3 = self.updated_config.irreps_in1, self.updated_config.irreps_in2, self.updated_config.irreps_out + self.L1, self.L2, self.L3 = ( + self.updated_config.irreps_in1, + self.updated_config.irreps_in2, + self.updated_config.irreps_out, + ) self.new_instructions = self.updated_config.instructions smem_limit -= 1 @@ -289,22 +378,31 @@ def __init__(self, self.memory_per_warp -= self.memory_per_warp % 8 # ===================================================================== - # Shared memory partitioning functions + # Shared memory partitioning functions - def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs): + def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs): irrep_itemsize = np.dtype(irrep_dtype).itemsize weight_itemsize = np.dtype(weight_dtype).itemsize smem = { - "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, - "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, - "L3": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L1": { + "size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, + "L2": { + "size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, + "L3": { + "size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, "L3_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, "weights": {"size": 0, "dtype": self.weight_dtype_cstr}, - "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} + "scratch": {"size": 0, "dtype": self.weight_dtype_cstr}, } if kahan: - smem["L3_kahan"]["size"] = smem["L3"]["size"] + smem["L3_kahan"]["size"] = smem["L3"]["size"] else: smem.pop("L3_kahan") @@ -316,47 +414,70 @@ def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs): if inst.connection_mode == "uvu": weights_smem += np.prod(inst.path_shape) - smem["weights"]["size"] = weights_smem * weight_itemsize + smem["weights"]["size"] = weights_smem * weight_itemsize - if include_scratch: + if include_scratch: smem["weights"]["size"] = 32 * 32 * weight_itemsize - # Max irrep size of 10 -> dim = 21 - smem["scratch"]["size"] = (32 * 21) * weight_itemsize + # Max irrep size of 10 -> dim = 21 + smem["scratch"]["size"] = (32 * 21) * weight_itemsize - range_offsets = list(accumulate([smem[name]["size"] for name in smem], initial=0)) + range_offsets = list( + accumulate([smem[name]["size"] for name in smem], initial=0) + ) for i, name in enumerate(smem): smem[name]["offset"] = range_offsets[i] - # Pad for alignment, in case we want to perform vectorized loads later + # Pad for alignment, in case we want to perform vectorized loads later smem["total"] = sum([smem[name]["size"] for name in smem]) return smem - - def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, - L2_dgrad=False # Used for double-backward pass - ): + def calculate_backward_smem( + L1_set, + L2_set, + L3_set, + inst_idxs, + L2_dgrad=False, # Used for double-backward pass + ): irrep_itemsize = np.dtype(irrep_dtype).itemsize weight_itemsize = np.dtype(weight_dtype).itemsize smem = { - "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, - "L1_grad": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L1": { + "size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, + "L1_grad": { + "size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, "L1_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, - "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, - "L2_grad": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, - "L3_grad": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L2": { + "size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, + "L2_grad": { + "size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, + "L3_grad": { + "size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, + "dtype": self.irrep_dtype_cstr, + }, "weights": {"size": 0, "dtype": self.weight_dtype_cstr}, "weights_grad": {"size": 0, "dtype": self.weight_dtype_cstr}, - "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} + "scratch": {"size": 0, "dtype": self.weight_dtype_cstr}, } if kahan: - smem["L1_kahan"]["size"] = smem["L1"]["size"] + smem["L1_kahan"]["size"] = smem["L1"]["size"] else: smem.pop("L1_kahan") if L2_dgrad: - smem["L2_dgrad"] = {"size": smem["L2"]["size"], "dtype": self.irrep_dtype_cstr} + smem["L2_dgrad"] = { + "size": smem["L2"]["size"], + "dtype": self.irrep_dtype_cstr, + } weights_smem = 0 for inst_idx in inst_idxs: @@ -367,44 +488,58 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, weights_smem += np.prod(inst.path_shape) smem["weights"]["size"] = weights_smem * np.dtype(weight_dtype).itemsize - smem["weights_grad"]["size"] = weights_smem * np.dtype(weight_dtype).itemsize + smem["weights_grad"]["size"] = ( + weights_smem * np.dtype(weight_dtype).itemsize + ) - if include_scratch: + if include_scratch: smem["weights"]["size"] = 32 * 32 * weight_itemsize - # We can reuse the weight buffer to accumulate the gradient in shared memory - smem["weights_grad"]["size"] = 0 - # Max irrep size of 10 -> dim = 21 - smem["scratch"]["size"] = (32 * 21) * weight_itemsize - - range_offsets = list(accumulate([smem[name]["size"] for name in smem], initial=0)) + # We can reuse the weight buffer to accumulate the gradient in shared memory + smem["weights_grad"]["size"] = 0 + # Max irrep size of 10 -> dim = 21 + smem["scratch"]["size"] = (32 * 21) * weight_itemsize + + range_offsets = list( + accumulate([smem[name]["size"] for name in smem], initial=0) + ) for i, name in enumerate(smem): smem[name]["offset"] = range_offsets[i] - smem["total"] = sum([smem[name]["size"] for name in smem]) + smem["total"] = sum([smem[name]["size"] for name in smem]) return smem # ===================================================================== # Step 2: Loop through the instructions, assigning them to segments that fit into shared memory - # for a single warp. Could be replaced by a more powerful algorithm. + # for a single warp. Could be replaced by a more powerful algorithm. if direction == "forward": calculate_smem = calculate_forward_smem elif direction == "backward": calculate_smem = calculate_backward_smem elif direction == "double_backward": - calculate_smem = lambda *args, **kwargs: calculate_backward_smem(*args, L2_dgrad=True, **kwargs) + calculate_smem = lambda *args, **kwargs: calculate_backward_smem( # noqa : E731 + *args, L2_dgrad=True, **kwargs + ) schedule2_succeeded = False try: if schedule_type != 2: raise SMEMCapacityException("Asked for schedule case 3.") - self.segments = create_schedule_case2(self.new_instructions, self.memory_per_warp, calculate_smem, direction) - logger.info(f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments.") + self.segments = create_schedule_case2( + self.new_instructions, self.memory_per_warp, calculate_smem, direction + ) + logger.info( + f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments." + ) schedule2_succeeded = True - except SMEMCapacityException as e: - self.segments = create_schedule_case3(self.new_instructions, self.memory_per_warp, calculate_smem, direction) - logger.info(f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments.") + except SMEMCapacityException: + self.segments = create_schedule_case3( + self.new_instructions, self.memory_per_warp, calculate_smem, direction + ) + logger.info( + f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments." + ) for i in range(len(self.segments)): L1_idxs, L2_idxs, L3_idxs, inst_idxs = self.segments[i] @@ -414,36 +549,57 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, L3Map = IrrepMapping(self.L3, L3_idxs) instructions = [ - (L1Map.src_dst_map[inst.i_in1], - L2Map.src_dst_map[inst.i_in2], - L3Map.src_dst_map[inst.i_out], - inst.connection_mode, inst.has_weight, inst.path_weight ** 2) - for inst in [self.new_instructions[idx] for idx in inst_idxs] + ( + L1Map.src_dst_map[inst.i_in1], + L2Map.src_dst_map[inst.i_in2], + L3Map.src_dst_map[inst.i_out], + inst.connection_mode, + inst.has_weight, + inst.path_weight**2, + ) + for inst in [self.new_instructions[idx] for idx in inst_idxs] ] - problem = TPProblem(L1Map.dst_irreps, L2Map.dst_irreps, L3Map.dst_irreps, instructions, - irrep_normalization="none", path_normalization="none", - internal_weights=False, shared_weights=config.shared_weights) + problem = TPProblem( + L1Map.dst_irreps, + L2Map.dst_irreps, + L3Map.dst_irreps, + instructions, + irrep_normalization="none", + path_normalization="none", + internal_weights=False, + shared_weights=config.shared_weights, + ) weight_offset = 0 if i > 0: - weight_offset = self.segments[i-1].weight_offset + self.segments[i-1].problem.weight_numel - - self.segments[i] = ComputationSegment(L1Map, L2Map, L3Map, problem, - calculate_smem(L1_idxs, L2_idxs, L3_idxs, inst_idxs), weight_offset, irrep_dtype) + weight_offset = ( + self.segments[i - 1].weight_offset + + self.segments[i - 1].problem.weight_numel + ) + + self.segments[i] = ComputationSegment( + L1Map, + L2Map, + L3Map, + problem, + calculate_smem(L1_idxs, L2_idxs, L3_idxs, inst_idxs), + weight_offset, + irrep_dtype, + ) for ir_idx, ir in enumerate([self.L1, self.L2, self.L3]): for i in range(len(ir)): irrep_used = False - for seg in self.segments: + for seg in self.segments: if i in seg.maps[ir_idx].idxs: if irrep_used: seg.maps[ir_idx].storeback_procedure[i] = "accumulate" irrep_used = True if schedule2_succeeded: - # Allow L1 and L2 irreps to persist in shared memory + # Allow L1 and L2 irreps to persist in shared memory for i, seg in enumerate(self.segments): for ir_map in [seg.L1Map, seg.L2Map]: if i > 0: @@ -461,28 +617,36 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, num_blocks=block_count, num_threads=warps_per_block * warp_size, warp_size=warp_size, - smem=self.memory_per_warp * warps_per_block) - + smem=self.memory_per_warp * warps_per_block, + ) def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim): - ''' + """ Reorders weights from the canonical e3nn form to the - form that LoopUnrollTP can ingest. Can also reorder the parameters + form that LoopUnrollTP can ingest. Can also reorder the parameters of a dense neural network layer that produces the weight matrix. If has_batch_dim is true, the first dimension of the input weight matrix - is treated as the batch dimension. - ''' + is treated as the batch dimension. + """ weights_out *= 0.0 - assert(direction in ["forward", "backward"]) + assert direction in ["forward", "backward"] for i, child_inst in enumerate(self.problem_splitter.new_instructions): - parent_start, parent_end = child_inst.parent_weights_start, child_inst.parent_weights_end + parent_start, parent_end = ( + child_inst.parent_weights_start, + child_inst.parent_weights_end, + ) parent_shape = list(child_inst.parent_weights_shape) - child_start, child_end, child_shape = self.updated_config.weight_range_and_shape_for_instruction(i) + child_start, child_end, child_shape = ( + self.updated_config.weight_range_and_shape_for_instruction(i) + ) - parent_range, child_range = [slice(parent_start, parent_end)], [slice(child_start, child_end)] - weights_subrange = child_inst.weights_subrange + parent_range, child_range = ( + [slice(parent_start, parent_end)], + [slice(child_start, child_end)], + ) + weights_subrange = child_inst.weights_subrange batch_dim = weights_in.shape[0] reshape_size = [-1] transpose_perm = None @@ -495,17 +659,27 @@ def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim): if has_batch_dim: child_range = [slice(0, batch_dim)] + child_range - parent_range = [slice(0, batch_dim)] + parent_range - parent_shape = [batch_dim] + parent_shape + parent_range = [slice(0, batch_dim)] + parent_range + parent_shape = [batch_dim] + parent_shape child_shape = [batch_dim] + list(child_shape) weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange reshape_size = [batch_dim] + reshape_size transpose_perm = [0] + [i + 1 for i in transpose_perm] if direction == "forward": - sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[tuple(weights_subrange)] - weights_out[tuple(child_range)] = sliced_weights.transpose(transpose_perm).reshape(reshape_size) + sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[ + tuple(weights_subrange) + ] + weights_out[tuple(child_range)] = sliced_weights.transpose( + transpose_perm + ).reshape(reshape_size) elif direction == "backward": transpose_child_shape = [child_shape[i] for i in transpose_perm] - sliced_weights = weights_in[tuple(child_range)].reshape(transpose_child_shape).transpose(transpose_perm) - weights_out[tuple(parent_range)].reshape(parent_shape)[tuple(weights_subrange)] = sliced_weights.flatten().reshape(child_shape) \ No newline at end of file + sliced_weights = ( + weights_in[tuple(child_range)] + .reshape(transpose_child_shape) + .transpose(transpose_perm) + ) + weights_out[tuple(parent_range)].reshape(parent_shape)[ + tuple(weights_subrange) + ] = sliced_weights.flatten().reshape(child_shape) diff --git a/openequivariance/implementations/E3NNTensorProduct.py b/openequivariance/implementations/E3NNTensorProduct.py index 179f2451..334ba65c 100644 --- a/openequivariance/implementations/E3NNTensorProduct.py +++ b/openequivariance/implementations/E3NNTensorProduct.py @@ -1,6 +1,11 @@ -__all__ = ['E3NNTensorProduct', 'E3NNTensorProductCompiled', 'E3NNTensorProductCompiledCUDAGraphs', 'E3NNTensorProductCompiledMaxAutotuneCUDAGraphs'] - -import os +__all__ = [ + "E3NNTensorProduct", + "E3NNTensorProductCompiled", + "E3NNTensorProductCompiledCUDAGraphs", + "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs", +] + +import os import pathlib import numpy as np @@ -8,80 +13,82 @@ from openequivariance.implementations.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger -TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path('triton_autotuning') +TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") logger = getLogger() + class E3NNTensorProduct(TensorProductBase): - def __init__(self, config : TPProblem, torch_op=True): + def __init__(self, config: TPProblem, torch_op=True): super().__init__(config, torch_op=torch_op) - assert(self.torch_op) - + assert self.torch_op + global torch global e3nn import torch - import e3nn + import e3nn from e3nn import o3 + e3nn.set_optimization_defaults(jit_script_fx=False) - assert(config.irrep_dtype == config.weight_dtype) + assert config.irrep_dtype == config.weight_dtype if config.irrep_dtype == np.float64: torch.set_default_dtype(torch.float64) self.e3nn_tp = o3.TensorProduct( - config.irreps_in1, - config.irreps_in2, - config.irreps_out, - config.instructions_raw, - in1_var=config.in1_var, - in2_var=config.in2_var, - out_var=config.out_var, - irrep_normalization=config.irrep_normalization, - path_normalization=config.path_normalization, - internal_weights=config.internal_weights, - shared_weights=config.shared_weights).to(device='cuda') + config.irreps_in1, + config.irreps_in2, + config.irreps_out, + config.instructions_raw, + in1_var=config.in1_var, + in2_var=config.in2_var, + out_var=config.out_var, + irrep_normalization=config.irrep_normalization, + path_normalization=config.path_normalization, + internal_weights=config.internal_weights, + shared_weights=config.shared_weights, + ).to(device="cuda") if config.irrep_dtype == np.float64: torch.set_default_dtype(torch.float32) # Reset to default - self.forward = self.e3nn_tp.__call__ + self.forward = self.e3nn_tp.__call__ def forward_cpu( - self, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_out : np.ndarray, - weights : np.ndarray, - ) -> None: - torch_L1_in = torch.tensor(L1_in, device='cuda') - torch_L2_in = torch.tensor(L2_in, device='cuda') - torch_weights = torch.tensor(weights, device='cuda') - + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights, device="cuda") + torch_L3_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) L3_out[:] = torch_L3_out.numpy(force=True) def backward_cpu( - self, - L1_in : np.ndarray, - L1_grad : np.ndarray, - L2_in : np.ndarray, - L2_grad : np.ndarray, - L3_grad : np.ndarray, - weights : np.ndarray, - weights_grad : np.ndarray, - ) -> None: - - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights = torch.tensor(weights, requires_grad=True, device='cuda') + self, + L1_in: np.ndarray, + L1_grad: np.ndarray, + L2_in: np.ndarray, + L2_grad: np.ndarray, + L3_grad: np.ndarray, + weights: np.ndarray, + weights_grad: np.ndarray, + ) -> None: + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") torch_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_grad, device='cuda') + torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") torch_out.backward(gradient=torch_L3_grad_in) - + L1_grad[:] = torch_L1_in.grad.numpy(force=True) L2_grad[:] = torch_L2_in.grad.numpy(force=True) weights_grad[:] = torch_weights.grad.numpy(force=True) @@ -92,52 +99,56 @@ def name(cls): class E3NNTensorProductCompiled(E3NNTensorProduct): - def __init__(self, config : TPProblem, torch_compile_kwargs : dict, torch_op : bool = True, ): - super().__init__(config, torch_op = torch_op) + def __init__( + self, + config: TPProblem, + torch_compile_kwargs: dict, + torch_op: bool = True, + ): + super().__init__(config, torch_op=torch_op) self.torch_compile_kwargs = torch_compile_kwargs - - logger.debug('Torch compiling e3nn TP') - logger.debug(msg=f'{torch_compile_kwargs}') - self.e3nn_tp = torch.compile(self.e3nn_tp, - **self.torch_compile_kwargs) - logger.debug('e3nn TP torch compiled') + + logger.debug("Torch compiling e3nn TP") + logger.debug(msg=f"{torch_compile_kwargs}") + self.e3nn_tp = torch.compile(self.e3nn_tp, **self.torch_compile_kwargs) + logger.debug("e3nn TP torch compiled") self.forward = self.e3nn_tp.__call__ - + + class E3NNTensorProductCompiledCUDAGraphs(E3NNTensorProductCompiled): - def __init__(self, config : TPProblem, torch_op=True): - + def __init__(self, config: TPProblem, torch_op=True): global torch import torch - + torch._dynamo.config.cache_size_limit = 64 - + torch_compile_kwargs = { - 'fullgraph':True, - 'backend': 'inductor', - 'options': { 'triton.cudagraphs': True} + "fullgraph": True, + "backend": "inductor", + "options": {"triton.cudagraphs": True}, } super().__init__(config, torch_compile_kwargs, torch_op=torch_op) + class E3NNTensorProductCompiledMaxAutotuneCUDAGraphs(E3NNTensorProductCompiled): - def __init__(self, config : TPProblem, torch_op=True): + def __init__(self, config: TPProblem, torch_op=True): global torch import torch TORCH_COMPILE_AUTOTUNING_DIR.mkdir(exist_ok=True) - os.environ['TORCHINDUCTOR_CACHE_DIR'] = str(TORCH_COMPILE_AUTOTUNING_DIR) - os.environ['TRITON_CACHE_DIR'] = str(TORCH_COMPILE_AUTOTUNING_DIR) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR) + os.environ["TRITON_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR) torch._dynamo.config.cache_size_limit = 64 torch_compile_kwargs = { - 'fullgraph':True, - 'backend': 'inductor', - 'options': - { - 'max_autotune':True, - 'triton.cudagraphs':True, - 'triton.unique_kernel_names':False, - 'coordinate_descent_tuning':False, + "fullgraph": True, + "backend": "inductor", + "options": { + "max_autotune": True, + "triton.cudagraphs": True, + "triton.unique_kernel_names": False, + "coordinate_descent_tuning": False, }, } - super().__init__(config, torch_compile_kwargs, torch_op=torch_op) \ No newline at end of file + super().__init__(config, torch_compile_kwargs, torch_op=torch_op) diff --git a/openequivariance/implementations/LoopUnrollTP.py b/openequivariance/implementations/LoopUnrollTP.py index 39930d05..b92f8d0c 100644 --- a/openequivariance/implementations/LoopUnrollTP.py +++ b/openequivariance/implementations/LoopUnrollTP.py @@ -1,18 +1,22 @@ import numpy as np import openequivariance.extlib as extlib -from openequivariance.templates.jinja_utils import * -from openequivariance.implementations.ComputationSchedule import ComputationSchedule +from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance.implementations.ComputationSchedule import ComputationSchedule + +from openequivariance.implementations.TensorProductBase import TensorProductBase +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.implementations.utils import ( + filter_and_analyze_problem, + count_cg_non_zero, +) -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.utils import filter_and_analyze_problem, count_cg_non_zero logger = getLogger() + class LoopUnrollTP(TensorProductBase): def __init__(self, config, torch_op=True): super().__init__(config, torch_op=torch_op) - L1, L2, L3 = self.L1, self.L2, self.L3 env = get_jinja_environment() template = env.get_template("loop_unroll_batch.cuh") @@ -22,43 +26,53 @@ def __init__(self, config, torch_op=True): self.is_uvw = analysis["is_uvw"] def generate_forward_schedule(warps_per_block): - self.forward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount * 4, - direction = "forward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + self.forward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount * 4, + direction="forward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + ) def generate_backward_schedule(warps_per_block): - self.backward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount * 4, - direction = "backward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + self.backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount * 4, + direction="backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + ) def generate_double_backward_schedule(warps_per_block): - self.double_backward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount, - direction = "double_backward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - schedule_type=3) - - scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] + self.double_backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction="double_backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3, + ) + + scheduler_generators = [ + generate_forward_schedule, + generate_backward_schedule, + generate_double_backward_schedule, + ] for generate_schedule in scheduler_generators: warp_count = 8 @@ -66,17 +80,22 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except Exception as e: + except Exception: warp_count -= 2 if warp_count == 0: - raise RuntimeError("Tensor product schedule generation failed, shared memory inadequate!") - - self.jit_kernel = extlib.postprocess_kernel(template.render( - forward_schedule=self.forward_schedule, - backward_schedule=self.backward_schedule, - double_backward_schedule=self.double_backward_schedule)) - - #with open("scratch.txt", "w") as f: + raise RuntimeError( + "Tensor product schedule generation failed, shared memory inadequate!" + ) + + self.jit_kernel = extlib.postprocess_kernel( + template.render( + forward_schedule=self.forward_schedule, + backward_schedule=self.backward_schedule, + double_backward_schedule=self.double_backward_schedule, + ) + ) + + # with open("scratch.txt", "w") as f: # f.write(self.jit_kernel) internal_cls = None @@ -89,13 +108,17 @@ def generate_double_backward_schedule(warps_per_block): internal_cls = extlib.JITTPImpl logger.info("Starting kernel compiler...") - self.internal = internal_cls(self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - {"L3_dim": self.L3.dim, - "shared_weights": int(self.config.shared_weights), - "is_uvw": int(self.is_uvw)}) + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + { + "L3_dim": self.L3.dim, + "shared_weights": int(self.config.shared_weights), + "is_uvw": int(self.is_uvw), + }, + ) logger.info("Kernel compiled!") logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") @@ -103,10 +126,16 @@ def generate_double_backward_schedule(warps_per_block): if self.torch_op: self.setup_torch_custom_op() - self.reorder_weights_e3nn_to_oeq = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) - self.reorder_weights_oeq_to_e3nn = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) + self.reorder_weights_e3nn_to_oeq = ( + lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights( + input, output, "forward", has_batch_dim + ) + ) + self.reorder_weights_oeq_to_e3nn = ( + lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights( + input, output, "backward", has_batch_dim + ) + ) @classmethod def register_torch_fakes(cls): @@ -115,12 +144,27 @@ def register_torch_fakes(cls): @torch._library.register_fake_class("torch_tp_jit::TorchJITProduct") class TorchJITProduct: - def __init__(self, kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int]) -> None: - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) @classmethod def __obj_unflatten__(cls, flattened_product): @@ -128,98 +172,129 @@ def __obj_unflatten__(cls, flattened_product): def __len__(self): return 0 - + def __setstate__(self, state): self.kernel_plaintext = state["kernel_plaintext"] self.fwd_config = state["fwd_config"] self.bwd_config = state["bwd_config"] self.dbl_bwd_config = state["dbl_bwd_config"] - self.kernel_dims = state["kernel_dims"] - - def exec_tensor_product_rawptr(self, - batch : int, - L1_in: int, L2_in: int, L3_out: int, - weights: int) -> None: + self.kernel_dims = state["kernel_dims"] + + def exec_tensor_product_rawptr( + self, batch: int, L1_in: int, L2_in: int, L3_out: int, weights: int + ) -> None: pass - def backward_rawptr(self, batch_size: int, - L1_in: int, L1_grad: int, - L2_in: int, L2_grad: int, - weights: int, weights_grad: int, - L3_grad: int): + def backward_rawptr( + self, + batch_size: int, + L1_in: int, + L1_grad: int, + L2_in: int, + L2_grad: int, + weights: int, + weights_grad: int, + L3_grad: int, + ): pass @torch.library.register_fake("torch_tp_jit::jit_tp_forward") def fake_forward(jit, L1_in, L2_in, W): - return L1_in.new_empty(L1_in.shape[0], jit.wrapped_obj.kernel_dims["L3_dim"]) + return L1_in.new_empty( + L1_in.shape[0], jit.wrapped_obj.kernel_dims["L3_dim"] + ) @torch.library.register_fake("torch_tp_jit::jit_tp_backward") def fake_backward(jit, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @classmethod def register_autograd(cls): - forward_op = torch.ops.torch_tp_jit.jit_tp_forward backward_op = torch.ops.torch_tp_jit.jit_tp_backward def setup_context(ctx, inputs, output): ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs - + def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad= backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) - return None, L1_grad, L2_grad, W_grad + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output + ) + return None, L1_grad, L2_grad, W_grad - torch.library.register_autograd("torch_tp_jit::jit_tp_forward", backward, setup_context=setup_context) + torch.library.register_autograd( + "torch_tp_jit::jit_tp_forward", backward, setup_context=setup_context + ) def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs def double_backward(ctx, E, F, G): - result = torch.ops.torch_tp_jit.jit_tp_double_backward(ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.L3_grad, - E, F, G) + result = torch.ops.torch_tp_jit.jit_tp_double_backward( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ) return None, result[0], result[1], result[2], result[3] - torch.library.register_autograd("torch_tp_jit::jit_tp_backward", double_backward, setup_context=setup_context_double_backward) - + torch.library.register_autograd( + "torch_tp_jit::jit_tp_backward", + double_backward, + setup_context=setup_context_double_backward, + ) @staticmethod def name(): return "LoopUnrollTP" - - def calculate_flops_forward(self, batch_size : int) -> dict: + + def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) else: tpp = self.config - flop_count = {'CG_decomposition': 0, 'linear_combination': 0, 'outer_products': 0} - for ins in tpp.instructions: - l1, l2, l3 = tpp.irreps_in1[ins.i_in1].ir.l, tpp.irreps_in2[ins.i_in2].ir.l, tpp.irreps_out[ins.i_out].ir.l - flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (ins.path_shape[0] * ins.path_shape[1]) - flop_count["linear_combination"] += (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 + flop_count = { + "CG_decomposition": 0, + "linear_combination": 0, + "outer_products": 0, + } + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) + flop_count["linear_combination"] += ( + (2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0 + ) flop_count["CG_decomposition"] *= 3 * batch_size - flop_count["linear_combination"] *= batch_size # Weights do not require FMA here + flop_count["linear_combination"] *= ( + batch_size # Weights do not require FMA here + ) flop_count["total"] = sum(flop_count.values()) return flop_count - def calculate_flops_backward(self, batch_size : int) -> dict: + def calculate_flops_backward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_backward(batch_size) else: tpp = self.config - flop_count = {'backward': 0} - for ins in tpp.instructions: - l1, l2, l3 = tpp.irreps_in1[ins.i_in1].ir.l, tpp.irreps_in2[ins.i_in2].ir.l, tpp.irreps_out[ins.i_out].ir.l - flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (ins.path_shape[0] * ins.path_shape[1]) + flop_count = {"backward": 0} + for ins in tpp.instructions: + l1, l2, l3 = ( + tpp.irreps_in1[ins.i_in1].ir.l, + tpp.irreps_in2[ins.i_in2].ir.l, + tpp.irreps_out[ins.i_out].ir.l, + ) + flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * ( + ins.path_shape[0] * ins.path_shape[1] + ) flop_count["backward"] *= 9 * batch_size flop_count["total"] = sum(flop_count.values()) return flop_count - -if extlib.TORCH_COMPILE: - LoopUnrollTP.register_torch_fakes() - LoopUnrollTP.register_autograd() \ No newline at end of file + + +if extlib.TORCH_COMPILE: + LoopUnrollTP.register_torch_fakes() + LoopUnrollTP.register_autograd() diff --git a/openequivariance/implementations/MultiplicityOuterProductTP.py b/openequivariance/implementations/MultiplicityOuterProductTP.py index 63de8328..8cc3a298 100644 --- a/openequivariance/implementations/MultiplicityOuterProductTP.py +++ b/openequivariance/implementations/MultiplicityOuterProductTP.py @@ -1,20 +1,27 @@ import numpy as np from openequivariance.implementations.utils import calc_weight_offsets -from openequivariance.implementations.e3nn_lite import Irrep, _MulIr, Irreps, TPProblem, Instruction -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.implementations.e3nn_lite import ( + Irreps, + TPProblem, + Instruction, +) +from openequivariance.implementations.TensorProductBase import TensorProductBase +from openequivariance.benchmark.logging_utils import getLogger from jinja2 import Environment, PackageLoader from openequivariance.extlib import KernelLaunchConfig, JITTPImpl, DeviceProp logger = getLogger() + def raise_helper(msg): raise Exception(msg) + def divide(numerator, denominator): - return numerator // denominator + return numerator // denominator + def sizeof(dtype): if dtype in ["float", "int", "unsigned int"]: @@ -22,39 +29,42 @@ def sizeof(dtype): else: raise Exception("Provided undefined datatype to sizeof!") + class MultiplicityOuterProductTP(TensorProductBase): - def __init__(self, config : TPProblem, torch_op : bool = False): + def __init__(self, config: TPProblem, torch_op: bool = False): super().__init__(config, torch_op) - for ins in config.instructions: # type : Instruction + for ins in config.instructions: # type : Instruction assert isinstance(ins, Instruction) - assert ins.connection_mode == 'uvw' + assert ins.connection_mode == "uvw" assert ins.path_shape[0] <= 32 assert ins.path_shape[1] <= 32 assert ins.path_shape[2] <= 32 irreps_in1 = config.irreps_in1 irreps_in2 = config.irreps_in2 - irreps_out = config.irreps_out + irreps_out = config.irreps_out # ================================================================================== - env = Environment(loader=PackageLoader("openequivariance"), extensions=['jinja2.ext.do']) - env.globals['raise'] = raise_helper - env.globals['divide'] = divide - env.globals['sizeof'] = sizeof - env.globals['range'] = range - env.globals['enumerate'] = enumerate - env.globals['len'] = len + env = Environment( + loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"] + ) + env.globals["raise"] = raise_helper + env.globals["divide"] = divide + env.globals["sizeof"] = sizeof + env.globals["range"] = range + env.globals["enumerate"] = enumerate + env.globals["len"] = len main_template = env.get_template("subkernel_per_interaction_multirep.cuh") # forward_subkernel_template = env.get_template("subkernel_forward_thread.cu.jinja2") # backward_subkernel_template = env.get_template("subkernel_backward_thread.cu.jinja2") - + # ===================================================================== # Updated to work with TensorProductProblem - + class RepData: - def __init__(self, irreps : Irreps): + def __init__(self, irreps: Irreps): assert isinstance(irreps, Irreps) self.rep_len = irreps.dim self.irrep_lengths = [mul_irrep.ir.dim for mul_irrep in irreps] @@ -62,7 +72,7 @@ def __init__(self, irreps : Irreps): offset = 0 self.offsets = [] - for mul_irrep in irreps: + for mul_irrep in irreps: self.offsets.append(offset) offset += mul_irrep.dim @@ -72,65 +82,96 @@ def __init__(self, irreps : Irreps): class CGTensor: def __init__(self, l1, l2, l3): tensor = load_cg_tensor(l1, l2, l3) - coord1, coord2, coord3 = [arr.astype(np.int32).copy() for arr in np.nonzero(tensor)] + coord1, coord2, coord3 = [ + arr.astype(np.int32).copy() for arr in np.nonzero(tensor) + ] float_values = tensor[np.nonzero(tensor)].astype(np.float32).copy() values = [str(float.hex(float(val))) + "f" for val in float_values] - self.tuples = [(coord1[i], coord2[i], coord3[i], values[i]) for i in range(len(values))] - # self.tuples.sort(key=lambda tup: (tup[1], tup[0], tup[2])) + self.tuples = [ + (coord1[i], coord2[i], coord3[i], values[i]) + for i in range(len(values)) + ] + # self.tuples.sort(key=lambda tup: (tup[1], tup[0], tup[2])) self.nnz = len(values) # ===================================================================== - # FORWARD MEMORY ANALYSIS - forward_thread_blocks_per_SM = 24 + # FORWARD MEMORY ANALYSIS + forward_thread_blocks_per_SM = 24 forward_threads_per_thread_block = 32 # ===================================================================== dp = DeviceProp(0) forward_launch_config = KernelLaunchConfig() - forward_launch_config.num_blocks = dp.multiprocessorCount * forward_thread_blocks_per_SM + forward_launch_config.num_blocks = ( + dp.multiprocessorCount * forward_thread_blocks_per_SM + ) forward_launch_config.num_threads = forward_threads_per_thread_block - # IMPORTANT! + # IMPORTANT! smem_gemm_max_n = forward_threads_per_thread_block - smem_gemm_L3_scratch = smem_gemm_max_n * max(RepData(config.irreps_out).irrep_lengths) # this has space for the largest output size * 32 - smem_gemm_weights_scratch = max(RepData(config.irreps_out).mults) * smem_gemm_max_n + smem_gemm_L3_scratch = smem_gemm_max_n * max( + RepData(config.irreps_out).irrep_lengths + ) # this has space for the largest output size * 32 + smem_gemm_weights_scratch = ( + max(RepData(config.irreps_out).mults) * smem_gemm_max_n + ) smem_gemm_info = { - 'n' : smem_gemm_max_n, - 'L3_scratch_elems' : smem_gemm_L3_scratch, - 'weight_scratch_elems' : smem_gemm_weights_scratch, + "n": smem_gemm_max_n, + "L3_scratch_elems": smem_gemm_L3_scratch, + "weight_scratch_elems": smem_gemm_weights_scratch, } logger.debug(smem_gemm_info) # END OF IMPORTANT forward_launch_config.smem = ( - (irreps_in1.dim + irreps_in2.dim + irreps_out.dim + smem_gemm_L3_scratch + smem_gemm_weights_scratch) - * sizeof("float") - * forward_launch_config.num_threads // forward_launch_config.warp_size - ) + ( + irreps_in1.dim + + irreps_in2.dim + + irreps_out.dim + + smem_gemm_L3_scratch + + smem_gemm_weights_scratch + ) + * sizeof("float") + * forward_launch_config.num_threads + // forward_launch_config.warp_size + ) - logger.info(f"Forward pass needs {forward_launch_config.smem} bytes of shared memory.") + logger.info( + f"Forward pass needs {forward_launch_config.smem} bytes of shared memory." + ) if forward_launch_config.smem > dp.maxSharedMemPerBlock: - raise Exception(f"Error, requested shared memory {forward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !") - + raise Exception( + f"Error, requested shared memory {forward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !" + ) + # ===================================================================== backward_launch_config = KernelLaunchConfig() backward_launch_config.num_blocks = dp.multiprocessorCount * 1 backward_launch_config.num_threads = 32 - backward_launch_config.smem = (2 * irreps_in1.dim + 2 * irreps_in2.dim + 2 * + irreps_out.dim) * sizeof("float") * backward_launch_config.num_threads // backward_launch_config.warp_size - logger.info(f"Backward pass needs {backward_launch_config.smem} bytes of shared memory.") + backward_launch_config.smem = ( + (2 * irreps_in1.dim + 2 * irreps_in2.dim + 2 * +irreps_out.dim) + * sizeof("float") + * backward_launch_config.num_threads + // backward_launch_config.warp_size + ) + logger.info( + f"Backward pass needs {backward_launch_config.smem} bytes of shared memory." + ) if backward_launch_config.smem > dp.maxSharedMemPerBlock: - raise Exception(f"Error, requested shared memory {backward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !") + raise Exception( + f"Error, requested shared memory {backward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !" + ) - # ===================================================================== + # ===================================================================== self.forward_config = forward_launch_config - self.backward_config = backward_launch_config + self.backward_config = backward_launch_config load_cg_tensor = self.load_cg_tensor # ===================================================================== @@ -141,23 +182,27 @@ def __init__(self, l1, l2, l3): # ===================================================================== # tranform "e3nn instructions" into "interactions" - instructions : list[Instruction] = config.instructions + instructions: list[Instruction] = config.instructions interactions = [] for ins in instructions: u = ins.i_in1 v = ins.i_in2 w = ins.i_out - interaction = (u, v, w, CGTensor(irreps_in1[u].ir.l, irreps_in2[v].ir.l, irreps_out[w].ir.l)) + interaction = ( + u, + v, + w, + CGTensor(irreps_in1[u].ir.l, irreps_in2[v].ir.l, irreps_out[w].ir.l), + ) interactions.append(interaction) # interactions.sort(key=lambda x: (x[2], x[0], x[1])) - assert len(interactions) != 0 # ===================================================================== kernel_text = main_template.render( - L1=RepData(config.irreps_in1), - L2=RepData(config.irreps_in2), + L1=RepData(config.irreps_in1), + L2=RepData(config.irreps_in2), L3=RepData(config.irreps_out), weight_numel=config.weight_numel, weight_offsets=weight_offsets, @@ -165,15 +210,17 @@ def __init__(self, l1, l2, l3): interactions=interactions, smem_gemm_info=smem_gemm_info, forward_config=forward_launch_config, - backward_config=backward_launch_config + backward_config=backward_launch_config, ) self.jit_kernel = kernel_text - + logger.debug(kernel_text) logger.info("Starting NVRTC") - self.internal = JITTPImpl(self.jit_kernel, self.forward_config, self.backward_config) + self.internal = JITTPImpl( + self.jit_kernel, self.forward_config, self.backward_config + ) logger.info("Kernel compiled!") if self.torch_op: diff --git a/openequivariance/implementations/TensorProduct.py b/openequivariance/implementations/TensorProduct.py index 2068378c..6ae08b5e 100644 --- a/openequivariance/implementations/TensorProduct.py +++ b/openequivariance/implementations/TensorProduct.py @@ -1,12 +1,13 @@ -from openequivariance import extlib from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP import torch + class TensorProduct(torch.nn.Module, LoopUnrollTP): - ''' + """ PyTorch-specialized dispatcher class that selects the right implementation based on problem - configuration. - ''' + configuration. + """ + def __init__(self, problem, torch_op=True): torch.nn.Module.__init__(self) LoopUnrollTP.__init__(self, problem, torch_op) @@ -16,5 +17,7 @@ def __init__(self, problem, torch_op=True): def name(): return LoopUnrollTP.name() - def forward(self, L1: torch.Tensor, L2: torch.Tensor, W: torch.Tensor) -> torch.Tensor: - return torch.ops.torch_tp_jit.jit_tp_forward(self.internal, L1, L2, W) \ No newline at end of file + def forward( + self, L1: torch.Tensor, L2: torch.Tensor, W: torch.Tensor + ) -> torch.Tensor: + return torch.ops.torch_tp_jit.jit_tp_forward(self.internal, L1, L2, W) diff --git a/openequivariance/implementations/TensorProductBase.py b/openequivariance/implementations/TensorProductBase.py index dee71413..17ff7634 100644 --- a/openequivariance/implementations/TensorProductBase.py +++ b/openequivariance/implementations/TensorProductBase.py @@ -1,35 +1,39 @@ -import pickle, pathlib, typing -from math import prod +import typing import numpy as np -import numpy.linalg as la import openequivariance.extlib as extlib -from openequivariance.extlib import * +from openequivariance.extlib import DeviceBuffer, GPUTimer from openequivariance.implementations.e3nn_lite import TPProblem, wigner_3j -from openequivariance.benchmark.logging_utils import getLogger, bcolors +from openequivariance.benchmark.logging_utils import getLogger + logger = getLogger() + class TensorProductBase: - next_tp_id = 0 # Assign unique IDs to each TP instance + next_tp_id = 0 # Assign unique IDs to each TP instance @staticmethod def load_cg_tensor(l1, l2, l3): - return wigner_3j(l1, l2, l3) + return wigner_3j(l1, l2, l3) - ''' + """ Each class implementation of a TensorProduct uses a different internal representation, which it can initialize uniquely. - ''' - def __init__(self, config : TPProblem, - torch_op : bool = False): + """ + + def __init__(self, config: TPProblem, torch_op: bool = False): assert isinstance(config, TPProblem) assert isinstance(torch_op, bool) self.config, self.torch_op = config, torch_op - self.L1, self.L2, self.L3 = config.irreps_in1, config.irreps_in2, config.irreps_out - self.irrep_dtype, self.weight_dtype = config.irrep_dtype, config.weight_dtype - self.reorder_weights_e3nn_to_oeq, self.reorder_weights_oeq_to_e3nn = None, None + self.L1, self.L2, self.L3 = ( + config.irreps_in1, + config.irreps_in2, + config.irreps_out, + ) + self.irrep_dtype, self.weight_dtype = config.irrep_dtype, config.weight_dtype + self.reorder_weights_e3nn_to_oeq, self.reorder_weights_oeq_to_e3nn = None, None self.tp_id = TensorProductBase.next_tp_id TensorProductBase.next_tp_id += 1 @@ -38,42 +42,46 @@ def __init__(self, config : TPProblem, global torch import torch - def __call__(self, L1_in, L2_in, weights): + def __call__(self, L1_in, L2_in, weights): return self.forward(L1_in, L2_in, weights) def forward_raw( - self, - batch : np.uint64, - L1_in: np.uint64, - L2_in: np.uint64, - L3_out: np.uint64, - weights: np.uint64 - ) -> None: - self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) - - def backward_raw(self, batch_size: np.uint64, - L1_in: np.uint64, L1_grad: np.uint64, - L2_in: np.uint64, L2_grad: np.uint64, - weights: np.uint64, weights_grad: np.uint64, - L3_grad: np.uint64): + self, + batch: np.uint64, + L1_in: np.uint64, + L2_in: np.uint64, + L3_out: np.uint64, + weights: np.uint64, + ) -> None: + self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) + + def backward_raw( + self, + batch_size: np.uint64, + L1_in: np.uint64, + L1_grad: np.uint64, + L2_in: np.uint64, + L2_grad: np.uint64, + weights: np.uint64, + weights_grad: np.uint64, + L3_grad: np.uint64, + ): self.internal.backward_rawptr( - batch_size, - L1_in, L1_grad, - L2_in, L2_grad, - weights, weights_grad, - L3_grad) + batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad + ) def forward_cpu( - self, - L1_in: np.ndarray, - L2_in: np.ndarray, - L3_out: np.ndarray, - weights: np.ndarray - ) -> None: - - weights_chunked = np.zeros_like(weights) + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + weights_chunked = np.zeros_like(weights) if self.reorder_weights_e3nn_to_oeq is not None: - self.reorder_weights_e3nn_to_oeq(weights, weights_chunked, not self.config.shared_weights) + self.reorder_weights_e3nn_to_oeq( + weights, weights_chunked, not self.config.shared_weights + ) else: weights_chunked = weights @@ -82,102 +90,146 @@ def forward_cpu( L2_d = DeviceBuffer(L2_in) L3_d = DeviceBuffer(L3_out) weights_d = DeviceBuffer(weights_chunked) - self.internal.exec_tensor_product_rawptr(batch, L1_d.data_ptr(), L2_d.data_ptr(), L3_d.data_ptr(), weights_d.data_ptr()) + self.internal.exec_tensor_product_rawptr( + batch, + L1_d.data_ptr(), + L2_d.data_ptr(), + L3_d.data_ptr(), + weights_d.data_ptr(), + ) L3_d.copy_to_host() - def backward_cpu(self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad) -> None: - weights_chunked = np.zeros_like(weights) + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad + ) -> None: + weights_chunked = np.zeros_like(weights) if self.reorder_weights_e3nn_to_oeq is not None: - self.reorder_weights_e3nn_to_oeq(weights, weights_chunked, not self.config.shared_weights) + self.reorder_weights_e3nn_to_oeq( + weights, weights_chunked, not self.config.shared_weights + ) else: weights_chunked = weights batch = L1_in.shape[0] - L1_d, L2_d, L3_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(L3_grad) + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_grad), + ) L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = DeviceBuffer(weights_chunked), DeviceBuffer(weights_grad) + weights_d, weights_grad_d = ( + DeviceBuffer(weights_chunked), + DeviceBuffer(weights_grad), + ) self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), - L3_d.data_ptr()) - + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) + L1_grad_d.copy_to_host() L2_grad_d.copy_to_host() weights_grad_d.copy_to_host() if self.reorder_weights_oeq_to_e3nn is not None: weights_grad_copy = weights_grad.copy() - self.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not self.config.shared_weights) + self.reorder_weights_oeq_to_e3nn( + weights_grad_copy, weights_grad, not self.config.shared_weights + ) def benchmark_forward( - self, - num_warmup : int, - num_iter : int, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_buffer : np.ndarray, - weights : np.ndarray) -> np.ndarray: + self, + num_warmup: int, + num_iter: int, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_buffer: np.ndarray, + weights: np.ndarray, + ) -> np.ndarray: time_millis = np.zeros(num_iter, dtype=np.float32) # GPUTimer introduces significantly less overhead when kernel runtime < 1ms timer = GPUTimer() if self.torch_op: - torch_L1_in = torch.tensor(L1_in).to(device='cuda').detach() - torch_L2_in = torch.tensor(L2_in).to(device='cuda').detach() - torch_weights = torch.tensor(weights).to(device='cuda').detach() + torch_L1_in = torch.tensor(L1_in).to(device="cuda").detach() + torch_L2_in = torch.tensor(L2_in).to(device="cuda").detach() + torch_weights = torch.tensor(weights).to(device="cuda").detach() - for i in range(num_warmup): - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) + for i in range(num_warmup): + self.forward(torch_L1_in, torch_L2_in, torch_weights) for i in range(num_iter): timer.clear_L2_cache() timer.start() - torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - time_millis[i] = timer.stop_clock_get_elapsed() + self.forward(torch_L1_in, torch_L2_in, torch_weights) + time_millis[i] = timer.stop_clock_get_elapsed() else: batch = L1_in.shape[0] - L1_d, L2_d, L3_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(L3_buffer) + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_buffer), + ) weights_d = DeviceBuffer(weights) for i in range(num_warmup): - self.internal.exec_tensor_product_rawptr(batch, L1_d.data_ptr(), L2_d.data_ptr(), L3_d.data_ptr(), weights_d.data_ptr()) + self.internal.exec_tensor_product_rawptr( + batch, + L1_d.data_ptr(), + L2_d.data_ptr(), + L3_d.data_ptr(), + weights_d.data_ptr(), + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() - self.internal.exec_tensor_product_rawptr(batch, L1_d.data_ptr(), L2_d.data_ptr(), L3_d.data_ptr(), weights_d.data_ptr()) - time_millis[i] = timer.stop_clock_get_elapsed() - + self.internal.exec_tensor_product_rawptr( + batch, + L1_d.data_ptr(), + L2_d.data_ptr(), + L3_d.data_ptr(), + weights_d.data_ptr(), + ) + time_millis[i] = timer.stop_clock_get_elapsed() + return time_millis - + def benchmark_backward( - self, - num_warmup : int, - num_iter : int, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_buffer : np.ndarray, - weights : np.ndarray, - L1_grad : np.ndarray, - L2_grad : np.ndarray, - weights_grad : np.ndarray - ) -> np.ndarray: + self, + num_warmup: int, + num_iter: int, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_buffer: np.ndarray, + weights: np.ndarray, + L1_grad: np.ndarray, + L2_grad: np.ndarray, + weights_grad: np.ndarray, + ) -> np.ndarray: time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() - if self.torch_op: - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights = torch.tensor(weights, requires_grad=True, device='cuda') + if self.torch_op: + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) - torch_L3_grad_in = torch.tensor(L3_buffer, device='cuda') + torch_L3_grad_in = torch.tensor(L3_buffer, device="cuda") - for i in range(num_warmup): - torch_out.backward(gradient=torch_L3_grad_in, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + for i in range(num_warmup): + torch_out.backward( + gradient=torch_L3_grad_in, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) for i in range(num_iter): torch_L1_in.grad.zero_() @@ -186,7 +238,11 @@ def benchmark_backward( timer.clear_L2_cache() timer.start() - torch_out.backward(gradient=torch_L3_grad_in, retain_graph=True, inputs=[torch_L1_in, torch_L2_in, torch_weights]) + torch_out.backward( + gradient=torch_L3_grad_in, + retain_graph=True, + inputs=[torch_L1_in, torch_L2_in, torch_weights], + ) time_millis[i] = timer.stop_clock_get_elapsed() L1_grad[:] = torch_L1_in.grad.numpy(force=True) @@ -194,93 +250,132 @@ def benchmark_backward( weights_grad[:] = torch_weights.grad.numpy(force=True) else: batch = L1_in.shape[0] - L1_d, L2_d, L3_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(L3_buffer) + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_buffer), + ) L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = DeviceBuffer(weights), DeviceBuffer(weights_grad) + weights_d, weights_grad_d = ( + DeviceBuffer(weights), + DeviceBuffer(weights_grad), + ) for i in range(num_warmup): self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), - L3_d.data_ptr()) + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), - L3_d.data_ptr()) + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) time_millis[i] = timer.stop_clock_get_elapsed() - + return time_millis - + def benchmark_double_backward( - self, - num_warmup : int, - num_iter : int, - L1_in : np.ndarray, - L2_in : np.ndarray, - L3_buffer : np.ndarray, - weights : np.ndarray, - L1_grad : np.ndarray, - L2_grad : np.ndarray, - weights_grad : np.ndarray, - L3_double_grad : np.ndarray, - ) -> np.ndarray: + self, + num_warmup: int, + num_iter: int, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_buffer: np.ndarray, + weights: np.ndarray, + L1_grad: np.ndarray, + L2_grad: np.ndarray, + weights_grad: np.ndarray, + L3_double_grad: np.ndarray, + ) -> np.ndarray: time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() - if self.torch_op: - torch_L1_in = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_in = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights = torch.tensor(weights, requires_grad=True, device='cuda') + if self.torch_op: + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") torch_out = self(torch_L1_in, torch_L2_in, torch_weights) - torch_out_grad = torch_out.clone().detach().to(device='cuda').requires_grad_(True) + torch_out_grad = ( + torch_out.clone().detach().to(device="cuda").requires_grad_(True) + ) (torch_L1_grad, torch_L2_grad, torch_weights_grad) = torch.autograd.grad( outputs=torch_out, inputs=[torch_L1_in, torch_L2_in, torch_weights], grad_outputs=torch_out_grad, create_graph=True, - retain_graph=True + retain_graph=True, ) - dummy = torch.norm(torch_L1_grad) + torch.norm(torch_L2_grad) + torch.norm(torch_weights_grad) - dummy_grad = torch.tensor(float(dummy), device='cuda', requires_grad=True) + dummy = ( + torch.norm(torch_L1_grad) + + torch.norm(torch_L2_grad) + + torch.norm(torch_weights_grad) + ) + dummy_grad = torch.tensor(float(dummy), device="cuda", requires_grad=True) - torch_L1_grad = torch.tensor(L1_in, requires_grad=True, device='cuda') - torch_L2_grad = torch.tensor(L2_in, requires_grad=True, device='cuda') - torch_weights_grad = torch.tensor(weights_grad, requires_grad=True, device='cuda') - torch_L3_double_grad = torch.tensor(L3_double_grad, device='cuda', requires_grad=True) + torch_L1_grad = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_grad = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights_grad = torch.tensor( + weights_grad, requires_grad=True, device="cuda" + ) + torch_L3_double_grad = torch.tensor( + L3_double_grad, device="cuda", requires_grad=True + ) - - (torch_L1_grad, torch_L2_grad, torch_weights_grad, torch_L3_double_grad) = torch.autograd.grad( - outputs = dummy, - inputs = [torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], - grad_outputs = dummy_grad, - retain_graph=True) + (torch_L1_grad, torch_L2_grad, torch_weights_grad, torch_L3_double_grad) = ( + torch.autograd.grad( + outputs=dummy, + inputs=[torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], + grad_outputs=dummy_grad, + retain_graph=True, + ) + ) - for i in range(num_warmup): - (torch_L1_grad, torch_L2_grad, torch_weights_grad, torch_L3_double_grad) = torch.autograd.grad( - outputs = dummy, - inputs = [torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], - grad_outputs = dummy_grad, - retain_graph=True) + for i in range(num_warmup): + ( + torch_L1_grad, + torch_L2_grad, + torch_weights_grad, + torch_L3_double_grad, + ) = torch.autograd.grad( + outputs=dummy, + inputs=[torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], + grad_outputs=dummy_grad, + retain_graph=True, + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() - (torch_L1_grad, torch_L2_grad, torch_weights_grad, torch_L3_double_grad) = torch.autograd.grad( - outputs = dummy, - inputs = [torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], - grad_outputs = dummy_grad, - retain_graph=True) + ( + torch_L1_grad, + torch_L2_grad, + torch_weights_grad, + torch_L3_double_grad, + ) = torch.autograd.grad( + outputs=dummy, + inputs=[torch_L1_in, torch_L2_in, torch_weights, torch_out_grad], + grad_outputs=dummy_grad, + retain_graph=True, + ) time_millis[i] = timer.stop_clock_get_elapsed() L1_grad[:] = torch_L1_grad.numpy(force=True) @@ -289,49 +384,63 @@ def benchmark_double_backward( L3_double_grad[:] = torch_L3_double_grad.numpy(force=True) else: batch = L1_in.shape[0] - L1_d, L2_d, L3_d = DeviceBuffer(L1_in), DeviceBuffer(L2_in), DeviceBuffer(L3_buffer) + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_buffer), + ) L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = DeviceBuffer(weights), DeviceBuffer(weights_grad) + weights_d, weights_grad_d = ( + DeviceBuffer(weights), + DeviceBuffer(weights_grad), + ) for i in range(num_warmup): self.internal.double_backward( - batch, - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), - L3_d.data_ptr()) + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) for i in range(num_iter): timer.clear_L2_cache() timer.start() self.internal.double_backward( - batch, - L1_d.data_ptr(), L1_grad_d.data_ptr(), - L2_d.data_ptr(), L2_grad_d.data_ptr(), - weights_d.data_ptr(), weights_grad_d.data_ptr(), - L3_d.data_ptr()) + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) time_millis[i] = timer.stop_clock_get_elapsed() - + return time_millis - def calculate_memory_streamed_forward(self, batch_size : int) -> dict: + def calculate_memory_streamed_forward(self, batch_size: int) -> dict: raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_backward(self, batch_size : int) -> dict: - raise NotImplementedError("This needs to be implemented in your class") - - def calculate_memory_streamed_double_backward(self, batch_size : int) -> dict: + + def calculate_memory_streamed_backward(self, batch_size: int) -> dict: raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_forward(self, batch_size : int) -> dict: + + def calculate_memory_streamed_double_backward(self, batch_size: int) -> dict: raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_backward(self, batch_size : int) -> dict: + + def calculate_flops_forward(self, batch_size: int) -> dict: raise NotImplementedError("This needs to be implemented in your class") - - def calculate_flops_double_backward(self, batch_size : int) -> dict: + + def calculate_flops_backward(self, batch_size: int) -> dict: raise NotImplementedError("This needs to be implemented in your class") + def calculate_flops_double_backward(self, batch_size: int) -> dict: + raise NotImplementedError("This needs to be implemented in your class") def setup_torch_custom_op(self): if not extlib.TORCH_COMPILE: @@ -339,23 +448,49 @@ def setup_torch_custom_op(self): def setup_nocompile_ops(self): # ----------------- Forward pass ----------------- - @torch.library.custom_op(f"openequivariance::tp_forward{self.tp_id}", mutates_args=(), device_types="cuda") - def forward(L1_in : torch.Tensor, L2_in : torch.Tensor, weights : torch.Tensor) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = L1_in.contiguous(), L2_in.contiguous(), weights.contiguous() - L3_out = torch.empty((L1_in_c.shape[0], self.L3.dim ), dtype=L1_in.dtype, device='cuda') - self.forward_raw(L1_in_c.shape[0], L1_in_c.data_ptr(), L2_in_c.data_ptr(), L3_out.data_ptr(), weights_c.data_ptr()) + @torch.library.custom_op( + f"openequivariance::tp_forward{self.tp_id}", + mutates_args=(), + device_types="cuda", + ) + def forward( + L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor + ) -> torch.Tensor: + L1_in_c, L2_in_c, weights_c = ( + L1_in.contiguous(), + L2_in.contiguous(), + weights.contiguous(), + ) + L3_out = torch.empty( + (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device="cuda" + ) + self.forward_raw( + L1_in_c.shape[0], + L1_in_c.data_ptr(), + L2_in_c.data_ptr(), + L3_out.data_ptr(), + weights_c.data_ptr(), + ) return L3_out - + @forward.register_fake def _(L1_in, L2_in, weights): return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - + self.forward = forward - + # ---------------- Backward pass ----------------- - @torch.library.custom_op(f"openequivariance::tp_grad_helper{self.tp_id}", mutates_args=(), device_types="cuda") - def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor, - weights : torch.Tensor, L3_grad : torch.Tensor ) -> typing.List[torch.Tensor]: + @torch.library.custom_op( + f"openequivariance::tp_grad_helper{self.tp_id}", + mutates_args=(), + device_types="cuda", + ) + def backward_helper( + L1_in: torch.Tensor, + L2_in: torch.Tensor, + weights: torch.Tensor, + L3_grad: torch.Tensor, + ) -> typing.List[torch.Tensor]: L1_grad = torch.empty_like(L1_in) L2_grad = torch.empty_like(L2_in) weights_grad = torch.empty_like(weights) @@ -363,24 +498,30 @@ def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor, if self.config.shared_weights: weights_grad[:] = 0.0 - self.backward_raw( L1_in.shape[0], - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr() ) - - return [L1_grad, L2_grad, weights_grad] - + self.backward_raw( + L1_in.shape[0], + L1_in.contiguous().data_ptr(), + L1_grad.data_ptr(), + L2_in.contiguous().data_ptr(), + L2_grad.data_ptr(), + weights.contiguous().data_ptr(), + weights_grad.data_ptr(), + L3_grad.contiguous().data_ptr(), + ) + + return [L1_grad, L2_grad, weights_grad] + @backward_helper.register_fake def _(L1_in, L2_in, weights, L3_grad): - return [L1_in.new_empty(*L1_in.shape), L2_in.new_empty(*L2_in.shape), weights.new_empty(*weights.shape)] + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + weights.new_empty(*weights.shape), + ] def setup_context(ctx, inputs, output): ctx.L1_in, ctx.L2_in, ctx.weights = inputs - + def backward(ctx, grad_output): result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) return result[0], result[1], result[2] @@ -388,7 +529,7 @@ def backward(ctx, grad_output): self.forward.register_autograd(backward, setup_context=setup_context) def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs def double_backward(ctx, grad_output): A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights @@ -397,11 +538,20 @@ def double_backward(ctx, grad_output): op1 = backward_helper(E, F, D, C) op2 = backward_helper(A, B, G, C) op3 = forward(E, B, D) - op4 = backward_helper(E, B, D, C) # op4 and op5 could be combined with op3 and op6 - op5 = backward_helper(A, F, D, C) + op4 = backward_helper( + E, B, D, C + ) # op4 and op5 could be combined with op3 and op6 + op5 = backward_helper(A, F, D, C) op6 = forward(A, F, D) op7 = forward(A, B, G) - return op1[0] + op2[0], op1[1] + op2[1], (op4[2] + op5[2]), (op3 + op6 + op7) + return ( + op1[0] + op2[0], + op1[1] + op2[1], + (op4[2] + op5[2]), + (op3 + op6 + op7), + ) - backward_helper.register_autograd(double_backward, setup_context=setup_context_double_backward) \ No newline at end of file + backward_helper.register_autograd( + double_backward, setup_context=setup_context_double_backward + ) diff --git a/openequivariance/implementations/convolution/CUEConv.py b/openequivariance/implementations/convolution/CUEConv.py index 9e70fc76..2ef6ccf4 100644 --- a/openequivariance/implementations/convolution/CUEConv.py +++ b/openequivariance/implementations/convolution/CUEConv.py @@ -1,10 +1,10 @@ import numpy as np -import numpy.linalg as la import itertools +from typing import Iterator from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.convolution.ConvolutionBase import * -from openequivariance.benchmark.tpp_creation_utils import * +from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase + class CUEConv(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, torch_op=True): @@ -17,16 +17,20 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True): self.cue_tp = self.reference_tp.cue_tp from openequivariance.implementations.convolution.scatter import scatter_sum + self.scatter_sum = scatter_sum def forward(self, L1_in, L2_in, weights, rows, cols): tp_outputs = self.cue_tp(L1_in[cols], L2_in, weights) - return self.scatter_sum(src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]) + return self.scatter_sum( + src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] + ) @staticmethod def name(): return "CUEConvolution" + class CUEConvFused(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype, torch_op) @@ -35,14 +39,12 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True): import torch import e3nn.o3 as o3 - np_to_torch_dtype = { - np.float32: torch.float32, - np.float64: torch.float64 - } + np_to_torch_dtype = {np.float32: torch.float32, np.float64: torch.float64} import cuequivariance as cue - import cuequivariance_torch as cuet - from cuequivariance_torch.primitives.tensor_product import TensorProductUniform4x1dIndexed + from cuequivariance_torch.primitives.tensor_product import ( + TensorProductUniform4x1dIndexed, + ) class O3_e3nn(cue.O3): def __mul__( # pylint: disable=no-self-argument @@ -70,21 +72,29 @@ def __lt__( # pylint: disable=no-self-argument @classmethod def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): + for l in itertools.count(0): # noqa : E741 yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l) - descriptor = (cue.descriptors.channelwise_tensor_product( + descriptor = ( + cue.descriptors.channelwise_tensor_product( cue.Irreps(O3_e3nn, str(config.irreps_in1)), cue.Irreps(O3_e3nn, str(config.irreps_in2)), - cue.Irreps(O3_e3nn, str(config.irreps_out)) - ).squeeze_modes().flatten_coefficient_modes()) - - self.tp = TensorProductUniform4x1dIndexed(descriptor.polynomial.operations[0][1], 'cuda', math_dtype=np_to_torch_dtype[config.irrep_dtype]) + cue.Irreps(O3_e3nn, str(config.irreps_out)), + ) + .squeeze_modes() + .flatten_coefficient_modes() + ) + + self.tp = TensorProductUniform4x1dIndexed( + descriptor.polynomial.operations[0][1], + "cuda", + math_dtype=np_to_torch_dtype[config.irrep_dtype], + ) def forward(self, L1_in, L2_in, weights, rows, cols): return self.tp(weights, L1_in, L2_in, None, rows, None, cols, L1_in.shape[0]) @staticmethod def name(): - return "CUEConvolutionFused" \ No newline at end of file + return "CUEConvolutionFused" diff --git a/openequivariance/implementations/convolution/E3NNConv.py b/openequivariance/implementations/convolution/E3NNConv.py index 31f56a04..216aa699 100644 --- a/openequivariance/implementations/convolution/E3NNConv.py +++ b/openequivariance/implementations/convolution/E3NNConv.py @@ -1,31 +1,33 @@ import numpy as np -import numpy.linalg as la -from openequivariance.implementations.convolution.ConvolutionBase import * -from openequivariance.implementations.E3NNTensorProduct import * +from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase +from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + class E3NNConv(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, torch_op=True): - assert(torch_op) + assert torch_op super().__init__(config, idx_dtype, torch_op) - import e3nn from e3nn import o3 + import torch + if config.irrep_dtype == np.float64: torch.set_default_dtype(torch.float64) self.e3nn_tp = o3.TensorProduct( - config.irreps_in1, - config.irreps_in2, - config.irreps_out, - config.instructions_raw, - in1_var=config.in1_var, - in2_var=config.in2_var, - out_var=config.out_var, - irrep_normalization=config.irrep_normalization, - path_normalization=config.path_normalization, - internal_weights=config.internal_weights, - shared_weights=config.shared_weights).to(device='cuda') + config.irreps_in1, + config.irreps_in2, + config.irreps_out, + config.instructions_raw, + in1_var=config.in1_var, + in2_var=config.in2_var, + out_var=config.out_var, + irrep_normalization=config.irrep_normalization, + path_normalization=config.path_normalization, + internal_weights=config.internal_weights, + shared_weights=config.shared_weights, + ).to(device="cuda") self.reference_tp = E3NNTensorProduct(config) @@ -33,15 +35,18 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True): torch.set_default_dtype(torch.float32) # Reset to default from openequivariance.implementations.convolution.scatter import scatter_sum + self.scatter_sum = scatter_sum def forward(self, L1_in, L2_in, weights, rows, cols): tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights) - return self.scatter_sum(src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]) + return self.scatter_sum( + src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] + ) @staticmethod def name(): - return "E3NNConvolution" + return "E3NNConvolution" def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype) @@ -49,16 +54,24 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): np.add.at(L3_out, graph.rows, tp_outputs) def backward_cpu( - self, - L1_in : np.ndarray, - L1_grad : np.ndarray, - L2_in : np.ndarray, - L2_grad : np.ndarray, - L3_grad : np.ndarray, - weights : np.ndarray, - weights_grad : np.ndarray, - graph): + self, + L1_in: np.ndarray, + L1_grad: np.ndarray, + L2_in: np.ndarray, + L2_grad: np.ndarray, + L3_grad: np.ndarray, + weights: np.ndarray, + weights_grad: np.ndarray, + graph, + ): L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype) self.reference_tp.backward_cpu( - L1_in[graph.cols], L1_grad_bcast, L2_in, L2_grad, L3_grad[graph.rows], weights, weights_grad) - np.add.at(L1_grad, graph.cols, L1_grad_bcast) \ No newline at end of file + L1_in[graph.cols], + L1_grad_bcast, + L2_in, + L2_grad, + L3_grad[graph.rows], + weights, + weights_grad, + ) + np.add.at(L1_grad, graph.cols, L1_grad_bcast) diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 51167640..a345cf85 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -1,16 +1,33 @@ -from openequivariance.implementations.convolution.ConvolutionBase import * -from openequivariance.implementations.ComputationSchedule import ComputationSchedule, SMEMCapacityException -from openequivariance.implementations.TensorProduct import * -from openequivariance.templates.jinja_utils import * -from openequivariance.extlib import * +import numpy as np + +from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase +from openequivariance.implementations.ComputationSchedule import ( + ComputationSchedule, + SMEMCapacityException, +) + + +from openequivariance.templates.jinja_utils import get_jinja_environment +from openequivariance import extlib +from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp + +from openequivariance.implementations.utils import filter_and_analyze_problem + +from openequivariance.benchmark.logging_utils import getLogger + +logger = getLogger() -from openequivariance.implementations.utils import filter_and_analyze_problem class LoopUnrollConv(ConvolutionBase): - def __init__(self, config, idx_dtype=np.int64, - torch_op=False, deterministic=False, kahan=False): + def __init__( + self, + config, + idx_dtype=np.int64, + torch_op=False, + deterministic=False, + kahan=False, + ): super().__init__(config, idx_dtype, torch_op, deterministic) - L1, L2, L3 = self.L1, self.L2, self.L3 if kahan: assert deterministic @@ -23,7 +40,9 @@ def __init__(self, config, idx_dtype=np.int64, self.is_uvw = analysis["is_uvw"] if config.shared_weights: - assert not deterministic, "Deterministic convolution does not support shared weights" + assert not deterministic, ( + "Deterministic convolution does not support shared weights" + ) forward_schedule_type = 3 backward_schedule_type = 2 @@ -32,46 +51,58 @@ def __init__(self, config, idx_dtype=np.int64, template = env.get_template("loop_unroll_conv_det.cuh") def generate_forward_schedule(warps_per_block): - self.forward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock // 4 * 3, warps_per_block=warps_per_block, - block_count=dp.multiprocessorCount, - direction = "forward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - schedule_type=forward_schedule_type, - warp_size=dp.warpsize, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - kahan=kahan) + self.forward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock // 4 * 3, + warps_per_block=warps_per_block, + block_count=dp.multiprocessorCount, + direction="forward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + schedule_type=forward_schedule_type, + warp_size=dp.warpsize, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + kahan=kahan, + ) def generate_backward_schedule(warps_per_block): - self.backward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock, warps_per_block=warps_per_block, - block_count=dp.multiprocessorCount * 2, - direction = "backward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - schedule_type=backward_schedule_type, - warp_size=dp.warpsize, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - kahan=kahan) - + self.backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + block_count=dp.multiprocessorCount * 2, + direction="backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + schedule_type=backward_schedule_type, + warp_size=dp.warpsize, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + kahan=kahan, + ) + def generate_double_backward_schedule(warps_per_block): - self.double_backward_schedule = ComputationSchedule(self.config, - smem_limit=dp.maxSharedMemPerBlock, - warps_per_block=warps_per_block, - warp_size=dp.warpsize, - block_count=dp.multiprocessorCount, - direction = "double_backward", - irrep_dtype = config.irrep_dtype, - weight_dtype = config.weight_dtype, - include_scratch=self.is_uvw, - stream_weights=self.is_uvw, - schedule_type=3, - kahan=kahan) - - scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] + self.double_backward_schedule = ComputationSchedule( + self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction="double_backward", + irrep_dtype=config.irrep_dtype, + weight_dtype=config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3, + kahan=kahan, + ) + + scheduler_generators = [ + generate_forward_schedule, + generate_backward_schedule, + generate_double_backward_schedule, + ] for generate_schedule in scheduler_generators: warp_count = 6 @@ -79,10 +110,12 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except SMEMCapacityException as e: + except SMEMCapacityException: warp_count -= 1 if warp_count == 0: - raise SMEMCapacityException("Tensor product schedule generation failed, shared memory inadequate!") + raise SMEMCapacityException( + "Tensor product schedule generation failed, shared memory inadequate!" + ) if not deterministic: for segment in self.forward_schedule.segments: @@ -97,7 +130,6 @@ def generate_double_backward_schedule(warps_per_block): for key in segment.L1Map.storeback_procedure: segment.L1Map.storeback_procedure[key] = "atomic_accumulate" - idx_type_map = {np.int32: "int", np.int64: "long"} if self.torch_op: @@ -109,19 +141,47 @@ def generate_double_backward_schedule(warps_per_block): workspace_size = 1 if deterministic: - destination_index_bytes = 32 # Add extra to account for padding + destination_index_bytes = 32 # Add extra to account for padding workspace_size = max( - (self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.forward_schedule.total_warps, - (self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps, - (self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.double_backward_schedule.total_warps + ( + self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.forward_schedule.total_warps, + ( + self.backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.backward_schedule.total_warps, + ( + self.double_backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + + destination_index_bytes + ) + * self.double_backward_schedule.total_warps, ) - self.forward_workspace_offset = self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize * self.forward_schedule.total_warps - self.backward_workspace_offset = self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.backward_schedule.total_warps - self.double_backwardB_offset = self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.double_backward_schedule.total_warps + self.forward_workspace_offset = ( + self.forward_schedule.L3.dim + * np.dtype(config.irrep_dtype).itemsize + * self.forward_schedule.total_warps + ) + self.backward_workspace_offset = ( + self.backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + * self.backward_schedule.total_warps + ) + self.double_backwardB_offset = ( + self.double_backward_schedule.L1.dim + * np.dtype(config.irrep_dtype).itemsize + * self.double_backward_schedule.total_warps + ) self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8 - self.backward_workspace_offset = (self.backward_workspace_offset + 7) // 8 * 8 + self.backward_workspace_offset = ( + (self.backward_workspace_offset + 7) // 8 * 8 + ) self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 self.allocate_workspace(workspace_size) @@ -133,7 +193,8 @@ def generate_double_backward_schedule(warps_per_block): idx_type=idx_type_map[idx_dtype], forward_workspace_offset=self.forward_workspace_offset, backward_workspace_offset=self.backward_workspace_offset, - double_backwardB_offset=self.double_backwardB_offset) + double_backwardB_offset=self.double_backwardB_offset, + ) self.jit_kernel = postprocess_kernel(self.jit_kernel) if self.torch_op and extlib.TORCH_COMPILE: @@ -142,26 +203,34 @@ def generate_double_backward_schedule(warps_per_block): internal_cls = torch.classes.torch_tp_jit.TorchJITConv else: - internal_cls = JITConvImpl + internal_cls = JITConvImpl logger.info("Starting kernel compiler...") - self.internal = internal_cls(self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - {"L3_dim": self.L3.dim, - "is_uvw": int(self.is_uvw), - "shared_weights": int(config.shared_weights)}) + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + { + "L3_dim": self.L3.dim, + "is_uvw": int(self.is_uvw), + "shared_weights": int(config.shared_weights), + }, + ) logger.info("Kernel compiled!") - #with open("scratch.txt", "w") as f: + # with open("scratch.txt", "w") as f: # f.write(self.jit_kernel) def reorder_weights_e3nn_to_oeq(self, input, output, has_batch_dim): - return self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) - + return self.forward_schedule.reorder_weights( + input, output, "forward", has_batch_dim + ) + def reorder_weights_oeq_to_e3nn(self, input, output, has_batch_dim): - return self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) + return self.forward_schedule.reorder_weights( + input, output, "backward", has_batch_dim + ) @staticmethod def name(): @@ -174,12 +243,27 @@ def register_torch_fakes(cls): @torch._library.register_fake_class("torch_tp_jit::TorchJITConv") class TorchJITConv: - def __init__(self, kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int]) -> None: - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) @classmethod def __obj_unflatten__(cls, flattened_product): @@ -187,43 +271,113 @@ def __obj_unflatten__(cls, flattened_product): def __len__(self): return 0 - + def __setstate__(self, state): - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = state - + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = state + @torch.library.register_fake("torch_tp_jit::jit_conv_forward") - def fake_forward(jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm): - return L1_in.new_empty(L1_in.shape[0], jit.wrapped_obj.kernel_dims["L3_dim"]) + def fake_forward( + jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + ): + return L1_in.new_empty( + L1_in.shape[0], jit.wrapped_obj.kernel_dims["L3_dim"] + ) @torch.library.register_fake("torch_tp_jit::jit_conv_backward") - def fake_backward(jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + def fake_backward( + jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + ): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @classmethod def register_autograd(cls): - forward_op = torch.ops.torch_tp_jit.jit_conv_forward backward_op = torch.ops.torch_tp_jit.jit_conv_backward double_backward_op = torch.ops.torch_tp_jit.jit_conv_double_backward def setup_context(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs - + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad= backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, grad_output, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm) + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) return None, L1_grad, L2_grad, W_grad, None, None, None, None - torch.library.register_autograd("torch_tp_jit::jit_conv_forward", backward, setup_context=setup_context) + torch.library.register_autograd( + "torch_tp_jit::jit_conv_forward", backward, setup_context=setup_context + ) def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs ctx.inputs = inputs def double_backward(ctx, E, F, G): - result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm) - return None, result[0], result[1], result[2], result[3], None, None, None, None + result = double_backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + E, + F, + G, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return ( + None, + result[0], + result[1], + result[2], + result[3], + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "torch_tp_jit::jit_conv_backward", + double_backward, + setup_context=setup_context_double_backward, + ) - torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward) if extlib.TORCH_COMPILE: LoopUnrollConv.register_torch_fakes() - LoopUnrollConv.register_autograd() \ No newline at end of file + LoopUnrollConv.register_autograd() diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index dca2a479..fc77bc3a 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -1,46 +1,87 @@ +from typing import Optional +import types + +import numpy as np +import torch + from openequivariance import extlib -from openequivariance.implementations.convolution.LoopUnrollConv import * +from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase +from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv from openequivariance.implementations.TensorProduct import TensorProduct -import numpy as np -from typing import Optional -import types class TensorProductConv(torch.nn.Module, LoopUnrollConv): - ''' + """ PyTorch-specialized dispatcher class. - ''' - def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=False, kahan=False): + """ + + def __init__( + self, + config, + idx_dtype=np.int64, + torch_op=True, + deterministic=False, + kahan=False, + ): torch.nn.Module.__init__(self) - LoopUnrollConv.__init__(self, config, idx_dtype=np.int64, - torch_op=torch_op, deterministic=deterministic, kahan=kahan) - - self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') + LoopUnrollConv.__init__( + self, + config, + idx_dtype=np.int64, + torch_op=torch_op, + deterministic=deterministic, + kahan=kahan, + ) + + self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel if not extlib.TORCH_COMPILE: - self.forward = types.MethodType(LoopUnrollConv.forward, self) - - def forward(self, L1_in: torch.Tensor, L2_in: - torch.Tensor, W: torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor, sender_perm: Optional[torch.Tensor]=None) -> torch.Tensor: + self.forward = types.MethodType(LoopUnrollConv.forward, self) + + def forward( + self, + L1_in: torch.tensor, + L2_in: torch.tensor, + W: torch.tensor, + rows: torch.tensor, + cols: torch.tensor, + sender_perm: Optional[torch.tensor] = None, + ) -> torch.tensor: if sender_perm is None: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, self.dummy_transpose_perm) + return torch.ops.torch_tp_jit.jit_conv_forward( + self.internal, + L1_in, + L2_in, + W, + rows, + cols, + self.workspace_buffer, + self.dummy_transpose_perm, + ) else: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, sender_perm) + return torch.ops.torch_tp_jit.jit_conv_forward( + self.internal, + L1_in, + L2_in, + W, + rows, + cols, + self.workspace_buffer, + sender_perm, + ) @staticmethod def name(): return LoopUnrollConv.name() - + # ================================================================== # Reference implementations for benchmarking + class TensorProductConvKahan(TensorProductConv): - def __init__(self, config, - idx_dtype=np.int64, - torch_op=True): + def __init__(self, config, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype, torch_op, deterministic=True, kahan=True) @staticmethod @@ -49,28 +90,26 @@ def name(): class TensorProductConvDeterministic(TensorProductConv): - def __init__(self, config, - idx_dtype=np.int64, - torch_op=True): + def __init__(self, config, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype, torch_op, deterministic=True) @staticmethod def name(): return "LoopUnrollConvDeterministic" + class TensorProductConvAtomic(TensorProductConv): - def __init__(self, config, - idx_dtype=np.int64, - torch_op=True): + def __init__(self, config, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype, torch_op, deterministic=False) @staticmethod def name(): return "LoopUnrollConvAtomic" + class TensorProductConvScatterSum(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, torch_op=True): - assert(torch_op) + assert torch_op global torch import torch @@ -78,32 +117,43 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True): self.reference_tp = TensorProduct(config, torch_op=torch_op) from openequivariance.implementations.convolution.scatter import scatter_sum + self.scatter_sum = scatter_sum def forward(self, L1_in, L2_in, weights, rows, cols): tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights) - return self.scatter_sum(src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]) - + return self.scatter_sum( + src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] + ) + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype) self.reference_tp.forward_cpu(L1_in[graph.cols], L2_in, tp_outputs, weights) np.add.at(L3_out, graph.rows, tp_outputs) def backward_cpu( - self, - L1_in : np.ndarray, - L1_grad : np.ndarray, - L2_in : np.ndarray, - L2_grad : np.ndarray, - L3_grad : np.ndarray, - weights : np.ndarray, - weights_grad : np.ndarray, - graph): + self, + L1_in: np.ndarray, + L1_grad: np.ndarray, + L2_in: np.ndarray, + L2_grad: np.ndarray, + L3_grad: np.ndarray, + weights: np.ndarray, + weights_grad: np.ndarray, + graph, + ): L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype) self.reference_tp.backward_cpu( - L1_in[graph.cols], L1_grad_bcast, L2_in, L2_grad, L3_grad[graph.rows], weights, weights_grad) + L1_in[graph.cols], + L1_grad_bcast, + L2_in, + L2_grad, + L3_grad[graph.rows], + weights, + weights_grad, + ) np.add.at(L1_grad, graph.cols, L1_grad_bcast) @staticmethod def name(): - return "LoopUnrollConvScatterSum" + return "LoopUnrollConvScatterSum" diff --git a/openequivariance/implementations/convolution/scatter.py b/openequivariance/implementations/convolution/scatter.py index b2f365fe..a71198c7 100644 --- a/openequivariance/implementations/convolution/scatter.py +++ b/openequivariance/implementations/convolution/scatter.py @@ -1,13 +1,15 @@ import torch -from typing import Optional +from typing import Optional -''' +""" Scatter sum operator from MACE. basic scatter_sum operations from torch_scatter from https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. -''' +""" + + def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): if dim < 0: dim = other.dim() + dim @@ -19,6 +21,7 @@ def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): src = src.expand_as(other) return src + def scatter_sum( src: torch.Tensor, index: torch.Tensor, @@ -40,4 +43,4 @@ def scatter_sum( out = torch.zeros(size, dtype=src.dtype, device=src.device) return out.scatter_add_(dim, index, src) else: - return out.scatter_add_(dim, index, src) \ No newline at end of file + return out.scatter_add_(dim, index, src) diff --git a/openequivariance/implementations/e3nn_lite.py b/openequivariance/implementations/e3nn_lite.py index 4220869e..df708071 100644 --- a/openequivariance/implementations/e3nn_lite.py +++ b/openequivariance/implementations/e3nn_lite.py @@ -1,4 +1,5 @@ -''' +# ruff: noqa: E741, E743 +""" This file contains lightly modified code from E3NN. The code has been modified to remove all dependency on Pytorch. @@ -12,14 +13,14 @@ MIT License for e3nn: Euclidean neural networks (e3nn) Copyright (c) 2020, The Regents of the University of California, through Lawrence Berkeley National Laboratory -(subject to receipt of any required approvals from the U.S. Dept. of Energy), -Ecole Polytechnique Federale de Lausanne (EPFL), Free University of Berlin +(subject to receipt of any required approvals from the U.S. Dept. of Energy), +Ecole Polytechnique Federale de Lausanne (EPFL), Free University of Berlin and Kostiantyn Lapchevskyi. All rights reserved. -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, -copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: @@ -31,18 +32,18 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import itertools from typing import Tuple, NamedTuple, Union, List, Any, Optional from math import sqrt, prod import collections -import sys import numpy as np import numpy.linalg as la -import functools, math +import functools + def perm_inverse(p): r""" @@ -50,6 +51,7 @@ def perm_inverse(p): """ return tuple(p.index(i) for i in range(len(p))) + class Irrep(tuple): def __new__(cls, l: Union[int, "Irrep", str, tuple], p=None): if p is None: @@ -154,6 +156,7 @@ def __contains__(self, _object): def __len__(self): raise NotImplementedError + class _MulIr(tuple): def __new__(cls, mul, ir=None): if ir is None: @@ -295,7 +298,9 @@ def __mul__(self, other) -> "Irreps": 6x1e """ if isinstance(other, Irreps): - raise NotImplementedError("Use o3.TensorProduct for this, see the documentation") + raise NotImplementedError( + "Use o3.TensorProduct for this, see the documentation" + ) return Irreps(super().__mul__(other)) def __rmul__(self, other) -> "Irreps": @@ -362,12 +367,12 @@ class Instruction(NamedTuple): path_shape: tuple -class TPProblem: +class TPProblem: instructions: List[Any] shared_weights: bool internal_weights: bool weight_numel: int - label : str + label: str _profiling_str: str _in1_dim: int _in2_dim: int @@ -378,17 +383,17 @@ def __init__( irreps_in2: Irreps, irreps_out: Irreps, instructions: List[tuple], - in1_var: Optional[List[float]] = None, - in2_var: Optional[List[float]] = None, - out_var: Optional[List[float]] = None, + in1_var: Optional[List[float]] = None, + in2_var: Optional[List[float]] = None, + out_var: Optional[List[float]] = None, irrep_normalization: str = "component", path_normalization: str = "element", internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, - label: Optional[str] = None, - irrep_dtype : type[np.generic] = np.float32, - weight_dtype : type[np.generic] = np.float32) -> None: - + label: Optional[str] = None, + irrep_dtype: type[np.generic] = np.float32, + weight_dtype: type[np.generic] = np.float32, + ) -> None: # === Setup === super().__init__() @@ -396,7 +401,7 @@ def __init__( assert path_normalization in ["element", "path", "none"] assert issubclass(irrep_dtype, np.generic) assert issubclass(weight_dtype, np.generic) - + self.irreps_in1 = Irreps(irreps_in1) self.irreps_in2 = Irreps(irreps_in2) self.irreps_out = Irreps(irreps_out) @@ -420,14 +425,27 @@ def __init__( has_weight=has_weight, path_weight=path_weight, path_shape={ - "uvw": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul, self.irreps_out[i_out].mul), + "uvw": ( + self.irreps_in1[i_in1].mul, + self.irreps_in2[i_in2].mul, + self.irreps_out[i_out].mul, + ), "uvu": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), "uvv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), "uuw": (self.irreps_in1[i_in1].mul, self.irreps_out[i_out].mul), "uuu": (self.irreps_in1[i_in1].mul,), "uvuv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul), - "uvu str: result += f"{self.in2_var = }\n" result += f"{self.out_var = }\n" result += f"num weights {self.weight_numel} \n" - result += f"| index | l | m | mode | weights | \n" - result += f"| in1 | in2 | out | in1 | in2 | out | in1 | in2 | out | | exist | path | \n" - for ins in self.instructions: # type : Instruction + result += "| index | l | m | mode | weights | \n" + result += "| in1 | in2 | out | in1 | in2 | out | in1 | in2 | out | | exist | path | \n" + for ins in self.instructions: # type : Instruction mul_irrep_in1 = self.irreps_in1[ins.i_in1] mul_irrep_in2 = self.irreps_in2[ins.i_in2] mul_irrep_out = self.irreps_out[ins.i_out] @@ -561,16 +618,19 @@ def __repr__(self) -> str: result += f" {str(ins.has_weight):<5} |" result += f" {ins.path_weight:<4.2f} | " result += "\n" - result = result.replace("self.","") - return result - - def weight_range_and_shape_for_instruction(self, instruction: int) -> Tuple[int, int, tuple]: + result = result.replace("self.", "") + return result + + def weight_range_and_shape_for_instruction( + self, instruction: int + ) -> Tuple[int, int, tuple]: if not self.instructions[instruction].has_weight: raise ValueError(f"Instruction {instruction} has no weights.") offset = sum(prod(ins.path_shape) for ins in self.instructions[:instruction]) ins = self.instructions[instruction] return offset, offset + prod(ins.path_shape), ins.path_shape + def change_basis_real_to_complex(l: int, dtype=None) -> np.ndarray: # https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form q = np.zeros((2 * l + 1, 2 * l + 1), dtype=np.complex128) @@ -581,16 +641,17 @@ def change_basis_real_to_complex(l: int, dtype=None) -> np.ndarray: for m in range(1, l + 1): q[l + m, l + abs(m)] = (-1) ** m / 2**0.5 q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5 - q = (-1j) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real + q = ( + -1j + ) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real - dtype = { - np.float32: np.complex64, - np.float64: np.complex128, - None: np.complex128 - }[dtype] + dtype = {np.float32: np.complex64, np.float64: np.complex128, None: np.complex128}[ + dtype + ] return q.astype(dtype) + def wigner_3j(l1: int, l2: int, l3: int, dtype=np.float64) -> np.ndarray: r"""Wigner 3j symbols :math:`C_{lmn}`. @@ -618,7 +679,7 @@ def wigner_3j(l1: int, l2: int, l3: int, dtype=np.float64) -> np.ndarray: :math:`l_3` dtype : np.dtype or None - ``dtype`` of the returned tensor. Default is np.float64 + ``dtype`` of the returned tensor. Default is np.float64 Returns ------- @@ -630,11 +691,11 @@ def wigner_3j(l1: int, l2: int, l3: int, dtype=np.float64) -> np.ndarray: C = _so3_clebsch_gordan(l1, l2, l3) # make sure we always get a copy so mutation doesn't ruin the stored tensors - return C.copy().astype(dtype) + return C.copy().astype(dtype) @functools.lru_cache(maxsize=None) -def _so3_clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: +def _so3_clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: Q1 = change_basis_real_to_complex(l1, dtype=np.float64) Q2 = change_basis_real_to_complex(l2, dtype=np.float64) Q3 = change_basis_real_to_complex(l3, dtype=np.float64) @@ -687,7 +748,9 @@ def _so3_clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: @functools.lru_cache(maxsize=None) -def _su2_clebsch_gordan(j1: Union[int, float], j2: Union[int, float], j3: Union[int, float]) -> np.ndarray: +def _su2_clebsch_gordan( + j1: Union[int, float], j2: Union[int, float], j3: Union[int, float] +) -> np.ndarray: """Calculates the Clebsch-Gordon matrix for SU(2) coupling j1 and j2 to give j3. Parameters @@ -706,13 +769,15 @@ def _su2_clebsch_gordan(j1: Union[int, float], j2: Union[int, float], j3: Union[ assert isinstance(j1, (int, float)) assert isinstance(j2, (int, float)) assert isinstance(j3, (int, float)) - mat = np.zeros((int(2 * j1 + 1), int(2 * j2 + 1), int(2 * j3 + 1)), dtype=np.float64) + mat = np.zeros( + (int(2 * j1 + 1), int(2 * j2 + 1), int(2 * j3 + 1)), dtype=np.float64 + ) if int(2 * j3) in range(int(2 * abs(j1 - j2)), int(2 * (j1 + j2)) + 1, 2): for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)): for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)): if abs(m1 + m2) <= j3: - mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = _su2_clebsch_gordan_coeff( - (j1, m1), (j2, m2), (j3, m1 + m2) + mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = ( + _su2_clebsch_gordan_coeff((j1, m1), (j2, m2), (j3, m1 + m2)) ) return mat @@ -758,7 +823,11 @@ def f(n: int) -> int: C = ( (2.0 * j3 + 1.0) * Fraction( - f(j3 + j1 - j2) * f(j3 - j1 + j2) * f(j1 + j2 - j3) * f(j3 + m3) * f(j3 - m3), + f(j3 + j1 - j2) + * f(j3 - j1 + j2) + * f(j1 + j2 - j3) + * f(j3 + m3) + * f(j3 - m3), f(j1 + j2 + j3 + 1) * f(j1 - m1) * f(j1 + m1) * f(j2 - m2) * f(j2 + m2), ) ) ** 0.5 @@ -766,7 +835,8 @@ def f(n: int) -> int: S = 0 for v in range(vmin, vmax + 1): S += (-1) ** int(v + j2 + m2) * Fraction( - f(j2 + j3 + m1 - v) * f(j1 - m1 + v), f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3) + f(j2 + j3 + m1 - v) * f(j1 - m1 + v), + f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3), ) C = C * S - return C \ No newline at end of file + return C diff --git a/openequivariance/implementations/symmetric_contraction/STPOpt.py b/openequivariance/implementations/symmetric_contraction/STPOpt.py index f9800ce7..890ca7c0 100644 --- a/openequivariance/implementations/symmetric_contraction/STPOpt.py +++ b/openequivariance/implementations/symmetric_contraction/STPOpt.py @@ -1,24 +1,36 @@ +# ruff: noqa : E402 import torch -from openequivariance.extlib import * +from openequivariance.extlib import GroupMM_F32, GroupMM_F64 + class GroupMM: next_id = 0 + def __init__(self, dtype, num_elements, batch_size): self.id = GroupMM.next_id self.num_elements = num_elements GroupMM.next_id += 1 - if dtype==torch.float32: - self.internal = GroupMM_F32(num_elements, batch_size) + if dtype == torch.float32: + self.internal = GroupMM_F32(num_elements, batch_size) else: self.internal = GroupMM_F64(num_elements, batch_size) - - @torch.library.custom_op(f"openequivariance::group_gemm{self.id}", mutates_args=(), device_types="cuda") - def group_gemm(A: torch.Tensor, B: torch.Tensor, - ragged_counts: torch.Tensor, M: int, K: int, ragged_inner: int) -> torch.Tensor: - ''' + @torch.library.custom_op( + f"openequivariance::group_gemm{self.id}", + mutates_args=(), + device_types="cuda", + ) + def group_gemm( + A: torch.Tensor, + B: torch.Tensor, + ragged_counts: torch.Tensor, + M: int, + K: int, + ragged_inner: int, + ) -> torch.Tensor: + """ If ragged_inner == 0: A is 3D, num_weights x num_features x M x K B is batch_size x num_features x K @@ -26,19 +38,24 @@ def group_gemm(A: torch.Tensor, B: torch.Tensor, If ragged_inner == 1: (needed for the backward pass) A is batch_size x num_features x M B is batch_size x num_features K - C is 3D, num_weights x num_features M x K - ''' + C is 3D, num_weights x num_features M x K + """ shape = None if ragged_inner == 0: shape = (B.shape[0], B.shape[1], M) elif ragged_inner == 1: shape = (num_elements, B.shape[1], M, K) - C = torch.zeros(shape, device='cuda', dtype=A.dtype) - self.internal.group_gemm(A.contiguous().data_ptr(), - B.contiguous().data_ptr(), - C.data_ptr(), ragged_counts.data_ptr(), - M, K, ragged_inner) + C = torch.zeros(shape, device="cuda", dtype=A.dtype) + self.internal.group_gemm( + A.contiguous().data_ptr(), + B.contiguous().data_ptr(), + C.data_ptr(), + ragged_counts.data_ptr(), + M, + K, + ragged_inner, + ) return C @group_gemm.register_fake @@ -57,28 +74,49 @@ def backward(ctx, grad_output): grad_A, grad_B = None, None if ctx.ragged_inner == 0: - grad_A = group_gemm(grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 1) - grad_B = group_gemm(ctx.A.transpose(2, 3), grad_output, ctx.ragged_counts, ctx.K, ctx.M, 0) + grad_A = group_gemm( + grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 1 + ) + grad_B = group_gemm( + ctx.A.transpose(2, 3), + grad_output, + ctx.ragged_counts, + ctx.K, + ctx.M, + 0, + ) elif ctx.ragged_inner == 1: - grad_A = group_gemm(grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 0) - grad_B = group_gemm(grad_output.transpose(2, 3), ctx.A, ctx.ragged_counts, ctx.K, ctx.M, 0) + grad_A = group_gemm( + grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 0 + ) + grad_B = group_gemm( + grad_output.transpose(2, 3), + ctx.A, + ctx.ragged_counts, + ctx.K, + ctx.M, + 0, + ) - return grad_A, grad_B, None, None, None, None + return grad_A, grad_B, None, None, None, None self.group_gemm.register_autograd(backward, setup_context=setup_context) def forward(self, weights, vectors, bincounts): - return self.group_gemm(weights, vectors, bincounts, weights.shape[2], weights.shape[3], 0) + return self.group_gemm( + weights, vectors, bincounts, weights.shape[2], weights.shape[3], 0 + ) + # -------------------------------------------------------------------------- from typing import Dict, Optional, Union -import torch from e3nn import o3 from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode from mace.tools.cg import U_matrix_real + @compile_mode("script") class Contraction(torch.nn.Module): def __init__( @@ -112,14 +150,14 @@ def __init__( # Create weight for product basis self.weights = torch.nn.ParameterList([]) - self.groupMM = GroupMM(torch.get_default_dtype(), num_elements, self.num_features) + self.groupMM = GroupMM( + torch.get_default_dtype(), num_elements, self.num_features + ) self.num_equivariance = 2 * irrep_out.lmax + 1 - + for i in range(correlation, 0, -1): # Shapes defining num_params = self.U_tensors(i).size()[-1] - num_equivariance = 2 * irrep_out.lmax + 1 - num_ell = self.U_tensors(i).size()[-2] if i == correlation: # Parameters for the product basis @@ -135,7 +173,7 @@ def __init__( / num_params ) self.weights.append(w) - + if not internal_weights: self.weights = weights[:-1] self.weights_max = weights[-1] @@ -148,34 +186,48 @@ def __init__( U_permuted = U.permute(permutation).reshape(num_params, -1) self.register_buffer(f"U_permuted_{i}", U_permuted) - def forward(self, x: torch.Tensor, bincount: torch.Tensor, sorted_indices: torch.Tensor): + def forward( + self, x: torch.Tensor, bincount: torch.Tensor, sorted_indices: torch.Tensor + ): U = self.U_tensors(self.correlation) num_params = U.shape[-1] num_ell = U.shape[-2] - U_weights = self.weights_max.transpose(1, 2).reshape(-1, num_params) @ self.U_permuted(self.correlation) - - out = self.groupMM.forward(U_weights.view(self.num_elements, self.num_features, -1, num_ell), x, bincount) + U_weights = self.weights_max.transpose(1, 2).reshape( + -1, num_params + ) @ self.U_permuted(self.correlation) + + out = self.groupMM.forward( + U_weights.view(self.num_elements, self.num_features, -1, num_ell), + x, + bincount, + ) out = out.view([x.shape[0], self.num_features] + list(U.shape[:-2])) - - for i, weight in enumerate(self.weights): + + for i, weight in enumerate(self.weights): U = self.U_tensors(self.correlation - i - 1) U_perm = self.U_permuted(self.correlation - i - 1) c_tensor = weight.transpose(1, 2).reshape(-1, weight.shape[1]) @ U_perm - c_tensor = c_tensor.view([weight.shape[0], weight.shape[2]] + list(U.shape[:-1])) + c_tensor = c_tensor.view( + [weight.shape[0], weight.shape[2]] + list(U.shape[:-1]) + ) c_tensor = c_tensor[sorted_indices] + out - + s = c_tensor.shape - out = torch.sum(c_tensor.view(s[0] * s[1], -1, s[-1]) * x.view(s[0] * s[1], 1, s[-1]), dim=2).view(s[:-1]) - #out = torch.bmm(c_tensor.view(s[0] * s[1], -1, s[-1]), x.view(s[0] * s[1], s[-1], 1)).view(s[:-1]) + out = torch.sum( + c_tensor.view(s[0] * s[1], -1, s[-1]) * x.view(s[0] * s[1], 1, s[-1]), + dim=2, + ).view(s[:-1]) + # out = torch.bmm(c_tensor.view(s[0] * s[1], -1, s[-1]), x.view(s[0] * s[1], s[-1], 1)).view(s[:-1]) return out.view(out.shape[0], -1) - + def U_tensors(self, nu: int): return dict(self.named_buffers())[f"U_matrix_{nu}"] def U_permuted(self, nu: int): return dict(self.named_buffers())[f"U_permuted_{nu}"] + @compile_mode("script") class SymmetricContraction(CodeGenMixin, torch.nn.Module): def __init__( @@ -186,7 +238,7 @@ def __init__( irrep_normalization: str = "component", path_normalization: str = "element", internal_weights: Optional[bool] = None, - shared_weights: Optional[bool ] = None, + shared_weights: Optional[bool] = None, num_elements: Optional[int] = None, ) -> None: super().__init__() @@ -237,51 +289,56 @@ def __init__( def forward(self, x: torch.Tensor, y: torch.Tensor): indices = torch.argmax(y, dim=1) - bincount = torch.bincount(indices, minlength=self.num_elements).to('cpu') + bincount = torch.bincount(indices, minlength=self.num_elements).to("cpu") permutation = torch.argsort(indices) inverse_perm = torch.argsort(permutation) sorted_indices = indices[permutation] x = x[permutation] - outs = [contraction(x, bincount, sorted_indices) for contraction in self.contractions] + outs = [ + contraction(x, bincount, sorted_indices) + for contraction in self.contractions + ] outs_cat = torch.cat(outs, dim=-1)[inverse_perm] return outs_cat - # -------------------------------------------------------------------------- + def test_group_matmul(): torch.manual_seed(0) num_elements = 10 - vpe= 30 # Vectors per element, uniform just for testing + vpe = 30 # Vectors per element, uniform just for testing num_features = 20 - M = 64 + M = 64 K = 123 - ragged_counts = torch.zeros(num_elements, dtype=torch.int64, device='cpu') + ragged_counts = torch.zeros(num_elements, dtype=torch.int64, device="cpu") for i in range(num_elements): - ragged_counts[i] = vpe + ragged_counts[i] = vpe def test_backward_0(): group_mm = GroupMM(torch.float32, num_elements, num_features) - A = torch.randn(num_elements, num_features, M, K).to('cuda') - B = torch.randn(num_elements * vpe, num_features, K).to('cuda') + A = torch.randn(num_elements, num_features, M, K).to("cuda") + B = torch.randn(num_elements * vpe, num_features, K).to("cuda") A.requires_grad = True B.requires_grad = True - ground_truth = torch.zeros(num_elements * vpe, num_features, M, device='cuda') + ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda") - # Test the forward pass + # Test the forward pass for i in range(num_elements): - B_slice = B[vpe * i:vpe * (i+1)] - ground_truth[vpe * i:vpe * (i+1)] = (A[i] @ B_slice.permute(1, 2, 0)).permute(2, 0, 1) + B_slice = B[vpe * i : vpe * (i + 1)] + ground_truth[vpe * i : vpe * (i + 1)] = ( + A[i] @ B_slice.permute(1, 2, 0) + ).permute(2, 0, 1) - C_g = torch.randn(num_elements * vpe, num_features, M).to('cuda') + C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda") C_g.requires_grad = True - ground_truth.backward(C_g, inputs=[A, B]) + ground_truth.backward(C_g, inputs=[A, B]) A_grad_gt = A.grad.detach().clone() B_grad_gt = B.grad.detach().clone() @@ -301,24 +358,24 @@ def test_backward_1(): print("TESTING BACKWARD_1!") group_mm = GroupMM(torch.float32, num_elements, num_features) - A = torch.zeros(num_elements * vpe, num_features, M, device='cuda') - B = torch.randn(num_elements * vpe, num_features, K).to('cuda') + A = torch.zeros(num_elements * vpe, num_features, M, device="cuda") + B = torch.randn(num_elements * vpe, num_features, K).to("cuda") A.requires_grad = True B.requires_grad = True - ground_truth = torch.zeros(num_elements, num_features, M, K).to('cuda') + ground_truth = torch.zeros(num_elements, num_features, M, K).to("cuda") for i in range(num_elements): - A_slice = A[vpe * i:vpe * (i+1)] - B_slice = B[vpe * i:vpe * (i+1)] + A_slice = A[vpe * i : vpe * (i + 1)] + B_slice = B[vpe * i : vpe * (i + 1)] ground_truth[i] = A_slice.permute(1, 2, 0) @ B_slice.permute(1, 0, 2) C = group_mm.group_gemm(A, B, ragged_counts, M, K, 1) - print(torch.norm(C - ground_truth)) + print(torch.norm(C - ground_truth)) - C_g = torch.randn(num_elements, num_features, M, K).to('cuda') + C_g = torch.randn(num_elements, num_features, M, K).to("cuda") C_g.requires_grad = True ground_truth.backward(C_g, inputs=[A, B]) @@ -334,28 +391,28 @@ def test_backward_1(): print(torch.norm(A.grad - A_grad_gt)) print(torch.norm(B.grad - B_grad_gt)) - def test_double_backward(): torch.autograd.set_detect_anomaly(True) - group_mm = GroupMM(torch.float32, num_elements, num_features) - A = torch.randn(num_elements, num_features, M, K).to('cuda') - B = torch.randn(num_elements * vpe, num_features, K).to('cuda') + GroupMM(torch.float32, num_elements, num_features) + A = torch.randn(num_elements, num_features, M, K).to("cuda") + B = torch.randn(num_elements * vpe, num_features, K).to("cuda") A.requires_grad = True B.requires_grad = True - ground_truth = torch.zeros(num_elements * vpe, num_features, M, device='cuda') - #ground_truth = group_mm.group_gemm(A, B, ragged_counts, M, K, 0) + ground_truth = torch.zeros(num_elements * vpe, num_features, M, device="cuda") - # Test the forward pass + # Test the forward pass for i in range(num_elements): - B_slice = B[vpe * i:vpe * (i+1)] - ground_truth[vpe * i:vpe * (i+1)] = (A[i] @ B_slice.permute(1, 2, 0)).permute(2, 0, 1) + B_slice = B[vpe * i : vpe * (i + 1)] + ground_truth[vpe * i : vpe * (i + 1)] = ( + A[i] @ B_slice.permute(1, 2, 0) + ).permute(2, 0, 1) - C_g = torch.randn(num_elements * vpe, num_features, M).to('cuda') - C_g.requires_grad = True + C_g = torch.randn(num_elements * vpe, num_features, M).to("cuda") + C_g.requires_grad = True - ground_truth.backward(C_g, inputs=[A, B], create_graph=True, retain_graph=True) + ground_truth.backward(C_g, inputs=[A, B], create_graph=True, retain_graph=True) dummy = torch.norm(A.grad) + torch.norm(B.grad) dummy_grad = torch.randn_like(dummy) @@ -369,11 +426,10 @@ def test_double_backward(): print(torch.norm(B_grad_gt)) print(torch.norm(C_grad_gt)) - test_backward_0() test_backward_1() test_double_backward() -if __name__=='__main__': - test_group_matmul() \ No newline at end of file +if __name__ == "__main__": + test_group_matmul() diff --git a/openequivariance/implementations/utils.py b/openequivariance/implementations/utils.py index aafb9715..44813021 100644 --- a/openequivariance/implementations/utils.py +++ b/openequivariance/implementations/utils.py @@ -1,80 +1,92 @@ import functools -import warnings import math -import numpy as np +import numpy as np from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import Irrep, _MulIr, Irreps, Instruction, TPProblem +from openequivariance.implementations.e3nn_lite import ( + Irreps, + Instruction, + TPProblem, +) -def sparse_outer_product_work(cg : np.ndarray) -> int: + +def sparse_outer_product_work(cg: np.ndarray) -> int: return np.sum(np.max(cg != 0, axis=2)) -def convenience_namer(L1 : Irreps, L2 : Irreps, L3 : Irreps): + +def convenience_namer(L1: Irreps, L2: Irreps, L3: Irreps): return f"({L1}x{L2}->{L3})" -# Non Zeros + +# Non Zeros @functools.lru_cache(typed=True) def count_cg_non_zero(l1, l2, l3) -> int: return np.count_nonzero(TensorProductBase.load_cg_tensor(l1, l2, l3)) -def calculate_total_nnz(tpp : TPProblem) -> int: - """ - To make sure you don't over count repeat CGs which get used multiple times - """ - nnz_by_l_combo = {} - for ins in tpp.instructions: # type : Instruction - l1 = tpp.irreps_in1[ins.i_in1].ir.l - l2 = tpp.irreps_in2[ins.i_in2].ir.l - l3 = tpp.irreps_out[ins.i_out].ir.l - assert isinstance(l1, int) - assert isinstance(l2, int) - assert isinstance(l3, int) - nnz_by_l_combo[(l1,l2,l3)] = count_cg_non_zero(l1,l2,l3) - return sum(nnz_by_l_combo.values()) - -def calc_weight_offsets(tpp : TPProblem) -> list[int]: + +def calculate_total_nnz(tpp: TPProblem) -> int: + """ + To make sure you don't over count repeat CGs which get used multiple times + """ + nnz_by_l_combo = {} + for ins in tpp.instructions: # type : Instruction + l1 = tpp.irreps_in1[ins.i_in1].ir.l + l2 = tpp.irreps_in2[ins.i_in2].ir.l + l3 = tpp.irreps_out[ins.i_out].ir.l + assert isinstance(l1, int) + assert isinstance(l2, int) + assert isinstance(l3, int) + nnz_by_l_combo[(l1, l2, l3)] = count_cg_non_zero(l1, l2, l3) + return sum(nnz_by_l_combo.values()) + + +def calc_weight_offsets(tpp: TPProblem) -> list[int]: """ - Returns a list of weight offsets for every instruction. + Returns a list of weight offsets for every instruction. """ assert isinstance(tpp, TPProblem) offset = 0 offsets = [] for ins in tpp.instructions: - assert isinstance(ins, Instruction) + assert isinstance(ins, Instruction) offsets.append(offset) if ins.has_weight: flatsize = math.prod(ins.path_shape) offset += flatsize - return offsets + return offsets + - def filter_and_analyze_problem(problem): - ''' + """ Centralized function that stops unhandled problem configurations, - returns a dictionary of useful information about the problem. - ''' + returns a dictionary of useful information about the problem. + """ for i, inst in enumerate(problem.instructions): - assert inst.connection_mode == problem.instructions[0].connection_mode, \ + assert inst.connection_mode == problem.instructions[0].connection_mode, ( f"All instructions must have the same connection mode, got {inst.connection_mode} and {problem.instructions[0].connection_mode}" + ) - assert inst.has_weight, \ + assert inst.has_weight, ( f"All instructions must have trainable weights, got {inst.has_weight} at index {i}" + ) - assert problem.instructions[0].connection_mode in ["uvu", "uvw"], \ - f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" + assert problem.instructions[0].connection_mode in ["uvu", "uvw"], ( + f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" + ) - assert problem.irrep_dtype == problem.weight_dtype, \ - f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}" + assert problem.irrep_dtype == problem.weight_dtype, ( + f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}" + ) - assert len(problem.instructions) > 0, \ - "Tensor product has no valid instructions!" + assert len(problem.instructions) > 0, "Tensor product has no valid instructions!" result = { - "is_uvw": problem.instructions[0].connection_mode == "uvw", + "is_uvw": problem.instructions[0].connection_mode == "uvw", } return result + def torch_to_oeq_dtype(torch_dtype): global torch import torch @@ -87,4 +99,4 @@ def torch_to_oeq_dtype(torch_dtype): elif torch_dtype == torch.float64: return np.float64 else: - raise ValueError("Unsupported torch dtype!") \ No newline at end of file + raise ValueError("Unsupported torch dtype!") diff --git a/openequivariance/templates/jinja_utils.py b/openequivariance/templates/jinja_utils.py index acb13289..bb326f27 100644 --- a/openequivariance/templates/jinja_utils.py +++ b/openequivariance/templates/jinja_utils.py @@ -1,10 +1,13 @@ -from jinja2 import Environment, PackageLoader +from jinja2 import Environment, PackageLoader + def raise_helper(msg): raise Exception(msg) + def divide(numerator, denominator): - return numerator // denominator + return numerator // denominator + def sizeof(dtype): if dtype in ["float", "int", "unsigned int"]: @@ -12,10 +15,13 @@ def sizeof(dtype): else: raise Exception("Provided undefined datatype to sizeof!") + def get_jinja_environment(): - env = Environment(loader=PackageLoader("openequivariance"), extensions=['jinja2.ext.do']) - env.globals['raise'] = raise_helper - env.globals['divide'] = divide - env.globals['sizeof'] = sizeof - env.globals['enumerate'] = enumerate - return env \ No newline at end of file + env = Environment( + loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"] + ) + env.globals["raise"] = raise_helper + env.globals["divide"] = divide + env.globals["sizeof"] = sizeof + env.globals["enumerate"] = enumerate + return env diff --git a/tests/batch_test.py b/tests/batch_test.py index 2879a4d6..fc02a863 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -1,27 +1,30 @@ import pytest from pytest_check import check -import numpy as np +import numpy as np import openequivariance as oeq from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.benchmark.correctness_utils import correctness_forward, correctness_backward, correctness_double_backward +from openequivariance.benchmark.correctness_utils import ( + correctness_forward, + correctness_backward, + correctness_double_backward, +) from itertools import chain, product + class TPCorrectness: def thresh(self, direction): - return { - "fwd": 1e-5, - "bwd": 3e-4, - "double_bwd": 3e-4 - }[direction] + return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] def check_result(self, result, fieldname): with check: error = result[fieldname]["diff_Linf_norm"] thresh = result["thresh"] - assert result[fieldname]["pass"], f"{fieldname} observed error={error:.5f} >= {thresh}" + assert result[fieldname]["pass"], ( + f"{fieldname} observed error={error:.5f} >= {thresh}" + ) - @pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope="class") + @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") def dtype(self, request): return request.param @@ -30,27 +33,29 @@ def tp_and_problem(self, problem): tp = TensorProduct(problem) return tp, problem - def test_tp_fwd(self, tp_and_problem): + def test_tp_fwd(self, tp_and_problem): tp, problem = tp_and_problem result = correctness_forward( problem=problem, test_implementation=tp, - reference_implementation=None, + reference_implementation=None, batch_size=1000, correctness_threshold=self.thresh("fwd"), - prng_seed=12345) + prng_seed=12345, + ) self.check_result(result, "output") - def test_tp_bwd(self, tp_and_problem): + def test_tp_bwd(self, tp_and_problem): tp, problem = tp_and_problem result = correctness_backward( problem=problem, test_implementation=tp, - reference_implementation=None, + reference_implementation=None, batch_size=1000, correctness_threshold=self.thresh("bwd"), - prng_seed=12345) + prng_seed=12345, + ) self.check_result(result, "weight_grad") self.check_result(result, "in1_grad") @@ -59,100 +64,178 @@ def test_tp_bwd(self, tp_and_problem): def test_tp_double_bwd(self, tp_and_problem): tp, problem = tp_and_problem result = correctness_double_backward( - problem = problem, + problem=problem, test_implementation=tp, - reference_implementation = None, - batch_size = 200, - correctness_threshold=self.thresh("double_bwd"), - prng_seed = 12345) + reference_implementation=None, + batch_size=200, + correctness_threshold=self.thresh("double_bwd"), + prng_seed=12345, + ) self.check_result(result, "output_double_grad") self.check_result(result, "in1_grad") self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") + class TestProductionModels(TPCorrectness): - from openequivariance.benchmark.benchmark_configs \ - import e3nn_torch_tetris_polynomial, diffdock_configs, mace_nequip_problems - production_model_tpps = list(chain( - mace_nequip_problems, - e3nn_torch_tetris_polynomial, - diffdock_configs)) - - @pytest.fixture(params=production_model_tpps, ids = lambda x : x.label, scope="class") + from openequivariance.benchmark.benchmark_configs import ( + e3nn_torch_tetris_polynomial, + diffdock_configs, + mace_nequip_problems, + ) + + production_model_tpps = list( + chain(mace_nequip_problems, e3nn_torch_tetris_polynomial, diffdock_configs) + ) + + @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype return request.param + class TestUVUSingleIrrep(TPCorrectness): muls = [ - (1, 1, 1), (2, 1, 2), (4, 1, 4), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (19, 1, 19), - (33, 1, 33), (49, 1, 49), (50, 1, 50), (123, 1, 123), - (128, 1, 128), (256, 1, 256), (512, 1, 512), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), - (16, 3, 16), (16, 9, 16), (24, 24, 24), (32, 32, 32) + (1, 1, 1), + (2, 1, 2), + (4, 1, 4), + (8, 1, 8), + (16, 1, 16), + (32, 1, 32), + (5, 1, 5), + (13, 1, 13), + (19, 1, 19), + (33, 1, 33), + (49, 1, 49), + (50, 1, 50), + (123, 1, 123), + (128, 1, 128), + (256, 1, 256), + (512, 1, 512), + (1, 2, 1), + (1, 4, 1), + (1, 16, 1), + (1, 32, 1), + (16, 3, 16), + (16, 9, 16), + (24, 24, 24), + (32, 32, 32), ] - - irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), - (2, 0, 2), (2, 2, 4), (2, 2, 2), (5, 3, 5), (7, 2, 5) ] - - def id_func(m, i): + + irs = [ + (0, 0, 0), + (1, 1, 1), + (1, 0, 1), + (1, 2, 1), + (2, 0, 2), + (2, 2, 4), + (2, 2, 2), + (5, 3, 5), + (7, 2, 5), + ] + + def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), - scope="class") + @pytest.fixture( + params=product(muls, irs), + ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvu", True)] - return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", - instructions, shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) - + instructions = [(0, 0, 0, "uvu", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + class TestUVWSingleIrrep(TPCorrectness): muls = [ - (1, 1, 1), (2, 1, 2), (4, 1, 4), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (19, 1, 19), - (33, 1, 33), (49, 1, 49), (50, 1, 50), (64, 1, 64), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), - (16, 3, 16), (16, 9, 16), (24, 24, 24), (32, 32, 32) + (1, 1, 1), + (2, 1, 2), + (4, 1, 4), + (8, 1, 8), + (16, 1, 16), + (32, 1, 32), + (5, 1, 5), + (13, 1, 13), + (19, 1, 19), + (33, 1, 33), + (49, 1, 49), + (50, 1, 50), + (64, 1, 64), + (1, 2, 1), + (1, 4, 1), + (1, 16, 1), + (1, 32, 1), + (16, 3, 16), + (16, 9, 16), + (24, 24, 24), + (32, 32, 32), ] - - irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), - (2, 0, 2), (2, 2, 4), (2, 2, 2), (5, 3, 5), (7, 2, 5) ] - def id_func(m, i): + irs = [ + (0, 0, 0), + (1, 1, 1), + (1, 0, 1), + (1, 2, 1), + (2, 0, 2), + (2, 2, 4), + (2, 2, 2), + (5, 3, 5), + (7, 2, 5), + ] + + def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), - scope="class") + @pytest.fixture( + params=product(muls, irs), + ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvw", True)] - return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", - instructions, shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) - + instructions = [(0, 0, 0, "uvw", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + class TestSharedWeights(TPCorrectness): - from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs - problems = [mace_problems[0], diffdock_configs[0]] + from openequivariance.benchmark.benchmark_configs import ( + mace_problems, + diffdock_configs, + ) + + problems = [mace_problems[0], diffdock_configs[0]] def thresh(self, direction): return { "fwd": 1e-5, - "bwd": 5e-4, # Expect higher errors for shared weights - "double_bwd": 5e-4 + "bwd": 5e-4, # Expect higher errors for shared weights + "double_bwd": 5e-4, }[direction] - @pytest.fixture(params=problems, ids = lambda x : x.label, scope="class") + @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): problem = request.param problem.irrep_dtype, problem.weight_dtype = dtype, dtype problem.shared_weights = True - return problem \ No newline at end of file + return problem diff --git a/tests/benchmark.py b/tests/benchmark.py index 7cabe628..fe765792 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -1,22 +1,46 @@ -import numpy as np -import numpy.linalg as la - -import itertools, logging, argparse, os, copy, gc +import itertools +import logging +import argparse +import os +import copy +import json +import pathlib from pathlib import Path import urllib.request +import numpy as np + from openequivariance.benchmark.logging_utils import getLogger from openequivariance.extlib import DeviceProp -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs, E3NNTensorProductCompiledMaxAutotuneCUDAGraphs -from openequivariance.implementations.TensorProduct import TensorProduct +from openequivariance.implementations.E3NNTensorProduct import ( + E3NNTensorProduct, + E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, +) +from openequivariance.implementations.TensorProduct import TensorProduct from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.benchmark.TestBenchmarkSuite import TestBenchmarkSuite, TestDefinition, Direction -from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction -from openequivariance.benchmark.benchmark_routines.paper_benchmark_uvw import run_paper_uvw_benchmark - -from openequivariance.implementations.convolution.TensorProductConv import * -from openequivariance.implementations.convolution.CUEConv import * -from openequivariance.benchmark.ConvBenchmarkSuite import * +from openequivariance.benchmark.TestBenchmarkSuite import ( + TestBenchmarkSuite, + TestDefinition, + Direction, +) +from openequivariance.benchmark.tpp_creation_utils import ( + ChannelwiseTPP, + FullyConnectedTPProblem, + SingleInstruction, +) +from openequivariance.benchmark.benchmark_routines.paper_benchmark_uvw import ( + run_paper_uvw_benchmark, +) + +from openequivariance.implementations.convolution.TensorProductConv import ( + TensorProductConvAtomic, + TensorProductConvDeterministic, + TensorProductConvKahan, + TensorProductConvScatterSum, +) + +from openequivariance.implementations.convolution.CUEConv import CUEConv, CUEConvFused +from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph logger = getLogger() @@ -24,89 +48,127 @@ FCTPP = FullyConnectedTPProblem implementation_map = { - 'e3nn': E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, - 'e3nn_uncompiled': E3NNTensorProduct, - 'cue': CUETensorProduct, - 'oeq': TensorProduct + "e3nn": E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, + "e3nn_uncompiled": E3NNTensorProduct, + "cue": CUETensorProduct, + "oeq": TensorProduct, } -datatype_map = { - 'float32': np.float32, - 'float64': np.float64 -} +datatype_map = {"float32": np.float32, "float64": np.float64} roofline_configs = [ - SingleInstruction(L1, L2, L3, cm, f"[{i+1}]#{L1} x {L2} -> {L3} ({cm})") - for i, (L1, L2, L3, cm) in enumerate([ - ("128x1e", "1x1e", "128x1e", "uvu"), - ("128x2e", "1x1e", "128x2e", "uvu"), - ("128x3e", "1x3e", "128x3e", "uvu"), - ("128x5e", "1x5e", "128x3e", "uvu"), - ("128x5e", "1x3e", "128x5e", "uvu"), - ("128x6e", "1x3e", "128x6e", "uvu"), - ("128x7e", "1x4e", "128x7e", "uvu"), - ("128x7e", "1x7e", "128x7e", "uvu"), - ]) + SingleInstruction(L1, L2, L3, cm, f"[{i + 1}]#{L1} x {L2} -> {L3} ({cm})") + for i, (L1, L2, L3, cm) in enumerate( + [ + ("128x1e", "1x1e", "128x1e", "uvu"), + ("128x2e", "1x1e", "128x2e", "uvu"), + ("128x3e", "1x3e", "128x3e", "uvu"), + ("128x5e", "1x5e", "128x3e", "uvu"), + ("128x5e", "1x3e", "128x5e", "uvu"), + ("128x6e", "1x3e", "128x6e", "uvu"), + ("128x7e", "1x4e", "128x7e", "uvu"), + ("128x7e", "1x7e", "128x7e", "uvu"), + ] + ) ] + def benchmark_uvu(params): - from openequivariance.benchmark.benchmark_configs \ - import mace_nequip_problems + from openequivariance.benchmark.benchmark_configs import mace_nequip_problems float64_problems = copy.deepcopy(mace_nequip_problems) - for problem in float64_problems: + for problem in float64_problems: problem.irrep_dtype = np.float64 - problem.weight_dtype = np.float64 + problem.weight_dtype = np.float64 problems = mace_nequip_problems + float64_problems - implementations = [ - implementation_map[impl] for impl in params.implementations - ] + implementations = [implementation_map[impl] for impl in params.implementations] directions = params.directions - datatypes = [datatype_map[dt] for dt in params.datatypes] - tests = [TestDefinition(implementation, problem, direction, correctness=False, benchmark=True) - for implementation, problem, direction - in itertools.product(implementations, problems, directions)] + tests = [ + TestDefinition( + implementation, problem, direction, correctness=False, benchmark=True + ) + for implementation, problem, direction in itertools.product( + implementations, problems, directions + ) + ] # Handle the float64 Benzene case, since we run out of memory with torch compile - tests = [test for test in tests - if 'benzene' not in test.problem.label - or test.implementation != E3NNTensorProductCompiledMaxAutotuneCUDAGraphs - or test.problem.irrep_dtype != np.float64] - - if 'e3nn' in params.implementations and 'float64' in params.datatypes: - tests.extend([TestDefinition(E3NNTensorProduct, - CTPP('64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e', '0e + 1o + 2e + 3o', '64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e', - 'nequip-revmd17-benzene', irrep_dtype=np.float64, weight_dtype=np.float64), direction, correctness=False, benchmark=True) - for direction in ['forward', 'backward']]) - - # Remove some more configurations for GPUs with limited memory + tests = [ + test + for test in tests + if "benzene" not in test.problem.label + or test.implementation != E3NNTensorProductCompiledMaxAutotuneCUDAGraphs + or test.problem.irrep_dtype != np.float64 + ] + + if "e3nn" in params.implementations and "float64" in params.datatypes: + tests.extend( + [ + TestDefinition( + E3NNTensorProduct, + CTPP( + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", + "0e + 1o + 2e + 3o", + "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", + "nequip-revmd17-benzene", + irrep_dtype=np.float64, + weight_dtype=np.float64, + ), + direction, + correctness=False, + benchmark=True, + ) + for direction in ["forward", "backward"] + ] + ) + + # Remove some more configurations for GPUs with limited memory if params.limited_memory: - tests = [test for test in tests if - (test.implementation == TensorProduct and 'benzene' not in test.problem.label) - or (test.implementation == CUETensorProduct and 'benzene' not in test.problem.label) - or ('benzene' not in test.problem.label and test.problem.irrep_dtype != np.float64)] + tests = [ + test + for test in tests + if ( + test.implementation == TensorProduct + and "benzene" not in test.problem.label + ) + or ( + test.implementation == CUETensorProduct + and "benzene" not in test.problem.label + ) + or ( + "benzene" not in test.problem.label + and test.problem.irrep_dtype != np.float64 + ) + ] bench_suite = TestBenchmarkSuite( num_warmup=100, num_iter=100, bench_batch_size=params.batch_size, prng_seed=11111, - test_name="uvu") + test_name="uvu", + ) data_folder = bench_suite.run(tests, params.output_folder) if params.plot: plot({"data_folder": data_folder}) -def benchmark_roofline(params): - implementations = [TensorProduct, CUETensorProduct] - directions = ['forward', 'backward'] - tests = [TestDefinition(implementation, problem, direction, correctness=False, benchmark=True) - for implementation, problem, direction - in itertools.product(implementations, roofline_configs, directions)] +def benchmark_roofline(params): + implementations = [TensorProduct, CUETensorProduct] + directions = ["forward", "backward"] + + tests = [ + TestDefinition( + implementation, problem, direction, correctness=False, benchmark=True + ) + for implementation, problem, direction in itertools.product( + implementations, roofline_configs, directions + ) + ] bench_suite = TestBenchmarkSuite( num_warmup=100, @@ -114,13 +176,15 @@ def benchmark_roofline(params): bench_batch_size=200000, prng_seed=11111, torch_op=False, - test_name="roofline") + test_name="roofline", + ) data_folder = bench_suite.run(tests, params.output_folder) if params.plot: plot({"data_folder": data_folder}) + def download_graphs(params, filenames): download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" @@ -129,82 +193,107 @@ def download_graphs(params, filenames): graphs = [] for filename in filenames: - target_path = Path(params.data) / filename + target_path = Path(params.data) / filename if not target_path.exists(): if params.disable_download: logging.critical(f"Error, {target_path} does not exist.") exit(1) else: logging.info(f"Downloading {download_prefix + filename}...") - urllib.request.urlretrieve(download_prefix + filename, target_path) + urllib.request.urlretrieve(download_prefix + filename, target_path) graphs.append(load_graph(str(target_path))) return graphs + def benchmark_convolution(params): - filenames = [ "covid_spike_radius3.0.pickle", - "1drf_radius6.0.pickle", - "carbon_lattice_radius6.0.pickle"] + filenames = [ + "covid_spike_radius3.0.pickle", + "1drf_radius6.0.pickle", + "carbon_lattice_radius6.0.pickle", + ] graphs = download_graphs(params, filenames) if not params.disable_bench: - configs = [ ChannelwiseTPP("128x0e+128x1o+128x2e", - "1x0e+1x1o+1x2e+1x3o", - "128x0e+128x1o+128x2e+128x3o"), - ChannelwiseTPP("128x0e+128x1o+128x2e", - "1x0e+1x1o+1x2e+1x3o", - "128x0e+128x1o+128x2e+128x3o"), - ] # MACE-large + configs = [ + ChannelwiseTPP( + "128x0e+128x1o+128x2e", + "1x0e+1x1o+1x2e+1x3o", + "128x0e+128x1o+128x2e+128x3o", + ), + ChannelwiseTPP( + "128x0e+128x1o+128x2e", + "1x0e+1x1o+1x2e+1x3o", + "128x0e+128x1o+128x2e+128x3o", + ), + ] # MACE-large configs[1].irrep_dtype = np.float64 configs[1].weight_dtype = np.float64 - bench = ConvBenchmarkSuite(configs, test_name="convolution") + bench = ConvBenchmarkSuite(configs, test_name="convolution") - implementations = [ TensorProductConvScatterSum, - CUEConv, - CUEConvFused, - TensorProductConvDeterministic, - TensorProductConvAtomic] + implementations = [ + TensorProductConvScatterSum, + CUEConv, + CUEConvFused, + TensorProductConvDeterministic, + TensorProductConvAtomic, + ] if params.limited_memory: - implementations = [impl for impl in implementations - if impl != TensorProductConvScatterSum - and impl != CUEConv] + implementations = [ + impl + for impl in implementations + if impl != TensorProductConvScatterSum and impl != CUEConv + ] output_folder = None - for graph in graphs: + for graph in graphs: for direction in ["forward", "backward"]: output_folder = bench.run( - implementations = implementations, - graph = graph, - direction=direction, - correctness=False, - benchmark=True, - output_folder=params.output_folder) + implementations=implementations, + graph=graph, + direction=direction, + correctness=False, + benchmark=True, + output_folder=params.output_folder, + ) if params.plot: if not params.limited_memory: plot({"data_folder": output_folder}) else: - logger.critical("Cannot plot convolution speedups over cuE with --limited-memory flag enabled.") + logger.critical( + "Cannot plot convolution speedups over cuE with --limited-memory flag enabled." + ) + def benchmark_double_backward(params): - from openequivariance.benchmark.benchmark_configs import mace_nequip_problems, diffdock_configs + from openequivariance.benchmark.benchmark_configs import ( + mace_nequip_problems, + diffdock_configs, + ) implementations = [E3NNTensorProduct, CUETensorProduct, TensorProduct] problems = diffdock_configs + mace_nequip_problems float64_problems = copy.deepcopy(problems) - for problem in float64_problems: + for problem in float64_problems: problem.irrep_dtype = np.float64 - problem.weight_dtype = np.float64 - - directions : list[Direction] = ['double_backward'] - tests = [TestDefinition(implementation, problem, direction, correctness=False, benchmark=True) - for problem, direction, implementation in itertools.product(problems + float64_problems, directions, implementations)] + problem.weight_dtype = np.float64 + + directions: list[Direction] = ["double_backward"] + tests = [ + TestDefinition( + implementation, problem, direction, correctness=False, benchmark=True + ) + for problem, direction, implementation in itertools.product( + problems + float64_problems, directions, implementations + ) + ] logger = getLogger() logger.setLevel(logging.INFO) @@ -215,6 +304,7 @@ def benchmark_double_backward(params): if params.plot: plot({"data_folder": data_folder}) + def benchmark_kahan_accuracy(params): from openequivariance.benchmark.benchmark_configs import mace_problems @@ -223,40 +313,44 @@ def benchmark_kahan_accuracy(params): implementations = [TensorProductConvAtomic, TensorProductConvKahan] problems = [mace_problems[0]] - bench = ConvBenchmarkSuite(problems, test_name="kahan_convolution_accuracy", correctness_threshold=1e-4) - directions = ['forward', 'backward'] + bench = ConvBenchmarkSuite( + problems, test_name="kahan_convolution_accuracy", correctness_threshold=1e-4 + ) + directions = ["forward", "backward"] if params.double_backward: - directions.append('double_backward') - - for graph in graphs: - for direction in directions: - output_folder = bench.run( - implementations = implementations, - graph = graph, - direction=direction, - correctness=True, - benchmark=False, - output_folder=params.output_folder, - high_precision_ref=True) - + directions.append("double_backward") + + for graph in graphs: + for direction in directions: + bench.run( + implementations=implementations, + graph=graph, + direction=direction, + correctness=True, + benchmark=False, + output_folder=params.output_folder, + high_precision_ref=True, + ) + def plot(params): import openequivariance.benchmark.plotting as plotting + data_folder, test_name = None, None if isinstance(params, dict): data_folder = params["data_folder"] else: data_folder = params.data_folder - with open(pathlib.Path(data_folder) / "metadata.json", 'r') as f: + with open(pathlib.Path(data_folder) / "metadata.json", "r") as f: metadata = json.load(f) test_name = metadata["test_name"] - if test_name == "uvu": + if test_name == "uvu": plotting.plot_uvu(data_folder) - elif test_name == "uvw": + elif test_name == "uvw": plotting.plot_uvw(data_folder) - elif test_name == "roofline": + elif test_name == "roofline": plotting.plot_roofline(data_folder) elif test_name == "convolution": plotting.plot_convolution(data_folder) @@ -265,65 +359,151 @@ def plot(params): else: raise ValueError(f"Unknown test name: {test_name}. Cannot plot results.") -if __name__=='__main__': + +if __name__ == "__main__": logger.setLevel(logging.INFO) dp = DeviceProp(0) paper_benchmark_gpu = "NVIDIA A100-SXM4-80GB" if dp.name != paper_benchmark_gpu: - logger.warning(msg=f"Current GPU ({dp.name}) is not the {paper_benchmark_gpu} used in the paper. Runtime benchmarks may differ from our reported results.") - parser = argparse.ArgumentParser(description='Benchmark openequivariance kernels') - parser.add_argument("--output_folder", "-o", type=str, default=None, help="Output folder for benchmark results") - - subparsers = parser.add_subparsers(help='subcommand help', required=True) - parser_uvu = subparsers.add_parser('uvu', help='Run the UVU kernel benchmark without fusion') - parser_uvu.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") - parser_uvu.add_argument("--implementations", "-i", type=str, nargs='+', - default=['e3nn', 'cue', 'oeq'], help="Implementations to benchmark", - choices=['e3nn', 'e3nn_uncompiled', 'cue', 'oeq']) - parser_uvu.add_argument("--directions", "-d", type=str, nargs='+', - default=['forward', 'backward'], help="Directions to benchmark", - choices=['forward', 'backward']) - parser_uvu.add_argument("--datatypes", "-t", type=str, nargs='+', - default=['float32', 'float64'], help="Data types to benchmark", - choices=['float32', 'float64']) - parser_uvu.add_argument("--limited-memory", action="store_true", help="Disable tests requiring large amounts of memory.") + logger.warning( + msg=f"Current GPU ({dp.name}) is not the {paper_benchmark_gpu} used in the paper. Runtime benchmarks may differ from our reported results." + ) + parser = argparse.ArgumentParser(description="Benchmark openequivariance kernels") + parser.add_argument( + "--output_folder", + "-o", + type=str, + default=None, + help="Output folder for benchmark results", + ) + + subparsers = parser.add_subparsers(help="subcommand help", required=True) + parser_uvu = subparsers.add_parser( + "uvu", help="Run the UVU kernel benchmark without fusion" + ) + parser_uvu.add_argument( + "--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark" + ) + parser_uvu.add_argument( + "--implementations", + "-i", + type=str, + nargs="+", + default=["e3nn", "cue", "oeq"], + help="Implementations to benchmark", + choices=["e3nn", "e3nn_uncompiled", "cue", "oeq"], + ) + parser_uvu.add_argument( + "--directions", + "-d", + type=str, + nargs="+", + default=["forward", "backward"], + help="Directions to benchmark", + choices=["forward", "backward"], + ) + parser_uvu.add_argument( + "--datatypes", + "-t", + type=str, + nargs="+", + default=["float32", "float64"], + help="Data types to benchmark", + choices=["float32", "float64"], + ) + parser_uvu.add_argument( + "--limited-memory", + action="store_true", + help="Disable tests requiring large amounts of memory.", + ) parser_uvu.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvu.set_defaults(func=benchmark_uvu) - parser_roofline = subparsers.add_parser('roofline', help='Run the roofline comparison') - parser_roofline.add_argument("--plot", action="store_true", help="Plot the results.") + parser_roofline = subparsers.add_parser( + "roofline", help="Run the roofline comparison" + ) + parser_roofline.add_argument( + "--plot", action="store_true", help="Plot the results." + ) parser_roofline.set_defaults(func=benchmark_roofline) - parser_conv = subparsers.add_parser('conv', help='Run the fused convolution kernel benchmark') - parser_conv.add_argument("--data", type=str, help="Folder containing graph data", required=True) - parser_conv.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") - parser_conv.add_argument("--disable_bench", action='store_true', help="Disable benchmark (downloads data if needed)") - parser_conv.add_argument("--limited-memory", action="store_true", help="Disable tests requiring large amounts of memory.") + parser_conv = subparsers.add_parser( + "conv", help="Run the fused convolution kernel benchmark" + ) + parser_conv.add_argument( + "--data", type=str, help="Folder containing graph data", required=True + ) + parser_conv.add_argument( + "--disable_download", + action="store_true", + help="Disable downloading data files if they do not exist", + ) + parser_conv.add_argument( + "--disable_bench", + action="store_true", + help="Disable benchmark (downloads data if needed)", + ) + parser_conv.add_argument( + "--limited-memory", + action="store_true", + help="Disable tests requiring large amounts of memory.", + ) parser_conv.add_argument("--plot", action="store_true", help="Plot the results.") parser_conv.set_defaults(func=benchmark_convolution) - parser_uvw = subparsers.add_parser('uvw', help='Run the UVW kernel benchmark without fusion') - parser_uvw.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") - parser_uvw.add_argument("--directions", "-d", type=str, nargs='+', - default=['forward', 'backward'], help="Directions to benchmark", - choices=['forward', 'backward']) + parser_uvw = subparsers.add_parser( + "uvw", help="Run the UVW kernel benchmark without fusion" + ) + parser_uvw.add_argument( + "--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark" + ) + parser_uvw.add_argument( + "--directions", + "-d", + type=str, + nargs="+", + default=["forward", "backward"], + help="Directions to benchmark", + choices=["forward", "backward"], + ) parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvw.set_defaults(func=run_paper_uvw_benchmark) - parser_double_bwd = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') - parser_double_bwd.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") + parser_double_bwd = subparsers.add_parser( + "double_backward", help="Run the higher derivative kernel benchmark" + ) + parser_double_bwd.add_argument( + "--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark" + ) parser_double_bwd.set_defaults(func=benchmark_double_backward) - parser_kahan = subparsers.add_parser('kahan_conv', help='Run the Kahan convolution accuracy benchmark') - parser_kahan.add_argument("--data", type=str, help="Folder to download graph data to (or already containing graphs)", required=True) - parser_kahan.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") - parser_kahan.add_argument("--double_backward", action='store_true', help="Run double backward test (high memory usage)") + parser_kahan = subparsers.add_parser( + "kahan_conv", help="Run the Kahan convolution accuracy benchmark" + ) + parser_kahan.add_argument( + "--data", + type=str, + help="Folder to download graph data to (or already containing graphs)", + required=True, + ) + parser_kahan.add_argument( + "--disable_download", + action="store_true", + help="Disable downloading data files if they do not exist", + ) + parser_kahan.add_argument( + "--double_backward", + action="store_true", + help="Run double backward test (high memory usage)", + ) parser_kahan.set_defaults(func=benchmark_kahan_accuracy) - parser_plot = subparsers.add_parser('plot', help="Generate a plot for a folder of benchmarks.") + parser_plot = subparsers.add_parser( + "plot", help="Generate a plot for a folder of benchmarks." + ) parser_plot.add_argument("data_folder", type=str) parser_plot.set_defaults(func=plot) args = parser.parse_args() - args.func(args) \ No newline at end of file + args.func(args) diff --git a/tests/conv_test.py b/tests/conv_test.py index 7f8496ad..24e09c34 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -1,33 +1,35 @@ -import pytest, tempfile, urllib +import pytest +import tempfile +import urllib from pytest_check import check -import numpy as np +import numpy as np import openequivariance as oeq -from openequivariance.benchmark.ConvBenchmarkSuite import load_graph +from openequivariance.benchmark.ConvBenchmarkSuite import load_graph from itertools import chain, product + class ConvCorrectness: def thresh(self, direction): - return { - "fwd": 1e-5, - "bwd": 3e-4, - "double_bwd": 3e-4 - }[direction] - + return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] def check_result(self, result, fieldname): with check: error = result[fieldname]["diff_Linf_norm"] thresh = result["thresh"] - assert result[fieldname]["pass"], f"{fieldname} observed error={error:.5f} >= {thresh}" - - @pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class') + assert result[fieldname]["pass"], ( + f"{fieldname} observed error={error:.5f} >= {thresh}" + ) + + @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") def dtype(self, request): return request.param - @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=['1drf'], scope='class') + @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="class") def graph(self, request): - download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" + download_prefix = ( + "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" + ) filename = request.param graph = None @@ -35,16 +37,16 @@ def graph(self, request): urllib.request.urlretrieve(download_prefix + filename, temp_file.name) graph = load_graph(temp_file.name) - #graph = load_graph("data/1drf_radius3.5.pickle") + # graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=['atomic', 'deterministic', 'kahan'], scope='class') + @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") def conv_object(self, request, problem): - if request.param == 'atomic': + if request.param == "atomic": return oeq.TensorProductConv(problem, deterministic=False) - elif request.param == 'deterministic': + elif request.param == "deterministic": return oeq.TensorProductConv(problem, deterministic=True) - elif request.param == 'kahan': + elif request.param == "kahan": if problem.irrep_dtype == np.float32: return oeq.TensorProductConv(problem, deterministic=True, kahan=True) else: @@ -55,10 +57,12 @@ def test_tp_fwd(self, conv_object, graph): assert True return - result = conv_object.test_correctness_forward(graph, - thresh=self.thresh("fwd"), - prng_seed=12345, - reference_implementation=None) + result = conv_object.test_correctness_forward( + graph, + thresh=self.thresh("fwd"), + prng_seed=12345, + reference_implementation=None, + ) self.check_result(result, "output") @@ -67,10 +71,12 @@ def test_tp_bwd(self, conv_object, graph): assert True return - result = conv_object.test_correctness_backward(graph, - thresh=self.thresh("bwd"), - prng_seed=12345, - reference_implementation=None) + result = conv_object.test_correctness_backward( + graph, + thresh=self.thresh("bwd"), + prng_seed=12345, + reference_implementation=None, + ) self.check_result(result, "weight_grad") self.check_result(result, "in1_grad") @@ -81,24 +87,28 @@ def test_tp_double_bwd(self, conv_object, graph): assert True return - result = conv_object.test_correctness_double_backward(graph, - thresh=self.thresh("double_bwd"), - prng_seed=12345, - reference_implementation=None) + result = conv_object.test_correctness_double_backward( + graph, + thresh=self.thresh("double_bwd"), + prng_seed=12345, + reference_implementation=None, + ) self.check_result(result, "output_grad") self.check_result(result, "in1_grad") self.check_result(result, "in2_grad") self.check_result(result, "weights_grad") + class TestProductionModels(ConvCorrectness): - from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs - production_model_tpps = list(chain( - mace_problems, - diffdock_configs - )) + from openequivariance.benchmark.benchmark_configs import ( + mace_problems, + diffdock_configs, + ) + + production_model_tpps = list(chain(mace_problems, diffdock_configs)) - @pytest.fixture(params=production_model_tpps, ids = lambda x : x.label, scope="class") + @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype return request.param @@ -106,69 +116,113 @@ def problem(self, request, dtype): class TestUVUSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (19, 1, 19), - (33, 1, 33), (49, 1, 49), (128, 1, 128), (1, 2, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), + (8, 1, 8), + (16, 1, 16), + (32, 1, 32), + (5, 1, 5), + (13, 1, 13), + (19, 1, 19), + (33, 1, 33), + (49, 1, 49), + (128, 1, 128), + (1, 2, 1), + (1, 16, 1), + (1, 32, 1), + (16, 3, 16), ] - - irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5) ] - def id_func(m, i): + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5)] + + def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), - scope="class") + @pytest.fixture( + params=product(muls, irs), + ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvu", True)] - return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", - instructions, shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) + instructions = [(0, 0, 0, "uvu", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + - class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), + (4, 1, 4), + (8, 1, 8), + (16, 1, 16), + (32, 1, 32), + (5, 1, 5), + (13, 1, 13), + (33, 1, 33), + (49, 1, 49), + (64, 1, 64), + (1, 2, 1), + (1, 4, 1), + (1, 16, 1), + (1, 32, 1), + (16, 3, 16), ] - + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] - def id_func(m, i): + def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" - @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), - scope="class") + @pytest.fixture( + params=product(muls, irs), + ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), + scope="class", + ) def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvw", True)] - return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", - instructions, shared_weights=False, - internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) - + instructions = [(0, 0, 0, "uvw", True)] + return oeq.TPProblem( + f"{m[0]}x{i[0]}e", + f"{m[1]}x{i[1]}e", + f"{m[2]}x{i[2]}e", + instructions, + shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, + weight_dtype=dtype, + ) + class TestAtomicSharedWeights(ConvCorrectness): - from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs - problems = [mace_problems[0], diffdock_configs[0]] + from openequivariance.benchmark.benchmark_configs import ( + mace_problems, + diffdock_configs, + ) + + problems = [mace_problems[0], diffdock_configs[0]] def thresh(self, direction): return { "fwd": 1e-5, - "bwd": 5e-2, # Expect higher errors for shared weights - "double_bwd": 5e-2 + "bwd": 5e-2, # Expect higher errors for shared weights + "double_bwd": 5e-2, }[direction] - @pytest.fixture(params=problems, ids = lambda x : x.label, scope="class") + @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") def problem(self, request, dtype): problem = request.param problem.irrep_dtype, problem.weight_dtype = dtype, dtype problem.shared_weights = True - return problem - - @pytest.fixture(scope='class') + return problem + + @pytest.fixture(scope="class") def conv_object(self, request, problem): - return oeq.TensorProductConv(problem, deterministic=False) \ No newline at end of file + return oeq.TensorProductConv(problem, deterministic=False) diff --git a/tests/export_test.py b/tests/export_test.py index 82166b2d..ea966bdc 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1,77 +1,100 @@ import torch -import pytest, tempfile +import pytest +import tempfile import numpy as np import openequivariance as oeq from torch_geometric import EdgeIndex -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.benchmark.correctness_utils import correctness_forward, correctness_backward, correctness_double_backward -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def problem_and_irreps(): X_ir, Y_ir, Z_ir = oeq.Irreps("32x5e"), oeq.Irreps("1x3e"), oeq.Irreps("32x5e") - problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, - [(0, 0, 0, "uvu", True)], - shared_weights=False, internal_weights=False, - irrep_dtype=np.float32, weight_dtype=np.float32) - - gen = torch.Generator(device='cuda') + problem = oeq.TPProblem( + X_ir, + Y_ir, + Z_ir, + [(0, 0, 0, "uvu", True)], + shared_weights=False, + internal_weights=False, + irrep_dtype=np.float32, + weight_dtype=np.float32, + ) + + gen = torch.Generator(device="cuda") gen.manual_seed(0) - return problem, X_ir, Y_ir, Z_ir, + return ( + problem, + X_ir, + Y_ir, + Z_ir, + ) + -@pytest.fixture(params=['batch', 'conv_det', 'conv_atomic'], scope='session') +@pytest.fixture(params=["batch", "conv_det", "conv_atomic"], scope="session") def tp_and_inputs(request, problem_and_irreps): problem, X_ir, Y_ir, _ = problem_and_irreps - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") gen.manual_seed(0) - if request.param == 'batch': + if request.param == "batch": batch_size = 1000 - X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) - Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) - W = torch.rand(batch_size, problem.weight_numel, device='cuda', generator=gen) + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(batch_size, problem.weight_numel, device="cuda", generator=gen) return oeq.TensorProduct(problem), (X, Y, W) else: node_ct, nonzero_ct = 3, 4 # Receiver, sender indices for message passing GNN edge_index = EdgeIndex( - [[0, 1, 1, 2], - [1, 0, 2, 1]], - device='cuda', - dtype=torch.long) + [[0, 1, 1, 2], [1, 0, 2, 1]], device="cuda", dtype=torch.long + ) - _, sender_perm = edge_index.sort_by("col") - edge_index, receiver_perm = edge_index.sort_by("row") + _, sender_perm = edge_index.sort_by("col") + edge_index, receiver_perm = edge_index.sort_by("row") edge_index = [edge_index[0].detach(), edge_index[1].detach()] - X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen) - Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen) - W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen) - - if request.param == 'conv_atomic': - return oeq.TensorProductConv(problem, torch_op=True, deterministic=False), (X, Y, W, edge_index[0], edge_index[1]) - elif request.param == 'conv_det': - return oeq.TensorProductConv(problem, torch_op=True, deterministic=True), (X, Y, W, edge_index[0], edge_index[1], sender_perm) + X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) + + if request.param == "conv_atomic": + return oeq.TensorProductConv(problem, torch_op=True, deterministic=False), ( + X, + Y, + W, + edge_index[0], + edge_index[1], + ) + elif request.param == "conv_det": + return oeq.TensorProductConv(problem, torch_op=True, deterministic=True), ( + X, + Y, + W, + edge_index[0], + edge_index[1], + sender_perm, + ) def test_jitscript(tp_and_inputs): - tp, inputs = tp_and_inputs + tp, inputs = tp_and_inputs uncompiled_result = tp.forward(*inputs) scripted_tp = torch.jit.script(tp) loaded_tp = None with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file: - scripted_tp.save(tmp_file.name) + scripted_tp.save(tmp_file.name) loaded_tp = torch.jit.load(tmp_file.name) - + compiled_result = loaded_tp(*inputs) assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5) def test_export(tp_and_inputs): - tp, inputs = tp_and_inputs + tp, inputs = tp_and_inputs uncompiled_result = tp.forward(*inputs) exported_tp = torch.export.export(tp, args=inputs, strict=False) @@ -80,25 +103,26 @@ def test_export(tp_and_inputs): def test_aoti(tp_and_inputs): - tp, inputs = tp_and_inputs + tp, inputs = tp_and_inputs uncompiled_result = tp.forward(*inputs) exported_tp = torch.export.export(tp, args=inputs, strict=False) aoti_model = None with tempfile.NamedTemporaryFile(suffix=".pt2") as tmp_file: try: - output_path = torch._inductor.aoti_compile_and_package( - exported_tp, - package_path=tmp_file.name) + output_path = torch._inductor.aoti_compile_and_package( + exported_tp, package_path=tmp_file.name + ) except Exception as e: - err_msg = \ - "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for " + \ - "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support " + \ - "in prior versions. " + \ - f"{e}" - assert False, err_msg - + err_msg = ( + "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for " + + "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support " + + "in prior versions. " + + f"{e}" + ) + assert False, err_msg + aoti_model = torch._inductor.aoti_load_package(output_path) aoti_result = aoti_model(*inputs) - assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) \ No newline at end of file + assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) diff --git a/tests/import_test.py b/tests/import_test.py index 447b77e9..f86958e7 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -1,25 +1,27 @@ def test_import(): import openequivariance - assert openequivariance.__version__ is not None + assert openequivariance.__version__ is not None assert openequivariance.__version__ != "0.0.0" + def test_tutorial(): import torch import e3nn.o3 as o3 - gen = torch.Generator(device='cuda') + gen = torch.Generator(device="cuda") batch_size = 1000 - X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") - X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) - Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) + X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) - instructions=[(0, 0, 0, "uvu", True)] + instructions = [(0, 0, 0, "uvu", True)] - tp_e3nn = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, - shared_weights=False, internal_weights=False).to('cuda') - W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen) + tp_e3nn = o3.TensorProduct( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ).to("cuda") + W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) Z = tp_e3nn(X, Y, W) print(torch.norm(Z)) @@ -28,10 +30,12 @@ def test_tutorial(): # =============================== import openequivariance as oeq - problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) tp_fast = oeq.TensorProduct(problem, torch_op=True) - Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier + Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier print(torch.norm(Z)) # =============================== @@ -43,27 +47,36 @@ def test_tutorial(): # Receiver, sender indices for message passing GNN edge_index = EdgeIndex( - [[0, 1, 1, 2], # Receiver - [1, 0, 2, 1]], # Sender - device='cuda', - dtype=torch.long) - - X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen) - Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen) - W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen) - - tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier - Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) # Z has shape [node_ct, z_ir.dim] + [ + [0, 1, 1, 2], # Receiver + [1, 0, 2, 1], + ], # Sender + device="cuda", + dtype=torch.long, + ) + + X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) + + tp_conv = oeq.TensorProductConv( + problem, torch_op=True, deterministic=False + ) # Reuse problem from earlier + Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] + ) # Z has shape [node_ct, z_ir.dim] print(torch.norm(Z)) # =============================== # =============================== - _, sender_perm = edge_index.sort_by("col") # Sort by sender index - edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index + _, sender_perm = edge_index.sort_by("col") # Sort by sender index + edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index # Now we can use the faster deterministic algorithm - tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) - Z = tp_conv.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm) + tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) + Z = tp_conv.forward( + X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm + ) print(torch.norm(Z)) # =============================== - assert(True) \ No newline at end of file + assert True diff --git a/tests/mace_driver.py b/tests/mace_driver.py index 4d172ea0..0312923c 100644 --- a/tests/mace_driver.py +++ b/tests/mace_driver.py @@ -1,33 +1,30 @@ -import sys, json, time, pathlib - +import json +import pathlib import argparse -import logging -from pathlib import Path import ase.io -import numpy as np import torch from e3nn import o3 from mace import data, modules, tools -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.tools import torch_geometric from torch.utils.benchmark import Timer -from mace.calculators import mace_mp from torch.profiler import profile, record_function, ProfilerActivity -from mace.tools import compile as mace_compile from mace.modules.wrapper_ops import OEQConfig, CuEquivarianceConfig import warnings + warnings.filterwarnings("ignore") try: - import cuequivariance as cue # pylint: disable=unused-import + import cuequivariance as cue # noqa F401 + CUET_AVAILABLE = True except ImportError: CUET_AVAILABLE = False + def analyze_trace(trace_file): trace = None with open(trace_file, "r") as f: @@ -42,11 +39,13 @@ def analyze_trace(trace_file): if "args" in event and "stream" in event["args"]: total += event["dur"] - if "forward" in event["name"] \ - or "backward" in event["name"] \ - or "TensorProductUniform1d" in event["name"] \ - or "channelwise_kernel_fwd" in event["name"] \ - or "channelwise_kernel_bwd" in event["name"]: + if ( + "forward" in event["name"] + or "backward" in event["name"] + or "TensorProductUniform1d" in event["name"] + or "channelwise_kernel_fwd" in event["name"] + or "channelwise_kernel_bwd" in event["name"] + ): cgtp_fwd_bwd += event["dur"] elif "_scatter_gather_elementwise_kernel" in event["name"]: @@ -54,13 +53,14 @@ def analyze_trace(trace_file): else: other_kernels += event["dur"] - return { - "total_cuda_ms": total / 1000., - "cgtp_fwd_bwd_ms": cgtp_fwd_bwd / 1000., - "reduce_by_key_ms": reduce_by_key / 1000., - "other_kernels_ms": other_kernels / 1000. + return { + "total_cuda_ms": total / 1000.0, + "cgtp_fwd_bwd_ms": cgtp_fwd_bwd / 1000.0, + "reduce_by_key_ms": reduce_by_key / 1000.0, + "other_kernels_ms": other_kernels / 1000.0, } + def create_model(hidden_irreps, max_ell, device, cueq_config=None, oeq_config=None): table = tools.AtomicNumberTable([6, 7, 8, 1, 11, 13, 15, 18]) model_config = { @@ -68,8 +68,12 @@ def create_model(hidden_irreps, max_ell, device, cueq_config=None, oeq_config=No "num_bessel": 8, "num_polynomial_cutoff": 6, "max_ell": max_ell, - "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], - "interaction_cls_first": modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], "num_interactions": 2, "num_elements": len(table), "hidden_irreps": o3.Irreps(hidden_irreps), @@ -87,9 +91,12 @@ def create_model(hidden_irreps, max_ell, device, cueq_config=None, oeq_config=No } return modules.ScaleShiftMACE(**model_config).to(device) -def benchmark_model(model, batch, num_iterations=100, warmup=100, label=None, output_folder=None): + +def benchmark_model( + model, batch, num_iterations=100, warmup=100, label=None, output_folder=None +): def run_inference(): - out = model(batch,training=True) + out = model(batch, training=True) torch.cuda.synchronize() return out @@ -104,36 +111,43 @@ def run_inference(): "run_inference": run_inference, }, ) - warm_up_measurement = timer.timeit(num_iterations) + timer.timeit(num_iterations) # warmup measurement = timer.timeit(num_iterations) with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: with record_function("model_inference"): - run_inference() + run_inference() trace_file = str(output_folder / f"traces/{label}_trace.json") prof.export_chrome_trace(trace_file) with open(output_folder / f"{label}.json", "w") as f: - json.dump({ - "time_ms_mean": measurement.mean * 1000, - "label": label, - "cuda_time_profile": analyze_trace(trace_file) - }, f, indent=4) + json.dump( + { + "time_ms_mean": measurement.mean * 1000, + "label": label, + "cuda_time_profile": analyze_trace(trace_file), + }, + f, + indent=4, + ) - #print(run_inference()) + # print(run_inference()) return measurement + def create_model_oeq(hidden_irreps, max_ell, device, cueq_config=None): source_model = create_model(hidden_irreps, max_ell, device, cueq_config) from mace.tools.scripts_utils import extract_config_mace_model + config = extract_config_mace_model(source_model) config["oeq_config"] = OEQConfig( enabled=True, optimize_channelwise=True, optimize_symmetric=True, - conv_fusion="deterministic") + conv_fusion="deterministic", + ) target_model = source_model.__class__(**config).to(device) @@ -148,22 +162,23 @@ def create_model_oeq(hidden_irreps, max_ell, device, cueq_config=None): target_model.load_state_dict(target_dict) return target_model.to(device) + def create_model_hybrid(hidden_irreps, max_ell, device, cueq_config=None): cueq_config = CuEquivarianceConfig( - enabled=True, - layout="mul_ir", - group="O3_e3nn", - optimize_all=False, - optimize_linear=True, - optimize_channelwise=False, - optimize_symmetric=True, - optimize_fctp=True, - fuse_convolution=True) - + enabled=True, + layout="mul_ir", + group="O3_e3nn", + optimize_all=False, + optimize_linear=True, + optimize_channelwise=False, + optimize_symmetric=True, + optimize_fctp=True, + fuse_convolution=True, + ) + oeq_config = OEQConfig( - enabled=True, - optimize_channelwise=True, - conv_fusion="deterministic") + enabled=True, optimize_channelwise=True, conv_fusion="deterministic" + ) model = create_model(hidden_irreps, max_ell, device, cueq_config, oeq_config) return model.to(device) @@ -171,11 +186,12 @@ def create_model_hybrid(hidden_irreps, max_ell, device, cueq_config=None): def create_model_cueq(hidden_irreps, max_ell, device, cueq_config=None): cueq_config = CuEquivarianceConfig( - enabled=True, - layout="ir_mul", - group="O3_e3nn", - optimize_all=True, - fuse_convolution=True) + enabled=True, + layout="ir_mul", + group="O3_e3nn", + optimize_all=True, + fuse_convolution=True, + ) model_cueq = create_model(hidden_irreps, max_ell, device, cueq_config) return model_cueq.to(device) @@ -190,33 +206,38 @@ def main(): parser.add_argument("--max_ell", type=int, default=3) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden_irreps", type=str, default="128x0e + 128x1o + 128x2e") - parser.add_argument("--output_folder", '-o', type=str, default=None) - parser.add_argument("--implementations", "-i", type=str, nargs='+', - default=['e3nn', 'cue', 'oeq', 'hybrid'], help="Implementations to benchmark", - choices=['e3nn', 'cue', 'oeq', 'hybrid']) + parser.add_argument("--output_folder", "-o", type=str, default=None) + parser.add_argument( + "--implementations", + "-i", + type=str, + nargs="+", + default=["e3nn", "cue", "oeq", "hybrid"], + help="Implementations to benchmark", + choices=["e3nn", "cue", "oeq", "hybrid"], + ) args = parser.parse_args() output_folder = args.output_folder output_folder = pathlib.Path(output_folder) - for dtype_str, dtype in [ ("f32", torch.float32), - ("f64", torch.float64) - ]: + for dtype_str, dtype in [("f32", torch.float32), ("f64", torch.float64)]: torch.set_default_dtype(dtype) device = torch.device(args.device) hidden_irreps = o3.Irreps(args.hidden_irreps) # Create dataset atoms_list = ase.io.read(args.xyz_file, index=":") - #table = tools.AtomicNumberTable(list(set(np.concatenate([atoms.numbers for atoms in atoms_list])))) + # table = tools.AtomicNumberTable(list(set(np.concatenate([atoms.numbers for atoms in atoms_list])))) table = tools.AtomicNumberTable([6, 7, 8, 1, 11, 13, 15, 18]) data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data.AtomicData.from_config( - data.config_from_atoms(atoms), - z_table=table, - cutoff=6.0 - ) for atoms in atoms_list], + dataset=[ + data.AtomicData.from_config( + data.config_from_atoms(atoms), z_table=table, cutoff=6.0 + ) + for atoms in atoms_list + ], batch_size=min(len(atoms_list), args.batch_size), shuffle=False, drop_last=False, @@ -227,7 +248,7 @@ def main(): output_folder.mkdir(parents=True, exist_ok=True) traces_folder = output_folder / "traces" - traces_folder.mkdir(parents=True, exist_ok=True) + traces_folder.mkdir(parents=True, exist_ok=True) # Compile is still not working for MACE and cueq; turned off for now print("\nBenchmarking Configuration:") @@ -238,28 +259,52 @@ def main(): print(f"Hidden irreps: {hidden_irreps}") print(f"Number of iterations: {args.num_iters}\n") - if 'e3nn' in args.implementations: + if "e3nn" in args.implementations: model_e3nn = create_model(hidden_irreps, args.max_ell, device) - measurement_e3nn = benchmark_model(model_e3nn, batch_dict, args.num_iters, label=f"e3nn_{dtype_str}", output_folder=output_folder) + measurement_e3nn = benchmark_model( + model_e3nn, + batch_dict, + args.num_iters, + label=f"e3nn_{dtype_str}", + output_folder=output_folder, + ) print(f"E3NN Measurement:\n{measurement_e3nn}") - if 'oeq' in args.implementations: + if "oeq" in args.implementations: model_oeq = create_model_oeq(hidden_irreps, args.max_ell, device) - measurement_oeq = benchmark_model(model_oeq, batch_dict, args.num_iters, label=f"ours_{dtype_str}", output_folder=output_folder) + measurement_oeq = benchmark_model( + model_oeq, + batch_dict, + args.num_iters, + label=f"ours_{dtype_str}", + output_folder=output_folder, + ) print(f"\nOpenEquivariance Measurement:\n{measurement_oeq}") - #print(f"\nSpeedup: {measurement_e3nn.mean / measurement_oeq.mean:.2f}x") + # print(f"\nSpeedup: {measurement_e3nn.mean / measurement_oeq.mean:.2f}x") - if 'hybrid' in args.implementations: + if "hybrid" in args.implementations: model_hybrid = create_model_hybrid(hidden_irreps, args.max_ell, device) - measurement_hybrid = benchmark_model(model_hybrid, batch_dict, args.num_iters, label=f"hybrid_{dtype_str}", output_folder=output_folder) + measurement_hybrid = benchmark_model( + model_hybrid, + batch_dict, + args.num_iters, + label=f"hybrid_{dtype_str}", + output_folder=output_folder, + ) print(f"\nHybrid Measurement:\n{measurement_hybrid}") - if 'cue' in args.implementations: + if "cue" in args.implementations: model_cueq = create_model_cueq(hidden_irreps, args.max_ell, device) - measurement_cueq = benchmark_model(model_cueq, batch_dict, args.num_iters, label=f"cuE_{dtype_str}", output_folder=output_folder) + measurement_cueq = benchmark_model( + model_cueq, + batch_dict, + args.num_iters, + label=f"cuE_{dtype_str}", + output_folder=output_folder, + ) print(f"\nCUET Measurement:\n{measurement_cueq}") - #print(f"\nSpeedup: {measurement_e3nn.mean / measurement_cueq.mean:.2f}x") + # print(f"\nSpeedup: {measurement_e3nn.mean / measurement_cueq.mean:.2f}x") if __name__ == "__main__": - main() \ No newline at end of file + main() From ad81a523d60f5612a925d0e5576a75515ccfc42d Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Mon, 26 May 2025 21:39:26 -0700 Subject: [PATCH 05/15] only test deterministic if shared_weights != true --- tests/conv_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index 24e09c34..85e31331 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -45,7 +45,10 @@ def conv_object(self, request, problem): if request.param == "atomic": return oeq.TensorProductConv(problem, deterministic=False) elif request.param == "deterministic": - return oeq.TensorProductConv(problem, deterministic=True) + if not problem.shared_weights: + return oeq.TensorProductConv(problem, deterministic=True) + else: + return None elif request.param == "kahan": if problem.irrep_dtype == np.float32: return oeq.TensorProductConv(problem, deterministic=True, kahan=True) From f22ec078ce892ca1736e3c413b4e52fbfddf917d Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Mon, 26 May 2025 21:40:21 -0700 Subject: [PATCH 06/15] revert Tensor -> tensor change (Tensor is a type, tensor the constructor) --- .../convolution/TensorProductConv.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index fc77bc3a..d7f0a06a 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -41,13 +41,13 @@ def __init__( def forward( self, - L1_in: torch.tensor, - L2_in: torch.tensor, - W: torch.tensor, - rows: torch.tensor, - cols: torch.tensor, - sender_perm: Optional[torch.tensor] = None, - ) -> torch.tensor: + L1_in: torch.Tensor, + L2_in: torch.Tensor, + W: torch.Tensor, + rows: torch.Tensor, + cols: torch.Tensor, + sender_perm: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if sender_perm is None: return torch.ops.torch_tp_jit.jit_conv_forward( self.internal, From 276c8bd3c23ddab4e684d326b41f54eccb4fb159 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Mon, 26 May 2025 21:41:03 -0700 Subject: [PATCH 07/15] add ruff and pre-commit to the CI requirements .txt (exact version to promote caching) --- .github/workflows/requirements_cuda_ci.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index da9fde80..ef35e340 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,4 +1,6 @@ numpy==2.2.5 torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 -ninja==1.11.1.4 \ No newline at end of file +ninja==1.11.1.4 +ruff==0.11.11 +pre-commit==4.2.0 \ No newline at end of file From d730f8be57a8d4a25148e5fd5b0fa2802ce8e343 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 27 May 2025 14:52:20 -0700 Subject: [PATCH 08/15] skips and looser thresholds --- tests/conv_test.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index 85e31331..ae5f8ca2 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -11,7 +11,7 @@ class ConvCorrectness: def thresh(self, direction): - return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] + return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] def check_result(self, result, fieldname): with check: @@ -48,17 +48,21 @@ def conv_object(self, request, problem): if not problem.shared_weights: return oeq.TensorProductConv(problem, deterministic=True) else: - return None + pytest.skip("Shared weights not supported with deterministic") elif request.param == "kahan": if problem.irrep_dtype == np.float32: - return oeq.TensorProductConv(problem, deterministic=True, kahan=True) + if not problem.shared_weights: + return oeq.TensorProductConv( + problem, deterministic=True, kahan=True + ) + else: + pytest.skip("Shared weights not supported with kahan") else: - return None + pytest.skip("Only Float32 supported with kahan") def test_tp_fwd(self, conv_object, graph): if conv_object is None: - assert True - return + pytest.skip("'conv_object' fixture returned None, skipping") result = conv_object.test_correctness_forward( graph, @@ -71,8 +75,7 @@ def test_tp_fwd(self, conv_object, graph): def test_tp_bwd(self, conv_object, graph): if conv_object is None: - assert True - return + pytest.skip("'conv_object' fixture returned None, skipping") result = conv_object.test_correctness_backward( graph, @@ -87,8 +90,7 @@ def test_tp_bwd(self, conv_object, graph): def test_tp_double_bwd(self, conv_object, graph): if conv_object is None: - assert True - return + pytest.skip("'conv_object' fixture returned None, skipping") result = conv_object.test_correctness_double_backward( graph, From 11fb61ae1aa006ed2bd1e0e8781e7d8f5f62535f Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 27 May 2025 16:18:34 -0700 Subject: [PATCH 09/15] save failed tensors --- openequivariance/benchmark/correctness_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index 6284dbb2..df629198 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Optional, Union from openequivariance.implementations.TensorProductBase import TensorProductBase @@ -33,7 +34,11 @@ def check_similiarity( diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) result["diff_Linf_norm"] = diff_Linf_norm result["pass"] = bool(diff_Linf_norm < correctness_threshold) - + path = Path("testing_failures") + path.mkdir(exist_ok=True) + np.save(path / "to_check", to_check) + np.save(path / "ground_truth", ground_truth) + logger.debug(print(correctness_threshold)) if result["pass"]: logger.info( f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" From 95c9f3ce9a3f10559d9191ad29900a39aa85ea39 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 27 May 2025 16:40:31 -0700 Subject: [PATCH 10/15] remove tensor saving --- openequivariance/benchmark/correctness_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index df629198..62133112 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -1,4 +1,3 @@ -from pathlib import Path from typing import Optional, Union from openequivariance.implementations.TensorProductBase import TensorProductBase @@ -34,11 +33,6 @@ def check_similiarity( diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) result["diff_Linf_norm"] = diff_Linf_norm result["pass"] = bool(diff_Linf_norm < correctness_threshold) - path = Path("testing_failures") - path.mkdir(exist_ok=True) - np.save(path / "to_check", to_check) - np.save(path / "ground_truth", ground_truth) - logger.debug(print(correctness_threshold)) if result["pass"]: logger.info( f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" From 4bbf1a78802b28c2760c01637a45f8016f0a34c3 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Tue, 27 May 2025 17:12:20 -0700 Subject: [PATCH 11/15] make E741 (lowercase L) lint a package level preference --- openequivariance/benchmark/plotting/plot_convolution.py | 1 - openequivariance/benchmark/plotting/plot_double_backward.py | 1 - openequivariance/benchmark/plotting/plot_roofline.py | 1 - openequivariance/benchmark/plotting/plot_uvu.py | 1 - openequivariance/benchmark/plotting/plot_uvw.py | 1 - openequivariance/benchmark/plotting/plotting_utils.py | 1 - openequivariance/implementations/CUETensorProduct.py | 2 +- openequivariance/implementations/convolution/CUEConv.py | 2 +- openequivariance/implementations/e3nn_lite.py | 2 +- pyproject.toml | 5 ++++- 10 files changed, 7 insertions(+), 10 deletions(-) diff --git a/openequivariance/benchmark/plotting/plot_convolution.py b/openequivariance/benchmark/plotting/plot_convolution.py index f2f62070..bf1d7fff 100644 --- a/openequivariance/benchmark/plotting/plot_convolution.py +++ b/openequivariance/benchmark/plotting/plot_convolution.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt import pathlib diff --git a/openequivariance/benchmark/plotting/plot_double_backward.py b/openequivariance/benchmark/plotting/plot_double_backward.py index a10193c5..bf880aaf 100644 --- a/openequivariance/benchmark/plotting/plot_double_backward.py +++ b/openequivariance/benchmark/plotting/plot_double_backward.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt import pathlib diff --git a/openequivariance/benchmark/plotting/plot_roofline.py b/openequivariance/benchmark/plotting/plot_roofline.py index 42e08020..bb7022b7 100644 --- a/openequivariance/benchmark/plotting/plot_roofline.py +++ b/openequivariance/benchmark/plotting/plot_roofline.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import numpy as np import pathlib from openequivariance.benchmark.plotting.plotting_utils import ( diff --git a/openequivariance/benchmark/plotting/plot_uvu.py b/openequivariance/benchmark/plotting/plot_uvu.py index aa926534..1fe04b40 100644 --- a/openequivariance/benchmark/plotting/plot_uvu.py +++ b/openequivariance/benchmark/plotting/plot_uvu.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt import pathlib diff --git a/openequivariance/benchmark/plotting/plot_uvw.py b/openequivariance/benchmark/plotting/plot_uvw.py index 01eb6fa3..11648bc4 100644 --- a/openequivariance/benchmark/plotting/plot_uvw.py +++ b/openequivariance/benchmark/plotting/plot_uvw.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import numpy as np import matplotlib.pyplot as plt import pathlib diff --git a/openequivariance/benchmark/plotting/plotting_utils.py b/openequivariance/benchmark/plotting/plotting_utils.py index fae6898c..80f959c2 100644 --- a/openequivariance/benchmark/plotting/plotting_utils.py +++ b/openequivariance/benchmark/plotting/plotting_utils.py @@ -1,4 +1,3 @@ -# ruff: noqa: E741 import json import os import pathlib diff --git a/openequivariance/implementations/CUETensorProduct.py b/openequivariance/implementations/CUETensorProduct.py index 4011eec4..bf6d10a0 100644 --- a/openequivariance/implementations/CUETensorProduct.py +++ b/openequivariance/implementations/CUETensorProduct.py @@ -70,7 +70,7 @@ def __lt__( # pylint: disable=no-self-argument @classmethod def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): # noqa : E741 + for l in itertools.count(0): yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l) diff --git a/openequivariance/implementations/convolution/CUEConv.py b/openequivariance/implementations/convolution/CUEConv.py index 2ef6ccf4..a75a3b65 100644 --- a/openequivariance/implementations/convolution/CUEConv.py +++ b/openequivariance/implementations/convolution/CUEConv.py @@ -72,7 +72,7 @@ def __lt__( # pylint: disable=no-self-argument @classmethod def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): # noqa : E741 + for l in itertools.count(0): yield O3_e3nn(l=l, p=1 * (-1) ** l) yield O3_e3nn(l=l, p=-1 * (-1) ** l) diff --git a/openequivariance/implementations/e3nn_lite.py b/openequivariance/implementations/e3nn_lite.py index df708071..01b5cb01 100644 --- a/openequivariance/implementations/e3nn_lite.py +++ b/openequivariance/implementations/e3nn_lite.py @@ -1,4 +1,4 @@ -# ruff: noqa: E741, E743 +# ruff: noqa: E743 """ This file contains lightly modified code from E3NN. The code has been modified to remove all dependency on Pytorch. diff --git a/pyproject.toml b/pyproject.toml index af40bc1d..911a48d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,4 +44,7 @@ include = ["openequivariance*"] [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", -] \ No newline at end of file +] + +[tool.ruff] +lint.ignore = ["E741"] \ No newline at end of file From 22ab2b01c901d724c979e7955bb4ce0fca93b2d2 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 28 May 2025 16:17:26 -0700 Subject: [PATCH 12/15] add pre-commit github action --- .github/workflows/pre-commit.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000..2f21b389 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,13 @@ +name: Pre-Commit Checks + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: pre-commit/action@v3.0.1 \ No newline at end of file From 68c2bbb775a9834fefad00ac4555310b3bcbc88a Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 28 May 2025 16:23:35 -0700 Subject: [PATCH 13/15] lint and format --- tests/multidevice_test.py | 48 ++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/tests/multidevice_test.py b/tests/multidevice_test.py index 7db89277..7c017e77 100644 --- a/tests/multidevice_test.py +++ b/tests/multidevice_test.py @@ -1,37 +1,49 @@ -import textwrap, torch, subprocess, os -import numpy as np +import textwrap +import torch +import subprocess +import os + def test_multidevice(): - result = subprocess.run([ - "python", "-m", "torch.distributed.run", - "--standalone", "--nnodes=1", "--nproc-per-node=gpu", - __file__], + result = subprocess.run( + [ + "python", + "-m", + "torch.distributed.run", + "--standalone", + "--nnodes=1", + "--nproc-per-node=gpu", + __file__, + ], capture_output=True, - check=False) - + check=False, + ) + if result.returncode != 0: - error_string = f''' - Invocation: {' '.join(result.args)} + error_string = f""" + Invocation: {" ".join(result.args)} Test failed with return code {result.returncode}. \nOutput:\n\n{result.stdout.decode()} \nError:\n\n{result.stderr.decode()} - ''' - assert False, textwrap.dedent(error_string) + """ + assert False, textwrap.dedent(error_string) assert True + if __name__ == "__main__": import openequivariance as oeq - # Use MACE-large to test >64KB shared memory allocation - from openequivariance.benchmark.benchmark_configs import mace_problems - problem = mace_problems[0] + # Use MACE-large to test >64KB shared memory allocation + from openequivariance.benchmark.benchmark_configs import mace_problems + + problem = mace_problems[0] local_rank = int(os.environ["LOCAL_RANK"]) - device = f'cuda:{local_rank}' + device = f"cuda:{local_rank}" torch.set_default_device(device) - X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out + X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out tp = oeq.TensorProduct(problem) batch_size = 1000 @@ -42,4 +54,4 @@ def test_multidevice(): W = torch.rand(batch_size, problem.weight_numel, device=device, generator=gen) with torch.cuda.device(device): - result = tp.forward(X, Y, W) \ No newline at end of file + result = tp.forward(X, Y, W) From 88a4a8ddbef89899c142ba5075b7bccedf8b1da1 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 28 May 2025 18:19:39 -0700 Subject: [PATCH 14/15] attempt at faster CI --- .github/workflows/requirements_cuda_ci.txt | 4 +--- .github/workflows/verify_extension_build.yml | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index ef35e340..da9fde80 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,6 +1,4 @@ numpy==2.2.5 torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 -ninja==1.11.1.4 -ruff==0.11.11 -pre-commit==4.2.0 \ No newline at end of file +ninja==1.11.1.4 \ No newline at end of file diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index db48af7b..83f1e894 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -13,8 +13,10 @@ permissions: jobs: verify_cuda_extension: if: ${{ github.event.label.name == 'ci-ready' || github.event_name != 'pull_request' }} - runs-on: ubuntu-latest - + runs-on: ubuntu-24.04 + + container: + image: nvcr.io/nvidia/cuda:12.9.0-devel-ubuntu24.04 # Pre installed CUDA toolkit steps: - uses: actions/checkout@v4 - name: Set up Python @@ -26,8 +28,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - sudo apt-get update - sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt pip install -e . From b34962a5e867cba94ff5ec86a5f6eec468690664 Mon Sep 17 00:00:00 2001 From: Austin Glover Date: Wed, 28 May 2025 18:23:36 -0700 Subject: [PATCH 15/15] remove these as they are covered by pre-commit action --- .github/workflows/requirements_cuda_ci.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index ef35e340..da9fde80 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,6 +1,4 @@ numpy==2.2.5 torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 -ninja==1.11.1.4 -ruff==0.11.11 -pre-commit==4.2.0 \ No newline at end of file +ninja==1.11.1.4 \ No newline at end of file