Nice to Know — The Transformer Variant Zoo
I'll be honest — for a long time I avoided reading about transformer variants. Every month a new paper would land with a clever name (XLNet, Reformer, RWKV), each claiming to fix something fundamental about the original architecture. I'd skim the abstract, feel a pang of guilt, and close the tab. The zoo kept growing and I kept pretending I'd catch up "next weekend." Finally the discomfort of nodding along in conversations about segment-level recurrence and permutation language models — without understanding a word of it — grew too great. Here is that dive.
What follows is a tour through the most important ideas that sit adjacent to the standard Transformer. Some of these became stepping stones to modern LLMs. Others were brilliant dead ends. A few are still actively competing for the throne. None of them are required to understand the core chapter material, but recognizing them on sight will save you hours and make you dangerous in any architecture discussion.
Before we start, a heads-up: we'll touch on attention mechanics, memory addressing, and routing networks. You don't need to be an expert on any of them beforehand. We'll build up what we need as we go, one idea at a time.
This isn't a short journey, but I hope you'll be glad you came.
The Masking Dilemma — XLNet
The Quadratic Cost — Reformer and LSH Attention
Rest Stop
Giving Networks Scratch Paper — Neural Turing Machines
The Differentiable Everything Insight
Not Every Expert Needs to Show Up — Mixture of Experts
The Post-Attention Generation — RetNet and RWKV
Wrap-Up
Resources and Credits
The Context Window Problem — Transformer-XL
Imagine you're building a system that summarizes long documents — legal contracts, say. Your transformer has a context window of 512 tokens. The contract is 4,000 tokens long. The standard approach is to chop the contract into eight segments of 512 tokens each and process them independently. Segment one gets summarized. Segment two gets summarized. But segment two has no idea what was in segment one. If the contract says "the party defined in Section 1" in segment five, segment five has never seen Section 1. This problem has a name: context fragmentation.
It's like reading a novel but someone tears out each chapter after you finish it. You can understand individual paragraphs fine, but you can't track plot threads across chapters. That's where your model is stuck.
Transformer-XL (Dai et al., 2019) solved this with a deceptively elegant idea called segment-level recurrence. When the model finishes processing segment one, it doesn't throw away the hidden states. It caches them. When segment two arrives, the attention mechanism can reach back into that cache and attend to segment one's representations. The effective context window grows — not by making the attention matrix bigger, but by carrying memory forward.
There's a wrinkle, though. Standard transformers use absolute positional encodings — token at position 47 always gets the same positional signal. If you concatenate segments, suddenly position 0 of segment two is actually position 512 of the document. The model gets confused about distances. Transformer-XL introduced relative positional encoding to fix this. Instead of encoding "I am at position 47," each attention computation encodes "this key is 3 positions to my left." The model reasons about distances, not absolute coordinates.
Think of it this way: absolute encoding is like having a GPS coordinate stamped on every word. Relative encoding is like knowing "that word was two sentences ago." The second one works even when you move to a new part of the document.
The combination — cached hidden states plus relative positions — lets the effective context grow linearly with the number of layers. A 16-layer Transformer-XL can attend roughly 16 segments back, giving it a context hundreds of times longer than the raw segment length. And the beautiful thing is that this comes with no increase in the quadratic attention cost per segment. Each segment still does standard attention within itself; it can attend to the cache for free because those states are already computed.
The limitation? The model still processes segments sequentially during training (you can't cache what hasn't been computed yet), which makes it slower to train than a model that can parallelize over the full sequence. And the cache size is fixed — you choose how many previous segments to remember, and everything older is gone. It's a sliding window, not infinite memory.
Still, Transformer-XL was the key precursor to XLNet and influenced how every modern LLM handles long contexts. When you hear about "extending the context window" of GPT-style models, the intellectual DNA traces back here.
The Masking Dilemma — XLNet
With Transformer-XL handling long contexts, a different problem came into focus. By 2019, the field had split into two camps for pretraining language models, and both camps had a blind spot.
In the GPT camp, you train left-to-right: predict each word from the words that came before it. This is autoregressive training. It's clean and natural — the model generates text the same way it was trained. But it only ever sees leftward context. If the word "bank" appears and the disambiguation depends on a word to its right, the model is out of luck.
In the BERT camp, you mask random words and predict them from the remaining context. This is masked language modeling. The model sees both left and right context. But there's a catch: those [MASK] tokens that appear during training never appear during actual use. The model is trained on slightly alien input. And when multiple tokens are masked simultaneously, BERT assumes they're independent of each other, which is often false. If "New" and "York" are both masked, BERT predicts each separately, ignoring that they're tightly correlated.
I'll be honest — when I first read the XLNet paper and they proposed predicting tokens in random order, I didn't believe it would help. It sounded like adding chaos for no reason. It took me three reads to see why it was clever.
XLNet (Yang et al., 2019) introduced permutation language modeling. The idea: take a sentence like "The cat sat on the mat." Normally you'd predict each word left-to-right. Instead, randomly shuffle the prediction order — maybe predict position 4 first, then position 1, then position 5, then position 3, and so on. For each position, you predict it using all the other positions that come before it in this particular random order.
Let's trace through a tiny example. Suppose our sentence is five tokens: [A, B, C, D, E]. One random permutation of prediction order might be [3, 1, 5, 2, 4]. That means we first predict token C (from no context), then predict A (from C), then E (from C and A), then B (from C, A, and E), then D (from everything). A different training step might give order [2, 4, 1, 5, 3], where B comes first and the contexts shift completely.
Over many training steps, every token sees every possible subset of the other tokens as context. Position 3 sometimes has left context and right context. Sometimes it has everything except one token. Sometimes it has almost nothing. The model learns to predict from any partial view of the sentence — which means it's effectively bidirectional, like BERT, but without the [MASK] token hack.
The key insight: the actual input tokens never change. The words stay in their original order in the input. What changes is the factorization order — which token gets predicted when. The model needs a mechanism to know "I'm predicting position 4, and I've already seen positions 3 and 1 in this particular permutation." XLNet uses two-stream self-attention for this: a content stream that carries the actual token information, and a query stream that knows which position is being predicted but can't peek at the token there. I still find this the trickiest part of the architecture. The two streams are necessary to prevent the model from cheating — seeing the answer to its own question.
XLNet was built on top of Transformer-XL (it reuses segment-level recurrence and relative positional encoding), so it also handles long contexts. At the time of its release, it outperformed BERT on 20 benchmarks. It was one of the first demonstrations that you could get bidirectional context without resorting to masking.
The limitation? Permutation training is computationally expensive. You need to sample many permutations per sentence to cover the space well, which makes training slower than either GPT or BERT. In practice, only a subset of positions are predicted per permutation (usually the last few in the factorization order), which helps but adds complexity. The field eventually moved toward scaling simple autoregressive models rather than elaborate pretraining objectives. But the insight — that the factorization order of predictions is a design choice, not a fixed constraint — remains influential.
The Quadratic Cost — Reformer and LSH Attention
So far we've been solving the "how much can the model see" problem. Transformer-XL extended context with recurrence. XLNet enriched context with permutations. But there's a more brutal problem lurking underneath all of this: the attention mechanism itself scales quadratically with sequence length.
For a sequence of 512 tokens, the attention matrix has 512 × 512 = about 262,000 entries. Manageable. For 4,096 tokens, it's roughly 16 million entries. For 65,000 tokens (a long document), it's over 4 billion entries. Per layer. Per attention head. The math gets unfriendly fast.
The Reformer (Kitaev et al., 2020) asked a provocative question: do we actually need every token to attend to every other token? In most trained attention matrices, the pattern is sparse — each token pays strong attention to a handful of other tokens and ignores the rest. What if we could figure out which tokens would attend to each other before computing the full attention matrix?
The answer was locality-sensitive hashing (LSH) attention. Here's the core idea. LSH is a technique from the nearest-neighbor search literature. You design a hash function with a special property: similar vectors get hashed to the same bucket with high probability, and dissimilar vectors get different buckets. It's like sorting a library not alphabetically, but by topic — books about the same subject end up on the same shelf, even if their titles start with different letters.
In Reformer, queries and keys are hashed into buckets using LSH. Attention is only computed between tokens that land in the same bucket — the ones that would have had high attention weights anyway. Tokens in different buckets are assumed to have negligible attention and are skipped entirely. This drops the complexity from O(n²) to roughly O(n log n).
To reduce the risk of missing important connections (two genuinely related tokens landing in different buckets), Reformer runs multiple rounds of hashing with different random hash functions and combines the results. It's a bet on probability: if two tokens are truly similar, they'll collide in at least one round.
The Reformer had a second innovation: reversible layers. Normal transformers store every layer's activations for backpropagation — that's a lot of memory for a deep model. Reversible layers (borrowed from RevNets) let you recompute any layer's activations from the layer above it, trading compute for memory. Instead of storing activations for all L layers (O(n·L) memory), you store activations for one layer and recompute on the fly (O(n) memory).
I have to be honest about the outcome, though. The Reformer was a brilliant engineering effort, but the constant factors in LSH attention turned out to be large. The hashing, sorting, and bucket management added overhead that ate into the theoretical complexity gains. And then Flash Attention arrived and attacked the same problem from a completely different angle — rearranging memory access patterns on GPUs rather than approximating the attention computation itself — and it won. Flash Attention gives you exact attention, faster, with no approximation error.
Reformer is historically important. It forced the community to ask the right questions about attention efficiency. But it's rarely used in production today. When you see it cited, it's usually in the lineage section of papers that took its insights further.
That's a genuinely useful mental model. If someone mentions any of these architectures in a meeting or a paper review, you'll know what they're talking about and — more importantly — why each one exists. You could stop here and be ahead of most practitioners.
What we haven't covered yet is a fundamentally different question: what if the model could write things down? What if, instead of cramming everything into hidden states, the network had actual external scratch paper it could read from and write to? And after that, we'll look at models that questioned whether attention itself is the right mechanism. Those are wilder ideas, and they take us further from the standard Transformer playbook.
But if the discomfort of not knowing what Neural Turing Machines, Mixture of Experts, or RWKV actually do is nagging at you, read on.
Giving Networks Scratch Paper — Neural Turing Machines
Everything we've discussed so far stores information in the same place: the hidden states of the network. Transformer-XL caches hidden states. XLNet shuffles the order of predictions but still computes hidden states. Even the attention mechanism is really a way to route information between hidden states. The network's "memory" and its "computation" live in the same substrate.
But think about how you solve a complex problem. You don't keep everything in your head. You write things down. You make lists. You cross out items. You flip back to page one. Your working memory is tiny; your notebook is huge. The notebook is external storage — you read from it, write to it, and your thinking process is separate from the storage medium.
In 2014, Graves, Wayne, and Danihelka at DeepMind asked: what if we gave a neural network a notebook?
A Neural Turing Machine (NTM) is a neural network coupled with an external memory matrix — think of it as a table with N rows and M columns, where each row is a memory slot. A controller network (usually an LSTM or a feedforward net) reads the input, decides what to read from memory, what to write to memory, and produces the output. At every time step, the controller issues read and write instructions.
The critical trick is how the controller addresses memory. A regular computer addresses memory with hard indices: "read slot 47." But hard indexing isn't differentiable — you can't compute a gradient through "go to slot 47." Without gradients, you can't train the system end-to-end with backpropagation.
NTMs solve this with soft addressing. Instead of pointing to one slot, the controller outputs a weight vector over all N slots — a distribution that says "pay 60% attention to slot 12, 30% to slot 13, and sprinkle the rest." Reading from memory becomes a weighted sum across all rows: r = Σ wᵢ · Mᵢ. Writing works similarly — erase a little from every row (weighted by the attention), then add new content (also weighted). Every operation is continuous and differentiable. Gradients flow through the memory, through the addressing, through the controller, all the way back to the input.
I'll admit — I still find the addressing mechanism hard to visualize. The NTM actually combines two types of addressing: content-based (find the memory row that looks most like this query vector — sound familiar? It's attention) and location-based (shift the focus left or right from where you were last reading, like moving a read head along a tape). The combination lets the network do things like "find the row containing X, then read the next three rows in order." That's remarkably close to how a traditional computer traverses a data structure.
The results were striking. NTMs learned to copy sequences, sort lists, and perform associative recall — all from input-output examples, without being explicitly programmed to do these tasks. The network figured out the algorithm on its own, using the memory as scratch space.
In 2016, the same team released the Differentiable Neural Computer (DNC), an improved version. The DNC added usage tracking (knowing which memory slots have been used and which are free), temporal links (remembering the order in which slots were written, so you can replay sequences), and dynamic memory allocation. It could answer questions about family trees and navigate the London Underground map by storing facts in memory and reasoning over them.
The limitation? NTMs and DNCs are hard to train. The soft addressing over large memories creates optimization landscapes that are noisy and full of plateaus. They never scaled to the sizes where transformers thrive. But their intellectual contribution is profound: they showed that memory and computation can be separated in a neural network, and that the boundary between "learned program" and "neural network" is blurrier than anyone thought.
The Differentiable Everything Insight
The NTM's trick — making memory operations differentiable so you can train them with gradient descent — is actually an instance of a much bigger idea: differentiable programming.
Here's the core insight. Traditional programming is about writing explicit instructions: if-else, loops, array indexing. Traditional deep learning is about stacking layers and training weights. Differentiable programming says: what if the entire program — including the control flow, the memory access, the data structures — was made of differentiable operations? Then you could train the whole program end-to-end with gradient descent.
The NTM's external memory is a differentiable data structure. Attention is a differentiable dictionary lookup. The softmax router in a Mixture of Experts (which we'll get to next) is a differentiable if-else. Positional encodings are differentiable representations of sequence position. Every piece of a modern transformer is a differentiable replacement for something that used to be a hard-coded algorithmic step.
Yann LeCun called deep learning "a subset of differentiable programming" in 2018, and I think he was right. The trajectory of the field isn't toward more layers or more data (though both help). It's toward making more and more of the computational pipeline differentiable, so the optimizer can tune all of it.
Neural ODEs (Chen et al., 2018) made the depth of a network differentiable — instead of discrete layers, the hidden state evolves according to a learned differential equation, and you can choose how many "layers" to use at inference time. Differentiable rendering lets you train 3D scene representations from 2D images by making the rendering pipeline itself differentiable. Differentiable sorting lets you include sorting operations in a neural network and backpropagate through them.
My favorite thing about differentiable programming is that, aside from high-level explanations like the one I gave, nobody is completely certain where the limits are. Can you make an entire database engine differentiable? A compiler? We don't know yet. But every year the boundary moves.
Not Every Expert Needs to Show Up — Mixture of Experts
Let's come back to transformers and a very practical problem. You want a bigger model because bigger models learn better representations. But bigger also means slower — more parameters to evaluate on every forward pass. Is there a way to have a huge model without paying the full computational cost on every input?
The answer is conditional computation: only activate the parts of the model that are relevant for each input. Don't make every parameter work on every token. Let the model decide, per token, which parameters to use.
The oldest incarnation of this idea is the Mixture of Experts (MoE), which dates back to Jacobs et al. in 1991 — decades before transformers. The concept: instead of one big neural network, you have several smaller "expert" networks and a "gating" network that decides which experts to consult for each input. The gating network looks at the input and produces weights: "For this input, Expert 2 gets 70% weight and Expert 5 gets 30%."
In the transformer context, the MoE layer typically replaces the feed-forward network (FFN) in each transformer block. Instead of one FFN that every token passes through, you have, say, 64 expert FFNs and a router. The router examines each token's hidden state and sends it to the top-k experts (often k = 1 or k = 2). The experts process the token, and their outputs are combined using the router's weights.
Here's where it gets interesting. A model can have 400 billion total parameters spread across all its experts, but each token only activates maybe 12 billion parameters — the ones in its assigned experts. Training sees all the parameters (different tokens go to different experts, so all experts get trained). Inference is fast because each token only evaluates a tiny fraction of the full model.
GShard (Lepikhin et al., 2020) was Google's framework for training enormous MoE models across thousands of TPUs. Each expert lives on a different device, and the router decides which device each token gets shipped to. Switch Transformer (Fedus et al., 2021) simplified the routing to top-1: each token goes to exactly one expert. This sounds restrictive, but it made the system dramatically easier to scale and turned out to work remarkably well. Switch Transformers reached trillion-parameter territory with training costs comparable to much smaller dense models.
Mixtral (Mistral, 2023) brought MoE to the open-source world with 8 experts and 2 active per token, showing that the approach works at practical scales with accessible hardware. GPT-4 is widely believed (though not officially confirmed) to use an MoE architecture. So does Google's Gemini.
The hardest problem in MoE isn't the architecture — it's load balancing. If the router sends all tokens to Expert 1 and none to Expert 7, you've wasted parameters on the idle experts and overloaded the popular one. Models use auxiliary losses that penalize uneven expert utilization — a tax on imbalance. Getting this tax right is an art. Too weak and experts collapse (one expert handles everything). Too strong and the router makes random, unhelpful assignments to keep things balanced.
I'll confess that no one fully understands why sparse routing works as well as it does. The intuition is that different experts specialize in different types of tokens (some become "syntax experts," others handle "rare words," others specialize in "numbers"). But when researchers inspect trained routers, the specialization is often messier and less interpretable than the stories we tell about it. It works. We're still figuring out exactly why.
The Post-Attention Generation — RetNet and RWKV
Everything we've discussed so far accepts a fundamental premise: attention — the quadratic, every-token-looks-at-every-token mechanism — is the right foundation for sequence modeling. Reformer tried to make attention cheaper. Flash Attention tried to make it faster on hardware. MoE tried to shrink the FFN cost. But attention itself remained the core.
What if attention isn't the answer?
The RNN era had one beautiful property that transformers threw away: constant memory during inference. An LSTM processes a sequence token by token, maintaining a fixed-size hidden state. The thousandth token costs the same as the first token. Transformers, by contrast, must store and attend to the entire KV cache, which grows with every new token. For long sequences and real-time applications, this is painful.
Two architectures — RetNet and RWKV — asked whether we can get transformer-level quality with RNN-level efficiency. They attack this from slightly different angles, but share a common insight: replace attention with a mechanism that can be computed either in parallel (like attention, for fast training) or recurrently (like an RNN, for efficient inference).
Retentive Networks (RetNet) (Sun et al., 2023) replace attention with a retention mechanism. The key idea is an exponentially decaying weighted sum over past tokens. At each position, the model computes something like: output_t = Σ (γ^(t-i)) · (query_t · key_i) · value_i, where γ is a decay factor between 0 and 1. Recent tokens contribute more; distant tokens fade exponentially.
The beautiful thing about this formulation is that it has a dual form. You can write it as a matrix multiplication (the "parallel" form, which looks like attention with a triangular decay mask — great for training on GPUs), or you can write it as a recurrence (the "recurrent" form — great for inference, where you maintain a fixed-size state and update it token by token). Same math, two implementations, each optimal for its use case.
The decay factor γ acts like a knob controlling the model's memory horizon. Different attention heads can have different decay rates — some heads with high γ remember far back, others with low γ focus on recent context. It's reminiscent of Clockwork RNNs (where different groups updated at different rates), but derived from a different mathematical starting point.
RWKV (Peng et al., 2023) takes a more radical approach. The name comes from its key parameter matrices: R (receptance), W (weight), K (key), V (value). RWKV replaces attention with a mechanism based on token-shift mixing and channel-wise weighted sums. Each token's representation is a learned mix of its own embedding and the previous token's embedding (the "token shift"), followed by a recurrent update that looks like a weighted moving average.
The recurrence in RWKV is per-channel: each dimension of the hidden state is updated independently. There's no matrix multiply in the recurrent step — each channel decays and accumulates on its own. This makes the per-step cost O(d) instead of O(d²), where d is the hidden dimension. Stacking multiple RWKV layers provides the cross-channel mixing that the recurrence alone doesn't give.
I'm still developing my intuition for when these alternatives genuinely beat transformers versus when they're trading expressiveness for speed. The empirical results are promising: RWKV models up to 14 billion parameters have been trained and show competitive perplexity with similar-size transformers, at a fraction of the inference cost for long sequences. RetNet claims to achieve "the impossible triangle" of parallel training, efficient inference, and strong performance. But "competitive with transformers" is a moving target, and transformers keep getting better too.
What I find most exciting about RetNet and RWKV isn't whether they'll "beat" transformers. It's that they forced the field to articulate what attention actually provides — and what it costs. The fact that exponential decay and channel-wise recurrences can approach attention-level performance suggests that the quadratic all-pairs comparison might be doing less than we assumed. Or it suggests that the value of attention is concentrated in a few important interactions, and a well-designed decay can approximate those interactions at a fraction of the price. Either way, our understanding of why transformers work is sharper because of these alternatives.
Wrap-Up
If you're still with me, thank you. I hope it was worth the tour.
We started with a simple frustration — the transformer's fixed context window — and watched Transformer-XL stitch segments together with cached hidden states. We saw XLNet scramble the prediction order to get bidirectional context without masking. We watched Reformer try to dodge the quadratic cost with hashing, and learned why Flash Attention ultimately won that battle. We gave neural networks scratch paper (Neural Turing Machines) and discovered that memory, computation, and even sorting can all be made differentiable. We saw Mixture of Experts split the work among specialists so a model can be massive without being slow. And we followed RetNet and RWKV as they questioned whether attention itself is the right mechanism — arriving at recurrence-attention hybrids that might reshape inference.
My hope is that the next time you encounter an unfamiliar architecture name in a paper or a conversation — Transformer-XL, XLNet, NTM, Switch Transformer, RWKV — instead of that familiar pang of guilt and the impulse to close the tab, you'll have a genuine mental model of what it does, why it exists, and where it fits in the larger story of sequence modeling.
Resources and Credits
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (Dai et al., 2019) — the original paper. Clear writing for a research paper, and the relative positional encoding section is especially well-explained.
"XLNet: Generalized Autoregressive Pretraining for Language Understanding" (Yang et al., 2019) — dense but rewarding. The two-stream attention diagrams are essential for understanding the architecture.
"Reformer: The Efficient Transformer" (Kitaev et al., 2020) — worth reading for the LSH attention idea alone, even though the architecture isn't widely used today.
"Neural Turing Machines" (Graves et al., 2014) — the O.G. paper on differentiable external memory. Still one of the most mind-expanding reads in deep learning.
"Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" (Fedus et al., 2021) — the clearest exposition of MoE in transformers. The load-balancing discussion is particularly insightful.
"RWKV: Reinventing RNNs for the Transformer Era" (Peng et al., 2023) — a bold paper that makes you rethink whether attention is necessary. The dual-form derivation is elegant.
"Retentive Network: A Successor to Transformer for Large Language Models" (Sun et al., 2023) — the claim in the title is ambitious, but the retention mechanism and its parallel-recurrent duality are beautifully motivated.
Distill.pub's "Augmented RNNs" — if you want interactive, visual explanations of NTMs and attention-augmented memory, this is unforgettable. The interactive diagrams make the addressing mechanisms click in a way that static figures can't.