In Part 2, we explored the fundamental CUDA building blocks - tensor core operations (mma) and efficient memory transfers (cp.async & ldmatrix). Now we’re ready to assemble these pieces into our first complete Flash Attention kernel. This base implementation will serve as our foundation, giving us a working kernel that we’ll iteratively optimize throughout the series.

We’ll tackle this in three main steps: first, we’ll figure out how to split up the work between different levels (CTAs, warps, and threads). Then we’ll build higher-level operations from our basic instructions. Finally, we’ll assemble everything into a complete kernel.

Here’s what we’re building towards:

  • forward pass only
  • non-causal attention
  • head dimension = 128
  • no dropout or KV caching
  • equal query/key/value sequence lengths
  • sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
  • 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32

By the end of this part, our kernel will reach 49% the performance of the official implementation on the RTX 3090.

Kernel Architecture Overview

Before diving into implementation details, let’s establish the overall structure of our kernel. Flash Attention kernels follow a standard three-phase pattern that we’ll use throughout this series:

Kernel Phases

  1. Prologue: Setup and initialization
  • Calculate memory addresses for each tensor and warp
  • Load from GMEM → SMEM → RF (happens once)
  • Initialize softmax statistics ( and ) and output accumulator
  1. Mainloop: Iterative attention computation (repeated for each / block)
  • Load : GMEM → SMEM → RF
  • Compute attention scores:
  • Apply online softmax and update statistics
  • Load : GMEM → SMEM → RF
  • Compute output contribution:
  1. Epilogue: Finalization and output
  • Complete softmax normalization
  • Convert from fp32 to fp16/bf16
  • Write final output: RF → SMEM → GMEM

Implementation Challenges

With our kernel structure established, we need to solve three main technical challenges:

  1. Data Movement: Getting tensors efficiently through the memory hierarchy (GMEM → SMEM → RF) while handling different layouts, access patterns, and synchronization requirements
  2. Mathematical Operations: Implementing GEMM operations and online softmax using our mma instructions and warp primitives
  3. Synchronization: Coordinating between threads and warps to avoid race conditions

The bulk of our complexity comes from data movement. Each tensor has different requirements: and are handled independently per warp, while and require cooperation across warps. Some tensors stay row-major throughout the memory hierarchy, others need transposition. Once we get the data where it needs to be, the actual math operations are relatively straightforward.

Let’s tackle these one by one, starting with the fundamental decisions that will shape our entire implementation.

Kernel Configuration

Now that we understand the overall structure, let’s nail down the specific parameters that will define our kernel’s behavior.

Starting now through kernel 7, we’ll choose our block sizes to be and . These dimensions work well together since they’re all nice multiples of our m16n8k16 instruction, and they’re small where tiles should be able to fit snugly into SMEM and the RF.

We’ll also use 128 threads (4 warps) per CTA, following CUDA’s recommended starting point.

In kernel 7, we’ll test and benchmark the performance of the kernel with different configurations to find the most performant.

CTA Work Distribution

We need to think about how to distribute work across three levels: CTAs, warps, and threads. Let’s start with the highest level: how we’ll distribute work across CTAs in our grid.

Our Flash Attention kernel needs to process tensors shaped like (batch_size, seq_len, n_heads, d_head) by tiling them into smaller blocks that fit into SM memory. Given our configuration, we’ll partition the tensors as follows:

  • and : Split into tiles
  • and : Split into tiles, denoted and

Each CTA is assigned to process one specific and block for a given (sample, head) pair. This will involve loading and computing with all blocks of and for that same (sample, head) pair, where .

This gives us n_samples * n_heads * T_r total query blocks to process, where T_r = seq_len / B_r. We’ll launch exactly that many CTAs to handle all the work.

Kernel Arguments

Here are the arguments our kernel will receive:

forward_kernel.cuh
struct FAForwardArgs {
    using index_t = int64_t;
 
    void *__restrict__ Q;
    void *__restrict__ K;
    void *__restrict__ V;
    void *__restrict__ O;
 
    const index_t batch_stride;
    const index_t seq_stride;
    const index_t head_stride;
 
    const index_t seq_len;
    const index_t n_heads;
 
    const int n_Q_blocks;
    const int n_KV_blocks;
};

Grid Mapping

Each CTA handles a query tile for a (sample, head) pair, so we’ll set the kernel grid to be some permutation of the shape (sample, query_block, head). How should we map CTAs to query blocks?

Let’s consider a single sample for a particular head. The CTAs for that sample may read different blocks of and write different blocks of , but they will all load the same and blocks.

Here’s the key insight: we want CTAs processing the same sample and head to run around the same time so they can share cached K/V data. When the first CTA loads a block, it gets cached in L2. If other CTAs for the same sample/head launch soon after, they’ll hit the cache instead of going to DRAM.

CTAs launch in order of their ID: blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z

Since CTAs with the same y and z but different x have consecutive IDs, they launch together. So we map query blocks to the x dimension: (x, y, z) -> (Q_block, head, batch).

If our mapping was (x, y, z) -> (head, Q_block, batch), our L2 cache would be far less effective since our hit rate would drop significantly.

So we’ll set up our variables like this:

forward_kernel.cuh
    // ...
    const int sample = blockIdx.z;
    const int head = blockIdx.y;
    const int q_seq_block = blockIdx.x;
    // ...

How does the grid mapping affect performance? Each and tile will get read by multiple CTAs, and with a size of 16KiB each, the L2 can fit multiple tiles. On an A100, L2 hits have 200-400 cycles of latency compared to DRAM’s ~550 cycles, according to Luo, et al. For Flash Attention, this cache reuse provides a small but measurable performance benefit for our compute-intensive workload.

Here’s how the grid mapping affects cache performance on our Flash Attention workload:

GPUL2 Cache SizeUnoptimized Hit RateOptimized Hit RatePerformance Impact
RTX 30906MB~2%~98%~3% performance hit
A10040MB~25.6%~92.6%~1% performance hit

The impact is modest, but since getting the mapping right is essentially free, it’s worth implementing.

Configuration Template Parameters

Even though we have a fixed kernel configuration, we’ll want to make it easy to generalize it if we want to test out different configurations (which we’ll do in kernel 7). We’ll add a struct template parameter to support different block sizes, number of threads per CTA, and different data types. This is the template parameter for the kernel function:

flash_attention.cuh
struct FlashForwardKernelConfig {
	// This gets statically converted into either half or nv_bfloat16
    const torch::ScalarType dtype;
    
    const int d_head;  // [64, 128]
    const int B_r;     // [64, 128]
    const int B_c;     // [32, 64, 128]
    const int n_warps; // [4, 8]. 8 only when B_r = 128
};

Warp-Level Work Distribution

Now that we’ve established how CTAs distribute work, let’s zoom in to the warp level. From here on, we’ll focus on warp-to-CTA and thread-to-warp interactions, so we’ll drop the superscript from and for simplicity.

Remember from Part 2 that tensor operations must execute in lockstep across all threads in a warp. Our mma instruction operates on 16×16 tiles, so we’ll use this as the basis for dividing work between warps.

Here’s how the math works out: we have 64-row blocks and 4 warps per CTA, so each warp gets 64÷4 = 16 rows. Combined with our 128-column head dimension, each warp handles (16, 128) sub-tiles.

We’ll treat how we load and differently from and .
Why the different strategies for / vs. /? While each warp is responsible for a unique slice of the and tensors, it must access the entire block of and to compute the attention scores for its slice. This means and can be handled largely independently per warp, while and require cooperation across all warps in the CTA.

and Processing (Independent per Warp)

  • The 64-row tile is split into 4 independent slices
  • Each warp operates completely independently on its slice for:
    • Loading/storing operations
    • GEMM computations

and Processing (Cooperative across Warps)

  • Loading phase: Each warp loads its own (16, 128) slice from GMEM → SMEM
  • Synchronization: All warps wait for the complete (64, 128) block to be loaded
  • Copy phase: Each warp independently copies the entire block from SMEM → RF

Per-Warp Workload Summary

Each of the 4 warps in a CTA processes:

Independent Operations (per warp):

  • Load slice of : GMEM → SMEM → RF
  • Compute with shape in RF
  • Compute in RF
  • Compute in RF
  • Store slice of : accumulated in RF, then RF → SMEM → GMEM

Cooperative Operations (across all warps):

  • Load blocks of and : cooperatively GMEM → SMEM, then independently SMEM → RF

The diagram below shows how each warp handles its slice of the work:

Every tile with a color background is stored in the RF.

Future optimization

It’s actually slightly more efficient for warps to cooperate when loading and blocks too. We’ll cover this later in part 5 (kernel 9).

Data Movement

Now that we understand the work distribution strategy, let’s tackle the implementation. Here’s where things get interesting - and complex. Moving data efficiently turns out to be the hairiest part of writing this kernel. We’re dealing with different layouts, access patterns, and synchronization needs all at once.

Each tensor has different requirements: and are handled independently per warp, while and require cooperation across warps. Some tensors stay row-major throughout the memory hierarchy, others need transposition. It’s a lot to juggle.

Our strategy is to build this complexity in layers, from low-level memory operations up to a clean abstraction that hides all the mess:

  1. Core memory operations: Generic functions for GMEM ⟷ SMEM transfers and specialized SMEM → RF functions that handle transposition
  2. Address management: Calculate the right pointers for each tensor and warp
  3. Unified interface: A MatrixLDST class that wraps everything into a simple, tensor-specific API

Let’s build this step by step.

Configuration Structs

To efficiently move data through the memory hierarchy (GMEM → SMEM → RF), we need operations that can handle different tensor requirements. Here are the configuration structs that encode all the LD/ST requirements:

load_store.cuh
struct TileLayout {
    const int row_fragments;
    const int col_fragments;
};
 
// constexpr non-type template parameter containing parameters for LD/ST for a
// block (Q, K, V, or O) from GMEM to SMEM and vice versa, and also loading from
// SMEM to the RF.
struct TensorLDSTConfig {
    // Tile layout for shared memory and RF.
    const TileLayout GSM;
    const TileLayout RF;
 
	// Block specific properties
    const bool transposed;
    const int block_size;
    const int smem_cols;
 
    // # of rows a warp in a thread-block independently loads/stores. it is equivalent to GSM.row_fragments * 8.
    const int warp_ldst_rows;
    // Whether not the warp will compute over the entire block.
    // This is false for (Q&O&S) and true for (K&V).
    const bool compute_over_entire_block;
}

Storage Layout

We use different storage layouts for each memory level and tensor:

  • SMEM: All tensors stored row-major for efficient loading from GMEM
  • RF: Most tensors stay row-major, but is stored transposed

Here’s a table to help you keep in mind what we need to transfer and the shapes at each memory level:

element size
(bytes)
GMEM+SMEM MajornessGMEM↔SMEM
Shape
SMEM Shape
(SMEM → RF)
RF MajornessRF Shape (Registers)
2Row majorRow major
2Row majorRow major
2Row majorColumn major
4N/AN/AN/ARow major
2N/AN/AN/ARow major
4N/AN/AN/ARow major
2Row majorRow major

And here are the LD/ST operations we’ll use to transfer these.

FromToBlocksPTX Instr. / C++Warp-Wide
Op Size
Thr.
Op Size
Thr. ID Mapping
Order
Register
Shape
Notes
GMEMSMEMcp.asyncRow-major
SMEMRFldmatrix.x4Column-major transpose
RFSMEMstandard (4B)Row-major
SMEMGMEMstandard (16B)
Row-major

Copying Between GMEM ⟷ SMEM

Now let’s implement the actual data movement. Recall that each warp-wide operation accesses a (4, 64) chunk, and our tile size is (16, 128), so we need (4, 2) cp.async() instructions to fully copy each block.

Here’s the code for copying between GMEM and SMEM. The key design choices:

  • Bidirectional: Works for both GMEM→SMEM and SMEM→GMEM by passing different operation functors
  • Thread coordination: Each thread calculates its offset in row-major order to avoid conflicts
  • Template-based: Uses the TensorLDSTConfig template parameter to handle different tensor layouts

The operation functors define the actual copy behavior. The following implementation handles both GMEM→SMEM and SMEM→GMEM transfers using template-based operation functors:

load_store.cuh
#define ROWS_PER_FRAGMENT 8
#define COLS_PER_FRAGMENT 8
#define GSM_LDST_ROWS_PER_ITER 4
#define BYTES_PER_VEC4_ACCESS 16
 
template <typename T>
struct GM2SM_async {
    __device__ constexpr void operator()(T *gmem, T *smem) {
        cp_async<BYTES_PER_VEC4_ACCESS>(smem, gmem);
    }
};
 
template <typename T>
struct SM2GM {
    __device__ constexpr void operator()(T *gmem, T *smem) {
        reinterpret_cast<uint4 *>(gmem)[0] = reinterpret_cast<uint4 *>(smem)[0];
    }
};
 
template <typename op, /* either GM2SM_async or SM2GM */
          TensorLDSTConfig CFG, typename value_t, typename index_t = int64_t>
__forceinline__ __device__ constexpr void copy_block_GSM(value_t *gmem, value_t *smem,
                                        index_t gmem_seq_stride,
                                        const int lane_id) {
    constexpr int n_row_iters =
        CFG.GSM.row_fragments * ROWS_PER_FRAGMENT / GSM_LDST_ROWS_PER_ITER;
 
    constexpr int col_fragments_per_iter = WARP_SIZE / GSM_LDST_ROWS_PER_ITER;
    constexpr int col_fragments_per_row = CFG.smem_cols / COLS_PER_FRAGMENT;
 
    const int thread_row = lane_id / col_fragments_per_iter;
    const int thread_col_fragment = lane_id % col_fragments_per_iter;
 
    #pragma unroll
    for (int r = 0; r < n_row_iters; ++r) {
        const int cur_row = r * GSM_LDST_ROWS_PER_ITER + thread_row;
        #pragma unroll
        for (int c = 0; c < col_fragments_per_row;
             c += col_fragments_per_iter) {
            const int col_fragment = c + thread_col_fragment;
 
            op()(&gmem[cur_row * gmem_seq_stride +
                       col_fragment * COLS_PER_FRAGMENT],
                 &smem[cur_row * CFG.smem_cols +
                       col_fragment * COLS_PER_FRAGMENT]);
        }
    }
}

A Note on reinterpret_cast

You’ll notice reinterpret_cast is used frequently. This is a common and safe pattern in high-performance CUDA for vectorized memory access. By re-interpreting pointers to larger types like uint4, we can load or store 16 bytes in a single instruction, provided the memory is correctly aligned, which it is in our case.

Additional Requirements for cp.async

When using cp.async(), we need to commit with cp.commit() and wait with cp.wait(). We’ll cover the synchronization details in Synchronization.

SMEM → RF Operations

You might recall from How Fragments Are Laid Out Across Threads that we store and load and differently from in the RF. The key difference is that while SMEM is stored row-major for all tensors, is stored col-major in the RF.

This layout difference requires transposing during the SMEM → RF copy: SMEM (row, col) gets transposed into RF (col, row). The transposition affects both the fragment arrangement and thread-level element storage, which is why we call ldmatrix_transpose() instead of ldmatrix() for . To handle these differences, we’ll create two separate helper functions.

An iteration for looks like

while an equivalent iteration for or looks like

and

load_store.cuh
#define ROWS_PER_FRAGMENT 8
#define COLS_PER_FRAGMENT 8
#define ELEMS_PER_VEC4_ACCESS 8
 
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_SM2RF(
    uint32_t (&regs)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
    const int lane_id, const int col_fragment_offset = 0) {
    constexpr int row_fragments_per_iter = 2;
    constexpr int rows_per_iter = ROWS_PER_FRAGMENT * row_fragments_per_iter;
 
    constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
    constexpr int col_fragments_per_iter = WARP_SIZE / rows_per_iter;
 
    const int thread_row = lane_id % rows_per_iter;
    const int thread_col_fragment = lane_id / rows_per_iter;
 
    #pragma unroll
    for (int r = 0; r < CFG.RF.row_fragments; r += row_fragments_per_iter) {
        const int cur_row = thread_row + r * ROWS_PER_FRAGMENT;
        #pragma unroll
        for (int c = 0; c < CFG.RF.col_fragments; c += col_fragments_per_iter) {
            const int smem_col_fragment = thread_col_fragment + c + col_fragment_offset;
 
            ldmatrix_x4(&smem[cur_row * CFG.smem_cols +
                              smem_col_fragment * ELEMS_PER_VEC4_ACCESS],
                        regs[r][c], regs[r + 1][c], regs[r][c + 1],
                        regs[r + 1][c + 1]);
        }
    }
}

  • we loop over SMEM in row → col, and vice versa for RF
  • instead of swapping the SMEM pointers to transpose the fragments, we’ll swap the RF indices
    • this allows the kernel to share same SMEM offset calculations with the other tiles
load_store.cuh
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_transposed_SM2RF(
    uint32_t (&regs)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
    const int lane_id, const int row_fragment_offset = 0) {
    constexpr int row_fragments_per_iter = 2;
    constexpr int rows_per_iter = ROWS_PER_FRAGMENT * row_fragments_per_iter;
 
    constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
    constexpr int col_fragments_per_iter = WARP_SIZE / rows_per_iter;
 
    const int thread_row = lane_id % rows_per_iter;
    const int thread_col_fragment = lane_id / rows_per_iter;
 
    #pragma unroll
    for (int r = 0; r < CFG.RF.col_fragments; r += row_fragments_per_iter) {
        const int cur_row =
            thread_row + (r + row_fragment_offset) * ROWS_PER_FRAGMENT;
        #pragma unroll
        for (int c = 0; c < CFG.RF.row_fragments; c += col_fragments_per_iter) {
            const int smem_col_fragment = thread_col_fragment + c;
 
            ldmatrix_x4_transpose(
                &smem[cur_row * CFG.smem_cols +
                      smem_col_fragment * ELEMS_PER_VEC4_ACCESS],
                regs[c][r], regs[c][r + 1], regs[c + 1][r], regs[c + 1][r + 1]);
        }
    }
}

Copying from RF → SMEM

The only block we copy from RF → SMEM is . Instead of ldmatrix, we use standard 4B smem[dst] = rf[src]; stores.

Each iteration of the loop stores a single tile back to SMEM. Since is for a single warp, this means we’ll need iterations to fully copy .

load_store.cuh
template <TensorLDSTConfig CFG, typename value_t>
__forceinline__ __device__ constexpr void copy_warp_fragment_RF2SM(
    uint32_t (&regs)[CFG.RF.row_fragments][CFG.RF.col_fragments], value_t *smem,
    const int lane_id) {
    constexpr int rows_per_iter = ROWS_PER_FRAGMENT;
    constexpr int col_fragments_per_iter = 1;
    constexpr int col_fragments = CFG.smem_cols / ELEMS_PER_VEC4_ACCESS;
 
    constexpr int elems_per_store = 2;
    const int thread_row = lane_id / 4;
    const int thread_inner_col = (lane_id % 4) * elems_per_store;
 
    #pragma unroll
    for (int r = 0; r < CFG.RF.row_fragments; ++r) {
        const int cur_row = thread_row + r * rows_per_iter;
        #pragma unroll
        for (int c = 0; c < CFG.RF.col_fragments; c += col_fragments_per_iter) {
            const int smem_col_fragment = c;
 
            reinterpret_cast<uint32_t *>(
                &smem[cur_row * CFG.smem_cols +
                      (smem_col_fragment * ELEMS_PER_VEC4_ACCESS +
                       thread_inner_col)])[0] = regs[r][c];
        }
    }
}

GMEM Address Calculations for Tensors

The address calculations for each block is fairly straightforward. This code calculates the pointers into the blocks for a particular sample, head, and query block. Since the warp specific addresses will be different for different tensors, we’ll add that logic instead to a tensor class that’ll wrap all the operations we need.

forward_kernel.cuh
    const int sample = ...;
    const int head = ...;
    const int q_seq_block = blockIdx.x;
 
	// ....
 
	using value_t = nv_bfloat16; // or half
 
    const index_t gmem_seq_stride = args.seq_stride;
 
    const index_t sample_head_offset =
        sample * args.batch_stride + head * args.head_stride;
    // We only read/write one block for Q and O.
    // These offsets are the same for the whole thread-block.
    const index_t QO_gmem_block_offset =
        sample_head_offset + q_seq_block * 64 * gmem_seq_stride;
    // We read the entire key sequence.
    const index_t KV_gmem_block_offset = sample_head_offset;
 
    value_t *gmem_Q = &static_cast<value_t *>(args.Q)[QO_gmem_block_offset];
    value_t *gmem_O = &static_cast<value_t *>(args.O)[QO_gmem_block_offset];
    value_t *gmem_K = &static_cast<value_t *>(args.K)[KV_gmem_block_offset];
    value_t *gmem_V = &static_cast<value_t *>(args.V)[KV_gmem_block_offset];

Tensor Abstraction Layer

Matrix LDST Class

The MatrixLDST class wraps all load and store operations, spanning all levels of the memory hierarchy. This abstraction handles the complexity of different tensor layouts and access patterns in one place.

Key features:

  • Unified interface for GMEM ⟷ SMEM ⟷ RF operations
  • Warp-specific address calculations
  • Support for both independent and cooperative loading patterns
  • Automatic handling of transposed layouts (for )
tensor.cuh
template <TensorLDSTConfig ldst, typename value_t, typename index_t = int64_t>
struct MatrixLDST {
    // Static configuration
    using matrix_storage_t =
        RFMatrix<value_t, ldst.mma_load_stages, ldst.RF.row_fragments,
                 ldst.RF.col_fragments>;
    using GM2SM_op = std::conditional_t<ldst.Common.async_copy,
                                        GM2SM_async<value_t>, GM2SM<value_t>>;
 
    using SM2GM_op = SM2GM<value_t>;
    static constexpr int mma_load_stages = ldst.mma_load_stages;
    static constexpr bool load_entire_block_into_rf =
        ldst.load_entire_block_into_rf;
    static constexpr bool transposed = ldst.transposed;
 
    // Runtime properties
    value_t *gmem_ptr;
    index_t gmem_seq_stride;
    // The location in memory used to load fragments from SMEM to RF.
    value_t *smem_srm_ptr;
    // The location in memory that the warp writes to for Q, K, V from GMEM to
    // smem and O for SMEM to GMEM.
    value_t *smem_gsm_ptr;
 
    const int lane_id;
 
    matrix_storage_t storage;
 
    __forceinline__ __device__ MatrixLDST(value_t *gmem_block_ptr, index_t _gmem_seq_stride,
                     value_t *_smem_ptr)
        : lane_id(threadIdx.x % WARP_SIZE) {
        const int warp_rank = threadIdx.x / WARP_SIZE;
 
        const index_t warp_seq = ldst.warp_ldst_rows * warp_rank;
 
        gmem_seq_stride = _gmem_seq_stride;
        gmem_ptr = gmem_block_ptr + warp_seq * gmem_seq_stride;
 
        smem_gsm_ptr = _smem_ptr + warp_seq * ldst.smem_cols;
        smem_srm_ptr =
            ldst.compute_over_entire_block ? _smem_ptr : smem_gsm_ptr;
    }
 
    __forceinline__ __device__ constexpr void zero() { storage.zero(); }
 
    __forceinline__ __device__ constexpr typename matrix_storage_t::storage_t (&data(
        const int stage = 0))[matrix_storage_t::rows][matrix_storage_t::cols] {
        return storage.data(stage);
    }
 
    __forceinline__ __device__ constexpr void advance_gmem_block() {
        gmem_ptr += ldst.block_size * gmem_seq_stride;
    }
 
    __forceinline__ __device__ constexpr void copy_GM2SM() {
        copy_block_GSM<GM2SM_op, ldst>(gmem_ptr, smem_gsm_ptr, gmem_seq_stride,
                                       lane_id);
    }
 
    __forceinline__ __device__ constexpr void copy_SM2GM() {
        copy_block_GSM<SM2GM_op, ldst>(gmem_ptr, smem_gsm_ptr, gmem_seq_stride,
                                       lane_id);
    }
 
    __forceinline__ __device__ constexpr void copy_SM2RF(int stage = 0, int tile_offset = 0) {
        if constexpr (!transposed) {
            copy_warp_fragment_SM2RF<ldst, value_t>(
                storage.data(stage), smem_srm_ptr, lane_id, tile_offset);
        } else {
            copy_warp_fragment_transposed_SM2RF<ldst, value_t>(
                storage.data(stage), smem_srm_ptr, lane_id, tile_offset);
        }
    }
 
    __forceinline__ __device__ constexpr void copy_RF2SM() {
        copy_warp_fragment_RF2SM<ldst, value_t>(data(), smem_srm_ptr, lane_id);
    }
};
 

Register Storage Classes

tensor.cuh
template <typename value_t, int N>
struct RFVector {
    static constexpr int size = N;
    value_t regs[N];
 
    __forceinline__ __device__ constexpr value_t &operator[](int idx) { return regs[idx]; }
};
 
template <typename value_t, int row_fragments, int col_fragments>
struct RFMatrix {
    using storage_t = std::conditional_t<sizeof(value_t) == 4, float, uint32_t>;
    static constexpr int regs_per_fragment = sizeof(value_t) / 2;
    static constexpr int rows = row_fragments;
    static constexpr int cols = col_fragments * regs_per_fragment;
 
    storage_t regs[rows][cols];
 
    __forceinline__ __device__ constexpr storage_t (&data(const int stage = 0))[rows][cols] {
        return reinterpret_cast<storage_t(&)[rows][cols]>(regs[stage]);
    }
 
    __forceinline__ __device__ constexpr void zero() {
		#pragma unroll
		for (int j = 0; j < rows; ++j) {
			#pragma unroll
			for (int k = 0; k < cols; ++k) {
				regs[j][k] = 0;
			}
		}
    }
};

Type Conversion

There are a couple tensor tiles we’ll need to convert from 32-bit to 16-bit. We’ll need to convert the attention matrix once per iteration to attend to the value vectors, and once in the end before writing it back to GMEM. Converting between 32-bit and 16-bit is straightforward, but there’s a bit of bookkeeping with the data types:

template <typename value_t, int M_fragments, int N_fragments>
__forceinline__ __device__ constexpr void
convert_to_16_bit_dtype(float (&src_float)[M_fragments][N_fragments * 2],
                        uint32_t (&dest_uint)[M_fragments][N_fragments]) {
    using value2_t =
        std::conditional_t<std::is_same_v<value_t, half>, half2, nv_bfloat162>;
 
    float2(&src)[M_fragments][N_fragments] =
        reinterpret_cast<float2(&)[M_fragments][N_fragments]>(src_float);
    value2_t(&dest)[M_fragments][N_fragments] =
        reinterpret_cast<value2_t(&)[M_fragments][N_fragments]>(dest_uint);
    #pragma unroll
    for (int m = 0; m < M_fragments; ++m) {
        #pragma unroll
        for (int n = 0; n < N_fragments; ++n) {
            if constexpr (std::is_same_v<value_t, half>) {
                dest[m][n] = __float22half2_rn(src[m][n]);
            } else {
                dest[m][n] = __float22bfloat162_rn(src[m][n]);
            }
        }
    }
}

Computing Operations

With our data movement machinery in place, we can turn to the mathematical operations. The two main operations we need are GEMM (for and ) and softmax. While GEMM builds directly on our mma primitives, the online softmax involves some clever thread coordination and statistical updates that are worth walking through carefully.

GEMM Implementation

MMA Operation Overview

A GEMM is composed of mma instructions that span all the fragments

As you might remember from part 2, each individual mma instruction performs , where for our specific instruction:

  • has shape (m, k) = (16, 16)
  • has shape (n, k) = (8, 16) and
  • & have shape (m, n) = (16, 8)

Here are the operand shapes and iteration patterns:

AA Shape
(Registers)
BB Shape
(Registers)
Iteration Shape
(k, m, n)

Each GEMM iteration covers (2, 2, 1) registers in the (k, m, n) dimensions, with operand register shapes of for the A matrix and for B.

Here’s a diagram of a single iteration

Implementation

To calculate the results across all fragments, we’ll create a triple nested loop in K -> M -> N order.

gemm.cuh
template <typename value_t, const int M_fragments, const int N_fragments,
          const int K_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void warp_fragment_mma_f32_accum(
    uint32_t (&regs_A)[M_fragments][K_fragments],
    uint32_t (&regs_B)[N_fragments][K_fragments],
    accum_t (&regs_C)[M_fragments][N_fragments * N_REGS_PER_F32_ACCUM_FRAGMENT]) {
    #pragma unroll
    for (int k = 0; k < K_fragments; k += MMA_K_FRAGMENTS_PER_ITER) {
        #pragma unroll
        for (int m = 0; m < M_fragments; m += MMA_M_FRAGMENTS_PER_ITER) {
            #pragma unroll
            for (int n = 0; n < N_fragments; n += MMA_N_FRAGMENTS_PER_ITER) {
                mma_m16n8k16_f32_accum<value_t>(
                    regs_C[m][n * 2],
                    regs_C[m][n * 2 + 1],
                    regs_C[m + 1][n * 2],
                    regs_C[m + 1][n * 2 + 1],
                    regs_A[m][k],
                    regs_A[m + 1][k],
                    regs_A[m][k + 1],
                    regs_A[m + 1][k + 1],
                    regs_B[n][k],
                    regs_B[n][k + 1],
                    regs_C[m][n * 2],
                    regs_C[m][n * 2 + 1],
                    regs_C[m + 1][n * 2],
                    regs_C[m + 1][n * 2 + 1]);
            }
        }
    }
}

Softmax

Thread-Level Perspective

For all the operations we covered until now, we viewed each warp as its own single cohesive unit and determined how each individual warp cooperates with its sibling warps. Now, we’ll look at the softmax operations from the perspective of a thread and how each thread will cooperate with their sibling threads in the same warp.

Each warp and thread operates on the data they store, with the workload distributed evenly and no extra LD/ST required to begin operating.

We compute softmax in 32-bit, which has both element-wise and row-wise operations.

Initialize Row Statistics

For both the row max () and row exponentiated sum (), each row will take up a single 32-bit register. We initialize them to and 0.0 respectively.

forward_kernel.cuh
	// ...
    constexpr accum_t neg_inf = -cuda::std::numeric_limits<float>::infinity();
    accum_t m[N::QO_fragments_per_warp];
    accum_t l[N::QO_fragments_per_warp];
    #pragma unroll
    for (int q = 0; q < N::QO_fragments_per_warp; ++q) {
        m[q] = neg_inf;
        l[q] = 0.0;
    }
	// ...

Dot-Product Scaling

We’ll perform the dot-product scaling of by multiplying each element by .

softmax.cuh
   // ...
const float softmax_scale = rsqrt(static_cast<float>(CFG.d_head));
   // ...
 
template <int QO_fragments, int KV_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_S_accum(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
              const accum_t &softmax_scale) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        #pragma unroll
        for (int k = 0; k < KV_accum_fragments; ++k) {
            S_accum[q][k] *= softmax_scale;
        }
    }
}

Reductions using Warp Shuffles

Since each thread stores different parts of the fragments, we use warp shuffles to avoid using shared memory to communicate between threads when performing reductions. For each shuffle instruction, a thread can receive and send a value to another thread in the same warp.

Since we want every thread to have the same final value at the end of each reductions, the warp shuffle we use is __shfl_xor_sync(WARP_MASK, val_to_share, xor_offset).

  • this will send the value val_to_share
  • it will read the value from tid ^ xor_offset, where tid is the calling threads ID

We reduce over the values within a thread and then warp shuffle twice to reduce between threads in the thread row quartet.

This is the code for the row max reduction:

softmax.cuh
// This mask indicates that every thread in the warp participates in the shuffle
#define SHFL_ENTIRE_WARP_MASK 0xffffffff
 
template <int QO_fragments, int KV_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
calc_row_max(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
             accum_t (&m_next)[QO_fragments], accum_t (&m_cur)[QO_fragments]) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        m_next[q] = m_cur[q];
 
        // Calculate max for row across all in-thread registers.
        #pragma unroll
        for (int k = 0; k < KV_accum_fragments; ++k) {
            m_next[q] = max(m_next[q], S_accum[q][k]);
        }
 
        // Group reduction
        m_next[q] = max(__shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, m_next[q], 2),
                        m_next[q]);
        m_next[q] = max(__shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, m_next[q], 1),
                        m_next[q]);
    }
}

In the reduction, after each call of __shfl_xor_sync(), each thread will double the amount of information they have.

  • First xor with offset 2
  • Second xor with offset 1.

and Rescaling

This function scales the previous iteration accumulators for and to the current block’s row max:

softmax.cuh
template <int QO_fragments, int d_head_accum_fragments,
          typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_l_O(accum_t (&m_next)[QO_fragments], accum_t (&m_cur)[QO_fragments],
          accum_t (&l)[QO_fragments],
          accum_t (&O_accum)[QO_fragments][d_head_accum_fragments]) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        accum_t scale;
        scale = expf(m_cur[q] - m_next[q]);
        m_cur[q] = m_next[q];
        l[q] *= scale;
        for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
            O_accum[q][d_head] *= scale;
        }
    }
}

Exponentiating

softmax.cuh
template <int QO_fragments, int KV_accum_fragments,
          typename accum_t = float>
__forceinline__ __device__ constexpr void
scaled_exp(accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
           accum_t (&m)[QO_fragments]) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        #pragma unroll
        for (int k = 0; k < KV_accum_fragments; ++k) {
            S_accum[q][k] = expf(S_accum[q][k] - m[q]);
        }
    }
}

Partial Reduction

We row-wise sum the local block of into . Since isn’t actually used until the end to normalize , we don’t actually need to warp shuffle until we reach that point. Since warp shuffles are a relatively expensive operation, we avoid doing so. This is pretty much the same code as calc_row_max, except that we sum instead of taking the max and we skip the shuffles.

Note: P_accum has the same underlying storage as S_accum.

softmax.cuh
template <int QO_fragments, int d_head_accum_fragments,
          typename accum_t = float>
__forceinline__ __device__ constexpr void
update_row_exp_sum(accum_t (&P_accum)[QO_fragments][d_head_accum_fragments],
                   accum_t (&l)[QO_fragments]) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        #pragma unroll
        for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
            l[q] += P_accum[q][d_head];
        }
    }
}

Softmax Epilogue

In the epilogue (after the mainloop), we shuffle the values between threads to get the final value of and normalize :

softmax.cuh
template <int QO_fragments, int d_head_accum_fragments,
          typename accum_t = float>
__forceinline__ __device__ constexpr void final_softmax_normalization(
    accum_t (&O_accum)[QO_fragments][d_head_accum_fragments],
    accum_t (&l)[QO_fragments]) {
    // Finish summing row_sums across all threads in the same row.
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        l[q] += __shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, l[q], 2);
        l[q] += __shfl_xor_sync(SHFL_ENTIRE_WARP_MASK, l[q], 1);
    }
 
    // Final row-wise O softmax normalization.
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        #pragma unroll
        for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
            O_accum[q][d_head] /= l[q];
        }
    }
}

Synchronization

We’ve covered data movement and computation, but there’s one more critical piece: making sure threads don’t step on each other. Without proper barriers, we’d have race conditions where threads read stale data or overwrite data that other threads still need.

Required Barriers

Let’s think about what minimal amount of synchronization we need for a single tensor to avoid race conditions. Let’s assume for now that each block gets its own slice of SMEM.

At the warp level,

  • each GMEM ↔ SMEM operation transfers a tile of elements,
  • each SMEM → RF operation copies a tile of elements, and
  • each RF → SMEM operation copies a tile of elements.

The critical synchronization points occur between these memory transfers. Between GMEM → SMEM and SMEM → RF operations, threads in the same warp will access values that sibling threads store. This becomes more complex for and , where warps will also access values that other sibling warps store.

To handle these dependencies, we need different barrier types depending on the tensor. We require a warp-wide barrier (__syncwarp()) for 1 and a CTA-wide barrier (__syncthreads()) for and .

Additionally, before any barrier, we need to call cp.wait to ensure the data actually gets copied from the async operations.

Finally, between RF → SMEM and SMEM → GMEM operations, threads in the same warp will access values that sibling threads store (), so we also need a __syncwarp() for the output tensor.

SMEM to RF Communications

Once the tensors are finally in SMEM, threads will only communicate with other threads in the same warp.

The instructions executed are either warp synchronous (like ldmatrix, mma, and warp shuffle) or don’t involve communication between other threads (like softmax operations excluding warp shuffle).

Synchronization Between Loop Iterations

Since we’ll copy multiple tiles of and into the same location in SMEM, we need to ensure all ldmatrix() operations for all warps have finished executing before overwriting any tensor in SMEM. So, we’ll need a __syncthread() barrier in between iterations.

Summary

TensorBarrier Needed In Between
GMEM → SMEM & SMEM → RF
Barrier Needed
In Between Iterations
__syncwarp()N/A, only copied once
__syncthreads()__syncthreads()
__syncthreads()__syncthreads()
__syncwarp() (after RF → SMEM)N/A, only copied once

Note: Synchronization Scope for and

Specifying a scope at a finer granularity here doesn’t actually provide us much of a benefit in our case for 2 reasons:

  1. At the moment, and are only copied to/from SMEM once, whereas and get copied multiple times.
  2. In the first iteration, we are bottlenecked by the __syncthreads() needed for .

In kernel 8, we will modify the layout in a way where we will also need CTA-wide synchronization.

Kernel 1: Base Implementation

Now that we’ve covered all the building blocks - data movement, GEMM operations, softmax computations, and synchronization requirements - we can assemble them into our first complete flash attention kernel.

Code Structure

We’ll divide the kernel into 3 sections following standard Cutlass terminology:

  1. Prologue: This includes boilerplate setup like initializing the correct memory addresses, and copying from GMEM → SMEM.

  2. Mainloop: This is where most of the logic goes, handling the iterative attention computation.

  3. Epilogue: We’ll end with normalizing softmax and writing back to GMEM.

Prologue

The prologue contains a lot of boilerplate setup.

forward_kernel.cuh
template <typename Kernel>
__global__ void
flash_forward_kernel(__grid_constant__ const FAForwardArgs args) {
    using accum_t = float;
    using index_t = int64_t;
    using value_t = typename Kernel::value_t;
 
	// ...
    // We initialize a CTA for each sample, seq tile, and head.
    const int sample = blockIdx.z;
    const int head = blockIdx.y;
    const int q_seq_block = blockIdx.x;
 
    const index_t gmem_seq_stride = args.seq_stride;
 
    const index_t sample_head_offset =
        sample * args.batch_stride + head * args.head_stride;
    // We only read/write one block for Q and O.
    // These offsets are the same for the whole thread-block.
    const index_t QO_gmem_block_offset =
        sample_head_offset + q_seq_block * CFG.B_r * gmem_seq_stride;
    // We read the entire key sequence.
    const index_t KV_gmem_block_offset = sample_head_offset;
 
    value_t *gmem_Q = &static_cast<value_t *>(args.Q)[QO_gmem_block_offset];
    value_t *gmem_O = &static_cast<value_t *>(args.O)[QO_gmem_block_offset];
    value_t *gmem_K = &static_cast<value_t *>(args.K)[KV_gmem_block_offset];
    value_t *gmem_V = &static_cast<value_t *>(args.V)[KV_gmem_block_offset];
 
    extern __shared__ __align__(16) char ch_smem[];
    value_t *smem_Q = reinterpret_cast<value_t *>(ch_smem);
    value_t *smem_O = smem_Q;
    value_t *smem_K = smem_Q;
    value_t *smem_V = smem_K;	
 
	// MatrixLDST types
    Q_t Q(gmem_Q, gmem_seq_stride, smem_Q);
    K_t K(gmem_K, gmem_seq_stride, smem_K);
    V_t V(gmem_V, gmem_seq_stride, smem_V);
    // S is only stored in registers.
    S_accum_t S_accum(nullptr, -1, nullptr);
    // P is only stored in registers.
    P_value_t P_b16(nullptr, -1, nullptr);
    // The accumulator for O is only kept in registers. At the end of the kernel, it is then converted into a 16-bit type and then copied into gmem.
    O_accum_t O_accum(nullptr, -1, nullptr);
    O_value_t O_b16(gmem_O, gmem_seq_stride, smem_O);
 
	// ...
 
    // Start the async copy of the Q and K tiles.
    Q.copy_GM2SM();
    cp_async_commit();
    O_accum.zero();
 
    // Initialize softmax_scale, m, and l.
    const accum_t softmax_scale = rsqrt(static_cast<accum_t>(CFG.d_head));
    constexpr accum_t neg_inf = -cuda::std::numeric_limits<float>::infinity();
    accum_t m[N::QO_fragments_per_warp];
    accum_t l[N::QO_fragments_per_warp];
    #pragma unroll
    for (int q = 0; q < N::QO_fragments_per_warp; ++q) {
        m[q] = neg_inf;
        l[q] = 0.0;
    }
 
	cp_async_wait<0>();
	__syncwarp();
	Q.copy_SM2RF();
	
	// ...

Mainloop

The mainloop implements the heart of the algorithm, which involves

  1. copying from GMEM → SMEM → RF
  2. computing
  3. computing softmax and rescaling and
  4. copying from GMEM → SMEM → RF
  5. computing

The barriers are on double duty here.

  • Barrier 1 ensures that the GMEM->SMEM copy of is complete before any warp starts the SMEM->RF copy. At the same time, it ensures the SMEM->RF copy of from the previous iteration is complete before its space in SMEM is overwritten by the new tile.
  • Barrier 2 does the same for the opposite tensors: it ensures is fully in SMEM before being read into the RF, and that is fully in RF before its SMEM slice is overwritten by the next tile.
forward_kernel.cuh
	// ...
	
	for (int j = 0; j < args.n_KV_blocks; ++j) {
		K.copy_GM2SM();
		K.advance_gmem_block();
		cp_async_commit();
        S_accum.zero();
        cp_async_wait<0>();
        __syncthreads(); // <---- Barrier 1
        
		K.copy_SM2RF();
 
        matmul<Kernel::S_QK_GEMM>(Q, K, S_accum);
 
        // Online softmax
        accum_t m_next[N::QO_fragments_per_warp];
		scale_S_accum(S_accum.data(), softmax_scale);
        calc_row_max(S_accum.data(), m_next, m);
        scale_l_O(m_next, m, l, O_accum.data());
        scaled_exp(S_accum.data(), m_next);
        update_row_exp_sum(S_accum.data(), l);
 
        // Convert the S accumulator block into P bf16/fp16 input block.
        convert_to_16_bit_dtype<value_t>(S_accum.data(), P_b16.data());
 
		V.copy_GM2SM();
		V.advance_gmem_block();
		cp_async_commit();
		cp_async_wait<0>();
		__syncthreads(); // <---- Barrier 2
		V.copy_SM2RF();
 
        matmul<typename Kernel::O_PV_GEMM>(P_b16, V, O_accum);
    }
    
	// ...

Epilogue

The epilogue handles the final steps: normalizing the output with the final softmax values, converting from fp32 to the 16-bit output data type and writing it back from SMEM → GMEM.

forward_kernel.cuh
    // ...
    
    final_softmax_normalization(O_accum.data(), l);
 
    convert_to_16_bit_dtype<value_t>(O_accum.data(), O_b16.data());
 
    O_b16.copy_RF2SM();
 
    __syncwarp();
 
    // Copy the final O tile from SMEM to GMEM.
    O_b16.copy_SM2GM();
}

Occupancy

Our block configuration is , and we decided to use 4 warps for each CTA for a total of 128 threads. Let’s determine the number of warps that can be active on an SM at once.

Threads per CTA

SM_86 supports up to 1536 resident threads per SM. With 128 threads per CTA, we can have up to per SM. This shouldn’t be a limiting factor.

Registers per Thread

On an RTX 3090 (SM_86), we have access to 65536 registers per SM. Each thread can access up to 255 registers and anything needed above that gets spilled to local memory. You might think that we’re nowhere close to using 255, — and you’d be right in many cases, — but matrix multiplication can be very register heavy.

To get a sense of how heavy, let’s take a look at the RF storage required for each tensor.

Tensorelement size
(bytes)
mma matrix
variable
Storage
Shape
Register
Count
2A32
4C/D64
2B128
2B128
4C/D32
2A16
42
42

That’s a total of 404 registers! And that doesn’t even include other registers used for other purposes, like memory accesses. Thankfully we don’t have to store all of these at the same time. In fact, the only tensors that’ll persist in their entirety across every mainloop iteration (barring register spills) are and . This sums up to 100 registers. The rest will be loaded on demand.

After compiling it down, the kernel will use a maximum of 202 registers per thread. This limits us to

12 Warps per SM

The threshold for hitting 12 warps per SM is 170 registers. nvcc has the option to artificially limit the number of registers per thread, but going from 202 to 170 would incur significant register spilling. This is a classic performance trade-off: limiting registers can increase occupancy (more concurrent warps), which helps hide memory latency, but it can also decrease single-thread performance if the compiler is forced to spill registers to slow local memory. For a compute-bound kernel like ours, the cost of spilling outweighs the benefit of higher occupancy. We’ll revisit register pressure in kernel 3.

Shared Memory per CTA

SM_86 supports up to 99KiB of SMEM per CTA. For our block configuration, :

  • and take up
  • and take up

Giving every tensor their own slice of SMEM would result in 64KiB, but to reach our goal of 2 CTAs per SM, we need to reduce that to 48KiB.

Fortunately, we can reduce SMEM usage by recognizing that and accesses never overlap, allowing them to share the same memory space. This gets us right to the threshold of 48KiB — perfect!

More generally, the SMEM needed per CTA is closely tied to how we order and synchronize accesses to the various slices of SMEM. While I won’t dive deep here, other tensors can also share their slice of SMEM with additional synchronization barriers. For example, can share the same slice with , and could potentially share as well, since we’re currently holding the entirety of in the RF.

Summary

With 202 registers per thread and 48KB SMEM per CTA, our kernel will be able to have 8 resident warps on each SM.

Performance and What’s Next

Phew, that was a lot of buildup. Now, the moment you’ve been waiting for: how fast is the kernel?

We’re hitting 16.83 TFLOPS compared to the reference’s 34.04 TFLOPS - about 49% of reference performance. For a first kernel with no optimizations, this is pretty decent! We’ve got all the core functionality working correctly, and we’re already approaching half the performance of a highly optimized implementation. But there’s clearly room for improvement.

In the next part, we’ll start diagnosing our kernel and iteratively improve it with techniques like swizzling, double buffering, and instruction fusion to close the performance gap on the RTX 3090.

Footnotes

  1. This is necessary since Volta because we aren’t able to assume thread divergence state. See Using CUDA Warp-Level Primitives | NVIDIA Technical Blog for more details. Having said that, in many cases the compiler will not actually emit a corresponding SASS instruction for __syncwarp() calls.