Skip to content

Conversation

@mikaylagawarecki
Copy link

@mikaylagawarecki mikaylagawarecki commented Jan 6, 2026

Purpose

Stacked on #31547

Test Plan

pytest tests/kernels/core/test_pos_encoding.py -v
pytest tests/kernels/core/test_rotary_embedding.py -v
pytest tests/kernels/core/test_apply_rotary_emb.py -v
pytest tests/kernels/core/test_fused_qk_norm_rope.py -v
pytest tests/kernels/test_apply_repetition_penalties.py -v
pytest tests/kernels/test_top_k_per_row.py -v

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mikaylagawarecki mikaylagawarecki changed the title Stable abi phase3 [4/n] Migrate pos_encoding sampler and fused_qknorm_rope Jan 6, 2026
@mergify mergify bot added ci/build nvidia cpu Related to CPU backends labels Jan 6, 2026
@mikaylagawarecki mikaylagawarecki changed the title [4/n] Migrate pos_encoding sampler and fused_qknorm_rope [4/n] Migrate pos_encoding sampler and fused_qknorm_rope to libtorch stable ABI Jan 6, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed effort to migrate numerous CUDA kernels to PyTorch's stable ABI, which will improve forward compatibility. The refactoring is extensive, touching many files to adopt stable tensor APIs, dispatch macros, and header-only includes. However, I've identified a recurring critical issue across several files: input.get_device() is incorrectly used to obtain the CUDA stream. This will lead to a compilation failure because the stable ABI function get_current_cuda_stream requires a device index (int32_t), which should be retrieved with input.get_device_index(). I have provided specific comments and suggestions for each occurrence. Addressing these compilation errors should put the PR in excellent shape.

Comment on lines +123 to +124
torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \
cudaStream_t stream = get_current_cuda_stream(input.get_device()); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
  cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +289 to +290
torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \
cudaStream_t stream = get_current_cuda_stream(input.get_device()); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
  cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +304 to +305
torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \
cudaStream_t stream = get_current_cuda_stream(input.get_device()); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
  cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +380 to +381
torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \
cudaStream_t stream = get_current_cuda_stream(input.get_device()); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
  cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +271 to +273
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device());
const cudaStream_t stream = get_current_cuda_stream(input.get_device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  const torch::stable::accelerator::DeviceGuard device_guard(
      input.get_device_index());
  const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +146 to +148
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device());
const cudaStream_t stream = get_current_cuda_stream(input.get_device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  const torch::stable::accelerator::DeviceGuard device_guard(
      input.get_device_index());
  const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Comment on lines +219 to +221
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device());
const cudaStream_t stream = get_current_cuda_stream(input.get_device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.

  const torch::stable::accelerator::DeviceGuard device_guard(
      input.get_device_index());
  const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends nvidia

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant