Image Segmentation is the process of partitioning a digital image into multiple segments. It provides a pixel-wise understanding of the visual scene.
1Semantic Segmentation
Bounding boxes are helpful, but they are clumsy. If a pedestrian is standing next to a car, the boxes overlap. For high-stakes AI like autonomous driving, we need pixel-perfect precision. Welcome to Image Segmentation and the U-Net architecture.
Semantic Segmentation is the process of classifying every single pixel in an image into a category. Instead of outputting a box [x, y, w, h], the neural network outputs a 'mask'βa new image where each pixel represents a class label (e.g., Road=1, Car=2, Person=3).
# Segmentation vs Detection
# Detection: Returns [x, y, w, h] for objects
# Segmentation: Returns a matrix of shape [Height, Width]
# where every value is a class ID integer.2The U-Net Architecture
The undisputed king of segmentation architecture is the U-Net. It gets its name because its diagram looks like the letter 'U'.
The left side is the 'Encoder' (which compresses the image) and the right side is the 'Decoder' (which expands it back). In standard object detection networks, the network only shrinks the image to extract features. U-Net must also expand the image back up in a 'Decoder' phase because the final output mask must have the exact same Height x Width resolution as the original input image.
import torch.nn as nn
# The U-Net structure
# 1. Encoder (Downsampling path)
# 2. Bottleneck (Deepest features)
# 3. Decoder (Upsampling path)3The Encoder (Compressing Space)
Let's look at the Encoder block. It uses normal Convolutions and Max Pooling. Max Pooling cuts the height and width in half.
As we go down the 'U', the image gets smaller physically, but we increase the number of channels (feature depth). The Encoder learns 'WHAT' is in the image (a car, a dog), but because we shrink the image, we lose the precise spatial coordinates of 'WHERE' those boundaries are. If we just blindly upsampled this back to full size, the edges would be blurry and terrible.
def encoder_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
# Shrinks physical size by 50%
nn.MaxPool2d(kernel_size=2, stride=2)
)4Skip Connections (The Magic Bridge)
This spatial loss is where U-Net performs magic: Skip Connections. Instead of just passing data sequentially, U-Net takes the high-resolution images from the early Encoder stages, and literally copies them across to the Decoder stages.
We 'concatenate' them together. The primary architectural purpose of these skip connections is to provide the decoder with lost high-resolution spatial details for sharp object boundaries. We combine the semantic depth with spatial clarity.
class UNet(nn.Module):
def forward(self, x):
# 1. Save high-res encoder output
enc1_out = self.encoder1(x)
# 2. In decoder, concatenate it
dec1_input = torch.cat([upsampled, enc1_out], dim=1)5The Decoder (Transposed Convolutions)
With the skip connections providing the 'Where', the Decoder block performs 'Transposed Convolutions'. This is the mathematical opposite of pooling. It forces a small matrix to expand into a larger one, doubling the height and width at each step.
At the very end of the network, the final output layer maps the channels down to the exact number of classes you are trying to predict. If you are predicting 'Background', 'Car', and 'Road', the final output channel depth is 3. If you only want to classify 'Healthy' or 'Tumor', the depth is 2.
def decoder_block(in_channels, out_channels):
return nn.Sequential(
# Expands spatial dimensions
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.ReLU()
)