Skip to content

Commit 24771fc

Browse files
committed
...
1 parent 0329be3 commit 24771fc

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

benchmark/examples/benchmark_moe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ def parse_args():
138138
],
139139
help="MoE fusion mode selector",
140140
)
141+
parser.add_argument(
142+
"--gemm_sms",
143+
type=int,
144+
default=None,
145+
help="Override GEMM_SMS for WG-specialized variant (default: auto)",
146+
)
141147
return parser.parse_args()
142148

143149

@@ -163,6 +169,7 @@ def _run_dist_once(
163169
n_expts_act,
164170
shmem,
165171
fusion_config,
172+
gemm_sms=None,
166173
):
167174
return mixture_of_expt_epsharded(
168175
x_dp_local,
@@ -173,6 +180,7 @@ def _run_dist_once(
173180
n_expts_act,
174181
shmem,
175182
fusion_config=fusion_config,
183+
gemm_sms=gemm_sms,
176184
)
177185

178186

@@ -249,6 +257,7 @@ def _worker(rank: int, world_size: int, init_url: str, args):
249257
args.n_expts_act,
250258
shmem,
251259
fusion_config,
260+
args.gemm_sms,
252261
)
253262

254263
if args.validate or args.compare_single_gpu:
@@ -275,6 +284,7 @@ def _worker(rank: int, world_size: int, init_url: str, args):
275284
shmem,
276285
fusion_config=fusion_config,
277286
timing_dict=td,
287+
gemm_sms=args.gemm_sms,
278288
)
279289
if rank == 0:
280290
for j in range(1, len(td)):

examples/31_expert_sharded_moe/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def mixture_of_expt_epsharded(
209209
shmem,
210210
fusion_config: MoeFusionConfig | None = None,
211211
timing_dict: dict | None = None,
212+
gemm_sms: int | None = None,
212213
):
213214
"""Expert-parallel MoE forward using iris symmetric heap.
214215
@@ -342,6 +343,7 @@ def _tick(label):
342343
combine_indx,
343344
shmem,
344345
ragged_metadata=y_ep_local_metadata,
346+
gemm_sms=gemm_sms,
345347
)
346348
_tick("wg_fused_matmul_scatter")
347349
else:

0 commit comments

Comments
 (0)