Skip to content

Conversation

@tdophung
Copy link
Collaborator

Description

There were 2 things that were wrong in the custom partitioning of permutation that was not detected in the L0 tests (as the cases were too small):

  1. Block size needs to propagate through to primitive, instead of using DEFAULT_BLOCK_SIZE
  2. workspace (an intermediate result from the 3 passes to create row_id_map), needs to be sharded too, according to the routing_map/row_id_map, not just replicated.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  1. pass block_size through primitives
  2. shard workspace

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…rd workspace correctly based on 2nd dim of routing_map/row_id map

Signed-off-by: DoubleCheeseCheetos <hanhdp99@gmail.com>
@tdophung
Copy link
Collaborator Author

/te_ci jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Summary

Fixed two critical bugs in permutation custom partitioning that caused incorrect behavior for larger test cases:

  • Block size propagation: The abstract methods in RowIdMapPass1Primitive and RowIdMapPass2Primitive now use the block_size parameter instead of hardcoded DEFAULT_BLOCK_SIZE when calculating workspace shape
  • Workspace sharding: The workspace tensor is now correctly sharded on the token dimension (PartitionSpec(None, tokens_axis)) instead of being replicated (PartitionSpec(None, None)), ensuring proper distribution across devices

The workspace shape is (num_experts, cdiv(num_tokens, block_size)) where the second dimension depends on num_tokens, so it must be sharded consistently with the token dimension to ensure each shard processes its local tokens correctly.

Test case sizes reduced from 256 to 64 experts to fit within GPU memory constraints on L40 and A100 hardware in CI.

Confidence Score: 5/5

  • This PR is safe to merge - it fixes well-documented bugs with clear, minimal changes
  • The changes are straightforward bug fixes that address two specific issues: (1) using the correct block_size parameter instead of the default constant, and (2) properly sharding the workspace tensor. The fixes are minimal, well-commented, and consistent across all affected methods. The test adjustments ensure CI can run on available hardware.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/permutation.py Fixed block_size propagation in abstract methods and corrected workspace sharding from replicated to token-dimension sharded
tests/jax/test_permutation.py Reduced largest test case from 256 to 64 experts to fit within L40 and A100 GPU memory constraints in CI

Sequence Diagram

sequenceDiagram
    participant Caller
    participant MakeRowIdMap
    participant Pass1 as RowIdMapPass1Primitive
    participant Pass2 as RowIdMapPass2Primitive
    participant Abstract1 as Pass1.abstract
    participant Abstract2 as Pass2.abstract
    participant Partition1 as Pass1.partition
    participant Partition2 as Pass2.partition

    Caller->>MakeRowIdMap: make_row_id_map(routing_map)
    Note over MakeRowIdMap: block_size = DEFAULT_BLOCK_SIZE (1024)
    
    MakeRowIdMap->>Pass1: bind(routing_map, block_size)
    Pass1->>Abstract1: abstract(routing_map_aval, block_size)
    Note over Abstract1: FIX 1: Use block_size param<br/>instead of DEFAULT_BLOCK_SIZE<br/>workspace_shape = (experts, cdiv(tokens, block_size))
    Abstract1-->>Pass1: (row_id_map_aval, workspace_aval)
    
    Pass1->>Partition1: partition(block_size, mesh, arg_infos)
    Note over Partition1: FIX 2: Workspace sharding changed<br/>from PartitionSpec(None, None)<br/>to PartitionSpec(None, tokens_axis)
    Partition1-->>Pass1: (mesh, sharded_impl, out_shardings)
    Pass1-->>MakeRowIdMap: (row_id_map, workspace)
    
    MakeRowIdMap->>Pass2: bind(row_id_map, workspace, block_size)
    Pass2->>Abstract2: abstract(row_id_map_aval, workspace_aval, block_size)
    Note over Abstract2: FIX 1: Use block_size param<br/>workspace_shape = (experts, cdiv(tokens, block_size))
    Abstract2-->>Pass2: (row_id_map_aval, workspace_aval)
    
    Pass2->>Partition2: partition(block_size, mesh, arg_infos)
    Note over Partition2: FIX 2: Workspace sharding changed<br/>from PartitionSpec(None, None)<br/>to PartitionSpec(None, tokens_axis)
    Partition2-->>Pass2: (mesh, sharded_impl, out_shardings)
    Pass2-->>MakeRowIdMap: (row_id_map, workspace)
    
    MakeRowIdMap-->>Caller: row_id_map
Loading

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

… and A100 in CI line up

Signed-off-by: tdophung <hanhdp99@gmail.com>
@tdophung
Copy link
Collaborator Author

/te_ci jax

@tdophung
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@tdophung tdophung merged commit 52ee5ea into NVIDIA:main Jan 23, 2026
22 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants