Scaling & Efficiency
I avoided thinking about distributed training for an embarrassingly long time. Every time I hit a CUDA out-of-memory error, I'd reduce the batch size to something pathetic, wait three days for training to finish, and pretend that was fine. When people talked about ring all-reduce and pipeline bubbles and 3D parallelism, I nodded along and immediately forgot everything. Finally, I was staring at a 7-billion parameter model that needed to fit on hardware I actually had access to, and the discomfort of not understanding what happens when one GPU isn't enough grew too great. Here is that dive.
Scaling deep learning training means two things: making a single model train faster across multiple devices, and making sure you aren't wasting compute on things that don't matter. The first half is distributed training — the machinery of splitting work across GPUs. The second is efficiency — mixed precision, memory tricks, and understanding how model size, data, and compute relate to each other through scaling laws. Both of these became essential knowledge sometime around 2020, and they're now table stakes for anyone training anything non-trivial.
Before we start, a heads-up. We'll be talking about GPU memory hierarchies, communication algorithms, and some arithmetic around bytes-per-parameter. You don't need to know any of it beforehand. We'll add the concepts we need one at a time, with explanation.
This isn't a short journey, but I hope you'll be glad you came.
The VRAM Problem: Where All Your Memory Goes
Let's start with the thing that makes everything else necessary. Your GPU has a fixed amount of VRAM — Video RAM, the GPU's dedicated memory. Everything that exists during training has to live there simultaneously: the model's parameters, the gradients computed during backpropagation, the optimizer's internal bookkeeping, and the intermediate activations saved during the forward pass so the backward pass can use them.
Here's a concrete picture. Imagine you're training a modest 1-billion parameter model. Each parameter is a floating-point number. In the standard float32 format, that's 4 bytes per number. So the parameters alone eat 4 GB. But the parameters are the small part.
The gradients during the backward pass are the same size as the parameters — another 4 GB. The optimizer state is where things get painful. Adam (and AdamW, which is what almost everyone uses) stores two extra tensors per parameter: the running mean of gradients (first moment) and the running mean of squared gradients (second moment). That's another 8 GB. So before you've even processed a single data point, your 1B model has consumed 16 GB on parameters, gradients, and optimizer state alone.
Then there are activations — the intermediate outputs of every layer, saved during the forward pass because backpropagation needs them. Their size scales with batch size, sequence length, hidden dimension, and the number of layers. For a transformer with a reasonable batch size, activations can easily double or triple the total memory usage.
Let me make that tangible. Here's the arithmetic for a 1B parameter model using mixed precision (float16 weights, float32 optimizer):
| Component | Size |
|---|---|
| Parameters (1B × 2 bytes in FP16) | ~2 GB |
| Gradients (1B × 2 bytes) | ~2 GB |
| Adam states (1B × 8 bytes, kept in FP32) | ~8 GB |
| Master weights (1B × 4 bytes, FP32 copy) | ~4 GB |
| Activations (varies with batch size) | ~2–8 GB |
| Total | ~18–24 GB |
That 1B model barely fits on a 24 GB consumer GPU (RTX 3090/4090) with a small batch size. Now think about a 7B model: 7× everything, so roughly 130–140 GB for parameters, gradients, and optimizer state alone — before activations. That's two A100-80GB GPUs minimum, and we haven't even started thinking about throughput.
This is the force that drives everything in this section. Every technique we'll look at — distributed training, mixed precision, gradient checkpointing, sharding — exists because VRAM runs out and we have to get creative about either splitting the work or reducing the memory footprint.
RuntimeError: CUDA out of memory. Tried to allocate X MiB — this is the single most common error in deep learning. Before searching for exotic solutions, try the four simplest fixes in this order: reduce batch size, enable mixed precision, use gradient accumulation, enable gradient checkpointing. These solve 90% of OOM errors.
Mixed Precision: The Free Lunch
The first thing to reach for when memory is tight — and honestly, even when it isn't — is mixed precision training. The idea is refreshingly straightforward. Most of the arithmetic in a neural network doesn't need the full precision of float32. Matrix multiplications, convolutions, element-wise operations — they do fine with half the bits. So we run the computationally expensive parts in float16 (or bfloat16), keep a master copy of the weights in float32 for the optimizer to work with, and get the best of both worlds.
The payoff is real: roughly 2× speedup on Tensor Core GPUs (anything from V100 onward) and around 40% less memory because activations are halved. The accuracy cost is negligible — the rounding noise from 16-bit arithmetic is completely drowned out by the inherent noise of mini-batch stochastic gradient descent.
There's one wrinkle. Float16 has a narrow dynamic range — its maximum value is about 65,504, compared to float32's 3.4 × 10³⁸. Small gradients can underflow to zero and vanish. The solution is loss scaling: multiply the loss by a large number before the backward pass (inflating the gradients into float16's representable range), then divide the gradients back down before the optimizer step.
BFloat16 sidesteps this problem entirely. Google designed it with the same exponent range as float32 (8 exponent bits) but fewer mantissa bits (7 instead of float16's 10). Same dynamic range, slightly less precision, no loss scaling needed. If your hardware supports BF16 (A100 and newer, or TPUs), use it. Life is easier.
| Format | Bits | Exponent | Mantissa | Max Value | Loss Scaling? |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | ~3.4 × 10³⁸ | No |
| FP16 | 16 | 5 | 10 | ~65,504 | Yes |
| BF16 | 16 | 8 | 7 | ~3.4 × 10³⁸ | No |
In PyTorch, the implementation is almost insultingly simple. Four extra lines turn on Automatic Mixed Precision (AMP):
from torch.amp import autocast, GradScaler
scaler = GradScaler("cuda")
for inputs, targets in dataloader:
optimizer.zero_grad()
with autocast("cuda"):
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
The autocast context manager routes matrix multiplications and convolutions to FP16 Tensor Cores, while keeping numerically sensitive operations (softmax, layer norm, loss functions) in FP32. The GradScaler dynamically adjusts the loss scaling factor — if gradients overflow, it halves the scale and skips the step; if things are stable, it gradually increases the scale to capture smaller gradients.
For BF16 on A100 or newer, you don't even need the scaler:
with autocast("cuda", dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
I'll be honest — when I first heard "half the precision, same results," I didn't believe it. But the math checks out. The optimizer updates happen in FP32, so the accumulated weight changes are identical. The only noise introduced is in the forward and backward passes, and that noise is dwarfed by the stochasticity of mini-batch SGD. If your mixed precision run diverges but your FP32 run doesn't, the bug is in your scaling setup, not in the concept.
Memory Optimization: Trading Compute for Bytes
Mixed precision halves your activation memory. But sometimes that's not enough. Two more techniques let you push further by trading compute time for memory savings.
Gradient Accumulation
The idea is almost too simple to feel like a technique. You want an effective batch size of 128, but only 32 samples fit in memory. So you run 4 forward-backward passes with batch size 32, accumulating the gradients without stepping the optimizer, and then step once after all 4. The gradients add up to the same thing as if you'd processed 128 samples at once.
accumulation_steps = 4
for i, (inputs, targets) in enumerate(dataloader):
with autocast("cuda"):
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
The division by accumulation_steps matters — without it, your effective learning rate would be 4× too large, because the gradients from 4 micro-batches are summed rather than averaged. Training time per step is the same (you still do the same total compute), but peak memory stays at the single micro-batch level.
Gradient Checkpointing
This one trades compute for memory in a more dramatic way. During a normal forward pass, every layer saves its output (the activations) because backpropagation needs them. For a 100-layer network, that's 100 layers worth of activations sitting in memory.
With gradient checkpointing (also called activation checkpointing or activation recomputation), you only save the activations at certain "checkpoint" layers. When the backward pass reaches a layer whose activations weren't saved, it recomputes them on-the-fly from the nearest checkpoint. This reduces activation memory from O(N) to roughly O(√N) for N layers, at the cost of about 30% extra compute time — you're running parts of the forward pass twice.
from torch.utils.checkpoint import checkpoint
# Instead of: output = transformer_block(x)
# Use: output = checkpoint(transformer_block, x, use_reentrant=False)
In practice, people checkpoint every transformer block or every few residual blocks. It's the go-to technique when your model fits in memory for inference but not for training (because training needs to store activations and gradients on top of the parameters).
I'm still developing my intuition for the optimal checkpoint placement — there's a real engineering tradeoff between memory reduction and the wall-clock cost of recomputation, and it depends heavily on the specific architecture and hardware. The safe default: checkpoint every transformer layer and accept the ~30% slowdown.
Congratulations on making it this far. You can stop here if you want. You now understand why VRAM is the bottleneck, how mixed precision gives you a ~2× speedup for free, how gradient accumulation lets you simulate large batches on small GPUs, and how gradient checkpointing trades compute for memory. That's enough to train most models on a single GPU efficiently. The short version of everything that follows: when one GPU isn't enough, DDP copies the model to multiple GPUs and splits the data; when the model itself doesn't fit on one GPU, FSDP or DeepSpeed shards the model across GPUs; and scaling laws tell you how to spend your compute budget wisely. There. You're 60% of the way there. But if the discomfort of not knowing what happens inside a distributed training run is nagging at you, read on.
DistributedDataParallel: The Workhorse
Let's picture a scenario. You've got a model that fits on one GPU, but training takes four days. You have access to 4 GPUs. The most natural thing in the world is to split the data four ways, give each GPU the same model, let them each compute gradients on their slice, and then combine the gradients so every GPU gets the same update. That's data parallelism.
PyTorch offers two implementations of this idea, and one of them you should never touch.
DataParallel: The One You Should Never Use
nn.DataParallel is the tempting one-liner — model = nn.DataParallel(model) — and it's fatally flawed. It runs in a single Python process, so it hits the GIL (Global Interpreter Lock). GPU 0 gathers all outputs and computes the loss, using far more memory than the other GPUs. Gradients are collected to GPU 0 and then broadcast back, sequentially. The result: it often runs slower with 4 GPUs than with 2, because the GPU-0 bottleneck dominates. It exists for historical reasons. Forget it exists.
DistributedDataParallel: How It Actually Works
DDP does the right thing. It spawns one process per GPU. Each process holds its own copy of the model, gets a unique slice of the training data via a DistributedSampler, and computes gradients independently. The magic happens in how the gradients are synchronized.
The core operation is called all-reduce. Think of it this way: every GPU has computed gradients for its local batch, and we need every GPU to end up with the average of all those gradients. The naive approach — send everything to one GPU, average there, send back — would create a bottleneck at that one GPU. Ring all-reduce is the clever alternative.
Imagine the GPUs arranged in a ring. Each GPU splits its gradient tensor into N chunks (where N is the number of GPUs). In the first phase (reduce-scatter), each GPU sends one chunk to its neighbor and receives a chunk from the other direction, adding them together as they pass around the ring. After N-1 steps, each GPU holds the fully reduced version of one chunk. In the second phase (all-gather), the complete chunks are passed around until every GPU has the full result. The elegant thing: at every step, each GPU is sending and receiving simultaneously, maxing out the network bandwidth. The total data each GPU sends is 2 × (N-1)/N times the gradient size — which approaches 2× the gradient size regardless of how many GPUs you add. That's bandwidth-optimal.
But DDP doesn't wait for the entire backward pass to finish before starting communication. It uses gradient bucketing: as gradients are computed layer by layer during the backward pass, they're grouped into buckets (25 MB each by default). As soon as a bucket is full, its all-reduce kicks off in the background on a separate CUDA stream, overlapping communication with the computation of earlier layers' gradients. By the time the backward pass finishes, most of the gradient synchronization is already done.
The result: near-linear scaling. 4 GPUs typically give 3.6–3.9× the throughput of one.
Here's the pattern you'll reuse in every DDP training script:
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # different shuffle each epoch
for batch in dataloader:
loss = model(batch.to(rank))
loss.backward()
optimizer.step()
optimizer.zero_grad()
dist.destroy_process_group()
# Launch with: torchrun --nproc_per_node=4 train_script.py
A few things to notice. init_process_group("nccl") uses NVIDIA's NCCL library — the standard for GPU-to-GPU communication, and what implements the ring all-reduce under the hood. The DistributedSampler splits the dataset so each GPU processes a unique subset; without it, all GPUs train on identical data and you've gained nothing. The sampler.set_epoch(epoch) call ensures different shuffling each epoch — forgetting this is a subtle bug that leads to the same data split every epoch. And torchrun is the launcher that spawns one process per GPU, handles environment variables, and supports fault tolerance.
Two more gotchas that will save you hours of debugging. First, log metrics and save checkpoints only on rank 0 — every process runs the same training loop, and you'll get duplicate logs and corrupted checkpoint files if you save from all ranks. Access the unwrapped model via model.module when saving. Second, if you're using BatchNorm, convert to SyncBatchNorm before wrapping with DDP, because standard BatchNorm computes statistics per-GPU, and with small per-GPU batch sizes those statistics become noisy.
When the Model Doesn't Fit: FSDP and DeepSpeed ZeRO
DDP replicates the entire model on every GPU. For a 7B parameter model with Adam, that's ~130 GB of state per GPU — more than any single GPU has. This is where sharding enters the picture.
The core insight is that most of the memory sitting on each GPU during DDP training is redundant. Every GPU holds a complete copy of the optimizer state, even though each GPU only needs those states during the optimizer step. If you have 8 GPUs, you're storing the same 8 GB of Adam moments 8 times. That's 56 GB wasted.
DeepSpeed ZeRO (Zero Redundancy Optimizer), developed by Microsoft, attacks this redundancy in three progressive stages. Think of it as increasingly aggressive sharding:
ZeRO Stage 1 partitions only the optimizer states across GPUs. Each GPU still holds all parameters and gradients, but stores only 1/N of the optimizer state (where N is the GPU count). Memory per GPU drops from P + G + O to P + G + O/N.
ZeRO Stage 2 adds gradient partitioning. Each GPU holds all parameters but only 1/N of gradients and optimizer states. Memory: P + G/N + O/N. The communication pattern changes — instead of all-reduce on gradients, each GPU reduce-scatters its gradient shard directly, which is slightly more efficient.
ZeRO Stage 3 goes all the way: parameters, gradients, and optimizer states are all sharded. Each GPU holds only 1/N of everything. Memory per GPU: (P + G + O) / N. When a layer needs its full parameters for a forward or backward pass, the participating GPUs gather them temporarily, compute, then discard the non-local shards. This is the maximum memory savings — and the maximum communication overhead.
PyTorch's FSDP (Fully Sharded Data Parallel) is essentially ZeRO Stage 3 built natively into PyTorch. Same idea, different implementation. The practical difference: FSDP integrates tightly with PyTorch's API and ecosystem, while DeepSpeed offers more advanced features (NVMe offloading, custom optimizers, more knobs) and tends to edge out at massive scale (64+ GPUs).
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
model = FSDP(
model,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
)
The tradeoff is always the same: more sharding means less memory per GPU but more communication to gather and scatter parameters for every layer. The practical rule is to use the lowest ZeRO stage (or DDP if possible) that fits your model. DDP first. If that OOMs, try ZeRO Stage 1. Then Stage 2. Then FSDP/ZeRO Stage 3. Each step trades some communication overhead for more memory headroom.
DeepSpeed also offers ZeRO-Infinity, which extends offloading to CPU RAM and even NVMe storage. This technically makes model size bounded only by your disk space, but the IO overhead is substantial. It's a tool of last resort — useful for experimentation with truly enormous models on limited hardware, but not something you'd want for production training.
Model Parallelism: Splitting the Model Itself
Everything we've seen so far is a form of data parallelism — same model (or shards of it), different data on each GPU. But there's a second axis: splitting the model's computation across GPUs. This is model parallelism, and it comes in two flavors.
Tensor Parallelism
Tensor parallelism splits individual layers across GPUs. Consider a large linear layer that multiplies a hidden state by a weight matrix. If the weight matrix is too wide for one GPU — or even if it isn't, but you want to speed up the computation — you can split the matrix column-wise across 4 GPUs, have each GPU compute its quarter of the output, and then combine the results. In transformer architectures, this is typically applied to the attention heads and the MLP layers: each GPU computes a subset of attention heads or a portion of the feed-forward network.
The requirement is a fast interconnect between GPUs. Tensor parallelism involves communication at every layer — partial results need to be combined before the next layer can proceed. On NVLink (600 GB/s on H100), this is fast enough to be practical. Over PCIe (64 GB/s), the communication overhead would kill your throughput. This is why tensor parallelism is typically used within a single node where GPUs are connected by NVLink, not across nodes.
Pipeline Parallelism
Pipeline parallelism partitions the model by layers. GPU 0 gets layers 1–10, GPU 1 gets layers 11–20, and so on. Data flows through the pipeline: GPU 0 processes a micro-batch, sends the activations to GPU 1, and immediately starts on the next micro-batch. In principle, all GPUs are busy once the pipeline fills up.
In practice, there's a problem called pipeline bubbles. At the start and end of a batch, some GPUs sit idle because the pipeline isn't full yet. If you have 4 stages and process 4 micro-batches, the first GPU finishes its forward passes before the last GPU has started, and the same happens in reverse for the backward pass. Micro-batching helps — more micro-batches means a fuller pipeline — but doesn't eliminate the bubble entirely. Typical pipeline efficiency is 80–90%, meaning 10–20% of GPU time is wasted waiting.
I'll be honest: pipeline parallelism has always felt like the least elegant of the three approaches to me. The bubble overhead is real, the scheduling gets complicated (GPT-style interleaved schedules, 1F1B scheduling), and debugging pipeline-parallel code is a genuine headache. But for very deep models where you need to split across many nodes, it's sometimes the only option.
3D Parallelism: Combining Everything
State-of-the-art systems like Megatron-LM (NVIDIA's framework for training very large language models) use all three dimensions simultaneously. A 64-GPU cluster might be organized as 4-way data parallelism × 4-way pipeline parallelism × 4-way tensor parallelism: 4 × 4 × 4 = 64 GPUs. Tensor parallelism operates within a node (fast NVLink), pipeline parallelism spans across nodes, and data parallelism replicates the entire pipeline for higher throughput.
This is the architecture behind GPT-4, LLaMA, Gemini, and every other model with tens of billions of parameters. You probably won't need to implement 3D parallelism yourself — frameworks like Megatron-LM, DeepSpeed, and Hugging Face Accelerate handle the orchestration — but understanding the vocabulary is essential for reading the engineering blogs and papers that describe these systems.
| Strategy | What's Split | Communication Pattern | When to Use |
|---|---|---|---|
| Data Parallel (DDP) | Data batches | All-reduce on gradients | Model fits on one GPU; need more throughput |
| FSDP / ZeRO | Model state (sharded) | Gather/scatter per layer | Model state doesn't fit on one GPU |
| Tensor Parallel | Layer computation | All-reduce per layer (fast interconnect) | Wide layers; intra-node with NVLink |
| Pipeline Parallel | Layer groups | Point-to-point activation passing | Very deep models; cross-node |
Flash Attention: Rethinking the Memory Hierarchy
So far we've been attacking memory at the level of parameters, gradients, and optimizer states. But there's another massive memory hog hiding inside transformers: the attention computation itself.
Standard self-attention computes an N × N matrix (where N is the sequence length) of attention scores — queries times keys. For a sequence length of 4,096, that's a 16-million element matrix per attention head. For 32 heads, that's half a billion numbers. This matrix gets materialized in GPU HBM (the main GPU memory), then softmax is applied, then it's multiplied by the values. Three reads from and writes to slow HBM memory. For long sequences, this becomes the bottleneck.
Flash Attention (Tri Dao et al., 2022) is an IO-aware algorithm that restructures the computation to avoid materializing the full attention matrix. The key insight is that GPUs have a small, ultra-fast on-chip memory called SRAM (about 100 KB per streaming multiprocessor, compared to 80 GB of HBM on an A100). Standard attention bounces data back and forth between fast SRAM and slow HBM. Flash Attention tiles the computation — processes small blocks of queries, keys, and values that fit entirely in SRAM, computes partial softmax results, and combines them. The full N × N matrix never exists in memory.
The result: memory goes from O(N²) to O(N) in sequence length, and wall-clock time drops by 2–4× because the GPU spends less time waiting for memory transfers and more time doing math. And — this is the part that surprised me — it's exact. Flash Attention computes the mathematically identical result to standard attention. No approximation, no accuracy loss. It's purely a scheduling optimization that respects the GPU's memory hierarchy.
In practice, you don't implement Flash Attention yourself. PyTorch 2.0+ includes it automatically through torch.nn.functional.scaled_dot_product_attention, and Hugging Face transformers enable it with a single flag. It's become standard — there's no reason not to use it.
Scaling Laws: How to Spend Your Compute Budget
All the techniques above are about how to train efficiently. Scaling laws answer a different question: what to train. Given a fixed compute budget, how big should your model be, and how much data should you train on?
In 2020, Kaplan et al. (OpenAI) showed that the loss of a language model decreases as a power law with model size N, dataset size D, and compute budget C:
L(N, D) ≈ L∞ + kN · N-α + kD · D-β
In plain English: make the model bigger, loss goes down. Train on more data, loss goes down. Both follow smooth, predictable curves. This is what makes neural network scaling so remarkable — unlike most engineering problems, the returns are predictable across orders of magnitude.
But there's a catch. Kaplan's analysis suggested that for a fixed compute budget, you should make the model as big as possible and train it on relatively little data. This led to the era of enormous models trained on "only" a few hundred billion tokens.
In 2022, Hoffmann et al. (DeepMind) challenged this with the Chinchilla paper, which showed that most large models were undertrained. Their refined analysis found that for a fixed compute budget C, the optimal allocation is roughly:
D* ≈ 20 × N
That is, you should train on about 20 tokens per parameter. The compute budget relates to both as C ≈ 6 × N × D (where C is in FLOPs). A 7B parameter model should ideally see about 140 billion tokens. A 70B model should see about 1.4 trillion tokens.
This single insight reshaped the field. It's why LLaMA-1 (7B) was trained on 1 trillion tokens — far more than earlier models of similar size. It's why the conversation shifted from "how big can we make the model" to "do we have enough data."
My favorite thing about scaling laws is that, aside from the high-level empirical observation, no one is completely certain why they hold so cleanly. We know that loss decreases as a power law with compute. We can fit the exponents. But the theoretical explanation for why neural networks exhibit such predictable scaling behavior is still an active area of research. It works, we can plan around it, and we don't fully understand why. I find that kind of humbling.
Before spending $100K training a large model, do the Chinchilla math. C = 6ND. For your compute budget C, the optimal model has N = √(C/120) parameters and trains on D = 20N tokens. If you can't get 20 tokens per parameter, make the model smaller. An optimally-trained smaller model will beat an undertrained larger one at the same compute cost.
Hyperparameter Optimization: Searching Wisely
The model architecture and distributed training setup get the most attention, but there's a quieter question that determines whether any training run succeeds: did you pick reasonable hyperparameters?
Hyperparameters are the choices you make before training starts — learning rate, batch size, weight decay, dropout rate, number of layers, optimizer choice, learning rate schedule. Gradient descent can't help with these. You have to search for them.
Grid search tries every combination from a predefined grid. It's exhaustive and exponential: 3 learning rates × 3 batch sizes × 3 dropout rates = 27 experiments. Add two more hyperparameters and you're at 243. The problem: grid search allocates equal budget to every dimension, even the ones that barely matter.
Random search samples randomly from distributions. It sounds careless, but it's surprisingly effective. Bergstra & Bengio (2012) showed why: most hyperparameters don't matter equally. Learning rate might explain 80% of the performance variation, while dropout explains 5%. A 5×5 grid gives you only 5 unique learning rates across 25 experiments. Random search gives you 25 unique learning rates — sampling the important dimension 5× more densely. With 60 random trials, you have a 95% chance of finding a configuration within the top 5% of the search space for the dimensions that matter.
Bayesian optimization goes further. It builds a surrogate model — a probabilistic model of how validation loss varies as a function of hyperparameters — and uses an acquisition function to decide where to sample next, balancing exploitation (near the best results so far) and exploration (regions of high uncertainty). Optuna is the go-to tool, using Tree-structured Parzen Estimators (TPE) by default. Combined with trial pruning (killing bad runs early), Bayesian optimization can find better hyperparameters in 2–5× fewer trials than random search.
import optuna
def objective(trial):
lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
dropout = trial.suggest_float('dropout', 0.0, 0.5)
model = build_model(dropout=dropout)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
weight_decay=weight_decay)
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, train_loader)
val_loss = evaluate(model, val_loader)
trial.report(val_loss, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
return val_loss
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)
print(study.best_params)
The practical priority order, from hundreds of experiments: learning rate is almost always the most important hyperparameter. A 10× wrong learning rate is worse than getting everything else wrong combined. Batch size matters next (affects both optimization dynamics and generalization). Weight decay is critical for transformers, less so for CNNs with dropout. Architecture choices (depth, width) matter less than most people think within a reasonable range.
Most models are undertrained, not undertuned. Before launching a 200-trial HPO study, ask: Am I training long enough? Is my data pipeline a bottleneck? Is my learning rate schedule reasonable? Fixing your training recipe — longer warmup, proper weight decay, cosine schedule — often delivers more improvement than any amount of hyperparameter search. HPO is the last step, not the first.
Putting It Together: The Decision Tree
If you're staring at a training job wondering which of these techniques to apply, here's the decision process I use:
Start with a single GPU. Enable mixed precision (BF16 if your GPU supports it, FP16 with GradScaler otherwise). Enable Flash Attention if you're training a transformer. These are free wins — no tradeoffs, no added complexity worth worrying about.
If you hit OOM, reduce batch size first. If your effective batch size gets too small for good convergence, add gradient accumulation. If activations are the bottleneck (check with torch.cuda.memory_summary()), enable gradient checkpointing and accept the ~30% slowdown.
If one GPU is too slow, move to DDP. Scale your learning rate linearly with the effective batch size and add a warmup period. Use torchrun as your launcher.
If the model doesn't fit on one GPU even with mixed precision and checkpointing, move to FSDP (PyTorch native) or DeepSpeed ZeRO (if you need more advanced features). Start with the lowest ZeRO stage that fits.
If you're training at true scale — tens of billions of parameters — you're in 3D parallelism territory. Use Megatron-LM or DeepSpeed's pipeline parallelism, combined with tensor parallelism within nodes and data parallelism across pipeline replicas. At this point you probably have a team of infrastructure engineers, and this section has given you enough vocabulary to have productive conversations with them.
And before you commit significant compute to any large training run, do the Chinchilla math. Make sure your model size and data budget are balanced. An optimally-trained 7B model will outperform an undertrained 13B model at the same compute cost.
Wrapping Up
If you're still with me, thank you. I hope it was worth it.
We started with a simple problem — GPU memory runs out — and traced it all the way through: mixed precision to halve your memory and double your speed, gradient accumulation and checkpointing to squeeze more out of a single GPU, DDP to scale across multiple GPUs with near-linear throughput, FSDP and DeepSpeed ZeRO to shard the model when it doesn't fit on one GPU, tensor and pipeline parallelism to split the computation itself, Flash Attention to fix the memory hierarchy bottleneck in transformers, and scaling laws to tell us how much model and how much data our compute budget can actually support.
My hope is that the next time you see a CUDA out-of-memory error — or someone mentions ring all-reduce and pipeline bubbles in a design review — instead of nodding along and quietly panicking, you'll have a pretty good mental model of what's happening under the hood and which lever to pull next.