Nice to Know

Chapter 8: Training Deep Networks 12 topics
TL;DR

Weight averaging smooths noisy training into better models. Knowledge distillation compresses giants into pocket-sized students. Curriculum learning and progressive resizing teach models the way humans learn — easy things first. Gradient checkpointing and torch.compile let you train bigger, faster. Exotic optimizers like SAM and LAMB exist for when you're competing or scaling to thousands of GPUs. And experiment tracking (W&B) is the tool you'll wish you'd used from run one.

Weight Averaging & EMA

Training a neural network with SGD or Adam is noisy by design — each mini-batch nudges the weights in a slightly different direction. The final set of weights you end up with is wherever the optimizer happened to stop, which may not even be the best point it visited during training. Weight averaging techniques address this by saying: instead of trusting the last snapshot, trust the average of many snapshots.

Stochastic Weight Averaging (SWA)

The idea is almost too straightforward to feel like a research contribution. Train your model normally for most of the run. Then, in the final stretch — the last 20–25% of epochs — start saving the weights at the end of each epoch. When training finishes, average all those saved checkpoints into one set of weights. That averaged model is what you deploy.

Why does this work? The averaged weights tend to land in flatter regions of the loss landscape. A flat minimum is one where small perturbations to the weights don't spike the loss. Sharp minima, by contrast, look great on the training set but crack under the slightest distribution shift. Averaging across many nearby points is like smearing the solution out across a wider basin — and wider basins generalize better. Izmailov et al. (2018) showed this consistently yields 1–2% accuracy improvements for free.

PyTorch has it built in. You wrap your model in torch.optim.swa_utils.AveragedModel, use SWALR as a scheduler for the averaging phase, and call update_bn at the end to recalculate batch norm statistics for the new averaged weights. That last step is easy to forget, and skipping it quietly tanks your accuracy — the batch norm running stats correspond to the original weights, not the averaged ones. When you'll need it: squeezing the last bit of performance before deployment, especially in vision tasks. It's the cheapest upgrade you can make to a trained model.

Exponential Moving Average (EMA)

EMA takes a different approach to the same problem. Instead of averaging checkpoints at the end, you maintain a shadow copy of the model weights throughout training, updated at every step: EMA = α × EMA + (1 − α) × current_weights, where α is typically 0.999 or 0.9999. The shadow weights are a smoothed-out version of the training trajectory. You train with the real weights but evaluate and deploy with the shadow weights.

Think of it as a running diary versus a final exam. The real weights are studying — jumping around, correcting mistakes, sometimes overcorrecting. The EMA weights are the diary that records the overall trend, filtering out the day-to-day noise. This is Polyak averaging, and it's been known since 1992 to improve convergence properties.

EMA is everywhere in modern deep learning. Diffusion models use it (the original DDPM paper applies EMA by default). GANs use it (the discriminator's EMA weights stabilize training). Self-supervised methods like BYOL use a momentum encoder that is, at its core, an EMA of the main encoder. If you're training anything generative, you're almost certainly using EMA whether you know it or not.

The distinction from SWA matters: EMA is online (updated every step, decays old weights exponentially) while SWA is offline (average the last N checkpoints uniformly). EMA adapts more to recent training; SWA treats all saved checkpoints equally. In practice, EMA is used during training and SWA is bolted on at the end. When you'll need it: any production model, generative models, or whenever you want a more stable model without changing your training loop.

Training Strategies

Knowledge Distillation

You have a 10-billion-parameter model that's wonderfully accurate and completely impossible to deploy on a phone. Knowledge distillation is how you transfer what that model knows into something a hundred times smaller.

The key insight — and this is genuinely clever — is that a model's soft probability outputs contain far more information than the hard labels. When a teacher model sees a cat image and outputs cat=0.7, tiger=0.2, lynx=0.1, those secondary probabilities are gold. They encode relationships between classes that the hard label "cat" completely ignores. Hinton et al. (2015) called this "dark knowledge" — the information hiding in the wrong answers.

To amplify this dark knowledge, you use temperature scaling. Instead of the standard softmax, you divide logits by a temperature T before applying softmax: softmax(zi / T). At T=1, you get normal probabilities. At T=3 or T=5, the distribution softens — the dominant class comes down, the minor classes come up, and those subtle inter-class relationships become louder. Both teacher and student use the same temperature, and the student learns to match the teacher's softened outputs.

The student's training loss is typically a weighted combination: α × KL-divergence between student and teacher soft outputs, plus (1 − α) × standard cross-entropy with the hard ground-truth labels. The first term teaches the student how the teacher thinks. The second keeps it grounded in reality. This is how companies like Apple deploy models on-device, how DistilBERT got 97% of BERT's performance with 40% fewer parameters, and increasingly how smaller LLMs get trained using outputs from much larger ones. When you'll need it: any time you need a smaller, faster model that retains the quality of a larger one — edge deployment, mobile, latency-critical serving.

Curriculum Learning

Bengio et al. (2009) formalized something intuitive: humans learn better when you start with easy material and progress to hard material. What if neural networks do too?

The idea is to rank your training examples by difficulty and feed them in order — easy examples early in training, hard examples later. The model builds a solid foundation before tackling edge cases. In noisy datasets, this is particularly effective because the easy examples tend to be the clean ones, giving the model a stable starting signal before introducing the noisy, ambiguous samples that might confuse an untrained model.

The hard part — and the reason curriculum learning isn't universally adopted — is defining "difficulty." You need a scoring function, and there's no one-size-fits-all answer. Common approaches: use the loss value itself (examples with low loss on a pre-trained model are "easy"), use model confidence, or hand-craft heuristics (shorter sentences are easier, lower-resolution images are easier). Self-paced learning automates this by letting the model itself decide what's easy at any point during training, which is gaining traction in 2024 for LLMs and multi-modal tasks.

I'll be honest — in my experience, curriculum learning helps noticeably on messy, real-world datasets and sequence tasks, but on clean benchmarks like CIFAR-10, it barely moves the needle. If you're struggling with convergence on noisy data, it's worth trying. If your data is already clean and well-balanced, your time is better spent elsewhere. When you'll need it: noisy or imbalanced datasets, sequence modeling, or any task where you can clearly define what "easy" means.

Progressive Resizing

This is curriculum learning applied to image resolution, and it's one of those techniques that makes you wonder why it isn't the default everywhere. Start training on small images — 64×64 or 128×128. Once the model has learned coarse features (shapes, colors, basic patterns), bump the resolution to 224×224 or higher for the final phase.

The economics are compelling. A 64×64 image has 16× fewer pixels than a 256×256 image, which means the forward pass is dramatically cheaper. You can fit larger batch sizes in the same VRAM, which stabilizes gradients. Early epochs are fast, so you iterate quickly. The model learns global structure first — the outline of a dog, the general color palette — and only later refines the fine-grained details like fur texture or whisker positions.

Jeremy Howard and the fast.ai team popularized this technique, and it's a staple in Kaggle winning solutions for image competitions. The technique also acts as a regularizer: small images physically can't overfit on high-frequency details because those details aren't there. When you'll need it: image classification and segmentation tasks, especially when VRAM is limited or you want faster iteration during early experiments.

Memory & Compute Tricks

Gradient Checkpointing

Here's the memory problem with training deep networks: during the forward pass, every layer produces activations, and those activations need to be kept around for the backward pass to compute gradients. A 100-layer ResNet stores 100 sets of activations. A transformer with 96 layers stores 96 attention maps, 96 sets of intermediate states. This eats VRAM alive.

Gradient checkpointing (also called activation checkpointing) makes a deliberate trade: throw away most activations during the forward pass, and recompute them on the fly during the backward pass. You only keep activations at selected "checkpoint" layers. When backprop reaches a gap, it re-runs the forward pass from the nearest checkpoint to regenerate what it needs.

The math works out elegantly. If your model has L layers and you place checkpoints every √L layers, memory drops from O(L) to O(√L) — a massive reduction — at the cost of one additional forward pass through each segment. In practice, this translates to roughly 60% memory savings at about a 20–30% training time increase. This is the reason you can fine-tune a 7B-parameter LLM on a single consumer GPU. Without it, you'd need a cluster. PyTorch: torch.utils.checkpoint. Hugging Face Transformers: set gradient_checkpointing=True in the training config. When you'll need it: any time your model barely fits in VRAM, which is most of the time with modern architectures.

torch.compile (PyTorch 2.0+)

For years, PyTorch's eager execution — run each operation immediately, one at a time — was its greatest strength for debugging and its biggest performance weakness. Every operation launches a separate GPU kernel, returns to Python, launches another kernel. The overhead adds up.

torch.compile changes the game. One line — model = torch.compile(model) — and PyTorch's TorchDynamo intercepts your Python bytecode, extracts the tensor operations into a computation graph, and hands that graph to TorchInductor for optimization. Inductor fuses adjacent operations into single GPU kernels (eliminating kernel launch overhead), optimizes memory layouts, and can even generate custom Triton kernels on the fly.

Typical speedups are 1.5×–3× with zero changes to your model code. The catch: the first forward pass is slow because it's compiling the graph (this can take 30 seconds to several minutes for large models). Subsequent passes are fast. As of 2024, it's stable for standard architectures — CNNs, transformers, most attention variants — but can choke on dynamic control flow, custom CUDA extensions, or unusual tensor operations. The error messages when compilation fails are... improving. Try it first on any new project. When you'll need it: any PyTorch training or inference workload where speed matters, which is all of them.

Exotic Optimizers

Lookahead Optimizer

Lookahead (Zhang et al., 2019) wraps any base optimizer — typically Adam or SGD — and adds a second set of "slow weights." The inner optimizer updates the "fast weights" normally for k steps (usually 5–10). Then Lookahead interpolates the slow weights toward the fast weights: slow = slow + α × (fast − slow), and resets the fast weights to match. It's the optimizer periodically asking: "Am I still heading somewhere useful, or did I wander off?"

This reduces variance in the optimization trajectory and can stabilize training in situations where Adam alone oscillates. The overhead is negligible. The catch is that empirical improvements are inconsistent — sometimes it helps, sometimes it doesn't, and it's hard to predict which. It never caught on widely, but the idea is elegant and worth knowing about. When you'll need it: unstable training where you've already tuned learning rate and weight decay. It's a low-risk experiment.

LAMB & LARS

When you scale batch sizes to 32K, 64K, or beyond — the kind of scale you hit when training across hundreds of GPUs — standard optimizers fall apart. The problem is that different layers have wildly different gradient magnitudes. A learning rate that's appropriate for the first convolutional layer can blow up the final classifier head, or vice versa.

LARS (Layer-wise Adaptive Rate Scaling, You et al. 2017) solves this by computing a per-layer learning rate based on the ratio of the layer's weight norm to its gradient norm. LAMB (Layer-wise Adaptive Moments for Batch training, You et al. 2020) extends this idea to Adam-style optimizers. LARS trained ResNet-50 on ImageNet in under 30 minutes. LAMB trained BERT in 76 minutes. These are engineering achievements that only matter at a scale most of us will never personally encounter — but they're how the big labs achieve those impossibly fast training times you see in papers. When you'll need it: distributed training across 64+ GPUs with very large batch sizes. Otherwise, Adam with a good schedule will serve you well.

Sharpness-Aware Minimization (SAM)

SAM (Foret et al., 2021) approaches generalization from a different angle. Instead of asking "find weights where the loss is low," it asks "find weights where the loss is low and the neighborhood is also low." It explicitly seeks flat regions of the loss landscape by adding a perturbation step: first, compute the gradient, take a step in the direction of steepest ascent to find the worst-case neighbor, then compute the gradient at that worst-case point and update in the direction of steepest descent.

Two forward-backward passes per step. That's the entire problem with SAM — it doubles training time. The generalization improvements are consistent and real (0.5–1.5% on standard benchmarks), which makes it extremely popular in Kaggle competitions where accuracy matters more than training cost. For production models where you're running hundreds of training runs during hyperparameter search, the 2× cost is brutal. Variants like ASAM (adaptive SAM) and LookSAM try to reduce the overhead, but the core trade-off remains. When you'll need it: competitions, final production model training where every fraction of a percent matters, or research where generalization is the primary concern.

Large-Scale Training Frameworks

DeepSpeed

Microsoft's DeepSpeed library exists because a 175-billion-parameter model requires roughly 700 GB of memory in fp32 (parameters alone are 700 GB, optimizer states push it to ~2.8 TB with Adam). No single GPU has that kind of memory. DeepSpeed's ZeRO (Zero Redundancy Optimizer) solves this by sharding — splitting the memory footprint across multiple GPUs.

ZeRO has three stages, each more aggressive. Stage 1 shards optimizer states (the Adam momenta, the variance terms — the stuff that's 2–3× the size of the model itself). Each GPU keeps only its slice. Stage 2 adds gradient sharding — gradients are also partitioned rather than replicated. Stage 3 goes all the way and shards the model parameters themselves. At Stage 3, no single GPU holds the full model at any point. Parameters are gathered just-in-time for computation, used, and released. This is how you train 100B+ parameter models on "commodity" hardware (which still means 8–64 expensive GPUs, but not thousands).

The trade-off, as always, is communication. Higher stages mean more data shuffling between GPUs. Stage 1 is nearly free in communication overhead. Stage 3 requires significant bandwidth. PyTorch now has a native alternative called FSDP (Fully Sharded Data Parallel) that implements similar ideas to ZeRO Stage 3, which is worth knowing if you prefer staying within the PyTorch ecosystem. When you'll need it: models that don't fit on a single GPU even with mixed precision and gradient checkpointing.

Megatron-LM

NVIDIA's Megatron-LM tackles a complementary problem. Where DeepSpeed focuses on memory efficiency, Megatron focuses on compute efficiency across many GPUs. It provides three types of parallelism: tensor parallelism (split individual matrix multiplications across GPUs), pipeline parallelism (split the model into sequential stages, each on different GPUs), and sequence parallelism (split the sequence dimension of attention computations).

In practice, large-scale training runs often combine both — Megatron-LM handles the parallelism strategy and DeepSpeed handles the memory optimization. If you read a paper that says "we trained a 530B parameter model," both names are probably in the methods section. When you'll need it: training at the 10B+ parameter scale with multi-node GPU clusters. For everything smaller, PyTorch's DDP or FSDP will get you there.

Experiment Tracking

Weights & Biases (W&B)

Every ML engineer eventually reaches the point where they have 47 training runs, each with slightly different hyperparameters, and no idea which one produced that promising loss curve they saw three days ago. That's the moment you wish you'd set up experiment tracking from the start.

W&B is the tool that's become the de facto standard for this. Two lines — wandb.init(project="my-project") and wandb.log({"loss": loss, "lr": lr}) — and you get real-time loss curves, system GPU/CPU monitoring, hyperparameter logging, and artifact versioning. The dashboard lets you overlay runs, compare hyperparameters against metrics, and run automated sweeps. It's free for personal use and ubiquitous in research papers.

Alternatives exist and have their place. MLflow is fully open-source and gives you complete data sovereignty — many enterprises use it for production model governance because nothing leaves your infrastructure. TensorBoard is lightweight and works fine for quick local experiments. Neptune and Comet fall somewhere in between. A common pattern at companies: W&B for R&D iteration speed, MLflow for production model registry and compliance. When you'll need it: from your very first serious training run. The earlier you start logging, the less time you spend recreating experiments you forgot to save.

Interview Corner

Senior interviewers like to probe whether you've actually used these techniques or are reciting definitions. A few things worth internalizing:

EMA and SWA are not the same thing. EMA updates at every step with exponential decay — recent weights matter more. SWA averages the last N checkpoints uniformly — all are weighted equally. Mixing them up signals you've only read the titles.

Knowledge distillation's temperature parameter is the key detail. At T=1 you get standard softmax. Higher T softens the distribution and amplifies the "dark knowledge" — those secondary class probabilities that encode inter-class similarities. If someone asks "why not train the student directly on the same data?" the answer is: because soft targets from the teacher carry strictly more information than hard labels.

Gradient checkpointing's memory reduction is O(√L), not O(1). You don't eliminate activation storage — you trade the full O(L) for O(√L) by recomputing segments. If someone asks the cost, it's roughly one extra forward pass through each segment, not a doubling of total compute.

If DeepSpeed ZeRO comes up, know the stages: Stage 1 shards optimizer states, Stage 2 adds gradient sharding, Stage 3 shards everything including parameters. And know that PyTorch's FSDP is the native equivalent of ZeRO Stage 3 — that comparison comes up frequently.

When You'll Need These

None of these are day-one essentials. The main sections of this chapter cover what you need to train deep networks competently. Come back here when the basics aren't enough. Model doesn't fit in memory? Gradient checkpointing and DeepSpeed. Last percent of accuracy? SWA, EMA, or SAM. Deploying a giant to a phone? Knowledge distillation. Training too slow? torch.compile. Lost track of experiments? W&B. You'll know when you need each one.