AI Usage

What AI was used for

  • creating most of the code in the python scripts
  • editing the blog posts
  • sparingly used for C++ code

Benchmarking and Profiling

Devices were tested on

  • CUDA 12.8
  • Driver version 570.133.20
  • torch==2.5.1+cu124
  • fork of flash attention v2 fixed to commit b36ad4e (closest tagged version is v2.7.4), which includes the following changes:
    • removing non-relevant kernels for building to speed up build time
    • adding cuda timing inside the flash attention C++ API
    • removing the allocation, computation, and storing back to gmem of lse. the kernel we build doesn’t write lse back. notwithstanding, the impact of this was very minimal, amounting to ~0.1% of difference
    • link to repo

Benchmarking

We benchmark the following sequence lengths:

  • 512
  • 1024
  • 2048
  • 4096
  • 8192
  • 16384

When benchmarking, each kernel was run 32 times for warmup and then timed 128 times for TFLOPs calculation. The practices in this blog post were followed for more accurate benchmarks.

When benchmarking, devices were run with the following clocks:

DeviceSM ClockDRAM Clock
RTX 30901680 MHz9501 MHz
A100 PCIe 80GB1110 MHz1512 MHz
Note: 1512 MHz is the only valid DRAM clock on the A100.

You can get the set of valid clocks for a GPU by using this command.

nvidia-smi --query-supported-clocks=gr,mem --format=csv

Profiling

For Nsight compute profile metrics, we run each kernel 32 times on sequence length 4096 to reduce metric variance. The metrics post in the blog are the mean values recorded, unless the metric is from a screenshot of Nsight Compute.

The command used to profile kernels:

ncu \
  --config-file off \
  --export /path/to/profile \
  --force-overwrite \
  --target-processes application-only \
  --kernel-name regex:device|flash \
  --warp-sampling-interval 1 \
  --warp-sampling-max-passes 1000 \
  --warp-sampling-buffer-size 536870912 \
  --set full \
  --apply-rules no \
  --import-source no \
  /path/to/python /path/to/repo/tools/benchmark/run_kernels.py 4096 128 --kernels $KERNEL_CONFIG_NAME \
  --n_runs 32

Calculating Arithmetic Intensity

Here are the python functions used to calculate arithmetic intensity for a single tile pair.

ELEM_SIZE = 2  # bytes
 
def softmax_flop(B_r, B_c, d_head) -> int:
    return B_r * (5 * B_c + d_head + 4)
 
def tile_flop(B_r, B_c, d_head) -> int:
    QK_flops = B_r * d_head * B_c
    PV_flops = B_r * B_c * d_head
 
    softmax_flops = softmax_flop(B_r, B_c, d_head)
    return QK_flops + PV_flops + softmax_flops
 
def gmem_transfer_size(B_r, B_c, d_head) -> int:
    return d_head * 2 * (B_r + B_c) * ELEM_SIZE
 
def arithmetic_intensity(B_r, B_c, kv_seq_len, d_head) -> float:
    return (
        tile_flop(B_r, B_c, d_head) * (kv_seq_len // B_c)
    ) / gmem_transfer_size(B_r, kv_seq_len, d_head)
 

Compute Capability

Different CUDA GPUs have different compute capabilities, which is a term to distinguish the features and instructions they support. This includes properties like the maximum amount of shared memory supported per SM and support for specific PTX instructions. It’s essentially hardware versioning.

The CUDA C++ Programming Guide has a table detailing these specifications. I include the ones most relevant compute capabilities here.

DeviceCompute
Capability
Max Shared Memory
Per CTA/SM
Max # 32b Registers
Per Thread
# 32b Registers
Per SM
A1008.0163KB/164KB25565536
RTX 30908.699KB/100KB25565536

Glossary

CUDA/GPU Programming Terms

  • GMEM: global memory, stored in DRAM
  • SMEM: shared memory, stored in SRAM
  • LMEM: local memory
  • RF: register file / register memory
  • CTA: cooperative thread array; same as thread-block. I’ll use CTA to avoid confusion with matrix blocks
  • tiling: dividing a matrix into tiles, each processed by a CTA, warp, or thread
  • fragment: a tile of a matrix stored in registers. In this series, this specifically refers to an tile in rmem
    • Within a warp, each thread holds 2 values
    • This requires a single 32-bit register for 16-bit values, and
  • lane_id: thread index within a warp (tid % 32)
  • LD/ST: load/store operations
  • mnk variables: standard naming convention for GEMM dimensions:
    • For D = AB^T + C:
      • A is (m, k)
      • B is (n, k)
      • C and D are (m, n)
    • In our context, k corresponds to:
      • in
      • in
  • Arithmetic Intensity: # fp operations performed divided by the # bytes loaded. This can be different for different levels in the memory hierarchy. For instance, the # bytes loaded from the L1 cache can be different from the L2 cache.

Flash Attention Terms

Following the notation from the paper, with a few simplifications:

  • : Query and output tensors handled by the current CTA
  • : The -th key/value tile
  • : row-wise max of attention scores
  • : row-wise sum of exponentiated attention scores
  • : query rows in the block
  • : key and rows in the and blocks
  • : head dimension

Kernel Specification

  • forward pass only
  • non-causal attention
  • head dimension = 128
  • no dropout or KV caching
  • equal query/key/value sequence lengths
  • sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
  • 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32

nvcc Flags for Register Spilling

We can configure nvcc with these settings to warn us if we’re spilling registers and how much local memory we’re using.

  • -Xptxas=-warn-spills: this will output a warning if a kernel is spilling registers into local memory
    - ptxas warning : Registers are spilled to local memory in function 'kernel_name', 20 bytes spill stores, 20 bytes spill loads
  • -Xptxas=-warn-lmem-usage: this will output a warning if a kernel uses local memory
    - Local memory used for function 'kernel_name', size of stack frame: 8 bytes
  • --resource-usage: will give information about the resources your kernel uses

If your kernel doesn’t spill any registers, the output will look something like this

ptxas info    : Compiling entry function '...' for 'sm_80'
ptxas info    : Function properties for ...
       0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 163 registers, used 1 barriers, 448 bytes cmem[0]
   ptxas info    : Compile time = 24.428 ms

And if it does spill, then something like this

   ptxas info    : Function properties for ...
       456 bytes stack frame, 668 bytes spill stores, 580 bytes spill loads
   ptxas info    : Used 255 registers, used 1 barriers, 456 bytes cumulative stack size, 448 bytes cmem[0]
   ptxas info    : Compile time = 69.922 ms

Instructions

mma instruction

The PTX instruction that we use for mma is

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 d, a, b, c

This instruction has a fair number of modifiers, so let’s digest this piece by piece.

  • mma means matrix multiply + accumulate
  • .sync.aligned means that threads in the same warp cooperate with each other
    • .sync means threads will block until all threads execute the instruction
    • .aligned means all threads must execute the same instruction
  • m16n8k16 is the shape of our operation
    • where
      • has shape
      • has shape
      • has shape
    • This means that we have
      • 4 fragments of A = 4 registers / 8 elements
      • 2 fragments of B = 2 registers / 4 elements
      • 2 fragments of C (2x2 = 4 registers / 4 elements because FP32 accumulator)
      • 2 fragments of D (2x2 = 4 registers because FP32 accumulator)
  • .row.col means is in row major format and is in col major format
  • .f32.f16.f16.f32 are the data types of D, A, B, C respectively.

The d, a, b, c components expand to {d1_1, d1_2, d2_1, d2_2}, {a1_1, a2_1, a3_1, a4_1}, {b1_1, b2_1}, {c1_1, c1_2, c2_1, c2_2}, where each variable is represents a register for {matrix}{fragment}_{register #}.

cp.async vs Traditional Loads

You might be wondering: aren’t “normal” loads still asynchronous? If they weren’t, wouldn’t GPUs would be very slow? Yes, you’d be right: the GPU won’t initiate a load and then sit there idly waiting for the load to complete. Instead, it’ll initiate the load and continue executing until it encounters an instruction that depends on the load. Only then will it block. Therefore, with enough instruction-level parallelism (ILP), we can hide the latency.

So what’s the difference then? The way standard loads copy data from gmem to smem is:

  • The data is first copied from gmem to the register file
  • Then from the register file it gets copied to smem

cp.async, on the other hand, completely bypasses the RF and copies directly into smem. In addition, when our accesses are 16 bytes and aligned, we can bypass the L1 cache with the .cg modifier, which reduces cache pollution. When using the other modifier .ca, L1 will not be bypassed.

As we’ll see later on, our register budget will be fairly constrained, so this can alleviate some of that pressure. cp.async also includes the benefits of checkpointing, as discussed earlier.

The typical path when loading data from gmem on a cache miss is gmem -> L2 -> L1 -> registers -> smem. However, with optimized async loads, it looks like gmem -> L2 -> smem. Even if we were to bypass the L1 using cache hints, we would still have to load through L2 -> registers -> smem.


Source

If we take a look at the corresponding SASS,

standard_copy_example
    for (int i = 0; i < N_COPIES; ++i) {
        smem[i] = gmem[i];
	}

compiles down to

 // Copies from GMEM -> RF via L2 & L1 cache
 LDG.E.128 R4, [R2.64] 
 LDG.E.128 R8, [R2.64+0x10] 
 LDG.E.128 R12, [R2.64+0x20] 
 LDG.E.128 R16, [R2.64+0x30] 
 
 // Copies from RF -> SMEM
 STS.128 [RZ], R4 
 STS.128 [0x10], R8 
 STS.128 [0x20], R12 
 STS.128 [0x30], R16 

while

cp_async_example
    for (int i = 0; i < 4; ++i) {
        cp_async<16>(smem + 16 * i, gmem);
    }
 
    cp_async_commit();
    cp_async_wait<0>();
 

compiles down to

 LDGSTS.E.BYPASS.LTC128B.128 [RZ], [R2.64] 
 LDGSTS.E.BYPASS.LTC128B.128 [0x10], [R2.64] 
 LDGSTS.E.BYPASS.LTC128B.128 [0x20], [R2.64] 
 LDGSTS.E.BYPASS.LTC128B.128 [0x30], [R2.64] 
 LDGDEPBAR          // commit
 DEPBAR.LE SB0, 0x0 // wait 0

Committing

Another benefit of cp.async compare to standard loads is related to committing.
With commit checkpointing, cp.async has a clean way of grouping multiple loads together into distinct operations that can be waited on separately. This allows us to, for instance:

  • Separate the loads for and into different groups, but still initiate them at the same time
  • Only wait on the transfer while is still in flight

When using standard loads, we’d need to either wait on both, or delay the load for until after has been loaded.

wmma API

Compared to mma, wmma is a higher-level API that handles loading and writing values between smem and rmem, where the layout of the matrix fragments is opaque. This means we don’t know which threads in the warp contain which components of the matrix, or which registers contain which values. 1 As a result, we can’t operate on the registers directly and must read and write from shared memory instead. There are two downsides to this:

  1. We’ll need to reserve extra shared memory for and (to perform row-wise operations); since shared memory will already be limited, this may slow our kernel either due to lower occupancy or because certain configurations of block sizes that are more optimal become infeasible.

  2. Furthermore, we will incur extra reads and writes to shared memory when we:

    • Perform softmax (read and write )
    • Compute and scale (read and read/write )

This results in higher latency and increased pressure on the MIO. This is the unit in each SM partition that handles requests for shared memory and various other pipelines like special math functions.

mma gives us control over the matrix fragment layouts so we can directly operate on the output values in registers and avoid unnecessary SMEM operations.

Shapes

Generalized Tensor Shapes

element size
(bytes)
GMEM+SMEM MajornessGMEM↔SMEM
Shape (Elements, pre kernel 9)
SMEM Shape (SMEM → RMEM, Elements)RMEM MajornessRMEM
Shape (Registers)
mma matrix
variable
2Row majorRow majorA
2Row majorRow majorB
2Row majorColumn majorB
4N/AN/AN/ARow majorC/D
2N/AN/AN/ARow majorA
4N/AN/AN/ARow majorC/D
2Row majorRow majorN/A

LD/ST Operation Table

FromToBlocksPTX Instr. / C++Warp-Wide
Op Size
Thr.
Op Size
Thr. ID Mapping
Order
Register
Shape
Notes
GMEMSMEMcp.asyncRow-major
SMEMRFldmatrix.x4Column-major transpose
RFSMEMstandard (4B)Row-major
SMEMGMEMstandard (16B)
Row-major

mma.m16n8k16 Instruction Shapes

OperandDTypeShape
(Variables)
Shape
(Elements)
Shape
(Fragments)
Shape
(Registers)
ABF16/FP16(m, k)(16, 16)(2, 2)(2, 2)
BBF16/FP16(n, k)(8, 16)(1, 2)(1, 2)
C+DFP32(m, n)(16, 8)(2, 1)(2, 2)

Warp-Wide Thread to (row, col) Mapping Kernels 1-8

OperationRowColumn
mma fragment /
RF → SMEM
(tid % 32) / 4(tid % 4) * 2
SMEM → RF (ldmatrix)tid % 16((tid % 32) / 16) * 8
GMEM ↔ SMEM(tid % 32) / 8tid % 8

Warp-Wide Thread to (row, col) Mapping Kernels 9-16

Let lid = threadIdx.x % 32;

OperationTensors AffectedRowColumn
mma fragment /
RF → SMEM
lid / 4(lid % 4) * 2
SMEM → RF (ldmatrix)lid % 16(lid / 16) * 8
SMEM → RF (ldmatrix)(lid % 8) + 8 * (lid % 32) / 16lid % 8
GMEM ↔ SMEMalllid / 8lid % 8

Kernel 1-6 Tensor Shapes

Tensor Shapes

element size
(bytes)
GMEM+SMEM MajornessGMEM↔SMEM
Shape
SMEM Shape
(SMEM → RF)
RF MajornessRF Shape (Registers)
2Row majorRow major
2Row majorRow major
2Row majorColumn major
4N/AN/AN/ARow major
2N/AN/AN/ARow major
4N/AN/AN/ARow major
2Row majorRow major

Tensor RF Storage Shapes

Tensorelement size
(bytes)
mma matrix
variable
Storage
Shape
Register
Count
2A32
4C/D64
2B128
2B128
4C/D32
2A16
42
42

GEMM Shapes

AA Shape
(Registers)
BB Shape
(Registers)
Iteration Shape
(k, m, n)

Footnotes

  1. The layout has been reverse engineered, but it’s unclear how reliable it is for production code.