Lesson overview | Previous part | Next part
Attention Mechanism Math: Part 4: Multi-Head Attention to 6. Complexity and Efficient Attention
4. Multi-Head Attention
Multi-Head Attention explains how transformer layers route information across sequence positions using differentiable, mask-aware retrieval.
4.1 Head dimensions
Purpose. Head dimensions focuses on splitting representation into multiple subspaces. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Multi-head attention runs several attention mechanisms in parallel using different learned projections.
Worked reading.
Each head has width in the standard design, then head outputs are concatenated and projected.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- syntax-like heads.
- copy heads.
- multi-query attention.
Non-examples:
- one monolithic attention map only.
- duplicating the same head without learned projections.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
4.2 Parallel heads
Purpose. Parallel heads focuses on different learned projections over the same sequence. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Multi-head attention runs several attention mechanisms in parallel using different learned projections.
Worked reading.
Each head has width in the standard design, then head outputs are concatenated and projected.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- syntax-like heads.
- copy heads.
- multi-query attention.
Non-examples:
- one monolithic attention map only.
- duplicating the same head without learned projections.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
4.3 Concatenation and output projection
Purpose. Concatenation and output projection focuses on returning to model width. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
This concept is part of the attention mechanism that mixes token representations according to learned compatibility scores.
Worked reading.
The implementation habit is to write shapes, scores, masks, softmax, and value aggregation explicitly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- self-attention.
- decoder attention.
- attention over retrieved context.
Non-examples:
- independent token processing.
- fixed averaging with no learned scores.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
4.4 Head specialization and redundancy
Purpose. Head specialization and redundancy focuses on why heads can be interpretable or unused. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Multi-head attention runs several attention mechanisms in parallel using different learned projections.
Worked reading.
Each head has width in the standard design, then head outputs are concatenated and projected.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- syntax-like heads.
- copy heads.
- multi-query attention.
Non-examples:
- one monolithic attention map only.
- duplicating the same head without learned projections.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
4.5 Grouped query and multi-query attention
Purpose. Grouped query and multi-query attention focuses on sharing K/V for inference efficiency. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Multi-head attention runs several attention mechanisms in parallel using different learned projections.
Worked reading.
Each head has width in the standard design, then head outputs are concatenated and projected.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- syntax-like heads.
- copy heads.
- multi-query attention.
Non-examples:
- one monolithic attention map only.
- duplicating the same head without learned projections.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
5. Decoder Attention in LLMs
Decoder Attention in LLMs explains how transformer layers route information across sequence positions using differentiable, mask-aware retrieval.
5.1 Autoregressive causal attention
Purpose. Autoregressive causal attention focuses on preventing future-token leakage. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
A mask changes which key positions a query is allowed to see by adding large negative values to forbidden logits before softmax.
Worked reading.
In decoder-only language modeling, token may attend to positions but not to future positions .
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- causal masks.
- padding masks.
- structured prompt masks.
Non-examples:
- zeroing output after softmax.
- trusting data order without a mask.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
5.2 KV cache
Purpose. KV cache focuses on reusing past keys and values during generation. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
During autoregressive generation, past keys and values can be cached so each new token computes attention against old K/V instead of recomputing the entire prefix.
Worked reading.
Prefill processes the whole prompt; decode appends one token at a time and reuses cached K/V tensors.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- LLM serving.
- streaming decode.
- multi-query attention.
Non-examples:
- training with full parallel sequence processing.
- recomputing all previous keys every token.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
5.3 Prefill versus decode
Purpose. Prefill versus decode focuses on two different attention workloads. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
This concept is part of the attention mechanism that mixes token representations according to learned compatibility scores.
Worked reading.
The implementation habit is to write shapes, scores, masks, softmax, and value aggregation explicitly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- self-attention.
- decoder attention.
- attention over retrieved context.
Non-examples:
- independent token processing.
- fixed averaging with no learned scores.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
5.4 Attention with positional encodings
Purpose. Attention with positional encodings focuses on how RoPE and ALiBi modify scores. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
This concept is part of the attention mechanism that mixes token representations according to learned compatibility scores.
Worked reading.
The implementation habit is to write shapes, scores, masks, softmax, and value aggregation explicitly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- self-attention.
- decoder attention.
- attention over retrieved context.
Non-examples:
- independent token processing.
- fixed averaging with no learned scores.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
5.5 Cross-attention preview
Purpose. Cross-attention preview focuses on encoder-decoder and retrieval-conditioned variants. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
This concept is part of the attention mechanism that mixes token representations according to learned compatibility scores.
Worked reading.
The implementation habit is to write shapes, scores, masks, softmax, and value aggregation explicitly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- self-attention.
- decoder attention.
- attention over retrieved context.
Non-examples:
- independent token processing.
- fixed averaging with no learned scores.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
6. Complexity and Efficient Attention
Complexity and Efficient Attention explains how transformer layers route information across sequence positions using differentiable, mask-aware retrieval.
6.1 Quadratic token cost
Purpose. Quadratic token cost focuses on why score matrices dominate long context. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Standard attention forms all pairwise query-key scores, so score memory grows quadratically with sequence length.
Worked reading.
Doubling context length roughly quadruples the score matrix size, even before considering layer count and batch size.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- long-context training.
- KV-cache sizing.
- FlashAttention kernels.
Non-examples:
- linear cost assumptions.
- ignoring memory traffic.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
6.2 Memory layout and IO
Purpose. Memory layout and IO focuses on why exact attention can be slow despite simple formulas. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Standard attention forms all pairwise query-key scores, so score memory grows quadratically with sequence length.
Worked reading.
Doubling context length roughly quadruples the score matrix size, even before considering layer count and batch size.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- long-context training.
- KV-cache sizing.
- FlashAttention kernels.
Non-examples:
- linear cost assumptions.
- ignoring memory traffic.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
6.3 FlashAttention intuition
Purpose. FlashAttention intuition focuses on tiling exact attention to reduce memory traffic. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
FlashAttention computes exact attention while avoiding materializing the full attention matrix in high-bandwidth memory.
Worked reading.
It tiles Q, K, and V blocks and maintains online softmax statistics so memory traffic is lower even though the mathematical result is exact.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- long-context training.
- GPU attention kernels.
- memory-efficient exact attention.
Non-examples:
- approximate sparse attention.
- changing the attention formula.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
6.4 Sparse and local attention preview
Purpose. Sparse and local attention preview focuses on approximating visibility patterns. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
This concept is part of the attention mechanism that mixes token representations according to learned compatibility scores.
Worked reading.
The implementation habit is to write shapes, scores, masks, softmax, and value aggregation explicitly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- self-attention.
- decoder attention.
- attention over retrieved context.
Non-examples:
- independent token processing.
- fixed averaging with no learned scores.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.
6.5 Long-context diagnostics
Purpose. Long-context diagnostics focuses on checking quality cost and position behavior together. This is a core part of how transformer layers turn a sequence of embeddings into context-aware hidden states.
Operational definition.
Attention diagnostics inspect weights, entropy, masks, and head importance, but they do not by themselves prove causal explanations.
Worked reading.
A low-entropy row means one or a few keys dominate; a high-entropy row means information is mixed broadly.
| Object | Shape | Meaning |
|---|---|---|
| hidden states entering the layer | ||
| query and key address vectors | ||
| value payload vectors | ||
| compatibility scores | ||
| attention weights | ||
| mixed output values |
Examples:
- attention heatmaps.
- head ablations.
- entropy dashboards.
Non-examples:
- claiming attention weight equals explanation.
- inspecting only one prompt.
Derivation habit.
- Write the shapes of .
- Add masks before softmax, not after.
- Check every attention row sums to one over visible keys.
- Separate mathematical attention from kernel implementation details.
- For LLM serving, distinguish prefill attention from decode attention with a KV cache.
Implementation lens.
A correct attention implementation is mostly a shape and masking discipline. The bug that hurts language modeling most is often not the matrix multiplication; it is allowing a token to see future positions or padding tokens.
For efficient inference, the formula stays the same but the workload changes. During prefill, the model processes a full prompt. During decode, it adds one query at a time while reading cached keys and values from previous tokens.
For interpretation, attention weights are useful traces of information flow, but they are not the whole model explanation. Residual connections, MLPs, layer norms, and later layers can change or override what a single attention map appears to show.