Frameworks: TensorFlow, Keras & JAX
PyTorch owns research and is rapidly eating production. TensorFlow still runs half the world's deployed models — its graph compilation and mobile story (TFLite) remain unmatched. JAX is the wildcard: a functional, compiler-first library where you write NumPy and get auto-diff, JIT compilation, vectorization, and multi-device parallelism through composable function transforms. Keras 3 now runs on all three backends. You don't need to master every framework — but understanding why each exists, what trade-off it makes, and when it's the right tool will come up in every senior-level conversation about ML infrastructure.
Why This Section Exists
I'll be honest — for a long time I avoided thinking about frameworks that weren't PyTorch. The previous section taught us to train models by writing explicit forward passes, computing losses, calling .backward(), and stepping an optimizer. That felt like the "real" way. Everything else looked like either a shortcut that hides what matters, or an academic exercise in functional programming. I procrastinated this dive for longer than I should have.
But here's what kept nagging at me: every time I saw a production ML system at a large company, it was TensorFlow. Every time I read a DeepMind paper pushing the boundary on scale, it was JAX. Every time someone prototyped something in an afternoon with three lines of code, it was Keras. These frameworks aren't historical curiosities. They represent fundamentally different answers to the question: what should the contract between a programmer and a GPU look like?
Before we start, a heads-up. We're going to be looking at computation graphs, JIT compilation, functional transforms, and multi-backend architecture. You don't need to know any of it beforehand. We covered the training loop in PyTorch last section — that's all the context we need. We'll add what we need, one piece at a time.
This isn't a short journey, but I hope you'll be glad you came.
Three Philosophies, One Goal
Every deep learning framework does the same three things: represent tensors, compute gradients, and run operations on accelerators. The differences are in how they organize the contract between your code and the hardware. Let me use an analogy that will follow us through this section.
Think of building a house. PyTorch is like a construction crew that executes your instructions as you say them — "put a brick here, now one there." You see every step happen in real time. You can stop mid-wall, inspect the mortar, change your mind. That's eager execution. TensorFlow's core approach is more like handing an architect a blueprint, and the architect optimizes the entire plan before a single brick is laid — rearranging the plumbing, fusing redundant walls, pre-cutting all the lumber at once. That's graph compilation. JAX takes the most radical approach: you describe a mathematical function for "what a house should look like," and JAX gives you back a new function that builds it — optimized, parallelized, and differentiable. That's function transformation.
Same house. Three very different relationships between you and the process.
Our Running Example: A Tiny Classifier
To keep things concrete, we're going to build the same thing in every framework: a small neural network that classifies handwritten digits (MNIST — 28×28 grayscale images, 10 classes). Two hidden layers, ReLU activations, softmax output. Nothing fancy. The architecture is identical everywhere — what changes is the code you write to express it and the machinery that runs underneath.
We'll watch how each framework handles the same four operations: define the model, compute the loss, get gradients, update parameters. By the end, you'll have a visceral sense of why these frameworks feel so different despite doing the same math.
TensorFlow: The Graph Compiler
TensorFlow was released by Google in 2015, and for its first few years, it was the dominant deep learning framework. Period. If you were doing deep learning between 2015 and 2019, you were probably writing TensorFlow.
But there's something you need to know about those early days: TensorFlow 1.x was painful. You had to define your entire computation as a static graph using placeholder variables, then "run" that graph inside a Session object. You couldn't print a tensor's value during execution. You couldn't set a breakpoint and inspect intermediate results. Debugging felt like trying to fix a car engine by reading the blueprint instead of opening the hood.
TensorFlow 2 (released in 2019) fixed this by making eager execution the default — operations run immediately, like PyTorch. But TF's core identity didn't change. Under the hood, TensorFlow still wants to compile your code into an optimized computation graph. It still wants to be the architect with the blueprint.
Eager Mode vs Graph Mode — The Central Tension
This is the single most important thing to understand about TensorFlow, and it comes up in interviews constantly. TF lives in two modes, and the tension between them explains almost everything about the framework's design.
In eager mode, every operation executes immediately and returns a concrete value. You can print it, inspect it with Python debugger, mix it with regular Python code. It feels natural. It's also slower, because every single operation goes through Python's interpreter and the overhead of dispatching each op individually to the GPU.
In graph mode, you wrap a function with @tf.function and TensorFlow traces it — recording every operation without executing them — to build a computation graph. Then TF hands that graph to its compiler (XLA, the same compiler JAX uses), which fuses operations, eliminates redundancies, optimizes memory layout, and produces a single compiled kernel. The result runs fast, because Python is completely out of the loop during execution.
import tensorflow as tf
# Eager mode: this runs immediately, returns a concrete value
x = tf.constant([1.0, 2.0, 3.0])
y = x * 2 + 1
print(y) # tf.Tensor([3. 5. 7.], shape=(3,), dtype=float32)
# Graph mode: wrap with @tf.function to compile
@tf.function
def compute(x):
return x * 2 + 1
# First call: TF traces the function, builds a graph, compiles it
# Subsequent calls with same shape/dtype: runs compiled version directly
result = compute(x)
The first time you call a @tf.function-decorated function, TF traces it — feeding special "tracer" objects through the function to record every operation. On future calls with the same input shapes and dtypes, TF skips Python entirely and runs the compiled graph. If you call it with different shapes, it retraces and compiles again.
Here's where it gets tricky. TF uses a system called AutoGraph to convert Python control flow (if-statements, for-loops) into graph operations. A Python if becomes a tf.cond. A Python for loop becomes a tf.while_loop. Most of the time this works. Sometimes it doesn't, and the error messages will make you question your career choices. I still occasionally get caught by a tracing error where a Python conditional works fine in eager mode but silently does the wrong thing inside a @tf.function.
@tf.function only happen during tracing, not during execution. A print() inside a tf.function fires once (during tracing) and then never again. Use tf.print() instead. This single gotcha has confused more people than any other aspect of TensorFlow.
Our MNIST Classifier in Keras
Keras is TensorFlow's high-level API. Where PyTorch gives you an explicit training loop — forward, loss, backward, step — Keras collapses all of that into three method calls. Here's our digit classifier:
import tensorflow as tf
from tensorflow import keras
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.1)
predictions = model.predict(x_test)
compile configures the optimizer and loss. fit runs the entire training loop — batching, forward pass, loss computation, backprop, parameter update, metric tracking, validation. predict runs inference. Three calls. That's the entire training pipeline.
If you're coming from PyTorch where you wrote a 30-line training loop, this feels like cheating. And in a way, it is — you're trading control for speed of development. For standard architectures with standard training, Keras is genuinely hard to beat on iteration speed. But the moment you need a custom training step, a non-standard loss, or a training loop that branches based on intermediate results, you start fighting Keras's abstractions instead of using them.
Keras does offer a middle ground: subclassing keras.Model and overriding train_step gives you PyTorch-level control while keeping Keras's metric tracking and callback system. Most production Keras code I've seen lives in this space — not the three-liner, not raw TensorFlow, but somewhere in between.
Where TensorFlow Still Wins
The reason TensorFlow remains enormously relevant despite PyTorch's research dominance comes down to deployment infrastructure that took years to build and has no PyTorch equivalent of the same maturity.
TF Serving is a production model server with automatic batching, model versioning, gRPC/REST endpoints, and canary deployments built in. You export a SavedModel and TF Serving handles the rest. Companies that serve millions of predictions per second — think Google, Airbnb, Twitter — often run TF Serving or derivatives of it.
TFLite compiles models for mobile and edge devices. If your model needs to run on a phone with 2GB of RAM and no internet connection, TFLite is the most battle-tested path. It handles quantization (shrinking model precision from float32 to int8), operator fusion, and hardware-specific acceleration (GPU delegates, NNAPI on Android, CoreML on iOS).
TensorFlow.js runs models in the browser. tf.data pipelines handle datasets too large for memory with prefetching, shuffling, and parallel loading. These aren't flashy features, but they're the plumbing that makes ML work in production at scale.
tf.placeholder, tf.variable_scope — all deprecated, but still running in production at plenty of companies. If someone hands you a TF1 codebase, don't panic. The math is the same. The API is what's different.
Keras 3: The Plot Twist
Here's something that caught me off guard when I first heard about it. In late 2023, François Chollet (Keras's creator) released Keras 3, and it's no longer tied to TensorFlow. Keras 3 is a multi-backend framework. The same Keras code can run on TensorFlow, PyTorch, or JAX — you pick the backend with an environment variable.
# Set before importing keras
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"
import keras
from keras import layers
model = keras.Sequential([
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit(x_train, y_train, epochs=5)
Same model definition. Same compile → fit. But underneath, it's running on JAX's XLA compiler, or PyTorch's autograd, or TensorFlow's graph executor. The idea is "write once, run anywhere" for deep learning.
This is a bigger deal than it might seem. It means Keras is no longer "TensorFlow's high-level API" — it's positioning itself as a universal interface. If you prototype in Keras and later realize you need JAX's function transforms for a specific research direction, you change one line instead of rewriting your model. If your company's production stack is TensorFlow but your research team uses PyTorch, Keras 3 can bridge that gap.
The caveat: not everything works equally well on every backend yet. Custom ops, some advanced layers, and certain preprocessing utilities have backend-specific limitations. But the core model-building and training loop works, and it's improving rapidly. Note that Keras 3 is the standalone keras package — not tf.keras, which is still TensorFlow-only and essentially frozen.
But we haven't talked about JAX yet, and JAX represents a fundamentally different way of thinking about computation. If the idea of "what if we treated neural network training as function composition" makes you curious, read on.
JAX: Functions All the Way Down
JAX was released by Google in 2018 and it starts from a premise that sounds almost academic: what if everything were a pure function?
A pure function always produces the same output for the same input, and has no side effects — no modifying global state, no mutating arguments, no printing. In regular Python, this is a nice-to-have. In JAX, it's a hard requirement for using the framework's most powerful features.
Why would anyone want that constraint? Because pure functions are safe to transform. If a function has no hidden dependencies and no side effects, you can safely compile it, differentiate it, vectorize it, or run it across eight TPUs — and you're guaranteed the result is correct. That's the deal JAX offers: give up mutation, get transformations.
Let's go back to our house-building analogy. In JAX, you don't build a house. You write a mathematical description of what a house is — a function from (lot, materials, blueprint) → house. Then JAX gives you new functions: one that computes how the house changes if you tweak the blueprint (that's differentiation), one that builds a whole neighborhood from your single-house function (that's vectorization), one that assigns different workers to different houses in parallel (that's device parallelism). All from the same original function description.
The Four Transforms
JAX gives you four function transforms, and understanding these is the key to understanding why people use JAX at all. Let's build our MNIST classifier piece by piece.
jax.grad takes a function and returns a new function that computes its gradient. Not a gradient value — a gradient function. This is the most fundamental departure from PyTorch. In PyTorch, you call loss.backward() and gradients accumulate on tensor .grad attributes. In JAX, there's no .backward(), no gradient attributes, no side effects at all. You pass a function in, you get a function out.
import jax
import jax.numpy as jnp
def mse_loss(params, x, y):
preds = jnp.dot(x, params['w']) + params['b']
return jnp.mean((preds - y) ** 2)
# grad returns a FUNCTION, not a value
grad_fn = jax.grad(mse_loss)
# Call the gradient function with actual data
params = {'w': jnp.ones((3,)), 'b': 0.0}
grads = grad_fn(params, x_batch, y_batch)
# grads has the same structure as params: {'w': array(...), 'b': float}
Notice something subtle: grads is a dictionary with the exact same structure as params. JAX knows how to differentiate through nested data structures — dicts, lists, tuples, combinations of all of them. These nested structures are called pytrees, and they're how JAX handles model parameters. No special Parameter class, no nn.Module with a .parameters() method. Parameters are data. Gradients are data with the same shape.
jax.jit takes a function and compiles it using XLA (Accelerated Linear Algebra), the same compiler that powers TensorFlow's graph mode. The first time you call a jit-compiled function, JAX traces it — feeding special tracer objects through the function to record every operation — and compiles the resulting computation graph. On future calls with the same input shapes, the compiled version runs directly, bypassing Python entirely.
@jax.jit
def training_step(params, x, y, lr=0.01):
grads = jax.grad(mse_loss)(params, x, y)
# Update params (functional style — return new params, don't mutate)
new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return new_params
# First call: traces + compiles. Subsequent calls: runs compiled code.
params = training_step(params, x_batch, y_batch)
The speed difference is dramatic. A jit-compiled JAX function can be 10-100x faster than the same function without JIT, because XLA fuses operations, optimizes memory access patterns, and generates hardware-specific kernels.
jax.vmap is automatic vectorization. You write a function that processes a single example, and vmap lifts it to process an entire batch. No loops, no manual batch dimension handling. This is wildly useful for things like per-example gradient computation, which is a pain in PyTorch but trivial in JAX.
def predict_single(params, x):
"""Process a SINGLE input vector — no batch dimension."""
return jnp.dot(x, params['w']) + params['b']
# vmap it: now it handles batches automatically
predict_batch = jax.vmap(predict_single, in_axes=(None, 0))
# None = don't map over params, 0 = map over first axis of x
predictions = predict_batch(params, x_batch) # works on entire batch
jax.pmap is like vmap but across devices. Write a function for one GPU, pmap distributes it across all available GPUs or TPUs. Each device processes a slice of the data. This is how Google trains models on TPU pods with hundreds of accelerators — they write the single-device logic and let pmap handle the distribution.
The real power is composability. These transforms stack. jax.jit(jax.vmap(jax.grad(loss_fn))) gives you a compiled, batched gradient function in one line. You can differentiate a vectorized function, compile the result, and distribute it across devices. Each transform is independent and they compose cleanly because everything is a pure function with no hidden state to corrupt.
The Gotchas — Why JAX Isn't for Everyone
I'm still developing my intuition for when JAX's constraints become a net negative rather than a net positive. Here are the real costs of functional purity:
No in-place mutation. JAX arrays are immutable. x[0] = 5 doesn't work. You use x.at[0].set(5), which returns a new array. Coming from NumPy or PyTorch, this feels like fighting the language. It's not — it's a deliberate design choice that makes JIT compilation safe — but it takes time to internalize.
Manual random state. NumPy has a global random number generator. Call np.random.normal() anywhere and you get random numbers. JAX has no global state (because global state is a side effect). Instead, you pass an explicit random key, and you split it every time you need more randomness:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(3,))
# Need more random numbers? Split again.
key, subkey = jax.random.split(key)
more_noise = jax.random.normal(subkey, shape=(5,))
If you reuse a key without splitting, you get the same random numbers. This is by design — reproducibility is baked in — but it's a footgun if you're not careful.
Python control flow inside JIT. Because jax.jit traces your function, Python if statements and for loops that depend on array values don't work as you'd expect. They're evaluated once during tracing and baked into the compiled graph. If you need data-dependent branching, you use jax.lax.cond. If you need data-dependent loops, you use jax.lax.scan or jax.lax.while_loop. This is the same tracing problem TensorFlow has with @tf.function, and it produces similarly confusing errors.
print() inside a jit-compiled function executes during tracing (once) but not during actual computation. If you're debugging a loss that's returning NaN, print(loss) shows you a tracer object, not a number. Use jax.debug.print() or drop the @jit decorator while debugging.
The JAX Ecosystem: Flax, Equinox, and Optax
JAX itself doesn't provide neural network layers. It gives you arrays and transforms. Building a neural network from raw JAX is like building a car from raw steel — possible, educational, but not how you'd ship a product. Libraries fill this gap.
Flax is Google's official neural network library for JAX. It provides an nn.Module base class with explicit parameter management. Parameters live in a separate dictionary, not on the module object. This separation feels odd if you're used to PyTorch, but it's the natural way to handle state in a functional framework — your model definition is separate from your model's parameters.
import flax.linen as nn
class MNISTClassifier(nn.Module):
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(64)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return nn.log_softmax(x)
model = MNISTClassifier()
# Initialize: pass a random key and a dummy input to discover parameter shapes
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28)))
# Forward pass: model.apply takes params explicitly
logits = model.apply(params, x_batch)
See how model.init returns parameters as a pytree, and model.apply takes them as an argument? The model object itself holds no state. You can jit-compile model.apply, take its gradient, vmap it — because it's a pure function of (params, input).
Equinox takes a different approach. Instead of separating parameters from modules, Equinox makes modules themselves pytrees. Your model is a dataclass whose fields are parameters and sub-modules. This feels more like PyTorch — the model "has" its parameters — but it still plays nicely with JAX transforms because pytrees are JAX-native.
Optax is the optimizer library. Rather than monolithic optimizer classes, Optax builds optimizers by chaining small transforms — scale by learning rate, apply momentum, clip gradients. You compose the chain you want:
import optax
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # gradient clipping
optax.adam(learning_rate=1e-3) # Adam update
)
opt_state = optimizer.init(params)
# Training step: get grads, apply optimizer, return new params and state
grads = jax.grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
This composability is Optax's superpower. Want gradient clipping + Adam + weight decay + learning rate warmup? Chain them. Each transform is independent and testable. It's the same philosophy as JAX itself — small, composable, functional pieces.
The Convergence: torch.compile and the Blurring Lines
Something interesting happened in 2023 that blurs the lines between these frameworks. PyTorch 2.0 introduced torch.compile, bringing compilation to the eager-execution world.
Under the hood, torch.compile uses two new components. TorchDynamo is a frontend that intercepts Python bytecode to capture a computation graph — without requiring you to change your code. TorchInductor is a backend compiler that takes that graph and generates optimized GPU kernels using Triton (a GPU programming language) or C++/CUDA.
import torch
model = MyModel()
# One line to compile — no code changes needed
compiled_model = torch.compile(model)
# Use exactly as before
output = compiled_model(input_batch)
The key difference from JAX's JIT: when Dynamo encounters Python code it can't trace (a dynamic condition, an unsupported operation), it falls back to eager mode for that section and continues. JAX would throw an error. This means torch.compile is far more forgiving — it optimizes what it can and runs the rest normally. The trade-off is that it might not achieve the same level of optimization as JAX on fully traceable code.
This convergence is real. PyTorch now has compilation (like TF and JAX). TensorFlow has eager execution (like PyTorch). JAX has high-level neural network libraries (like PyTorch and Keras). The frameworks are learning from each other. The differences are narrowing. But the core philosophies — imperative vs. graph-compiled vs. functional — still shape how you think about and write code in each one.
The Full Picture
| PyTorch | TensorFlow / Keras | JAX | |
|---|---|---|---|
| Core paradigm | Imperative, eager-first, OOP. torch.compile adds optional graph compilation. |
Graph compilation at its core. Eager mode as default UX since TF2. Keras provides declarative high-level API. | Functional, transform-based. Pure functions + composable transforms (grad, jit, vmap, pmap). |
| Gradient computation | loss.backward() — gradients accumulate on tensor .grad attributes (imperative, stateful) |
tf.GradientTape context manager records ops, then tape.gradient() (imperative with explicit recording) |
jax.grad(fn) returns a new function (purely functional, no state) |
| Compilation | torch.compile (Dynamo + Inductor). Graceful fallback on untraceable code. |
@tf.function + XLA. AutoGraph converts Python control flow. |
@jax.jit + XLA. Strict — requires pure functions, no side effects. |
| Production deployment | TorchServe, ONNX export, TorchScript. Improving rapidly. | TF Serving, TFLite (mobile), TF.js (browser), SavedModel. Most mature. | Limited native deployment. Often export to TF SavedModel or ONNX. |
| Multi-device | DistributedDataParallel (DDP), FSDP. Manual setup. | tf.distribute.Strategy. Relatively simple API. |
jax.pmap / jax.sharding. First-class, especially on TPUs. |
| Ecosystem | HuggingFace, torchvision, torchaudio. Largest open-source model ecosystem. | TF Hub, tf.data, TFX pipeline. Enterprise and mobile ecosystem. | Flax, Equinox, Optax, Orbax. Smaller but concentrated at Google/DeepMind. |
| Community (2024) | ~70% of research papers. Growing in production. | Declining in research, still massive in deployed production. | Small but influential. Used for AlphaFold, Gemini, PaLM. |
When to Reach for What
Let's drop the diplomacy and be direct about this. Every "framework comparison" article ends with "it depends" and lists pros and cons. That's true, but not helpful. Here's a more honest heuristic.
Default to PyTorch. For research, for learning, for most new projects. The ecosystem is largest, the open-source models you'll want to use are in PyTorch, and the debugging experience (eager execution, real Python stack traces) is the best. PyTorch 2.0's torch.compile closes the performance gap with compiled frameworks. Meta, Microsoft, Tesla, and most AI startups are PyTorch shops. If you only learn one framework deeply, this is the one.
Reach for TensorFlow/Keras when you're deploying to mobile or edge devices (TFLite has no real PyTorch equivalent at the same maturity), you're inheriting an existing TF production system (rewriting is rarely worth it), or you're prototyping something standard and want the fastest path from idea to trained model (Keras's compile → fit is unbeatable for standard architectures). With Keras 3, you also get multi-backend flexibility — prototype in Keras, deploy on whatever backend your infrastructure supports.
Reach for JAX when you need composable function transforms that don't exist elsewhere (differentiating through a physics simulation, computing per-example gradients efficiently, custom higher-order derivatives), you're working with TPU pods (JAX has first-class TPU support through XLA), or you're in the Google/DeepMind ecosystem where JAX is the primary framework. AlphaFold, Gemini, and PaLM were all built with JAX. If large-scale Google research is your world, JAX isn't optional.
My favorite thing about frameworks is that, once you understand one deeply, the others take days to pick up, not months. The concepts — layers, losses, optimizers, gradient computation, batching, data loading — are identical. What changes is how you express them. A nn.Linear in PyTorch is a keras.layers.Dense in Keras is a nn.Dense in Flax. The weights are the same matrix. The forward pass is the same multiplication. The gradient is the same derivative.
Wrap-Up
If you're still with me, thank you. I hope it was worth it.
We started with a question — why do other frameworks exist when PyTorch works fine? — and traced three different answers. TensorFlow said "give me a blueprint and I'll optimize the whole plan before execution." JAX said "write pure functions and I'll give you differentiation, compilation, vectorization, and parallelism as composable transforms." Keras said "let me handle the ceremony so you can focus on the architecture." We built the same classifier in each, watched how they handle the same four operations differently, and saw how PyTorch 2.0's torch.compile is blurring the lines between eager and compiled execution.
My hope is that the next time you encounter a TensorFlow codebase in production or a JAX-based research paper, instead of feeling lost or dismissive, you'll recognize the philosophy behind the syntax — and you'll have a pretty good mental model of what's going on under the hood.
Resources
- Keras 3 announcement — François Chollet's blog post explaining the multi-backend vision. Wildly clarifying on where Keras is headed. keras.io/keras_3
- JAX Quickstart — The official tutorial is one of the best-written framework introductions I've seen. Starts from NumPy, builds to transforms. jax.readthedocs.io
- "You Don't Know JAX" — The Common Gotchas page. Required reading before writing any serious JAX code. Common Gotchas
- PyTorch 2.0 blog post — The design behind torch.compile, Dynamo, and Inductor. Insightful on why PyTorch chose "graph breaks" over strict tracing. pytorch.org
- Equinox documentation — Patrick Kidger's library is beautifully designed and the docs explain the "models as pytrees" philosophy better than anything else. docs.kidger.site/equinox
What You Should Now Be Able To Do
- Explain TensorFlow's eager-vs-graph tension and why
@tf.functionexists - Read a Keras
compile → fit → predictpipeline and know what each call does under the hood - Describe what "pure function" means in JAX and why that constraint enables grad, jit, vmap, and pmap
- Explain what pytrees are and why JAX uses them instead of Parameter objects
- Compare how PyTorch, TF, and JAX each handle gradient computation differently
- Know what Keras 3 is and how it differs from tf.keras
- Articulate when you'd choose TF over PyTorch (mobile, serving, existing infrastructure) and when you'd reach for JAX (composable transforms, TPUs, Google ecosystem)
- Explain how
torch.compilebridges the eager-vs-compiled gap and how it differs fromjax.jit - Know that Flax, Equinox, and Optax exist and what role each plays in the JAX ecosystem