A complete PyTorch training script can look surprisingly short. That brevity is deceptive. Inside one five-line loop, PyTorch is building a computation graph, differentiating a scalar objective, and updating every trainable parameter in the network.
This article unpacks a plain multilayer perceptron training script line by line, with the focus where it belongs: the mathematics of logits, gradients, and parameter updates.
The full script
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.device(
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
net.apply(init_weights)
net.to(device)
batch_size, lr, num_epochs = 256, 0.1, 10
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
for epoch in range(num_epochs):
net.train()
for X, y in train_loader:
X, y = X.to(device), y.to(device)
logits = net(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()The code looks compact because PyTorch hides the bookkeeping. The underlying logic is still the standard mathematical pipeline.
The model computes logits, not probabilities
The network architecture is
Here is the logit vector in . The last layer does not apply Softmax. That job is delegated to the loss function, which uses a stable log-softmax implementation internally.
The first layer uses nn.Flatten() because Fashion-MNIST images start as tensors of shape . Flattening turns each image into a vector in , which is what the first linear layer expects.
Why net.apply(init_weights) exists
PyTorch modules are trees. A sequential network contains submodules, and those submodules may contain parameters. The call
net.apply(init_weights)recursively visits every submodule and calls your function on it. In this script, the initializer only acts on nn.Linear layers.
The actual initialization is small-Gaussian weights plus zero biases:
The point is not that this is the universally best initializer. The point is to start in a regime where activations and gradients are not immediately extreme.
Why data and model must live on the same device
net.to(device) moves every parameter tensor in the model onto the selected device. That is why the training loop also contains
X, y = X.to(device), y.to(device)Matrix multiplications and loss evaluations cannot mix tensors from different devices. If the model lives on MPS and the data stays on CPU, the forward pass will fail.
Cross-entropy expects logits
nn.CrossEntropyLoss() combines log-softmax with negative log-likelihood. For one example with label , the objective is
That is why you pass logits directly. If you manually apply Softmax first, you lose numerical stability and you duplicate work the loss function was already going to do.
The default reduction is a batch mean. So if the batch contains examples, the scalar loss returned by PyTorch is
The training loop has five mathematical steps
1. Forward pass
logits = net(X) evaluates the model with the current parameter values. During this forward pass, PyTorch records the operations needed to differentiate the final scalar loss with respect to those parameters.
2. Loss evaluation
loss = loss_fn(logits, y) turns the batch predictions and labels into one scalar objective. That scalar is the root node for backpropagation.
3. Gradient reset
optimizer.zero_grad() is necessary because PyTorch accumulates gradients by default. Without clearing them, the next backward call would add new gradients to the previous ones.
4. Backpropagation
loss.backward() applies the chain rule through the entire computation graph. For each trainable parameter , it computes a numerical tensor representing
and stores that tensor in theta.grad.
The useful distinction here is this: a symbolic derivative is a function, but backward() evaluates that derivative at the current parameter values and the current batch. What lands in .grad is a concrete numerical array.
5. Optimizer step
With vanilla SGD, optimizer.step() applies
to every trainable parameter in the model. The learning rate determines the step size. Large gradients imply the loss is sensitive to that parameter, so the update magnitude grows accordingly.
What evaluation mode changes
In evaluation, two pieces of state matter:
model.eval()switches layers such as Dropout and BatchNorm into inference behavior.torch.no_grad()disables gradient tracking, which saves memory and computation.
Accuracy is then just
Since Softmax preserves ordering, gives the same predicted class as .
The main takeaway
A pure PyTorch MLP training script looks short because the framework automates the hard bookkeeping, not because the mathematics disappeared.
- The model outputs logits.
- Cross-entropy turns those logits into a scalar objective.
- Autograd computes every parameter gradient by reverse-mode differentiation.
- The optimizer uses those gradients to update parameters.
Once that pipeline is clear, training code stops looking like ceremony and starts looking like a direct implementation of the optimization problem you wrote down on paper.