Real kernels

You have the machine. This is the software that squeezes it. Three famous kernels, each just coalescing, tiling, and tensor cores aimed at a real piece of a model.

From the machine to the model

Every trick on the index was general: coalesce the loads, tile for reuse, pipeline the copies, shrink the numbers. Real kernels are those tricks aimed at one operation a model actually runs. Nothing new to learn, just the same moves put to real work.

And they all start with the same question from the roofline: is this operation memory-bound or compute-bound? The answer decides which trick matters. Attention (the operation that lets each word weigh every other word) starts memory-bound, so the win is moving fewer bytes. A big matmul is compute-bound, so the win is feeding the tensor cores in the lowest precision that still works.

Start with attention, the kernel that made this way of thinking famous.

running max m −∞
denominator l 0.00
output O / l
0 / 4 tiles

No tiles processed yet. Running max, denominator, and output all start empty.

Show as data
Online-softmax state after each key tile (4 keys per tile)
Tiletile maxmrescaleloutput
03.03.00.0001.5911.390
15.05.00.1351.3213.175
24.05.01.0001.8912.880
36.06.00.3681.7214.054

FlashAttention: never store the whole score matrix

Attention is how a model decides which earlier tokens (chunks of the input, roughly words) each token should look at. It scores every query (Q) against every key (K), turns those scores into weights with softmax (which squashes a row of numbers into positive fractions that add up to 1), then blends the values (V) by those weights. In symbols that is softmax(Q Kᵀ) V. The naive way builds the full score matrix first: for a sequence of length N, that is an N by N grid of numbers written to memory, softmaxed, then read back to multiply by V. For long sequences that matrix is enormous, and shuttling it in and out of memory is the whole cost. Attention is memory-bound.

FlashAttention refuses to store it. It tiles over the keys (Phase 2) and keeps just three running numbers per query: the max m, the softmax denominator l, and the output accumulator O. Each key tile updates them. The only subtlety is the online softmax: when a tile brings a larger score, the earlier sums are rescaled by exp(m_old − m_new) so everything stays on one scale. You stepped through exactly that above.

The score matrix never touches main memory. On real hardware the tiles stream in by TMA and the matmuls run on tensor cores while the next tile loads (warp specialization). Attention stops being memory-bound and starts keeping the tensor cores busy.

Quick checkFlashAttention does the same FLOPs as naive attention, sometimes more. Why is it faster?

FlashAttention 1 to 4

FlashAttention is not one kernel but a lineage. The math above never changes; each version re-choreographs it for the newest hardware and wrings out more overlap, exactly the "keep the tensor cores fed" story the hardware phases told.

  1. FlashAttention 1 (2022). The original: tile over the keys, keep the running max, denominator, and output, and never write the N by N score matrix. Attention stops being memory-bound. That is the scene you just stepped through.
  2. FlashAttention 2 (2023). Same math, better use of the GPU. It cuts the costly non-matmul work (the rescaling), parallelizes across the sequence dimension so long sequences fill the whole chip, and repartitions work across warps to cut shared-memory traffic. Roughly 2x faster than v1.
  3. FlashAttention 3 (2024). Rewritten for Hopper. It uses TMA to move tiles, warp specialization to split producers from consumers, and a ping-pong schedule that overlaps the softmax (on the special-function units) with the matmuls (on the tensor cores) so neither waits. Adds FP8 attention, and pushes the H100 far closer to peak.
  4. FlashAttention 4 (2025). The Blackwell-generation iteration, targeting the 5th-gen tensor cores, their single-thread async MMA (tcgen05), and FP4 microscaling. The same overlap-everything playbook, aimed at the newest machine.

The attention family

FlashAttention spawned an ecosystem. A few names you will meet, all built on the same tiled core:

  • SDPA (scaled_dot_product_attention). PyTorch's standard attention call, and a dispatcher: it picks a fused FlashAttention-style backend (flash, memory-efficient, or cuDNN) for your shapes and dtype automatically, so you rarely call the raw kernel yourself. PyTorch 2.10 added a variable-length (varlen) path, so a batch of different-length sequences skips the padding it used to waste, and 2.11 added native sliding-window attention.
  • FlexAttention. The escape hatch for custom attention. You supply two little Python functions: a score_mod that edits each score before the softmax (for relative-position bias, ALiBi, soft-capping) and a mask_mod that says which query-key pairs are even allowed (causal, sliding window, document or packed masking, prefix-LM). torch.compile fuses them into one FlashAttention-style kernel, forward and backward, and turns the mask into a block mask so fully-masked blocks are skipped instead of computed. Nearly any attention variant in a few lines, at fused-kernel speed, with no CUDA. (See the PyTorch team's FlexAttention blog post.)
  • SageAttention. Quantized attention: run the QKᵀ and PV matmuls in INT8 or FP4 on tensor cores, smoothing the few outlier channels first so the low precision does not wreck the result. It is the W4A4 idea aimed squarely at attention.

DeepGEMM: a matmul in FP8, safely

A transformer is mostly big matmuls, and matmul (a GEMM, general matrix multiply) is compute-bound, so the lever is precision. DeepGEMM (from DeepSeek) is a lean library that runs these matmuls in FP8 on Hopper, using the same tiling, TMA loads, and warpgroup tensor-core instructions you already met.

The catch is the accumulator, the running sum. Pick FP8 for it below and watch the dot product freeze far short of the truth; only a wide accumulator tracks the exact line.

Accumulator
FP8 running sum 32.0
Exact sum 63.9
Error 49.9%

fast but swamps: small terms vanish once the sum grows.

Accumulating 64 terms in FP8: running sum reaches 32.0 versus the exact 63.9, an error of 49.9 percent. fast but swamps: small terms vanish once the sum grows.

Show as data
Final sum and error by accumulator width, for 64 terms
AccumulatorMantissa bitsFinal sumError
FP8332.049.9%
FP161063.90.0%
FP322363.90.0%

The trick is staying accurate at 8 bits. DeepGEMM uses fine-grained scaling, a separate scale per small block of the matrix (the scaling ladder again), and it accumulates in FP32. The tensor cores multiply in FP8 for speed, but the running sum is promoted to full precision so rounding does not pile up over thousands of steps.

Mixture of experts: many small matmuls

A mixture of experts (MoE) layer replaces one big feed-forward network with many smaller experts, and a router sends each token to just a few of them. The model gets far more parameters without doing more math per token, because most experts sit idle for any given token.

Route the tokens below. Try skewed routing, or push the experts to 8, and watch the tensor-core utilization fall as tiles go half-empty or one expert becomes a straggler.

Experts
Routing
Top-k
Capacity

32 tokens, tensor cores work in tiles of 8. Aim for tall, full columns and no red.

Tensor-core utilization 100%
Busiest expert (straggler) 8 tok
Padding waste 0 slots
Tokens dropped 0

Balanced routing, top-1, 4 experts: tensor cores run 32 token-slots to do 32 real ones, so 100% useful. Busiest expert holds 8 tokens.

Show as data
Per-expert token counts, tiles, and padding (tile size 8)
ExpertKept tokensTilesPadded slotsDropped
E08100
E18100
E28100
E38100

On the GPU that turns into a grouped GEMM: instead of one large matmul, you run one matmul per expert, each on the handful of tokens routed to it. The challenge is imbalance. If everyone picks the same expert, one matmul is huge and the rest are empty, so the tensor cores stall. Libraries like ScatterMoE and SonicMoE pack the routed tokens tightly into contiguous groups so every expert's matmul is a full, efficient tile instead of a ragged one.

The other lever is a capacity limit: cap how many tokens any one expert will accept, and drop the overflow. That bounds the straggler and the memory it needs, at the cost of ignoring some tokens. Flip capacity on in the scene with skewed routing and you can watch the busiest expert shrink while the dropped-token counter climbs. That trade is the heart of MoE serving.

Which rung is each kernel on

These libraries are not written the same way. The authoring ladder runs from least effort to most control, and each kernel sits where its need for control puts it.

  1. PyTorch / torch.compile. Plain tensor code; the compiler fuses it and emits Triton. Fusing matters because each separate operation would otherwise write its result to HBM and the next would read it back; a fused kernel keeps the data on-chip and touches slow memory once, the Phase 0 lesson at the level of a whole graph. Where most MoE routing and elementwise glue lives.
  2. Triton. A Python DSL for block-level kernels. FlexAttention and many fused kernels compile through here.
  3. CUTLASS / CuTe. C++ templates for peak tensor-core GEMM, driven by the layout algebra from Phase 6. DeepGEMM and grouped-GEMM kernels live at this level.
  4. CUDA / PTX. The low-level instruction layer, hand-placed threads and address math. The FlashAttention core and DeepEP's transfers drop here for the last bit of speed.

That is the whole journey. From a single warp coalescing a load, up through tiling and tensor cores, out to the numbers themselves, and finally to the kernels that run the models you use every day. The one idea underneath all of it has not changed since Phase 0: move the bytes you need, and no more.

Thread
The smallest unit of work on a GPU. One thread runs one instance of the kernel on one lane.
Warp
A group of 32 threads that execute the same instruction together, in lockstep. The scheduling unit of the GPU.
SIMT
Single Instruction, Multiple Threads. All 32 threads of a warp run one shared instruction over their own data.
Lockstep
All threads in a warp advance together on the same instruction at the same time.
Coalescing
When a warp reads neighbouring addresses so the hardware serves them in as few memory transactions as possible.
Sector
The 32-byte unit the hardware fetches from global memory. A warp wants its data packed into as few sectors as possible.
HBM
High Bandwidth Memory, the large off-chip global memory. Biggest and slowest tier, hundreds of cycles away.
Shared memory
Fast on-chip scratchpad private to one thread block. Split into 32 banks.
Register
Per-thread on-chip storage. The fastest memory, about one cycle to access.
L2 cache
On-chip cache shared by all SMs, sitting between the per-SM L1 caches and global HBM.
SM
Streaming Multiprocessor. A core building block of the GPU that runs thread blocks. An A100 has 108 of them.
Sub-partition
One of the four processing blocks inside an SM. Each has its own warp scheduler and execution units.
Warp scheduler
The unit that picks one eligible warp each cycle and issues its next instruction. Four per SM.
Eligible
A warp that is ready to issue this cycle, not waiting on memory or a dependency.
Stalled
A warp that cannot issue yet because it is waiting, usually on a memory load.
Latency hiding
Keeping the machine busy during long waits by running other ready warps while one warp stalls.
Occupancy
The number of active warps on an SM divided by the maximum it can hold. More warps give the scheduler more to switch to.
Bank
One of the 32 slots shared memory is split into. Consecutive 4-byte words map to consecutive banks (word w lands in bank w mod 32), and each bank serves one word per cycle.
Bank conflict
When two or more threads in a warp want different words in the same bank. Their reads serialize.
Tiling
Loading a small block of a matrix into shared memory once so every thread in the block reuses it many times.
Data reuse
Using a value staged in fast memory many times before fetching new data, so slow global memory is touched as little as possible.
GEMM
General matrix multiply, C = A times B. The workhorse operation behind neural networks and the main thing GPUs are tuned for.
cp.async
An asynchronous copy from global memory straight into shared memory, without stalling the thread or passing through registers.
Software pipelining
Overlapping the load of the next tile with the compute on the current one, so memory latency hides behind useful work.
Double buffering
Using two shared-memory buffers that take turns: one is being computed on while the other is being filled by the next load.
Prefetch
Starting a load early, before the data is needed, so it has arrived by the time you use it.
Thread block
A group of threads that run on one SM and share its shared memory. Also called a CTA.
Register file
The pool of registers on an SM, shared out among all resident threads. About 256 KB on an A100.
Register pressure
How many registers a kernel needs per thread. High pressure means fewer warps fit, which lowers occupancy.
Register spilling
When a thread needs more registers than it has, the extra values spill to local memory, which actually lives in slow global memory.
Local memory
Per-thread memory that, despite the name, lives in slow off-chip global memory. Registers spill here when they run out.
TMA
Tensor Memory Accelerator. A Hopper copy engine that moves whole tensor tiles between global and shared memory from a single descriptor, so one thread issues the load.
Tensor map descriptor
A small host-built struct (128 bytes) that tells TMA the tensor base, shape, strides, tile size, element type, and swizzle. One thread passes it to issue a bulk copy.
mbarrier
An asynchronous barrier in shared memory. TMA signals it when a tile lands and waiting threads wake, handing each buffer stage from producer to consumer.
wgmma
Warpgroup matrix multiply. A Hopper instruction where 128 threads issue one asynchronous tensor-core matmul that reads its operands from shared memory.
Warpgroup
Four contiguous warps, 128 threads, the granularity Hopper wgmma operates on.
Warp specialization
Giving different warps different jobs: producer warps issue TMA loads while consumer warps run wgmma, overlapping load and compute.
Thread block cluster
A Hopper group of blocks co-scheduled on one GPC that can read each other’s shared memory (distributed shared memory).
Multicast
A TMA mode that broadcasts one global load into several blocks’ shared memory in a cluster, so a shared operand crosses the bus only once.
Tensor Memory
A dedicated on-SM memory on Blackwell (256 KB) that holds the MMA accumulator, so the register file no longer has to feed the tensor cores at FP4 rates.
tcgen05
Blackwell’s 5th-generation tensor-core MMA. A single thread issues the matmul for the whole block, reading operands from shared memory and TMEM.
Microscaling
Storing a block of low-precision values (say FP4) with one shared low-precision block scale (E8M0 for MXFP4, FP8 for NVFP4), so tiny formats stay accurate. NVFP4 uses a block of 16.
FP4
A 4-bit floating-point format (E2M1). Packs eight values per 32 bits and roughly quadruples tensor-core peak over FP16, at the cost of range.
Accumulator
The running sum D in D = A × B that a matmul builds up across its K steps. Where it lives (registers or TMEM) is a recurring bottleneck.
CuTe
The layout layer under CUTLASS. Expresses thread-to-data mappings as Shape ⊗ Stride and composes, tiles, and swizzles them at compile time.
CUTLASS
NVIDIA’s open template library for peak-performance GEMM and related kernels, built on CuTe layouts.
Layout
A CuTe object, Shape ⊗ Stride, that maps a logical coordinate to a linear memory offset. Change the stride to re-lay-out data without moving it.
Stride
How far apart, in memory, consecutive elements along an axis sit. A stride of 1 means contiguous, which is what makes a read coalesce.
Swizzle
A layout that permutes shared-memory addresses so a tile reads back bank-conflict-free and in the order the tensor cores want.
Roofline
A plot of attainable throughput against arithmetic intensity. An op is memory-bound under a bandwidth roof until enough reuse lifts it to the flat compute roof.
Arithmetic intensity
FLOPs performed per byte moved from memory. Low intensity is memory-bound; tiling raises it until the op becomes compute-bound.
Exponent
A float's exponent bits set how big or small it can get. More exponent means more dynamic range.
Mantissa
A float's mantissa bits set how finely it resolves values between powers of two. More mantissa means more precision.
BF16
Brain float 16: 8 exponent bits and 7 mantissa bits. Same range as FP32 with less precision, which is why it is the training default.
FP8
An 8-bit float in two flavors: E4M3 (more precision, forward pass) and E5M2 (more range, gradients). Introduced for tensor cores on Hopper.
Quantization
Storing weights or activations in fewer bits than they were trained in, usually with a shared scale factor to recover the real magnitudes.
INT8
An 8-bit integer format with evenly spaced steps and a shared scale. Cheap and tight when a tensor has no wild outliers.
Ternary (BitNet b1.58)
Weights restricted to three values, -1, 0, and +1, about 1.58 bits each. The matmul becomes addition and subtraction with no multiplies; trained from scratch, not compressed after.
NVFP4
NVIDIA's 4-bit float (E2M1) with one FP8 scale per 16 values plus a per-tensor FP32 scale. Finer blocks than MXFP4 for better accuracy.
MXFP4
The open OCP microscaling 4-bit float: one power-of-two scale (E8M0) shared across every block of 32 values.
GGUF
The llama.cpp file format that packs a model plus metadata for local inference. Holds k-quant tensor types like Q4_K_M.
k-quant
A GGUF quantization scheme: weights in super-blocks of 256, split into sub-blocks of 32, with a two-level (super-block and sub-block) scale.
Attention
The operation that lets each token weigh every other token: score queries against keys, softmax the scores, then blend the values.
Softmax
Turns a row of numbers into positive weights that sum to 1, so they can act as attention weights or class probabilities.
Token
One chunk of a model’s input or output, roughly a word or word-piece. Sequences are measured in tokens.
Gradient
The correction signal used to update a model during training. Gradients span a huge range of magnitudes, which is why their number format needs range.
Tensor core
A dedicated unit inside the SM that multiplies a small matrix in one instruction, far faster than ordinary threads doing it multiply by multiply. First shipped on Volta.
MMA
Matrix multiply-accumulate: the tensor core’s core operation, D = A × B + C. wgmma (Hopper) and tcgen05.mma (Blackwell) are MMA instructions.
FLOP
One floating-point operation, a single multiply or add. Throughput is measured in FLOPs per second (FLOP/s).
Activation
The data flowing through a model (the inputs and intermediate results), as opposed to the fixed weights. Activations carry outliers, which makes them harder to quantize.
W4A16
A quantization recipe: 4-bit weights, 16-bit activations. Weight-only, so it saves memory but computes in 16-bit. GPTQ, AWQ, and GGUF are W4A16.
W4A4
A quantization recipe: 4-bit weights and 4-bit activations. The matmul runs on low-precision tensor cores, saving compute too, but activation outliers make it hard. NVFP4 is W4A4.
SDPA
PyTorch's scaled_dot_product_attention: a dispatcher that auto-picks a fused FlashAttention-style backend (flash, memory-efficient, or cuDNN) for your shapes and dtype.
FlexAttention
A PyTorch API for writing custom attention masks and score modifications as a small function that still compiles to one fused FlashAttention-style kernel.
SageAttention
Quantized attention: runs the attention matmuls in INT8 or FP4 on tensor cores, smoothing outlier channels first. The W4A4 idea applied to attention.
SonicMoE
A mixture-of-experts kernel library that packs routed tokens into contiguous groups so each expert’s grouped-GEMM tile is full and efficient.
Data parallel
Replicate the whole model on every GPU, split the batch, and average gradients with an all-reduce. DDP is the efficient PyTorch version.
FSDP / ZeRO
Sharded data parallel: split the batch AND shard params, gradients, and optimizer state across GPUs, gathering each layer just in time. Trades communication for memory.
Tensor parallel
Split each layer's weight matrices across GPUs (Megatron splits heads and MLP columns) and all-reduce the activations. Chatty, so it wants NVLink.
Pipeline parallel
Put different layers on different GPUs as stages; activations flow stage to stage. The idle time while the pipeline fills and drains is the "bubble."
Expert parallel
Spread a mixture-of-experts layer’s experts across GPUs and all-to-all the tokens to wherever their expert lives. DeepEP accelerates the shuffle.
Context / sequence parallel
Split the sequence across GPUs. Sequence parallel saves activation memory on the non-matmul parts; context parallel (ring attention) splits attention over ultra-long contexts.
All-reduce
A collective that sums a value across all GPUs and hands every GPU the total. Used to average gradients (data parallel) and combine activations (tensor parallel).
All-to-all
A collective where every GPU sends a different piece to every other GPU. Used to route tokens to their expert in expert parallelism.
KV cache
The stored Keys and Values of every past token, kept so they are not recomputed each step. It dominates long-context memory and grows with the number of KV heads, head size, layers, and sequence length.
RoPE
Rotary position embedding: encodes position by rotating the query and key vectors by an angle proportional to position, so the attention score depends on relative distance. Applied every layer, not added to embeddings.
NoPE
No positional encoding: a decoder-only causal model can infer position from the causal mask alone (a counting signal), with no explicit position input. Works only in the causal setting and has a finite usable range.
Sliding window attention
Each token attends only to the last W tokens (a local band), capping cost and local KV cache. Stacking layers still compounds the reach, so the model is not limited to W.
Attention sink
A learned bias that keeps the always-important first tokens in the softmax denominator, letting a very small sliding window (like gpt-oss’s 128) stay stable.
MHA
Multi-head attention: every query head has its own Key and Value head. Best quality, largest KV cache.
MQA
Multi-query attention: all query heads share a single Key/Value head. Smallest KV cache, but the hard sharing can cost quality.
GQA
Grouped-query attention: query heads are split into groups, each sharing one Key/Value head. Near-MHA quality at close to MQA memory, and the mainstream default.
MLA
Multi-head latent attention (DeepSeek): compress every head's K and V into one small shared latent, cache only that, and reconstruct per-head K/V on the fly. GQA-level cache at MHA quality, with a small decoupled RoPE key.
DSA
DeepSeek Sparse Attention (V3.2): a lightning indexer scores past tokens and each query attends only to its top-k (2048), turning attention from order N squared into order N times k. Built on MLA.
YaRN
A RoPE context-extension method: scale the rotation frequencies (NTK-style) and add an attention-temperature correction, so a model trained at one length works at a longer one.