A recurrent neural network is the simplest answer to the question "how do I let a model see arbitrarily long history?" — share one set of weights across time, maintain a hidden state, and update that state with every new input. The architecture fits in a few lines of code, but training it correctly requires understanding backpropagation through time, gradient clipping, and where the parameter explosion you might expect from variable-length sequences actually goes.
The RNN recurrence
A vanilla RNN is defined by two equations applied at every time step :
The hidden state is the running summary. It depends on the current input through and on the previous hidden state through . The activation is usually tanh, which keeps the hidden state bounded in . The output is a linear projection of the hidden state — for language modeling, those would be the logits over the vocabulary.
The single most important property of these equations is that the same parameters , , are used at every time step. Time-axis parameter sharing is the recurrent counterpart of weight sharing in convolution. Parameter count is constant in sequence length. A 4-token sequence and a 4000-token sequence use exactly the same number of parameters.
Why tanh, not ReLU
In feedforward networks, ReLU is the default. In RNNs, tanh dominates. The reason is that the hidden state is fed back into itself. With ReLU, an unbounded positive activation can grow without limit: can produce indefinitely. After hundreds of steps, hidden states explode.
Tanh squeezes its output into , which prevents this kind of magnitude blow-up. The trade-off is that tanh saturates: when the pre-activation is large in absolute value, the gradient is nearly zero. That contributes to the vanishing gradient problem, which is the deeper reason gated cells like LSTM and GRU were eventually invented.
A from-scratch forward pass
The cleanest way to internalize the recurrence is to build the forward pass without using any high-level RNN module. The pattern: initialize parameters, define a step function, loop over time:
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
W_xh = normal((num_inputs, num_hiddens))
W_hh = normal((num_hiddens, num_hiddens))
b_h = torch.zeros(num_hiddens, device=device)
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
params = [W_xh, W_hh, b_h, W_hq, b_q]
for p in params:
p.requires_grad_(True)
return params
def init_rnn_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),)
def rnn_forward(inputs, state, params):
W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs: # X: (batch_size, vocab_size)
H = torch.tanh(X @ W_xh + H @ W_hh + b_h) # (batch_size, num_hiddens)
Y = H @ W_hq + b_q # (batch_size, vocab_size)
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)The input shape conventions deserve attention. inputs is iterated along the time dimension, so each X is one time slice with shape (batch_size, vocab_size). The Python for loop over time is intentional — it makes the recurrence visible. PyTorch's built-in nn.RNN compiles the same logic into optimized CUDA kernels, but the math is identical.
What backpropagation through time is actually doing
Training an RNN means minimizing a loss summed over time. For language modeling with cross-entropy at every position:
Since the same weights are used at every time step, the gradient of with respect to accumulates contributions from every time step. This is what backpropagation through time (BPTT) does. The chain rule unrolls the recurrence:
Each contains a chain back through every previous time step, because depends on , which depends on , and so on. Conceptually, the unrolled computation graph is a deep feedforward network whose depth equals the sequence length, and BPTT is just standard backpropagation applied to that unrolled graph.
Where vanishing and exploding gradients come from
The gradient of the hidden state at time with respect to the hidden state at time involves a product of Jacobians:
Whatever the spectral radius of is gets raised to the power of the gap . If the largest singular value is below 1, the product shrinks exponentially — gradients from far back vanish. If it is above 1, the product grows exponentially — gradients explode.
This is the same vanishing/exploding gradient problem that plagues very deep feedforward networks, but in an RNN it is structurally guaranteed because the same weight matrix is multiplied many times. The depth of the computation graph along the time axis can easily reach hundreds.
Gradient clipping is not optional
Exploding gradients in an RNN destroy training in one step: a single batch with a very large gradient norm pushes parameters into a region where the loss is undefined or astronomically large, and the optimizer never recovers. The standard fix is gradient norm clipping:
If the gradient norm exceeds the threshold , scale the entire gradient vector down so its norm equals . This preserves direction but caps magnitude. In PyTorch:
def clip_gradients(net, theta):
if isinstance(net, nn.Module):
params = [p for p in net.parameters() if p.requires_grad]
else:
params = net
norm = torch.sqrt(sum((p.grad ** 2).sum() for p in params))
if norm > theta:
for p in params:
p.grad[:] *= theta / normTypical values of are 1.0 to 5.0 for character-level language models. Clipping does not help vanishing gradients — that needs architectural changes (gated cells, residual connections, attention) — but it makes the difference between an RNN that trains and one that diverges in the first epoch.
Truncated BPTT: bounding the unroll length
Backpropagating through 10,000 time steps of a long document is computationally infeasible and hurts gradient quality anyway. Truncated BPTT processes the sequence in chunks of length and only backpropagates within each chunk. Hidden states are passed forward (so the model still sees long context at the forward pass), but their gradient graphs are detached at chunk boundaries:
for X, Y in data_iter:
if state is None or use_random_iter:
state = init_state(batch_size, num_hiddens, device)
else:
# Detach to truncate BPTT — keep value, drop gradient history
if isinstance(state, tuple):
state = tuple(s.detach() for s in state)
else:
state = state.detach()
y_hat, state = net(X, state)
loss = criterion(y_hat, Y).mean()
optimizer.zero_grad()
loss.backward()
clip_gradients(net, theta=1.0)
optimizer.step()Random sampling of subsequences (no carry-over) and sequential partitioning (state carried between batches with detach) are the two standard regimes. Sequential partitioning gives the model longer effective context but requires the detach step every batch.
Generation: hidden state warmup, then sample
After training, language generation works by feeding a prompt one token at a time to warm up the hidden state, then sampling from the model's output distribution to extend the sequence:
def predict(prefix, num_preds, net, vocab, device):
state = net.begin_state(batch_size=1, device=device)
outputs = [vocab[prefix[0]]]
get_input = lambda: torch.tensor([outputs[-1]],
device=device).reshape(1, 1)
# Warmup: feed prefix, do not sample
for c in prefix[1:]:
_, state = net(get_input(), state)
outputs.append(vocab[c])
# Generate: sample, append, continue
for _ in range(num_preds):
y, state = net(get_input(), state)
outputs.append(int(y.argmax(dim=1).reshape(1)))
return ''.join([vocab.idx_to_token[i] for i in outputs])The warmup phase is essential. Without it, the hidden state starts at zero and the first generated tokens are decoded from a context-free state. Feeding the prefix populates the hidden state with the right conditioning before sampling begins.
Evaluating language models with perplexity
Cross-entropy loss is the training objective, but its raw value is hard to interpret. Perplexity is the standard reporting metric:
Perplexity has a clean interpretation: it is the average number of equally likely choices the model considers at each position. A perfect language model has perplexity 1 (no uncertainty). A uniform model over a 10,000-token vocabulary has perplexity 10,000. Perplexity 100 means the model is, on average, choosing among 100 equally likely tokens — much better than random, far from perfect.
Reporting perplexity makes loss values comparable across models and datasets in a way that raw cross-entropy does not.
The main takeaway
A vanilla RNN is two equations: a recurrent hidden state and a linear output. Time-axis parameter sharing keeps the parameter count constant in sequence length. Backpropagation through time unrolls the recurrence into a deep computation graph, which is why exploding and vanishing gradients are structural — they come from raising to the power of the time gap.
Three engineering practices are non-negotiable: tanh as the hidden activation to bound state magnitude, gradient clipping to survive exploding gradients, and truncated BPTT to bound the unroll length. With those in place, even a from-scratch RNN can train on character-level language modeling. The vanishing-gradient half of the problem requires architectural changes, which is exactly the territory of GRU and LSTM.