Tin Rabzelj
Tin Rabzelj
Dashed Line

Sequence to Sequence Learning with Neural Networks | Paper Notes

8/6/2025

The seq2seq model is way to translate sequences from one form to another. For example, translation or converting speech to text. The core idea is: compress the input into a fixed-size representation, then expand it back out into the target sequence.

The architecture:

Encoder: takes an input sequence (like an English sentence) and reads it token by token using an LSTM network. LSTMs are particularly good at this because they can remember important information from earlier in the sequence, even when processing long sentences. The encoder's job is to compress that into a single fixed-size vector.

Decoder: takes that compressed representation and expands it back out into the target sequence. It's basically a language model that generates text, but it's conditioned on the input sequence's meaning. The decoder predicts one token at a time, using both the encoded input and the tokens it's already generated.

Deep networks work better. Going deeper significantly improved performance. They used 4-layer LSTMs with 1000 cells each, totaling 384M parameters. Each additional layer reduced perplexity by about 10%.

Reversing the order of words in the source sentence (while keeping the target in normal order) dramatically improved performance. If you're translating "a b c" to "α β γ", you instead train on "c b a" to "α β γ". They don't have a conclusive explanation on why this works. It supposedly creates more short-term dependencies between source and target words, making it easier for the model to learn the mapping.

The model outperformed traditional phrase-based translation systems and could handle very long sentences, something that was previously challenging for neural approaches.

Implementation

model.py
import torch
import torch.nn as nn
import random


class Encoder(nn.Module):
    """
    Encodes the source sequence into a context vector.
    """

    def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, source):
        # source shape: [source_len, batch_size]
        embedded = self.dropout(self.embedding(source))
        # embedded shape: [source_len, batch_size, embedding_dim]

        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs shape: [source_len, batch_size, hidden_dim * n_directions]
        # hidden, cell shape: [n_layers * n_directions, batch_size, hidden_dim]

        return hidden, cell


class Decoder(nn.Module):
    """
    Decodes the context vector to produce the target sequence.
    """

    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # input shape: [batch_size]
        # hidden, cell shape: [n_layers * n_directions, batch_size, hidden_dim]
        input = input.unsqueeze(0)
        # input shape: [1, batch_size]

        embedded = self.dropout(self.embedding(input))
        # embedded shape: [1, batch_size, embedding_dim]

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # output shape: [seq_len, batch_size, hidden_dim * n_directions]
        # hidden, cell shape: [n_layers * n_directions, batch_size, hidden_dim]
        # seq_len is 1

        prediction = self.fc_out(output.squeeze(0))
        # prediction shape: [batch_size, output_dim]

        return prediction, hidden, cell


class Seq2SeqModel(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        # source shape: [len, batch_size]
        # target shape: [target_len, batch_size]

        batch_size = target.shape[1]
        target_len = target.shape[0]
        target_vocab_size = self.decoder.output_dim

        # Tensor to store decoder outputs.
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        # Encoder forward pass.
        hidden, cell = self.encoder(source)

        # First input to the decoder is the <SOS> token.
        input = target[0, :]

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output

            # Decide if we are going to use teacher forcing or not.
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)

            # If teacher forcing, use actual next token as next input.
            # If not, use predicted token.
            input = target[t] if teacher_force else top1

        return outputs
data.py
import torch
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import kagglehub
from kagglehub import KaggleDatasetAdapter

UNK_TOKEN = "<unk>"  # Unknown word
PAD_TOKEN = "<pad>"  # Padding
SOS_TOKEN = "<sos>"  # Start of Sequence
EOS_TOKEN = "<eos>"  # End of Sequence


class Vocabulary:
    """
    Manages the mapping between words and numerical indices.
    """

    def __init__(self, counter, specials, min_freq=1):
        self.word2idx = {token: i for i, token in enumerate(specials)}
        self.idx2word = {i: token for i, token in enumerate(specials)}

        for word, count in counter.items():
            if count >= min_freq:
                if word not in self.word2idx:
                    idx = len(self.word2idx)
                    self.word2idx[word] = idx
                    self.idx2word[idx] = word

    def __len__(self):
        return len(self.word2idx)

    def string_to_indices(self, s):
        tokens = s.lower().split()
        return [self.word2idx.get(token, self.word2idx[UNK_TOKEN]) for token in tokens]

    def indices_to_string(self, indices):
        return " ".join([self.idx2word.get(idx, UNK_TOKEN) for idx in indices])


class Seq2SeqDataset(Dataset):
    """
    Dataset for handling sequence pairs.
    """

    def __init__(
        self,
        source_sentences,
        target_sentences,
        source_vocab,
        target_vocab,
        reverse_source=True,
    ):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.reverse_source = reverse_source

    def __len__(self):
        return len(self.source_sentences)

    def __getitem__(self, idx):
        source = self.source_sentences[idx]
        target = self.target_sentences[idx]

        source_indices = self.source_vocab.string_to_indices(source)
        if self.reverse_source:
            source_indices.reverse()

        target_indices = self.target_vocab.string_to_indices(target)

        # Add Start-of-Sequence (SOS) and End-of-Sequence (EOS) tokens to every sentence.
        source_indices = (
            [self.source_vocab.word2idx[SOS_TOKEN]]
            + source_indices
            + [self.source_vocab.word2idx[EOS_TOKEN]]
        )
        target_indices = (
            [self.target_vocab.word2idx[SOS_TOKEN]]
            + target_indices
            + [self.target_vocab.word2idx[EOS_TOKEN]]
        )

        return torch.tensor(source_indices), torch.tensor(target_indices)


def collate_fn(batch):
    """
    Pads sequences in a batch to the same length.
    """
    sources, targets = zip(*batch)

    # This assumes load_and_prepare_data has been called.
    global source_vocab, target_vocab

    source_lengths = [len(s) for s in sources]
    target_lengths = [len(t) for t in targets]

    padded_sources = torch.full(
        (len(sources), max(source_lengths)),
        source_vocab.word2idx[PAD_TOKEN],
        dtype=torch.long,
    )
    padded_targets = torch.full(
        (len(targets), max(target_lengths)),
        target_vocab.word2idx[PAD_TOKEN],
        dtype=torch.long,
    )

    for i, s in enumerate(sources):
        end = source_lengths[i]
        padded_sources[i, :end] = s[:end]

    for i, t in enumerate(targets):
        end = target_lengths[i]
        padded_targets[i, :end] = t[:end]

    return padded_sources, padded_targets


def load_and_prepare_data(dataset_size=1_000):
    # Load dataset.
    ds = kagglehub.load_dataset(
        KaggleDatasetAdapter.HUGGING_FACE,
        "rajpulapakura/english-to-french-small-dataset",
        "english_french.csv",
    )
    source_sentences = [row for row in ds["English"][:dataset_size]]
    target_sentences = [row for row in ds["French"][:dataset_size]]

    # Build vocabularies.
    source_counter = Counter(" ".join(source_sentences).lower().split())
    target_counter = Counter(" ".join(target_sentences).lower().split())

    specials = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]

    global source_vocab, target_vocab
    source_vocab = Vocabulary(source_counter, specials)
    target_vocab = Vocabulary(target_counter, specials)

    print(f"Source vocabulary size: {len(source_vocab)}")
    print(f"Target vocabulary size: {len(target_vocab)}")

    # Create Dataset and DataLoader.
    dataset = Seq2SeqDataset(
        source_sentences, target_sentences, source_vocab, target_vocab
    )
    loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

    return loader, source_vocab, target_vocab
main.py
import time
import math
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader

from data import Vocabulary, load_and_prepare_data, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN
from model import Encoder, Decoder, Seq2SeqModel

N_EPOCHS = 100
DATASET_SIZE = 500  # -1
GRADIENT_CLIP = 1.0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINT_DIR = "./checkpoint"
MODEL_FILE = os.path.join(CHECKPOINT_DIR, "model.pth")


def train(
    model: Seq2SeqModel,
    train_iterator: DataLoader,
    source_vocab: Vocabulary,
    target_vocab: Vocabulary,
):
    # Initialize weights.
    def init_weights(m):
        for name, param in m.named_parameters():
            nn.init.uniform_(param.data, -0.08, 0.08)

    model.apply(init_weights)

    optimizer = optim.Adam(model.parameters())
    TARGET_PAD_IDX = target_vocab.word2idx[PAD_TOKEN]
    criterion = nn.CrossEntropyLoss(ignore_index=TARGET_PAD_IDX)

    print("Training...")
    best_train_loss = float("inf")

    for epoch in range(N_EPOCHS):
        start_time = time.time()

        train_loss = train_epoch(model, train_iterator, optimizer, criterion)

        end_time = time.time()

        epoch_mins = int((end_time - start_time) / 60)
        epoch_secs = int((end_time - start_time) - (epoch_mins * 60))

        # Save the best model.
        if train_loss < best_train_loss:
            best_train_loss = train_loss
            if not os.path.isdir(CHECKPOINT_DIR):
                os.makedirs(CHECKPOINT_DIR)
            torch.save(model.state_dict(), MODEL_FILE)
            print(f"|---> Best model saved to {MODEL_FILE}")

        print(f"Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s")
        print(f"\tTrain Loss: {train_loss:.3f} | PPL: {math.exp(train_loss):7.3f}")

    print("\nTraining finished.")


def train_epoch(
    model: Seq2SeqModel,
    iterator: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
):
    model.train()
    epoch_loss = 0

    for i, batch in enumerate(iterator):
        source, target = batch
        source = source.to(DEVICE).permute(1, 0)  # LSTM expects [seq_len, batch]
        target = target.to(DEVICE).permute(1, 0)

        optimizer.zero_grad()

        output = model(source, target)

        # target shape: [target_len, batch_size]
        # output shape: [target_len, batch_size, output_dim]

        output_dim = output.shape[-1]

        output = output[1:].reshape(-1, output_dim)
        target = target[1:].reshape(-1)

        loss = criterion(output, target)
        loss.backward()

        # Clip gradients to prevent them from exploding.
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


def translate_sentence(
    sentence: str,
    model: Seq2SeqModel,
    source_vocab: Vocabulary,
    target_vocab: Vocabulary,
    device: torch.device,
    max_len=50,
):
    model.eval()

    # Pre-process the sentence.
    tokens = [token.lower() for token in sentence.split()]
    # The source sentence must be reversed, which matches the training process.
    tokens.reverse()
    tokens = [SOS_TOKEN] + tokens + [EOS_TOKEN]
    source_indexes = [
        source_vocab.word2idx.get(token, source_vocab.word2idx["<unk>"])
        for token in tokens
    ]
    source_tensor = torch.LongTensor(source_indexes).unsqueeze(1).to(device)

    # Encoder pass.
    with torch.no_grad():
        hidden, cell = model.encoder(source_tensor)

    # Decoder pass.
    target_indexes = [target_vocab.word2idx[SOS_TOKEN]]

    # The initial input to the decoder is the <sos> token
    input = torch.LongTensor([target_indexes[-1]]).to(device)

    for _ in range(max_len):
        with torch.no_grad():
            output, hidden, cell = model.decoder(input, hidden, cell)

        pred_token = output.argmax(1).item()
        target_indexes.append(pred_token)

        if pred_token == target_vocab.word2idx[EOS_TOKEN]:
            break

        input = torch.LongTensor([pred_token]).to(device)

    target_tokens = [target_vocab.idx2word[i] for i in target_indexes]

    return " ".join(target_tokens[1:-1])  # Exclude <sos> and <eos>


def main():
    train_iterator, source_vocab, target_vocab = load_and_prepare_data(
        dataset_size=DATASET_SIZE
    )

    # Build the model
    INPUT_DIM = len(source_vocab)
    OUTPUT_DIM = len(target_vocab)
    ENCODER_EMBEDDING_DIM = 64
    DECODER_EMBEDDING_DIM = 64
    HIDDEN_DIM = 512
    N_LAYERS = 2
    ENCODER_DROPOUT = 0.2
    DECODER_DROPOUT = 0.2

    encoder = Encoder(
        INPUT_DIM, ENCODER_EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, ENCODER_DROPOUT
    )
    decoder = Decoder(
        OUTPUT_DIM, DECODER_EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DECODER_DROPOUT
    )
    model = Seq2SeqModel(encoder, decoder, DEVICE).to(DEVICE)

    train(model, train_iterator, source_vocab, target_vocab)

    model.load_state_dict(torch.load(MODEL_FILE))

    expected_sentence = "Hello!"
    translation = translate_sentence(
        expected_sentence, model, source_vocab, target_vocab, DEVICE
    )
    print("\n--- INFERENCE ---")
    print(f"Source: {expected_sentence}")
    print(f"Predicted Translation: {translation}")


if __name__ == "__main__":
    main()

8/6/2025

Read more