Probabilistic Models & Inference
I avoided graphical models for longer than I’d like to admit. Every time I saw one of those diagrams with circles and arrows and Greek letters stuffed inside rectangles, I’d nod along in the meeting and then Google the concept again thirty minutes later. The nodes, the edges, the plates, the d-separation rules — it all felt like an elaborate notation system designed to keep outsiders out. Finally the discomfort of not knowing what was actually going on under the hood grew too great for me. Here is that dive.
Probabilistic graphical models are a way to represent joint probability distributions using graphs. They were introduced in earnest through Judea Pearl’s work in the 1980s and became the dominant framework for structured reasoning under uncertainty long before deep learning took over. Today they’re the backbone of causal inference, medical diagnosis systems, probabilistic programming languages, and even the latent-variable structure inside VAEs.
Before we start, a heads-up. We’re going to be working with probability distributions, conditional independence, and a bit of calculus when we get to variational inference. You don’t need to know any of it beforehand. We’ll add the concepts we need one at a time, with explanation.
This isn’t a short journey, but I hope you’ll be glad you came.
Drawing the story: Bayesian networks
When there’s no arrow direction: Markov random fields
Plate notation: handling repetition without losing your mind
D-separation: reading independence off the graph
Asking questions: exact inference
Rest stop and an off-ramp
When exact answers are too expensive: the approximation fork
The sampling path: MCMC
Metropolis-Hastings: the random walk
Gibbs sampling: one variable at a time
Hamiltonian Monte Carlo: physics to the rescue
The optimization path: variational inference
The ELBO: the quantity that makes it all work
Mean-field: the crudest possible approximation
Amortized inference: teaching a neural network to do inference
Choosing your path: when to use which
Resources and credits
The Problem with Giant Probability Tables
Let’s start with something concrete. Imagine you’re building a diagnostic system for a small clinic. The clinic sees patients who might have one of three conditions: a cold, the flu, or allergies. The doctor observes four symptoms: fever, cough, runny nose, and fatigue. Each variable is binary — present or absent.
The honest, brute-force way to represent every possible combination of diseases and symptoms is to write down a joint probability table. Seven binary variables means 27 = 128 possible combinations, each needing its own probability. Subtract one (because probabilities sum to one) and you need 127 numbers.
That’s manageable. But think about what happens as the clinic grows. Add ten diseases and twenty symptoms. Now you have 30 binary variables and 230 − 1 ≈ one billion entries. For a realistic hospital system with hundreds of variables, the table is larger than the number of atoms in the observable universe.
This is the fundamental problem. Joint probability distributions are exponentially large. We need a way to compress them. And the key insight is that most variables in the world don’t directly interact with most other variables. Your fever depends on whether you have the flu. It doesn’t directly depend on whether the hospital parking lot is full. If we can write down which variables directly influence which other variables, we can factor this giant table into many small tables.
That’s what graphical models are. They’re a map — a literal picture — of which variables talk to which. And once you have that map, the math becomes tractable.
Drawing the Story: Bayesian Networks
A Bayesian network is a directed acyclic graph — a DAG — where each node is a random variable and each arrow points from cause to effect. “Directed” means the arrows have direction. “Acyclic” means you can’t follow the arrows in a circle back to where you started. Think of it as a family tree for variables: parents influence children, but children don’t influence parents.
Back to our clinic. A cold can cause a cough and a runny nose. The flu can cause a fever, a cough, and fatigue. Allergies can cause a runny nose and fatigue. We draw that as a graph:
Cold Flu Allergies
│ \ / | \ / |
│ \ / | \ / |
│ ↓ ↓ ↓ ↓ ↓ ↓
│ Cough Fever │ Fatigue Runny Nose
│ ↑ │ ↑ ↑
└────┘ └────┘ │
│
(Cold also causes Cough) (Allergies also cause Runny Nose)
The power of this picture is in what it lets us write down. The full joint distribution over all seven variables factors into a product of small conditional probability tables, one per node:
P(Cold, Flu, Allergies, Fever, Cough, Runny, Fatigue)
= P(Cold) × P(Flu) × P(Allergies)
× P(Fever | Flu)
× P(Cough | Cold, Flu)
× P(Runny | Cold, Allergies)
× P(Fatigue | Flu, Allergies)
Each node’s probability depends only on its parents in the graph. That’s the factoring rule for Bayesian networks. Instead of 127 numbers for the full joint table, we need: 1 number for P(Cold), 1 for P(Flu), 1 for P(Allergies), 1 for P(Fever | Flu), 3 for P(Cough | Cold, Flu), 3 for P(Runny | Cold, Allergies), and 3 for P(Fatigue | Flu, Allergies). That’s 13 numbers instead of 127. The savings grow exponentially as you add more variables.
I’ll be honest — when I first saw this factoring trick, I didn’t trust it. How can throwing away all those joint probabilities be okay? The answer is conditional independence. By drawing the arrows the way we did, we’re asserting that, for instance, once you know whether someone has the flu, their fever is independent of their allergy status. That’s a modeling assumption, and it better be roughly right, but when it is, the compression is dramatic.
Each little table — P(Fever | Flu), P(Cough | Cold, Flu), and so on — is called a conditional probability table, or CPT. The CPT for each node lists the probability of each outcome for every combination of its parents. The node with no parents (like Cold, Flu, Allergies) has an even simpler table: a single probability.
When There’s No Arrow Direction: Markov Random Fields
Bayesian networks work beautifully when you can tell a causal story — diseases cause symptoms, rain causes wet grass, a burglar causes an alarm. But sometimes relationships are symmetric. Neighboring pixels in an image tend to have similar colors, but neither pixel “causes” the other. Amino acids that sit near each other in a protein interact, but there’s no direction to that interaction.
For these situations, we use a Markov random field (MRF), also called an undirected graphical model. The edges have no arrows. Instead of conditional probability tables, MRFs use potential functions — also called factors — that score how “compatible” neighboring variable assignments are. A high score means those values like being together. A low score means they don’t.
The joint distribution in an MRF looks like this:
P(X₁, ..., Xₙ) = (1/Z) × ∏ ψ_c(X_c)
cliques c
Each ψc is a potential function over a clique — a group of variables that are all connected to each other. And Z is the partition function, the normalizing constant that makes all the probabilities sum to one.
Z is where the trouble starts. To compute Z, you’d have to sum the product of all potentials over every possible configuration of every variable. For n binary variables, that’s 2n terms. This is why computing exact probabilities in an MRF is generally intractable — and why we’ll need the approximate inference methods we’ll build up to later.
Think of the difference this way. A Bayesian network tells a generative story: first the disease happens, then the symptoms appear. You can simulate data by following the arrows forward. An MRF describes correlations: these variables tend to agree with each other, without saying who caused whom. Bayesian networks are natural for diagnosis and causal reasoning. MRFs are natural for images, spatial data, and physics. The Ising model from statistical mechanics — the granddaddy of all undirected models — is an MRF.
Plate Notation: Handling Repetition Without Losing Your Mind
Let’s go back to our clinic. Suppose we see not one patient but a thousand. Each patient has their own diseases and symptoms, but they all share the same underlying disease prevalence rates and the same symptom-disease relationships. Drawing a thousand copies of the same graph would be absurd.
Plate notation solves this with a rectangle. You draw the variables for one patient, wrap them in a box (the “plate”), and write “N = 1000” in the corner. Everything inside the plate gets repeated N times. Everything outside the plate is shared across all repetitions.
θ ← disease prevalence (shared across all patients)
│
↓
┌──────────────────────┐
│ Disease_i │
│ │ │ N = 1000
│ ↓ │
│ Symptoms_i │
└──────────────────────┘
The parameter θ sits outside the plate — there’s only one copy, shared by everyone. Each patient’s disease and symptoms sit inside — they get their own independent copy. This is how you represent a model with a thousand data points using a diagram the size of a napkin sketch.
Plates can nest. If your clinic has multiple departments, each with its own patients, you’d have an outer plate for departments and an inner plate for patients within each department. The notation scales to arbitrarily complex hierarchical models — which is exactly the kind of model that probabilistic programming languages like PyMC and Stan are built to handle.
When I first encountered plate notation, I kept confusing “inside the plate” with “outside the plate.” The rule is: if a variable is specific to each data point, it goes inside. If it’s shared across data points (a parameter you’re trying to learn), it goes outside. Shaded nodes mean observed data. Unshaded nodes mean latent (hidden) variables. That’s the entire visual vocabulary.
D-Separation: Reading Independence Off the Graph
Here’s the part that makes graphical models genuinely powerful, not merely pretty diagrams. Given a graph, you can determine which variables are conditionally independent of which other variables by looking at the picture. No math required. The algorithm for doing this is called d-separation.
There are three fundamental connection patterns, and everything else is built from these three. Let’s walk through each one using our clinic.
The Chain. Flu → Fever → HospitalVisit. The flu causes a fever, and the fever causes a hospital visit. Without knowing anything else, learning someone has the flu tells you something about whether they went to the hospital. But once you know they have a fever, learning about the flu doesn’t give you any additional information about the hospital visit. The fever “screens off” the flu from the hospital visit. Observing the middle variable blocks the flow of information.
The Fork. Cough ← Cold → RunnyNose. The cold is a common cause of both symptoms. If you see someone coughing, you might guess they also have a runny nose (because both suggest a cold). But if you already know they have a cold, the cough tells you nothing new about the runny nose. The common cause, once known, makes its effects independent.
The Collider. This is the one that trips everyone up. Cold → Cough ← Flu. Both a cold and the flu can cause a cough. Normally, having a cold tells you nothing about whether someone has the flu — they’re independent. But now suppose you observe that the patient is coughing. If you then learn they don’t have a cold, that increases your belief that they have the flu. Observing the common effect creates a dependency between causes that wasn’t there before. This is called the explaining away effect.
I still occasionally get tripped up by colliders. The intuition is: when you see an effect with multiple possible causes, ruling out one cause makes the other causes more likely. It’s the reason why conditioning on a collider (or any descendant of a collider) can introduce bias in statistical analyses — a critical insight for causal inference.
The full d-separation algorithm amounts to checking whether every path between two variables is “blocked” by the set of observed variables, using these three patterns as building blocks. If every path is blocked, the variables are conditionally independent given the observations. If even one path is unblocked, they’re dependent.
Asking Questions: Exact Inference
A graphical model is a map. Inference is using that map to answer questions. The most common question: given some observed evidence (the patient has a fever and a cough), what is the probability of each unobserved variable (do they have the flu)?
For small or nicely structured graphs, we can answer this exactly.
Variable elimination works by summing out (marginalizing) one variable at a time. The idea is simple: if you don’t care about a variable, sum over all its possible values. Let’s trace through a tiny example with our clinic.
Say we want P(Flu | Cough=yes). We need to eliminate (sum out) Cold, Allergies, Fever, RunnyNose, and Fatigue. We pick an elimination order and process each variable. For each one, we multiply together all the factors that mention it, sum over its values, and produce a new, smaller factor. When all the unwanted variables are gone, we normalize what’s left.
The catch: the order in which you eliminate variables matters enormously. A bad order can produce intermediate factors that are exponentially large. Finding the optimal elimination order is NP-hard. In practice, heuristics like “eliminate the variable that creates the smallest new factor” work well enough.
Belief propagation takes a different approach: message passing. Each node sends “messages” to its neighbors summarizing what it believes. On trees — graphs with no loops — this is exact. Messages flow from leaves to root, then from root back to leaves. Two passes, and every node knows its marginal probability. This is the sum-product algorithm, and it’s elegant: each node only needs to talk to its neighbors, and the whole thing converges in linear time.
On graphs with loops (which is most real-world graphs), belief propagation isn’t exact anymore. But “loopy belief propagation” — running the same message-passing rules on a loopy graph and hoping for the best — often works surprisingly well. It has no convergence guarantee, and sometimes it oscillates or diverges. But in practice, for things like error-correcting codes (turbo codes, LDPC codes) and image segmentation, loopy BP has been wildly successful.
There’s also the junction tree algorithm, which transforms any graph into a tree of cliques and then runs exact belief propagation on that tree. It’s the general-purpose exact inference algorithm. The cost is exponential in the treewidth of the graph — a measure of how “tree-like” the graph is. For chain-shaped or tree-shaped graphs, treewidth is small and inference is fast. For densely connected graphs like grid-shaped MRFs over images, treewidth is too large for practical use.
And that’s the fundamental tension. Exact inference is beautiful when it works. But for most interesting models, the graph is too big, too loopy, or too densely connected. We need approximations.
Rest Stop and an Off-Ramp
Congratulations on making it this far. If you need to stop, you can.
You now have a mental model for the entire first half of probabilistic modeling. You know that graphical models — Bayesian networks and Markov random fields — represent joint distributions compactly by encoding which variables directly influence which others. You know how to read conditional independence from the graph using d-separation. You know that exact inference works by eliminating variables or passing messages, and that it’s tractable for tree-structured graphs but breaks down when the graph gets too complex.
That’s a genuinely useful mental model. If someone mentions a “Bayesian network” or “belief propagation” or “explaining away,” you know what they’re talking about.
The short version of what comes next: when exact inference is too expensive (which is most of the time in real applications), you have two families of approximation. One draws samples from the posterior (MCMC). The other finds the closest simple distribution to the posterior by optimizing a clever objective (variational inference). MCMC is accurate but slow. VI is fast but biased. There. You’re 80% of the way there.
But if the discomfort of not knowing what’s underneath those approximations is nagging at you, read on.
When Exact Answers Are Too Expensive: The Approximation Fork
For our clinic model with seven variables, exact inference is instant. For a hospital network with a few hundred variables, it’s intractable. For the continuous latent variables in a VAE or a Gaussian process, exact inference isn’t even on the table.
This is where the field has split into two philosophies for decades, and understanding both is a core skill in probabilistic ML.
The sampling philosophy says: we can’t write down the posterior in closed form, but we can draw samples from it. If we draw enough samples, we can estimate any quantity we care about — the mean, the variance, the probability that a parameter is positive, anything. This family of methods is called Markov chain Monte Carlo (MCMC).
The optimization philosophy says: forget about getting the true posterior. Pick a family of simple distributions, and find the one that’s closest to the true posterior. Turn an inference problem into an optimization problem. This family of methods is called variational inference (VI).
Think of it like finding a mountain’s height. The sampling approach is like sending a thousand hikers with altimeters to wander the mountain and averaging their readings. Slow but accurate. The optimization approach is like fitting a nice smooth surface to satellite images of the mountain. Fast but it misses the craggy details.
Let’s walk through each family, starting with sampling.
The Sampling Path: MCMC
The idea behind MCMC is one of the most beautiful in computational statistics. We construct a Markov chain — a sequence of random states where each state depends only on the previous one — whose long-run behavior (its stationary distribution) is exactly the posterior we want to sample from. Run the chain long enough, and the samples we collect will be distributed according to p(θ|D), the true posterior.
The word “Monte Carlo” refers to using random sampling to estimate quantities. “Markov chain” describes the mechanism that generates those samples. Together, MCMC means: “we’re generating random samples by running a carefully designed chain of random steps.”
Let’s make this concrete. Back to our clinic, but now with a continuous parameter: we want to estimate θ, the true prevalence of the flu in our patient population, given that we observed 12 flu cases out of 100 patients.
Metropolis-Hastings: The Random Walk
The Metropolis-Hastings algorithm is the foundational MCMC method, published in 1953 — making it older than most of computer science. Here’s how it works.
Imagine you’re standing on a landscape where the height at every point represents the posterior probability. You want to explore this landscape in a way that spends more time in the high regions (likely parameter values) and less time in the low regions (unlikely values). You can’t see the whole landscape at once. You can only measure the height directly under your feet.
The algorithm goes like this. You’re standing at some point θ. You propose a random step — say, move a small random distance in a random direction. That gives you a candidate point θ′. Now you check: is the proposed point higher or lower than where you are?
If θ′ is higher (more probable), you move there. Always. If θ′ is lower, you might still move there, with a probability proportional to how much lower it is. If it’s only slightly lower, you’ll probably accept. If it’s dramatically lower, you’ll almost certainly reject and stay where you are.
The acceptance probability is: α = min(1, p(θ′|D) / p(θ|D)). And here’s the beautiful part: computing this ratio cancels the normalizing constant. You never need to compute the intractable denominator in Bayes’ theorem. The ratio of unnormalized posteriors is enough.
import numpy as np
def metropolis_hastings(log_target, initial, n_samples, proposal_std):
"""MH sampler: the random walk through posterior space."""
samples = np.zeros(n_samples)
current = initial
current_lp = log_target(current)
accepted = 0
for i in range(n_samples):
# Propose: take a random step
proposed = current + np.random.normal(0, proposal_std)
proposed_lp = log_target(proposed)
# Accept/reject: compare heights
# (Working in log space for numerical stability)
if np.log(np.random.uniform()) < (proposed_lp - current_lp):
current = proposed
current_lp = proposed_lp
accepted += 1
samples[i] = current
print(f"Acceptance rate: {accepted / n_samples:.1%}")
return samples
That’s the entire algorithm. Propose, compare, accept or reject, repeat. After enough iterations, the histogram of collected samples approximates the posterior distribution.
The problem is that “enough iterations” can be a very large number. In one dimension, a random walk explores the space reasonably fast. In ten dimensions, it’s sluggish. In a hundred dimensions, it’s nearly useless. The random walk takes tiny steps, and each step only changes the parameter value by a small amount. In high-dimensional spaces, the chance of proposing a point that’s actually in a high-probability region becomes vanishingly small.
We need smarter ways to move through the space.
Gibbs Sampling: One Variable at a Time
One idea: instead of proposing all parameters at once, update them one at a time. For each parameter θi, sample its new value from the conditional distribution p(θi | everything else). If these conditional distributions have nice closed forms — which they do in many conjugate models — every proposal is automatically accepted. No rejection, no wasted steps.
This is Gibbs sampling, and for decades it was the workhorse of Bayesian computation. Software packages like BUGS and JAGS were built entirely around it. Latent Dirichlet Allocation (the topic model) was made practical by Gibbs sampling.
The weakness shows up when variables are strongly correlated. Imagine the posterior has a long, narrow diagonal ridge — like a tilted ellipse. Gibbs can only move horizontally or vertically, one variable at a time. It takes many tiny steps to crawl along that diagonal. It’s like trying to walk diagonally across a city that only has a north-south/east-west grid. You can get there, but it takes a lot of zigzagging.
For models with many strongly correlated parameters (which includes most interesting models), we need something that can take big, informed steps along the contours of the posterior.
Hamiltonian Monte Carlo: Physics to the Rescue
This is the method that made modern Bayesian inference actually work on real problems, and I’ll be honest — the physics connection took me a while to internalize.
Here’s the key insight. Instead of a random walk, we give our explorer a physical body. The negative log-posterior becomes a landscape — specifically, a potential energy surface. High-probability regions are valleys (low potential energy) and low-probability regions are peaks (high potential energy). We place a ball on this surface and give it a random kick. The ball rolls, gaining speed in the valleys and slowing down on the hills, following the laws of Hamiltonian mechanics.
The ball needs a momentum — a velocity in parameter space. At each step of HMC, we do two things: randomly sample a new momentum (a random kick), then simulate the ball rolling along the surface for a fixed number of steps. The final position becomes our proposed sample. Because the ball follows the gradient of the posterior — rolling downhill into high-probability regions — the proposals tend to be far away and in high-probability areas.
The physics:
Position θ ↔ parameter values
Potential energy U(θ) ↔ -log p(θ|D)
Momentum p ↔ auxiliary variable (sampled fresh each step)
Total energy H(θ,p) = U(θ) + K(p) ↔ potential + kinetic
Hamilton's equations:
dθ/dt = p (position changes according to momentum)
dp/dt = -∇U(θ) (momentum changes according to gradient)
= ∇ log p(θ|D) ← this is the crucial ingredient
The gradient — the slope of the posterior surface — is what makes HMC dramatically more efficient than random-walk methods. In 100 dimensions, a random walk doesn’t know which direction to go. HMC follows the gradient, rolling along the posterior’s contours like a marble in a bowl.
But HMC has two hyperparameters that need tuning: the step size (how finely we simulate the physics) and the number of leapfrog steps (how long we let the ball roll before stopping). Get these wrong, and HMC can be worse than a random walk.
NUTS — the No-U-Turn Sampler — eliminates this tuning problem. It runs the simulation forward and backward in time, building a tree of possible trajectories, and stops when the trajectory starts to double back on itself (a “U-turn”). NUTS is what makes Stan and PyMC practical for working scientists. It’s HMC without the hand-tuning, and it works remarkably well across a wide range of models.
One important requirement: HMC needs the gradient of the log-posterior. Your model must be differentiable. This is exactly why modern probabilistic programming languages are built on automatic differentiation engines — PyMC on Aesara/PyTensor, NumPyro on JAX, Stan on its own AD system. They compute the gradient for free, so you write the model and the sampler handles the physics.
I’m still developing my intuition for why the physics simulation preserves the right distribution. The formal answer involves Hamiltonian dynamics being volume-preserving (Liouville’s theorem) and time-reversible, which guarantees detailed balance. But the visceral understanding — a ball rolling on the posterior surface — is the mental image that made it click for me.
The Optimization Path: Variational Inference
MCMC gives you the true posterior — eventually. But “eventually” might mean hours or days of sampling. If you have millions of data points, or need an answer in milliseconds, or want to train a deep generative model end-to-end with gradient descent, sampling is too slow.
Variational inference takes a completely different approach. Instead of sampling from the true posterior, we approximate it. We choose a family of simple, tractable distributions — maybe Gaussians, maybe something else — and find the member of that family that’s closest to the true posterior. We’ve turned an inference problem into an optimization problem, and optimization is something we know how to do fast.
The analogy from earlier: MCMC sends hikers to explore every ridge and valley of the mountain. VI takes a satellite photo and fits the smoothest surface that matches the data. It’s faster, but it might smooth over the interesting craggy bits.
“Closest” is measured by KL divergence, which quantifies how different one probability distribution is from another. We want to minimize KL(q || p) — the divergence from our simple approximation q(θ) to the true posterior p(θ|D). But computing this KL divergence requires knowing p(θ|D), which is the thing we can’t compute. We seem stuck.
This is where the ELBO comes in.
The ELBO: The Quantity That Makes It All Work
There’s a mathematical identity that saves us. The log-evidence — log p(D), the total probability of the observed data — can be decomposed as:
log p(D) = ELBO(q) + KL(q || p)
The left side is a fixed constant (it doesn’t depend on q). The KL divergence on the right is always non-negative. That means the ELBO is always less than or equal to log p(D) — it’s a lower bound on the log-evidence. ELBO stands for Evidence Lower BOund.
Here’s the trick: since log p(D) is constant, maximizing the ELBO is exactly equivalent to minimizing KL(q || p). We can’t compute the KL divergence directly, but we can maximize the ELBO, and it achieves the same thing.
The ELBO itself decomposes into two terms we can actually compute:
ELBO(q) = E_q[log p(D|θ)] − KL(q(θ) || p(θ))
───────────────── ──────────────────
"How well does θ "How far is q from
explain the data?" the prior?"
The first term rewards q for concentrating on parameter values that explain the data well. The second term penalizes q for straying too far from the prior. Sound familiar? It’s the same fit-versus-regularization tradeoff that appears everywhere in machine learning, expressed in the language of probability.
I won’t pretend the algebra that derives the ELBO is the insightful part. The insight is this: we wanted to do something impossible (minimize KL to the true posterior), and we found a surrogate objective (the ELBO) that achieves the same thing and that we can compute. That’s the entire intellectual move.
Mean-Field: The Crudest Possible Approximation
We need to choose a family of distributions for q. The most dramatic simplification is the mean-field approximation: assume all latent variables are independent. The variational distribution factors as:
q(θ₁, θ₂, ..., θₙ) = q₁(θ₁) × q₂(θ₂) × ... × qₙ(θₙ)
Each variable gets its own little distribution, optimized independently. The optimization algorithm — coordinate ascent variational inference (CAVI) — cycles through the variables one at a time, updating each qi while holding the others fixed, until convergence.
This is crude. The true posterior might have strong correlations between parameters — maybe θ1 and θ2 are tightly coupled, so knowing one tells you a lot about the other. Mean-field ignores this entirely. It treats them as independent, which typically underestimates the posterior uncertainty. The marginal distributions might be centered in the right place, but they’re too narrow and too confident.
There’s a deeper problem. KL(q || p) is mode-seeking. When the true posterior has multiple peaks (modes), q tends to collapse onto a single peak and pretend the others don’t exist. It finds one good explanation and ignores alternatives. For a multimodal posterior — common in mixture models, multi-layer neural networks, and other models with symmetries — mean-field VI gives a misleadingly simple picture of the uncertainty.
Despite these limitations, mean-field is fast, GPU-friendly, and often gives reasonable point estimates. For many applications — especially when you need a quick answer and the uncertainty isn’t critical — it’s good enough. The key is knowing what it gets wrong.
Modern approaches improve on mean-field by using richer variational families. Full-rank Gaussian approximations capture correlations. Normalizing flows can represent complex, multimodal distributions. Each step toward a richer family gives a tighter ELBO bound and a better approximation, at the cost of more computation.
Amortized Inference: Teaching a Neural Network to Do Inference
Standard VI optimizes variational parameters separately for each data point. If you have a million images and each one has its own latent variable z, that’s a million separate optimization problems. Slow.
Amortized inference solves this by training a neural network — the encoder or recognition network — to predict the variational parameters directly from the data. Instead of optimizing q(z|x) for each x individually, you learn a function that maps any x to the parameters of q(z|x) in a single forward pass.
This is exactly what a Variational Autoencoder does. The VAE encoder takes an image x and outputs the mean μ(x) and variance σ²(x) of a Gaussian q(z|x). The decoder takes a sample from that Gaussian and reconstructs the image. The training loss is the negative ELBO:
VAE loss = -ELBO
= reconstruction error + KL(q(z|x) || p(z))
────────────────────── ──────────────────
"How good is the "How far is the
reconstruction?" posterior from the prior?"
Every time you train a VAE, you’re doing amortized variational inference. The encoder has learned to perform approximate inference in a single forward pass — no iterative optimization at test time. The “amortization” is that the cost of learning to do inference is spread across the entire training set, and once learned, inference for a new data point is nearly free.
There’s one technical detail that makes this trainable: the reparameterization trick. We can’t backpropagate through a sampling operation (sampling z from q(z|x) is non-differentiable). The trick is to rewrite the sampling as: z = μ(x) + σ(x) × ε, where ε is random noise sampled from a standard Gaussian N(0,1). Now ε is independent of the network parameters, and we can backpropagate through μ(x) and σ(x). The randomness is factored out. The gradient flows through.
Choosing Your Path: When to Use Which
After all of this, the practical question: which method should you actually use?
Here is how I think about it. Start with the data size. If you have millions of data points and need results in minutes, VI is almost certainly the right choice. MCMC can’t process a million-row dataset without specialized tricks (stochastic MCMC exists but is still an active research area). For datasets with hundreds to tens of thousands of observations, both approaches are viable.
Next, ask how much you trust the uncertainty estimates. If you’re building a recommendation system and the uncertainty is nice-to-have, VI is fine. If you’re making medical decisions, or if your scientific paper will be judged on the faithfulness of its posterior estimates, use MCMC (specifically NUTS). Reviewers in Bayesian statistics will expect convergence diagnostics: R̂ < 1.01, effective sample sizes above 400, no divergent transitions, and trace plots that look like hairy caterpillars.
For deep generative models (VAEs, normalizing flows, diffusion models with latent variables), amortized VI is the standard. These models have too many latent variables and too much data for MCMC to be feasible. The entire training loop is built around optimizing the ELBO.
A pattern I’ve found useful in practice: start with VI to prototype quickly. Get the model structure right, check that the results are sensible. Then, for the final analysis, validate with MCMC on a representative subset. If the VI and MCMC posteriors agree, you can trust the VI results on the full dataset. If they disagree, the MCMC results are the ones to trust.
My favorite thing about this whole landscape is that MCMC and VI aren’t really competitors — they’re complementary tools for different regimes. The field spent years arguing about which was better. The answer, like most things in engineering, is: it depends.
| Consideration | MCMC (NUTS) | Variational Inference |
|---|---|---|
| Speed | Slow (minutes to hours) | Fast (seconds to minutes) |
| Scalability | Small/medium data | Millions of data points |
| Accuracy | Asymptotically exact | Biased approximation |
| Multimodal posteriors | Can explore (slowly) | Tends to collapse to one mode |
| Uncertainty quality | Faithful (when converged) | Often underestimates |
| Gradient needed? | Yes (for HMC/NUTS) | Yes (for stochastic VI) |
| Best for | Publication, medical, safety-critical | Prototyping, production, deep generative models |
Wrap-Up
If you’re still with me, thank you. I hope it was worth it.
We started with a simple observation: joint probability tables are exponentially large, and graphical models compress them by encoding which variables directly influence which others. We learned two flavors — Bayesian networks for causal stories, Markov random fields for undirected correlations — and picked up plate notation and d-separation along the way. We saw that exact inference works beautifully on trees but collapses on complex graphs, which pushed us into the two great families of approximation: MCMC (explore the posterior by building a random walk that converges to it) and variational inference (find the closest simple distribution by optimizing the ELBO). And we saw how amortized inference collapses the VI optimization into a single neural network forward pass, which is how VAEs work under the hood.
My hope is that the next time you encounter a plate diagram, a discussion of d-separation, or a debate about MCMC versus VI, instead of that familiar urge to nod along and Google it later, you’ll have a pretty darn good mental model of what’s going on under the hood.
Resources and Credits
Koller & Friedman, Probabilistic Graphical Models (2009). The 1,200-page bible. Exhaustive, rigorous, and the definitive reference. Not a weekend read, but wildly comprehensive.
Bishop, Pattern Recognition and Machine Learning, Chapters 8 & 10. The most readable introduction to graphical models and variational inference in a textbook. Still the one I recommend first.
Blei, Kucukelbir & McAuliffe, “Variational Inference: A Review for Statisticians” (2017). The clearest exposition of modern VI I’ve found. Connects the classical ideas to stochastic VI and amortized inference.
Betancourt, “A Conceptual Introduction to Hamiltonian Monte Carlo” (2017). The single best explanation of HMC. Makes the physics intuition visceral. If you read one paper on MCMC, make it this one.
Kingma & Welling, “Auto-Encoding Variational Bayes” (2013). The O.G. VAE paper. Where amortized inference meets deep learning. Remarkably concise for how influential it’s been.
McElreath, Statistical Rethinking (2020). A warm, opinionated introduction to Bayesian modeling with extensive use of MCMC. The accompanying lectures on YouTube are unforgettable.