You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I recommend reading the [excellent blog by Yifan Yang](https://yang-yifan.github.io/blogs/mma_swizzle/mma_swizzle.html#6-how-transposed-input-is-handled) blog to understand swizzling. This section is basically a condensation of the blog and a few of notes.
330
+
I recommend reading the [excellent blog by Yifan Yang](https://yang-yifan.github.io/blogs/mma_swizzle/mma_swizzle.html#6-how-transposed-input-is-handled) blog to understand swizzling. This section is basically a quick summary of the blog.
331
331
332
-
Swizzling refers to arranging the data in the SMEM in a manner to avoid bank conflits while reading or writing data to and from SMEM.
332
+
Swizzling refers to arranging the data in the SMEM in a manner to avoid bank conflits while reading or writing data from SMEM.
333
333
334
334
<imgsrc="https://yang-yifan.github.io/blogs/mma_swizzle/figures/swizzle_none_k.png"alt="Swizzle pattern for shared memory"style="max-width: 600px; display: block; margin: 0auto;">
335
335
336
-
For the Tensor core instriction requires $8 x 16B$ data from GMEM which can be loaded into the SMEM using threads // check if any other way// and fed into the tensor core using the above `ldmatrix` instruction. This works well since there are no bank conflicts in this case. Thread 0 loads from 32bits bank 0, thread 1 from bank 1, and so on till thread 31 from bank31.
336
+
For the Tensor core instruction requires 8 x 16B data from GMEM which can be loaded into the SMEM using threads // check if any other way// and fed into the tensor core using the above `ldmatrix` instruction. This works well since there are no bank conflicts in this case. Thread 0 loads from 32bits bank 0, thread 1 from bank 1, and so on till thread 31 from bank31.
337
337
338
338
This works well but notice that while loading from GMEM we load 8 chunks of 16B contiguous memory, which means 8 load instructions, whereas GPUs support up to 128B contiguous load. Also since loading from GMEM has a very high latency as compared loading from L2, SMEM or registers, we would like to load larger chunks.
339
339
340
340
But if we load 8 chunks of 32B contigous memory and store them in SMEM contiguously, we will have bank conflicts while reading from SMEM.
341
341
342
-
Now we will have multiple 2 way bank conflicts since both $\text{thread}\_0$ and $\text{thread}\_{16}$ will read from bank 0 and same for every $\text{thread}\_i$ and $\text{thread}\_{i+16}$ . If we add a 32B swizzling, we avoid bank conflicts.
342
+
Now we will have multiple 2 way bank conflicts since both thread0 and thread16 will read from bank 0 and same for every $\text{thread}\_i$ and $\text{thread}\_{i+16}$ . If we add a 32B swizzling, we avoid bank conflicts.
343
343
<imgsrc="https://yang-yifan.github.io/blogs/mma_swizzle/figures/why_swizzle.png"alt="Swizzle pattern for shared memory"style="max-width: 800px; display: block; margin: 0auto;">
344
344
345
345
```copied
@@ -352,48 +352,47 @@ A new concept called 16B atomicity. This is saying for a 16B chunk that is conti
352
352
-[CUDA Mode Video on Tensor Cores](https://www.youtube.com/watch?v=hQ9GPnV0-50&t=3968s)
353
353
354
354
355
-
### Ping-Pong
355
+
<!--### Ping-Pong-->
356
356
357
357
<!-- For Ping-Pong, each warp group takes on a specialized role of either Data producer or Data consumer. The producer warp group focuses on producing data movement to fill the shared memory buffers (via TMA). Two other warp groups are dedicated consumers that process the math (MMA) portion with tensor cores, and then do any follow up work and write their results back to global memory (epilogue)
358
358
359
359
The producer can feed data to Tensor cores of Consumers. While one consumer is using the Tensor cores for Main Loop (MMA), the other can work on Epilogue which uses the CUDA cores. Thereby maximizing the utilization of Tensor cores -->
360
360
361
-
## GEMM flow in Blackwell
362
-
363
-
Full GEMM: (Gemm_M × Gemm_N) output, iterating over Gemm_K
364
-
│
365
-
Cluster Tile: Multiple CTAs in a cluster TOGETHER compute a larger tile
// Each CTA in the cluster handles: 128 × 256 (half the M dimension)
392
-
393
-
The cluster doesn't work on ONE MMA tile together - rather, multiple CTAs in a cluster each handle their own MMA tile, but they can share data via distributed shared memory and synchronize.
// Each CTA in the cluster handles: 128 × 256 (half the M dimension)
443
+
444
+
The cluster doesn't work on ONE MMA tile together - rather, multiple CTAs in a cluster each handle their own MMA tile, but they can share data via distributed shared memory and synchronize.
418
445
419
446
### Loading from SMEM in Blackwell
420
447
@@ -433,12 +460,13 @@ GPU Memory controller can issue upto 128B load from SMEM in a single cycle. Also
433
460
434
461
### Load tiles into SMEM using TMA
435
462
436
-
Load BMxBK and BKxBN tiles from 64x64 fp16 (8192B) tiles from GMEM to SMEM. TMA and tensor cores operates on "core matrices" which are 8x16B of data which for half is 8x8 tile of data. Which means we need to load (64/8)x(64/8) == (8x8) core matrices. While loading data in SMEM we need to keep in mind that it will be fed to Tensor cores (tcgen05) which expects the data in a certain format. TMA can load a column of 8 core matrices (1024B) (8,1) at a time which means to load 8192B we load 8 times. Use Tcgen05.mma instruction and store the results in TMEM. Move results from TMEM to registers and finally to GMEM
463
+
Load BMxBK and BKxBN tiles from 64x64 fp16 (8192B) tiles from GMEM to SMEM. TMA and tensor cores operates on "core matrices" which are 8x16B of data which for half is 8x8 tile of data. Which means we need to load (64/8)x(64/8) == (8x8) core matrices. While loading data in SMEM we need to keep in mind that it will be fed to Tensor cores (tcgen05) which expects the data in a certain format. TMA can load a column of 8 core matrices (1024B) (8,1) at a time which means to load 8192B we load 8 times. Use `Tcgen05.mma` instruction and store the results in TMEM. Move results from TMEM to registers and finally to GMEM
437
464
438
465
439
466
440
-
## References
467
+
## References & Recommended resources
441
468
469
+
-[Articles by colfax research](https://research.colfax-intl.com/blog/)
442
470
-[CUDA Training Series by NVIDIA and OLCF](https://www.olcf.ornl.gov/cuda-training-series/)
443
471
-[CUDA Training Series YouTube Playlist](https://www.youtube.com/playlist?app=desktop&list=PL6RdenZrxrw-zNX7uuGppWETdxt_JxdMj)
444
472
-[CUDA Training Exercises](https://github.com/olcf/cuda-training-series/tree/master/exercises)
INCLUDES = -I/opt/cutlass/include -I/opt/cutlass/examples/cute/tutorial/blackwell
115
+
116
+
mma: mma.o
117
+
$(NVCC)$(FLAGS) -o $@$^
118
+
119
+
mma.o: $(SRC)
120
+
$(NVCC)$(FLAGS)$(INCLUDES) -dc -o $@$<
121
+
122
+
clean:
123
+
rm -f *.o mma
124
+
```
125
+
126
+
Run the command
127
+
128
+
```bash
129
+
make && ./mma
130
+
```
131
+
132
+
It does code-generation in two stages:
98
133
99
134
Stage | What nvcc produces | Option that drives it
100
135
1. Front-end | PTX for a virtual architecture (“compute XX…”) | arch= inside -gencode (or --gpu-architecture)
@@ -159,6 +194,82 @@ This will generate PTX code for compute capability 3.0, 5.2, and 7.0. Generate S
159
194
160
195
`compute_XX` refers to a PTX version and sm_XX refers to a cubin version and the `arch=` clause must always be a PTX version, while the `code=` clause can be cubin or PTX or both
161
196
197
+
# CUTLASS/CuTe Development
198
+
199
+
clangd for IDE features like, Go to definition, Hover documentation, Auto-completion, Error diagnostics. Clangd is part of the LLVM/Clang project and understands C++ deeply. Unlike simple syntax highlighting, clangd actually compiles the code in the background to understand types, templates, and symbols.
200
+
201
+
- Install `clangd` Extension in VS Code. If prompted, let it download the clangd binary
202
+
- If you have Microsoft's "C/C++" extension installed, disable its IntelliSense to avoid conflicts by: Settings → search `C_Cpp.intelliSenseEngine` → set to `disabled`
203
+
- Create a `.clangd` file in the project root. This file tells clangd how to compile your code.
204
+
205
+
---
206
+
207
+
### Example Configuration
208
+
209
+
```yaml
210
+
CompileFlags:
211
+
Add:
212
+
- "-xc++"
213
+
- "-std=c++17"
214
+
- "-I/path/to/cutlass/include"
215
+
- "-I/path/to/cutlass/tools/util/include"
216
+
Remove:
217
+
- "-forward-unknown-to-host-compiler"
218
+
- "--generate-code*"
219
+
- "-gencode*"
220
+
```
221
+
222
+
### Flag Explanations
223
+
224
+
| Flag | Purpose |
225
+
|------|---------|
226
+
| `-xc++` | Treat `.cu` files as C++ (clangd doesn't understand CUDA natively) |
227
+
| `-std=c++17` | Use C++17 standard (CUTLASS requires C++17) |
228
+
| `-I/path/to/include` | Include paths - where to find headers |
229
+
230
+
### Remove Flags
231
+
232
+
These are nvcc-specific flags that clang doesn't understand:
233
+
- `-forward-unknown-to-host-compiler`
234
+
- `--generate-code*`
235
+
- `-gencode*`
236
+
---
237
+
238
+
- Finally Restart clangd using `Cmd+Shift+P` → `clangd: Restart language server`
239
+
240
+
---
241
+
242
+
## Optional: Thrust/CUB Support
243
+
244
+
If you want Thrust headers to work (for `thrust::device_vector`, etc.), download them separately:
**nsys**: CLI for Nsight Systems which supports system wide profiling.
@@ -168,7 +279,6 @@ This will generate PTX code for compute capability 3.0, 5.2, and 7.0. Generate S
168
279
**nvprof**: CLI for the NVIDIA Visual Profiler which supports profiling and tracing of CUDA applications. It is deprecated in CUDA 11.0 and will be removed in a future release.
0 commit comments