Interpretability & Efficient Architectures

Chapter 10: Sequence Models & Attention Peeking inside transformers · The quadratic wall · FlashAttention · KV cache · Speculative decoding

I avoided looking inside transformers for longer than I care to admit. I'd train a model, watch the loss go down, see the outputs get better, and think: good enough. But "good enough" has a shelf life. One day someone asked me why my model kept attending to punctuation marks in a sentiment task, and I had nothing. Another day I tried feeding it a 20-page document and watched my GPU memory explode. The discomfort of not understanding what was happening inside — or how to make it scale — finally grew too great. Here is that dive.

This section covers two separate but intertwined questions about sequence models, particularly transformers. First: what did the model learn, and can we peek inside without breaking it? That's interpretability. Second: standard self-attention scales quadratically with sequence length, which means it falls apart for long documents. That's the efficiency problem. Both questions became urgent as transformers went from research curiosities to production workhorses.

Before we start, a heads-up. We'll be getting into GPU memory hierarchies, matrix multiplication tricks, and some linear algebra. You don't need any of that beforehand. We'll add what we need, one piece at a time.

This isn't a short journey, but I hope you'll be glad you came.

What We'll Cover

Attention heatmaps and why they lie
Probing classifiers — an MRI for neural networks
BERTology and the layer-by-layer map
Mechanistic interpretability — induction heads and superposition
Rest stop
The quadratic wall
Sparse attention — Longformer, BigBird, and friends
Linear attention and the kernel trick
FlashAttention — the trick that actually won
The KV cache and why it dominates inference memory
Multi-query and grouped-query attention
Speculative decoding — two models, one answer
Wrap-up

Let's ground all of this in a concrete scenario. Imagine we're building a customer support chatbot for a small online bookstore. The chatbot reads a customer's message — "I ordered The Great Gatsby last Tuesday and it still hasn't arrived" — and needs to generate a helpful reply. Our model is a small transformer with 4 layers, 4 attention heads per layer, and a vocabulary of 10,000 tokens. We'll keep coming back to this bookstore bot as we peel back layers of both interpretability and efficiency.

Attention Heatmaps and Why They Lie

The first thing everyone tries — myself included — is plotting the attention weights. Each attention head produces a matrix where the value in row i, column j tells us how much token i "looked at" token j when computing its output. You can draw this as a heatmap, and the results look intoxicating.

For our bookstore bot, imagine feeding in the sentence "The book arrived damaged." We pull out the attention matrix from head 2 of layer 3. It's a 5×5 grid (one row and column per token). When we visualize it, we see that the token "damaged" strongly attends to "book." Satisfying. We feel like we've understood something.

Here's the problem: we haven't. In 2019, Jain and Wallace ran an experiment that shook the interpretability community. They showed you could find completely different attention distributions — ones that looked nothing like the original — that produced the exact same model predictions. If two different "explanations" lead to the same answer, neither one is really an explanation. Wiegreffe and Pinter pushed back later that year with important nuance, but the consensus settled somewhere uncomfortable: attention weights show you where information was routed, not what information mattered for the final decision.

I'll be honest — this tripped me up for a while. I wanted attention heatmaps to be the X-ray that shows you the broken bone. They're more like a traffic map that shows which roads had cars on them. Useful for debugging routing, but it doesn't tell you why someone drove to that particular destination.

So if heatmaps can't tell us what the model actually learned, what can?

Probing Classifiers — An MRI for Neural Networks

Here's a sharper question: does a specific layer of our bookstore bot "know" the part of speech of each token? Does it know which tokens refer to people versus products? Does it understand sentence structure? Probing classifiers are the tool built to answer these questions.

The idea is elegant. We freeze the model — not a single weight changes. Then we extract the hidden state vectors from a particular layer for a bunch of labeled examples. Finally, we train a tiny classifier — typically a linear one, meaning it can only draw straight-line decision boundaries — on those hidden states to predict some property. If a linear classifier can predict part-of-speech tags from layer 2's hidden states with 95% accuracy, that information must be sitting there in an accessible form.

Let's walk through this with our bookstore bot. We have the sentence "I ordered The Great Gatsby last Tuesday." After running it through the model, layer 1 produces a hidden state vector for each token — seven vectors, each maybe 256 dimensions. We collect thousands of such sentences with part-of-speech labels. Then we train a linear classifier: given a 256-dimensional vector, predict whether the token is a noun, verb, adjective, or something else. If this trivial classifier hits high accuracy, the model's first layer has already separated tokens by their grammatical role.

Researchers have done this systematically, layer by layer, for BERT and similar models. The pattern is remarkably consistent. Early layers — layers 1 and 2 — capture surface features: word identity, part-of-speech tags, morphology. Middle layers — around layers 4 through 8 in a 12-layer BERT — capture syntax: dependency relations, phrase structure, whether two tokens are in the same clause. Later layers capture semantics: coreference (knowing "it" refers to "the book"), sentiment, and entity types.

Think of it like the layers of a medical imaging pipeline. The early layers are like the raw pixel detector — they see edges and shapes. The middle layers assemble those into anatomical structures. The later layers are the radiologist reading the image and saying "this is abnormal."

There's a catch, though. If you make your probe too powerful — say, a deep neural network instead of a linear classifier — it might learn the property from scratch rather than detecting it in the hidden states. A sufficiently complex probe can extract patterns from random noise. That defeats the whole purpose. The discipline of keeping probes simple is what makes this technique meaningful. If a linear classifier can find it, the representation must be doing real work.

Probing tells us what information lives in each layer. What it doesn't tell us is how the model uses that information — which specific heads and neurons work together to produce a prediction. For that, we need something more invasive.

BERTology — The Layer-by-Layer Map

The probing results I described above didn't come from one paper. They came from hundreds. The field got so prolific that Rogers, Kovaleva, and Rumshisky published a survey in 2020 called "A Primer in BERTology" that synthesized the findings. The name stuck: BERTology is the study of what BERT-like models learn and how they organize that knowledge.

A few findings stand out. First, not all attention heads matter equally. You can prune (remove) a large fraction of heads with minimal impact on performance. Some heads are specialists — one might focus on direct objects, another on punctuation, another on coreference. Second, the knowledge is distributed: no single neuron or head is solely responsible for any linguistic feature. It's a property of the collective. Third, representations evolve as they move through layers. A word's embedding at layer 1 looks nothing like its embedding at layer 12. The model progressively refines from surface to syntax to semantics.

For our bookstore bot, this means something concrete. If the bot is failing at understanding that "it" in "I returned it yesterday" refers to the book, the problem likely lives in the later layers. If it's confused about word boundaries in a misspelled query, the issue is probably in the early layers. BERTology gives us a rough map of where to look.

But BERTology is still observational. We're looking at the map, not the terrain. Can we go deeper — can we reverse-engineer the actual computations?

Mechanistic Interpretability — Reverse-Engineering the Circuits

This is where things get genuinely ambitious. Mechanistic interpretability tries to identify the specific computations that individual neurons, attention heads, and combinations of components perform. Not "this layer knows syntax" but "these two attention heads, working together, implement a copy-paste operation."

The most famous discovery is the induction head. To understand what an induction head does, let's go back to our bookstore bot. Suppose earlier in the conversation, the customer wrote "I'm looking for Harry Potter." Now, later in the same conversation, the customer writes "I need another Harry —". An induction head is a pair of attention heads that work together to complete this pattern: one head scans backward for a previous occurrence of the current token ("Harry"), and the second head copies what came after it last time ("Potter"). It's a pattern-matching circuit baked into the weights.

Researchers at Anthropic and elsewhere have shown that induction heads emerge at a specific point during training — there's almost a phase transition where the model's ability to learn from context suddenly improves, and that improvement coincides with the formation of these heads. This is one of the cleanest examples of a neural network developing an identifiable algorithm through training.

The second major discovery is superposition. I'll be honest — this one took me several readings to internalize. The intuition is this: a model with, say, 256 neurons per layer doesn't store 256 concepts. It stores far more than 256 concepts by encoding them as overlapping, nearly-orthogonal directions in that 256-dimensional space. One neuron might respond to both "academic writing" and "the color blue" — not because those concepts are related, but because the model learned to pack them into the same space efficiently. It's a compression trick, and it's why looking at individual neurons in isolation gives you confusing results.

Anthropic's more recent work uses sparse autoencoders and a technique called dictionary learning to disentangle these superimposed features. They've been able to isolate specific concepts — like the "Golden Gate Bridge" concept — as individual features in Claude's activations and then manipulate them. It's extraordinary and a little unsettling.

I'm still developing my intuition for why superposition works so well. The high-level explanation — "there are more concepts than neurons, so the model compresses" — is satisfying in the way a one-sentence summary of a novel is satisfying. The full picture involves interesting geometry about how many nearly-orthogonal vectors you can pack into a space, and the field is actively working through the details.

Mechanistic interpretability is the particle physics of deep learning. The findings are striking, potentially critical for AI safety and model auditing. But the work is painstaking, often manual, and hasn't yet scaled to the largest production models. If you're a practitioner, know that these tools exist and are improving fast. If you're asked about them in an interview, know the two key concepts — induction heads and superposition — and why they matter for understanding in-context learning.

Rest Stop 🛑

Congratulations on making it this far. You can stop here if you want.

You now have a mental model of the interpretability toolkit: attention heatmaps show routing but not reasoning, probing classifiers reveal what information each layer stores, BERTology maps the progression from surface features to semantics, and mechanistic interpretability is beginning to reverse-engineer specific circuits like induction heads.

That's a solid foundation. It won't tell you the full story — interpretability is a young field and the tools are still fragile. But if someone asks you "what does this model actually learn?", you have concrete answers with appropriate caveats.

What comes next is different terrain entirely: the engineering problem of making attention scale to long sequences. This is where the bookstore bot starts trying to read entire order histories and shipping logs, and the GPU starts sweating. If the discomfort of not knowing how that gets solved is nagging at you, read on.

The Quadratic Wall

Let's go back to our bookstore bot. So far, customer messages have been short — 20 or 30 tokens. At those lengths, attention is fast. But now we want the bot to read the customer's entire order history before responding. That's a page of text. Maybe ten pages. Standard self-attention computes a score between every pair of tokens. For a sequence of length n, that's n × n scores.

Let's make the pain concrete with our bookstore scenario. A short customer message is about 50 tokens: 50 × 50 = 2,500 attention scores. Comfortable. A full conversation thread is 2,000 tokens: 2,000 × 2,000 = 4 million scores. Fine. An entire order history with all correspondence is 32,000 tokens: 32,000 × 32,000 = about 1 billion scores. Now we have a problem. A full product catalog the bot might need to search through hits 100,000 tokens: 100,000 × 100,000 = 10 billion scores. The GPU throws up its hands.

This is the quadratic wall, and every technique in the rest of this section attacks it from a different angle. Think of it like a traffic problem: a small town doesn't need highways, but a growing city needs multiple strategies — wider roads, bypasses, express lanes, smarter traffic lights. No single solution handles everything.

Sparse Attention — Don't Look at Everything

Here's the first observation: most entries in the attention matrix are near zero anyway. When our bookstore bot processes a 2,000-token conversation, the token "Gatsby" at position 47 probably has no meaningful relationship with the token "thank" at position 1,983. We're computing a score between them and the answer is effectively zero. That's wasted work.

Sparse attention says: skip the pairs that don't matter. Instead of computing all n² scores, define a pattern that picks out the important ones.

The simplest pattern is a local window. Each token attends to its w nearest neighbors — the tokens directly before and after it. If w is 512, each token sees 512 neighbors instead of all n tokens. The cost becomes n × w instead of n × n. For our 32,000-token order history with a window of 512, that's about 16 million scores instead of 1 billion. A 60× reduction.

But wait — doesn't this mean distant tokens can never interact? Not quite. Information propagates through layers. If layer 1 lets token 1 see tokens 1–512, and layer 2 lets it see those same neighbors (who themselves saw their 512 neighbors in layer 1), then after 2 layers, token 1 has indirect access to tokens 1–1024. Stack 12 layers and information can travel across 6,000 tokens. It's like a game of telephone, except the information is encoded in high-dimensional vectors and surprisingly little gets lost.

A local window handles nearby context well, but what about genuinely long-range dependencies? What if the customer's complaint at the end of the conversation references an order number from the beginning? Three additional sparse patterns help.

Global tokens are a handful of designated positions — often the [CLS] token or a start-of-sequence marker — that attend to every token and that every token attends to. They act as information hubs: a global token can pick up the order number from position 5 and make it available to the complaint at position 1,900. The cost is small because there are only a few global tokens.

Strided attention has each token attend to every k-th token, giving a coarse bird's-eye view of the whole sequence. If k is 64, a token at position 1,000 sees positions 0, 64, 128, 192, and so on — a sparse but wide perspective.

Random connections are exactly what they sound like: for each token, randomly select a small number of other tokens to attend to. This sounds reckless, but random connections create shortcuts in the information graph, similar to how random edges in a social network shrink the "six degrees of separation" distance.

Longformer (2020) combines local window attention with global tokens. Every token sees its neighbors, and a few special tokens see everything. This handles tasks like document classification and question answering over long texts. BigBird (2020) adds random connections on top, and proved something theoretically striking: the combination of local + global + random connections gives the same theoretical expressiveness as full attention, as measured by the ability to simulate Turing machines. Both achieve near-full-attention quality at a fraction of the cost.

The limitation? Sparse patterns are hand-designed. Someone has to decide the window size, the number of global tokens, the stride. Different tasks might want different patterns. And for tasks where the truly important attention connections are neither local nor random — where the model needs to discover which pairs matter — fixed patterns can miss things.

Linear Attention — Rearranging the Math

Instead of deciding which pairs to skip, what if we could change the math so the quadratic cost disappears entirely? That's the idea behind linear attention.

Standard attention computes softmax(QKT)V. The bottleneck is QKT, which produces the n × n attention matrix. But here's a matrix algebra observation. Matrix multiplication is associative: (AB)C = A(BC). If we could skip the softmax — or approximate it — we could compute KTV first. That product has shape d × d (where d is the head dimension, typically 64 or 128). Then we multiply Q by that d × d result. The cost goes from O(n²d) to O(nd²). Since d is small and fixed, that's linear in sequence length.

The catch is the softmax. It's applied element-wise to the n × n matrix, which means you can't factor it out of the multiplication. The trick that Performer (2020) and Linformer (2020) use is to replace the softmax with a kernel approximation. The Performer uses a technique called FAVOR+ (Fast Attention Via positive Orthogonal Random features), which approximates the softmax kernel using random feature maps. Linformer takes a different approach: it projects the sequence length dimension of K and V down to a smaller fixed size k using learned projection matrices, reducing the attention matrix from n × n to n × k.

# Standard attention: O(n²d) — build n×n first
scores = Q @ K.T           # shape: n × n  ← this is the problem
weights = softmax(scores)   # shape: n × n
output = weights @ V        # shape: n × d

# Linear attention (Performer-style): O(nd²) — build d×d first
Q_prime = phi(Q)            # apply random feature map
K_prime = phi(K)            # apply random feature map
kv = K_prime.T @ V          # shape: d × d  ← much smaller!
output = Q_prime @ kv       # shape: n × d

The promise is beautiful: linear scaling means a 100,000-token sequence costs 100× more than 1,000 tokens, not 10,000× more. The reality is messier. The kernel approximation introduces error. On tasks where the exact pattern of token-to-token attention matters — where the model needs to attend very sharply to one or two specific tokens — the approximation smooths things out and quality drops. The Performer and Linformer showed impressive results on certain benchmarks but never achieved the same dominance as the technique we'll look at next.

I think of sparse and linear attention like two different strategies for handling a crowded library. Sparse attention says "I'll only look at nearby shelves and a few random ones across the room." Linear attention says "I'll build a compressed summary of every shelf so I don't have to visit each one." Both work, both have blind spots, and neither is what most production systems ended up using.

FlashAttention — The Trick That Actually Won

Here's where I need to correct a misconception I held for an embarrassingly long time. When I first heard about "efficient attention," I assumed the solution was mathematical — find a clever approximation. FlashAttention is nothing of the sort. It computes exact standard attention. Same math, same outputs, bit-for-bit identical results. The entire trick is about how the computation is organized on the GPU hardware.

To understand FlashAttention, we need to understand one thing about GPUs. A modern GPU has two kinds of memory. High-bandwidth memory (HBM) is large — 40 to 80 gigabytes on an A100 — but relatively slow to access. Think of it as a warehouse: lots of space, but it takes time to walk to the shelf and back. SRAM (the on-chip memory, sometimes called shared memory) is tiny — around 20 megabytes total across all streaming multiprocessors — but extremely fast. Think of it as your workbench: everything on it is within arm's reach.

The naive attention computation goes like this. We compute Q × KT, getting the full n × n attention matrix. We write that matrix to HBM. We read it back from HBM to compute softmax. We write the softmax result to HBM. We read it back to multiply by V. We write the final output to HBM. Each of those reads and writes is a trip to the warehouse. For long sequences, we're spending more time moving data between HBM and SRAM than actually doing arithmetic. The computation is memory-bound, not compute-bound.

FlashAttention, introduced by Tri Dao in 2022, eliminates those round trips using tiling. Instead of computing the entire n × n attention matrix at once, it processes the matrix in small blocks — tiles — that fit in SRAM. The algorithm loads a tile of Q and a tile of K into SRAM, computes their attention scores, applies softmax (partially), multiplies by the corresponding tile of V, and accumulates the result. Then it moves on to the next tile. The full n × n matrix is never materialized in HBM at all.

The clever part is the online softmax trick. Softmax needs a denominator — the sum of all exponentials — which normally requires seeing all values first. How do you compute softmax in tiles without seeing every value? The answer: maintain running statistics. For each row, keep track of the maximum value seen so far and the running sum of exponentials. As each new tile arrives, update these statistics incrementally. When all tiles are processed, the running statistics give you the exact softmax. No approximation. No error.

Let me trace through a tiny example. Suppose our bookstore bot has a 4-token sentence and we tile it into blocks of 2. We have Q with 4 rows and K with 4 columns. Instead of computing the full 4 × 4 attention matrix, we process two tiles. Tile 1: Q rows [1,2] × K columns [1,2] — a 2 × 2 block. We compute these 4 scores, find the max per row, compute partial exponentials, and start accumulating weighted V. Tile 2: Q rows [1,2] × K columns [3,4] — another 2 × 2 block. We compute these scores, update our running max (if we found a larger value, we rescale the previous partial sums), update the running sum, and add more weighted V. After tile 2, we have the exact answer for rows 1 and 2. Then we repeat for Q rows [3,4]. The full 4 × 4 matrix was never stored anywhere.

The results are dramatic. FlashAttention delivers a 2–4× wall-clock speedup over naive attention and uses far less memory, because it only stores O(n) values in HBM instead of O(n²). FlashAttention-2 (2023) improved the parallelism further. This is what runs under the hood of essentially every modern large language model. If you remember one technique from this entire section, make it FlashAttention.

I want to sit with this for a moment, because the insight is profound. The math didn't change. The algorithm didn't change. What changed was respecting the hardware. So much of high-performance computing is about understanding that data movement costs more than arithmetic, and designing algorithms that keep data close to where the computation happens. FlashAttention is a masterclass in this principle.

The library analogy from earlier still applies. Sparse attention says "visit fewer shelves." Linear attention says "make a summary map." FlashAttention says "bring the shelves to my workbench in small batches." Same library, same books, same reading — but faster, because we stopped walking back and forth.

The KV Cache — Why Inference Memory Blows Up

We've been talking about making attention efficient during training, where we process the whole sequence at once. Inference is a different beast. When our bookstore bot generates a response — one token at a time, autoregressively — it runs the full attention computation at every step. Generate token 1, run attention. Generate token 2, run attention over tokens 1–2. Generate token 100, run attention over all 100 tokens so far.

Here's the wasteful part. When generating token 100, the model computes new key and value vectors for every previous token — but tokens 1 through 99 haven't changed. Their keys and values are exactly the same as they were when we generated token 99. We're recomputing the same thing over and over.

The KV cache is the fix. After computing the key and value vectors for each token, we store them. When generating the next token, we only compute Q, K, and V for the single new token, then look up the cached keys and values for all previous tokens. The attention computation goes from processing n tokens' worth of K and V at step n to processing 1 new token plus a cache lookup.

Let's trace this for our bookstore bot generating "Your order has shipped." At step 1, we compute K₁ and V₁ for "Your" and cache them. At step 2, we compute K₂ and V₂ for "order", cache them, and attend over [K₁, K₂] and [V₁, V₂] using only the new query Q₂. At step 3, we compute K₃ and V₃ for "has", cache them, attend over [K₁, K₂, K₃] using Q₃. Each step adds one entry to the cache instead of recomputing everything.

The KV cache is what makes LLM inference feasible at all. Without it, generating each new token would be as expensive as processing the entire sequence from scratch. With it, the per-token cost is roughly constant. That's the difference between a chatbot that responds in milliseconds and one that takes minutes.

The tradeoff: the cache consumes memory. For each layer, for each attention head, we store a key vector and a value vector for every token generated so far. A rough calculation for a 70-billion-parameter model with 80 layers, 64 attention heads, and a head dimension of 128: the KV cache for a single sequence of 4,096 tokens requires about 80 × 64 × 4,096 × 128 × 2 (K and V) × 2 bytes (FP16) ≈ 10.7 gigabytes. For a single user's conversation. Now imagine serving hundreds of users simultaneously — the KV cache alone can dwarf the model weights in memory consumption.

This is why the KV cache became the dominant bottleneck in LLM serving. Not compute, not model weights — the per-request cache for ongoing conversations.

Multi-Query and Grouped-Query Attention — Shrinking the Cache

The KV cache problem has a structural cause. In standard multi-head attention, every attention head has its own Q, K, and V projection matrices. If we have 64 heads, we store 64 separate K tensors and 64 separate V tensors per layer. That's what makes the cache so enormous.

In 2019, Noam Shazeer proposed multi-query attention (MQA): keep 64 separate query heads, but share a single K head and a single V head across all of them. Every query head still gets its own perspective (its own Q projection), but they all attend over the same keys and values. The KV cache shrinks by a factor of 64 — from 64 K/V pairs to 1.

Let me trace through what this looks like for our bookstore bot. In standard multi-head attention with 4 heads, generating one token produces 4 key vectors and 4 value vectors to cache. After 100 tokens, the cache holds 800 vectors per layer. With MQA, generating one token produces 1 key vector and 1 value vector, shared across all 4 heads. After 100 tokens, the cache holds 200 vectors per layer. A 4× reduction.

The concern with MQA is quality loss. Different heads are forced to attend over identical keys and values, which might limit the model's ability to attend to different kinds of information simultaneously. In practice, the quality hit is small for many tasks but noticeable for others.

Grouped-query attention (GQA), introduced in the Llama 2 paper (2023), is the compromise that stuck. Instead of all heads sharing one K/V pair, you group the heads. With 64 query heads and 8 groups, you have 8 K/V pairs — each serving 8 query heads. The KV cache shrinks by 8× (instead of 64×), but each group of queries gets its own, somewhat specialized K and V. GQA is used in Llama 2, Llama 3, Mistral, and Gemma. It's the current sweet spot — most of MQA's memory savings with almost no measurable quality loss.

I think of it like desks in an office. Standard multi-head attention gives every employee their own filing cabinet. MQA puts everyone around one shared cabinet. GQA creates team filing cabinets — small groups share, but not the whole company. It's a practical trade-off between personalization and space efficiency.

There's a complementary technique worth mentioning: PagedAttention, used in the vLLM serving framework. Even with GQA shrinking the per-head cache, managing the memory for thousands of concurrent users is messy. Traditional implementations allocate a large contiguous memory block for each user's cache, even if the conversation is short. PagedAttention borrows an idea from operating systems: it divides the KV cache into fixed-size pages and maps them dynamically, like virtual memory. Short conversations use few pages; long ones grow as needed. No wasted memory from over-allocation. If you ever deploy LLMs at scale, you'll spend more time thinking about KV cache management than about model architecture.

Speculative Decoding — Two Models, One Answer

Even with FlashAttention for efficient computation and GQA for a lean cache, autoregressive generation has a fundamental limitation: it's sequential. We generate one token, wait for it, use it to generate the next, wait, and so on. The GPU — a machine designed for massive parallelism — spends most of its time idle, waiting for each single token to finish. It's like having a factory with a thousand workers, and only one of them can work at a time.

Speculative decoding breaks this sequential bottleneck with a clever trick. We use two models: a small, fast draft model and the large, accurate target model. The draft model guesses the next several tokens — say 8 at a time. Then the target model verifies all 8 guesses in a single forward pass (which can be done in parallel, since we have all the inputs). If the first 5 guesses match what the target model would have generated, we accept them all and move on. If guess 6 diverges, we reject it and everything after it, keeping only the first 5.

Let's trace this for our bookstore bot. The customer asks "When will my order arrive?" The draft model — a small, cheap model — quickly generates 8 candidate tokens: "Your", "order", "is", "expected", "to", "arrive", "by", "Friday". The target model (our big, accurate model) takes these 8 tokens and, in a single parallel forward pass, checks: would I have generated each of these? It agrees with "Your", "order", "is", "expected", "to", "arrive" but at position 7, it would have said "on" instead of "by". We accept the first 6 tokens, discard the rest, and start a new draft from "arrive."

The math works because verification is cheaper than generation. The target model can process 8 tokens in parallel about as fast as it processes 1 token — the attention computation scales well with batch size. So instead of 8 sequential forward passes through the big model (slow), we do 8 cheap passes through the draft model (fast) plus 1 forward pass through the big model (same cost as generating 1 token). If the draft model is right most of the time — and a well-chosen draft model can be right 70–80% of the time — we get a 2–3× speedup with zero quality loss. The outputs are mathematically identical to what the target model would have produced on its own.

I still find it surprising that this works so well. The key insight is that most tokens in a well-trained LLM are highly predictable — articles, prepositions, common phrases. A small model gets those right. The big model is only truly needed for the harder decisions, and speculative decoding lets us skip ahead past the easy ones.

Wrap-Up

If you're still with me, thank you. I hope it was worth the trek.

We started by trying to peek inside our bookstore bot — first with attention heatmaps that turned out to be misleading, then with probing classifiers that revealed a clean layer-by-layer progression from surface features to syntax to semantics, and finally with mechanistic interpretability's discovery of specific circuits like induction heads. Then we hit the quadratic wall: the bot needed to handle longer inputs, and standard attention couldn't scale. We fought that wall with sparse patterns (Longformer, BigBird), tried to dissolve it with linear attention (Performer, Linformer), and then discovered that the real winner — FlashAttention — didn't change the math at all but rearranged it to respect the GPU's memory hierarchy. On the inference side, we saw how the KV cache makes generation feasible, how GQA shrinks that cache, and how speculative decoding breaks the sequential bottleneck.

My hope is that the next time you see a transformer model behaving strangely, instead of treating it as an inscrutable black box, you'll know where to look inside — and you'll have a pretty good mental model of the tricks that make it fast enough to look inside at all.

Resources and Credits

The landscape here moves fast, but these are the ones I keep coming back to.