Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openequivariance/benchmark/correctness_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def check_similiarity(name : str, to_check : np.ndarray, ground_truth : np.nda
result["shape_match"] = True
diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf))
result["diff_Linf_norm"] = diff_Linf_norm
result["pass"] = bool(diff_Linf_norm < correctness_threshold)
result["pass"] = bool(diff_Linf_norm < correctness_threshold)

if result["pass"]:
logger.info(f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}")
Expand Down
13 changes: 12 additions & 1 deletion openequivariance/extension/torch_tp_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class TorchJITConv : public torch::CustomClassHolder {
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
JITConvImpl<JITKernel> internal;
int64_t L3_dim;
int shared_weights;

TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
fwd_dict(fwd_dict_i.copy()),
Expand All @@ -230,7 +231,8 @@ class TorchJITConv : public torch::CustomClassHolder {
to_map(dbl_bwd_dict_i),
to_map(kernel_dims_i)
),
L3_dim(kernel_dims.at("L3_dim")) { }
L3_dim(kernel_dims.at("L3_dim")),
shared_weights(kernel_dims.at("shared_weights")) { }

tuple<tuple<string, string>,
tuple<string, Map_t>,
Expand Down Expand Up @@ -341,6 +343,11 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_backward(
torch::Tensor cols_contig = cols.contiguous();
torch::Tensor workspace_contig = workspace.contiguous();
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();

if(jit_instance->shared_weights == 1) {
W_grad.zero_();
}

jit_instance->internal.backward(
data_ptr(L1_in_contig), data_ptr(L1_grad),
data_ptr(L2_in_contig), data_ptr(L2_grad),
Expand Down Expand Up @@ -388,6 +395,10 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_doubl
torch::Tensor workspace_contig = workspace.contiguous();
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();

if(jit_instance->shared_weights == 1) {
W_grad.zero_();
}

jit_instance->internal.double_backward(
data_ptr(L1_in_contig), data_ptr(L2_in_contig),
data_ptr(W_contig), data_ptr(L3_grad_contig),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor,
L2_grad = torch.empty_like(L2_in)
weights_grad = torch.empty_like(weights)

if self.config.shared_weights:
weights_grad[:] = 0.0

self.internal.backward_rawptrs(
L1_in.contiguous().data_ptr(), L1_grad.data_ptr(),
L2_in.contiguous().data_ptr(), L2_grad.data_ptr(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, config, idx_dtype=np.int64,

analysis = filter_and_analyze_problem(config)
self.is_uvw = analysis["is_uvw"]
assert not config.shared_weights, "LoopUnrollConv does not yet support shared weights"

if config.shared_weights:
assert not deterministic, "Deterministic convolution does not support shared weights"

forward_schedule_type = 3
backward_schedule_type = 2
Expand Down Expand Up @@ -148,7 +150,8 @@ def generate_double_backward_schedule(warps_per_block):
vars(self.backward_schedule.launch_config),
vars(self.double_backward_schedule.launch_config),
{"L3_dim": self.L3.dim,
"is_uvw": int(self.is_uvw)})
"is_uvw": int(self.is_uvw),
"shared_weights": int(config.shared_weights)})
logger.info("Kernel compiled!")

#with open("scratch.txt", "w") as f:
Expand Down
20 changes: 14 additions & 6 deletions openequivariance/templates/loop_unroll_batch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ __global__ void backward(
IRREP_T* l3_shft = L3_grad + i * {{backward_schedule.L3.dim}} + lane_id;

{%- if not tpp.shared_weights %}
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
{%- else %}
WEIGHT_T* w = weights;
WEIGHT_T* wgrad = weights_grad;
WEIGHT_T* w = weights;
WEIGHT_T* wgrad = weights_grad;
{%- endif %}
WEIGHT_T* weights_shft = w + lane_id;

Expand Down Expand Up @@ -128,7 +128,11 @@ __global__ void backward(
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}

{%- if not backward_schedule.stream_weights%}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- if not tpp.shared_weights %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- else %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
{%- endif %}
{%- endif %}
} {%- endfor %}
}
Expand Down Expand Up @@ -295,7 +299,11 @@ __global__ void double_backward_B(
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}

{% if not schedule.stream_weights%}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- if not tpp.shared_weights %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- else %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
{%- endif %}
{% endif %}
}
} {%- endfor %}
Expand Down
30 changes: 21 additions & 9 deletions openequivariance/templates/loop_unroll_conv_atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ __global__ void forward(
IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id;
IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id;
IRREP_T* l3 = L3_out + row * {{forward_schedule.L3.dim}} + lane_id;
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
{%- if not tpp.shared_weights %}
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
{%- else %}
WEIGHT_T* w = weights;
{%- endif %}

__syncwarp();
{{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }}
Expand Down Expand Up @@ -115,11 +119,11 @@ __global__ void backward(
IRREP_T* l3_shft = L3_grad + row * {{backward_schedule.L3.dim}} + lane_id;

{%- if not tpp.shared_weights %}
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
{%- else %}
WEIGHT_T* w = weights;
WEIGHT_T* wgrad = weights_grad;
WEIGHT_T* w = weights;
WEIGHT_T* wgrad = weights_grad;
{%- endif %}
WEIGHT_T* weights_shft = w + lane_id;

Expand Down Expand Up @@ -155,8 +159,12 @@ __global__ void backward(
{{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }}
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}

{%- if not backward_schedule.stream_weights%}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- if not backward_schedule.stream_weights %}
{%- if not tpp.shared_weights %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- else %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
{%- endif %}
{%- endif %}
} {%- endfor %}
}
Expand Down Expand Up @@ -332,8 +340,12 @@ __global__ void double_backward_B(
{{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }}
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}

{% if not schedule.stream_weights%}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{% if not schedule.stream_weights %}
{%- if not tpp.shared_weights %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
{%- else %}
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
{%- endif %}
{% endif %}
}
} {%- endfor %}
Expand Down
4 changes: 2 additions & 2 deletions openequivariance/templates/loop_unroll_tp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
scratch1[{{i % num_scratch_reg}}] = l3_grad[{{coord3}}] * {{value}};

{%- if double_bwd %}
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}];
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}];
{%- else %}
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}];
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}];
{%- endif %}

scratch2[{{i % num_scratch_reg}}] = scratch1[{{i % num_scratch_reg}}] * weight;
Expand Down
34 changes: 30 additions & 4 deletions tests/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from itertools import chain, product

class TPCorrectness:
def thresh(self, direction):
return {
"fwd": 1e-5,
"bwd": 3e-4,
"double_bwd": 3e-4
}[direction]

def check_result(self, result, fieldname):
with check:
error = result[fieldname]["diff_Linf_norm"]
Expand All @@ -30,7 +37,7 @@ def test_tp_fwd(self, tp_and_problem):
test_implementation=tp,
reference_implementation=None,
batch_size=1000,
correctness_threshold=1e-5,
correctness_threshold=self.thresh("fwd"),
prng_seed=12345)

self.check_result(result, "output")
Expand All @@ -42,7 +49,7 @@ def test_tp_bwd(self, tp_and_problem):
test_implementation=tp,
reference_implementation=None,
batch_size=1000,
correctness_threshold=3e-4,
correctness_threshold=self.thresh("bwd"),
prng_seed=12345)

self.check_result(result, "weight_grad")
Expand All @@ -56,7 +63,7 @@ def test_tp_double_bwd(self, tp_and_problem):
test_implementation=tp,
reference_implementation = None,
batch_size = 200,
correctness_threshold = 3e-4,
correctness_threshold=self.thresh("double_bwd"),
prng_seed = 12345)

self.check_result(result, "output_double_grad")
Expand Down Expand Up @@ -129,4 +136,23 @@ def problem(self, request, dtype):
return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e",
instructions, shared_weights=False,
internal_weights=False,
irrep_dtype=dtype, weight_dtype=dtype)
irrep_dtype=dtype, weight_dtype=dtype)


class TestSharedWeights(TPCorrectness):
from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs
problems = [mace_problems[0], diffdock_configs[0]]

def thresh(self, direction):
return {
"fwd": 1e-5,
"bwd": 5e-4, # Expect higher errors for shared weights
"double_bwd": 5e-4
}[direction]

@pytest.fixture(params=problems, ids = lambda x : x.label, scope="class")
def problem(self, request, dtype):
problem = request.param
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
problem.shared_weights = True
return problem
43 changes: 37 additions & 6 deletions tests/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from itertools import chain, product

class ConvCorrectness:
def thresh(self, direction):
return {
"fwd": 1e-5,
"bwd": 3e-4,
"double_bwd": 3e-4
}[direction]


def check_result(self, result, fieldname):
with check:
error = result[fieldname]["diff_Linf_norm"]
thresh = result["thresh"]
assert result[fieldname]["pass"], f"{fieldname} observed error={error:.2f} >= {thresh}"

assert result[fieldname]["pass"], f"{fieldname} observed error={error:.5f} >= {thresh}"
@pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class')
def dtype(self, request):
return request.param
Expand Down Expand Up @@ -48,7 +56,7 @@ def test_tp_fwd(self, conv_object, graph):
return

result = conv_object.test_correctness_forward(graph,
thresh=3e-05,
thresh=self.thresh("fwd"),
prng_seed=12345,
reference_implementation=None)

Expand All @@ -60,7 +68,7 @@ def test_tp_bwd(self, conv_object, graph):
return

result = conv_object.test_correctness_backward(graph,
thresh=3e-04,
thresh=self.thresh("bwd"),
prng_seed=12345,
reference_implementation=None)

Expand All @@ -74,7 +82,7 @@ def test_tp_double_bwd(self, conv_object, graph):
return

result = conv_object.test_correctness_double_backward(graph,
thresh=3e-04,
thresh=self.thresh("double_bwd"),
prng_seed=12345,
reference_implementation=None)

Expand Down Expand Up @@ -140,4 +148,27 @@ def problem(self, request, dtype):
return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e",
instructions, shared_weights=False,
internal_weights=False,
irrep_dtype=dtype, weight_dtype=dtype)
irrep_dtype=dtype, weight_dtype=dtype)


class TestAtomicSharedWeights(ConvCorrectness):
from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs
problems = [mace_problems[0], diffdock_configs[0]]

def thresh(self, direction):
return {
"fwd": 1e-5,
"bwd": 5e-2, # Expect higher errors for shared weights
"double_bwd": 5e-2
}[direction]

@pytest.fixture(params=problems, ids = lambda x : x.label, scope="class")
def problem(self, request, dtype):
problem = request.param
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
problem.shared_weights = True
return problem

@pytest.fixture(scope='class')
def conv_object(self, request, problem):
return oeq.TensorProductConv(problem, deterministic=False)