diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index 7f721a9c..c0a70f63 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -21,7 +21,7 @@ def check_similiarity(name : str, to_check : np.ndarray, ground_truth : np.nda result["shape_match"] = True diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) result["diff_Linf_norm"] = diff_Linf_norm - result["pass"] = bool(diff_Linf_norm < correctness_threshold) + result["pass"] = bool(diff_Linf_norm < correctness_threshold) if result["pass"]: logger.info(f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}") diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 6711f3f1..f4a5cbe2 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -218,6 +218,7 @@ class TorchJITConv : public torch::CustomClassHolder { Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; JITConvImpl internal; int64_t L3_dim; + int shared_weights; TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : fwd_dict(fwd_dict_i.copy()), @@ -230,7 +231,8 @@ class TorchJITConv : public torch::CustomClassHolder { to_map(dbl_bwd_dict_i), to_map(kernel_dims_i) ), - L3_dim(kernel_dims.at("L3_dim")) { } + L3_dim(kernel_dims.at("L3_dim")), + shared_weights(kernel_dims.at("shared_weights")) { } tuple, tuple, @@ -341,6 +343,11 @@ tuple jit_conv_backward( torch::Tensor cols_contig = cols.contiguous(); torch::Tensor workspace_contig = workspace.contiguous(); torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); + + if(jit_instance->shared_weights == 1) { + W_grad.zero_(); + } + jit_instance->internal.backward( data_ptr(L1_in_contig), data_ptr(L1_grad), data_ptr(L2_in_contig), data_ptr(L2_grad), @@ -388,6 +395,10 @@ tuple jit_conv_doubl torch::Tensor workspace_contig = workspace.contiguous(); torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); + if(jit_instance->shared_weights == 1) { + W_grad.zero_(); + } + jit_instance->internal.double_backward( data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index b645f896..612427bf 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -697,6 +697,9 @@ def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor, L2_grad = torch.empty_like(L2_in) weights_grad = torch.empty_like(weights) + if self.config.shared_weights: + weights_grad[:] = 0.0 + self.internal.backward_rawptrs( L1_in.contiguous().data_ptr(), L1_grad.data_ptr(), L2_in.contiguous().data_ptr(), L2_grad.data_ptr(), diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 676397a0..51167640 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -21,7 +21,9 @@ def __init__(self, config, idx_dtype=np.int64, analysis = filter_and_analyze_problem(config) self.is_uvw = analysis["is_uvw"] - assert not config.shared_weights, "LoopUnrollConv does not yet support shared weights" + + if config.shared_weights: + assert not deterministic, "Deterministic convolution does not support shared weights" forward_schedule_type = 3 backward_schedule_type = 2 @@ -148,7 +150,8 @@ def generate_double_backward_schedule(warps_per_block): vars(self.backward_schedule.launch_config), vars(self.double_backward_schedule.launch_config), {"L3_dim": self.L3.dim, - "is_uvw": int(self.is_uvw)}) + "is_uvw": int(self.is_uvw), + "shared_weights": int(config.shared_weights)}) logger.info("Kernel compiled!") #with open("scratch.txt", "w") as f: diff --git a/openequivariance/templates/loop_unroll_batch.cuh b/openequivariance/templates/loop_unroll_batch.cuh index 533385af..5a55cfd7 100644 --- a/openequivariance/templates/loop_unroll_batch.cuh +++ b/openequivariance/templates/loop_unroll_batch.cuh @@ -81,11 +81,11 @@ __global__ void backward( IRREP_T* l3_shft = L3_grad + i * {{backward_schedule.L3.dim}} + lane_id; {%- if not tpp.shared_weights %} - WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; - WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}}; + WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}}; {%- else %} - WEIGHT_T* w = weights; - WEIGHT_T* wgrad = weights_grad; + WEIGHT_T* w = weights; + WEIGHT_T* wgrad = weights_grad; {%- endif %} WEIGHT_T* weights_shft = w + lane_id; @@ -128,7 +128,11 @@ __global__ void backward( {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} {%- if not backward_schedule.stream_weights%} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- if not tpp.shared_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- else %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);) + {%- endif %} {%- endif %} } {%- endfor %} } @@ -295,7 +299,11 @@ __global__ void double_backward_B( {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} {% if not schedule.stream_weights%} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- if not tpp.shared_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- else %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);) + {%- endif %} {% endif %} } } {%- endfor %} diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index c73dbbc5..21f389d2 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -68,7 +68,11 @@ __global__ void forward( IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id; IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id; IRREP_T* l3 = L3_out + row * {{forward_schedule.L3.dim}} + lane_id; - WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; + {%- if not tpp.shared_weights %} + WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = weights; + {%- endif %} __syncwarp(); {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} @@ -115,11 +119,11 @@ __global__ void backward( IRREP_T* l3_shft = L3_grad + row * {{backward_schedule.L3.dim}} + lane_id; {%- if not tpp.shared_weights %} - WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; - WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}}; + WEIGHT_T* w = weights + i * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}}; {%- else %} - WEIGHT_T* w = weights; - WEIGHT_T* wgrad = weights_grad; + WEIGHT_T* w = weights; + WEIGHT_T* wgrad = weights_grad; {%- endif %} WEIGHT_T* weights_shft = w + lane_id; @@ -155,8 +159,12 @@ __global__ void backward( {{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }} {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} - {%- if not backward_schedule.stream_weights%} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- if not backward_schedule.stream_weights %} + {%- if not tpp.shared_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- else %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);) + {%- endif %} {%- endif %} } {%- endfor %} } @@ -332,8 +340,12 @@ __global__ void double_backward_B( {{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }} {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} - {% if not schedule.stream_weights%} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {% if not schedule.stream_weights %} + {%- if not tpp.shared_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {%- else %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);) + {%- endif %} {% endif %} } } {%- endfor %} diff --git a/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/templates/loop_unroll_tp.cuh index 8de105f5..eab95c57 100644 --- a/openequivariance/templates/loop_unroll_tp.cuh +++ b/openequivariance/templates/loop_unroll_tp.cuh @@ -204,9 +204,9 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__ scratch1[{{i % num_scratch_reg}}] = l3_grad[{{coord3}}] * {{value}}; {%- if double_bwd %} - weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}]; + weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}]; {%- else %} - weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}]; + weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}]; {%- endif %} scratch2[{{i % num_scratch_reg}}] = scratch1[{{i % num_scratch_reg}}] * weight; diff --git a/tests/batch_test.py b/tests/batch_test.py index fef70cf2..2879a4d6 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -8,6 +8,13 @@ from itertools import chain, product class TPCorrectness: + def thresh(self, direction): + return { + "fwd": 1e-5, + "bwd": 3e-4, + "double_bwd": 3e-4 + }[direction] + def check_result(self, result, fieldname): with check: error = result[fieldname]["diff_Linf_norm"] @@ -30,7 +37,7 @@ def test_tp_fwd(self, tp_and_problem): test_implementation=tp, reference_implementation=None, batch_size=1000, - correctness_threshold=1e-5, + correctness_threshold=self.thresh("fwd"), prng_seed=12345) self.check_result(result, "output") @@ -42,7 +49,7 @@ def test_tp_bwd(self, tp_and_problem): test_implementation=tp, reference_implementation=None, batch_size=1000, - correctness_threshold=3e-4, + correctness_threshold=self.thresh("bwd"), prng_seed=12345) self.check_result(result, "weight_grad") @@ -56,7 +63,7 @@ def test_tp_double_bwd(self, tp_and_problem): test_implementation=tp, reference_implementation = None, batch_size = 200, - correctness_threshold = 3e-4, + correctness_threshold=self.thresh("double_bwd"), prng_seed = 12345) self.check_result(result, "output_double_grad") @@ -129,4 +136,23 @@ def problem(self, request, dtype): return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file + irrep_dtype=dtype, weight_dtype=dtype) + + +class TestSharedWeights(TPCorrectness): + from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs + problems = [mace_problems[0], diffdock_configs[0]] + + def thresh(self, direction): + return { + "fwd": 1e-5, + "bwd": 5e-4, # Expect higher errors for shared weights + "double_bwd": 5e-4 + }[direction] + + @pytest.fixture(params=problems, ids = lambda x : x.label, scope="class") + def problem(self, request, dtype): + problem = request.param + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.shared_weights = True + return problem \ No newline at end of file diff --git a/tests/conv_test.py b/tests/conv_test.py index a7e592f6..7f8496ad 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -7,12 +7,20 @@ from itertools import chain, product class ConvCorrectness: + def thresh(self, direction): + return { + "fwd": 1e-5, + "bwd": 3e-4, + "double_bwd": 3e-4 + }[direction] + + def check_result(self, result, fieldname): with check: error = result[fieldname]["diff_Linf_norm"] thresh = result["thresh"] - assert result[fieldname]["pass"], f"{fieldname} observed error={error:.2f} >= {thresh}" - + assert result[fieldname]["pass"], f"{fieldname} observed error={error:.5f} >= {thresh}" + @pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class') def dtype(self, request): return request.param @@ -48,7 +56,7 @@ def test_tp_fwd(self, conv_object, graph): return result = conv_object.test_correctness_forward(graph, - thresh=3e-05, + thresh=self.thresh("fwd"), prng_seed=12345, reference_implementation=None) @@ -60,7 +68,7 @@ def test_tp_bwd(self, conv_object, graph): return result = conv_object.test_correctness_backward(graph, - thresh=3e-04, + thresh=self.thresh("bwd"), prng_seed=12345, reference_implementation=None) @@ -74,7 +82,7 @@ def test_tp_double_bwd(self, conv_object, graph): return result = conv_object.test_correctness_double_backward(graph, - thresh=3e-04, + thresh=self.thresh("double_bwd"), prng_seed=12345, reference_implementation=None) @@ -140,4 +148,27 @@ def problem(self, request, dtype): return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file + irrep_dtype=dtype, weight_dtype=dtype) + + +class TestAtomicSharedWeights(ConvCorrectness): + from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs + problems = [mace_problems[0], diffdock_configs[0]] + + def thresh(self, direction): + return { + "fwd": 1e-5, + "bwd": 5e-2, # Expect higher errors for shared weights + "double_bwd": 5e-2 + }[direction] + + @pytest.fixture(params=problems, ids = lambda x : x.label, scope="class") + def problem(self, request, dtype): + problem = request.param + problem.irrep_dtype, problem.weight_dtype = dtype, dtype + problem.shared_weights = True + return problem + + @pytest.fixture(scope='class') + def conv_object(self, request, problem): + return oeq.TensorProductConv(problem, deterministic=False) \ No newline at end of file