PyTorch Training
PyTorch won the framework war — and it won by being honest about what it is. There's no .fit() hiding the machinery behind a curtain. You write the loop. You own every line. That means understanding what a tensor actually remembers, how autograd tapes your computation graph and then eats the tape during .backward(), why nn.Module quietly registers parameters behind your back, and how DataLoader's multiprocessing pipeline can silently bottleneck your entire training run. This section builds PyTorch from the ground up — tensors, then autograd, then the training loop, then Dataset/DataLoader, then nn.Module internals, then the real-world patterns that separate a training script that works from one that works well: mixed precision, gradient accumulation, gradient checkpointing, torch.compile, and distributed training. If you learn one framework deeply, make it this one.
The Confession
I avoided writing a real training loop for longer than I'd like to admit. Every time someone showed me PyTorch code, I'd nod along, copy-paste the loop from a tutorial, change the model class, and pray it worked. When it didn't, I'd stare at the loss curve and wiggle hyperparameters. I didn't understand what zero_grad() actually cleared, or why the graph vanished after .backward(), or what pin_memory was pinning. Finally the discomfort of not understanding the machinery I depended on every day grew unbearable. Here is that dive.
PyTorch was released by Facebook AI Research in 2016, building on the Lua-based Torch library. It introduced a radical idea for its time: build the computation graph dynamically as your Python code runs, instead of pre-declaring a static graph. That idea — define-by-run — turned out to be what researchers and engineers actually wanted. PyTorch now dominates arXiv, powers Hugging Face, and is increasingly the default in production via TorchServe, ONNX export, and torch.compile.
Before we start, a heads-up. We're going to walk through tensors, automatic differentiation, multiprocessing, GPU memory management, and even some compiler internals. You don't need to know any of it beforehand. We'll add what we need, one piece at a time.
This isn't a short journey, but I hope you'll be glad you came.
What We'll Cover
Tensors: The Thing That Remembers · Autograd: The Tape Recorder · The Training Loop — Seven Lines That Run the World · Dataset and DataLoader: The Data Pipeline · nn.Module: The Building Block That Registers Itself · Device Management: The Same-Room Rule · Saving and Loading: Checkpoints Done Right · Inference: Two Switches, Not One · Rest Stop · The Patterns That Matter: Gradient Accumulation · Mixed Precision: Two Number Systems · Gradient Checkpointing: Trading Compute for Memory · torch.compile: The JIT That Actually Works · Distributed Training: DDP and FSDP · The Gotchas That Will Eat Your Afternoon · Putting It All Together
Tensors: The Thing That Remembers
Everything in PyTorch starts with a tensor. A tensor is a multi-dimensional array — conceptually identical to a NumPy ndarray, but carrying two capabilities that change everything: it can live on a GPU, and it can remember how it was created.
Let's build our running example. Imagine we're training a tiny sentiment classifier — three movie reviews, each represented as a small vector of features. We'll use this classifier throughout the section, starting absurdly small and gradually adding complexity until we have a production-grade training pipeline.
import torch
# Our three movie reviews as feature vectors (3 samples, 4 features each)
reviews = torch.tensor([
[0.8, 0.1, 0.9, 0.2], # positive review
[0.1, 0.9, 0.2, 0.8], # negative review
[0.5, 0.5, 0.5, 0.5], # ambiguous review
], dtype=torch.float32)
labels = torch.tensor([1, 0, 1], dtype=torch.long) # 1=positive, 0=negative
That dtype=torch.float32 matters more than it looks. Tensors track their data type, device, and shape as first-class properties. When two tensors meet in an operation, all three must be compatible — or PyTorch throws a RuntimeError. This strictness feels annoying at first. It saves you from silent numerical bugs later.
NumPy interop is nearly free. torch.from_numpy() shares the underlying memory — no copy — and .numpy() goes the other way. But there's a catch: shared memory means changing one changes the other. If that surprises you during debugging, you're not the first.
import numpy as np
arr = np.array([1.0, 2.0, 3.0])
t = torch.from_numpy(arr) # shared memory — no copy
arr[0] = 99.0
print(t[0]) # tensor(99.) — surprise
For GPU tensors, the story changes. .to('cuda') copies data to GPU memory. The CPU and GPU tensors are now independent.
Autograd: The Tape Recorder
If tensors are the nouns of PyTorch, autograd is the verb. It's the system that makes training possible, and understanding it — really understanding it — is the difference between writing a training loop and knowing what your training loop does.
The analogy I keep coming back to: imagine you're recording a cooking show. Every time you chop a vegetable, every time you add a spice, the camera records it. That recording is the computation graph. At the end of the meal, someone asks "how did the salt affect the final taste?" — and you rewind the tape to trace the path from salt to the finished dish. That rewind is .backward().
Setting requires_grad=True on a tensor is like turning the camera on. From that moment, every operation involving that tensor gets recorded. Each operation creates a Function node in the graph, and each tensor carries a grad_fn attribute pointing back to the Function that created it.
x = torch.tensor(3.0, requires_grad=True)
y = x ** 2 + 2 * x + 1 # y = x² + 2x + 1
print(y.grad_fn) # <AddBackward0> — the last operation that created y
y.backward() # rewind the tape: dy/dx = 2x + 2 = 8.0
print(x.grad) # tensor(8.)
Here's what happened under the hood. When we computed x ** 2, PyTorch created a PowBackward node and saved x inside it (needed to compute the derivative of x²). When we computed + 2 * x, it created a MulBackward and an AddBackward. When we computed + 1, another AddBackward. These nodes form a directed acyclic graph from y back to x.
Calling y.backward() walks this graph in reverse — that's the "back" in backpropagation — computing ∂y/∂x using the chain rule at each step. The result lands in x.grad. And then the graph is destroyed. Gone. Freed from memory. Next forward pass builds a fresh one.
This is why it's called define-by-run: the graph IS your Python code. Want an if statement? Fine. Want a for loop whose length depends on the input? Fine. The graph captures whatever actually executed.
def weird_function(x):
if x.item() > 0:
return x ** 2
else:
return x ** 3
x = torch.tensor(2.0, requires_grad=True)
y = weird_function(x)
y.backward()
print(x.grad) # tensor(4.) — gradient of x², because x > 0
I'll be honest — the first time I saw that a dynamic graph could handle arbitrary Python control flow and still compute correct gradients, I didn't believe it. The trick is that autograd doesn't know about the if statement. It only sees the operations that actually ran. If x > 0, it recorded x ** 2 and never saw x ** 3. The graph is a recording of what happened, not a description of what could happen.
A few things to internalize about autograd, because they'll bite you otherwise. Gradients accumulate by default. Calling .backward() adds to whatever is already in .grad, rather than replacing it. This is a deliberate design choice (we'll use it for gradient accumulation later), but it means you must zero gradients before each training step. Also, the tape is gone after .backward() — if you call it twice on the same graph, you get a RuntimeError unless you pass retain_graph=True. And .detach() severs a tensor from the graph, producing a new tensor that looks identical but has no history. Use this for logging values without leaking memory.
The Training Loop — Seven Lines That Run the World
Keras gives you .fit(). PyTorch gives you seven lines and says "your problem now." This is a feature, not a limitation. Those seven lines are the heartbeat of every PyTorch project, and knowing why each one exists — not what it does, but why — is what separates someone who uses PyTorch from someone who understands it.
Let's train our sentiment classifier. One weight matrix, one bias, one loss function, one optimizer:
import torch
import torch.nn as nn
# Our toy data
reviews = torch.tensor([[0.8,0.1,0.9,0.2],[0.1,0.9,0.2,0.8],[0.5,0.5,0.5,0.5]])
labels = torch.tensor([1, 0, 1])
# A single linear layer: 4 features → 2 classes
model = nn.Linear(4, 2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(5):
optimizer.zero_grad() # 1. clear old gradients
outputs = model(reviews) # 2. forward pass — builds the graph
loss = criterion(outputs, labels)# 3. compute scalar loss
loss.backward() # 4. walk graph backward — compute gradients
optimizer.step() # 5. update parameters using gradients
print(f"Epoch {epoch}: loss={loss.item():.4f}")
Five lines inside the loop, but the real training loop in practice has seven core operations. Let's see the full canonical pattern with batched data, GPU handling, and everything a real project needs:
model.train() # set training mode
for epoch in range(num_epochs):
for inputs, targets in dataloader:
inputs = inputs.to(device) # move data to GPU
targets = targets.to(device)
optimizer.zero_grad() # clear accumulated gradients
outputs = model(inputs) # forward pass (graph is built here)
loss = criterion(outputs, targets)# compute loss
loss.backward() # backward pass (graph is consumed here)
optimizer.step() # update weights
scheduler.step() # adjust learning rate
Let me walk through each line, because I've been bitten by every single one.
model.train() toggles two specific layer behaviors. Dropout layers start randomly zeroing neurons (they're dormant during eval). BatchNorm layers switch to computing statistics from the current batch instead of using stored running averages. That's all it does — it has zero effect on gradient computation. But forget to call it after an eval(), and your model trains with dropout disabled and BatchNorm frozen. The loss curve will look fine-ish. Your model will underperform. You'll blame the learning rate. I've done this.
inputs.to(device) moves tensors to whatever device your model lives on. If your model is on GPU and your data is on CPU, you get an immediate RuntimeError. Every tensor in an operation must be on the same device — think of it as the same-room rule. I'll cover the subtle asymmetry between .to() on modules vs. tensors later.
optimizer.zero_grad() is the line that catches everyone at least once. Because gradients accumulate by default, skipping this means each step's gradients pile on top of the previous step's. Your loss appears to converge, but the model behaves erratically because the effective gradient is a running sum of everything that came before. The fix is trivially easy once you know the problem — the diagnosis is not, because nothing crashes.
outputs = model(inputs) runs the forward pass. Calling a model like a function invokes its forward() method plus any registered hooks (more on those later). This is where the computation graph gets built — every matmul, every activation, every normalization becomes a node that .backward() will later traverse.
loss = criterion(outputs, targets) computes a scalar loss. A detail worth knowing: nn.CrossEntropyLoss expects raw logits, not probabilities. It applies LogSoftmax internally using the LogSumExp trick, which is more numerically stable than doing softmax() followed by log() yourself. If you feed it softmax'd outputs, your loss function is effectively computing softmax twice, and your gradients are wrong. This has ruined entire training runs for people.
loss.backward() is where the magic happens. Starting from the loss scalar, autograd traverses the computational graph backward, applying the chain rule at each node, and deposits ∂loss/∂param into every parameter's .grad attribute. After this call, the graph is destroyed — freed from memory. This is why you can't call .backward() twice without retain_graph=True.
optimizer.step() reads the .grad on each parameter and updates the parameter accordingly. For SGD, it's literally param -= lr * param.grad. For Adam, it's more involved — maintaining running averages of gradients and squared gradients to compute adaptive learning rates. The optimizer holds references to your model's parameters from when you created it with model.parameters().
There's one more line worth mentioning: scheduler.step() adjusts the learning rate. Some schedulers (CosineAnnealing, StepLR) are called per-epoch. Others (OneCycleLR) are called per-batch. Calling it at the wrong frequency is a subtle bug — your LR schedule will either be too aggressive or barely noticeable. Always check the scheduler's documentation.
Dataset and DataLoader: The Data Pipeline
The training loop above fed all three reviews at once. That works for toy data. Real datasets have millions of samples, live on disk, and need transformations. PyTorch separates this into two pieces: Dataset defines what each sample is, and DataLoader handles how samples become batches.
Let's grow our sentiment classifier. Instead of three hardcoded vectors, imagine we have a CSV of movie reviews with features pre-extracted.
from torch.utils.data import Dataset, DataLoader
class ReviewDataset(Dataset):
def __init__(self, csv_path):
import pandas as pd
df = pd.read_csv(csv_path)
self.features = torch.tensor(df.drop('label', axis=1).values, dtype=torch.float32)
self.labels = torch.tensor(df['label'].values, dtype=torch.long)
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
Two methods. That's the contract. __len__ returns the number of samples. __getitem__ returns one sample given an index. PyTorch never calls your Dataset on a batch of indices at once — it calls __getitem__ individually and the DataLoader assembles the batch.
The DataLoader is where things get interesting — and where performance bottlenecks silently live.
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True, # randomize order each epoch
num_workers=4, # parallel loading processes
pin_memory=True, # page-locked memory for faster GPU transfer
prefetch_factor=2, # batches prefetched per worker
persistent_workers=True, # keep workers alive between epochs
drop_last=True, # drop incomplete final batch
)
When num_workers=0, data loading happens in the main process, sequentially. Your GPU finishes a training step, waits for the next batch, processes it, waits again. A significant chunk of training time is wasted on loading.
With num_workers=4, PyTorch spawns four separate Python processes. Each one independently calls your __getitem__, loads data, applies transforms, and places the result in a shared-memory queue. The main process pulls from this queue. While the GPU crunches one batch, workers are already preparing the next. With prefetch_factor=2, each worker prefetches two batches ahead — so up to num_workers × prefetch_factor batches are waiting in the queue at any time.
Think of it like a kitchen with four prep cooks and one chef. The chef (GPU) should never be waiting for ingredients. If the chef finishes a dish and there's nothing prepped, the restaurant slows down. You want enough prep cooks to keep the chef busy, but not so many that they're bumping into each other and eating up all the counter space (RAM).
pin_memory=True is the detail that ties the pipeline to the GPU. Page-locked (pinned) CPU memory enables asynchronous, non-blocking transfers to the GPU via cuda.Stream. Without it, each CPU→GPU transfer must wait for the copy to complete before continuing. With it, the transfer happens in the background while other work proceeds. Always use this when training on CUDA.
persistent_workers=True keeps worker processes alive between epochs. Without it, PyTorch kills and respawns all workers at every epoch boundary — that's 5-15 seconds of pure overhead per epoch for large datasets. With it, workers stay warm.
A word about worker_init_fn: if your Dataset applies random augmentations, each worker process is forked from the same parent — and inherits the same random seed. Without setting a per-worker seed, all four workers produce identical augmentations. The fix:
def worker_init(worker_id):
np.random.seed(42 + worker_id)
loader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init)
nn.Module: The Building Block That Registers Itself
Every model in PyTorch is an nn.Module. Every layer is an nn.Module. A model composed of layers is a Module of Modules. It's turtles all the way down.
The contract is two methods: define your layers in __init__, define the computation in forward. But the machinery underneath that contract is what makes nn.Module powerful — and occasionally surprising.
Let's build a proper classifier for our reviews:
import torch.nn as nn
import torch.nn.functional as F
class SentimentClassifier(nn.Module):
def __init__(self, input_dim=4, hidden_dim=32, num_classes=2):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_dim)
self.dropout = nn.Dropout(0.3)
self.layer2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.layer1(x)
x = self.bn(x)
x = F.relu(x)
x = self.dropout(x)
x = self.layer2(x)
return x # raw logits — never apply softmax here
When you write self.layer1 = nn.Linear(input_dim, hidden_dim), something quiet happens. nn.Module overrides Python's __setattr__. It inspects the thing being assigned. If it's an nn.Parameter, it goes into an internal _parameters dictionary. If it's another nn.Module, it goes into _modules. This is how model.parameters() can recursively discover every learnable weight in a nested model — it walks _modules and _parameters all the way down.
This auto-registration is why storing layers in a plain Python list is a trap:
# BROKEN — parameters invisible to optimizer
class BadModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(10, 10) for _ in range(3)] # plain list!
# CORRECT — use nn.ModuleList
class GoodModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
The plain list bypasses __setattr__. PyTorch never knows those layers exist. model.parameters() returns nothing. The optimizer has nothing to update. Training runs, loss never improves, and nothing crashes. I've spent an afternoon on this exact bug.
There's another registration mechanism that's less well-known but equally important: register_buffer. It stores tensors that are part of the model's state but aren't parameters — things like BatchNorm's running mean and variance. Buffers are saved in state_dict(), moved with .to(device), but never updated by the optimizer.
class ModelWithState(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.register_buffer('running_count', torch.zeros(1))
def forward(self, x):
self.running_count += 1 # tracked in state_dict, moves with .to()
return self.linear(x)
Hooks are the third mechanism worth knowing. register_forward_hook attaches a function that runs after every forward pass of a module — useful for extracting intermediate features, debugging, or monitoring activations. register_full_backward_hook does the same for the backward pass — useful for gradient analysis or custom gradient modification. The old register_backward_hook was deprecated because it could give incorrect results with modules that have multiple outputs.
Device Management: The Same-Room Rule
Every tensor in an operation must be on the same device. Model on GPU, data on CPU? RuntimeError. Two tensors on different GPUs? RuntimeError. This is the same-room rule, and it's PyTorch's most common source of errors for beginners.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentClassifier().to(device) # in-place for modules
inputs = torch.randn(16, 4)
inputs = inputs.to(device) # NOT in-place for tensors — must reassign
Here's the critical asymmetry that trips everyone up at least once. For nn.Module, .to(device) moves all parameters and buffers in-place and returns self. For tensors, it returns a new tensor on the target device — the original stays where it was. Writing tensor.to(device) without reassigning is a no-op that silently does nothing. Your model crashes on the next line with a device mismatch, and the error message doesn't tell you the real cause.
For Apple Silicon: use torch.device('mps') instead of 'cuda'. Not all operations are supported on MPS yet, but coverage improves with every PyTorch release.
Saving and Loading: Checkpoints Done Right
Always save the state_dict(), never the model object. Saving the whole model uses Python's pickle, which embeds your class structure, import paths, and module layout into the saved file. Rename a class? Move a file? Change an import? Your checkpoint is broken.
# Save weights only
torch.save(model.state_dict(), 'model_weights.pt')
# Load weights — create architecture first, then load state
model = SentimentClassifier()
model.load_state_dict(torch.load('model_weights.pt', map_location=device))
model.to(device)
model.eval()
For resuming training after a crash, you need the full training state — not just model weights, but optimizer state (momentum buffers for Adam), scheduler state, the current epoch, and your best validation loss. Without the optimizer state, Adam's adaptive learning rates reset to zero, and your training takes a visible hit.
# Save full checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': best_val_loss,
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')
# Resume
ckpt = torch.load('checkpoint_epoch_42.pt', map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
start_epoch = ckpt['epoch'] + 1
The map_location=device argument is essential. If you saved on GPU 0 and load on CPU (or a different GPU), map_location handles the remapping. Without it, PyTorch tries to load onto the original device, which may not exist on your current machine.
For partial loading — say you've added new layers to a model but want to keep the old weights — use strict=False:
model.load_state_dict(ckpt['model_state_dict'], strict=False)
# Missing keys: new layers initialized randomly
# Unexpected keys: old layers that no longer exist are ignored
Inference: Two Switches, Not One
Running inference correctly requires flipping two independent switches. Forgetting either one is a bug — and the bugs are different.
model.eval() # switch 1: change layer behavior
with torch.inference_mode(): # switch 2: disable gradient tracking
outputs = model(inputs.to(device))
predictions = outputs.argmax(dim=1)
model.eval() changes how specific layers behave. Dropout stops zeroing neurons — all neurons are active. BatchNorm switches from per-batch statistics to its stored running mean and variance. Without this, dropout randomly kills neurons during inference, making your predictions nondeterministic, and BatchNorm uses noisy single-batch statistics instead of the stable running averages it accumulated during training.
torch.inference_mode() disables the tape recorder entirely. No computation graph is built, no gradient metadata is stored. This saves significant memory (the graph for a large model can consume gigabytes) and speeds up computation. It's the modern replacement for torch.no_grad() — stricter (tensors created inside can't accidentally flow into later gradient computations) and faster.
These two switches are independent. eval() changes layer behavior but doesn't touch gradients. inference_mode() disables gradients but doesn't touch layer behavior. You need both. Using one without the other is a real bug that real people ship to production.
After evaluation, remember to switch back: model.train(). Otherwise your next training epoch runs with dropout disabled.
Rest Stop
Congratulations on making it this far. You can stop here if you want.
You now have a complete mental model of PyTorch training: tensors that record their history, autograd that rewinds the tape, the seven-line training loop, Dataset/DataLoader for feeding data, nn.Module for building models, device management, checkpointing, and inference. This is enough to write a working training pipeline for most projects.
But it doesn't tell the whole story. Real training runs hit GPU memory limits, take hours or days, run across multiple GPUs, and need every bit of speed the hardware can offer. The next sections cover the patterns that address these realities: gradient accumulation, mixed precision, gradient checkpointing, torch.compile, and distributed training.
The short version: gradient accumulation simulates larger batches without more memory, mixed precision runs most operations in float16 for 2× speed, gradient checkpointing trades compute for memory by recomputing activations during backward, torch.compile JIT-compiles your model for 15-40% speedup, and DDP/FSDP scale training across GPUs. There. You're 70% of the way.
But if the discomfort of not knowing what's underneath is nagging at you, read on.
The Patterns That Matter: Gradient Accumulation
You want a batch size of 256 but your GPU can only fit 64 samples. The typical advice is "use gradient accumulation." But what does that actually mean, mechanically?
Remember that autograd accumulates gradients by default — that thing that seemed like a footgun earlier? It's now a feature. Call .backward() four times before calling optimizer.step(), and the .grad on each parameter contains the sum of four backward passes' worth of gradients. Divide the loss by the accumulation count to keep the gradient magnitude equivalent to what a single large batch would produce.
accum_steps = 4 # effective batch = 64 × 4 = 256
optimizer.zero_grad()
for step, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets) / accum_steps # normalize
loss.backward() # gradients accumulate
if (step + 1) % accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
The math works out: if each mini-batch contributes gradients g₁, g₂, g₃, g₄, the accumulated gradient is (g₁ + g₂ + g₃ + g₄) / 4, which is the mean gradient over 256 samples — exactly what you'd get from a single batch of 256. Going back to our kitchen analogy: the chef is tasting four dishes and averaging the feedback before adjusting the recipe, rather than adjusting after every single dish.
One subtlety: BatchNorm still sees batches of 64, not 256. Its running statistics and per-batch normalization use the actual batch size. For very small effective batches, consider using GroupNorm or LayerNorm instead.
Mixed Precision: Two Number Systems
Modern GPUs have specialized hardware — Tensor Cores on NVIDIA, Matrix Engines on AMD — that operate on float16 or bfloat16 values at roughly 2× the throughput of float32. Mixed precision training exploits this: run the expensive operations (matrix multiplications, convolutions) in lower precision for speed, keep the sensitive operations (loss computation, softmax, weight updates) in float32 for accuracy.
PyTorch's Automatic Mixed Precision (AMP) handles the precision routing automatically:
from torch.amp import GradScaler, autocast
scaler = GradScaler('cuda')
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
with autocast('cuda'): # forward pass in mixed precision
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward() # backward in scaled float16
scaler.unscale_(optimizer) # unscale gradients for clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # step only if no inf/nan
scaler.update() # adjust scale factor
The autocast context manager hooks into PyTorch's dispatcher. For each operation, it checks a policy: safe operations (matmul, conv, linear) get cast to float16; risky operations (softmax, log, loss functions) stay in float32. You don't choose which operations are cast — PyTorch does, based on extensive numerical analysis by the core team.
The GradScaler solves a specific problem with float16: its dynamic range is tiny (roughly 6×10⁻⁵ to 6.5×10⁴). Small gradients — especially in early layers of deep networks — can underflow to zero. GradScaler multiplies the loss by a large factor (say 1024) before backward, which scales all gradients up into the representable range. Before the optimizer step, it divides them back down. If it detects inf or nan (meaning the scale was too large), it skips the optimizer step and reduces the scale factor. Over time, the scale factor finds a sweet spot automatically.
An important development: bfloat16 has the same exponent range as float32 (8 exponent bits vs. float16's 5), so gradient underflow is much less likely. On Ampere and newer NVIDIA GPUs (A100, H100), bfloat16 is increasingly preferred because you often don't need GradScaler at all:
with autocast('cuda', dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward() # no scaler needed with bfloat16
optimizer.step()
Gradient Checkpointing: Trading Compute for Memory
During a forward pass, PyTorch stores intermediate activations at every layer — because they're needed during backward to compute gradients. For a transformer with 24 layers and batch size 32, those activations can consume 10+ GB. Gradient checkpointing says: don't store them. Recompute them during backward.
from torch.utils.checkpoint import checkpoint
class CheckpointedModel(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(nn.Linear(256, 256), nn.ReLU())
self.block2 = nn.Sequential(nn.Linear(256, 256), nn.ReLU())
self.head = nn.Linear(256, 10)
def forward(self, x):
x = checkpoint(self.block1, x, use_reentrant=False)
x = checkpoint(self.block2, x, use_reentrant=False)
return self.head(x)
With checkpointing, during forward, only the inputs to each block are saved. During backward, when autograd needs the intermediate activations from block1, it re-runs block1's forward pass on the fly to regenerate them. This roughly halves memory usage for a model with many sequential blocks, at the cost of one additional forward pass. For N layers, memory drops from O(N) to O(√N) in the optimal checkpointing strategy.
I'm still developing my intuition for exactly when checkpointing pays off. The tradeoff depends on the ratio of memory savings to compute overhead. For transformers with many identical blocks, it's almost always worth it. For models with lightweight layers, the overhead of recomputation may not justify the memory savings.
One critical requirement: the checkpointed function must be deterministic. If it contains dropout with a different random seed on recomputation, the gradients will be wrong. PyTorch handles this by saving and restoring the RNG state, but custom stochastic operations need explicit care.
torch.compile: The JIT That Actually Works
PyTorch 2.0 introduced torch.compile, and — I'll be honest — after years of abandoned JIT attempts (TorchScript, I'm looking at you), I was skeptical. But torch.compile is different. It actually works on real code, and the speedups are meaningful.
model = SentimentClassifier().to(device)
model = torch.compile(model) # one line — that's it
Under the hood, three systems work together. TorchDynamo hooks into the Python interpreter at the bytecode level, watching your model's forward pass execute and capturing the PyTorch operations into an FX graph — a clean intermediate representation. TorchInductor, the compiler backend, takes that FX graph and generates optimized machine code: fusing operations, eliminating memory round-trips, and generating custom Triton kernels for GPU (or C++ for CPU). The compiled code gets cached, and future calls skip the compilation step entirely.
When Dynamo encounters Python code it can't trace — dynamic control flow, calls to non-PyTorch libraries — it creates a graph break. The code on one side of the break runs as compiled code, the break itself runs in eager mode, and the code after the break is compiled separately. This fallback mechanism is why torch.compile works on real-world models where TorchScript would choke.
Three compilation modes to know:
torch.compile(model, mode='default') # safe, moderate speedup
torch.compile(model, mode='reduce-overhead') # uses CUDA graphs — lower launch overhead
torch.compile(model, mode='max-autotune') # benchmarks multiple kernel variants — slowest compile, fastest run
Typical speedups range from 15-40% depending on the model architecture and how memory-bound the operations are. The first forward pass is slow (compilation happens), but subsequent calls run at compiled speed. For training, this means the first batch of each epoch may be slow, but the rest fly.
Distributed Training: DDP and FSDP
When one GPU isn't enough — because your dataset is massive or your model takes days to converge — you need distributed training. PyTorch offers two main strategies: DistributedDataParallel (DDP) for data parallelism, and FullyShardedDataParallel (FSDP) when the model itself doesn't fit on one GPU.
DDP is the workhorse. Each GPU gets a complete copy of the model and a different slice of the data. After each backward pass, DDP synchronizes gradients across all GPUs using an all-reduce operation — typically a ring all-reduce via NCCL, where each GPU communicates only with its two neighbors in a ring topology, passing chunks of gradient tensors in N-1 steps for N GPUs. The clever part: DDP overlaps this communication with the backward pass. It groups gradients into buckets, and as soon as a bucket's gradients are ready, the all-reduce starts while the rest of backward continues. The GPU is computing and communicating simultaneously.
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
model = SentimentClassifier().to(local_rank)
model = DDP(model, device_ids=[local_rank])
# Training loop is almost identical — DDP handles gradient sync
# Launch with: torchrun --nproc_per_node=4 train.py
FSDP goes further. Instead of replicating the full model on each GPU, it shards parameters, gradients, and optimizer states across GPUs. Each GPU holds only a fraction of the model. During forward, it gathers the parameters it needs from other GPUs, computes, and releases them. This is the PyTorch equivalent of DeepSpeed's ZeRO-3 — it enables training models that are far larger than a single GPU's memory.
The tradeoff: DDP is simpler, more mature, and faster for models that fit in GPU memory. FSDP has more communication overhead but unlocks models that are otherwise untrain-able. If your model fits on one GPU, use DDP. If it doesn't, FSDP is your path.
The Gotchas That Will Eat Your Afternoon
I've catalogued these from personal experience and from watching others debug the same issues repeatedly. Each one is a real bug that doesn't crash — it silently degrades your results.
Forgetting optimizer.zero_grad() — Gradients pile up across batches. Loss appears to converge. Model performance plateaus far below what it should be. Nothing crashes. You blame the learning rate, the architecture, the data. I've wasted entire afternoons on this.
Forgetting model.eval() during inference — Dropout is still randomly killing neurons. BatchNorm uses noisy per-batch statistics instead of its smooth running averages. Eval metrics look worse than they should, and they vary between runs. You think your model is bad. It might not be.
The .to(device) asymmetry — Writing tensor.to('cuda') without tensor = tensor.to('cuda') does nothing. The tensor stays on CPU. You get a device mismatch error on the next line. Modules are in-place; tensors are not.
Storing loss instead of loss.item() for logging — If you append the loss tensor (not the scalar) to a list, you're keeping the entire computation graph alive. Memory usage grows linearly with training steps. Eventually your process gets OOM-killed and you blame your model's size. The fix: loss.item() extracts a Python float, severing the graph.
In-place operations breaking autograd — Operations like relu_() (note the underscore) or tensor[0] = value modify tensors in-place. If autograd saved that tensor for backward, and you modify it after, PyTorch throws "one of the variables needed for gradient computation has been modified by an inplace operation." The error message is confusing because the in-place operation might be far from where autograd saved the tensor. Avoid in-place ops during training unless you're sure they're safe.
.detach() vs .data — Both sever a tensor from the graph, but .data is dangerous: it returns a view that shares storage with the original, bypassing autograd safety checks. If you modify the result, you can silently corrupt gradients. Use .detach() — it's safe because autograd still knows about the original tensor.
CrossEntropyLoss with softmax'd inputs — nn.CrossEntropyLoss expects raw logits. It applies log-softmax internally. If your model already applies softmax in forward(), you're computing softmax twice — once explicitly and once inside the loss. Gradients become wrong, and the model converges to a worse solution. Return raw logits from forward(). Always.
Putting It All Together
Here is a complete, production-grade training script — our tiny sentiment classifier grown up. It combines every pattern from this section: the training loop, DataLoader, device management, mixed precision, gradient clipping, checkpointing, and evaluation.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import GradScaler, autocast
# ── Setup ──────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_train = torch.randn(1000, 4)
y_train = torch.randint(0, 2, (1000,))
X_val = torch.randn(200, 4)
y_val = torch.randint(0, 2, (200,))
train_loader = DataLoader(TensorDataset(X_train, y_train),
batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(TensorDataset(X_val, y_val),
batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
# ── Model ──────────────────────────────────────────────────
model = SentimentClassifier(input_dim=4, hidden_dim=64, num_classes=2)
model = model.to(device)
# model = torch.compile(model) # uncomment for PyTorch 2.0+ speedup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
scaler = GradScaler('cuda', enabled=(device.type == 'cuda'))
# ── Training ───────────────────────────────────────────────
best_val_loss = float('inf')
for epoch in range(20):
model.train()
train_loss = 0.0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad(set_to_none=True)
with autocast('cuda', enabled=(device.type == 'cuda')):
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
scheduler.step()
# ── Validation ─────────────────────────────────────────
model.eval()
val_loss, correct = 0.0, 0
with torch.inference_mode():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
val_loss += criterion(outputs, targets).item() * inputs.size(0)
correct += (outputs.argmax(1) == targets).sum().item()
avg_train = train_loss / len(train_loader.dataset)
avg_val = val_loss / len(val_loader.dataset)
val_acc = correct / len(val_loader.dataset)
print(f"Epoch {epoch+1:>2}/20 Train: {avg_train:.4f} "
f"Val: {avg_val:.4f} Acc: {val_acc:.1%}")
if avg_val < best_val_loss:
best_val_loss = avg_val
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'val_loss': avg_val,
}, 'best_model.pt')
If you're still with me, thank you. I hope it was worth it.
We started with three hardcoded tensors and a five-line loop. We learned how autograd records a computation graph like a tape recorder and then erases the tape during backward. We built the canonical training loop and understood why each of its seven lines exists. We saw how DataLoader orchestrates parallel workers, pinned memory, and prefetching to keep the GPU fed. We opened up nn.Module and found the hidden __setattr__ that auto-registers parameters. And we covered the patterns that make real training fast: mixed precision, gradient accumulation, checkpointing, compilation, and distribution.
My hope is that the next time you stare at a PyTorch training script — whether it's a hundred lines or ten thousand — instead of copy-pasting and hoping, you'll read each line and know exactly what it does, what it costs, and what breaks if you remove it. That's the difference between using a framework and understanding it.