Once the math of RNN, GRU, and LSTM cells is clear, the PyTorch implementation is mostly bookkeeping — input shapes, hidden-state initialization, gradient clipping, and a training loop that handles truncated BPTT correctly. This article is a practical guide to building, training, and evaluating sequence models in PyTorch, using a character-level language model as the running example.
Input shape conventions
PyTorch's nn.RNN, nn.GRU, and nn.LSTM share the same shape conventions. The default expects input shape (num_steps, batch_size, feature_dim) — time first. Setting batch_first=True swaps to (batch_size, num_steps, feature_dim), which matches conventions in most other frameworks and is generally easier to reason about.
A character-level language model with vocabulary size 28 (lowercase letters plus space and unknown), batch size 32, and 35 characters per sample produces tensors of shape (32, 35, 28) when using one-hot encoding. With an embedding layer, the embedded shape becomes (32, 35, embed_dim) for whatever embedding dimension is chosen.
A complete model: embeddings, LSTM, output projection
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharLanguageModel(nn.Module):
def __init__(self, vocab_size, embed_dim=64,
hidden_size=256, num_layers=2, dropout=0.3):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0.0,
batch_first=True,
)
self.output = nn.Linear(hidden_size, vocab_size)
def forward(self, x, state=None):
# x: (batch_size, num_steps) of token indices
embedded = self.embedding(x) # (B, T, E)
if state is None:
state = self.init_state(x.size(0), x.device)
out, state = self.lstm(embedded, state) # out: (B, T, H)
logits = self.output(out) # (B, T, vocab_size)
return logits, state
def init_state(self, batch_size, device):
h = torch.zeros(self.num_layers, batch_size,
self.hidden_size, device=device)
c = torch.zeros(self.num_layers, batch_size,
self.hidden_size, device=device)
return (h, c)Three pieces matter here. The embedding layer converts integer token indices to dense vectors — both far cheaper and far more expressive than one-hot encoding. The LSTM produces an output of shape (B, T, H) with one hidden vector per time step. The final linear layer projects each hidden vector to vocabulary-sized logits, ready for cross-entropy.
The hidden state for an LSTM is a tuple (h, c) with shapes (num_layers, batch_size, hidden_size) each. For nn.GRU, only h is needed and the state is a single tensor. For nn.RNN, also a single tensor.
Cross-entropy with sequence-shaped logits
nn.CrossEntropyLoss expects logits of shape (N, C) and targets of shape (N,). With sequence outputs of shape (B, T, V), the right pattern is to flatten:
logits, _ = model(x) # (B, T, V)
loss = F.cross_entropy(
logits.reshape(-1, vocab_size), # (B*T, V)
targets.reshape(-1), # (B*T,)
)Setting the loss to predict the next character at every position means each batch yields B * T training signals — much more efficient than predicting only the last position. PyTorch handles the math correctly as long as the shapes flatten consistently.
A complete training loop with detach and clipping
def train_one_epoch(model, data_iter, optimizer, device, theta=1.0):
model.train()
total_loss = 0.0
total_tokens = 0
state = None
for x, y in data_iter:
x, y = x.to(device), y.to(device)
# Truncated BPTT: detach state from previous batch
if state is not None:
state = (state[0].detach(), state[1].detach())
logits, state = model(x, state)
loss = F.cross_entropy(
logits.reshape(-1, model.vocab_size),
y.reshape(-1),
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), theta)
optimizer.step()
total_loss += loss.item() * y.numel()
total_tokens += y.numel()
return total_loss / total_tokensThree details are easy to get wrong:
Detaching the state. Without .detach(), every batch's gradient graph extends back through every previous batch, eventually exhausting memory. With detach, the value of the hidden state is preserved (so the model still has long context at the forward pass) but the gradient graph is cut at the batch boundary.
Gradient clipping. torch.nn.utils.clip_grad_norm_ rescales gradients in place. The threshold of 1.0 is conservative; values up to 5.0 are common.
Tracking tokens, not batches. Loss should be averaged per predicted token, not per batch, so that perplexity is computed correctly. Multiplying loss.item() by y.numel() recovers the total log-likelihood, which is then divided by total tokens to get per-token average loss.
Random sampling versus sequential partitioning
Long documents are sliced into training subsequences in two main regimes. Random sampling picks subsequences at random offsets and shuffles them. Each batch's hidden state must be reinitialized to zero because the subsequences are not adjacent in the original text.
Sequential partitioning picks subsequences that follow one another in the original text. This lets the model carry hidden state across batches, which improves long-range modeling but requires the detach step above.
# Random sampling — reinitialize state every batch
for x, y in data_iter:
state = model.init_state(x.size(0), device)
logits, _ = model(x, state)
...
# Sequential partitioning — carry and detach state
state = None
for x, y in data_iter:
if state is not None:
state = (state[0].detach(), state[1].detach())
logits, state = model(x, state)
...Sequential partitioning generally produces lower perplexity on small corpora because the model sees true cross-batch context. Random sampling is more robust when batches contain unrelated documents.
Evaluating with perplexity
@torch.no_grad()
def evaluate_perplexity(model, data_iter, device):
model.eval()
total_loss = 0.0
total_tokens = 0
state = None
for x, y in data_iter:
x, y = x.to(device), y.to(device)
if state is not None:
state = (state[0].detach(), state[1].detach())
logits, state = model(x, state)
loss = F.cross_entropy(
logits.reshape(-1, model.vocab_size),
y.reshape(-1),
reduction='sum',
)
total_loss += loss.item()
total_tokens += y.numel()
avg_nll = total_loss / total_tokens
return torch.exp(torch.tensor(avg_nll)).item()Perplexity is . A perfect model has perplexity 1; a uniform model over a vocabulary of size has perplexity . Reasonable LSTM language models on small corpora reach perplexity 2–10 on character level and 50–200 on word level.
Generation: warmup, then sample
@torch.no_grad()
def generate(model, prefix, num_chars, char_to_idx, idx_to_char,
temperature=1.0, device='cpu'):
model.eval()
state = model.init_state(1, device)
# Warmup: feed prefix to populate hidden state
input_ids = [char_to_idx[c] for c in prefix]
output = list(prefix)
x = torch.tensor([input_ids[:-1]], device=device)
_, state = model(x, state)
# Sample one token at a time
last_id = input_ids[-1]
for _ in range(num_chars):
x = torch.tensor([[last_id]], device=device)
logits, state = model(x, state)
logits = logits[0, -1] / temperature
probs = F.softmax(logits, dim=-1)
last_id = torch.multinomial(probs, num_samples=1).item()
output.append(idx_to_char[last_id])
return ''.join(output)Two design choices in this loop matter. The warmup phase feeds the prompt through the model without sampling, populating the hidden state with proper conditioning. Without warmup, generated tokens come from a context-free zero state.
The temperature parameter rescales logits before softmax. Lower temperatures (e.g., 0.5) make the distribution peakier — the model picks high-probability tokens more often, generating more conservative text. Higher temperatures (e.g., 1.5) flatten the distribution, encouraging diverse but riskier samples. temperature=1.0 uses the model's raw probabilities.
Sampling with torch.multinomial introduces randomness; using argmax would always pick the most likely next token, which often produces repetitive text.
Device, dtype, and performance details
Three details matter for practical training:
- Move state and inputs to the same device. Hidden state initialization must use
device=device, otherwise the first forward pass crashes with a device mismatch. - Use
nn.LSTM(..., batch_first=True). The default time-first layout is faster on some hardware but error-prone in user code. Most projects benefit from settingbatch_first=Trueconsistently. - Pin memory and prefetch in the DataLoader. RNN training is often I/O-bound at small model sizes. Setting
pin_memory=Trueandnum_workers=2in the DataLoader keeps the GPU fed.
When to stop training
Track validation perplexity and use early stopping with patience around 5–10 epochs. Sequence models can plateau and then resume improving after a learning rate drop, so combining early stopping with a learning rate scheduler (e.g., ReduceLROnPlateau) usually outperforms either alone.
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=2)
best_ppl = float('inf')
patience_counter = 0
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, train_iter, optimizer, device)
val_ppl = evaluate_perplexity(model, val_iter, device)
scheduler.step(val_ppl)
if val_ppl < best_ppl:
best_ppl = val_ppl
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pt')
else:
patience_counter += 1
if patience_counter >= 10:
breakThe main takeaway
The PyTorch sequence-modeling pattern is short but every detail matters. Use nn.Embedding for token inputs, nn.LSTM or nn.GRU with batch_first=True for the recurrence, and a final linear layer for vocabulary-sized logits. In the training loop, detach state across batches for truncated BPTT, clip gradients to handle exploding gradients, and average loss per token for clean perplexity computation.
For evaluation, perplexity gives a directly interpretable number. For generation, warm up the hidden state with a prompt before sampling, and tune temperature to balance fluency against diversity. With these patterns in place, scaling from a toy character model to a multi-layer word-level model is mostly a matter of swapping the dataset, growing the embedding and hidden dimensions, and waiting longer.