MathIsimple
RNNs & Sequence Modeling
12 min read

PyTorch Sequence Models: A Practical Training Guide

Embeddings, batch_first, detach, gradient clipping, perplexity, and warmup-based generation.

PyTorchLSTMSequence TrainingPerplexityText Generation

Once the math of RNN, GRU, and LSTM cells is clear, the PyTorch implementation is mostly bookkeeping — input shapes, hidden-state initialization, gradient clipping, and a training loop that handles truncated BPTT correctly. This article is a practical guide to building, training, and evaluating sequence models in PyTorch, using a character-level language model as the running example.

Input shape conventions

PyTorch's nn.RNN, nn.GRU, and nn.LSTM share the same shape conventions. The default expects input shape (num_steps, batch_size, feature_dim) — time first. Setting batch_first=True swaps to (batch_size, num_steps, feature_dim), which matches conventions in most other frameworks and is generally easier to reason about.

A character-level language model with vocabulary size 28 (lowercase letters plus space and unknown), batch size 32, and 35 characters per sample produces tensors of shape (32, 35, 28) when using one-hot encoding. With an embedding layer, the embedded shape becomes (32, 35, embed_dim) for whatever embedding dimension is chosen.

A complete model: embeddings, LSTM, output projection

import torch
import torch.nn as nn
import torch.nn.functional as F


class CharLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64,
                 hidden_size=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0.0,
            batch_first=True,
        )
        self.output = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, state=None):
        # x: (batch_size, num_steps) of token indices
        embedded = self.embedding(x)             # (B, T, E)
        if state is None:
            state = self.init_state(x.size(0), x.device)
        out, state = self.lstm(embedded, state)  # out: (B, T, H)
        logits = self.output(out)                # (B, T, vocab_size)
        return logits, state

    def init_state(self, batch_size, device):
        h = torch.zeros(self.num_layers, batch_size,
                        self.hidden_size, device=device)
        c = torch.zeros(self.num_layers, batch_size,
                        self.hidden_size, device=device)
        return (h, c)

Three pieces matter here. The embedding layer converts integer token indices to dense vectors — both far cheaper and far more expressive than one-hot encoding. The LSTM produces an output of shape (B, T, H) with one hidden vector per time step. The final linear layer projects each hidden vector to vocabulary-sized logits, ready for cross-entropy.

The hidden state for an LSTM is a tuple (h, c) with shapes (num_layers, batch_size, hidden_size) each. For nn.GRU, only h is needed and the state is a single tensor. For nn.RNN, also a single tensor.

Cross-entropy with sequence-shaped logits

nn.CrossEntropyLoss expects logits of shape (N, C) and targets of shape (N,). With sequence outputs of shape (B, T, V), the right pattern is to flatten:

logits, _ = model(x)                       # (B, T, V)
loss = F.cross_entropy(
    logits.reshape(-1, vocab_size),         # (B*T, V)
    targets.reshape(-1),                     # (B*T,)
)

Setting the loss to predict the next character at every position means each batch yields B * T training signals — much more efficient than predicting only the last position. PyTorch handles the math correctly as long as the shapes flatten consistently.

A complete training loop with detach and clipping

def train_one_epoch(model, data_iter, optimizer, device, theta=1.0):
    model.train()
    total_loss = 0.0
    total_tokens = 0
    state = None

    for x, y in data_iter:
        x, y = x.to(device), y.to(device)

        # Truncated BPTT: detach state from previous batch
        if state is not None:
            state = (state[0].detach(), state[1].detach())

        logits, state = model(x, state)
        loss = F.cross_entropy(
            logits.reshape(-1, model.vocab_size),
            y.reshape(-1),
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), theta)
        optimizer.step()

        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()

    return total_loss / total_tokens

Three details are easy to get wrong:

Detaching the state. Without .detach(), every batch's gradient graph extends back through every previous batch, eventually exhausting memory. With detach, the value of the hidden state is preserved (so the model still has long context at the forward pass) but the gradient graph is cut at the batch boundary.

Gradient clipping. torch.nn.utils.clip_grad_norm_ rescales gradients in place. The threshold θ\theta of 1.0 is conservative; values up to 5.0 are common.

Tracking tokens, not batches. Loss should be averaged per predicted token, not per batch, so that perplexity is computed correctly. Multiplying loss.item() by y.numel() recovers the total log-likelihood, which is then divided by total tokens to get per-token average loss.

Random sampling versus sequential partitioning

Long documents are sliced into training subsequences in two main regimes. Random sampling picks subsequences at random offsets and shuffles them. Each batch's hidden state must be reinitialized to zero because the subsequences are not adjacent in the original text.

Sequential partitioning picks subsequences that follow one another in the original text. This lets the model carry hidden state across batches, which improves long-range modeling but requires the detach step above.

# Random sampling — reinitialize state every batch
for x, y in data_iter:
    state = model.init_state(x.size(0), device)
    logits, _ = model(x, state)
    ...

# Sequential partitioning — carry and detach state
state = None
for x, y in data_iter:
    if state is not None:
        state = (state[0].detach(), state[1].detach())
    logits, state = model(x, state)
    ...

Sequential partitioning generally produces lower perplexity on small corpora because the model sees true cross-batch context. Random sampling is more robust when batches contain unrelated documents.

Evaluating with perplexity

@torch.no_grad()
def evaluate_perplexity(model, data_iter, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    state = None

    for x, y in data_iter:
        x, y = x.to(device), y.to(device)
        if state is not None:
            state = (state[0].detach(), state[1].detach())
        logits, state = model(x, state)
        loss = F.cross_entropy(
            logits.reshape(-1, model.vocab_size),
            y.reshape(-1),
            reduction='sum',
        )
        total_loss += loss.item()
        total_tokens += y.numel()

    avg_nll = total_loss / total_tokens
    return torch.exp(torch.tensor(avg_nll)).item()

Perplexity is exp(average per-token negative log-likelihood)\exp(\text{average per-token negative log-likelihood}). A perfect model has perplexity 1; a uniform model over a vocabulary of size VV has perplexity VV. Reasonable LSTM language models on small corpora reach perplexity 2–10 on character level and 50–200 on word level.

Generation: warmup, then sample

@torch.no_grad()
def generate(model, prefix, num_chars, char_to_idx, idx_to_char,
             temperature=1.0, device='cpu'):
    model.eval()
    state = model.init_state(1, device)

    # Warmup: feed prefix to populate hidden state
    input_ids = [char_to_idx[c] for c in prefix]
    output = list(prefix)
    x = torch.tensor([input_ids[:-1]], device=device)
    _, state = model(x, state)

    # Sample one token at a time
    last_id = input_ids[-1]
    for _ in range(num_chars):
        x = torch.tensor([[last_id]], device=device)
        logits, state = model(x, state)
        logits = logits[0, -1] / temperature
        probs = F.softmax(logits, dim=-1)
        last_id = torch.multinomial(probs, num_samples=1).item()
        output.append(idx_to_char[last_id])

    return ''.join(output)

Two design choices in this loop matter. The warmup phase feeds the prompt through the model without sampling, populating the hidden state with proper conditioning. Without warmup, generated tokens come from a context-free zero state.

The temperature parameter rescales logits before softmax. Lower temperatures (e.g., 0.5) make the distribution peakier — the model picks high-probability tokens more often, generating more conservative text. Higher temperatures (e.g., 1.5) flatten the distribution, encouraging diverse but riskier samples. temperature=1.0 uses the model's raw probabilities.

Sampling with torch.multinomial introduces randomness; using argmax would always pick the most likely next token, which often produces repetitive text.

Device, dtype, and performance details

Three details matter for practical training:

  • Move state and inputs to the same device. Hidden state initialization must use device=device, otherwise the first forward pass crashes with a device mismatch.
  • Use nn.LSTM(..., batch_first=True). The default time-first layout is faster on some hardware but error-prone in user code. Most projects benefit from setting batch_first=True consistently.
  • Pin memory and prefetch in the DataLoader. RNN training is often I/O-bound at small model sizes. Setting pin_memory=True and num_workers=2 in the DataLoader keeps the GPU fed.

When to stop training

Track validation perplexity and use early stopping with patience around 5–10 epochs. Sequence models can plateau and then resume improving after a learning rate drop, so combining early stopping with a learning rate scheduler (e.g., ReduceLROnPlateau) usually outperforms either alone.

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2)

best_ppl = float('inf')
patience_counter = 0
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_iter, optimizer, device)
    val_ppl = evaluate_perplexity(model, val_iter, device)
    scheduler.step(val_ppl)

    if val_ppl < best_ppl:
        best_ppl = val_ppl
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= 10:
            break

The main takeaway

The PyTorch sequence-modeling pattern is short but every detail matters. Use nn.Embedding for token inputs, nn.LSTM or nn.GRU with batch_first=True for the recurrence, and a final linear layer for vocabulary-sized logits. In the training loop, detach state across batches for truncated BPTT, clip gradients to handle exploding gradients, and average loss per token for clean perplexity computation.

For evaluation, perplexity gives a directly interpretable number. For generation, warm up the hidden state with a prompt before sampling, and tune temperature to balance fluency against diversity. With these patterns in place, scaling from a toy character model to a multi-layer word-level model is mostly a matter of swapping the dataset, growing the embedding and hidden dimensions, and waiting longer.

Related reading

Continue the cluster — these articles build directly on the ideas above.

Ask AI ✨