Back to Glossary
Architecture

Flash Attention

Definition

An IO-aware attention algorithm that reduces memory usage and speeds up transformer inference by minimizing GPU memory reads/writes through tiling and recomputation.

Why It Matters

Flash Attention is one of the most important optimizations in modern LLM infrastructure. Standard attention implementations treat GPU memory as unlimited. They materialize the full attention matrix, which grows quadratically with sequence length. Flash Attention rewrites this computation to be IO-aware, dramatically reducing memory bandwidth and enabling longer contexts at lower cost.

For AI engineers, Flash Attention knowledge matters because it’s now the default in most production LLM deployments. Understanding it explains why some models handle long contexts efficiently while others don’t, and helps you make informed decisions about inference optimization. It’s also the foundation for more advanced techniques like Flash Attention 2 and 3.

The impact is significant: 2-4x faster training, 50-70% memory reduction, and the ability to process much longer sequences on the same hardware. This is how modern models achieve 128K+ context windows without prohibitive costs.

How It Works

The Memory Bottleneck Problem

Standard attention computes the full N×N attention matrix (where N is sequence length), stores it in GPU high-bandwidth memory (HBM), applies softmax, then multiplies with values. For long sequences, this matrix becomes massive:

  • 4K tokens: 64MB for the attention matrix alone
  • 32K tokens: 4GB for the attention matrix
  • 128K tokens: 64GB, larger than most GPU memory

The bottleneck isn’t compute, it’s memory bandwidth. Moving data between HBM and GPU compute units is slow and expensive.

Flash Attention’s Solution: Tiling

Instead of materializing the full attention matrix, Flash Attention:

  1. Divides Q, K, V into small tiles that fit in fast SRAM
  2. Computes attention for each tile pair
  3. Accumulates results using online softmax (no full matrix needed)
  4. Writes only the final output to HBM

This trades extra computation for dramatically fewer memory accesses, a favorable tradeoff because modern GPUs are memory-bound, not compute-bound.

Key Techniques

  • Kernel fusion: Combines multiple operations into single GPU kernels to avoid intermediate writes
  • Tiling: Processes attention in blocks that fit in fast on-chip memory
  • Recomputation: Recomputes attention weights during backward pass instead of storing them
  • Online softmax: Updates softmax incrementally without seeing all scores

Implementation Basics

Flash Attention is typically used through libraries rather than implemented directly:

PyTorch Integration

PyTorch 2.0+ includes torch.nn.functional.scaled_dot_product_attention which automatically uses Flash Attention when available.

Hugging Face Transformers

Enable with model.to_bettertransformer() or use models that have Flash Attention built-in. Many recent models (Llama 2/3, Mistral, Falcon) include Flash Attention support.

Direct Usage

The flash-attn library provides direct access with additional features like sliding window attention and cross-attention variants.

When to Use

Flash Attention provides the most benefit when:

  • Processing long sequences (1K+ tokens)
  • Running on modern GPUs (A100, H100, or consumer Ampere/Ada cards)
  • Memory is the limiting factor for batch size or context length

For very short sequences, the overhead may not be worth it. Profile your specific use case to verify gains.

Source

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness achieves 2-4x speedup over standard attention by reducing HBM accesses through kernel fusion and tiling.

https://arxiv.org/abs/2205.14135