MathIsimple
CNNs & Computer Vision
12 min read

Practical Transfer Learning with ResNet in PyTorch

Building a multi-task fruit classifier on a small dataset, with frozen backbone, weighted losses, cosine annealing, and early stopping.

Transfer LearningResNetPyTorchMulti-Task LearningFine-Tuning

Most real computer-vision projects do not have ImageNet's 1.2 million labeled images. They have a few hundred or a few thousand. Training a CNN from scratch on that much data either overfits immediately or plateaus at uselessly low accuracy. Transfer learning — taking a network pretrained on ImageNet and adapting it to a target task — is what makes serious image classification accessible at small scale.

This article walks through a complete multi-task transfer-learning pipeline in PyTorch using ResNet-18: building a shared backbone with two prediction heads, deciding which layers to freeze, weighting losses across tasks, and structuring training with cosine annealing and early stopping. The example task — predicting both fruit type (six classes) and freshness (fresh / rotten) from a single image — is small enough to be transparent and realistic enough to be instructive.

Why not just train two separate models?

The most direct approach to a two-output problem is two independent ResNets, one for each output. It works, but it pays for the simplicity:

  • Inference cost doubles. Each model runs the full backbone on every image.
  • Parameter count doubles. Two models means twice the storage and memory footprint.
  • The two backbones extract highly redundant features. Recognizing "this is an apple" and recognizing "this apple is rotten" both rely on color, texture, and shape — there is no reason to learn those features twice.

Multi-task learning addresses all three. A shared backbone extracts features once. Two small classification heads branch off at the end, one per task. When the tasks are related, the shared backbone benefits from receiving gradient signals from both heads, and each task implicitly regularizes the other.

Shared backbone with two classification heads

class MultiTaskFruitModel(nn.Module):
    def __init__(self, backbone_name="resnet18", pretrained=True):
        super().__init__()
        self.backbone, feature_dim = self._build_backbone(backbone_name, pretrained)
        self.fruit_head = nn.Linear(feature_dim, len(FRUIT_TO_IDX))         # 6 classes
        self.freshness_head = nn.Linear(feature_dim, len(FRESHNESS_TO_IDX)) # 2 classes

    def forward(self, x):
        features = self.backbone(x)
        fruit_logits = self.fruit_head(features)
        freshness_logits = self.freshness_head(features)
        return fruit_logits, freshness_logits

The crucial trick is in _build_backbone:

feature_dim = backbone.fc.in_features   # 512 for ResNet-18
backbone.fc = nn.Identity()             # discard the original 1000-class head

ResNet-18's last layer is Linear(512 → 1000), mapping to ImageNet's 1000 classes. Replacing it with nn.Identity() turns the backbone into a 512-dimensional feature extractor. Two new linear heads consume that 512-dimensional vector independently.

The data flow at inference time:

Image (224x224x3)
    │
    ▼
ResNet backbone  →  512-dim feature vector
                         │
        ┌────────────────┴────────────────┐
        ▼                                 ▼
    fruit_head (512→6)            freshness_head (512→2)
        │                                 │
        ▼                                 ▼
    Six fruit logits             Fresh/rotten logits

One backbone pass produces both predictions. The structure also creates a useful side effect: when the freshness task discovers that brown patches indicate rot, the gradient pushes the backbone to extract patch features more reliably. Better patch features benefit fruit classification too, because color and texture are class-discriminative there as well. This implicit regularization is a real advantage when the two tasks are correlated.

Multi-task loss: combining gradients in one backward pass

Each task contributes its own cross-entropy loss. The two losses are combined into a scalar before backpropagation:

loss_fruit = criterion_fruit(fruit_logits, fruit_labels)
loss_fresh = criterion_fresh(fresh_logits, fresh_labels)

loss = 2.0 * loss_fruit + 1.0 * loss_fresh
loss.backward()

The weights — 2.0 for fruit classification, 1.0 for freshness — are not arbitrary. Fruit classification has six classes and is harder; freshness has two classes and is easier. Giving the harder task a higher weight forces the optimizer to allocate more capacity to it, balancing the gradient pressure between heads.

loss.backward() is called once. PyTorch differentiates the combined loss with respect to all parameters automatically. The gradient signals from both heads flow back into the shared backbone and add element-wise, so a single optimizer step updates everything.

Freezing layers: protecting pretrained features

The single most important decision in transfer learning with a small dataset is which layers to freeze. ResNet-18's features are hierarchical:

  • Shallow layers (conv1, layer1, layer2): edges, color gradients, textures. Universal across all image tasks.
  • Deep layers (layer3, layer4): object parts, semantic concepts. Strongly task-specific.

ResNet-18 was trained on 1.2 million ImageNet images. Its shallow layers have seen far more edge and texture variation than a few thousand fruit photos can provide. Letting all layers train from the start risks corrupting those carefully learned features with noisy gradients from a small dataset.

def freeze_backbone_layers(model, freeze_until="layer2"):
    freeze_order = ["conv1", "bn1", "layer1", "layer2", "layer3"]
    freeze_idx = freeze_order.index(freeze_until)
    layers_to_freeze = freeze_order[:freeze_idx + 1]

    for name, param in model.backbone.named_parameters():
        for layer_name in layers_to_freeze:
            if name.startswith(layer_name):
                param.requires_grad = False
                break

Freezing up to layer2 means:

conv1   ← frozen (edge detection, color gradients)
bn1     ← frozen
layer1  ← frozen (low-level texture)
layer2  ← frozen (mid-level shape)
─────────────────────────────────────
layer3  ← trained (high-level fruit-specific features)
layer4  ← trained
fruit_head      ← trained
freshness_head  ← trained

The mechanism behind requires_grad = False is direct: PyTorch skips gradient computation for parameters with this flag set, and the optimizer skips them too. There is no special "freeze" API — it is just a flag on each parameter.

Freezing protects an asset that is more valuable than the dataset: pretrained features. The deeper the freeze depth, the more pretrained capacity is preserved at the cost of fewer trainable parameters. layer2 is a reasonable default for small datasets in domains close to ImageNet (natural photographs, common objects). For domains far from ImageNet (medical imaging, satellite imagery), shallower freezing or full fine-tuning often does better, but those settings need more data to avoid overfitting.

The optimizer must only see trainable parameters

trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)

If frozen parameters are passed to the optimizer, two things go wrong. PyTorch wastes memory tracking optimizer state (momentum, variance estimates) for parameters that will never update. And some optimizers will warn or error when given parameters that have no gradients during a step. The filter is small but essential.

AdamW is the modern default for fine-tuning. The weight_decay parameter applies L2 regularization correctly — separately from the gradient update rather than embedded in it as in plain Adam. For small-dataset fine-tuning, a learning rate of 10410^{-4} is conservative and stable; aggressive learning rates risk disrupting the pretrained features the freezing was meant to protect.

Learning-rate schedule and early stopping

Two scheduling decisions matter for small-dataset transfer learning. The first is the learning-rate schedule. Cosine annealing from 10410^{-4} down to 10610^{-6} over the planned training run gives the optimizer large steps early to find a good basin and small steps late for precise convergence:

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

The second is early stopping. With limited data, training too long produces a model that fits the training set perfectly but generalizes poorly. Track the average validation accuracy across both tasks and stop when it stops improving:

def evaluate(model, val_loader):
    fruit_correct, fresh_correct, total = 0, 0, 0
    model.eval()
    with torch.no_grad():
        for images, fruit_y, fresh_y in val_loader:
            fruit_logits, fresh_logits = model(images)
            fruit_correct += (fruit_logits.argmax(1) == fruit_y).sum().item()
            fresh_correct += (fresh_logits.argmax(1) == fresh_y).sum().item()
            total += images.size(0)
    return {
        "fruit_acc": fruit_correct / total,
        "fresh_acc": fresh_correct / total,
    }

# Training loop with early stopping
patience = 8
best_score = 0.0
patience_counter = 0
best_state = None

for epoch in range(num_epochs):
    train_one_epoch(model, train_loader, optimizer)
    scheduler.step()
    val_metrics = evaluate(model, val_loader)
    score = (val_metrics["fruit_acc"] + val_metrics["fresh_acc"]) / 2.0

    if score > best_score:
        best_score = score
        patience_counter = 0
        best_state = copy.deepcopy(model.state_dict())
    else:
        patience_counter += 1
        if patience_counter >= patience:
            break

model.load_state_dict(best_state)

Tracking the average accuracy across tasks ensures both heads are doing reasonable work — a model that aces fruit classification but fails on freshness is not a useful multi-task model. Restoring the best state at the end means the deployed model is the snapshot from the best validation epoch, not the one from the last training step (which may have started overfitting).

When to fine-tune vs. freeze

Two clean rules cover most practical situations:

Freeze most of the backbone when the target dataset is small (a few hundred to a few thousand examples) and the domain is close to ImageNet. Train only the last stage and the new heads. Convergence is fast, overfitting is suppressed, and the pretrained features do most of the work.

Fine-tune all layers when the target dataset is large enough to support meaningful gradient updates throughout the backbone, or when the domain is genuinely different from ImageNet (medical imaging, satellite imagery, microscopy). Use a smaller learning rate for the pretrained backbone than for the new heads — a common practice is two parameter groups: 10510^{-5} for the backbone, 10310^{-3} for the heads.

The main takeaway

Transfer learning is not about "using someone else's model." It is about identifying which parts of an existing trained network are reusable for a new task and which parts need to be relearned, then structuring training so the reusable parts are protected while the new parts converge.

For a small-dataset multi-task project, the pattern is reliable: a frozen ResNet-18 backbone provides 512-dimensional features cheaply, two linear heads consume those features for two correlated tasks, a weighted sum of cross-entropy losses lets one backward pass train everything, and cosine annealing plus early stopping handle convergence without overfitting. Each design choice corresponds to a specific risk in the small-data regime — freezing protects pretrained features, multi-task heads share computation, weighted loss balances task difficulty, early stopping preserves the best validation snapshot. Together they turn ImageNet's pretrained backbone into a working classifier for a task it was never trained on.

Related reading

Continue the cluster — these articles build directly on the ideas above.

Ask AI ✨