In this 7-part series, we’re going to implement Flash Attention 2 from scratch on Ampere GPUs. We’ll build an initial implementation and optimize it over 16 kernel iterations, all without importing any external libraries. By the final kernel, we’ll reach 99.2% the performance of the official implementation on the A100 and 102.9% on the RTX 3090 (at sequence length 4096).
You can find the code here.
Prerequisites
Coming in, you should have a fair bit of experience with CUDA programming and be familiar with flash attention. This isn’t an exhaustive list, but should give a decent idea of what you should be familiar with.
CUDA Programming (Intermediate Level):
- How warps are scheduled and executed
- The GPU memory hierarchy (registers, shared memory, L1/L2 cache, DRAM)
- Occupancy
- Tiling
- Parallel reductions
- What bank conflicts are and why they matter
- What row and column major are
Deep Learning Concepts:
- Self-attention mechanisms
- Flash Attention algorithm
If you need a refresher on Flash Attention, check out the Flash Attention 2 paper or this ELI5 blog post.
Why Flash Attention?
Flash Attention is worth implementing from scratch for two compelling reasons:
-
It’s one of the most impactful innovations made in ML engineering. Attention scales quadratically in compute and memory with sequence length, and this bottleneck becomes increasingly critical as demand for longer sequences grows. Flash Attention made huge leaps towards this by
- Completely solving the memory scaling issue by reducing memory complexity from
to - Mitigating the compute issue by significantly increasing data reuse in fast, on-chip memory
- Completely solving the memory scaling issue by reducing memory complexity from
-
Flash attention is also a complex algorithm with unique optimization challenges beyond standard GEMM kernels. It combines two back-to-back GEMMs with additional state management and a non-trivial FP workload. The algorithm complexity makes it an excellent advanced GPU programming exercise
Why Ampere?
Ampere is the most recent generation where consumer and HPC GPUs share the exact same CUDA APIs. Newer generations (Hopper vs Ada, Blackwell) have diverged significantly: HPC accelerators get exclusive tensor pipeline features and APIs that aren’t available on consumer cards.
This makes Ampere an interesting target for GPU optimization because we can write identical code and compare performance across very different hardware configurations on level ground. Ampere’s HPC and consumer lineups have dramatically different performance characteristics:
- A100 (HPC): 312 tensor TFLOPs / 19.5 FP32 TFLOPs peak throughput
- RTX 3090 (Consumer): 71 tensor TFLOPs / 35.6 FP32 TFLOPs peak throughput
We’ll see how this performance disparity plays out as we make progress through the different kernels.
Since we’re targeting Ampere, we’ll exclude newer Hopper/Blackwell features like:
- TMA related instructions (
cp.async.bulk
) - async warp group mmas (
wgmma
) (Hopper) - 5th gen tensor operations (Blackwell)
Kernel Specification
To keep this series focused and manageable, we’ll narrow our focus to a well-defined slice of features:
- forward pass only
- non-causal attention
- head dimension = 128
- no dropout or KV caching
- equal query/key/value sequence lengths
- sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
- 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32
Full Overview
Over seven blog posts, we’ll implement Flash Attention 2 from scratch. We’ll build up to a basic kernel that achieves just 49.5% of reference performance and systematically optimize through 16 iterations.
The story has an interesting twist: we’ll develop kernels 1 → 7 primarily on an RTX 3090. At the end of kernel 7, we’ll find that our kernel will perform around 101% on the RTX 3090 compared to the reference implementation. However, we’ll run the same kernel on the A100 and see a huge drop to ~80%. To close this gap, we’ll need to explore advanced optimization techniques beyond standard GEMM approaches like low level assembly analysis. This’ll lead us through the kernels 8 → 16, where we’ll land within 1% of reference implementation with kernel 16.
- Intro and Overview (this post)
- Building Blocks
- We’ll cover the key CUDA operations and primitives we’ll use, which include tensor core operations (
mma
) and memory transfers (cp.async
&ldmatrix
). - These will be the most performant instructions available on the Ampere architecture.
- We’ll cover the key CUDA operations and primitives we’ll use, which include tensor core operations (
- RTX 3090: Base Implementation
- We’ll implement a basic kernel using the instructions we went over in the previous part and get a working version running on an RTX 3090.
- RTX 3090: Kernels 2 to 7
- We’ll build on top of the base implementation by iteratively optimizing on an RTX 3090 using techniques utilized in Cutlass and various other methods.
- By the end, we’ll surpass reference performance by ~1%. However, when we compare our kernel with the reference on an A100, we’ll find significant performance gaps.
- A100: Kernels 8 to 11
- We’ll start by analyzing the cause for the significant performance gap between the RTX 3090 and the A100.
- Then, we’ll work on optimizing our kernel for the A100, primarily by reducing instruction count. This will involve a lot of low level assembly analysis, so you’ll want to get comfortable reading SASS code.
- We’ll also start taking a closer look at the Ampere microarchitecture.
- A100: Kernels 12-16
- We’ll make final optimizations to reach within 1% of reference performance.
- A100: Kernel Analysis
- We’ll do a deep dive analysis of the final kernel (kernel 16) and kernel 10, guided by previous research on Ampere microarchitecture.
Kernels
Foundation, CUTLASS Optimizations, and FP Fusion (Kernels 1-7)
- Base Implementation
- Avoiding Bank Conflicts
- Eagerly Loading K & V Blocks
- Interleaving On-Chip LD/ST with Computation
- Double Buffering Shared Memory to Register File Loads
- Improving FP32 Throughput *
- Auto-Tuning
A100-Focused Instruction-Level Optimizations (Kernels 8-11)
- Reducing
IADD3
,LOP3
, andSHF
instructions - Reducing
IMAD.MOV.U32
andMOV
instructions - Removing
CSRZ
Instructions + Optimizing Initial Softmax Iteration - Encoded Swizzling from the RF (register file) to SMEM
A100 Final Tuning (Kernels 12-16)
- Miscellaneous Code Changes
- Iterating Backwards *
- Cache Configuration
- Tiling along
d_head
* - Static GMEM Stride
*Optimizations inspired by the official implementation
Performance
For GEMM kernels, the standard reference for comparison is cuBLAS. For Flash Attention, we’ll benchmark against the official implementation.
We’ll focus on sequence length 4096 for most of our benchmarking. While this is longer than typical for bidirectional transformers, it strikes a good balance: shorter sequences don’t fully utilize the GPU and achieve lower TFLOPs, while longer sequences take too long to benchmark iteratively but achieve higher TFLOPs.
You can find more details on benchmarking setup and methodology in the appendix here.
Table (TFLOPs Relative to Reference)
Kernel Iteration | A100 | A100 | RTX 3090 | RTX 3090 |
---|---|---|---|---|
seq_len = 4096 | harm. mean | seq_len = 4096 | harm. mean | |
1. Base Implementation | 15.8% | 16.6% | 49.5% | 49.8% |
2. Avoiding Bank Conflicts | 72.6% | 72.4% | 98.3% | 98.6% |
3. Eagerly Loading K & V Blocks | 77.6% | 79.9% | 99.4% | 100.0% |
4. Interleaving On-Chip LD/ST with Computation | 77.6% | 80.0% | 100.0% | 100.4% |
5. Double Buffering Shared Memory to Register File Loads | 76.8% | 79.1% | 99.7% | 100.3% |
6. Improving FP32 Throughput | 78.1% | 80.4% | 99.9% | 100.4% |
7. Auto-Tuning | 80.3% | 82.3% | 101.5% | 101.8% |
8. Reducing IADD3 , LOP3 , and SHF instructions | 87.8% | 88.9% | 101.7% | 101.2% |
9. Reducing IMAD.MOV.U32 and MOV instructions | 95.3% | 96.3% | 97.5% | 97.4% |
10. Removing CSRZ Instructions + Optimizing Initial Softmax Iteration | 93.9% | 95.0% | 102.9% | 102.3% |
11. Encoded Swizzling from the RF to SMEM | 95.2% | 96.7% | 102.8% | 102.3% |
12. Miscellaneous Code Changes | 95.3% | 97.0% | 102.8% | 102.3% |
13. Iterating Backwards | 97.6% | 98.8% | 101.5% | 101.2% |
14. Cache Configuration | 97.7% | 99.1% | 101.5% | 101.2% |
15. Tiling along d_head | 97.9% | 99.5% | 101.5% | 101.3% |
16. Static GMEM Stride | 99.2% | 100.4% | 100.9% | 100.7% |
0. Reference (TFLOPs) | 94.28 | 88.01 | 34.04 | 33.48 |
The harmonic mean is taken over sequence lengths 512, 1024, 2048, 4096, 8192, 16384.
Notation & Terminology
I’ll mostly stick to standard CUDA terminology and the notation from the Flash Attention papers. I’ll include some of the terms and acronyms I use most frequently here, but you can find a full reference in Glossary.
- GMEM: global memory
- SMEM: shared memory
- LMEM: local memory
- RF: register file / register memory
: Query and output tensors handled by the current CTA : The -th key/value tile : query rows in the block : key and rows in the and blocks
Up Next
In Part 2, we’ll dive into the CUDA building blocks that enable high-performance Flash Attention on Ampere architecture: tensor core operations, asynchronous memory transfers, shared memory banking strategies, and warp-level programming patterns. These primitives will form the foundation for all our subsequent kernel optimizations.