MathIsimple
RNNs & Sequence Modeling
14 min read

RNNs from Scratch: Backpropagation Through Time and Gradient Clipping

Time-axis weight sharing, the BPTT chain rule, and why exploding gradients are not optional to handle.

RNNBPTTGradient ClippingPyTorchSequence Models

A recurrent neural network is the simplest answer to the question "how do I let a model see arbitrarily long history?" — share one set of weights across time, maintain a hidden state, and update that state with every new input. The architecture fits in a few lines of code, but training it correctly requires understanding backpropagation through time, gradient clipping, and where the parameter explosion you might expect from variable-length sequences actually goes.

The RNN recurrence

A vanilla RNN is defined by two equations applied at every time step tt:

ht=ϕ(Wxhxt+Whhht1+bh)h_t = \phi(W_{xh} x_t + W_{hh} h_{t-1} + b_h)
ot=Whqht+bqo_t = W_{hq} h_t + b_q

The hidden state hth_t is the running summary. It depends on the current input xtx_t through WxhW_{xh} and on the previous hidden state through WhhW_{hh}. The activation ϕ\phi is usually tanh, which keeps the hidden state bounded in [1,1][-1, 1]. The output oto_t is a linear projection of the hidden state — for language modeling, those would be the logits over the vocabulary.

The single most important property of these equations is that the same parameters WxhW_{xh}, WhhW_{hh}, WhqW_{hq} are used at every time step. Time-axis parameter sharing is the recurrent counterpart of weight sharing in convolution. Parameter count is constant in sequence length. A 4-token sequence and a 4000-token sequence use exactly the same number of parameters.

Why tanh, not ReLU

In feedforward networks, ReLU is the default. In RNNs, tanh dominates. The reason is that the hidden state is fed back into itself. With ReLU, an unbounded positive activation can grow without limit: ht=ReLU(Whhht1+)h_t = \text{ReLU}(W_{hh} h_{t-1} + \ldots) can produce ht>ht1h_t > h_{t-1} indefinitely. After hundreds of steps, hidden states explode.

Tanh squeezes its output into [1,1][-1, 1], which prevents this kind of magnitude blow-up. The trade-off is that tanh saturates: when the pre-activation is large in absolute value, the gradient is nearly zero. That contributes to the vanishing gradient problem, which is the deeper reason gated cells like LSTM and GRU were eventually invented.

A from-scratch forward pass

The cleanest way to internalize the recurrence is to build the forward pass without using any high-level RNN module. The pattern: initialize parameters, define a step function, loop over time:

def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    W_xh = normal((num_inputs, num_hiddens))
    W_hh = normal((num_hiddens, num_hiddens))
    b_h  = torch.zeros(num_hiddens, device=device)
    W_hq = normal((num_hiddens, num_outputs))
    b_q  = torch.zeros(num_outputs, device=device)

    params = [W_xh, W_hh, b_h, W_hq, b_q]
    for p in params:
        p.requires_grad_(True)
    return params


def init_rnn_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),)


def rnn_forward(inputs, state, params):
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:                              # X: (batch_size, vocab_size)
        H = torch.tanh(X @ W_xh + H @ W_hh + b_h) # (batch_size, num_hiddens)
        Y = H @ W_hq + b_q                        # (batch_size, vocab_size)
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

The input shape conventions deserve attention. inputs is iterated along the time dimension, so each X is one time slice with shape (batch_size, vocab_size). The Python for loop over time is intentional — it makes the recurrence visible. PyTorch's built-in nn.RNN compiles the same logic into optimized CUDA kernels, but the math is identical.

What backpropagation through time is actually doing

Training an RNN means minimizing a loss summed over time. For language modeling with cross-entropy at every position:

L=1Tt=1T(ot,yt)\mathcal{L} = \frac{1}{T} \sum_{t=1}^{T} \ell(o_t, y_t)

Since the same weights are used at every time step, the gradient of L\mathcal{L} with respect to WhhW_{hh} accumulates contributions from every time step. This is what backpropagation through time (BPTT) does. The chain rule unrolls the recurrence:

LWhh=t=1TththtWhh\frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_{t=1}^{T} \frac{\partial \ell_t}{\partial h_t} \frac{\partial h_t}{\partial W_{hh}}

Each ht/Whh\partial h_t / \partial W_{hh} contains a chain back through every previous time step, because hth_t depends on ht1h_{t-1}, which depends on ht2h_{t-2}, and so on. Conceptually, the unrolled computation graph is a deep feedforward network whose depth equals the sequence length, and BPTT is just standard backpropagation applied to that unrolled graph.

Where vanishing and exploding gradients come from

The gradient of the hidden state at time TT with respect to the hidden state at time tt involves a product of Jacobians:

hTht=k=t+1Thkhk1=k=t+1Tdiag(ϕ())Whh\frac{\partial h_T}{\partial h_t} = \prod_{k=t+1}^{T} \frac{\partial h_k}{\partial h_{k-1}} = \prod_{k=t+1}^{T} \operatorname{diag}(\phi'(\cdot)) W_{hh}^\top

Whatever the spectral radius of WhhW_{hh} is gets raised to the power of the gap TtT - t. If the largest singular value is below 1, the product shrinks exponentially — gradients from far back vanish. If it is above 1, the product grows exponentially — gradients explode.

This is the same vanishing/exploding gradient problem that plagues very deep feedforward networks, but in an RNN it is structurally guaranteed because the same weight matrix is multiplied many times. The depth of the computation graph along the time axis can easily reach hundreds.

Gradient clipping is not optional

Exploding gradients in an RNN destroy training in one step: a single batch with a very large gradient norm pushes parameters into a region where the loss is undefined or astronomically large, and the optimizer never recovers. The standard fix is gradient norm clipping:

gmin ⁣(1,θg)gg \leftarrow \min\!\left(1, \frac{\theta}{\|g\|}\right) g

If the gradient norm exceeds the threshold θ\theta, scale the entire gradient vector down so its norm equals θ\theta. This preserves direction but caps magnitude. In PyTorch:

def clip_gradients(net, theta):
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net
    norm = torch.sqrt(sum((p.grad ** 2).sum() for p in params))
    if norm > theta:
        for p in params:
            p.grad[:] *= theta / norm

Typical values of θ\theta are 1.0 to 5.0 for character-level language models. Clipping does not help vanishing gradients — that needs architectural changes (gated cells, residual connections, attention) — but it makes the difference between an RNN that trains and one that diverges in the first epoch.

Truncated BPTT: bounding the unroll length

Backpropagating through 10,000 time steps of a long document is computationally infeasible and hurts gradient quality anyway. Truncated BPTT processes the sequence in chunks of length kk and only backpropagates within each chunk. Hidden states are passed forward (so the model still sees long context at the forward pass), but their gradient graphs are detached at chunk boundaries:

for X, Y in data_iter:
    if state is None or use_random_iter:
        state = init_state(batch_size, num_hiddens, device)
    else:
        # Detach to truncate BPTT — keep value, drop gradient history
        if isinstance(state, tuple):
            state = tuple(s.detach() for s in state)
        else:
            state = state.detach()

    y_hat, state = net(X, state)
    loss = criterion(y_hat, Y).mean()
    optimizer.zero_grad()
    loss.backward()
    clip_gradients(net, theta=1.0)
    optimizer.step()

Random sampling of subsequences (no carry-over) and sequential partitioning (state carried between batches with detach) are the two standard regimes. Sequential partitioning gives the model longer effective context but requires the detach step every batch.

Generation: hidden state warmup, then sample

After training, language generation works by feeding a prompt one token at a time to warm up the hidden state, then sampling from the model's output distribution to extend the sequence:

def predict(prefix, num_preds, net, vocab, device):
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]],
                                     device=device).reshape(1, 1)
    # Warmup: feed prefix, do not sample
    for c in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[c])
    # Generate: sample, append, continue
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

The warmup phase is essential. Without it, the hidden state starts at zero and the first generated tokens are decoded from a context-free state. Feeding the prefix populates the hidden state with the right conditioning before sampling begins.

Evaluating language models with perplexity

Cross-entropy loss is the training objective, but its raw value is hard to interpret. Perplexity is the standard reporting metric:

PPL=exp ⁣(1Ni=1NlogP(xix<i))\text{PPL} = \exp\!\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(x_i \mid x_{<i})\right)

Perplexity has a clean interpretation: it is the average number of equally likely choices the model considers at each position. A perfect language model has perplexity 1 (no uncertainty). A uniform model over a 10,000-token vocabulary has perplexity 10,000. Perplexity 100 means the model is, on average, choosing among 100 equally likely tokens — much better than random, far from perfect.

Reporting perplexity makes loss values comparable across models and datasets in a way that raw cross-entropy does not.

The main takeaway

A vanilla RNN is two equations: a recurrent hidden state and a linear output. Time-axis parameter sharing keeps the parameter count constant in sequence length. Backpropagation through time unrolls the recurrence into a deep computation graph, which is why exploding and vanishing gradients are structural — they come from raising WhhW_{hh} to the power of the time gap.

Three engineering practices are non-negotiable: tanh as the hidden activation to bound state magnitude, gradient clipping to survive exploding gradients, and truncated BPTT to bound the unroll length. With those in place, even a from-scratch RNN can train on character-level language modeling. The vanishing-gradient half of the problem requires architectural changes, which is exactly the territory of GRU and LSTM.

Related reading

Continue the cluster — these articles build directly on the ideas above.

Ask AI ✨