AI Usage
What AI was used for
- creating most of the code in the python scripts
- editing the blog posts
- sparingly used for C++ code
Benchmarking and Profiling
Devices were tested on
- CUDA 12.8
- Driver version 570.133.20
torch==2.5.1+cu124
- fork of flash attention v2 fixed to commit
b36ad4e
(closest tagged version is v2.7.4), which includes the following changes:- removing non-relevant kernels for building to speed up build time
- adding cuda timing inside the flash attention C++ API
- removing the allocation, computation, and storing back to gmem of
lse
. the kernel we build doesn’t writelse
back. notwithstanding, the impact of this was very minimal, amounting to ~0.1% of difference - link to repo
Benchmarking
We benchmark the following sequence lengths:
- 512
- 1024
- 2048
- 4096
- 8192
- 16384
When benchmarking, each kernel was run 32 times for warmup and then timed 128 times for TFLOPs calculation. The practices in this blog post were followed for more accurate benchmarks.
When benchmarking, devices were run with the following clocks:
Device | SM Clock | DRAM Clock |
---|---|---|
RTX 3090 | 1680 MHz | 9501 MHz |
A100 PCIe 80GB | 1110 MHz | 1512 MHz |
Note: 1512 MHz is the only valid DRAM clock on the A100. |
You can get the set of valid clocks for a GPU by using this command.
nvidia-smi --query-supported-clocks=gr,mem --format=csv
Profiling
For Nsight compute profile metrics, we run each kernel 32 times on sequence length 4096 to reduce metric variance. The metrics post in the blog are the mean values recorded, unless the metric is from a screenshot of Nsight Compute.
The command used to profile kernels:
ncu \
--config-file off \
--export /path/to/profile \
--force-overwrite \
--target-processes application-only \
--kernel-name regex:device|flash \
--warp-sampling-interval 1 \
--warp-sampling-max-passes 1000 \
--warp-sampling-buffer-size 536870912 \
--set full \
--apply-rules no \
--import-source no \
/path/to/python /path/to/repo/tools/benchmark/run_kernels.py 4096 128 --kernels $KERNEL_CONFIG_NAME \
--n_runs 32
Calculating Arithmetic Intensity
Here are the python functions used to calculate arithmetic intensity for a single tile pair.
ELEM_SIZE = 2 # bytes
def softmax_flop(B_r, B_c, d_head) -> int:
return B_r * (5 * B_c + d_head + 4)
def tile_flop(B_r, B_c, d_head) -> int:
QK_flops = B_r * d_head * B_c
PV_flops = B_r * B_c * d_head
softmax_flops = softmax_flop(B_r, B_c, d_head)
return QK_flops + PV_flops + softmax_flops
def gmem_transfer_size(B_r, B_c, d_head) -> int:
return d_head * 2 * (B_r + B_c) * ELEM_SIZE
def arithmetic_intensity(B_r, B_c, kv_seq_len, d_head) -> float:
return (
tile_flop(B_r, B_c, d_head) * (kv_seq_len // B_c)
) / gmem_transfer_size(B_r, kv_seq_len, d_head)
Compute Capability
Different CUDA GPUs have different compute capabilities, which is a term to distinguish the features and instructions they support. This includes properties like the maximum amount of shared memory supported per SM and support for specific PTX instructions. It’s essentially hardware versioning.
The CUDA C++ Programming Guide has a table detailing these specifications. I include the ones most relevant compute capabilities here.
Device | Compute Capability | Max Shared Memory Per CTA/SM | Max # 32b Registers Per Thread | # 32b Registers Per SM |
---|---|---|---|---|
A100 | 8.0 | 163KB/164KB | 255 | 65536 |
RTX 3090 | 8.6 | 99KB/100KB | 255 | 65536 |
Glossary
CUDA/GPU Programming Terms
- GMEM: global memory, stored in DRAM
- SMEM: shared memory, stored in SRAM
- LMEM: local memory
- RF: register file / register memory
- CTA: cooperative thread array; same as thread-block. I’ll use CTA to avoid confusion with matrix blocks
- tiling: dividing a matrix into tiles, each processed by a CTA, warp, or thread
- fragment: a tile of a matrix stored in registers. In this series, this specifically refers to an
tile in rmem - Within a warp, each thread holds 2 values
- This requires a single 32-bit register for 16-bit values, and
- lane_id: thread index within a warp (
tid % 32
) - LD/ST: load/store operations
- mnk variables: standard naming convention for GEMM dimensions:
- For
D = AB^T + C
:- A is
(m, k)
- B is
(n, k)
- C and D are
(m, n)
- A is
- In our context,
k
corresponds to:in in
- For
- Arithmetic Intensity: # fp operations performed divided by the # bytes loaded. This can be different for different levels in the memory hierarchy. For instance, the # bytes loaded from the L1 cache can be different from the L2 cache.
Flash Attention Terms
Following the notation from the paper, with a few simplifications:
: Query and output tensors handled by the current CTA : The -th key/value tile : row-wise max of attention scores : row-wise sum of exponentiated attention scores : query rows in the block : key and rows in the and blocks : head dimension
Kernel Specification
- forward pass only
- non-causal attention
- head dimension = 128
- no dropout or KV caching
- equal query/key/value sequence lengths
- sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
- 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32
nvcc Flags for Register Spilling
We can configure nvcc
with these settings to warn us if we’re spilling registers and how much local memory we’re using.
-Xptxas=-warn-spills
: this will output a warning if a kernel is spilling registers into local memory
-ptxas warning : Registers are spilled to local memory in function 'kernel_name', 20 bytes spill stores, 20 bytes spill loads
-Xptxas=-warn-lmem-usage
: this will output a warning if a kernel uses local memory
-Local memory used for function 'kernel_name', size of stack frame: 8 bytes
--resource-usage
: will give information about the resources your kernel uses
If your kernel doesn’t spill any registers, the output will look something like this
ptxas info : Compiling entry function '...' for 'sm_80'
ptxas info : Function properties for ...
0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 163 registers, used 1 barriers, 448 bytes cmem[0]
ptxas info : Compile time = 24.428 ms
And if it does spill, then something like this
ptxas info : Function properties for ...
456 bytes stack frame, 668 bytes spill stores, 580 bytes spill loads
ptxas info : Used 255 registers, used 1 barriers, 456 bytes cumulative stack size, 448 bytes cmem[0]
ptxas info : Compile time = 69.922 ms
Instructions
mma instruction
The PTX instruction that we use for mma
is
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 d, a, b, c
This instruction has a fair number of modifiers, so let’s digest this piece by piece.
mma
meansmatrix multiply + accumulate
.sync.aligned
means that threads in the same warp cooperate with each other.sync
means threads will block until all threads execute the instruction.aligned
means all threads must execute the same instruction
m16n8k16
is the shape of our operation- where
has shape has shape has shape
- This means that we have
- 4 fragments of A = 4 registers / 8 elements
- 2 fragments of B = 2 registers / 4 elements
- 2 fragments of C (2x2 = 4 registers / 4 elements because FP32 accumulator)
- 2 fragments of D (2x2 = 4 registers because FP32 accumulator)
- where
.row.col
meansis in row major format and is in col major format .f32.f16.f16.f32
are the data types ofD, A, B, C
respectively.
The d, a, b, c
components expand to {d1_1, d1_2, d2_1, d2_2}, {a1_1, a2_1, a3_1, a4_1}, {b1_1, b2_1}, {c1_1, c1_2, c2_1, c2_2}
, where each variable is represents a register for {matrix}{fragment}_{register #}
.
cp.async vs Traditional Loads
You might be wondering: aren’t “normal” loads still asynchronous? If they weren’t, wouldn’t GPUs would be very slow? Yes, you’d be right: the GPU won’t initiate a load and then sit there idly waiting for the load to complete. Instead, it’ll initiate the load and continue executing until it encounters an instruction that depends on the load. Only then will it block. Therefore, with enough instruction-level parallelism (ILP), we can hide the latency.
So what’s the difference then? The way standard loads copy data from gmem to smem is:
- The data is first copied from gmem to the register file
- Then from the register file it gets copied to smem
cp.async
, on the other hand, completely bypasses the RF and copies directly into smem. In addition, when our accesses are 16 bytes and aligned, we can bypass the L1 cache with the .cg
modifier, which reduces cache pollution. When using the other modifier .ca
, L1 will not be bypassed.
As we’ll see later on, our register budget will be fairly constrained, so this can alleviate some of that pressure. cp.async
also includes the benefits of checkpointing, as discussed earlier.
The typical path when loading data from gmem on a cache miss is gmem -> L2 -> L1 -> registers -> smem
. However, with optimized async loads, it looks like gmem -> L2 -> smem
. Even if we were to bypass the L1 using cache hints, we would still have to load through L2 -> registers -> smem
.
If we take a look at the corresponding SASS,
for (int i = 0; i < N_COPIES; ++i) {
smem[i] = gmem[i];
}
compiles down to
// Copies from GMEM -> RF via L2 & L1 cache
LDG.E.128 R4, [R2.64]
LDG.E.128 R8, [R2.64+0x10]
LDG.E.128 R12, [R2.64+0x20]
LDG.E.128 R16, [R2.64+0x30]
// Copies from RF -> SMEM
STS.128 [RZ], R4
STS.128 [0x10], R8
STS.128 [0x20], R12
STS.128 [0x30], R16
while
for (int i = 0; i < 4; ++i) {
cp_async<16>(smem + 16 * i, gmem);
}
cp_async_commit();
cp_async_wait<0>();
compiles down to
LDGSTS.E.BYPASS.LTC128B.128 [RZ], [R2.64]
LDGSTS.E.BYPASS.LTC128B.128 [0x10], [R2.64]
LDGSTS.E.BYPASS.LTC128B.128 [0x20], [R2.64]
LDGSTS.E.BYPASS.LTC128B.128 [0x30], [R2.64]
LDGDEPBAR // commit
DEPBAR.LE SB0, 0x0 // wait 0
Committing
Another benefit of cp.async
compare to standard loads is related to committing.
With commit checkpointing, cp.async
has a clean way of grouping multiple loads together into distinct operations that can be waited on separately. This allows us to, for instance:
- Separate the loads for
and into different groups, but still initiate them at the same time - Only wait on the
transfer while is still in flight
When using standard loads, we’d need to either wait on both, or delay the load for
wmma API
Compared to mma
, wmma
is a higher-level API that handles loading and writing values between smem and rmem, where the layout of the matrix fragments is opaque. This means we don’t know which threads in the warp contain which components of the matrix, or which registers contain which values. 1 As a result, we can’t operate on the registers directly and must read and write from shared memory instead. There are two downsides to this:
-
We’ll need to reserve extra shared memory for
and (to perform row-wise operations); since shared memory will already be limited, this may slow our kernel either due to lower occupancy or because certain configurations of block sizes that are more optimal become infeasible. -
Furthermore, we will incur extra reads and writes to shared memory when we:
- Perform softmax (read
and write ) - Compute
and scale (read and read/write )
- Perform softmax (read
This results in higher latency and increased pressure on the MIO. This is the unit in each SM partition that handles requests for shared memory and various other pipelines like special math functions.
mma
gives us control over the matrix fragment layouts so we can directly operate on the output values in registers and avoid unnecessary SMEM operations.
Shapes
Generalized Tensor Shapes
element size (bytes) | GMEM+SMEM Majorness | GMEM↔SMEM Shape (Elements, pre kernel 9) | SMEM Shape (SMEM → RMEM, Elements) | RMEM Majorness | RMEM Shape (Registers) | mma matrixvariable | ||
---|---|---|---|---|---|---|---|---|
2 | Row major | Row major | A | |||||
2 | Row major | Row major | B | |||||
2 | Row major | Column major | B | |||||
4 | N/A | N/A | N/A | Row major | C/D | |||
2 | N/A | N/A | N/A | Row major | A | |||
4 | N/A | N/A | N/A | Row major | C/D | |||
2 | Row major | Row major | N/A |
LD/ST Operation Table
From | To | Blocks | PTX Instr. / C++ | Warp-Wide Op Size | Thr. Op Size | Thr. ID Mapping Order | Register Shape | Notes |
---|---|---|---|---|---|---|---|---|
GMEM | SMEM | cp.async | Row-major | |||||
SMEM | RF | ldmatrix.x4 | Column-major | |||||
RF | SMEM | standard (4B) | Row-major | |||||
SMEM | GMEM | standard (16B) | Row-major |
mma.m16n8k16 Instruction Shapes
Operand | DType | Shape (Variables) | Shape (Elements) | Shape (Fragments) | Shape (Registers) |
---|---|---|---|---|---|
A | BF16/FP16 | (m, k) | (16, 16) | (2, 2) | (2, 2) |
B | BF16/FP16 | (n, k) | (8, 16) | (1, 2) | (1, 2) |
C+D | FP32 | (m, n) | (16, 8) | (2, 1) | (2, 2) |
Warp-Wide Thread to (row, col) Mapping Kernels 1-8
Operation | Row | Column |
---|---|---|
mma fragment /RF → SMEM | (tid % 32) / 4 | (tid % 4) * 2 |
SMEM → RF (ldmatrix ) | tid % 16 | ((tid % 32) / 16) * 8 |
GMEM ↔ SMEM | (tid % 32) / 8 | tid % 8 |
Warp-Wide Thread to (row, col) Mapping Kernels 9-16
Let lid = threadIdx.x % 32;
Operation | Tensors Affected | Row | Column |
---|---|---|---|
mma fragment /RF → SMEM | lid / 4 | (lid % 4) * 2 | |
SMEM → RF (ldmatrix ) | lid % 16 | (lid / 16) * 8 | |
SMEM → RF (ldmatrix ) | (lid % 8) + 8 * (lid % 32) / 16 | lid % 8 | |
GMEM ↔ SMEM | all | lid / 8 | lid % 8 |
Kernel 1-6 Tensor Shapes
Tensor Shapes
element size (bytes) | GMEM+SMEM Majorness | GMEM↔SMEM Shape | SMEM Shape (SMEM → RF) | RF Majorness | RF Shape (Registers) | ||
---|---|---|---|---|---|---|---|
2 | Row major | Row major | |||||
2 | Row major | Row major | |||||
2 | Row major | Column major | |||||
4 | N/A | N/A | N/A | Row major | |||
2 | N/A | N/A | N/A | Row major | |||
4 | N/A | N/A | N/A | Row major | |||
2 | Row major | Row major |
Tensor RF Storage Shapes
Tensor | element size (bytes) | mma matrixvariable | Storage Shape | Register Count |
---|---|---|---|---|
2 | A | 32 | ||
4 | C/D | 64 | ||
2 | B | 128 | ||
2 | B | 128 | ||
4 | C/D | 32 | ||
2 | A | 16 | ||
4 | 2 | |||
4 | 2 |
GEMM Shapes
A | A Shape (Registers) | B | B Shape (Registers) | Iteration Shape (k, m, n) |
---|---|---|---|---|
Footnotes
-
The layout has been reverse engineered, but it’s unclear how reliable it is for production code. ↩