Skip to content

Introduce benchmark framework using CUDA events#157

Open
mcgibbon wants to merge 8 commits intoNVIDIA:mainfrom
mcgibbon:feature/benchmark-framework
Open

Introduce benchmark framework using CUDA events#157
mcgibbon wants to merge 8 commits intoNVIDIA:mainfrom
mcgibbon:feature/benchmark-framework

Conversation

@mcgibbon
Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon commented Mar 12, 2026

This PR adds timing for the SHT and for the torch implementation of DISCO convolution through a new benchmarking framework, run through python -m torch_harmonics.benchmark.

This is largely taken from the implementation we used/I authored in https://github.com/ai2cm/ace

mcgibbon and others added 3 commits March 12, 2026 15:30
Introduce a torch_harmonics.benchmark subpackage with:
- Timer infrastructure (CUDATimer, NullTimer, CPUEventPair) for GPU
  event-based and CPU wall-clock timing
- BenchmarkABC base class with registry via @register_benchmark
- CLI runner (python -m torch_harmonics.benchmark) that saves JSON results
- RealSHT and InverseRealSHT benchmarks at 1-degree resolution

Also add benchmark_results to .gitignore.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register a disco_conv_s2_torch_1deg benchmark at 1-degree resolution
(B=4, 4 channels, 180x360) using the non-optimized torch contraction
path, which does not require the custom CUDA extension.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce hardware.py with a device-name-to-scale-factor lookup table
so benchmark batch sizes adapt to different GPUs. Base batch sizes are
tuned for Tesla T4 (factor 1.0). Unknown devices default to 1.0 with
a warning to add an entry for their hardware.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mcgibbon
Copy link
Copy Markdown
Contributor Author

@azrael417 you may find this helpful to check the SHT timings on your hardware, for #155. You'll want to insert new batch size scaling factors to fully occupy the hardware. I tried to make it straightforward to add new benchmarks.

The entrypoint will create git-tag labelled json files under benchmark_results/ in the directory you run it from (location modifiable by flag).

@azrael417
Copy link
Copy Markdown
Collaborator

Hello Jeremy, thanks for putting this together.
I have added multiple things to the MR. This is what I added:

  • backward benchmark
  • device selection support
  • batch size override
    Can you please have a look and see if you are OK with it?

@mcgibbon
Copy link
Copy Markdown
Contributor Author

Hello Jeremy, thanks for putting this together. I have added multiple things to the MR. This is what I added:

  • backward benchmark
  • device selection support
  • batch size override
    Can you please have a look and see if you are OK with it?

Thanks @azrael417 . I don't see these commits in the history on this PR, can you link me to where I can check them out?

@azrael417
Copy link
Copy Markdown
Collaborator

azrael417 commented Mar 20, 2026

Oh, I seemingly cannot push them automatically to your branch. Those ended up in branch pr-157. Feel free to review and potentially merge those into this PR.

torch.cuda.synchronize()

@classmethod
def run_forward_benchmark(cls, iters=10, warmup=1) -> BenchmarkResult:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should separate out forward and backward benchmarks like this. Rather, run_instance is free to define a timer.context("forward") and timer.context("backward") block as separate blocks if it so chooses, without being required to do so. That way the backward benchmark can also take advantage of the work from the forward benchmark, instead of repeating it.

I'll refactor the existing benchmarks so they time the backward pass, and remove the "backward" framework infrastructure.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this so that you can run forward and backward independently. For example when we implement a new kernel we first implement and optimize the forward pass. In this case, there is no backward defined and we do not want to run that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining - this is already handled by the existing framework. When you implement a new kernel, you write a benchmark with only a forward pass, but you can put it inside a child timer block for "forward". The backward block will not get reported. When you add a backward pass, you can add it into the existing benchmark with the backward child timer. Yes the total benchmark time will change on that commit, but the commit includes the benchmark update, and the forward child timings are still directly comparable before and after.

Replace run_instance_forward/run_instance_backward with a single
run_instance that uses timer.child("forward") and timer.child("backward")
to time phases within one call. This lets the backward pass reuse the
forward computation and keeps the benchmark API simpler. Also fix
gradient accumulation across iterations and remove unnecessary
retain_graph=True.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mcgibbon
Copy link
Copy Markdown
Contributor Author

mcgibbon commented Mar 23, 2026

What is the use case for the batch size override? I was hoping the GPU-dependent factors would handle this, and was thinking the benchmark code should set its problem size.

I'm a little worried that CLI-set batch sizes will result in different benchmark runs/output files using different batch sizes, which doesn't show up in the filename or in the result. That means we can no longer be confident the output directory contains directly comparable benchmarks. At least when batch sizes get changed by modifying the benchmark code, this is reflected in the git sha (or the -dirty suffix) changing.

Comment on lines +69 to +71
If a global override has been set via set_batch_size(), that value is
returned directly. Otherwise the base is scaled by the hardware factor
(tuned relative to a Tesla T4). Always returns at least 1.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit counter-intuitive, I wouldn't expect a helper function scale_batch_size to access globals or do this kind of behavior. We should override at a higher level in the code where it's more appropriate.

Comment on lines +67 to +68
cpu_timer = CPUEventPair()
cpu_timer.record_start()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should use with cpu_timer: instead of record_start/record_end.

) -> int:
set_device(device)
if batch_size is not None:
set_batch_size(batch_size)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid a lot of indirection if we explicitly pass cls.run_benchmark(iters=iters, batch_size=batch_size) with a default batch_size=None on that method. Let's refactor to do that, and delete this global state.

)

@abc.abstractmethod
def run_instance(self: Self, timer: Timer) -> TensorDict:
Copy link
Copy Markdown
Contributor Author

@mcgibbon mcgibbon Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also need to be modified to take in an optional batch_size. I'm wondering now though what the use case is for the batch size needing to be changed in the CLI rather than changing the benchmark. Overriding the batch size breaks the promise the output file makes that it's giving the timings for "this benchmark". e.g. if a user runs two executions on different git shas with different batch size arguments, the timings could be different even if the code is the same, and there's no way in the output files to tell whether this is because of random machine noise or because two different args were passed. This situation is a lot worse when some of the code did change.

Clearly batch size needs to be tuned on different hardware, but that's why the hardware-specific scaling exists, and the hardware is included in the benchmark filename.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with making out GPU utilization because in a realistic case you do not do that memory wise. Therefore, it is good to understand how a kernel scales with batch size. Having a new config entry just for a different batch size is very cumbersome.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with making out GPU utilization because in a realistic case you do not do that memory wise.

I don't quite understand this sentence, but the next one makes sense and I think might be saying the same thing.

Therefore, it is good to understand how a kernel scales with batch size. Having a new config entry just for a different batch size is very cumbersome.

That's a great use case, thanks for explaining it. I think it would make sense for the benchmark to automatically run with more than one batch size, perhaps a base value and then double that value. I'll update the code to do that, in a way that the framework does it for each benchmark.

For now I'll have the benchmarks take in "batch size factor" argument to the init function, and have this included in the filename. By default the benchmarks will run with 1x and 2x factors, but these factors will be configurable as a list from the command line.

Copy link
Copy Markdown
Contributor Author

@mcgibbon mcgibbon Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think you meant "maxing out". You could use this batch size factor feature, or another option is to use benchmarks that explicitly are tuned not to maximize gpu utilization. For the batch size factor, perhaps we'd retune the base case to be a low utilization case, and then use two factors that will each be in the maximized regime.

I would suggest not using these isolated benchmarks as a way to measure "realistic" performance, in the sense of what you'll get when running FCN3 on top-end hardware. I don't think they can accomplish this well. The purpose of these isolated benchmarks is to, well, isolate the code. At sufficiently small problem sizes, overhead dominates and all versions of the code run in the same amount of time, in a way that also doesn't reflect realistic use.

When I say I'm maxing out GPU occupancy, I just mean that I'm increasing the problem size until (diff_run_time/diff_problem_size) asymptotes, though I'm doing this manually and not very well - you could likely do better. This by far does not mean maximum memory utilization. On my T4 I'm kind of targeting 30ms+ run-times. I find when the run times are 1-7ms, the execution time is quite insensitive to significant changes in memory ordering.

You can write a benchmark that uses child timers to fully map out timings within for example FCN3, and that code can tell you how much time is spent in each section in that realistic case. We've done this in our SFNO, where the block takes in a timer argument. I'm considering refactoring this to use our GlobalTimer singleton class instead so it doesn't show up in method signatures. Some leakage does occur when memory ordering changes, e.g. you may get one step faster by making a change that delays a contiguous call into a later block. But this is better for including also the impact different kernels have on one another during execution (e.g. these changes in memory ordering do affect the execution time).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow. Yes, you have overhead when launching a single layer/kernel but you can compare that kernel to other variants since the overhead is similar. For the true kernel performance, one can also collect a profile and look at the raw timings in the profile.

Address review feedback to use `with cpu_timer:` instead of explicit
record_start/record_end calls.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mcgibbon
Copy link
Copy Markdown
Contributor Author

mcgibbon commented Mar 23, 2026

For now I'm going to remove the batch_size override from this PR, because of the concerns I have about it. It breaks the main purpose of this code, which is to compare performance across different git sha as the code evolves. But if there's a specific use case you need it for, I can add it back in or revert the commit removing it.

Ready for another look @azrael417 . I will be out for a little over 2 weeks starting this weekend.

mcgibbon and others added 2 commits March 23, 2026 18:23
The batch size override breaks reproducibility — output files can't
distinguish whether timing differences across runs come from code
changes or different CLI arguments. Each benchmark should define its
own appropriate batch size via hardware scaling.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move hardware-specific batch size scaling out of scale_batch_size and
into run.py, so individual benchmarks no longer implicitly read global
state. Each benchmark now runs with multiple user-configurable batch
size factors (default 1x and 2x), with the factor included in the
output filename. The --batch-size-factors CLI arg controls which
factors to use.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mcgibbon
Copy link
Copy Markdown
Contributor Author

Batch size scaling factor is added, could definitely tweak the base batch sizes and default list of scale factors at your request. If you comment what you'd like on the PR I'll have Claude make those updates, or I can pull changes from a branch again. Alternatively, feel free to tweak after merge.

Copy link
Copy Markdown
Collaborator

@azrael417 azrael417 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having a batch size override is very important for quick kernel profiling and optimization iterations. We never run batch sizes > 1 and I do not want to have to edit a config and do PR when I change some benchmark settings. Since this is no competition benchmark and just for quick testing, I do not see a problem with having some flexibility here. One can write all the benchmark parameters into the benchmark file as well so that the results are comparable.

"Defaults to 'cuda' if available, otherwise 'cpu'.",
)
parser.add_argument(
"--batch-size-factors",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this explicit batch size override instead of a factor?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I will add this when I get back in a couple weeks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants