Training Debugging & Stability

Chapter 8, Section 5 ~30 min read Production-Critical Interview Favorite
TL;DR

Training a deep network is 10% architecture design and 90% debugging why the loss is doing something insane. This section covers the systematic toolkit for diagnosing and fixing training failures: loss landscape geometry that explains why optimization is hard, the single-batch overfit test that catches 80% of bugs, a structured debugging flow from data to gradients, NaN warfare, reading loss curves like a cardiologist reads EKGs, reproducibility rituals, monitoring infrastructure, the silent bugs that produce no errors but quietly destroy your model, gradient clipping mechanics, and weight initialization theory. We use a running example — a sentiment classifier on movie reviews — and break it in every way possible so we know how to fix it.

The Landscape You're Navigating

Before we talk about debugging, we need to understand the terrain. When we train a neural network, we're searching for a good set of parameters — weights and biases — that minimize a loss function. That loss function defines a surface over parameter space. For a network with n parameters, this surface lives in n+1 dimensions: n dimensions for the parameters, plus one for the loss value. Our sentiment classifier might have 5 million parameters. That means we're navigating a surface in 5,000,001-dimensional space. I don't know about you, but my spatial intuition breaks down somewhere around dimension 4.

Here's the thing that tripped me up for a long time: I kept imagining loss landscapes as bumpy 2D surfaces with valleys and peaks, like mountain terrain. The intuitions from that picture are mostly wrong for high-dimensional spaces. In two dimensions, a saddle point is a rare curiosity — a mountain pass. In millions of dimensions, saddle points are everywhere. They vastly outnumber local minima.

Saddle Points Dominate (And That's Good News)

Dauphin et al. (2014) formalized what many practitioners had suspected: in high-dimensional non-convex optimization, the critical points you encounter are overwhelmingly saddle points, not local minima. A critical point is any point where the gradient is zero. At a local minimum, the loss curves upward in every direction. At a saddle point, the loss curves upward in some directions and downward in others. Think about it probabilistically. At a random critical point in n-dimensional space, each direction independently has roughly a 50% chance of curving up or down. For that point to be a true local minimum, all n directions must curve upward. The probability of that happening by chance is roughly 2-n. For our 5-million-parameter network, that's an astronomically small number. So the critical points we encounter almost always have escape routes — directions where the loss decreases.

This is genuinely good news. It means SGD doesn't get "trapped" in local minima the way we feared in the early days of neural networks. The real challenge at saddle points is that gradients are small near them, so training slows down. Momentum-based optimizers like Adam help tremendously here because they carry velocity through the low-gradient saddle region.

Sharp vs Flat Minima: Not All Solutions Are Equal

Not all low-loss regions generalize equally well. Keskar et al. (2017) showed that sharp minima — regions where the loss changes rapidly when you perturb the weights — tend to generalize poorly. Flat minima — regions where you can wiggle the weights substantially without the loss changing much — tend to generalize well. The intuition is elegant: a sharp minimum is "fragile." The training data happens to align in a very specific way at that point, but any slight distribution shift (like moving from training to test data) throws you out of the valley. A flat minimum is "robust." Small perturbations don't matter, which means the solution captures something more fundamental about the data.

This connects to a beautiful observation about batch size. Small-batch SGD introduces noise into the gradient estimates. That noise acts as an implicit regularizer, pushing the optimizer away from sharp minima (because the noisy gradient overshoots and escapes) and toward flat minima (because the noise can't push you out). Large-batch training converges to sharper minima because the gradient estimates are more precise and the optimizer settles into whatever narrow valley it finds. This is one reason why naively scaling up batch size often hurts generalization — you lose the regularizing effect of gradient noise.

Skip Connections Reshape the Landscape

Li et al. (2018) produced some of the most visually stunning results in deep learning by visualizing loss landscapes. They showed that a plain deep network (like VGG) has an incredibly chaotic loss surface — riddled with sharp peaks, narrow valleys, and large flat regions where the gradient gives you no useful signal. But add skip connections (as in ResNet), and the landscape becomes dramatically smoother. Almost convex-looking, with broad smooth valleys leading to good minima.

This is why ResNets are so much easier to train than VGGs of comparable depth. It's not that skip connections give you a "better" architecture in some abstract sense — they literally make the optimization problem easier by reshaping the landscape. When you're debugging a deep network that won't train, adding skip connections isn't a hack. It's a fundamental improvement to the loss surface geometry.

Mode Connectivity

Here's something that still blows my mind. Garipov et al. (2018) showed that independently trained neural networks — same architecture, same data, different random seeds — converge to solutions that can be connected by simple low-loss paths in parameter space. Train model A, train model B from a different random initialization, and you can find a smooth curve through parameter space connecting A and B along which the loss stays low. This means the set of good solutions isn't a bunch of isolated pockets scattered across parameter space. It's more like a connected manifold. The implications for ensemble methods and model averaging are profound, but for debugging, the takeaway is reassuring: there are many good solutions, and the optimizer can find them from many starting points.

SAM: Explicitly Seeking Flat Minima

If flat minima generalize better, why not directly optimize for flatness? That's exactly what Sharpness-Aware Minimization (SAM), proposed by Foret et al. (2021), does. Instead of minimizing the loss at the current point, SAM minimizes the worst-case loss in a small neighborhood around the current point. The update has two steps: first, compute the gradient at the current point, take a step in the direction that maximizes loss (finding the worst neighbor), then compute the gradient at that worst neighbor and step in the direction that minimizes it. The effect is that SAM avoids parameters where the loss is low at the exact point but high nearby — the definition of a sharp minimum.

# SAM optimizer — simplified core logic
# Standard optimizers minimize L(w). SAM minimizes max_{||e||<=rho} L(w + e)

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05):
        defaults = dict(rho=rho)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups)

    @torch.no_grad()
    def first_step(self):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None:
                    continue
                # Save current weights, then perturb toward steepest ascent
                self.state[p]["old_p"] = p.data.clone()
                e_w = p.grad * scale
                p.add_(e_w)  # climb to worst neighbor

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data = self.state[p]["old_p"]  # restore original weights
        self.base_optimizer.step()  # step using gradient at perturbed point

    def _grad_norm(self):
        norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2)
                for group in self.param_groups
                for p in group["params"] if p.grad is not None
            ]),
            p=2
        )
        return norm

The cost is roughly 2x that of standard training (two forward-backward passes per step). But SAM consistently improves generalization, especially on tasks where you're close to the edge of overfitting. For our sentiment classifier, SAM might buy us 0.5-1% accuracy on the test set — not a revolution, but meaningful in competitive settings.

Learning Rate Warmup: Why Starting High Is Dangerous

At initialization, our weights are random. The loss landscape at a random point is chaotic — gradients are large and point in essentially arbitrary directions. If we start with a high learning rate, those large, noisy gradients catapult us far from our starting point into regions of the landscape that may be even worse. The loss spikes, gradients grow even larger, and we enter a death spiral that often ends in NaN. Learning rate warmup starts with a very small learning rate and gradually increases it over the first few hundred or thousand steps. This gives the network time to move into a more structured region of the landscape where the gradients carry actual signal before we start taking large steps. It's especially critical for transformers and any architecture without skip connections.

💡 Key Insight

The loss landscape isn't something you're stuck with. Architecture choices (skip connections), optimizer choices (SAM, momentum), and hyperparameter choices (small batch, warmup) all reshape the effective landscape you're navigating. Debugging a network that won't train is often about making the landscape easier to navigate, not about finding a cleverer way to search a terrible landscape.

The Single-Batch Overfit Test: Your First Debugging Move

Andrej Karpathy called this "the most important debugging test in deep learning," and I'd go further: it's the equivalent of a segfault check in systems programming. Before you tune a single hyperparameter, before you worry about regularization or data augmentation, you need to verify that your model can memorize a single batch of data. If it can't, you have a fundamental bug — not a hyperparameter problem, not a capacity problem, a bug.

The test is dead simple. Grab one batch from your training set. Feed it to the model repeatedly for 500-1000 steps with a reasonable learning rate (1e-3 for Adam is usually fine). The loss should decrease steadily and approach zero (or near-zero, depending on the task). For our sentiment classifier with 2 classes, the initial loss should be around -log(1/2) = 0.693 (random guessing). After 500 steps on a single batch of 32 samples, it should be below 0.01. If it's not, stop everything. Don't touch the learning rate. Don't add more layers. Fix the bug.

import torch
import torch.nn as nn

# Our baby sentiment classifier
class SentimentModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        embeds = self.embedding(x)
        _, (hidden, _) = self.lstm(embeds)
        return self.classifier(hidden.squeeze(0))

# The overfit test
model = SentimentModel(vocab_size=10000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Grab ONE batch and freeze it
train_loader_iter = iter(train_loader)
fixed_batch = next(train_loader_iter)
inputs, labels = fixed_batch

model.train()
for step in range(500):
    optimizer.zero_grad()
    logits = model(inputs)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        acc = (logits.argmax(dim=1) == labels).float().mean()
        print(f"Step {step}: loss={loss.item():.4f}, acc={acc.item():.3f}")

# After 500 steps: loss should be < 0.01, accuracy should be 1.0
# If not: YOU HAVE A BUG

What Failure Tells You

If the loss doesn't go to near-zero, the failure mode tells you a lot about the bug:

Loss stuck at initial value (e.g., 0.693 for binary classification). The model isn't learning anything. The most common causes: gradients aren't flowing (check for detached tensors, wrong dtype, frozen parameters), the optimizer isn't actually connected to the model's parameters, or the loss function doesn't match the model output format.

Loss decreases but plateaus well above zero. The model has some gradient signal but can't fit the batch. This often means the model's capacity is too low for even 32 examples (unlikely unless your architecture is very small), or more commonly, there's a data preprocessing bug — maybe the labels are shuffled relative to the inputs, or the tokenization is destroying information.

Loss immediately goes to NaN or infinity. Numerical instability from the very first step. Check for division by zero in custom layers, log(0) in losses, or a learning rate that's absurdly high. We'll cover NaN debugging in depth shortly.

Loss oscillates wildly without converging. The learning rate is too high for this architecture on this data. Try reducing it by 10x. If it still oscillates, you likely have a data-label mismatch — the labels are wrong or shuffled.

🔥 Hard-Won Lesson

I once spent two days tuning hyperparameters for a model that wouldn't converge. Grid search over learning rates, architectures, batch sizes — nothing worked. When I finally ran the single-batch overfit test, the loss was stuck at 2.303 (exactly -log(1/10) for 10-class classification). The model was guessing randomly. Turned out, the DataLoader had shuffle=True but the labels tensor wasn't being shuffled in sync with the inputs. Two days. The overfit test would have caught it in 30 seconds.

A Systematic Debugging Flow

Once the overfit test passes (or when you're trying to figure out why it's failing), we need a systematic approach. I've seen too many people debug training by randomly changing hyperparameters and hoping something works. That's not debugging, that's gambling. Here's the flow I follow, and it's roughly in order of "most likely to be the problem."

Step 1: Inspect the Data Pipeline

Most "model bugs" are data bugs. I know this sounds like a platitude, but I mean it literally. Print what the model actually sees. Not what you think it sees. Not what the data documentation says it should see. Print the actual tensor values going into the forward pass.

# ALWAYS do this before anything else
batch = next(iter(train_loader))
inputs, labels = batch

print(f"Input shape: {inputs.shape}")        # Expected: (batch_size, seq_len)
print(f"Input dtype: {inputs.dtype}")         # Should be torch.long for embeddings
print(f"Label shape: {labels.shape}")         # Expected: (batch_size,)
print(f"Label dtype: {labels.dtype}")         # Should be torch.long for CrossEntropyLoss
print(f"Label values: {labels[:10]}")         # Are these in [0, num_classes-1]?
print(f"Input range: [{inputs.min()}, {inputs.max()}]")  # Within vocab size?

# For text: decode back to words and visually inspect
for i in range(3):
    tokens = inputs[i].tolist()
    text = tokenizer.decode(tokens)
    label = labels[i].item()
    print(f"Label={label}: {text[:200]}")
    # Does the sentiment match the label? You'd be surprised how often it doesn't.

For our sentiment classifier, I've caught the following bugs with this check: tokenizer using wrong vocabulary (every word mapped to UNK), padding token not being ignored in the LSTM (polluting the hidden state), labels being 1-indexed instead of 0-indexed (causing CrossEntropyLoss to throw an index error with 2 classes or, worse, silently accepting with 3+ classes but being off by one), and input sequences being truncated to 10 tokens when the average review is 200 tokens.

Step 2: Verify the Expected Initial Loss

Before training starts, the model's weights are random, so its predictions should be roughly uniform over classes. For a classification problem with C classes, the expected initial loss with CrossEntropyLoss is:

expected_initial_loss = -log(1/C) = log(C)

Number of Classes (C) Expected Initial Loss
2 (binary)0.693
10 (CIFAR-10)2.303
100 (CIFAR-100)4.605
1000 (ImageNet)6.908
30,000 (language model vocab)10.31

If your initial loss is wildly different from this, something is wrong. Much lower than expected? Your model has some bias toward certain classes before training — possibly leaking information, or the last layer bias is initialized badly. Much higher? Your loss function is computing something unexpected. I once saw an initial loss of 47.0 for a 10-class problem. Turned out the model output was being passed through a softmax before CrossEntropyLoss, which applies log-softmax internally. The double softmax compressed the probabilities, making the log-probability much more negative than expected.

# Sanity check: initial loss
model.eval()
with torch.no_grad():
    logits = model(inputs)
    loss = criterion(logits, labels)
    print(f"Initial loss: {loss.item():.4f}")
    print(f"Expected for {num_classes} classes: {torch.log(torch.tensor(num_classes, dtype=torch.float)).item():.4f}")
    print(f"Logit stats: mean={logits.mean():.4f}, std={logits.std():.4f}")
    # Logits should have mean near 0 and moderate std if initialization is reasonable

Step 3: The Learning Rate Finder

Leslie Smith (2017) proposed a brilliant technique: instead of guessing the learning rate, sweep across a wide range and let the data tell you. Start with a tiny learning rate (like 1e-7), and increase it exponentially each batch until it reaches a large value (like 1.0 or 10.0). Record the loss at each step. Plot loss vs learning rate on a log scale. The optimal learning rate is roughly 10x below the point where the loss is lowest before it starts to diverge.

import math

def lr_finder(model, train_loader, criterion, start_lr=1e-7, end_lr=1.0, num_steps=200):
    """Sweep learning rate from start_lr to end_lr, record loss at each step."""
    optimizer = torch.optim.SGD(model.parameters(), lr=start_lr)
    lr_mult = (end_lr / start_lr) ** (1 / num_steps)

    lrs, losses = [], []
    best_loss = float('inf')
    avg_loss = 0.0
    beta = 0.98  # smoothing factor

    model.train()
    data_iter = iter(train_loader)

    for step in range(num_steps):
        try:
            inputs, labels = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            inputs, labels = next(data_iter)

        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # Exponential smoothing
        avg_loss = beta * avg_loss + (1 - beta) * loss.item()
        smoothed_loss = avg_loss / (1 - beta ** (step + 1))

        # Stop if loss is diverging
        if step > 10 and smoothed_loss > 4 * best_loss:
            break

        if smoothed_loss < best_loss:
            best_loss = smoothed_loss

        current_lr = start_lr * (lr_mult ** step)
        lrs.append(current_lr)
        losses.append(smoothed_loss)

        # Update LR for next step
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_mult

    return lrs, losses
    # Plot lrs (log scale) vs losses
    # Pick LR about 10x below the minimum loss point

For our sentiment classifier with Adam, the learning rate finder typically shows loss decreasing from around 1e-5, reaching a minimum around 1e-3 to 3e-3, then diverging rapidly above 1e-2. So we'd pick something like 1e-3 to 3e-4. This isn't the final answer — the optimal LR depends on batch size, model size, and warmup schedule — but it gets you in the right ballpark immediately instead of wasting runs on LR=0.1 (diverges) or LR=1e-6 (trains but barely moves).

Step 4: Gradient Health Checks

If the model is training but something feels off — loss is stagnant, convergence is glacially slow, or the model isn't learning certain features — it's time to look at the gradients directly.

def check_gradient_health(model):
    """Print gradient statistics per layer — the training vital signs."""
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad
            print(f"{name:40s} | "
                  f"grad mean={grad.mean().item():+.2e}, "
                  f"grad std={grad.std().item():.2e}, "
                  f"grad norm={grad.norm().item():.2e}, "
                  f"param norm={param.norm().item():.2e}")
        else:
            print(f"{name:40s} | NO GRADIENT")

# Call after loss.backward() but before optimizer.step()
loss.backward()
check_gradient_health(model)
optimizer.step()

What to look for: NO GRADIENT on any parameter means that parameter is disconnected from the loss. It's either frozen (requires_grad=False), or there's a computational graph break somewhere — maybe you called .detach() or .item() or .numpy() on an intermediate tensor, severing the autograd chain. Gradient norms near zero on early layers but healthy on later layers means vanishing gradients — the signal is dying as it flows backward through the network. Gradient norms that are wildly different across layers (e.g., 1e-8 for early layers, 1e+3 for later layers) suggest the network needs better initialization, normalization layers, or gradient clipping.

Step 5: Architecture Bugs — The Double Softmax Trap

This deserves its own subsection because it catches an embarrassing number of people, including experienced practitioners who should know better (I include myself in that group).

PyTorch's nn.CrossEntropyLoss expects raw logits — unnormalized scores. Internally, it applies log_softmax and then nll_loss. If your model's forward method ends with nn.Softmax() or F.softmax(), you're applying softmax twice. The first softmax squashes your logits into [0, 1]. The second (inside CrossEntropyLoss) applies log-softmax to those already-squashed values. The resulting loss is mathematically valid — gradients still flow — but the effective optimization landscape is distorted, and training is severely crippled.

# BUG: Double softmax — the model applies softmax, then CrossEntropyLoss applies it again
class BrokenModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        self.softmax = nn.Softmax(dim=1)  # <-- THIS IS THE BUG

    def forward(self, x):
        return self.softmax(self.fc(x))  # Returns probabilities, not logits

# CrossEntropyLoss internally does log_softmax(input), so this computes:
# loss = nll_loss(log_softmax(softmax(logits)))  <-- double softmax!

# FIX: Return raw logits from the model
class FixedModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)  # Raw logits — let the loss function handle normalization

The sneaky part is that this bug doesn't crash. It doesn't produce NaN. The model still learns, kind of. It'll converge to maybe 70% accuracy on a task where it should reach 90%. You'll think you need a bigger model or better hyperparameters, and you'll waste days before realizing the loss function was wrong all along.

Step 6: torch.autograd.detect_anomaly()

When you've exhausted the manual checks and something is still producing NaN or unexpected values, PyTorch has a nuclear option: torch.autograd.detect_anomaly(). It instruments every autograd operation to check for NaN and Inf in both forward and backward passes, and when it finds one, it gives you a traceback pointing to the exact operation.

# ONLY use for debugging — this makes training 5-10x slower
with torch.autograd.detect_anomaly():
    logits = model(inputs)
    loss = criterion(logits, labels)
    loss.backward()
    # If any operation produces NaN/Inf, you'll get a RuntimeError
    # with the exact traceback showing which operation is the culprit

Never leave this on in production training. The overhead is massive. Use it to find the bug, fix it, remove it.

💡 Key Insight

The debugging flow has a hierarchy: data first, loss second, gradients third, architecture fourth. Most bugs are in the first two. If you jump to tweaking architecture or hyperparameters before verifying data integrity, you're debugging in the dark.

NaN Loss: Causes, Mechanisms, and Fixes

NaN in your loss is the training equivalent of a kernel panic. Everything was fine, and then suddenly every number in your model is NaN, and the entire run is irrecoverable. I've had this happen on hour 47 of a 48-hour training run. It's not fun. Let's understand exactly how NaN propagates and how to defend against it.

The Chain Reaction

NaN rarely starts everywhere at once. It starts in one operation, one layer, one parameter. Then it spreads through the computation graph like an infection. A single NaN gradient updates one weight to NaN. That NaN weight produces NaN activations for every input in the next forward pass. Those NaN activations produce NaN gradients for every other weight. Within one or two steps, every parameter in the model is NaN. Once it starts, there's no recovery without loading a checkpoint.

Here's the typical chain: a particularly adversarial batch (outlier inputs, mislabeled samples) produces an unusually large loss. The large loss produces large gradients. The large gradients produce a huge weight update. The updated weight now holds an extreme value. In the next forward pass, that extreme weight produces an activation that overflows to infinity. Infinity times anything is infinity. Infinity minus infinity is NaN. And NaN times anything is NaN. Game over.

Float16 Overflow

If you're training in mixed precision (which you should be for speed), float16 has a maximum representable value of approximately 65,504. That sounds like a lot until you consider that a squared value of 256 is 65,536 — already overflow. Softmax computes exp(x), which overflows float16 for x > 11.09. Matrix multiplications with large hidden dimensions can easily produce pre-activation values exceeding that range.

# Float16 range demonstration
import torch
print(torch.finfo(torch.float16).max)   # 65504.0
print(torch.finfo(torch.float16).min)   # -65504.0
print(torch.finfo(torch.float16).tiny)  # 6.1035e-05 (smallest positive normal)

# This overflows in float16:
x = torch.tensor([256.0], dtype=torch.float16)
print(x ** 2)  # tensor([inf], dtype=torch.float16)

# This is why softmax needs to be computed in float32 even in mixed precision
logits_fp16 = torch.tensor([12.0, 13.0, 14.0], dtype=torch.float16)
print(torch.exp(logits_fp16))  # tensor([inf, inf, inf]) in float16!

logits_fp32 = logits_fp16.float()
print(torch.exp(logits_fp32))  # tensor([162755., 442413., 1202604.]) — fine in fp32

Common Numerical Instability Sources

log(0): If your model produces an output that's exactly 0.0 (after softmax, or in a probability-output model), and you take the log, you get -inf. Then any arithmetic with -inf can produce NaN. The fix: add an epsilon guard. Use torch.log(x + 1e-8) instead of torch.log(x) in custom loss functions. The built-in CrossEntropyLoss handles this internally via log_softmax, which is numerically stable.

sqrt(0) with gradients: The derivative of sqrt(x) is 1/(2*sqrt(x)). At x=0, that's 1/0 = inf. Even if the forward pass is fine, the backward pass explodes. The fix: torch.sqrt(x + 1e-8).

exp(large): exp(x) overflows for x greater than about 88 in float32 (or 11 in float16). This comes up in softmax, Gaussian mixture models, and anything involving probabilities. The standard trick is the log-sum-exp trick: instead of computing exp(x_i) / sum(exp(x_j)), compute exp(x_i - max(x)) / sum(exp(x_j - max(x))). PyTorch's built-in softmax does this automatically.

Division by zero: Custom normalization layers, attention mechanisms with zero-valued denominators, or any place where you divide by a computed quantity that might be zero. Always add epsilon to denominators.

# Common epsilon guards for numerical stability

# BAD: log(0) = -inf, then -inf * weight = NaN
log_prob = torch.log(predicted_prob)

# GOOD: epsilon guard prevents log(0)
log_prob = torch.log(predicted_prob + 1e-8)

# BAD: sqrt gradient is inf at 0
norm = torch.sqrt(sum_of_squares)

# GOOD: epsilon inside the sqrt
norm = torch.sqrt(sum_of_squares + 1e-8)

# BAD: division by potentially zero norm
normalized = x / x.norm(dim=-1, keepdim=True)

# GOOD: clamp the denominator
normalized = x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)

# BAD: custom attention without numerical stability
attn_weights = torch.exp(scores) / torch.exp(scores).sum(dim=-1, keepdim=True)

# GOOD: use the numerically stable built-in softmax
attn_weights = torch.softmax(scores, dim=-1)

GradScaler for Mixed Precision

PyTorch's GradScaler is the frontline defense against NaN in mixed-precision training. The core idea is clever: float16 has limited dynamic range, so small gradient values underflow to zero (losing information) while large values overflow to inf. GradScaler addresses this by scaling the loss up before the backward pass, which proportionally scales all gradients up into the representable range of float16. Before the optimizer step, it scales the gradients back down by the same factor. If any gradient is inf or NaN (indicating overflow even with scaling), GradScaler skips the optimizer step entirely and reduces the scale factor for the next iteration.

from torch.amp import autocast, GradScaler

scaler = GradScaler()
model = SentimentModel(vocab_size=10000).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()

        # Forward pass in mixed precision (float16 where safe, float32 where needed)
        with autocast(device_type='cuda'):
            logits = model(inputs)
            loss = criterion(logits, labels)

        # Backward pass: scale loss up, compute gradients in float16
        scaler.scale(loss).backward()

        # Unscale gradients back to float32 range
        # If any gradients are inf/NaN, skip this update
        scaler.unscale_(optimizer)

        # Optional: clip gradients AFTER unscaling, BEFORE step
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step: skips if inf/NaN detected, adjusts scale factor
        scaler.step(optimizer)
        scaler.update()

        # To monitor: check if steps are being skipped
        # scaler.get_scale() will decrease when inf is detected

The scaler.update() call at the end adjusts the scale factor dynamically. If the last step was successful (no inf/NaN), it may increase the scale factor to use more of float16's range. If the step was skipped, it decreases the scale factor. This self-tuning behavior means you rarely need to touch the initial scale — the default of 216 = 65536 works for most training runs.

Using detect_anomaly() to Find the Source

When NaN appears and you don't know which operation produced it, detect_anomaly is your debugger. But use it strategically — wrap only a few suspicious batches, not your entire training loop.

# Strategy: run normally, switch to anomaly detection only when NaN detected
for step, (inputs, labels) in enumerate(train_loader):
    optimizer.zero_grad()
    logits = model(inputs)
    loss = criterion(logits, labels)

    if torch.isnan(loss) or torch.isinf(loss):
        print(f"NaN/Inf detected at step {step}!")
        print(f"Input stats: min={inputs.min()}, max={inputs.max()}")
        print(f"Logit stats: min={logits.min()}, max={logits.max()}")

        # Reload last good checkpoint and re-run this batch with anomaly detection
        model.load_state_dict(last_good_checkpoint)
        with torch.autograd.detect_anomaly():
            logits = model(inputs)
            loss = criterion(logits, labels)
            loss.backward()  # This will raise with traceback to the exact bad op
    else:
        loss.backward()
        optimizer.step()
        # Save checkpoint periodically
        if step % 100 == 0:
            last_good_checkpoint = {k: v.clone() for k, v in model.state_dict().items()}
⚠️ Warning

A common mistake with mixed precision: calling .item() on the scaled loss instead of the unscaled loss for logging. The scaled loss might be 65536x larger than the real loss. Always log loss.item() from the loss computed inside the autocast context, not the scaled version passed to backward().

Reading Loss Curves Like a Cardiologist Reads EKGs

A loss curve is the EKG of your model's training. Learning to read it — really read it, not glance at it and shrug — is one of the most valuable skills you can develop. Every pathology has a signature. Let's catalog them.

The Healthy Training Curve

A healthy training curve for our sentiment classifier looks like this: training loss drops steeply for the first few epochs, then gradually levels off. Validation loss follows the same trajectory but slightly above. The gap between them is small and relatively constant. Both curves eventually flatten. That gap represents the generalization gap — how much the model has memorized training-specific patterns versus learning generalizable features. A small, stable gap means good regularization.

Overfitting: The Crossover

The classic overfitting pattern is unmistakable: training loss keeps decreasing smoothly, but at some epoch — the crossover point — validation loss stops decreasing and starts increasing. The model has exhausted the generalizable patterns and is now memorizing noise in the training set. Every additional epoch past the crossover makes the model worse on unseen data.

The fix depends on severity. Mild overfitting (val loss increases slowly): add dropout, increase weight decay, or use early stopping. Severe overfitting (val loss climbs steeply while train loss drops to near-zero): you need more training data, data augmentation, or a smaller model. For our sentiment classifier, switching from a 3-layer LSTM to a 1-layer LSTM, or adding dropout=0.5 between layers, often tames overfitting without sacrificing too much capacity.

Underfitting: Both Curves High and Flat

When both training and validation loss are high and neither is decreasing meaningfully, the model is underfitting. It doesn't have enough capacity to capture the patterns in the data, or it hasn't trained long enough, or the learning rate is too low. Check the learning rate first — underfitting with a tiny LR looks identical to underfitting with insufficient capacity, but one is fixed in 5 seconds and the other requires architecture changes.

Double Descent: When More Capacity Gets Worse, Then Better

This is one of the most counterintuitive phenomena in modern deep learning, and it directly challenges the classical bias-variance tradeoff that generations of ML students were taught. The double descent phenomenon goes like this: as you increase model capacity (more parameters), test error initially follows the classical U-curve — it first decreases (underfitting region), then increases (overfitting region). But if you keep increasing capacity past the interpolation threshold — the point where the model has enough parameters to perfectly fit the training data — test error starts decreasing again. The classical picture predicts that overparameterized models should be terrible. In practice, they're often the best.

The interpolation threshold is the critical point where the model has exactly enough parameters to achieve zero training loss. At this point, the model is forced into a very specific, often brittle solution — the only configuration of weights that achieves zero training loss. With even more parameters, there are now many configurations that achieve zero training loss, and SGD (with its implicit regularization from noise) tends to find ones that are smooth and generalizable. Think of it like fitting a polynomial to data points: if you have exactly as many polynomial coefficients as data points, the fit is forced and often wild. If you have many more coefficients than data points, there are infinitely many perfect fits, and regularization can select a smooth one.

Double descent also occurs along the epoch axis: for a fixed-size model, test error may increase during training, then decrease again if you train long enough. This is deeply unsettling to anyone raised on "stop early to avoid overfitting" — but it's been observed consistently across architectures and datasets.

Grokking: The Surprise Generalization

Power et al. (2022) described something even stranger: grokking, or delayed generalization. The model memorizes the training set quickly — training loss goes to zero, training accuracy hits 100%. But validation performance is no better than random. For thousands of additional training steps, nothing happens. And then, suddenly, validation accuracy jumps from chance level to near-perfect. The model generalizes long, long after it has memorized.

Grokking has been observed most dramatically on small algorithmic tasks (like modular arithmetic), but similar phenomena occur in larger settings. The mechanism appears to be related to weight decay: during the memorization phase, the model uses large weights to encode specific input-output mappings. As training continues, weight decay slowly pushes the weights down, and at some point, the model "discovers" that there's a simpler, more generalizable solution that achieves the same training loss with smaller weights. It undergoes a phase transition from a memorization solution to a generalization solution.

The practical implication is uncomfortable: you can't always trust early stopping. Sometimes the model needs to train much longer than you'd expect to find the generalizable solution. If your loss curves look like grokking — perfect training performance but stagnant validation — increasing weight decay and being patient with training might pay off.

Loss Spikes

You're watching the loss curve, and suddenly there's a sharp upward spike — loss jumps from 0.3 to 2.5 in a single step, then (hopefully) recovers within a few steps. What happened? Usually one of two things: a "bad batch" containing outliers, corrupted samples, or mislabeled examples that produce an anomalously large loss, or a sudden change in the effective learning rate (if using a scheduler with warmup restarts, for example).

If you have gradient clipping enabled, the spike usually recovers because the clipping limits the damage from the anomalous gradient. Without clipping, a single spike can destabilize the optimizer's momentum and it takes many steps to recover — or it never does. This is why gradient clipping is essentially mandatory for any training run longer than a few hours.

💡 Key Insight

Log your loss per-batch, not per-epoch. Epoch-averaged losses smooth over spikes and anomalies that are critical diagnostic information. A single bad batch that spikes the loss by 100x is invisible in the epoch average but might be destabilizing your training. Per-batch logging catches it.

Reproducibility: Fighting Nondeterminism

Reproducibility in deep learning is harder than it should be. You set a random seed, you expect the same results. But there are at least six sources of nondeterminism lurking in a typical PyTorch training pipeline, and if you miss any of them, your results will vary between runs even with the same seed.

The Sources of Chaos

Random number generators: Python's random module, NumPy's numpy.random, and PyTorch's torch.random are three independent RNGs. Setting one doesn't affect the others. Your data augmentation might use NumPy random, your dropout uses PyTorch random, and your shuffling uses Python random. You need to seed all three.

CUDA atomics: Operations like scatter_add, index_add, and embedding lookups with duplicate indices use atomic operations on the GPU. Atomic operations on CUDA have non-deterministic ordering because they depend on thread scheduling, which varies between runs. The final result is mathematically identical in exact arithmetic but differs in floating-point arithmetic due to different summation orders (floating-point addition is not associative).

cuDNN benchmark mode: By default, PyTorch's cuDNN backend benchmarks multiple convolution algorithms on the first call and picks the fastest. The selected algorithm can change between runs due to hardware state, and different algorithms produce slightly different floating-point results. Additionally, some cuDNN algorithms are inherently non-deterministic.

DataLoader workers: When using num_workers > 0, each worker process gets its own copy of the dataset and its own random state. The default random state for each worker depends on the base seed plus the worker ID, but if you're not careful about how you initialize them, the randomness within each worker can be unpredictable.

Python hash randomization: Python 3 randomizes the hash seed at startup for security reasons. This affects the order of dictionary iteration and set operations. If your data pipeline uses dicts or sets anywhere in a way that affects processing order, results will differ.

The Seeding Ceremony

import random
import numpy as np
import torch
import os

def seed_everything(seed=42):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    os.environ['PYTHONHASHSEED'] = str(seed)

    # Force cuDNN to use deterministic algorithms (slower but reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

The torch.backends.cudnn.deterministic = True flag forces cuDNN to use deterministic convolution algorithms. This comes at a performance cost — sometimes 10-20% slower — because the deterministic algorithms may not be the fastest ones available. The benchmark = False flag disables the algorithm search, which both saves time on the first iteration and removes another source of nondeterminism.

Worker Initialization

DataLoader workers need special attention. Each worker is a separate process with its own copy of the Python random state, NumPy random state, and PyTorch random state. If you don't seed them explicitly, they'll use unpredictable seeds.

def worker_init_fn(worker_id):
    """Ensure each DataLoader worker has a unique but deterministic seed."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed + worker_id)
    random.seed(worker_seed + worker_id)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    worker_init_fn=worker_init_fn,
    generator=torch.Generator().manual_seed(42)  # Seed the shuffling RNG too
)

The generator parameter seeds the RNG used for shuffling. Without it, the shuffle order depends on the global PyTorch RNG state, which might have been modified by model initialization or other operations between when you set the global seed and when the DataLoader starts shuffling.

The Strictest Mode

PyTorch provides a nuclear option: torch.use_deterministic_algorithms(True). This is stricter than cudnn.deterministic. It forces every PyTorch operation to use a deterministic implementation, and if an operation doesn't have one, it raises a RuntimeError instead of silently being non-deterministic. This is useful for discovering which operations in your pipeline are non-deterministic, but it can break training if you use operations that don't have deterministic GPU implementations (like certain scatter operations).

# Strictest reproducibility — will error on any non-deterministic op
torch.use_deterministic_algorithms(True)

# Some ops (like scatter_add on CUDA) will raise RuntimeError.
# You can set a fallback for those:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
# This controls the workspace size for cuBLAS, which affects determinism of matmuls

The Uncomfortable Truth

Even with all seeds set, all deterministic flags enabled, and all workers properly initialized, results can differ between PyTorch versions, CUDA versions, GPU architectures (a V100 and an A100 may give different results), and even between different GPUs of the same model due to manufacturing variations in floating-point units. If your experiment's conclusions depend on the 4th decimal place of accuracy being exactly reproducible across different machines, you have a fragile experiment. Document the exact software and hardware stack, report results with confidence intervals from multiple seeds, and accept that bitwise reproducibility across environments is a losing battle.

🔥 Hard-Won Lesson

I once spent a week trying to reproduce results from a paper. My numbers were consistently 0.3% lower. Turned out the paper used PyTorch 1.9 and I was on 1.13. The change in default initialization for nn.MultiheadAttention between those versions was enough to shift the results. The code was identical. The seeds were identical. The library versions weren't.

Training Monitoring: Your Early Warning System

Training a deep network without proper monitoring is like flying a plane without instruments. You might make it, but when something goes wrong, you won't know until you've already crashed. Let's set up the instruments.

What to Log and Why

Loss per batch: Not per epoch. Per batch. Epoch-level averages hide important information — a single catastrophic batch, a gradual drift, periodic oscillations tied to data ordering. Batch-level losses show you the real dynamics of training.

Gradient norms: The global gradient norm (across all parameters) is the single most useful early warning signal. If it starts creeping up over training, you're heading toward instability. If it suddenly spikes, a bad batch or learning rate change is causing problems. If it drops to near-zero, your model has stopped learning (vanishing gradients, or it's converged).

Parameter norms: Track the L2 norm of each layer's weights over training. If parameter norms grow unboundedly, something is wrong — probably a missing weight decay or a feedback loop in the architecture. If they collapse to near-zero, something is aggressively regularizing or the learning signal is absent.

Learning rate: Always log the actual learning rate, especially when using schedulers. It's embarrassingly easy to have a scheduler that you think is doing cosine annealing but is actually stuck at the initial LR because you forgot to call scheduler.step().

import torch
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/sentiment_v1')

global_step = 0
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()

        # ---- MONITORING BLOCK ----
        # 1. Log the loss (use .item() to detach from graph!)
        writer.add_scalar('train/loss_batch', loss.item(), global_step)

        # 2. Gradient norm — the most important health metric
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        total_norm = total_norm ** 0.5
        writer.add_scalar('train/grad_norm', total_norm, global_step)

        # 3. Per-layer gradient norms (find where gradients vanish/explode)
        for name, p in model.named_parameters():
            if p.grad is not None:
                writer.add_scalar(f'grad_norm/{name}', p.grad.norm().item(), global_step)

        # 4. Learning rate
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('train/lr', current_lr, global_step)

        # 5. Alert conditions
        if total_norm > 100:
            print(f"WARNING: grad norm {total_norm:.1f} at step {global_step}")
        if torch.isnan(loss):
            print(f"FATAL: NaN loss at step {global_step}")
            break
        # ---- END MONITORING BLOCK ----

        # Clip gradients before stepping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        global_step += 1

writer.close()

The Memory Leak Trap: loss vs loss.item()

This bug is so common and so insidious that it deserves a spotlight. When you compute loss = criterion(logits, labels), the loss tensor carries with it the entire computation graph used to compute it — every intermediate tensor, every weight reference, everything needed for .backward(). If you store loss in a list (e.g., losses.append(loss)), you're keeping the entire computation graph alive in memory. After 10,000 batches, you've accumulated 10,000 computation graphs in memory. Your GPU runs out of memory, and the error message says "CUDA out of memory" but doesn't tell you why.

# BUG: This stores the entire computation graph for every batch
all_losses = []
for inputs, labels in train_loader:
    loss = criterion(model(inputs), labels)
    loss.backward()
    all_losses.append(loss)  # MEMORY LEAK! Keeps the graph alive
    optimizer.step()
    optimizer.zero_grad()

# FIX: Use .item() to extract the Python float, releasing the graph
all_losses = []
for inputs, labels in train_loader:
    loss = criterion(model(inputs), labels)
    loss.backward()
    all_losses.append(loss.item())  # Just a float — graph is freed
    optimizer.step()
    optimizer.zero_grad()

The same trap applies to logging: writer.add_scalar('loss', loss, step) is fine because TensorBoard extracts the value, but wandb.log({"loss": loss}) might keep a reference depending on the library version. Always use loss.item() when extracting scalar values for logging or accumulation.

Per-Layer Gradient Monitoring

Global gradient norm tells you the overall health, but per-layer gradient norms tell you where the problems are. In our sentiment classifier, a healthy gradient profile looks like this: the classifier head (last layer) has the largest gradients, the LSTM layers have moderate gradients, and the embedding layer has the smallest gradients. If the embedding layer gradients are zero but everything else is healthy, the embedding weights are frozen (maybe intentionally, maybe not). If the LSTM gradients are orders of magnitude larger than the classifier gradients, there's likely an instability in the recurrent computation.

# Periodic detailed gradient check — run every N steps
def log_gradient_histogram(model, writer, step):
    """Log gradient histograms to TensorBoard for visual inspection."""
    for name, param in model.named_parameters():
        if param.grad is not None:
            writer.add_histogram(f'gradients/{name}', param.grad, step)
            writer.add_histogram(f'weights/{name}', param.data, step)

            # Also check for concerning patterns
            grad = param.grad
            if torch.any(torch.isnan(grad)):
                print(f"NaN gradient in {name} at step {step}")
            if torch.any(torch.isinf(grad)):
                print(f"Inf gradient in {name} at step {step}")
            if grad.abs().max().item() < 1e-10:
                print(f"Vanishing gradient in {name} at step {step}: max={grad.abs().max().item():.2e}")
💡 Key Insight

The ratio of gradient norm to parameter norm is more informative than either alone. If ||grad|| / ||param|| > 1, the update is larger than the parameters themselves — that's a sign of instability. A healthy ratio is typically between 1e-4 and 1e-1. Track this ratio per layer to catch problems early.

The Silent Bugs: No Error, No Warning, Quietly Broken

These are the bugs I hate most. They don't crash your program. They don't produce NaN. They don't trigger any warning. They sit there quietly, making your model 5-15% worse than it should be, and you have no idea. You tune hyperparameters to compensate, you try bigger models, you collect more data — none of it fully makes up for a fundamental bug in the training loop. I've compiled this list from personal pain and from watching dozens of colleagues hit the same issues.

Forgetting model.eval() During Validation

This is probably the single most common silent bug in PyTorch training code. During training, you call model.train() which enables dropout (randomly zeroing neurons) and makes BatchNorm use per-batch statistics. During validation, you should call model.eval() which disables dropout and makes BatchNorm use running statistics accumulated during training.

If you forget model.eval() during validation, two things go wrong. First, dropout randomly zeros out neurons during inference, making your validation predictions noisier and worse than they should be. Second, BatchNorm computes mean and variance from the validation batch instead of using the stable running statistics. If your validation batch is small, these batch statistics are noisy, further degrading performance. The net effect: your validation loss is artificially high and noisy, making it look like the model generalizes worse than it actually does. You might add more regularization in response, which makes the model actually worse.

# THE BUG: No model.eval() during validation
for epoch in range(num_epochs):
    # Training
    model.train()
    for inputs, labels in train_loader:
        # ... training loop ...
        pass

    # Validation — BUG: model is still in train mode!
    val_loss = 0
    for inputs, labels in val_loader:
        with torch.no_grad():
            logits = model(inputs)  # Dropout is active! BatchNorm uses batch stats!
            val_loss += criterion(logits, labels).item()

# THE FIX: Explicitly set eval mode and restore train mode after
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        # ... training loop ...
        pass

    model.eval()  # <-- Critical: disables dropout, uses running stats for BN
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            logits = model(inputs)
            val_loss += criterion(logits, labels).item()
    model.train()  # <-- Restore train mode for next epoch

The Inverse: Forgetting model.train() After Evaluation

The mirror bug: you correctly call model.eval() for validation, but forget to call model.train() before the next training epoch. Now dropout is disabled during training. The model trains without any dropout regularization, which often leads to overfitting. And because you're looking at the validation loss (which is correctly computed in eval mode), you see overfitting and wonder why your regularization isn't working. It's because you accidentally turned it off.

The Double Softmax Trap (Revisited)

We covered this in the debugging flow, but it's worth repeating as a silent bug because of how common it is. If your model ends with nn.Softmax(dim=1) and you use nn.CrossEntropyLoss, you're applying softmax twice. The model trains. It learns. It produces non-random predictions. But the accuracy plateau is significantly lower than it should be — maybe 82% instead of 91%. The double softmax compresses the probability distribution, making it harder for the model to express confident predictions. The gradients are dampened, and the effective loss surface is flattened in the wrong way.

Gradient Accumulation + Scheduler Misalignment

Gradient accumulation is a technique for simulating larger batch sizes when you can't fit a full large batch in GPU memory. You accumulate gradients over N mini-batches, then do a single optimizer step. The effective batch size is N × mini_batch_size. The bug arises when your learning rate scheduler steps every mini-batch instead of every optimizer step.

# BUG: Scheduler steps every mini-batch, not every optimizer step
accumulation_steps = 4
for step, (inputs, labels) in enumerate(train_loader):
    logits = model(inputs)
    loss = criterion(logits, labels) / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

    scheduler.step()  # BUG: steps 4x too often!

# FIX: Step the scheduler only on optimizer steps
accumulation_steps = 4
for step, (inputs, labels) in enumerate(train_loader):
    logits = model(inputs)
    loss = criterion(logits, labels) / accumulation_steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()  # Only step when the optimizer steps

If you're using a cosine annealing scheduler with 1000 optimizer steps, but the scheduler steps 4000 times, it completes its entire cycle in the first quarter of training and spends the remaining 75% at the minimum LR. Your effective learning rate schedule is completely wrong.

Wrong Dtype: Integer Inputs to Floating-Point Operations

If you pass a torch.int64 tensor where a torch.float32 is expected, PyTorch often silently handles it — but gradients don't flow through integer operations. For our sentiment classifier, the input token IDs should be torch.long (for the embedding lookup), but if you accidentally cast the embedding output or some intermediate value to int, gradients die silently.

# Subtle dtype bug example
x = torch.tensor([1.5, 2.7, 3.2], requires_grad=True)
y = x.int()  # Casts to integer — NO GRADIENTS through this!
z = y.float() * 2.0  # Back to float, but grad chain is broken
z.sum().backward()
print(x.grad)  # None! The int() cast severed the gradient chain.

Augmentation on Validation Data

Random data augmentation (flips, crops, color jitter, etc.) should be applied ONLY to training data. If you apply random augmentation to validation data, your validation metrics will have unnecessary variance — sometimes a validation sample gets an easy augmentation and is classified correctly, sometimes a hard augmentation makes it fail. The metric bounces around between epochs, making it impossible to tell if the model is genuinely improving. Use deterministic preprocessing (center crop, fixed resize) for validation.

# Correct: separate transforms for train and val
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224),
    transforms.ColorJitter(brightness=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

val_transform = transforms.Compose([
    transforms.CenterCrop(224),       # Deterministic — no randomness
    transforms.ToTensor(),
    transforms.Normalize(mean, std),  # Same normalization as training
])

Preprocessing Leakage

Computing normalization statistics (mean, standard deviation) on the entire dataset, including the validation and test sets, is a subtle form of data leakage. The model's preprocessing step has "seen" the validation data distribution. In practice, the effect is often tiny — the mean and std don't change much when you add 10-20% more data. But in competitive settings or when your dataset is small, it can inflate your metrics by a measurable amount. The correct approach: compute mean and std on the training set only, apply those same values to normalize both training and validation data.

Frozen Pretrained Layers You Forgot About

When fine-tuning a pretrained model (say, BERT for our sentiment classifier), it's common to start by freezing the pretrained layers and training only the classification head. The bug is forgetting to unfreeze them when you're ready for full fine-tuning.

# Phase 1: Freeze BERT, train only the classifier head
for param in model.bert.parameters():
    param.requires_grad = False

# ... train for a few epochs on just the head ...

# Phase 2: Full fine-tuning — BUG if you forget this!
for param in model.bert.parameters():
    param.requires_grad = True  # Unfreeze BERT layers

# Also: you probably want a smaller learning rate for fine-tuning
optimizer = torch.optim.AdamW([
    {'params': model.bert.parameters(), 'lr': 2e-5},       # Small LR for pretrained
    {'params': model.classifier.parameters(), 'lr': 1e-3},  # Larger LR for head
])

If you forget to unfreeze, the model achieves whatever accuracy the randomly initialized head can manage with fixed BERT features — maybe 85% instead of 93%. You might think BERT isn't helping and give up. But the features are fine; you've accidentally frozen the fine-tuning.

🔥 Hard-Won Lesson

Run a parameter audit before training starts. Print every parameter name, its shape, whether requires_grad is True, and which optimizer group it belongs to. This takes 5 lines of code and catches frozen parameters, missing parameters, and optimizer misconfiguration in one shot.

# Parameter audit — run this before every training job
print("=" * 80)
print(f"{'Parameter':<50} {'Shape':<20} {'Grad':>5} {'#Params':>10}")
print("=" * 80)
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
    n = param.numel()
    total_params += n
    if param.requires_grad:
        trainable_params += n
    print(f"{name:<50} {str(list(param.shape)):<20} {str(param.requires_grad):>5} {n:>10,}")
print("=" * 80)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")

Gradient Clipping: The Safety Net

Gradient clipping is one of those techniques that you almost always want on, and the question is not "should I clip?" but "what threshold?" The idea is straightforward: if a batch produces an anomalously large gradient (because of an outlier input, a numerical edge case, or a coincidentally nasty combination of samples), we cap the gradient magnitude before applying it. The signal is preserved (the direction is the same), but the step size is limited.

How clip_grad_norm_ Works

PyTorch's torch.nn.utils.clip_grad_norm_ computes the global gradient norm — the L2 norm of the gradient vector formed by concatenating all parameter gradients into a single vector. If this norm exceeds the specified max_norm, every gradient is scaled by the same factor to bring the total norm down to max_norm. The formula:

Given the total gradient norm ||g|| and the clipping threshold max_norm:

If ||g|| > max_norm: scale every gradient by max_norm / ||g||

If ||g|| <= max_norm: do nothing

This is a global operation — it considers all parameters together, not each parameter independently. This is important because it preserves the relative magnitudes of gradients across layers. If layer A has a gradient norm of 5 and layer B has a gradient norm of 3, and the total norm is 5.83, and you clip to 1.0, both are scaled by 1.0/5.83 ≈ 0.17. Layer A's gradient is now 0.86 and layer B's is 0.51 — the same ratio as before. Per-parameter clipping would distort these ratios.

# The standard gradient clipping pattern
optimizer.zero_grad()
logits = model(inputs)
loss = criterion(logits, labels)
loss.backward()

# AFTER backward, BEFORE step — order matters!
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# grad_norm is the ORIGINAL norm before clipping — useful for logging

optimizer.step()

# Log the pre-clip gradient norm to monitor training health
writer.add_scalar('train/grad_norm_preclip', grad_norm.item(), global_step)

Choosing the Clipping Threshold

There's no universal best value, but there are good defaults and a principled way to tune. The HuggingFace Transformers library defaults to max_norm=1.0 for BERT-style models, and this works well for most transformer architectures. For RNNs and LSTMs, values between 0.5 and 5.0 are typical. For CNNs with batch normalization, you might not need clipping at all because BatchNorm stabilizes the gradient magnitudes naturally.

A practical approach: start without clipping, log the gradient norms for the first hundred steps, look at the distribution, and set the clipping threshold at the 90th or 95th percentile. This means you're only clipping the truly anomalous batches, not interfering with normal training dynamics.

Architecture Typical max_norm Notes
BERT / Transformers 1.0 HuggingFace default
GPT / Large LMs 1.0 Critical for training stability
LSTM / GRU 1.0 – 5.0 RNNs are prone to gradient spikes
ResNet / CNN 5.0 or none BatchNorm provides natural stability
GAN discriminator 0.5 – 1.0 Very unstable training dynamics

Monitoring the Clip Rate

If gradients are being clipped on every single step, that's a red flag. It means the gradient magnitudes are consistently too large, and clipping is constantly interfering with the optimizer's updates. This usually means the learning rate is too high, the batch size is too small (noisy gradients), or the architecture has a structural issue causing amplification. If gradients are never clipped, the threshold is either well-chosen or the clipping is unnecessary. A healthy clip rate is somewhere in between: most steps proceed normally, but the occasional outlier batch gets tamed.

# Track clip statistics over training
clip_count = 0
total_steps = 0

for inputs, labels in train_loader:
    optimizer.zero_grad()
    loss = criterion(model(inputs), labels)
    loss.backward()

    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    total_steps += 1
    if grad_norm.item() > 1.0:
        clip_count += 1

    optimizer.step()

    if total_steps % 100 == 0:
        clip_rate = clip_count / total_steps
        print(f"Clip rate: {clip_rate:.1%} over {total_steps} steps")
        # Healthy: 1-10%. Concerning: >50%. Investigate: 100%.
⚠️ Warning

There's also clip_grad_value_, which clips each gradient element independently to a range [-value, value]. This is different from norm clipping and generally less useful because it distorts the gradient direction. Norm clipping preserves direction and only scales magnitude, which is almost always what you want. Use clip_grad_norm_ unless you have a specific reason not to.

Weight Initialization: Where It All Begins

The initial values of your network's weights determine the starting point of optimization. Get it badly wrong, and the model can't learn at all — activations saturate, gradients vanish, and the loss sits at its initial value forever. Get it right, and gradient flow is healthy from the very first step. The theory behind initialization is one of the more elegant pieces of deep learning mathematics.

The Core Problem

Consider a layer y = Wx where x is the input with n elements and W is a weight matrix. Each element of y is a sum of n terms: y_j = sum(W_ji * x_i). If each weight W_ji has variance Var(w) and each input x_i has variance Var(x), and they're independent, then the variance of y_j is:

Var(y) = n × Var(w) × Var(x)

If Var(w) is too large, Var(y) will be much larger than Var(x). Each layer amplifies the signal. After 10 layers, the activations have been amplified by a factor of (n × Var(w))^10. If n × Var(w) = 2, that's 210 = 1024× amplification. Activations explode. If Var(w) is too small, each layer attenuates the signal, and after 10 layers, activations (and their gradients) are negligibly small.

The goal is to choose Var(w) so that n × Var(w) ≈ 1, meaning Var(w) = 1/n. This keeps the variance of activations roughly constant across layers. That's the core idea. The details depend on the activation function.

Xavier / Glorot Initialization (2010)

Xavier Glorot and Yoshua Bengio proposed initializing weights from a distribution with variance 2 / (fan_in + fan_out), where fan_in is the number of input neurons and fan_out is the number of output neurons. The averaging of fan_in and fan_out ensures that both the forward pass (activations) and backward pass (gradients) have stable variance. This derivation assumes linear activations or tanh (which is approximately linear near zero).

# Xavier/Glorot initialization — designed for tanh or sigmoid
import torch.nn.init as init

linear = nn.Linear(256, 128)
init.xavier_uniform_(linear.weight)   # Uniform distribution variant
# or
init.xavier_normal_(linear.weight)    # Normal distribution variant

# The uniform variant samples from:
# U(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out)))
# For fan_in=256, fan_out=128: U(-0.0625, 0.0625)

# The normal variant samples from:
# N(0, 2 / (fan_in + fan_out))
# For fan_in=256, fan_out=128: N(0, 0.0052) i.e., std ≈ 0.072

Kaiming / He Initialization (2015)

Xavier initialization assumes the activation function is approximately linear around zero. ReLU breaks that assumption: it sets all negative activations to zero, effectively killing half the signal. If you initialize with Xavier and use ReLU, the variance of activations halves at each layer because half the neurons output zero. After 10 layers, activations are attenuated by 2-10 ≈ 0.001×. Gradients suffer the same fate.

Kaiming He et al. (2015) fixed this by doubling the variance to compensate for ReLU's 50% kill rate: Var(w) = 2 / fan_in. The factor of 2 in the numerator accounts for half the neurons being dead. For Leaky ReLU with negative slope a, the formula generalizes to Var(w) = 2 / ((1 + a²) × fan_in).

# Kaiming/He initialization — designed for ReLU
linear = nn.Linear(256, 128)
init.kaiming_uniform_(linear.weight, nonlinearity='relu')
# or
init.kaiming_normal_(linear.weight, nonlinearity='relu')

# For Leaky ReLU with negative slope 0.01:
init.kaiming_normal_(linear.weight, a=0.01, nonlinearity='leaky_relu')

# PyTorch's default for nn.Linear is actually Kaiming Uniform
# So if you're using ReLU, you rarely need to change initialization manually

PyTorch Defaults

PyTorch's default initialization for nn.Linear is Kaiming Uniform (the He variant). This is a reasonable default for ReLU networks. nn.Conv2d also defaults to Kaiming Uniform. nn.Embedding initializes from N(0, 1). nn.LSTM initializes all weights from U(-1/sqrt(hidden_size), 1/sqrt(hidden_size)). In practice, PyTorch defaults work well for most architectures, and you rarely need to manually set initialization unless you're seeing specific symptoms (vanishing activations in deep networks, or training instability in the first few steps).

BatchNorm Makes Initialization Matter Less

One of the underappreciated benefits of batch normalization is that it makes weight initialization less critical. BatchNorm normalizes the activations at each layer to have zero mean and unit variance (during training). So even if your weights are poorly initialized and produce activations with wildly varying scales, BatchNorm resets the scale at every layer. The initialization still matters for the very first forward pass (before BatchNorm has accumulated any statistics), but the impact is greatly diminished.

This is one reason modern architectures with BatchNorm (like ResNets) are so easy to train — they're robust to initialization choices that would be fatal for unnormalized networks.

Transformer Initialization

Transformers use a different initialization convention. The standard approach, popularized by the original GPT and BERT implementations, initializes all weights from N(0, 0.02). The small standard deviation (0.02 instead of the Kaiming-derived value) prevents the pre-LayerNorm activations from growing too large early in training. Some implementations scale the initialization of residual path weights by 1/sqrt(N) where N is the number of residual layers, following the GPT-2 paper's suggestion. This ensures that the output of the residual stream doesn't grow with network depth at initialization.

# Transformer-style initialization
def init_transformer_weights(module):
    """Initialize weights following GPT-2 / BERT convention."""
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    elif isinstance(module, nn.LayerNorm):
        torch.nn.init.ones_(module.weight)
        torch.nn.init.zeros_(module.bias)

model.apply(init_transformer_weights)
💡 Key Insight

If your model uses ReLU and no BatchNorm, Kaiming initialization is essential. If your model uses BatchNorm, initialization matters less. If your model is a transformer, use the small-std convention (0.02) and consider scaling residual paths. If you're not sure, the PyTorch defaults are reasonable — but check by printing the first-batch activation statistics. If any layer's output has a mean far from zero or a variance far from one, initialization is likely the problem.

Putting It All Together: A Debugging Checklist

Let's crystallize everything into a checklist. When your sentiment classifier (or any model) isn't training well, work through this list in order. Each step either fixes the problem or eliminates a category of bugs.

Before Training Starts

1. Data sanity check. Print a few samples. Decode token IDs back to text. Verify labels match the content. Check shapes, dtypes, and value ranges. This takes 2 minutes and catches the majority of bugs.

2. Parameter audit. Print all parameter names, shapes, requires_grad status, and total counts. Verify nothing is accidentally frozen and the optimizer is connected to all trainable parameters.

3. Initial loss check. Forward one batch, compute loss. Compare to the expected value for random guessing (-log(1/C)). If it's wildly different, investigate before training.

4. Single-batch overfit test. Train on one fixed batch for 500 steps. Loss must go to near-zero. If it doesn't, you have a fundamental bug — fix it before proceeding.

During Training

5. Monitor loss per-batch, not per-epoch. Look for spikes, plateaus, and divergence.

6. Monitor gradient norms. Global norm per step, per-layer norms periodically. Alert if norm exceeds a threshold.

7. Gradient clipping enabled. max_norm=1.0 is a safe default. Log the clip rate.

8. model.eval() for validation, model.train() for training. Every time. No exceptions.

9. loss.item() for logging. Never store the loss tensor itself in a list.

10. Separate transforms for train and val. Random augmentation on training only.

When Things Go Wrong

11. NaN loss? Check for numerical instabilities (log(0), sqrt(0), exp(large)). Add epsilon guards. If using mixed precision, verify GradScaler is configured. Use detect_anomaly() to find the exact operation.

12. Loss plateaus? Run the learning rate finder. Check that the optimizer is stepping (scheduler might have reduced LR to zero). Verify gradients are non-zero on all layers.

13. Val loss diverges while train loss drops? Overfitting. Add dropout, weight decay, data augmentation. Or get more data.

14. Both losses high? Underfitting. Increase model capacity, increase learning rate, train longer.

15. Results differ between runs? Seed all RNGs, set deterministic flags, use worker_init_fn. Accept that cross-machine reproducibility is approximate.

# The complete, battle-tested training loop for our sentiment classifier
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
import random
import numpy as np

# ---- REPRODUCIBILITY ----
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

# ---- MODEL ----
model = SentimentModel(vocab_size=10000).cuda()

# ---- PARAMETER AUDIT ----
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {total:,} total, {trainable:,} trainable")

# ---- OPTIMIZER AND SCHEDULER ----
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# ---- INITIAL LOSS CHECK ----
model.eval()
with torch.no_grad():
    sample_batch = next(iter(train_loader))
    init_loss = criterion(model(sample_batch[0].cuda()), sample_batch[1].cuda())
    expected = torch.log(torch.tensor(2.0))  # 2-class: 0.693
    print(f"Initial loss: {init_loss.item():.4f} (expected: {expected.item():.4f})")
model.train()

# ---- TRAINING LOOP ----
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    clip_count = 0

    for step, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            logits = model(inputs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        # Gradient clipping — log the pre-clip norm
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        if grad_norm.item() > 1.0:
            clip_count += 1

        scaler.step(optimizer)
        scaler.update()

        # Safe logging — .item() detaches from graph
        batch_loss = loss.item()
        epoch_loss += batch_loss

        if torch.isnan(loss):
            print(f"NaN at epoch {epoch}, step {step}. Stopping.")
            break

    scheduler.step()

    # ---- VALIDATION (eval mode!) ----
    model.eval()
    val_loss = 0.0
    correct = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            logits = model(inputs)
            val_loss += criterion(logits, labels).item() * labels.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

    val_loss /= total_samples
    val_acc = correct / total_samples
    avg_train_loss = epoch_loss / len(train_loader)

    print(f"Epoch {epoch}: train_loss={avg_train_loss:.4f}, "
          f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f}, "
          f"clip_rate={clip_count/len(train_loader):.1%}, "
          f"lr={optimizer.param_groups[0]['lr']:.2e}")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= 5:
            print("Early stopping triggered.")
            break

That's the training loop I use for every new project, sentiment classifier or otherwise. It's not the shortest possible loop — it has monitoring, clipping, mixed precision, reproducibility, and early stopping built in. But every one of those features has saved me hours of debugging at some point. The 20 extra lines of code are an investment that pays for itself on the first training run that goes wrong.

💡 Key Insight

Debugging deep learning training isn't about knowing every possible thing that can go wrong. It's about having a systematic flow — data, loss, gradients, architecture — and working through it methodically. The bugs described here cover 95% of training failures I've encountered in practice. The remaining 5% are genuinely novel, and by the time you've ruled out the common issues, you'll have enough information to diagnose the novel ones.