Transfer Learning is the practice of taking a model trained on one task and repurposing it for another. It is the gold standard for high-performance Vision AI.
1The Ultimate Shortcut
Training a Deep CNN from scratch takes massive data and days of GPU time. Why do that when you can borrow the brain of a model that already knows how to see?
This is the core philosophy of Transfer Learning. First, we load a model like ResNet or VGG that was pre-trained on ImageNet. It already knows how to recognize basic shapes, textures, and objects, acting as an incredibly powerful feature extractor right out of the box.
import torchvision.models as models
import torch.nn as nn
# Load ResNet18 with ImageNet weights
model = models.resnet18(pretrained=True)2Preserving Knowledge (Freezing)
We don't want to destroy the pre-trained weights during training. If we pass gradients all the way back through the entire network, our small, uncalibrated dataset might aggressively overwrite the carefully learned ImageNet features.
To prevent this, we 'freeze' the base layers by setting their gradient requirements to False. This locks the weights in place, ensuring the model retains its foundational vision capabilities while drastically reducing the computation required.
# Freeze all parameters in the base model
for param in model.parameters():
param.requires_grad = False3Replacing the Head
Now, we replace the final classification layer. If ImageNet has 1000 classes but we only need 2 (for example, a simple Cat vs. Dog classifier), we swap the 'head' of the model.
We grab the number of input features going into the final layer, and then overwrite that layer with a brand new, randomly initialized Linear layer mapped to our specific number of output classes.
num_ftrs = model.fc.in_features
# Replace last layer with a new linear layer
model.fc = nn.Linear(num_ftrs, 2)
# New layer has requires_grad=True by default4Targeted Fine-Tuning
By training only this new layer, we leverage the 'vision' of the original model while adapting it to our specific task with very little data.
The optimizer will only update the weights of our new classification head because the rest of the model is frozen. Once the head is stable, we could potentially unfreeze a few of the top base layers to 'fine-tune' them, but often just training the new head is enough for stellar results.
# Model is ready for Fine-Tuning
print('Classification head replaced.')
# Only the new fc layer weights will update during training