-
-
Notifications
You must be signed in to change notification settings - Fork 12.5k
[4/n] Migrate pos_encoding sampler and fused_qknorm_rope to libtorch stable ABI #31842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[4/n] Migrate pos_encoding sampler and fused_qknorm_rope to libtorch stable ABI #31842
Conversation
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
Signed-off-by: Mikayla Gawarecki <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and well-executed effort to migrate numerous CUDA kernels to PyTorch's stable ABI, which will improve forward compatibility. The refactoring is extensive, touching many files to adopt stable tensor APIs, dispatch macros, and header-only includes. However, I've identified a recurring critical issue across several files: input.get_device() is incorrectly used to obtain the CUDA stream. This will lead to a compilation failure because the stable ABI function get_current_cuda_stream requires a device index (int32_t), which should be retrieved with input.get_device_index(). I have provided specific comments and suggestions for each occurrence. Addressing these compilation errors should put the PR in excellent shape.
| torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \ | ||
| cudaStream_t stream = get_current_cuda_stream(input.get_device()); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \ | ||
| cudaStream_t stream = get_current_cuda_stream(input.get_device()); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \ | ||
| cudaStream_t stream = get_current_cuda_stream(input.get_device()); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| torch::stable::accelerator::DeviceGuard device_guard(input.get_device()); \ | ||
| cudaStream_t stream = get_current_cuda_stream(input.get_device()); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| const torch::stable::accelerator::DeviceGuard device_guard( | ||
| input.get_device()); | ||
| const cudaStream_t stream = get_current_cuda_stream(input.get_device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| const torch::stable::accelerator::DeviceGuard device_guard( | ||
| input.get_device()); | ||
| const cudaStream_t stream = get_current_cuda_stream(input.get_device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
| const torch::stable::accelerator::DeviceGuard device_guard( | ||
| input.get_device()); | ||
| const cudaStream_t stream = get_current_cuda_stream(input.get_device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_current_cuda_stream expects a device index of type int32_t, but input.get_device() returns a torch::stable::Device object. This will cause a compilation error. You should use input.get_device_index() instead. While DeviceGuard can accept a Device object, using get_device_index() for both is more consistent.
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
Signed-off-by: Mikayla Gawarecki <[email protected]>
07f7e6d to
5eee0b9
Compare
Purpose
Stacked on #31547
Test Plan
pytest tests/kernels/core/test_pos_encoding.py -v
pytest tests/kernels/core/test_rotary_embedding.py -v
pytest tests/kernels/core/test_apply_rotary_emb.py -v
pytest tests/kernels/core/test_fused_qk_norm_rope.py -v
pytest tests/kernels/test_apply_repetition_penalties.py -v
pytest tests/kernels/test_top_k_per_row.py -v
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.