In Part 2, we explored the fundamental CUDA building blocks - tensor core operations (mma
) and efficient memory transfers (cp.async
& ldmatrix
). Now we’re ready to assemble these pieces into our first complete Flash Attention kernel. This base implementation will serve as our foundation, giving us a working kernel that we’ll iteratively optimize throughout the series.
We’ll tackle this in three main steps: first, we’ll figure out how to split up the work between different levels (CTAs, warps, and threads). Then we’ll build higher-level operations from our basic instructions. Finally, we’ll assemble everything into a complete kernel.
Here’s what we’re building towards:
- forward pass only
- non-causal attention
- head dimension = 128
- no dropout or KV caching
- equal query/key/value sequence lengths
- sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
- 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32
By the end of this part, our kernel will reach 49% the performance of the official implementation on the RTX 3090.
Kernel Architecture Overview
Before diving into implementation details, let’s establish the overall structure of our kernel. Flash Attention kernels follow a standard three-phase pattern that we’ll use throughout this series:
Kernel Phases
- Prologue: Setup and initialization
- Calculate memory addresses for each tensor and warp
- Load
from GMEM → SMEM → RF (happens once) - Initialize softmax statistics (
and ) and output accumulator
- Mainloop: Iterative attention computation (repeated for each
/ block)
- Load
: GMEM → SMEM → RF - Compute attention scores:
- Apply online softmax and update statistics
- Load
: GMEM → SMEM → RF - Compute output contribution:
- Epilogue: Finalization and output
- Complete softmax normalization
- Convert
from fp32 to fp16/bf16 - Write final output: RF → SMEM → GMEM
Implementation Challenges
With our kernel structure established, we need to solve three main technical challenges:
- Data Movement: Getting tensors efficiently through the memory hierarchy (GMEM → SMEM → RF) while handling different layouts, access patterns, and synchronization requirements
- Mathematical Operations: Implementing GEMM operations and online softmax using our
mma
instructions and warp primitives - Synchronization: Coordinating between threads and warps to avoid race conditions
The bulk of our complexity comes from data movement. Each tensor has different requirements:
Let’s tackle these one by one, starting with the fundamental decisions that will shape our entire implementation.
Kernel Configuration
Now that we understand the overall structure, let’s nail down the specific parameters that will define our kernel’s behavior.
Starting now through kernel 7, we’ll choose our block sizes to be m16n8k16
instruction, and they’re small where tiles should be able to fit snugly into SMEM and the RF.
We’ll also use 128 threads (4 warps) per CTA, following CUDA’s recommended starting point.
In kernel 7, we’ll test and benchmark the performance of the kernel with different configurations to find the most performant.
CTA Work Distribution
We need to think about how to distribute work across three levels: CTAs, warps, and threads. Let’s start with the highest level: how we’ll distribute work across CTAs in our grid.
Our Flash Attention kernel needs to process tensors shaped like (batch_size, seq_len, n_heads, d_head)
by tiling them into smaller blocks that fit into SM memory. Given our configuration, we’ll partition the tensors as follows:
and : Split into tiles and : Split into tiles, denoted and
Each CTA is assigned to process one specific (sample, head)
pair. This will involve loading and computing with all (sample, head)
pair, where
This gives us n_samples * n_heads * T_r
total query blocks to process, where T_r = seq_len / B_r
. We’ll launch exactly that many CTAs to handle all the work.
Kernel Arguments
Here are the arguments our kernel will receive:
struct FAForwardArgs {
using index_t = int64_t;
void *__restrict__ Q;
void *__restrict__ K;
void *__restrict__ V;
void *__restrict__ O;
const index_t batch_stride;
const index_t seq_stride;
const index_t head_stride;
const index_t seq_len;
const index_t n_heads;
const int n_Q_blocks;
const int n_KV_blocks;
};
Grid Mapping
Each CTA handles a (sample, head)
pair, so we’ll set the kernel grid to be some permutation of the shape (sample, query_block, head)
. How should we map CTAs to query blocks?
Let’s consider a single sample for a particular head. The CTAs for that sample may read different blocks of
Here’s the key insight: we want CTAs processing the same sample and head to run around the same time so they can share cached K/V data. When the first CTA loads a
CTAs launch in order of their ID: blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z
Since CTAs with the same y
and z
but different x
have consecutive IDs, they launch together. So we map query blocks to the x
dimension: (x, y, z) -> (Q_block, head, batch)
.
If our mapping was (x, y, z) -> (head, Q_block, batch)
, our L2 cache would be far less effective since our hit rate would drop significantly.
So we’ll set up our variables like this:
// ...
const int sample = blockIdx.z;
const int head = blockIdx.y;
const int q_seq_block = blockIdx.x;
// ...
How does the grid mapping affect performance? Each
Here’s how the grid mapping affects cache performance on our Flash Attention workload:
GPU | L2 Cache Size | Unoptimized Hit Rate | Optimized Hit Rate | Performance Impact |
---|---|---|---|---|
RTX 3090 | 6MB | ~2% | ~98% | ~3% performance hit |
A100 | 40MB | ~25.6% | ~92.6% | ~1% performance hit |
The impact is modest, but since getting the mapping right is essentially free, it’s worth implementing.
Configuration Template Parameters
Even though we have a fixed kernel configuration, we’ll want to make it easy to generalize it if we want to test out different configurations (which we’ll do in kernel 7). We’ll add a struct template parameter to support different block sizes, number of threads per CTA, and different data types. This is the template parameter for the kernel function:
struct FlashForwardKernelConfig {
// This gets statically converted into either half or nv_bfloat16
const torch::ScalarType dtype;
const int d_head; // [64, 128]
const int B_r; // [64, 128]
const int B_c; // [32, 64, 128]
const int n_warps; // [4, 8]. 8 only when B_r = 128
};
Warp-Level Work Distribution
Now that we’ve established how CTAs distribute work, let’s zoom in to the warp level. From here on, we’ll focus on warp-to-CTA and thread-to-warp interactions, so we’ll drop the
Remember from Part 2 that tensor operations must execute in lockstep across all threads in a warp. Our mma
instruction operates on 16×16 tiles, so we’ll use this as the basis for dividing work between warps.
Here’s how the math works out: we have 64-row blocks and 4 warps per CTA, so each warp gets 64÷4 = 16 rows. Combined with our 128-column head dimension, each warp handles (16, 128) sub-tiles.
We’ll treat how we load
Why the different strategies for
and Processing (Independent per Warp)
- The 64-row tile is split into 4 independent
slices - Each warp operates completely independently on its slice for:
- Loading/storing operations
- GEMM computations
and Processing (Cooperative across Warps)
- Loading phase: Each warp loads its own (16, 128) slice from GMEM → SMEM
- Synchronization: All warps wait for the complete (64, 128) block to be loaded
- Copy phase: Each warp independently copies the entire block from SMEM → RF
Per-Warp Workload Summary
Each of the 4 warps in a CTA processes:
Independent Operations (per warp):
- Load
slice of : GMEM → SMEM → RF - Compute
with shape in RF - Compute
in RF - Compute
in RF - Store
slice of : accumulated in RF, then RF → SMEM → GMEM
Cooperative Operations (across all warps):
- Load
blocks of and : cooperatively GMEM → SMEM, then independently SMEM → RF
The diagram below shows how each warp handles its slice of the work:
Every tile with a color background is stored in the RF.
Future optimization
It’s actually slightly more efficient for warps to cooperate when loading
and blocks too. We’ll cover this later in part 5 (kernel 9).
Data Movement
Now that we understand the work distribution strategy, let’s tackle the implementation. Here’s where things get interesting - and complex. Moving data efficiently turns out to be the hairiest part of writing this kernel. We’re dealing with different layouts, access patterns, and synchronization needs all at once.
Each tensor has different requirements:
Our strategy is to build this complexity in layers, from low-level memory operations up to a clean abstraction that hides all the mess:
- Core memory operations: Generic functions for GMEM ⟷ SMEM transfers and specialized SMEM → RF functions that handle transposition
- Address management: Calculate the right pointers for each tensor and warp
- Unified interface: A
MatrixLDST
class that wraps everything into a simple, tensor-specific API
Let’s build this step by step.
Configuration Structs
To efficiently move data through the memory hierarchy (GMEM → SMEM → RF), we need operations that can handle different tensor requirements. Here are the configuration structs that encode all the LD/ST requirements:
struct TileLayout {
const int row_fragments;
const int col_fragments;
};
// constexpr non-type template parameter containing parameters for LD/ST for a
// block (Q, K, V, or O) from GMEM to SMEM and vice versa, and also loading from
// SMEM to the RF.
struct TensorLDSTConfig {
// Tile layout for shared memory and RF.
const TileLayout GSM;
const TileLayout RF;
// Block specific properties
const bool transposed;
const int block_size;
const int smem_cols;
// # of rows a warp in a thread-block independently loads/stores. it is equivalent to GSM.row_fragments * 8.
const int warp_ldst_rows;
// Whether not the warp will compute over the entire block.
// This is false for (Q&O&S) and true for (K&V).
const bool compute_over_entire_block;
}
Storage Layout
We use different storage layouts for each memory level and tensor:
- SMEM: All tensors stored row-major for efficient loading from GMEM
- RF: Most tensors stay row-major, but
is stored transposed
Here’s a table to help you keep in mind what we need to transfer and the shapes at each memory level:
element size
(bytes)GMEM+SMEM Majorness GMEM↔SMEM
ShapeSMEM Shape
(SMEM → RF)RF Majorness RF Shape (Registers) 2 Row major Row major 2 Row major Row major 2 Row major Column major 4 N/A N/A N/A Row major 2 N/A N/A N/A Row major 4 N/A N/A N/A Row major 2 Row major Row major
And here are the LD/ST operations we’ll use to transfer these.
From To Blocks PTX Instr. / C++ Warp-Wide
Op SizeThr.
Op SizeThr. ID Mapping
OrderRegister
ShapeNotes GMEM SMEM cp.async
Row-major SMEM RF ldmatrix.x4
Column-major transpose RF SMEM standard (4B) Row-major SMEM GMEM standard (16B) Row-major
Copying Between GMEM ⟷ SMEM
Now let’s implement the actual data movement. Recall that each warp-wide operation accesses a (4, 64) chunk, and our tile size is (16, 128), so we need (4, 2) cp.async()
instructions to fully copy each block.
Here’s the code for copying between GMEM and SMEM. The key design choices:
- Bidirectional: Works for both GMEM→SMEM and SMEM→GMEM by passing different operation functors
- Thread coordination: Each thread calculates its offset in row-major order to avoid conflicts
- Template-based: Uses the
TensorLDSTConfig
template parameter to handle different tensor layouts
The operation functors define the actual copy behavior. The following implementation handles both GMEM→SMEM and SMEM→GMEM transfers using template-based operation functors:
#define ROWS_PER_FRAGMENT 8
#define COLS_PER_FRAGMENT 8
#define GSM_LDST_ROWS_PER_ITER 4
#define BYTES_PER_VEC4_ACCESS 16
template <typename T>
struct GM2SM_async {
__device__ constexpr void operator()(T *gmem, T *smem) {
cp_async<BYTES_PER_VEC4_ACCESS>(smem, gmem);
}
};
template <typename T>
struct SM2GM {
__device__ constexpr void operator()(T *gmem, T *smem) {
reinterpret_cast<uint4 *>(gmem)[0] = reinterpret_cast<uint4 *>(smem)[0];
}
};
template <typename op, /* either GM2SM_async or SM2GM */
TensorLDSTConfig CFG, typename value_t, typename index_t = int64_t>
__forceinline__ __device__ constexpr void copy_block_GSM(value_t *gmem, value_t *smem,
index_t gmem_seq_stride,
const int lane_id) {
constexpr int n_row_iters =
CFG.GSM.row_fragments * ROWS_PER_FRAGMENT / GSM_LDST_ROWS_PER_ITER;
constexpr int col_fragments_per_iter = WARP_SIZE / GSM_LDST_ROWS_PER_ITER;
constexpr int col_fragments_per_row = CFG.smem_cols / COLS_PER_FRAGMENT;
const int thread_row = lane_id / col_fragments_per_iter;
const int thread_col_fragment = lane_id % col_fragments_per_iter;
#pragma unroll
for (int r = 0; r < n_row_iters; ++r) {
const int cur_row = r * GSM_LDST_ROWS_PER_ITER + thread_row;
#pragma unroll
for (int c = 0; c < col_fragments_per_row;
c += col_fragments_per_iter) {
const int col_fragment = c + thread_col_fragment;
op()(&gmem[cur_row * gmem_seq_stride +
col_fragment * COLS_PER_FRAGMENT],
&smem[cur_row * CFG.smem_cols +
col_fragment * COLS_PER_FRAGMENT]);
}
}
}
A Note on
reinterpret_cast
You’ll notice
reinterpret_cast
is used frequently. This is a common and safe pattern in high-performance CUDA for vectorized memory access. By re-interpreting pointers to larger types likeuint4
, we can load or store 16 bytes in a single instruction, provided the memory is correctly aligned, which it is in our case.
Additional Requirements for
cp.async
When using
cp.async()
, we need to commit withcp.commit()
and wait withcp.wait()
. We’ll cover the synchronization details in Synchronization.
SMEM → RF Operations
You might recall from How Fragments Are Laid Out Across Threads that we store and load
This layout difference requires transposing during the SMEM → RF copy: SMEM (row, col) gets transposed into RF (col, row). The transposition affects both the fragment arrangement and thread-level element storage, which is why we call ldmatrix_transpose()
instead of ldmatrix()
for
An iteration for
while an equivalent iteration for
and
#define ROWS_PER_FRAGMENT 8
#define COLS_PER_FRAGMENT 8
#define ELEMS_PER_VEC4_ACCESS 8
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_SM2RF(
uint32_t (®s)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
const int lane_id, const int col_fragment_offset = 0) {
constexpr int row_fragments_per_iter = 2;
constexpr int rows_per_iter = ROWS_PER_FRAGMENT * row_fragments_per_iter;
constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
constexpr int col_fragments_per_iter = WARP_SIZE / rows_per_iter;
const int thread_row = lane_id % rows_per_iter;
const int thread_col_fragment = lane_id / rows_per_iter;
#pragma unroll
for (int r = 0; r < CFG.RF.row_fragments; r += row_fragments_per_iter) {
const int cur_row = thread_row + r * ROWS_PER_FRAGMENT;
#pragma unroll
for (int c = 0; c < CFG.RF.col_fragments; c += col_fragments_per_iter) {
const int smem_col_fragment = thread_col_fragment + c + col_fragment_offset;
ldmatrix_x4(&smem[cur_row * CFG.smem_cols +
smem_col_fragment * ELEMS_PER_VEC4_ACCESS],
regs[r][c], regs[r + 1][c], regs[r][c + 1],
regs[r + 1][c + 1]);
}
}
}
- we loop over SMEM in row → col, and vice versa for RF
- instead of swapping the SMEM pointers to transpose the fragments, we’ll swap the RF indices
- this allows the kernel to share same SMEM offset calculations with the other tiles
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_transposed_SM2RF(
uint32_t (®s)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
const int lane_id, const int row_fragment_offset = 0) {
constexpr int row_fragments_per_iter = 2;
constexpr int rows_per_iter = ROWS_PER_FRAGMENT * row_fragments_per_iter;
constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
constexpr int col_fragments_per_iter = WARP_SIZE / rows_per_iter;
const int thread_row = lane_id % rows_per_iter;
const int thread_col_fragment = lane_id / rows_per_iter;
#pragma unroll
for (int r = 0; r < CFG.RF.col_fragments; r += row_fragments_per_iter) {
const int cur_row =
thread_row + (r + row_fragment_offset) * ROWS_PER_FRAGMENT;
#pragma unroll
for (int c = 0; c < CFG.RF.row_fragments; c += col_fragments_per_iter) {
const int smem_col_fragment = thread_col_fragment + c;
ldmatrix_x4_transpose(
&smem[cur_row * CFG.smem_cols +
smem_col_fragment * ELEMS_PER_VEC4_ACCESS],
regs[c][r], regs[c][r + 1], regs[c + 1][r], regs[c + 1][r + 1]);
}
}
}
Copying from RF → SMEM
The only block we copy from RF → SMEM is ldmatrix
, we use standard 4B smem[dst] = rf[src];
stores.
Each iteration of the loop stores a single
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_RF2SM(
uint32_t (®s)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
const int lane_id) {
constexpr int rows_per_iter = ROWS_PER_FRAGMENT;
constexpr int col_fragments_per_iter = 1;
constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
constexpr int elems_per_store = 2;
const int thread_row = lane_id / 4;
const int thread_inner_col = (lane_id % 4) * elems_per_store;
#pragma unroll
for (int r = 0; r < CFG.RF.row_fragments; ++r) {
const int cur_row = thread_row + r * rows_per_iter;
#pragma unroll
for (int c = 0; c < CFG.RF.col_fragments; c += col_fragments_per_iter) {
const int smem_col_fragment = c;
reinterpret_cast<uint32_t *>(
&smem[cur_row * CFG.smem_cols +
(smem_col_fragment * ELEMS_PER_VEC4_ACCESS +
thread_inner_col)])[0] = regs[r][c];
}
}
}
GMEM Address Calculations for Tensors
The address calculations for each block is fairly straightforward. This code calculates the pointers into the blocks for a particular sample, head, and query block. Since the warp specific addresses will be different for different tensors, we’ll add that logic instead to a tensor class that’ll wrap all the operations we need.
const int sample = ...;
const int head = ...;
const int q_seq_block = blockIdx.x;
// ....
using value_t = nv_bfloat16; // or half
const index_t gmem_seq_stride = args.seq_stride;
const index_t sample_head_offset =
sample * args.batch_stride + head * args.head_stride;
// We only read/write one block for Q and O.
// These offsets are the same for the whole thread-block.
const index_t QO_gmem_block_offset =
sample_head_offset + q_seq_block * 64 * gmem_seq_stride;
// We read the entire key sequence.
const index_t KV_gmem_block_offset = sample_head_offset;
value_t *gmem_Q = &static_cast<value_t *>(args.Q)[QO_gmem_block_offset];
value_t *gmem_O = &static_cast<value_t *>(args.O)[QO_gmem_block_offset];
value_t *gmem_K = &static_cast<value_t *>(args.K)[KV_gmem_block_offset];
value_t *gmem_V = &static_cast<value_t *>(args.V)[KV_gmem_block_offset];
Tensor Abstraction Layer
Matrix LDST Class
The MatrixLDST
class wraps all load and store operations, spanning all levels of the memory hierarchy. This abstraction handles the complexity of different tensor layouts and access patterns in one place.
Key features:
- Unified interface for GMEM ⟷ SMEM ⟷ RF operations
- Warp-specific address calculations
- Support for both independent and cooperative loading patterns
- Automatic handling of transposed layouts (for
)
template <TensorLDSTConfig ldst, typename value_t, typename index_t = int64_t>
struct MatrixLDST {
// Static configuration
using matrix_storage_t =
RFMatrix<value_t, ldst.mma_load_stages, ldst.RF.row_fragments,
ldst.RF.col_fragments>;
using GM2SM_op = std::conditional_t<ldst.Common.async_copy,
GM2SM_async<value_t>, GM2SM<value_t>>;
using SM2GM_op = SM2GM<value_t>;
static constexpr int mma_load_stages = ldst.mma_load_stages;
static constexpr bool load_entire_block_into_rf =
ldst.load_entire_block_into_rf;
static constexpr bool transposed = ldst.transposed;
// Runtime properties
value_t *gmem_ptr;
index_t gmem_seq_stride;
// The location in memory used to load fragments from SMEM to RF.
value_t *smem_srm_ptr;
// The location in memory that the warp writes to for Q, K, V from GMEM to
// smem and O for SMEM to GMEM.
value_t *smem_gsm_ptr;
const int lane_id;
matrix_storage_t storage;
__forceinline__ __device__ MatrixLDST(value_t *gmem_block_ptr, index_t _gmem_seq_stride,
value_t *_smem_ptr)
: lane_id(threadIdx.x % WARP_SIZE) {
const int warp_rank = threadIdx.x / WARP_SIZE;
const index_t warp_seq = ldst.warp_ldst_rows * warp_rank;
gmem_seq_stride = _gmem_seq_stride;
gmem_ptr = gmem_block_ptr + warp_seq * gmem_seq_stride;
smem_gsm_ptr = _smem_ptr + warp_seq * ldst.smem_cols;
smem_srm_ptr =
ldst.compute_over_entire_block ? _smem_ptr : smem_gsm_ptr;
}
__forceinline__ __device__ constexpr void zero() { storage.zero(); }
__forceinline__ __device__ constexpr typename matrix_storage_t::storage_t (&data(
const int stage = 0))[matrix_storage_t::rows][matrix_storage_t::cols] {
return storage.data(stage);
}
__forceinline__ __device__ constexpr void advance_gmem_block() {
gmem_ptr += ldst.block_size * gmem_seq_stride;
}
__forceinline__ __device__ constexpr void copy_GM2SM() {
copy_block_GSM<GM2SM_op, ldst>(gmem_ptr, smem_gsm_ptr, gmem_seq_stride,
lane_id);
}
__forceinline__ __device__ constexpr void copy_SM2GM() {
copy_block_GSM<SM2GM_op, ldst>(gmem_ptr, smem_gsm_ptr, gmem_seq_stride,
lane_id);
}
__forceinline__ __device__ constexpr void copy_SM2RF(int stage = 0, int tile_offset = 0) {
if constexpr (!transposed) {
copy_warp_fragment_SM2RF<ldst, value_t>(
storage.data(stage), smem_srm_ptr, lane_id, tile_offset);
} else {
copy_warp_fragment_transposed_SM2RF<ldst, value_t>(
storage.data(stage), smem_srm_ptr, lane_id, tile_offset);
}
}
__forceinline__ __device__ constexpr void copy_RF2SM() {
copy_warp_fragment_RF2SM<ldst, value_t>(data(), smem_srm_ptr, lane_id);
}
};
Register Storage Classes
template <typename value_t, int N>
struct RFVector {
static constexpr int size = N;
value_t regs[N];
__forceinline__ __device__ constexpr value_t &operator[](int idx) { return regs[idx]; }
};
template <typename value_t, int row_fragments, int col_fragments>
struct RFMatrix {
using storage_t = std::conditional_t<sizeof(value_t) == 4, float, uint32_t>;
static constexpr int regs_per_fragment = sizeof(value_t) / 2;
static constexpr int rows = row_fragments;
static constexpr int cols = col_fragments * regs_per_fragment;
storage_t regs[rows][cols];
__forceinline__ __device__ constexpr storage_t (&data(const int stage = 0))[rows][cols] {
return reinterpret_cast<storage_t(&)[rows][cols]>(regs[stage]);
}
__forceinline__ __device__ constexpr void zero() {
#pragma unroll
for (int j = 0; j < rows; ++j) {
#pragma unroll
for (int k = 0; k < cols; ++k) {
regs[j][k] = 0;
}
}
}
};
Type Conversion
There are a couple tensor tiles we’ll need to convert from 32-bit to 16-bit. We’ll need to convert the attention matrix
template <typename value_t, int M_fragments, int N_fragments>
__forceinline__ __device__ constexpr void
convert_to_16_bit_dtype(float (&src_float)[M_fragments][N_fragments * 2],
uint32_t (&dest_uint)[M_fragments][N_fragments]) {
using value2_t =
std::conditional_t<std::is_same_v<value_t, half>, half2, nv_bfloat162>;
float2(&src)[M_fragments][N_fragments] =
reinterpret_cast<float2(&)[M_fragments][N_fragments]>(src_float);
value2_t(&dest)[M_fragments][N_fragments] =
reinterpret_cast<value2_t(&)[M_fragments][N_fragments]>(dest_uint);
#pragma unroll
for (int m = 0; m < M_fragments; ++m) {
#pragma unroll
for (int n = 0; n < N_fragments; ++n) {
if constexpr (std::is_same_v<value_t, half>) {
dest[m][n] = __float22half2_rn(src[m][n]);
} else {
dest[m][n] = __float22bfloat162_rn(src[m][n]);
}
}
}
}
Computing Operations
With our data movement machinery in place, we can turn to the mathematical operations. The two main operations we need are GEMM (for mma
primitives, the online softmax involves some clever thread coordination and statistical updates that are worth walking through carefully.
GEMM Implementation
MMA Operation Overview
A GEMM is composed of mma
instructions that span all the fragments
As you might remember from part 2, each individual mma
instruction performs
has shape (m, k) = (16, 16)
has shape (n, k) = (8, 16)
and& have shape (m, n) = (16, 8)
Here are the operand shapes and iteration patterns:
A A Shape
(Registers)B B Shape
(Registers)Iteration Shape
(k, m, n)
Each GEMM iteration covers (2, 2, 1)
registers in the (k, m, n)
dimensions, with operand register shapes of A
matrix and B
.
Here’s a diagram of a single iteration
Implementation
To calculate the results across all fragments, we’ll create a triple nested loop in K -> M -> N
order.
template <typename value_t, const int M_fragments, const int N_fragments,
const int K_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void warp_fragment_mma_f32_accum(
uint32_t (®s_A)[M_fragments][K_fragments],
uint32_t (®s_B)[N_fragments][K_fragments],
accum_t (®s_C)[M_fragments][N_fragments * N_REGS_PER_F32_ACCUM_FRAGMENT]) {
#pragma unroll
for (int k = 0; k < K_fragments; k += MMA_K_FRAGMENTS_PER_ITER) {
#pragma unroll
for (int m = 0; m < M_fragments; m += MMA_M_FRAGMENTS_PER_ITER) {
#pragma unroll
for (int n = 0; n < N_fragments; n += MMA_N_FRAGMENTS_PER_ITER) {
mma_m16n8k16_f32_accum<value_t>(
regs_C[m][n * 2],
regs_C[m][n * 2 + 1],
regs_C[m + 1][n * 2],
regs_C[m + 1][n * 2 + 1],
regs_A[m][k],
regs_A[m + 1][k],
regs_A[m][k + 1],
regs_A[m + 1][k + 1],
regs_B[n][k],
regs_B[n][k + 1],
regs_C[m][n * 2],
regs_C[m][n * 2 + 1],
regs_C[m + 1][n * 2],
regs_C[m + 1][n * 2 + 1]);
}
}
}
}
Softmax
Thread-Level Perspective
For all the operations we covered until now, we viewed each warp as its own single cohesive unit and determined how each individual warp cooperates with its sibling warps. Now, we’ll look at the softmax operations from the perspective of a thread and how each thread will cooperate with their sibling threads in the same warp.
Each warp and thread operates on the data they store, with the workload distributed evenly and no extra LD/ST required to begin operating.
We compute softmax in 32-bit, which has both element-wise and row-wise operations.
Initialize Row Statistics
For both the row max (
// ...
constexpr accum_t neg_inf = -cuda::std::numeric_limits<float>::infinity();
accum_t m[N::QO_fragments_per_warp];
accum_t l[N::QO_fragments_per_warp];
#pragma unroll
for (int q = 0; q < N::QO_fragments_per_warp; ++q) {
m[q] = neg_inf;
l[q] = 0.0;
}
// ...
Dot-Product Scaling
We’ll perform the dot-product scaling of
// ...
const float softmax_scale = rsqrt(static_cast<float>(CFG.d_head));
// ...
template <int QO_fragments, int KV_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_S_accum(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
const accum_t &softmax_scale) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
#pragma unroll
for (int k = 0; k < KV_accum_fragments; ++k) {
S_accum[q][k] *= softmax_scale;
}
}
}
Reductions using Warp Shuffles
Since each thread stores different parts of the fragments, we use warp shuffles to avoid using shared memory to communicate between threads when performing reductions. For each shuffle instruction, a thread can receive and send a value to another thread in the same warp.
Since we want every thread to have the same final value at the end of each reductions, the warp shuffle we use is __shfl_xor_sync(WARP_MASK, val_to_share, xor_offset)
.
- this will send the value
val_to_share
- it will read the value from
tid ^ xor_offset
, wheretid
is the calling threads ID
We reduce over the values within a thread and then warp shuffle twice to reduce between threads in the thread row quartet.
This is the code for the row max reduction:
// This mask indicates that every thread in the warp participates in the shuffle
#define SHFL_ENTIRE_WARP_MASK 0xffffffff
template <int QO_fragments, int KV_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
calc_row_max(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
accum_t (&m_next)[QO_fragments], accum_t (&m_cur)[QO_fragments]) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
m_next[q] = m_cur[q];
// Calculate max for row across all in-thread registers.
#pragma unroll
for (int k = 0; k < KV_accum_fragments; ++k) {
m_next[q] = max(m_next[q], S_accum[q][k]);
}
// Group reduction
m_next[q] = max(__shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, m_next[q], 2),
m_next[q]);
m_next[q] = max(__shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, m_next[q], 1),
m_next[q]);
}
}
In the reduction, after each call of __shfl_xor_sync()
, each thread will double the amount of information they have.
- First xor with offset 2
- Second xor with offset 1.
and Rescaling
This function scales the previous iteration accumulators for
template <int QO_fragments, int d_head_accum_fragments,
typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_l_O(accum_t (&m_next)[QO_fragments], accum_t (&m_cur)[QO_fragments],
accum_t (&l)[QO_fragments],
accum_t (&O_accum)[QO_fragments][d_head_accum_fragments]) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
accum_t scale;
scale = expf(m_cur[q] - m_next[q]);
m_cur[q] = m_next[q];
l[q] *= scale;
for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
O_accum[q][d_head] *= scale;
}
}
}
Exponentiating
template <int QO_fragments, int KV_accum_fragments,
typename accum_t = float>
__forceinline__ __device__ constexpr void
scaled_exp(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
accum_t (&m)[QO_fragments]) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
#pragma unroll
for (int k = 0; k < KV_accum_fragments; ++k) {
S_accum[q][k] = expf(S_accum[q][k] - m[q]);
}
}
}
Partial Reduction
We row-wise sum the local block of calc_row_max
, except that we sum instead of taking the max and we skip the shuffles.
Note: P_accum
has the same underlying storage as S_accum
.
template <int QO_fragments, int d_head_accum_fragments,
typename accum_t = float>
__forceinline__ __device__ constexpr void
update_row_exp_sum(accum_t (&P_accum)[QO_fragments][d_head_accum_fragments],
accum_t (&l)[QO_fragments]) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
#pragma unroll
for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
l[q] += P_accum[q][d_head];
}
}
}
Softmax Epilogue
In the epilogue (after the mainloop), we shuffle the values between threads to get the final value of
template <int QO_fragments, int d_head_accum_fragments,
typename accum_t = float>
__forceinline__ __device__ constexpr void final_softmax_normalization(
accum_t (&O_accum)[QO_fragments][d_head_accum_fragments],
accum_t (&l)[QO_fragments]) {
// Finish summing row_sums across all threads in the same row.
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
l[q] += __shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, l[q], 2);
l[q] += __shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, l[q], 1);
}
// Final row-wise O softmax normalization.
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
#pragma unroll
for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
O_accum[q][d_head] /= l[q];
}
}
}
Synchronization
We’ve covered data movement and computation, but there’s one more critical piece: making sure threads don’t step on each other. Without proper barriers, we’d have race conditions where threads read stale data or overwrite data that other threads still need.
Required Barriers
Let’s think about what minimal amount of synchronization we need for a single tensor to avoid race conditions. Let’s assume for now that each block gets its own slice of SMEM.
At the warp level,
- each GMEM ↔ SMEM operation transfers a
tile of elements, - each SMEM → RF operation copies a
tile of elements, and - each RF → SMEM operation copies a
tile of elements.
The critical synchronization points occur between these memory transfers. Between GMEM → SMEM and SMEM → RF operations, threads in the same warp will access values that sibling threads store. This becomes more complex for
To handle these dependencies, we need different barrier types depending on the tensor. We require a warp-wide barrier (__syncwarp()
) for __syncthreads()
) for
Additionally, before any barrier, we need to call cp.wait
to ensure the data actually gets copied from the async operations.
Finally, between RF → SMEM and SMEM → GMEM operations, threads in the same warp will access values that sibling threads store (__syncwarp()
for the output tensor.
SMEM to RF Communications
Once the tensors are finally in SMEM, threads will only communicate with other threads in the same warp.
The instructions executed are either warp synchronous (like ldmatrix
, mma
, and warp shuffle) or don’t involve communication between other threads (like softmax operations excluding warp shuffle).
Synchronization Between Loop Iterations
Since we’ll copy multiple tiles of ldmatrix()
operations for all warps have finished executing before overwriting any tensor in SMEM. So, we’ll need a __syncthread()
barrier in between iterations.
Summary
Tensor | Barrier Needed In Between GMEM → SMEM & SMEM → RF | Barrier Needed In Between Iterations |
---|---|---|
__syncwarp() | N/A, only copied once | |
__syncthreads() | __syncthreads() | |
__syncthreads() | __syncthreads() | |
__syncwarp() (after RF → SMEM) | N/A, only copied once |
Note: Synchronization Scope for
and Specifying a scope at a finer granularity here doesn’t actually provide us much of a benefit in our case for 2 reasons:
- At the moment,
and are only copied to/from SMEM once, whereas and get copied multiple times. - In the first iteration, we are bottlenecked by the
__syncthreads()
needed for. In kernel 8, we will modify the layout in a way where we will also need CTA-wide synchronization.
Kernel 1: Base Implementation
Now that we’ve covered all the building blocks - data movement, GEMM operations, softmax computations, and synchronization requirements - we can assemble them into our first complete flash attention kernel.
Code Structure
We’ll divide the kernel into 3 sections following standard Cutlass terminology:
-
Prologue: This includes boilerplate setup like initializing the correct memory addresses, and copying
from GMEM → SMEM. -
Mainloop: This is where most of the logic goes, handling the iterative attention computation.
-
Epilogue: We’ll end with normalizing softmax and writing
back to GMEM.
Prologue
The prologue contains a lot of boilerplate setup.
template <typename Kernel>
__global__ void
flash_forward_kernel(__grid_constant__ const FAForwardArgs args) {
using accum_t = float;
using index_t = int64_t;
using value_t = typename Kernel::value_t;
// ...
// We initialize a CTA for each sample, seq tile, and head.
const int sample = blockIdx.z;
const int head = blockIdx.y;
const int q_seq_block = blockIdx.x;
const index_t gmem_seq_stride = args.seq_stride;
const index_t sample_head_offset =
sample * args.batch_stride + head * args.head_stride;
// We only read/write one block for Q and O.
// These offsets are the same for the whole thread-block.
const index_t QO_gmem_block_offset =
sample_head_offset + q_seq_block * CFG.B_r * gmem_seq_stride;
// We read the entire key sequence.
const index_t KV_gmem_block_offset = sample_head_offset;
value_t *gmem_Q = &static_cast<value_t *>(args.Q)[QO_gmem_block_offset];
value_t *gmem_O = &static_cast<value_t *>(args.O)[QO_gmem_block_offset];
value_t *gmem_K = &static_cast<value_t *>(args.K)[KV_gmem_block_offset];
value_t *gmem_V = &static_cast<value_t *>(args.V)[KV_gmem_block_offset];
extern __shared__ __align__(16) char ch_smem[];
value_t *smem_Q = reinterpret_cast<value_t *>(ch_smem);
value_t *smem_O = smem_Q;
value_t *smem_K = smem_Q;
value_t *smem_V = smem_K;
// MatrixLDST types
Q_t Q(gmem_Q, gmem_seq_stride, smem_Q);
K_t K(gmem_K, gmem_seq_stride, smem_K);
V_t V(gmem_V, gmem_seq_stride, smem_V);
// S is only stored in registers.
S_accum_t S_accum(nullptr, -1, nullptr);
// P is only stored in registers.
P_value_t P_b16(nullptr, -1, nullptr);
// The accumulator for O is only kept in registers. At the end of the kernel, it is then converted into a 16-bit type and then copied into gmem.
O_accum_t O_accum(nullptr, -1, nullptr);
O_value_t O_b16(gmem_O, gmem_seq_stride, smem_O);
// ...
// Start the async copy of the Q and K tiles.
Q.copy_GM2SM();
cp_async_commit();
O_accum.zero();
// Initialize softmax_scale, m, and l.
const accum_t softmax_scale = rsqrt(static_cast<accum_t>(CFG.d_head));
constexpr accum_t neg_inf = -cuda::std::numeric_limits<float>::infinity();
accum_t m[N::QO_fragments_per_warp];
accum_t l[N::QO_fragments_per_warp];
#pragma unroll
for (int q = 0; q < N::QO_fragments_per_warp; ++q) {
m[q] = neg_inf;
l[q] = 0.0;
}
cp_async_wait<0>();
__syncwarp();
Q.copy_SM2RF();
// ...
Mainloop
The mainloop implements the heart of the algorithm, which involves
- copying
from GMEM → SMEM → RF - computing
- computing softmax and rescaling
and - copying
from GMEM → SMEM → RF - computing
The barriers are on double duty here.
- Barrier 1 ensures that the
GMEM->SMEM
copy ofis complete before any warp starts the SMEM->RF
copy. At the same time, it ensures theSMEM->RF
copy offrom the previous iteration is complete before its space in SMEM is overwritten by the new tile. - Barrier 2 does the same for the opposite tensors: it ensures
is fully in SMEM before being read into the RF, and that is fully in RF before its SMEM slice is overwritten by the next tile.
// ...
for (int j = 0; j < args.n_KV_blocks; ++j) {
K.copy_GM2SM();
K.advance_gmem_block();
cp_async_commit();
S_accum.zero();
cp_async_wait<0>();
__syncthreads(); // <---- Barrier 1
K.copy_SM2RF();
matmul<Kernel::S_QK_GEMM>(Q, K, S_accum);
// Online softmax
accum_t m_next[N::QO_fragments_per_warp];
scale_S_accum(S_accum.data(), softmax_scale);
calc_row_max(S_accum.data(), m_next, m);
scale_l_O(m_next, m, l, O_accum.data());
scaled_exp(S_accum.data(), m_next);
update_row_exp_sum(S_accum.data(), l);
// Convert the S accumulator block into P bf16/fp16 input block.
convert_to_16_bit_dtype<value_t>(S_accum.data(), P_b16.data());
V.copy_GM2SM();
V.advance_gmem_block();
cp_async_commit();
cp_async_wait<0>();
__syncthreads(); // <---- Barrier 2
V.copy_SM2RF();
matmul<typename Kernel::O_PV_GEMM>(P_b16, V, O_accum);
}
// ...
Epilogue
The epilogue handles the final steps: normalizing the output with the final softmax values, converting
// ...
final_softmax_normalization(O_accum.data(), l);
convert_to_16_bit_dtype<value_t>(O_accum.data(), O_b16.data());
O_b16.copy_RF2SM();
__syncwarp();
// Copy the final O tile from SMEM to GMEM.
O_b16.copy_SM2GM();
}
Occupancy
Our block configuration is
Threads per CTA
SM_86
supports up to 1536 resident threads per SM. With 128 threads per CTA, we can have up to
Registers per Thread
On an RTX 3090 (SM_86
), we have access to 65536 registers per SM. Each thread can access up to 255 registers and anything needed above that gets spilled to local memory. You might think that we’re nowhere close to using 255, — and you’d be right in many cases, — but matrix multiplication can be very register heavy.
To get a sense of how heavy, let’s take a look at the RF storage required for each tensor.
Tensor element size
(bytes)mma
matrix
variableStorage
ShapeRegister
Count2 A 32 4 C/D 64 2 B 128 2 B 128 4 C/D 32 2 A 16 4 2 4 2
That’s a total of 404 registers! And that doesn’t even include other registers used for other purposes, like memory accesses. Thankfully we don’t have to store all of these at the same time. In fact, the only tensors that’ll persist in their entirety across every mainloop iteration (barring register spills) are
After compiling it down, the kernel will use a maximum of 202 registers per thread. This limits us to
12 Warps per SM
The threshold for hitting 12 warps per SM is 170 registers.
nvcc
has the option to artificially limit the number of registers per thread, but going from 202 to 170 would incur significant register spilling. This is a classic performance trade-off: limiting registers can increase occupancy (more concurrent warps), which helps hide memory latency, but it can also decrease single-thread performance if the compiler is forced to spill registers to slow local memory. For a compute-bound kernel like ours, the cost of spilling outweighs the benefit of higher occupancy. We’ll revisit register pressure in kernel 3.
Shared Memory per CTA
SM_86
supports up to 99KiB of SMEM per CTA. For our block configuration,
and take up and take up
Giving every tensor their own slice of SMEM would result in 64KiB, but to reach our goal of 2 CTAs per SM, we need to reduce that to 48KiB.
Fortunately, we can reduce SMEM usage by recognizing that
More generally, the SMEM needed per CTA is closely tied to how we order and synchronize accesses to the various slices of SMEM. While I won’t dive deep here, other tensors can also share their slice of SMEM with additional synchronization barriers. For example,
Summary
With 202 registers per thread and 48KB SMEM per CTA, our kernel will be able to have 8 resident warps on each SM.
Performance and What’s Next
Phew, that was a lot of buildup. Now, the moment you’ve been waiting for: how fast is the kernel?
We’re hitting 16.83 TFLOPS compared to the reference’s 34.04 TFLOPS - about 49% of reference performance. For a first kernel with no optimizations, this is pretty decent! We’ve got all the core functionality working correctly, and we’re already approaching half the performance of a highly optimized implementation. But there’s clearly room for improvement.
In the next part, we’ll start diagnosing our kernel and iteratively improve it with techniques like swizzling, double buffering, and instruction fusion to close the performance gap on the RTX 3090.
Footnotes
-
This is necessary since Volta because we aren’t able to assume thread divergence state. See Using CUDA Warp-Level Primitives | NVIDIA Technical Blog for more details. Having said that, in many cases the compiler will not actually emit a corresponding SASS instruction for
__syncwarp()
calls. ↩