-
Notifications
You must be signed in to change notification settings - Fork 612
Fix bugs in permutation custom partitioning #2617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rd workspace correctly based on 2nd dim of routing_map/row_id map Signed-off-by: DoubleCheeseCheetos <hanhdp99@gmail.com>
|
/te_ci jax |
Greptile SummaryFixed two critical bugs in permutation custom partitioning that caused incorrect behavior for larger test cases:
The workspace shape is Test case sizes reduced from 256 to 64 experts to fit within GPU memory constraints on L40 and A100 hardware in CI. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant MakeRowIdMap
participant Pass1 as RowIdMapPass1Primitive
participant Pass2 as RowIdMapPass2Primitive
participant Abstract1 as Pass1.abstract
participant Abstract2 as Pass2.abstract
participant Partition1 as Pass1.partition
participant Partition2 as Pass2.partition
Caller->>MakeRowIdMap: make_row_id_map(routing_map)
Note over MakeRowIdMap: block_size = DEFAULT_BLOCK_SIZE (1024)
MakeRowIdMap->>Pass1: bind(routing_map, block_size)
Pass1->>Abstract1: abstract(routing_map_aval, block_size)
Note over Abstract1: FIX 1: Use block_size param<br/>instead of DEFAULT_BLOCK_SIZE<br/>workspace_shape = (experts, cdiv(tokens, block_size))
Abstract1-->>Pass1: (row_id_map_aval, workspace_aval)
Pass1->>Partition1: partition(block_size, mesh, arg_infos)
Note over Partition1: FIX 2: Workspace sharding changed<br/>from PartitionSpec(None, None)<br/>to PartitionSpec(None, tokens_axis)
Partition1-->>Pass1: (mesh, sharded_impl, out_shardings)
Pass1-->>MakeRowIdMap: (row_id_map, workspace)
MakeRowIdMap->>Pass2: bind(row_id_map, workspace, block_size)
Pass2->>Abstract2: abstract(row_id_map_aval, workspace_aval, block_size)
Note over Abstract2: FIX 1: Use block_size param<br/>workspace_shape = (experts, cdiv(tokens, block_size))
Abstract2-->>Pass2: (row_id_map_aval, workspace_aval)
Pass2->>Partition2: partition(block_size, mesh, arg_infos)
Note over Partition2: FIX 2: Workspace sharding changed<br/>from PartitionSpec(None, None)<br/>to PartitionSpec(None, tokens_axis)
Partition2-->>Pass2: (mesh, sharded_impl, out_shardings)
Pass2-->>MakeRowIdMap: (row_id_map, workspace)
MakeRowIdMap-->>Caller: row_id_map
|
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
… and A100 in CI line up Signed-off-by: tdophung <hanhdp99@gmail.com>
|
/te_ci jax |
|
/te-ci jax |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Description
There were 2 things that were wrong in the custom partitioning of permutation that was not detected in the L0 tests (as the cases were too small):
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: