Latent Forge

Making FlashAttention-4 faster for inference

By Wren · June 23, 2026 · 4 min read

Last year's release of the FlashAttention-4 kernel source gave the community a chance to confirm the structure that teams had previously reverse-engineered. Since then, Modal's team has contributed a series of changes aimed at making the kernel better suited to LLM inference—especially the decode-heavy workloads that dominate token generation.

Why inference is different

Inference is not pre-training. Per the write-up, the decode (token generation) phase is typically memory bandwidth–limited rather than compute-limited, and the workloads are far less uniform: batch sizes and sequence lengths vary, and keys and values usually have to be retrieved from a cache. That mismatch requires new kernel code, and the authors are blunt that the code has to be fast—"performance is the product."

Two categories of change

The team groups its contributions into two buckets:

  • Adjusting the parallelism strategy — tuning the number of query tiles per thread block and switching from query parallelism to key/value parallelism (porting the "split KV" technique into FA4).
  • Supporting irregular global memory accesses — replacing cp.async.bulk loads issued through the Tensor Memory Accelerator (TMA) with cp.async loads for scattered access patterns.

The post argues that changing parallelism strategies gives the largest leverage on massively parallel hardware: if you're locked into one approach, the sequential term in Amdahl's Law is fixed, but moving work between parallel and sequential components can beat simply speeding up a fixed parallel section.

On tooling, the authors note they inherited the CUDA Templates DSL (CuTe DSL) from the original kernel authors and found it productive, with fast JIT compilation and minimal runtime cost. They also reiterate a point from their earlier write-up: FA4 is best understood algorithmically at the tile level even though it's implemented at the warp level, and they're looking forward to better support for the CUDA Tile programming model for building future attention and matmul kernels.

The contributions, by PR

PR 2109 — FP8 inputs (merged April 17, 2026). Adding support for 8-bit floats (e4m3 or e5m2) reduces bytes moved and operated on, shrinking KV caches and enabling longer contexts or more concurrency. The reported figure of merit is up to 1.16x throughput over a bf16 baseline:

Batch / Seq LenBF16 TFLOP/sFP8 TFLOP/sSpeedup
1 / 16384156918181.13x
32 / 51296210901.16x

The authors note the gain is below the naive 2x you might expect from halving bit width, and suggest this is consistent with a softmax bottleneck—softmax still runs at higher precision on CUDA Cores and/or Special Function Units even as the Tensor Cores take lower-precision inputs. They tie the FP8 motivation to real models: DeepSeek-V3/V4 natively support 8-bit attention, while Qwen and Gemma deployments sometimes use 8-bit KV caches.

PR 1999 and PR 2104 — arbitrary KV page sizes (merged November 13, 2025) and a performance follow-up (merged January 15, 2026). Originally, FA4 required KV cache pages to match tile size, a restriction stemming from the TMA, which accelerates large affine accesses but can't gather scattered blocks into a single tile and may slow down smaller loads. The team added a cp.async path (CuTe DSL's wrapper for PTX cp.async) via a PagedKVManager. In the TMA version, one thread per warp loaded a tile; in the cp.async version, each thread issues its own load with hardware coalescing, computing its own page and offset. They repurposed the otherwise idle warp 15 so the producer group spans two warps. The reported figure of merit is up to 2.40x throughput for small page sizes:

Page SizeIn PR 1999?TFLOP/s (1999)TFLOP/s (2104)Speedup
1y18.5644.572.40x
8y31.2142.581.37x
32y34.9842.471.21x
128n42.1141.96-

Smaller page sizes matter because large pages can force unnecessary duplication—for example, when several requests share a short prefix but diverge afterward.

Takeaways

The work reads as a pragmatic adaptation of a training-oriented kernel to the realities of inference: lower precision where models tolerate it, finer-grained KV paging for cache efficiency, and parallelism reshaped around decode. The figures come from the team's own PRs and tables rather than independent testing, so treat them as the authors' reported results. For anyone building Blackwell-era attention kernels, the post is a useful map of where FA4's TMA-centric design meets friction under irregular inference workloads—and how cp.async and split-KV help close the gap.

No first-hand testing implied.

Related reading

More Projects

Advanced semiconductor node enters risk production with significant performance and thermal efficiency gains

June 23, 2026