Python Basics

Chapter 1: Python & Programming Foundations Bits & Bytes · Objects · Data Structures · DataLoader · Overloading · Context Managers
Bits, Bytes, and How Your Computer Remembers Things
What Is an Object, Really? — From C Structs to Python Names
Data Structures — What's Actually Happening Inside the Four Containers
PyTorch DataLoader — How Your Training Loop Actually Loads Data
Operator Overloading — Teaching Objects New Tricks
Functions Are Objects
Closures — Captured State
Context Managers — PyTorch's Train, Eval, and No-Grad
UTF-8 — How Text Becomes Bytes
Wrap-Up
Resources

UTF-8 — How Text Becomes Bytes

We've spent time with integers and floats — numbers the CPU handles natively. Text is different. A character isn't a number in any hardware sense; the CPU has no idea what an "A" is. Everything is bytes, so at some point every character has to be mapped to a sequence of bytes. The encoding that rules the modern internet — and that you will encounter constantly in ML preprocessing, file I/O, and tokenizer code — is UTF-8.

The Problem: 256 Slots Aren't Enough

ASCII (1963) solved the original problem: map the English alphabet, digits, and control characters to 7-bit codes. That's 128 slots — plenty for a teletype machine. Extended ASCII stretched to 8 bits (256 slots) to fit Western European accented characters. Then the world showed up: Arabic, Chinese, Japanese, Korean, Cyrillic, emoji, mathematical symbols. There are over 140,000 characters in use today. One byte cannot hold them.

The solution is Unicode — a universal catalog that assigns every character a code point: an integer in the range U+0000 to U+10FFFF (roughly 1.1 million possible slots, about 150,000 currently assigned). Unicode answers the question "what number does this character get?" It does not answer "how do I store that number in memory?" That's where UTF-8 comes in.

UTF-8: Variable-Width Encoding

UTF-8 (designed by Ken Thompson and Rob Pike in 1992, on a placemat in a New Jersey diner) encodes each code point as 1 to 4 bytes using a clever self-describing bit pattern. The key insight: the first byte tells you how many bytes the character uses.

Code point range Byte pattern Payload bits Example
U+0000 – U+007F 0xxxxxxx 7 bits 'A' → 0x41
U+0080 – U+07FF 110xxxxx 10xxxxxx 11 bits 'é' → 0xC3 0xA9
U+0800 – U+FFFF 1110xxxx 10xxxxxx 10xxxxxx 16 bits '€' → 0xE2 0x82 0xAC
U+10000 – U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 21 bits '😀' → 0xF0 0x9F 0x98 0x80
UTF-8 encoding table. The number of leading 1s in byte 1 tells you the total byte count. Continuation bytes always start with 10.

The bit pattern is self-synchronizing by design. If you land in the middle of a UTF-8 stream (say, at a random offset in a file), you can always tell whether you're looking at a start byte or a continuation byte: start bytes begin with 0 (ASCII) or 11; continuation bytes always begin with 10. You can scan backwards from any position to find the character boundary. UTF-16, by contrast, lacks this property for sequences that span two 16-bit code units — one of many reasons UTF-8 won.

Let's trace through a concrete example. The Euro sign has code point U+20AC (decimal 8364). In binary: 0010 0000 1010 1100 — that's 14 significant bits, which falls in the 3-byte range. Packing those 14 bits into the 1110xxxx 10xxxxxx 10xxxxxx template:

Code point:  U+20AC = 0010 0000 1010 1100
Split:       xxxx = 0010 | xxxxxx = 000010 | xxxxxx = 101100
Bytes:       1110 0010 | 10 000010 | 10 101100
             = 0xE2     | 0x82      | 0xAC

Peel off the framing bits (1110, 10, 10) and you recover the original code point. This is all your text editor, HTTP server, and Python runtime are doing under the hood every time they handle non-ASCII text.

Python: str Is Not Bytes

Python 3 made a clean separation that Python 2 famously bungled: str is a sequence of Unicode code points (abstract characters). bytes is a sequence of raw octets. They are completely different types. To move between them, you encode (str → bytes) or decode (bytes → str), always specifying an encoding.

s = "café"          # str: 4 code points — U+0063 U+0061 U+0066 U+00E9
b = s.encode("utf-8")
print(b)            # b'caf\xc3\xa9'   — 5 bytes (é is 2 bytes in UTF-8)
print(len(s))       # 4  (code points)
print(len(b))       # 5  (bytes)

# Decode back
print(b.decode("utf-8"))   # 'café'

# Wrong encoding → garbage or error
b.decode("latin-1")        # 'café'  — bit pattern reinterpreted incorrectly
b.decode("ascii")          # UnicodeDecodeError: byte 0xc3 out of range

The len() discrepancy — 4 characters but 5 bytes — is the source of a class of bugs in ML text preprocessing. A naive text[:max_len] truncation operates on bytes, not characters, and can cut a multi-byte sequence in half, producing an invalid UTF-8 stream. Always decode to str first, operate on code points, then encode back if needed.

BOM and UTF-16

You'll sometimes see files with a BOM (Byte Order Mark) — a magic header of 2–3 bytes that signals encoding and byte order. UTF-8 BOM is 0xEF 0xBB 0xBF and is almost never needed (UTF-8 has no endianness issue) but Windows tools love to emit it. Read files with encoding='utf-8-sig' in Python to silently strip it. UTF-16 requires a BOM (or you must specify utf-16-le / utf-16-be explicitly) because each character is 2 bytes and byte order matters. UTF-8 is the right default for all new work.

Why This Matters for ML

In practice you will hit UTF-8 in three places. First, reading data: always open text files with open(path, encoding='utf-8') (Python's default is locale-dependent on Windows — a silent disaster on cloud servers with different locale settings). Second, tokenizers: BPE tokenizers like GPT-2's operate on UTF-8 bytes, not characters. The vocabulary maps byte sequences to token IDs. When you see a tokenizer config with "byte_level": true, it is working at the UTF-8 byte level — the 256 possible byte values become the base vocabulary. Third, string length heuristics: if you're comparing string lengths, measuring tokens, or slicing context windows, always know whether you're counting code points or bytes.

import json

# Reading a dataset — always specify encoding
with open("dataset.jsonl", encoding="utf-8") as f:
    records = [json.loads(line) for line in f]

# Counting characters correctly
text = "こんにちは"   # 5 Japanese characters
print(len(text))                    # 5  (correct: code points)
print(len(text.encode("utf-8")))    # 15 (3 bytes each in UTF-8)

# Byte-level BPE: each byte is a valid "token" slot
vocab_base = {bytes([i]): i for i in range(256)}   # 256-entry base vocab

UTF-8 is one of those things you can use for years without thinking about it — until the day a user submits text with an emoji, or you try to concatenate a str and a bytes, or your carefully crafted 512-token context window silently truncates a Japanese sentence mid-character. Once you understand the bit layout, none of this is mysterious. It's just bytes with a well-defined framing scheme.

Floating-Point — Where Things Get Weird

Integers are clean — every integer within range has an exact binary representation. Floating-point numbers are a different story entirely, and this is where most ML precision bugs come from.

A floating-point number follows the IEEE 754 standard (published in 1985, updated in 2008). The idea is borrowed from scientific notation: instead of writing 6,022,000,000,000,000,000,000,000, a scientist writes 6.022 × 1023. The computer does the same thing in binary.

A 32-bit float (float32) is carved into three fields:

S 1 bit sign EXPONENT 8 bits biased by 127 MANTISSA (fraction) 23 bits implicit leading 1 value = (-1)S × 2(E-127) × 1.mantissa

The sign bit (1 bit) determines positive or negative. The exponent (8 bits) tells you the scale — how big or small the number is. The mantissa (also called the significand or fraction, 23 bits) holds the precision digits. There's a hidden trick: the leading 1 before the decimal point is implicit (not stored), giving you effectively 24 bits of precision. The value is: (-1)sign × 2(exponent - 127) × 1.mantissa.

Let's trace through a concrete example. The number 6.5 in float32:

6.5 in binary is 110.1 (that's 4 + 2 + 0.5). In normalized scientific notation: 1.101 × 22. So the sign bit is 0 (positive), the exponent is 2 + 127 = 129 = 10000001 in binary (the 127 bias shifts the range so negative exponents don't need a separate sign), and the mantissa is 10100000000000000000000 (the 101 after the implicit leading 1, padded with zeros).

This representation has a fundamental consequence that bites people every day:

>>> 0.1 + 0.2
0.30000000000000004

0.1 in decimal is an infinitely repeating fraction in binary (0.0001100110011...), the same way 1/3 is 0.333... in decimal. The mantissa can only store 23 bits (for float32) or 52 bits (for float64), so it gets rounded. The error is tiny, but it's there. In ML, these rounding errors can accumulate over millions of gradient updates.

Here's why this matters for ML practitioners — the common float types you'll encounter:

TypeBitsExponentMantissaPrecision (decimal digits)Range
float16 (half)16510~3.3±65,504
bfloat161687~2.4±3.4×1038
float32 (single)32823~7.2±3.4×1038
float64 (double)641152~15.9±1.8×10308

Notice bfloat16 — Google Brain invented it specifically for deep learning. It has the same exponent range as float32 (so it can represent the same magnitudes, crucial for gradient values that can spike large), but with far fewer mantissa bits (less precision). The tradeoff: you lose precision in the small digits but you almost never overflow. Standard float16 overflows at 65,504 — which sounds like a lot until a gradient spike pushes a value past that threshold and your loss becomes inf, then NaN, and your training run is dead.

There are three special values in IEEE 754 that you'll encounter: inf (infinity, when the exponent is all 1s and mantissa is all 0s), -inf, and NaN (Not a Number, exponent all 1s with any non-zero mantissa). NaN is the only value in all of computing that is not equal to itself: NaN != NaN is True. That's by design — it signals that something went wrong, and comparing it to anything should raise suspicion.

import torch
x = torch.tensor(1e38, dtype=torch.float16)
print(x)  # tensor(inf) — overflowed

y = torch.tensor(1e38, dtype=torch.bfloat16)
print(y)  # tensor(1.0000e+38) — same range as float32, no overflow

I'm still developing my intuition for when to use which precision type. The general rule in 2024-era ML: train in bfloat16 or mixed precision (float32 master weights, bfloat16 forward/backward), accumulate gradients in float32, and quantize to int8 or int4 for inference. But the optimal recipe changes with every new GPU architecture.

How Bytes Sit in Memory

One more piece of the puzzle. Your computer's memory is a long, flat array of bytes, each with a numeric address. When you store a 4-byte float32, it occupies four consecutive addresses. But there's a quirk: which byte goes first?

Big-endian stores the most significant byte first (the "big end" — like writing a number left to right). Little-endian stores the least significant byte first. Most modern CPUs (x86, ARM in little-endian mode) use little-endian. Network protocols historically use big-endian. This almost never matters in Python — until you're reading binary model weight files saved on a different architecture, or interfacing with CUDA memory directly. Then it matters enormously.

import struct
val = 6.5
packed = struct.pack('f', val)    # pack as float32
print([f'0x{b:02x}' for b in packed])
# ['0x00', '0x00', '0xd0', '0x40'] — little-endian on x86

We now know what the raw material is — bits grouped into bytes, bytes interpreted as integers or floats according to specific rules, sitting at addresses in a flat memory array. But Python doesn't expose any of this directly. You never write malloc(4) or worry about byte order in normal Python code. That's because Python wraps every value in something called an object. And understanding what that object actually is — at the C level — is the key to understanding everything else about Python's behavior.

What Is an Object, Really? — From C Structs to Python Names

Here's the mental model shift that took me an embarrassingly long time. In most programming introductions, variables are described as boxes that hold values. You put the number 5 into a box labeled x. That mental model is actively wrong in Python, and it's the source of some of the most persistent bugs in production code. To understand why, we need to go one level deeper — into the C code that Python itself is written in.

CPython — The Interpreter You're Actually Running

When you type python in your terminal, you're running CPython — the reference implementation of Python, written in C. There are other implementations (PyPy, Jython, GraalPy), but CPython is what ships by default, what pip installs packages for, and what runs your PyTorch training loops. Every Python object you create is, under the hood, a C struct allocated on the heap.

The root of everything is PyObject, a C struct defined in Include/object.h in the CPython source code. Every single Python object — integers, strings, lists, your custom classes, even None — starts with this structure:

// Simplified from CPython source (Include/object.h)
typedef struct _object {
    Py_ssize_t ob_refcnt;    // reference count
    PyTypeObject *ob_type;   // pointer to the type object
} PyObject;

Two fields. That's it. Every Python object begins with a reference count (how many names or containers are pointing at this object) and a pointer to its type (which tells Python what operations this object supports). A Python integer isn't the raw number 42 sitting in a register. It's a PyLongObject — a struct that starts with the PyObject header and then adds fields to store the actual numeric value:

// Simplified from CPython source
typedef struct {
    PyObject ob_base;        // refcount + type pointer
    Py_ssize_t ob_size;      // number of "digits"
    uint32_t ob_digit[1];    // flexible array of 30-bit digits
} PyLongObject;

This is why Python integers have no size limit — they're stored as arrays of "digits" that can grow. A small number like 42 uses one digit. A number with thousands of decimal digits uses hundreds. The tradeoff: a simple integer that takes 4 bytes in C takes at least 28 bytes in Python (the PyObject header + size + one digit + memory allocator overhead). That overhead is the cost of Python's flexibility.

Names Are Post-It Notes, Not Boxes

Now here's where it all connects. When you write:

a = [1, 2, 3]

Python does three things. First, it creates a PyListObject on the heap — a C struct containing the PyObject header, a pointer to an array of PyObject* pointers (one for each element), and a length. Second, it creates (or finds existing) PyLongObjects for the integers 1, 2, and 3. Third, it binds the name a — which lives in the current namespace (a Python dictionary mapping strings to object pointers) — to the list object's address.

The name a doesn't contain the list. It points to the list. It's a post-it note stuck on the object, not a box holding it. If you write:

b = a

Python doesn't copy the list. It doesn't create a new PyListObject. It takes the name b and points it at the same object a already points to. The list's ob_refcnt goes from 1 to 2. Two names, one object.

NAMES (namespace dict) OBJECTS (heap) a b c [1, 2] refcnt: 2 | id: 0x7f3... [1, 2, 3] refcnt: 1 | id: 0x7f4... b = a → both names point to the SAME PyListObject (refcnt incremented, no copy)
Python variables are names in a namespace dictionary, pointing to PyObject structs on the heap — not boxes that hold values.

Let's see what this means concretely with our experiment tracker:

best_run = results[1]       # {"lr": 0.001, "epochs": 20, "accuracy": 0.91}
alias = best_run

How many dictionaries exist in memory right now? One. A single PyDictObject on the heap. The names best_run, alias, and results[1] are all entries in namespaces (or container internal arrays) that store the same pointer — the memory address of that one dictionary. If we modify through any of those names, every other name sees the change:

alias["accuracy"] = 0.95
print(best_run["accuracy"])  # 0.95 — same object, remember

This behavior has a name: aliasing. When two or more names refer to the same object in memory, they are aliases. You can verify this with Python's built-in id() function, which in CPython returns the memory address of the underlying PyObject:

print(id(best_run) == id(alias))  # True — same PyObject address
print(best_run is alias)          # True — 'is' checks identity, not equality

That is keyword checks whether two names point to the same C struct in memory — identity. The == operator checks whether two objects have the same value — equality (by calling __eq__). These are different questions. Two separate dictionaries can have identical contents (== returns True) while being distinct C structs at different addresses (is returns False).

Why This Design? Why Not Boxes?

This is the question I wish someone had answered for me earlier. In C, a variable is a box — a named region of memory at a fixed address with a fixed size. int x = 5; allocates 4 bytes on the stack, labels them x, and writes the bit pattern for 5 directly into those bytes. Assignment y = x copies the 4 bytes. This is fast and memory-efficient, but it means the type and size are frozen at compile time.

Python made a different choice. By making everything a heap-allocated PyObject referenced by name, you get: dynamic typing (the name x can point to an integer now and a string later, because the type is stored in the object, not the name), automatic memory management (when the reference count hits zero, the object gets freed — no free() calls), and the ability for everything to be first-class (functions, classes, modules are all objects you can pass around). The cost is overhead — every integer carries a 28+ byte wrapper — but for an orchestration language that delegates heavy computation to C and CUDA, that tradeoff is worth it.

Mutability — The Other Half of the Story

Not all objects let you change their contents. Python draws a hard line: some types are mutable (you can change the object in-place) and some are immutable (any "change" creates a new object).

Lists, dictionaries, and sets are mutable. Integers, floats, strings, and tuples are immutable. When you write x = 5; x = x + 1, you're not modifying the integer object 5. You're creating a new integer object 6 and rebinding the name x to point at it. The old object 5 still exists (its reference count decreased by 1), and if nothing else points to it, it'll eventually be freed.

a = [1, 2, 3]
print(id(a))       # 0x7f3... — address of the list object
a.append(4)
print(id(a))       # 0x7f3... — SAME address! Mutated in place.

s = "hello"
print(id(s))       # 0x7f4... — address of the string object
s = s + " world"
print(id(s))       # 0x7f5... — DIFFERENT address. New string created.

This is why x is None is the correct way to check for None, never x == None. There's exactly one None object in all of Python — a singleton PyObject. Identity check compares pointer addresses (one CPU instruction). Equality check calls __eq__, which could do anything on a custom class.

We now have the complete foundation: every Python value is a PyObject on the heap, names are entries in namespace dictionaries that store pointers to those objects, and mutability determines whether you can change an object in place or must create a new one. But our experiment tracker is still a list of dictionaries, and we need to think about why we'd choose a list versus a set versus a dictionary — not stylistically, but mechanically. What's actually happening inside these containers?

Data Structures — What's Actually Happening Inside the Four Containers

Python gives us four core container types, and the choice between them isn't stylistic — it has real performance consequences rooted in how each one is built at the C level. I avoided learning these internals for years, treating Big-O notation as something to memorize for interviews. Then I had a data pipeline that took 14 hours and a colleague brought it down to 20 minutes by swapping a list membership check for a set. That's when I realized: the performance isn't magic. It's architecture.

PYTHON CONTAINERS — WHAT LIVES INSIDE LIST dynamic array [ptr, ptr, ptr, ...] index O(1) · append amortized O(1) search O(n) · insert O(n) DICT compact hash table hash(key) → index → value get / set / del O(1) avg preserves insertion order (3.7+) SET hash table (keys only) hash(elem) → slot membership O(1) avg no duplicates · no ordering TUPLE fixed-size array 🔒 (ptr, ptr, ptr) hashable → usable as dict key ✓ immutable · less memory · cached Performance isn't magic — it's architecture.

List — A Dynamic Array of Pointers

A Python list is not a linked list. It's a dynamic array — a contiguous block of memory holding pointers to PyObjects. The C struct looks roughly like this:

// Simplified from CPython (Include/cpython/listobject.h)
typedef struct {
    PyObject ob_base;
    Py_ssize_t ob_size;       // current number of elements
    PyObject **ob_item;       // pointer to array of PyObject pointers
    Py_ssize_t allocated;     // capacity of the array
} PyListObject;

The key field is ob_item — a pointer to a C array of PyObject* pointers. Each pointer is 8 bytes (on a 64-bit system). The elements themselves live elsewhere on the heap; the list only stores their addresses. This means indexing is O(1): to get element i, Python computes ob_item + (i × 8) and dereferences the pointer. One multiplication, one memory access. Done.

Appending is clever. When the internal array fills up, Python doesn't allocate space for exactly one more element — it over-allocates. The growth pattern roughly follows: 0, 4, 8, 16, 24, 32, 40, 52, 64, ... Each resize allocates more space than needed, so future appends don't require a new allocation. This gives append an amortized O(1) time complexity — each individual append is O(1) on average, even though occasionally one triggers a resize that copies the entire array.

But searching — if x in my_list — is O(n). Python has to walk through every pointer in the array, dereference each one, and call __eq__ on the pointed-to object. If your list has a million elements, that's up to a million pointer dereferences and comparisons. For the deduplication pipeline that took 14 hours, this is exactly where the time went: if item in seen_list inside a loop over millions of items. That's O(n) × O(n) = O(n²).

Set — A Hash Table That Changes Everything

A Python set is a hash table — the same data structure that makes dictionaries fast, but storing only keys, no values. To understand why if x in my_set is O(1) while if x in my_list is O(n), we need to understand hashing.

A hash function takes any hashable object and maps it to an integer. Python's built-in hash() does this: hash("hello") returns some large integer, hash(42) returns 42, hash((1, 2)) returns some integer. The hash value is computed from the object's content. Two objects with equal values produce the same hash (requirement), but two different objects can produce the same hash (called a collision — the pigeonhole principle guarantees this when you map an infinite set of possible values to a finite set of integers).

Here's how a set uses hashing for O(1) lookup. Internally, the set maintains an array of slots (let's say 8 slots to start). When you add an element, Python computes hash(element) % num_slots to get a slot index, and stores the element (pointer) in that slot. When you check element in my_set, Python computes the same hash(element) % num_slots, jumps directly to that slot, and checks if the stored element is equal. No scanning. No walking through every element. One hash computation, one (or a few, in case of collision) slot checks.

# What "x in my_set" actually does under the hood:
# 1. Compute hash(x) → some integer, say 738291
# 2. Compute 738291 % table_size → slot index, say 3
# 3. Look at slot 3:
#    - Empty? → x is not in the set. Done.
#    - Occupied? → Compare stored element with x using ==
#      - Equal? → x IS in the set. Done.
#      - Not equal? → Collision. Probe next slot. Repeat.

When there's a collision (two elements hash to the same slot), Python uses open addressing with a perturbation scheme — it probes a sequence of other slots until it finds an empty one. As long as the table isn't too full (CPython resizes when it's about 2/3 full), the expected number of probes is roughly constant. That's why average-case lookup is O(1).

The tradeoff: you can only put hashable objects in a set. That means immutable types — integers, strings, tuples of hashable elements. You can't put a list in a set because lists are mutable: if you could, and then you mutated the list, its hash would change, and the set would "lose" it (it'd be in the wrong slot). Python prevents this by making list unhashable.

# The 14-hour pipeline fix:
seen_list = []                    # O(n) membership check
seen_set = set()                  # O(1) membership check

for item in million_items:
    if item not in seen_set:      # O(1) — hash, jump, done
        seen_set.add(item)
        process(item)

That's the difference between O(n²) and O(n). Between 14 hours and 20 minutes.

Dict — A Hash Table with Values Attached

A Python dict is essentially a set that stores key-value pairs instead of keys alone. Since Python 3.6 (and guaranteed since 3.7), dictionaries preserve insertion order. CPython achieves this with a clever compact dictionary design introduced in Python 3.6:

There are two arrays. The first is a hash table (sparse array of indices). The second is a dense array of (hash, key, value) entries, stored in insertion order. When you do d["lr"], Python computes hash("lr"), uses it to index into the sparse hash table, which gives an index into the dense array, which holds the key-value pair. If the key matches, you have your value. O(1) on average.

# What d["lr"] = 0.01 does internally:
# 1. hash("lr") → 2847291
# 2. 2847291 % sparse_size → slot 5 in hash table
# 3. sparse[5] = 0  (index into dense array)
# 4. dense[0] = (2847291, "lr", 0.01)
#
# Lookup d["lr"]:
# 1. hash("lr") → 2847291
# 2. 2847291 % sparse_size → slot 5
# 3. sparse[5] = 0 → check dense[0].key == "lr" → yes!
# 4. Return dense[0].value → 0.01

This compact design uses about 25% less memory than the old implementation (pre-3.6), because the sparse table only stores small integers (indices) instead of full key-value entries, and the dense array has no gaps.

For our experiment tracker, dictionaries are the natural choice for each experiment record — fast lookup by key name, and we can add or remove fields freely.

Tuple — The Immutable, Cached, Hashable Record

A tuple is the simplest container: a fixed-size array of PyObject* pointers, allocated once, never resized. No over-allocation, no growth strategy, no hash table. The C struct is lean:

// Simplified
typedef struct {
    PyObject ob_base;
    Py_ssize_t ob_size;
    PyObject *ob_item[1];    // fixed array, inline
} PyTupleObject;

Because tuples are immutable, Python can optimize aggressively. Small tuples (length 0 and 1) are cached and reused. The empty tuple () is a singleton — every time you write (), you get the same object. Tuples of immutable elements are themselves hashable, which means you can use them as dictionary keys or set elements. This is useful when you want to use a compound key:

# Track which hyperparameter combos you've tried
tried = set()
tried.add((0.01, 10))      # (lr, epochs) as tuple — hashable ✓
tried.add((0.001, 20))
print((0.01, 10) in tried)  # True, O(1) lookup

You can't do this with a list: tried.add([0.01, 10]) raises TypeError: unhashable type: 'list'.

Tuples also use less memory than lists because they don't need the allocated field or the indirection of a separate pointer array. The elements are stored inline in the struct. For returning multiple values from a function — return loss, accuracy — tuples are the idiomatic choice. Python actually creates a tuple behind the scenes.

Our containers are now more than black boxes — we can see the gears turning. But so far we've been working with data that fits in memory. What happens when your dataset is 50GB? You can't load it all at once. You need to load it piece by piece, on demand. That's where lazy iteration meets real-world ML infrastructure.

PyTorch DataLoader — How Your Training Loop Actually Loads Data

I'll be honest — I used DataLoader for over a year before I opened its source code. I knew you pass it a dataset, it gives you batches, and num_workers > 0 makes it faster. That was my entire mental model. Then I ran into a deadlock in a multiprocessing DataLoader on a shared cluster, and I realized I had no idea what was happening under the hood. So let's build it up from scratch.

The core idea is lazy iteration. Instead of loading your entire 50GB dataset into memory, you load one batch at a time. Each iteration of your training loop fetches 32 (or 64, or 128) samples, processes them, and discards them before fetching the next batch. Memory usage stays constant at O(batch_size), not O(dataset_size). That's the fundamental trick.

The Two Contracts: Dataset and DataLoader

PyTorch's data loading system is built on two abstractions. A Dataset defines what your data is and how to access individual samples. A DataLoader defines how to iterate over the dataset — batching, shuffling, parallelism.

A map-style Dataset implements two methods:

import torch
from torch.utils.data import Dataset, DataLoader

class ExperimentDataset(Dataset):
    def __init__(self, records):
        self.records = records

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        features = torch.tensor([r["lr"], r["epochs"]], dtype=torch.float32)
        target = torch.tensor(r["accuracy"], dtype=torch.float32)
        return features, target

__len__ tells the DataLoader how many samples exist. __getitem__ fetches one sample by index. That's the entire contract. The DataLoader never calls anything else on your Dataset.

What Happens Inside the DataLoader

When you write for batch in loader, here's what actually happens, step by step:

dataset = ExperimentDataset(results)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

for features, targets in loader:
    print(features.shape, targets.shape)
    # torch.Size([2, 2]) torch.Size([2])

The for loop calls iter(loader), which creates a _SingleProcessDataLoaderIter (or _MultiProcessingDataLoaderIter if num_workers > 0). This iterator object holds the state for one pass through the dataset.

WHAT HAPPENS WHEN YOU WRITE: for batch in loader 1 Sampler Generates indices: [2, 0, 1] (shuffled) or [0, 1, 2] (sequential) 2 BatchSampler Groups into batches: [[2, 0], [1]] batch_size=2 3 Fetch Calls dataset[2], dataset[0] for each index in the batch 4 Collate Stacks samples into tensors: torch.stack(samples) Result: ONE batch tensor features: tensor([[0.01, 50], [0.01, 10]]), targets: tensor([0.89, 0.82]) Memory: O(batch_size), not O(dataset_size) Previous batch is garbage-collected. Only one batch lives in memory at a time.

Four stages per batch: the Sampler generates a permutation of indices (shuffled for training, sequential for validation). The BatchSampler chunks those indices into groups of batch_size. The Fetcher calls dataset[idx] for each index in the current batch. And the Collate function (default_collate) stacks the individual samples into batched tensors.

The num_workers Machinery

When num_workers=0 (the default), everything happens in the main process. The fetching step calls dataset[idx] synchronously, which means your GPU sits idle while the CPU loads and preprocesses data. This is often the bottleneck — your expensive GPU is waiting for data.

When num_workers > 0, PyTorch spawns that many worker processes (not threads — the GIL would negate the benefit for CPU-bound work). Each worker runs in its own process with its own copy of the dataset object. The main process sends batch index lists to workers through a multiprocessing.Queue, workers fetch and preprocess the samples, and send the results back through another queue. The main process consumes batches from the results queue.

# What num_workers=4 looks like internally:
#
# Main process:   [send indices] → Queue → [receive batches] → GPU
#                                    ↑           ↑
# Worker 0: dataset[idx] → preprocess → result queue
# Worker 1: dataset[idx] → preprocess → result queue
# Worker 2: dataset[idx] → preprocess → result queue
# Worker 3: dataset[idx] → preprocess → result queue
#
# Workers prefetch: while GPU processes batch N,
# workers prepare batch N+1, N+2, ...

The prefetch_factor parameter (default: 2) controls how many batches each worker prepares in advance. With 4 workers and prefetch_factor=2, up to 8 batches can be waiting in the queue while the GPU processes the current one. This is called double buffering (or more precisely, multi-buffering), and it's the key to hiding data loading latency behind GPU computation.

There are gotchas. Each worker process gets a fork (or spawn) of the main process, including a copy of the dataset. If your dataset holds a large array in memory, you've multiplied that memory usage by num_workers. The fix is to use memory-mapped files or lazy loading in your __getitem__. Also, random state isn't automatically different across workers — if your data augmentation uses random or numpy.random, you need to set a worker_init_fn that seeds each worker differently, or you'll get identical augmentations across workers.

def worker_init_fn(worker_id):
    seed = torch.initial_seed() % 2**32
    import numpy as np
    np.random.seed(seed + worker_id)

loader = DataLoader(dataset, batch_size=32, num_workers=4,
                    worker_init_fn=worker_init_fn)

I still occasionally get bitten by the "pin_memory" gotcha: setting pin_memory=True allocates the batch tensors in page-locked (pinned) memory, which enables faster CPU-to-GPU transfers via DMA. It's almost always worth enabling when training on GPU, but it increases host memory usage and can cause OOM errors if your batches are large and you have many prefetched batches in flight.

The DataLoader's lazy, batched iteration is the reason training on datasets larger than your RAM is possible at all. But all of these containers and iteration patterns share something we haven't talked about yet — the ability to redefine what operators like +, [], and len() mean for custom objects. That mechanism has a name.

Operator Overloading — Teaching Objects New Tricks

Every time you write a + b in Python, you're making a function call. Python translates a + b into a.__add__(b). The + operator isn't hardcoded to mean "add numbers." It means "call the __add__ method on the left operand, passing the right operand as an argument." What that method does is entirely up to the object.

This is operator overloading — defining what operators mean for your custom types. Python uses special methods (called dunder methods, short for "double underscore") to implement this. We've already seen some: __len__ for len(), __getitem__ for [] indexing. The full set is extensive, but here are the ones you'll actually encounter in ML code:

OperationSyntaxDunder MethodML Example
Additiona + b__add__Tensor + Tensor
Multiplicationa * b__mul__Scalar * Tensor
Matrix multiplya @ b__matmul__Weight @ Input
Indexinga[i]__getitem__Dataset[idx]
Lengthlen(a)__len__len(dataset)
String reprrepr(a)__repr__print(tensor)
Call as functiona(x)__call__model(input)
Comparisona == b__eq__Tensor comparison
Booleanbool(a)__bool__if tensor: ...

Let's build a concrete example. Suppose we want our experiment tracker results to support comparison and arithmetic — finding which run is "better" and computing average metrics:

class ExperimentResult:
    def __init__(self, lr, epochs, accuracy):
        self.lr = lr
        self.epochs = epochs
        self.accuracy = accuracy

    def __repr__(self):
        return f"Exp(lr={self.lr}, epochs={self.epochs}, acc={self.accuracy:.3f})"

    def __gt__(self, other):
        return self.accuracy > other.accuracy

    def __eq__(self, other):
        return self.accuracy == other.accuracy

    def __add__(self, other):
        # Average of two experiments
        return ExperimentResult(
            lr=(self.lr + other.lr) / 2,
            epochs=(self.epochs + other.epochs) // 2,
            accuracy=(self.accuracy + other.accuracy) / 2
        )

exp1 = ExperimentResult(0.01, 10, 0.82)
exp2 = ExperimentResult(0.001, 20, 0.91)

print(exp2 > exp1)      # True — __gt__ compares accuracy
print(sorted([exp1, exp2]))  # [Exp(lr=0.01, epochs=10, acc=0.820), ...]
print(exp1 + exp2)       # Exp(lr=0.006, epochs=15, acc=0.865)

Now sorted() works on our custom objects because we defined __gt__. Python's sort only needs a way to compare pairs of objects — the dunder methods provide that interface.

The __call__ Method — Why model(x) Works

This is perhaps the most important dunder method in all of PyTorch. When you write output = model(input), Python calls model.__call__(input). In PyTorch's nn.Module, __call__ does a lot more than call your forward method — it runs hooks, handles autograd setup, manages training vs eval mode, and then calls forward. You never call model.forward(input) directly because you'd skip all that machinery.

import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel(2, 1)
x = torch.tensor([[0.01, 10.0]])
output = model(x)      # calls model.__call__(x) → runs hooks → calls forward(x)
print(output)           # tensor([[...]], grad_fn=<AddmmBackward0>)

The __matmul__ operator (@) is another one you'll use constantly. It was added in Python 3.5 specifically because the scientific computing community needed a clean syntax for matrix multiplication. Before @, you had to write np.dot(A, B) or A.dot(B), which gets unreadable with chains of multiplications. Now: output = W @ x + b. That's a linear layer in one line.

The Reverse Operators

There's a subtlety. When Python evaluates 2 * tensor, it first tries int.__mul__(2, tensor). The integer doesn't know how to multiply with a tensor, so it returns NotImplemented. Python then tries the reflected operator: tensor.__rmul__(2). PyTorch defines __rmul__ on tensors, so 2 * tensor works. Without this fallback mechanism, you'd have to always write tensor * 2, never 2 * tensor.

x = torch.tensor([1.0, 2.0, 3.0])
print(2 * x)    # tensor([2., 4., 6.]) — int.__mul__ fails → tensor.__rmul__ succeeds
print(x * 2)    # tensor([2., 4., 6.]) — tensor.__mul__ succeeds directly

Operator overloading is what makes PyTorch feel like doing math. loss = (predictions - targets) ** 2 works because tensors define __sub__ and __pow__. Every operator in that expression is a method call that builds a computation graph for autograd. It's dunders all the way down.

We've now seen how Python lets objects define their own behavior for built-in operations. But objects don't exist in isolation — they get passed around, stored, and combined. The most fundamental unit of combination is a function. And in Python, functions themselves are objects.

Functions Are Objects

This is one of those ideas that sounds academic until you see what it makes possible. In Python, a function is an object — an instance of the function type, living on the heap like any other PyObject. It has attributes (__name__, __doc__, __defaults__). You can assign it to a variable, put it in a list, pass it as an argument to another function, return it from a function.

This property is called being a first-class citizen, and it's the foundation of everything that follows — closures, higher-order functions, callbacks.

def square(x):
    return x * x

def apply(func, value):
    return func(value)

print(apply(square, 5))   # 25 — we passed a function as an argument

We already used this when we wrote sorted(results, key=lambda r: r["accuracy"]). The key parameter accepts a function. We passed a lambda — an anonymous function expression — as the sorting criterion. The sorted function calls our lambda on each element to decide the ordering.

A lambda is nothing special. It's a shorthand for defining a function that consists of a single expression. lambda r: r["accuracy"] is exactly equivalent to:

def get_accuracy(r):
    return r["accuracy"]

The only difference is that lambda doesn't bind a name. It creates the function object and hands it over, unnamed. Use lambda for short, throwaway functions. Use def for anything more than one expression — readability matters more than cleverness.

Functions being objects is interesting on its own, but the real power emerges when a function remembers variables from the scope where it was created. That's a closure.

Closures — Captured State

A closure is a function that remembers variables from the enclosing scope even after that scope has finished executing. The function "closes over" those variables, carrying them along wherever it goes.

def make_threshold_filter(min_accuracy):
    def passes(result):
        return result["accuracy"] >= min_accuracy
    return passes

is_good = make_threshold_filter(0.85)
print(is_good(results[0]))  # False (0.82 < 0.85)
print(is_good(results[1]))  # True (0.91 >= 0.85)

When make_threshold_filter(0.85) is called, it creates the inner function passes and returns it. At this point, make_threshold_filter has finished executing — its local scope is gone. But passes still has access to min_accuracy because Python stores it in a cell object attached to the function. The inner function carries its environment with it.

There's a gotcha here that catches even experienced developers. Closures capture the variable, not the value. This is called late binding:

filters = []
for threshold in [0.7, 0.8, 0.9]:
    filters.append(lambda r: r["accuracy"] >= threshold)

# You'd expect three different filters. But:
print(filters[0]({"accuracy": 0.75}))  # False — threshold is 0.9, not 0.7!

All three lambdas closed over the same variable threshold, not the value it held at each iteration. By the time you call any of them, threshold has its final loop value of 0.9. The fix is to capture the current value as a default argument:

filters = []
for threshold in [0.7, 0.8, 0.9]:
    filters.append(lambda r, t=threshold: r["accuracy"] >= t)
    # t=threshold evaluates at definition time, freezing the value

Remember our earlier discussion about default arguments being evaluated at definition time? That behavior, which caused the mutable default argument bug, is the same mechanism that fixes the late binding problem here. Same rule, opposite consequence. Python is consistent — it's our mental models that are inconsistent.

Closures are the engine behind many patterns in ML code — learning rate schedulers that close over the optimizer, callback functions that close over training state, and most importantly, they're what make context managers and resource management work. Speaking of which — there's a pattern of "set up, do work, tear down" that shows up constantly in PyTorch training loops.

Context Managers — PyTorch's Train, Eval, and No-Grad

Every PyTorch training loop has a rhythm: train for an epoch, evaluate on validation data, repeat. And every time you switch from training to evaluation, there's a set of state changes that must happen — and must be undone afterward. Getting this wrong causes some of the most insidious bugs in ML: your validation accuracy looks great, but your model is actually still in training mode with dropout active, or gradients are silently being computed and eating GPU memory during inference.

Python's with statement and the context manager protocol exist to handle exactly this pattern: set something up, do work, guarantee the teardown happens even if something crashes in the middle. Let me show you the three most important context managers in PyTorch, and then we'll see how they work under the hood.

torch.no_grad() — The Memory Saver

During training, PyTorch builds a computation graph — every operation on tensors is recorded so that .backward() can compute gradients. This graph consumes memory. A lot of memory. During inference (prediction), you don't need gradients. You want the forward pass to run without recording anything.

model = SimpleModel(2, 1)
x = torch.tensor([[0.01, 10.0]])

# WITHOUT no_grad — computation graph is built (wasteful during inference)
output = model(x)
print(output.requires_grad)  # True — graph was recorded, memory allocated

# WITH no_grad — no graph, no gradient memory
with torch.no_grad():
    output = model(x)
    print(output.requires_grad)  # False — clean, no graph overhead

What torch.no_grad() actually does: when you enter the with block, it sets a thread-local flag that tells PyTorch's autograd engine to stop recording operations. Every tensor operation inside the block skips graph construction. When you exit the block — whether normally or via an exception — the flag is restored to its previous state. The "guarantee" part is crucial: if your inference code throws an error, you don't want the no-grad flag stuck on for the rest of your program.

model.train() and model.eval() — Behavioral Switches

Some layers behave differently during training versus inference. Dropout randomly zeros out neurons during training (for regularization) but passes everything through during inference. BatchNorm uses batch statistics during training but switches to running-average statistics during inference. These layers check the model's training flag to decide which behavior to use.

model.train()   # sets self.training = True on model and ALL submodules
# Dropout is active, BatchNorm uses batch stats

model.eval()    # sets self.training = False on model and ALL submodules
# Dropout is disabled, BatchNorm uses running stats

A common mistake: calling model.eval() but forgetting torch.no_grad(). They do different things. model.eval() changes layer behavior (dropout, batchnorm). torch.no_grad() disables gradient computation. You almost always want both during validation:

def validate(model, val_loader):
    model.eval()                          # change layer behavior
    total_loss = 0.0
    with torch.no_grad():                 # stop recording gradients
        for features, targets in val_loader:
            output = model(features)
            loss = criterion(output, targets)
            total_loss += loss.item()
    model.train()                         # restore training behavior
    return total_loss / len(val_loader)

Notice we have to remember to call model.train() at the end. If we forget — or if an exception interrupts the function — the model stays in eval mode for subsequent training iterations. Dropout won't fire, BatchNorm will use stale running stats, and your training will silently degrade.

How Context Managers Work Under the Hood

A context manager is any object that implements two dunder methods: __enter__ and __exit__. When Python encounters with obj as x, it calls x = obj.__enter__(). When the block ends (normally or via exception), it calls obj.__exit__(exc_type, exc_val, exc_tb).

torch.no_grad() is a class that works like this (simplified from PyTorch source):

# Simplified from torch/autograd/grad_mode.py
class no_grad:
    def __enter__(self):
        self.prev = torch.is_grad_enabled()   # save previous state
        torch.set_grad_enabled(False)          # disable gradients
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.set_grad_enabled(self.prev)      # restore previous state
        return False                           # don't suppress exceptions

The __exit__ method restores the previous grad state regardless of how the block ended. That's the safety guarantee. Even if your inference code raises an error, gradients get re-enabled.

You can build your own context managers for common ML patterns. Here's one that handles the train/eval switching properly:

from contextlib import contextmanager

@contextmanager
def eval_mode(model):
    """Context manager that handles model.eval() + no_grad safely."""
    was_training = model.training
    model.eval()
    try:
        with torch.no_grad():
            yield model
    finally:
        if was_training:
            model.train()

# Now validation is safe even if something crashes:
with eval_mode(model) as m:
    for batch in val_loader:
        output = m(batch)
        # ... if an exception occurs here, model.train() is still called

The @contextmanager decorator from contextlib turns a generator function into a context manager. The code before yield is __enter__, the code after yield (in the finally block) is __exit__. Every concept we've built so far — closures, dunder methods, generators — is working together here.

Another context manager you'll encounter in production: torch.cuda.amp.autocast() for mixed-precision training. It temporarily changes the default dtype for certain operations to float16 or bfloat16, enabling faster computation on modern GPUs:

from torch.amp import autocast, GradScaler

scaler = GradScaler()

for features, targets in train_loader:
    optimizer.zero_grad()

    with autocast(device_type='cuda'):         # forward pass in float16/bfloat16
        output = model(features)
        loss = criterion(output, targets)

    scaler.scale(loss).backward()              # backward in float32
    scaler.step(optimizer)
    scaler.update()

The pattern is always the same: set up a temporary state, yield control to the user's code, guarantee cleanup. It's the most reliable way to manage the complex state transitions that real ML training requires.