Training Stability

Chapter 7: Deep Learning Foundations Gradient Flow · Normalization · Initialization · Residual Connections
Section Overview

Deep networks don't train by default — they vanish, explode, or stagnate. This section traces why that happens from first principles, then builds up the four pillars that make modern deep learning possible: proper initialization, normalization layers, residual connections, and gradient clipping. We'll follow a single toy network through each failure mode and its fix, so by the end you can diagnose a broken training run from the gradient norms alone.

The Core Problem: Why Deep Networks Break

I avoided staring at gradient flow for a long time. Training worked when I copied the right recipe, and I told myself that was enough. But the first time a 40-layer network produced NaN on batch 12 — and I had no idea which of the 40 layers was responsible — I realized I'd been driving without understanding the engine. Here is that stare.

The culprit is the chain rule. Backpropagation computes the gradient for each weight by multiplying local gradients from every layer between the loss and that weight. For a weight sitting in layer 1 of a 50-layer network, its gradient is a product of roughly 50 terms. Each term includes the derivative of that layer's activation function and the weights in that layer. If those terms are consistently a bit less than 1, the product collapses toward zero. If they're consistently a bit more than 1, the product explodes toward infinity.

That's really the entire story. Everything else in this section — initialization, normalization, residual connections, clipping — exists to keep those per-layer multipliers hovering near 1. Not shrinking, not growing. Stable.

Let's make this concrete with a running example. Imagine we're training a network with 20 layers to classify handwritten digits — a small enough problem that we can watch every gradient by hand.

Vanishing Gradients

Suppose every layer in our 20-layer digit classifier uses the sigmoid activation. Sigmoid squashes its input to the range (0, 1), and its derivative is σ'(x) = σ(x) · (1 − σ(x)). That derivative hits a maximum of 0.25 when x = 0 — and that's the ceiling. Everywhere else it's smaller. Once a neuron saturates (output near 0 or 1), its local derivative drops toward zero.

Now trace what happens during backpropagation. The gradient for layer 1's weights gets multiplied by the sigmoid derivative at every layer it passes through. In the best possible case — every neuron sitting at exactly zero input, every derivative at its maximum of 0.25 — the gradient shrinks by a factor of 0.25 at each layer.

# Best-case gradient magnitude with sigmoid across depth
for depth in [10, 20, 50]:
    grad = 0.25 ** depth
    print(f"After {depth:2d} layers: 0.25^{depth} = {grad:.2e}")

# After 10 layers: 0.25^10 = 9.54e-07
# After 20 layers: 0.25^20 = 9.09e-13
# After 50 layers: 0.25^50 = 7.89e-31

After 10 layers the gradient is one millionth of its original value. After 20, it's less than a trillionth. After 50, float32 rounds it to exactly zero. And remember — 0.25 is the best case. Real neurons are saturated more often than not, which makes the actual numbers worse.

I'll be honest — the first time I computed 0.25²⁰ and saw 10⁻¹³, I didn't believe it. Surely something else would compensate? But no. That's literally what happens. The early layers receive gradients so small that their weights barely move from their random starting values. The last few layers near the loss train fine — they see healthy gradients. The result is a network where the top floors are furnished and the foundation is still scaffolding. A penthouse on sand.

How You Spot It

The symptom is deceptive. Loss plateaus early, and it looks like the model has converged. But it hasn't learned anything useful — it's more like it gave up. The giveaway is checking gradient norms per layer. The last few layers show healthy values (maybe 10⁻² to 10⁰). The first few layers show values like 10⁻⁹ or smaller. That ratio — orders of magnitude of difference between early and late layers — is the fingerprint of vanishing gradients.

⚠️ The Deceptive Plateau

A flat loss curve doesn't always mean convergence. If you're training a deep network with sigmoid or tanh and the loss stops improving early, print the gradient norms for your first few layers before concluding the model has learned what it can. You might discover that the early layers never started learning at all.

Exploding Gradients

Same chain rule, opposite failure mode. If those per-layer multipliers are consistently greater than 1, the product doesn't collapse — it detonates. This happens when weight matrices have large eigenvalues that amplify the gradient at each layer.

# Gradient amplification across layers
for factor in [1.5, 2.0, 3.0]:
    result = factor ** 50
    print(f"{factor}^50 = {result:.2e}")

# 1.5^50 = 6.38e+08
# 2.0^50 = 1.13e+15
# 3.0^50 = 7.18e+23

A per-layer multiplier of 1.5 — not even that far from 1 — means the first layer's gradient is 638 million times larger than the output gradient. Weight updates become enormous, weights jump to absurd values, and the next forward pass produces inf or NaN. Training is over, and it happened in one step.

Back to our digit classifier. Say we initialized the weights too large, or we're using an RNN where the same weight matrix is multiplied at every time step. In RNNs this is literal matrix exponentiation — if the spectral radius of the weight matrix exceeds 1, gradients grow exponentially with sequence length. A 200-token sentence means multiplying the same matrix 200 times. The math is unforgiving.

How You Spot It

Exploding gradients are much louder than vanishing ones. Your loss is decreasing normally — maybe even looking promising — then suddenly spikes to NaN or inf. No warning, no gradual degradation. One moment the model is learning, the next moment it's dead. If you log weight values, you'll see them balloon to astronomical magnitudes in the batch right before the crash.

💡 Quick Diagnostic

Loss suddenly jumps to NaN? Likely exploding gradients. Loss plateaus and early layers barely update? Likely vanishing gradients. The root cause is the same — the chain rule amplifying deviations from 1 — but they fail in very different ways and require different fixes.

Weight Initialization: The First Line of Defense

The most direct way to keep per-layer multipliers near 1 is to start with the right weights. Wrong initialization doesn't give your network a slow start — it gives it a death sentence.

The Symmetry Problem

Before we talk about scale, there's an even more basic failure. If you initialize all weights to zero, every neuron in a layer computes the same weighted sum, produces the same output, receives the same gradient, and gets the same update. After one training step, they're still identical. After a million steps, still identical. You've paid for 512 neurons and gotten the computational power of one.

This is the symmetry problem. Identical neurons have no mechanism to become different. Random initialization breaks that symmetry — each neuron starts with slightly different weights, specializes during training, and eventually learns a different feature. Randomness isn't a compromise here. It's a requirement.

Scale Is Everything

So we initialize randomly. But from what distribution? If we draw weights from N(0, 10) — a normal distribution with standard deviation 10 — activations explode. By layer 5 of our digit classifier, we're hitting float overflow. If we draw from N(0, 0.001), activations collapse to zero. By layer 5, every neuron outputs approximately the same tiny value and gradients vanish.

Let's trace this through our 20-layer network to see why. For a single neuron with n inputs, the output before the activation is z = w₁x₁ + w₂x₂ + ... + wₙxₙ. If the weights and inputs are independent with mean zero, the variance of z is n · Var(w) · Var(x). That factor of n is the key — it means the output variance depends on the layer width. If Var(w) is too large, variance grows at each layer. Too small, it shrinks. We need Var(w) calibrated so that the variance of activations stays roughly constant from layer to layer. Not growing, not shrinking. Stable.

Xavier / Glorot Initialization (2010)

Xavier Glorot and Yoshua Bengio worked out the math for this. Their derivation assumes roughly linear activations (valid for sigmoid and tanh near zero, which is where properly initialized neurons start). To preserve variance in both the forward pass and the backward pass, they arrived at:

Var(W) = 2 / (fan_in + fan_out)

Here fan_in is the number of inputs to the neuron and fan_out is the number of outputs. The 2 in the numerator comes from averaging the forward-pass requirement (1/fan_in) with the backward-pass requirement (1/fan_out). In practice, you sample from N(0, 2/(fan_in + fan_out)) or the equivalent uniform distribution.

This works well for sigmoid and tanh — but it has a blind spot.

He / Kaiming Initialization (2015)

ReLU breaks the Xavier derivation. ReLU sets all negative inputs to zero, killing roughly half the activations in each layer. That halving means Xavier's variance is too small by a factor of 2 — activations shrink layer by layer, and by layer 20 of our digit classifier, signals have decayed significantly.

Kaiming He's fix: double the variance and use only fan_in.

Var(W) = 2 / fan_in

The factor of 2 compensates exactly for ReLU zeroing out the negative half of the distribution. With He initialization and ReLU, the variance of activations stays constant across depth. Our 20-layer digit classifier has signal reaching all the way from input to output, and gradients flowing all the way back.

💡 The Decision Rule
ActivationInitializationWhy
ReLU / Leaky ReLUHe (Kaiming)Compensates for ReLU killing ~50% of inputs
Sigmoid / TanhXavier (Glorot)Assumes near-linear behavior around zero
GELU / SiLUHe worksShaped similarly to ReLU in practice

In PyTorch

PyTorch's nn.Linear defaults to Kaiming uniform initialization. So does nn.Conv2d. For most ReLU networks, the defaults are already correct — which is a testament to how settled this question is. When you need explicit control:

import torch.nn as nn

# He / Kaiming — the default choice for ReLU networks
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')

# Xavier / Glorot — for sigmoid or tanh networks
nn.init.xavier_normal_(layer.weight)

# Apply custom init across an entire model
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

model.apply(init_weights)

Biases are initialized to zero. They don't cause the symmetry problem because the weight matrices are already different — each neuron's unique weights produce unique gradients regardless of shared bias values.

Normalization Layers: Keeping Activations in Check

Proper initialization sets a good starting point, but it's a one-time thing. As training progresses, the distribution of activations shifts — what was a nicely centered distribution in layer 7 at step 0 might be wildly different at step 10,000. Normalization layers fix this dynamically, re-centering and re-scaling activations at every forward pass.

Batch Normalization

Batch normalization was introduced by Ioffe and Szegedy in 2015, and it changed what was possible in deep learning almost overnight. The mechanics are deceptively straightforward. For each feature in a layer, batch norm computes the mean and variance across all examples in the current mini-batch, subtracts the mean, divides by the standard deviation (plus a tiny ε to avoid division by zero), then applies a learned scale (γ) and shift (β).

# Batch norm forward pass (simplified)
# x has shape [batch_size, features]
mean = x.mean(dim=0)                    # mean per feature across batch
var  = x.var(dim=0)                     # variance per feature across batch
x_hat = (x - mean) / (var + 1e-5).sqrt()  # normalize
out = gamma * x_hat + beta              # learned scale and shift

The learned γ and β are crucial. Without them, the normalization would force every layer's activations to mean=0, variance=1 — stripping away representational power. γ and β let the network learn what the optimal scale and shift should be, while the normalization step ensures it starts from a stable baseline.

There's a subtlety that matters in production: during training, batch norm uses the current mini-batch statistics. During inference, it uses running averages accumulated during training (an exponential moving average of the means and variances seen during training). This means the model's behavior at inference time doesn't depend on what else is in the batch — a property that matters more than it sounds.

The original paper claimed batch norm works by reducing internal covariate shift — the shifting of layer input distributions as the weights of earlier layers change. I'll be honest: this explanation is contested. Later research suggests the real benefit is smoothing the loss landscape, making the optimization surface better-conditioned so that gradients are more predictable and larger learning rates are feasible. The debate isn't fully resolved, but what's not debatable is that batch norm works. It makes training faster, allows higher learning rates, and reduces sensitivity to initialization.

Where Batch Norm Breaks Down

Batch norm has a dependency that becomes a liability: it needs a batch. With a batch size of 1 — as often happens during inference, or with very large models that can only fit one example per GPU — the batch statistics are meaningless (mean and variance of a single sample tell you nothing). It also struggles with variable-length sequences in NLP: tokens at position 50 might exist in only a few sequences in the batch, making statistics unreliable.

This is what motivated a different approach.

Layer Normalization

Layer normalization flips the axis. Instead of computing statistics across the batch (across examples, per feature), it computes statistics across features (within a single example). Each sample is normalized independently — no dependency on what else is in the batch, no issues with batch size 1, no problems with variable-length sequences.

# Layer norm forward pass (simplified)
# x has shape [batch_size, features]
mean = x.mean(dim=-1, keepdim=True)        # mean across features, per sample
var  = x.var(dim=-1, keepdim=True)         # variance across features, per sample
x_hat = (x - mean) / (var + 1e-5).sqrt()
out = gamma * x_hat + beta

This is why every transformer uses layer norm instead of batch norm. In language models, sequences have variable lengths, batch sizes are often 1 during inference, and training is distributed across many GPUs where synchronizing batch statistics is expensive. Layer norm sidesteps all of these issues.

Pre-Norm vs. Post-Norm: Where You Put It Matters

The original transformer (2017) placed layer norm after the residual connection — Post-LayerNorm: x' = LayerNorm(x + Sublayer(x)). This turns out to be unstable for very deep models. Gradients can still explode because the unnormalized residual flows through the sublayer before being tamed.

Pre-LayerNorm places normalization before the sublayer: x' = x + Sublayer(LayerNorm(x)). This is dramatically more stable. The residual path stays clean (no normalization disrupting it), and the sublayer always receives well-behaved input. GPT-3, Llama, PaLM, and most modern large language models use Pre-LayerNorm.

An even newer variant, RMSNorm, drops the mean subtraction entirely and normalizes using only the root mean square of the activations. It's computationally cheaper and empirically performs just as well. Llama 2 and many recent LLMs use RMSNorm — a sign that the mean centering in standard layer norm may not be pulling its weight.

💡 Which Normalization When
ArchitectureNormalizationWhy
CNNsBatch NormFixed spatial dims, large batches, well-tested
Transformers / LLMsLayer Norm (Pre-LN) or RMSNormVariable lengths, batch-size independent
Small-batch / inferenceLayer Norm or Group NormBatch stats unreliable with few samples

Residual Connections: The Gradient Highway

Proper initialization and normalization keep signals healthy in theory, but there's a deeper problem. In 2015, Kaiming He's team observed something paradoxical: a 56-layer plain CNN performed worse than a 20-layer one — not on the test set (that would be overfitting), but on the training set. Adding more layers made optimization harder, not easier. They called this the degradation problem.

Think about it this way. A 56-layer network should be at least as good as a 20-layer one. In the worst case, the extra 36 layers could learn to be identity functions (pass the input through unchanged), and you'd get the same performance. But in practice, the optimizer couldn't find those identity mappings. Learning "do nothing" turns out to be surprisingly difficult for a neural network.

The fix was elegant. Instead of asking a block of layers to learn the desired output directly — y = F(x) — they restructured it as y = F(x) + x. The block only needs to learn the residual: the difference between the desired output and the input. If the ideal behavior is the identity function, F(x) needs to be zero — and pushing weights toward zero is something optimizers are very good at.

The Gradient Perspective

The gradient story is where residual connections really shine. Without the skip connection, the gradient flowing through a block is ∂F(x)/∂x — it depends entirely on whatever F learned, and if F is poorly behaved, the gradient can vanish or explode. With the skip connection, the gradient becomes ∂F(x)/∂x + 1. That "+ 1" is the gradient through the identity path (∂x/∂x = 1), and it's always there regardless of what F is doing.

# Without residual: gradient depends entirely on F
# ∂y/∂x = ∂F(x)/∂x                    → can vanish

# With residual: y = F(x) + x
# ∂y/∂x = ∂F(x)/∂x + 1               → always ≥ 1
# The +1 creates a "gradient highway" — information flows back
# even if the learned block has vanishing gradients

This is why ResNets can train 152 layers (and eventually over 1000 with proper normalization) where plain networks collapse at around 20. The skip connections create a highway for gradients to flow backward without being multiplied through dozens of nonlinear transformations. Every modern deep architecture — ResNets, transformers, U-Nets — uses this idea.

Coming back to our digit classifier: with residual connections, even if the sigmoid activations in the middle layers have vanishing derivatives, the gradient can bypass them entirely through the skip paths. The early layers receive learning signal. The network actually trains.

Gradient Clipping: The Safety Net

Initialization, normalization, and residual connections are preventive measures — they make gradient pathologies less likely. Gradient clipping is reactive — it catches explosions that slip through anyway. Think of it as a circuit breaker: you'd rather fix the wiring, but you definitely want the breaker in case something goes wrong.

Norm-Based Clipping

The most common approach computes the total L2 norm of all gradients in the model. If it exceeds a threshold (max_norm), all gradients are scaled down proportionally so the norm equals max_norm. The direction of the gradient is preserved — only the magnitude is capped.

# After loss.backward(), before optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

A max_norm of 1.0 is the standard starting point for RNNs and transformers. For deep CNNs, values between 0.5 and 5.0 are common. The right value depends on your architecture — if clipping activates on every single training step, your threshold is probably too aggressive and you're artificially limiting learning. If it never activates, it's not doing anything (which might be fine — it's a safety net, not a tuning knob).

Value-Based Clipping

An alternative is to clamp each individual gradient element to a range, say [-1, 1]. This is simpler but cruder — it can distort the gradient direction because different elements get clipped by different amounts. Norm-based clipping is preferred in nearly every modern setting.

Adaptive Gradient Clipping (AGC)

A more recent idea from the NF-Net paper (2021): instead of a fixed global threshold, compare each layer's gradient norm to its parameter norm. The intuition is that a layer with large weights can tolerate larger gradient updates than one with small weights. AGC clips only when the ratio exceeds a factor λ (typically 0.01–0.05). This was critical for training very deep normalizer-free networks where batch norm wasn't available to keep things stable.

Dead Neurons: The ReLU Tradeoff

ReLU solved vanishing gradients for active neurons — but it introduced a new failure mode. If a neuron's input becomes permanently negative (due to a large weight update or unfortunate initialization), ReLU outputs zero, the gradient is zero, and the neuron never recovers. It's dead. In a large network, you can lose 10–20% of neurons this way without noticing, because the remaining neurons compensate. But in smaller networks, or with aggressive learning rates, dead neurons can cripple capacity.

The detection is straightforward: after a forward pass, check what fraction of activations are exactly zero. If it's consistently above 50% in a layer (more than the expected ~50% from ReLU's negative half), you have dying neurons.

Leaky ReLU fixes this by outputting a small slope (typically 0.01x) for negative inputs instead of zero. The neuron still has a gradient when inactive, so it can recover. GELU — used in BERT and most modern transformers — takes a smoother approach, gradually transitioning between linear and near-zero rather than having a hard cutoff. Neither is universally better than ReLU, but both eliminate the dead neuron problem entirely.

Putting It All Together: The Stability Diagnostic

The difference between someone who can debug a broken training run and someone who randomly tries hyperparameters is knowing what to measure. Here's the function I keep in every project:

def gradient_health_check(model):
    """Print per-layer gradient norms after a backward pass."""
    norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            norm = param.grad.norm().item()
            flag = ""
            if norm < 1e-7:
                flag = " ← vanishing!"
            elif norm > 1e3:
                flag = " ← exploding!"
            norms[name] = norm
            print(f"{name:<40s} {norm:>12.2e}{flag}")

    if norms:
        ratio = max(norms.values()) / (min(norms.values()) + 1e-12)
        print(f"\nMax/min gradient ratio: {ratio:.1f}x")
        if ratio > 100:
            print("⚠ Gradient flow problem: >100x difference across layers")

The ratio between the largest and smallest gradient norms tells the whole story. In a healthy network, this ratio stays within 1–2 orders of magnitude. If it exceeds 100×, you have a gradient flow problem, and the fix depends on the direction: vanishing means check your activation functions and add residual connections; exploding means check your initialization and add gradient clipping.

⚠️ The Connection Triangle

Activation functions, weight initialization, and gradient behavior are three vertices of a single triangle. They're deeply coupled:

Sigmoid + large random init = saturated neurons = vanishing gradients. Guaranteed.
ReLU + He init = stable variance = stable gradients. Reliable.
ReLU + large random init = exploding activations = NaN within a few batches.
Any activation + BatchNorm + residual connections = dramatically reduced sensitivity to everything above.

When training fails, check all three vertices. The answer almost always points directly to the fix.

Learning Rate Warmup

One more piece that I initially dismissed but now consider essential for large models: learning rate warmup. At the very start of training, all weights are random and gradients can be noisy and large. Jumping in with your full learning rate is like flooring the accelerator on ice — you might be fine, or you might spin out immediately. Warmup ramps the learning rate linearly from near-zero to its target value over the first few hundred or thousand steps. This gives the optimizer time to find a reasonable region of the loss landscape before taking aggressive steps. Nearly every large transformer uses warmup. It's cheap insurance against early-training instability.

The Modern Recipe

He initialization + ReLU/GELU activations + normalization layers + residual connections + gradient clipping + learning rate warmup. That's the stack that makes modern deep learning work. Each piece addresses a specific failure mode: initialization prevents signal collapse at step 0, normalization prevents drift during training, residual connections guarantee gradient flow at any depth, clipping catches the remaining edge cases, and warmup smooths the critical early phase.

None of these were part of the original neural network recipe. Each was invented because someone's training run broke, they figured out why, and they engineered a fix. The stack we have now is the product of 30+ years of people watching networks fail and asking "why."

If you're still with me after all that — thank you. We started from the chain rule and a broken 20-layer sigmoid network, and built up the complete toolkit for making deep networks train. My hope is that the next time you see NaN in your loss, or a plateau that doesn't make sense, you won't reach for random hyperparameter changes. You'll check the gradient norms, look at the activation–initialization–normalization triangle, and know exactly where to intervene.