MathIsimple
Deep Learning
13 min read

Why Deep Networks Fail to Train: Vanishing Gradients and Initialization

A deep multilayer perceptron can fail long before the optimizer gets a fair chance.

InitializationVanishing GradientsXavier InitializationHe InitializationDeep Learning Basics

A deep multilayer perceptron can fail long before the optimizer gets a fair chance. The code runs, the data loads, and the loss either stays flat or turns into NaN. In many of those cases, the real problem is not the optimizer at all. It is the scale of the signals you created at initialization.

The central issue is simple: in a deep network, both activations and gradients are pushed through a long sequence of matrix multiplications and nonlinearities. If each layer amplifies the signal a little, the whole system can blow up. If each layer shrinks it a little, the signal can disappear before it reaches the early layers.

Why depth makes gradient problems unavoidable

Consider a depth-LL network. A first-layer parameter only affects the loss through every layer that comes after it, so the gradient has the chain-rule form

LW1=LzLzLzL1z2z1z1W1\frac{\partial \mathcal{L}}{\partial W_1} = \frac{\partial \mathcal{L}}{\partial z_L} \frac{\partial z_L}{\partial z_{L-1}} \cdots \frac{\partial z_2}{\partial z_1} \frac{\partial z_1}{\partial W_1}

The important qualitative fact is not the exact notation. It is that the gradient depends on a long product of local sensitivity terms. Each term acts like a multiplier. When those multipliers are consistently below one, gradients shrink exponentially. When they are consistently above one, gradients grow exponentially.

This is why a shallow network may train without drama while a five-layer or ten-layer network with almost identical code becomes unstable. Depth turns a small scale mismatch into a compounding effect.

The multiplier comes from both the weights and the nonlinearity

Each layer contributes two ingredients to that multiplier. The first is the weight matrix. If the singular values of the layer are large, activations and gradients can be amplified. If they are tiny, both signals are damped. In practical terms, very large random weights make the network numerically aggressive, while very small weights can make it inert.

The second ingredient is the derivative of the activation function. With sigmoid,

σ(z)=11+ez,σ(z)=σ(z)(1σ(z))\sigma(z) = \frac{1}{1 + e^{-z}}, \qquad \sigma'(z) = \sigma(z)(1 - \sigma(z))

the derivative becomes very small whenever zz is far from zero. That is the saturation problem. If early random weights push pre-activations into those saturated regions, each layer contributes another factor near zero, and the backward signal collapses.

Tanh has the same structural issue, although centered activations help a little. ReLU avoids two-sided saturation, but it has its own asymmetry: the derivative is 11 on the positive side and 00 on the negative side. If too many units land in the negative region and stay there, you get dead neurons and partial gradient shutdown.

In deep networks, the effective multiplier is never "just the weights" or "just the activation." It is the interaction between the two.

Initialization has three jobs, not one

A good initialization is not merely "small random numbers." It has to satisfy three separate goals at once.

  • Break symmetry so different neurons do not learn identical features.
  • Keep forward activations at a stable scale across depth.
  • Keep backward gradients at a stable scale across depth.

The first requirement rules out all-zero weights. The other two are harder. Ideally, if the input to a layer has variance near one, the output should also have variance near one. Likewise, if the gradient entering a layer has a reasonable variance, the gradient leaving it should not immediately explode or vanish.

Those are variance-preservation goals. Xavier and He initialization are best understood as principled attempts to approximate them.

Xavier initialization balances fan-in and fan-out

Suppose a layer computes z=Wxz = Wx and the coordinates of xx are centered with equal variance. Then a rough independence argument gives

Var(z)ninVar(W)Var(x)\operatorname{Var}(z) \approx n_{\text{in}} \operatorname{Var}(W) \operatorname{Var}(x)

where ninn_{\text{in}} is the fan-in. Preserving forward variance suggests Var(W)1/nin\operatorname{Var}(W) \approx 1 / n_{\text{in}}. A mirror calculation for backward flow suggests Var(W)1/nout\operatorname{Var}(W) \approx 1 / n_{\text{out}}. Xavier initialization splits the difference:

Var(W)=2nin+nout\operatorname{Var}(W) = \frac{2}{n_{\text{in}} + n_{\text{out}}}

That compromise works well for activations like tanh, where positive and negative responses are treated more symmetrically. In PyTorch, the corresponding tools are nn.init.xavier_uniform_ and nn.init.xavier_normal_.

He initialization compensates for ReLU's one-sided gate

ReLU changes the variance story because half the mass is typically cut off at zero. If inputs are roughly symmetric around zero, then applying ReLU discards the negative half and reduces the activation variance. If you keep Xavier's scale unchanged, the signal can decay layer by layer.

He initialization compensates for that by increasing the weight variance to

Var(W)=2nin\operatorname{Var}(W) = \frac{2}{n_{\text{in}}}

The extra factor of two is not arbitrary. It is the variance correction needed when the activation keeps only the positive side of a roughly symmetric signal. In PyTorch, that family appears as nn.init.kaiming_uniform_ and nn.init.kaiming_normal_.

A good default rule is straightforward: use Xavier for tanh-like activations, and use He for ReLU-like activations.

How to inspect whether a network is healthy

When training behaves strangely, you do not have to guess. You can instrument the model and inspect activation statistics and gradient norms directly.

import torch
import torch.nn as nn

def check_gradient_health(model, input_sample):
    activations = {}

    def make_hook(name):
        def hook(module, inputs, output):
            activations[name] = output.detach()
        return hook

    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            hooks.append(module.register_forward_hook(make_hook(name)))

    model.zero_grad(set_to_none=True)
    output = model(input_sample)
    loss = output.sum()
    loss.backward()

    print("\n[Activations]")
    for name, act in activations.items():
        print(name, act.mean().item(), act.std().item())

    print("\n[Gradients]")
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(name, param.grad.norm().item(), param.grad.mean().item())

    for hook in hooks:
        hook.remove()

You are not looking for perfect constants. You are looking for warning signs. If activation standard deviations collapse toward zero early in the network, forward signal is dying. If early-layer gradient norms are tiny compared with late-layer gradients, backpropagation is fading before it reaches the front of the model. If norms blow up by orders of magnitude, you are in the opposite regime.

A practical troubleshooting checklist

When loss is flat or unstable, walk through the simple causes first.

  1. Check for NaN or inf in activations, loss, and gradients.
  2. Inspect whether the learning rate is wildly too large or too small.
  3. Look at activation statistics to see whether the model is saturating.
  4. Look at gradient norms across layers to detect vanishing or exploding behavior.
  5. Confirm that inputs are normalized to a sensible scale.
  6. Revisit initialization for custom parameters or custom modules.
  7. If the network is deep, consider architectural aids such as normalization layers or residual connections.

The main takeaway

Initialization matters because deep learning is a scale-sensitive process. Before the model has learned anything, your random draw determines whether signals stay in a trainable range or whether they decay and explode before optimization even starts.

Xavier and He initialization are not magic formulas. They are variance-preserving design rules. Once you understand that, "training instability" stops feeling mysterious. It becomes a question of signal propagation, and that question is something you can inspect and fix.

Ask AI ✨