Multi-Head Attention
Definition
An extension of attention that runs multiple attention operations in parallel with different learned projections, allowing the model to capture different types of relationships simultaneously.
Why It Matters
Multi-head attention is why LLMs can track multiple types of relationships simultaneously. A single attention mechanism might focus on one thing, say, grammatical structure. Multiple heads can track grammar, semantics, entity relationships, and more, all at the same time.
Research shows different attention heads specialize in different tasks: some track syntax, others track coreference, others identify named entities. This parallelism is key to the rich understanding modern LLMs demonstrate.
For AI engineers, understanding multi-head attention helps explain model behavior. It’s also relevant when working with attention visualizations for debugging or interpretability.
Implementation Basics
How It Works
- Project input into h different (Query, Key, Value) sets
- Run attention independently on each set
- Concatenate all outputs
- Project concatenated result to final dimension
Architecture Details
MultiHead(Q, K, V) = Concat(head₁, ..., headₕ) × W_O
where headᵢ = Attention(Q × W_Qᵢ, K × W_Kᵢ, V × W_Vᵢ)
Typical Configurations
- GPT-2: 12 heads per layer
- GPT-3: 96 heads per layer
- Llama 2 (70B): 64 heads per layer
Why Multiple Heads?
- Specialization: Heads can learn different patterns
- Robustness: Multiple perspectives on the same input
- Expressiveness: Captures complex relationships
- Parallelization: Heads run independently
Head Specialization Research Studies have found heads that specialize in:
- Previous token attention (position-based)
- Rare word attention
- Duplicate token detection
- Syntactic dependency tracking
- Coreference resolution
Computational Considerations
- Total attention dimension split across heads
- Head dimension = model dimension / number of heads
- More heads = finer-grained attention
- Fewer heads = more parameters per head
Grouped Query Attention (GQA) Modern optimization where multiple query heads share key-value pairs:
- Reduces memory for KV cache
- Enables longer context with same memory
- Used in Llama 2 and Claude models
This is why newer models can handle longer contexts more efficiently than older architectures.
Source
Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions, which a single attention head would inhibit.
https://arxiv.org/abs/1706.03762