MathIsimple
CNNs & Computer Vision
18 min read

Batch Normalization, ResNet, and DenseNet: Making Very Deep CNNs Trainable

How three architectural ideas — normalization, residual learning, and dense connectivity — fixed the gradient and feature-flow problems that capped CNN depth.

Batch NormalizationResNetDenseNetCNN ArchitecturePyTorch

Three architectural ideas — batch normalization, residual learning, and dense connectivity — each solved a specific failure mode of deep convolutional networks. Together they shifted the practical ceiling on CNN depth from around 20 layers to over 150, and they remain the structural foundation of nearly every modern image model.

This article covers all three in sequence: what problem each one solves, how it solves it mathematically, and how to implement it in PyTorch from scratch. The architecture progression is not accidental — each idea sets up the next.

Part 1: The degradation problem

Adding layers to a convolutional network should, in principle, make it more expressive. In practice, the early 2010s produced a consistent and puzzling result: beyond a modest depth, accuracy on the training set stopped improving. This was not overfitting — the training error itself plateaued or rose. Researchers found that a 56-layer network regularly underperformed a 20-layer one on the same task.

The degradation problem: as depth increases past a threshold, training accuracy degrades — not because the network overfits, but because optimization becomes harder. A deeper network that should, in theory, be able to represent anything the shallower network can represent, fails to learn even the identity mapping through its additional layers.

The identity mapping observation is the key intuition. If you take a network that converges at nn layers and add kk more layers, the optimal solution for those added layers would be to do nothing — pass the input through unchanged. This sounds trivial, but gradient-based optimization has difficulty learning this. The added layers instead learn slightly incorrect mappings, accumulating small errors that compound through the stack and degrade the overall function.

What residual learning changes

The insight behind ResNet is a reformulation of what each layer should be asked to learn. Instead of asking a stack of layers to learn a mapping H(x)\mathcal{H}(x) from scratch, force the stack to learn a residual:

F(x)=H(x)xH(x)=F(x)+x\mathcal{F}(x) = \mathcal{H}(x) - x \quad \Longrightarrow \quad \mathcal{H}(x) = \mathcal{F}(x) + x

The shortcut connection adds the original input xx directly to the output of the residual block. If the optimal behavior for a block is the identity, the network only needs to drive F(x)\mathcal{F}(x) to zero — a far simpler target than learning the full identity mapping through many weight layers. Empirically, residual functions are much easier for gradient descent to learn than unreferenced mappings.

The shortcut connection also provides a direct path for gradients. In a plain deep network, gradients must flow through every weight layer in sequence; repeated matrix multiplications can attenuate or amplify them unpredictably. The shortcut bypasses those layers entirely, delivering gradient signal back to early layers without passing through the weights. This is distinct from vanishing gradient fixes like LSTM gates — here the bypass is structural, not gated.

Part 2: Batch normalization

ResNet needs deep networks to train. Deep networks need stable activation distributions to converge. Batch normalization is the mechanism that provides that stability, and it appears throughout ResNet as a first-class component rather than an afterthought.

The shifting distribution problem

A deep network processes data through a sequence of layers. Each layer receives the output of the one before it. During training, every weight update modifies what an earlier layer sends downstream — so the input distribution to every subsequent layer changes after each parameter update. Downstream layers chase a moving target rather than fitting a stable function.

This effect compounds with depth. A small change in an early layer propagates through all later layers, and the deeper the network, the more pronounced the instability. The practical symptoms: activations saturate in the flat regions of sigmoid or tanh where gradients vanish, gradient magnitudes become inconsistent across layers, and training grows sensitive to the learning rate in ways that are hard to predict.

The two-step mechanism

Batch normalization addresses the distribution shift directly. Given a mini-batch of mm activation values, it computes batch statistics and normalizes each value to approximately zero mean and unit variance:

μB=1mi=1mxix^i=xiμBσB2+ε\mu_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} x_i \qquad \hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \varepsilon}}

Fixing every layer's output to zero mean and unit variance would constrain the model unnecessarily. The second step recovers expressivity with a learned affine transform:

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

γ\gamma and β\beta are learned per channel. If the normalized representation is already optimal, the network sets γ=1,β=0\gamma = 1, \beta = 0. If it needs a different scale or shift, it learns those. The normalization step removes distributional chaos; the affine step hands full expressive control back to the network.

Placement, bias elimination, and inference mode

The standard placement in a ResNet block is after the convolution, before the activation:

Conv2d(bias=False)  →  BatchNorm2d  →  ReLU

When batch normalization follows a convolutional layer, any learned bias term is immediately subtracted by the normalization step. Setting bias=False on the preceding convolution is therefore standard — it eliminates a redundant parameter without any loss of expressive power.

At inference time, there is no mini-batch to compute statistics from. PyTorch tracks a running exponential moving average of batch statistics throughout training:

μrunning(1α)μrunning+αμB\mu_{\text{running}} \leftarrow (1 - \alpha)\,\mu_{\text{running}} + \alpha\,\mu_\mathcal{B}

After training, running_mean and running_var approximate population statistics. The model must be explicitly switched between modes:

model.train()   # uses current batch statistics
model.eval()    # uses running_mean / running_var

Forgetting model.eval() before inference is one of the most common silent bugs in PyTorch. The output will vary with batch composition rather than reflecting the learned function — a particularly insidious class of bug because the model runs without error.

Part 3: ResNet and residual learning

ResNet-18 is the smallest member of the ResNet family, but it contains every structural idea that appears in ResNet-50, ResNet-101, and beyond. The only differences between variants are the block type and the number of blocks per stage. Building ResNet-18 from scratch makes the architecture fully transparent.

The BasicBlock: residual addition in practice

The fundamental unit of ResNet-18 is the BasicBlock: two 3×3 convolutions with a shortcut connection that adds the block's input to its output.

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.relu  = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)

        # The shortcut must match the output shape exactly.
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        out = out + identity          # residual addition
        return self.relu(out)

The shortcut path handles two shape-mismatch situations. When stride=2, the convolution halves the spatial dimensions, so the input x is twice as large as the output. When channel count changes (for example, 64 → 128), the depths do not align. In both cases, a 1×1 convolution projects x into the correct shape. When neither condition holds, nn.Identity() passes x through unchanged — zero-cost at both training and inference time.

The second convolution always uses stride=1. Downsampling happens exactly once per stage, in the first block of that stage. Subsequent blocks within the stage perform feature extraction at fixed spatial resolution.

The ResNet backbone: stages and spatial flow

The full network is organized as a stem followed by four stages, each stage containing one or more BasicBlocks:

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_channels = 64

        # Stem: 7×7 conv reduces 224×224 → 112×112, then max pool → 56×56
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        self.layer1 = self._make_layer(block,  64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        layers = [block(self.in_channels, out_channels, stride=stride)]
        self.in_channels = out_channels * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)      # [B,  64, 56, 56]
        x = self.layer1(x)    # [B,  64, 56, 56]
        x = self.layer2(x)    # [B, 128, 28, 28]
        x = self.layer3(x)    # [B, 256, 14, 14]
        x = self.layer4(x)    # [B, 512,  7,  7]
        x = self.avgpool(x)   # [B, 512,  1,  1]
        x = torch.flatten(x, 1)
        return self.fc(x)

def resnet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

The _make_layer method accepts the block class itself rather than an instance. This lets the same ResNet constructor work with both BasicBlock and Bottleneck by changing only one argument at the call site. The expansion class attribute on the block type tells the final linear layer how many channels to expect: BasicBlock.expansion = 1, Bottleneck.expansion = 4.

The name ResNet-18 comes from counting the learnable layers: 8 blocks × 2 convolutions per block = 16 convolutional layers, plus the 7×7 stem convolution and the final fully connected layer, for 18 total. ResNet-50 and ResNet-101 replace BasicBlock with a three-layer Bottleneck block (1×1 → 3×3 → 1×1) and use deeper stage configurations.

ResNet as a transfer learning backbone

ResNet-18 pretrained on ImageNet converges in seconds on small datasets that would take hours to train from scratch. The standard approach removes the final linear layer and replaces it with a new head sized for the target task:

model = torchvision.models.resnet18(weights="IMAGENET1K_V1")

# Replace the 1000-class head with a 2-class head
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Option A: fine-tune all parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Option B: freeze all layers except the new head (faster, less data needed)
for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
)

Option B is often sufficient when the target domain is close to ImageNet (natural photographs, common objects). Option A makes sense when the domain diverges (medical imaging, satellite imagery) or when enough training data is available to meaningfully update the early convolutional layers without overfitting.

Part 4: DenseNet and dense connectivity

ResNet solved the depth problem by adding shortcuts. DenseNet asked a different question: what if every layer could see every previous layer's output directly? The answer changes the fundamental structure of feature flow through the network — and reduces the parameter count substantially in the process.

Why addition dilutes information

ResNet's shortcut computes H(x)=F(x)+x\mathcal{H}(x) = \mathcal{F}(x) + x. Once the addition happens, the original xx and the transformed F(x)\mathcal{F}(x) are merged into one tensor. Downstream layers cannot distinguish which features came from which source. Early-layer features — edges, textures, color gradients — survive multiple additions in diluted form, but they cannot be accessed cleanly by a deep layer that wants to reuse them directly.

Addition also imposes a channel-count constraint: F(x)\mathcal{F}(x) must have exactly as many channels as xx. This forces each layer in a ResNet to maintain a large channel count (256 or 512 in deeper stages) so that the identity path and the residual path remain compatible, even in layers that do not need that many channels for their actual computation.

DenseNet replaces addition with concatenation. Instead of F(x)+x\mathcal{F}(x) + x, each layer produces:

x=H([x0,  x1,  ,  x1])x_\ell = H_\ell\bigl([x_0,\; x_1,\; \ldots,\; x_{\ell-1}]\bigr)

Every layer receives the concatenation of all preceding layers' feature maps as input, and contributes its output to all subsequent layers. No information is lost to mixing. Early features remain available in their original form at every depth.

Growth rate and Dense Blocks

Because every layer concatenates all previous outputs, the number of input channels grows with depth. DenseNet controls this with the growth rate kk: each layer contributes exactly kk new channels to the feature stack. kk is typically small — 12 or 32. The reasoning: since all previous features are preserved by concatenation and accessible directly, each layer does not need to maintain a large channel count to retain information. It only needs to contribute a small number of genuinely new features.

A Dense Block with \ell layers and initial input width c0c_0 produces a final output of width c0+kc_0 + \ell \cdot k. For a block with 12 layers and k=32k=32, starting from 64 channels, the output is 448 channels — achieved with far fewer parameters than a ResNet stage at comparable depth.

PyTorch implementation

The core layer is a bottleneck design that first compresses the accumulated input channels, then extracts exactly kk new features, and concatenates the result:

class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        # 1×1: compress accumulated channels to 4k
        self.bottleneck = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
        )
        # 3×3: extract k new features from the compressed representation
        self.conv = nn.Sequential(
            nn.BatchNorm2d(4 * growth_rate),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False),
        )

    def forward(self, x):
        out = self.conv(self.bottleneck(x))
        return torch.cat([x, out], dim=1)   # ← concatenation, not addition

The pre-activation order (BN → ReLU → Conv) is intentional. Each DenseLayer receives input that is the concatenation of feature maps from layers trained at different times with different distributions. Normalizing before the convolution brings this heterogeneous input to a common scale, giving the convolutional kernel a consistent signal to learn from.

A DenseBlock chains multiple DenseLayers, tracking the growing channel count:

class DenseBlock(nn.Module):
    def __init__(self, in_channels, num_layers, growth_rate):
        super().__init__()
        layers = []
        for i in range(num_layers):
            # Layer i sees: initial channels + i layers × k channels each
            layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)   # each call widens x by k channels
        return x

Between Dense Blocks, a TransitionLayer reduces both channel count and spatial resolution, preventing the feature stack from growing unbounded across blocks:

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, compression=0.5):
        super().__init__()
        out_channels = int(in_channels * compression)   # θ=0.5 by default
        self.transition = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2),
        )

    def forward(self, x):
        return self.transition(x)

The full DenseNet-121 (the smallest standard configuration from the original paper) assembles stem, four Dense Blocks, three Transition Layers, and a classifier:

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_layers=(6, 12, 24, 16),
                 init_channels=64, compression=0.5, num_classes=1000):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, init_channels, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(init_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        features = nn.Sequential()
        in_channels = init_channels

        for i, num_layers in enumerate(block_layers):
            features.add_module(f'DenseBlock_{i+1}',
                                DenseBlock(in_channels, num_layers, growth_rate))
            in_channels = in_channels + num_layers * growth_rate

            if i != len(block_layers) - 1:
                features.add_module(f'Transition_{i+1}',
                                    TransitionLayer(in_channels, compression))
                in_channels = int(in_channels * compression)

        features.add_module('BN_final', nn.BatchNorm2d(in_channels))
        features.add_module('ReLU_final', nn.ReLU(inplace=True))

        self.features = features
        self.classifier = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.features(x)
        x = x.mean(dim=[2, 3])   # global average pooling
        return self.classifier(x)

Following the channel arithmetic: DenseNet-121 starts with 64 channels from the stem. After the four Dense Blocks and three Transition Layers (each compressing by 0.5), the final feature dimension before the classifier is 1024. The total parameter count is approximately 7.8 million — compared to 11.7 million for ResNet-18 and 25.6 million for VGG-16, while achieving substantially higher accuracy on ImageNet.

ResNet vs. DenseNet: practical considerations

ResNet is the more commonly deployed option because its shortcut addition is memory-efficient: the shortcut does not create a new tensor, it adds in place. DenseNet's concatenation requires storing all previous feature maps within a block simultaneously, which increases peak memory consumption during training roughly linearly with block depth.

DenseNet tends to outperform ResNet of comparable parameter count on image classification benchmarks, particularly when training data is limited. The direct feature reuse provides a form of implicit ensemble diversity — each layer can select from a rich palette of features at multiple scales rather than being limited to what the immediately preceding layer provides.

For deployment on memory-constrained hardware, ResNet is usually the better choice. For tasks where model accuracy matters more than inference-time memory budget — or where the dataset is small enough that feature reuse reduces overfitting — DenseNet is worth serious consideration.

The main takeaway

Each of the three architectures in this article represents a different answer to the same underlying question: how do you move information — both feature representations and gradient signals — efficiently through a very deep network?

Batch normalization stabilizes the optimization landscape by keeping activation distributions well-behaved throughout training. It does not change what the network can represent; it makes that representation reachable by gradient descent. The two practical rules that follow directly from understanding it: bias=False before a batch norm layer, and model.eval() before inference.

ResNet's shortcut connection reformulates the learning problem. By asking each block to learn a residual F(x)=H(x)x\mathcal{F}(x) = \mathcal{H}(x) - x rather than the full mapping H(x)\mathcal{H}(x), it makes the identity solution trivially learnable and provides a direct gradient path that bypasses the weight layers entirely. The degradation problem disappears.

DenseNet changes the connectivity topology. By concatenating — not adding — all previous feature maps as input to each layer, it eliminates information loss through mixing, enables direct feature reuse across depths, and allows each layer to maintain a narrow channel count while still drawing on the full feature history. The cost is peak memory consumption during training; the benefit is higher accuracy per parameter.

For most standard tasks, ResNet-18 or ResNet-50 with ImageNet pretraining is the practical starting point. DenseNet deserves consideration when training data is limited, when the task benefits from multi-scale feature access, or when parameter count is constrained more than memory. Both architectures remain in active use in modern pipelines, often as backbone components within larger systems.

Related reading

Continue the cluster — these articles build directly on the ideas above.

Ask AI ✨