Tin Rabzelj
Tin Rabzelj
Dashed Line

Neural Machine Translation by Jointly Learning to Align and Translate | Paper Notes

8/10/2025

This paper introduces the attention mechanism for neural machine translation. It proposes the RNNsearch model as an extension to the basic encoder-decoder architecture.

The goal is replacing the fixed-length context vector with a dynamic attention mechanism. Instead of compressing the entire source sentence into a single vector, the model uses "soft alignment" to selectively focus on different parts of the source sentence when generating each target word.

Architecture

Encoder: uses bidirectional RNN (BRNN).

  • Forward RNN: Reads from x1x_1 to xTxx_{T_x}, producing forward hidden states h1,...,hTx\overrightarrow{h_1}, ..., \overrightarrow{h_{T_x}}.
  • Backward RNN: Reads from xTxx_{T_x} to x1x_1, producing backward hidden states h1,...,hTx\overleftarrow{h_1}, ..., \overleftarrow{h_{T_x}}.
  • Annotations: Each word gets an annotation hj=[hjT;hjT]Th_j = [\overrightarrow{h_j}^T; \overleftarrow{h_j}^T]^T by concatenating forward and backward states.

This bidirectional approach ensures each annotation contains information about both preceding and following words.

Decoder: specifies attention.

The decoder generates each target word yiy_i using:

p(yiy1,...,yi1,x)=g(yi1,si,ci),p(y_i|y_1, ..., y_{i-1}, x) = g(y_{i-1}, s_i, c_i),

where:

  • sis_i is the decoder's hidden state at step ii.
  • cic_i is a dynamic context vector computed for each target word.

The context vector cic_i is computed as a weighted sum of all source annotations:

ci=j=1Txαijhj.c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j.

The attention weights αij\alpha_{ij} represent how much attention to pay to source word jj when generating target word ii:

αij=exp(eij)k=1Txexp(eik),\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})},

where eij=a(si1,hj)e_{ij} = a(s_{i-1}, h_j) is an alignment model (implemented as a feedforward neural network) that scores how well the input around position jj matches the output at position ii.

Key architectural advantages

Variable-length encoding: this alleviates the bottleneck of compressing all information into a single vector, which benefits longer sentences.

Selective attention: allows the model to dynamically focus on the most relevant parts of the input when generating each target word, rather than processing the entire sentence uniformly.

Soft alignment: Differentiable alignment allows end-to-end training.

Attention and soft alignment ensure that the entire translation system, including the alignment mechanism, can be jointly trained in an end-to-end fashion, leading to more robust and accurate translations, especially for longer inputs where traditional methods struggle.

This architecture was groundbreaking because it introduced the attention mechanism that became fundamental to modern NLP.

Implementation

model.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        emb_dim: int,
        enc_hid_dim: int,
        dec_hid_dim: int,
        n_layers: int = 1,
        dropout: float = 0.0,
    ):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(
            emb_dim, enc_hid_dim, n_layers, bidirectional=True, dropout=dropout
        )
        self.dropout = nn.Dropout(dropout)
        # FC layer must map the concatenated encoder states to the decoder's hidden dimension
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

    def forward(self, src: torch.Tensor):
        # src: [src_len, batch_size]
        embedded = self.dropout(self.embedding(src))
        # embedded: [src_len, batch_size, emb_dim]
        outputs, hidden = self.rnn(embedded)
        # outputs: [src_len, batch_size, enc_hid_dim * 2] (bidirectional)
        # hidden: [n_layers * 2, batch_size, enc_hid_dim]

        # Concatenate the final forward and backward hidden states
        # and pass it through the FC layer to create the initial decoder hidden state.
        s = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        hidden = torch.tanh(self.fc(s)).unsqueeze(0)
        # hidden: [1, batch_size, dec_hid_dim]

        return outputs, hidden


class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention, self).__init__()
        # The encoder output dimension is enc_hid_dim * 2 due to bidirectionality.
        self.W = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
        self.U = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden: [1, batch_size, dec_hid_dim]
        # encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        src_len = encoder_outputs.shape[0]

        # Repeat the decoder hidden state for each encoder time step.
        hidden_repeated = hidden.repeat(src_len, 1, 1)
        # hidden_repeated: [src_len, batch_size, dec_hid_dim]

        # Calculate alignment scores.
        energy = torch.tanh(self.W(hidden_repeated) + self.U(encoder_outputs))
        # energy: [src_len, batch_size, dec_hid_dim]

        attention = self.v(energy).squeeze(2)
        # attention: [src_len, batch_size]

        return F.softmax(attention, dim=0)


class Decoder(nn.Module):
    def __init__(
        self,
        output_dim: int,
        emb_dim: int,
        enc_hid_dim: int,
        dec_hid_dim: int,
        n_layers: int = 1,
        dropout: float = 0.0,
    ):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.attention = Attention(enc_hid_dim, dec_hid_dim)
        # The context vector will have dimension enc_hid_dim * 2.
        self.rnn = nn.GRU(
            emb_dim + (enc_hid_dim * 2), dec_hid_dim, n_layers, dropout=dropout
        )
        self.fc_out = nn.Linear(emb_dim + dec_hid_dim + (enc_hid_dim * 2), output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self, input: torch.Tensor, hidden: torch.Tensor, encoder_outputs: torch.Tensor
    ):
        # input: [batch_size]
        # hidden: [1, batch_size, dec_hid_dim]
        # encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]

        input = input.unsqueeze(0)
        # input: [1, batch_size]
        embedded = self.dropout(self.embedding(input))
        # embedded: [1, batch_size, emb_dim]

        # Get attention weights
        attn_weights = self.attention(hidden, encoder_outputs)
        # attn_weights: [src_len, batch_size]

        # Compute context vector (weighted sum of encoder outputs)
        # encoder_outputs is [src_len, batch_size, enc_hid_dim * 2]
        context = torch.einsum("sb,sbd->bd", attn_weights, encoder_outputs)
        context = context.unsqueeze(0)
        # context: [1, batch_size, enc_hid_dim * 2]

        rnn_input = torch.cat((embedded, context), dim=2)
        # rnn_input: [1, batch_size, emb_dim + enc_hid_dim * 2]

        output, hidden = self.rnn(rnn_input, hidden)
        # output: [1, batch_size, dec_hid_dim]
        # hidden: [n_layers, batch_size, dec_hid_dim]

        # Predict next token.
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        context = context.squeeze(0)
        prediction = self.fc_out(torch.cat((embedded, output, context), dim=1))
        # prediction: [batch_size, output_dim]

        return prediction, hidden, attn_weights


class RNNSearchModel(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, device: torch.device):
        super(RNNSearchModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(
        self, src: torch.Tensor, trg: torch.Tensor, teacher_forcing_ratio: float = 0.5
    ):
        # src: [src_len, batch_size]
        # trg: [trg_len, batch_size]

        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)

        input = trg[0, :]  # <sos>
        for t in range(1, trg_len):
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            outputs[t] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1

        return outputs
utility.py
from mosestokenizer import MosesTokenizer

tokenize_en = MosesTokenizer("en")
tokenize_fr = MosesTokenizer("fr")


def tokenize(text: str, lang: str) -> list[str]:
    """
    Tokenizes a text string for a given language.
    """
    if lang == "en":
        return tokenize_en(text)
    elif lang == "fr":
        return tokenize_fr(text)
    else:
        raise ValueError(f"Unsupported language: {lang}")
main.py
import os

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import Encoder, Decoder, RNNSearchModel
from data import load_and_prepare_data, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
from utility import tokenize

CHECKPOINT_DIR = "./checkpoint"
MODEL_FILE = os.path.join(CHECKPOINT_DIR, "model.pth")


ENC_EMB_DIM = 512
DEC_EMB_DIM = 512
ENC_HID_DIM = 500
DEC_HID_DIM = 1000
ENC_N_LAYERS = 1
DEC_N_LAYERS = 1
ENC_DROPOUT = 0.0
DEC_DROPOUT = 0.0

N_EPOCHS = 10
CLIP = 1.0  # Gradient clipping value
TEACHER_FORCING_RATIO = 0.5


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_model(model, train_iterator, trg_pad_idx):
    # Initialize weights.
    def init_weights(m):
        for name, param in m.named_parameters():
            if "weight" in name:
                nn.init.normal_(param.data, mean=0, std=0.01)
            else:
                nn.init.constant_(param.data, 0)

    model.apply(init_weights)

    # Define optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)

    best_loss = float("inf")

    # Training loop
    for epoch in range(N_EPOCHS):
        train_loss = train_epoch(
            model,
            train_iterator,
            optimizer,
            criterion,
            CLIP,
            TEACHER_FORCING_RATIO,
            device,
        )

        if train_loss < best_loss:
            best_loss = train_loss
            torch.save(model.state_dict(), MODEL_FILE)

        print(f"Epoch: {epoch + 1:02}")
        print(f"\tTrain Loss: {train_loss:.3f}")


def train_epoch(
    model, iterator, optimizer, criterion, clip, teacher_forcing_ratio, device
):
    model.train()
    epoch_loss = 0

    for i, batch in enumerate(tqdm(iterator)):
        src, trg = batch
        src = src.to(device).t()  # Transpose for the model
        trg = trg.to(device).t()  # Transpose for the model

        optimizer.zero_grad()

        output = model(src, trg, teacher_forcing_ratio)
        # trg = [trg_len, batch_size]
        # output = [trg_len, batch_size, output_dim]

        output_dim = output.shape[-1]
        output = output[1:].reshape(-1, output_dim)
        trg = trg[1:].reshape(-1)
        # output = [(trg_len - 1) * batch_size, output_dim]
        # trg = [(trg_len - 1) * batch_size]

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


def translate_sentence(sentence, source_vocab, target_vocab, model, device, max_len=50):
    model.eval()

    # Use the same tokenizer as the training data.
    if isinstance(sentence, str):
        tokens = tokenize(sentence, "en")
    else:
        # If the input is already a list, treat it as pre-tokenized.
        tokens = [token for token in sentence]

    # Convert to indices, handling unknown words. The vocabulary is case-sensitive.
    src_indexes = [
        source_vocab.word2idx.get(token, source_vocab.word2idx[UNK_TOKEN])
        for token in tokens
    ]

    # Add SOS and EOS tokens.
    src_indexes = (
        [source_vocab.word2idx[SOS_TOKEN]]
        + src_indexes
        + [source_vocab.word2idx[EOS_TOKEN]]
    )

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)

    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)

    trg_indexes = [target_vocab.word2idx[SOS_TOKEN]]
    for i in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        with torch.no_grad():
            output, hidden, _ = model.decoder(trg_tensor, hidden, encoder_outputs)

        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)

        if pred_token == target_vocab.word2idx[EOS_TOKEN]:
            break

    # Return the translation, excluding SOS and EOS tokens.
    trg_tokens = [target_vocab.idx2word.get(i, UNK_TOKEN) for i in trg_indexes]
    return " ".join(trg_tokens[1:-1])


def main():
    # Load data and vocabularies.
    train_iterator, source_vocab, target_vocab = load_and_prepare_data()

    INPUT_DIM = len(source_vocab)
    OUTPUT_DIM = len(target_vocab)
    TRG_PAD_IDX = target_vocab.word2idx[PAD_TOKEN]

    # Initialize models.
    encoder = Encoder(
        INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_N_LAYERS, ENC_DROPOUT
    )
    decoder = Decoder(
        OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_N_LAYERS, DEC_DROPOUT
    )
    model = RNNSearchModel(encoder, decoder, device).to(device)

    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    train_model(model, train_iterator, TRG_PAD_IDX)

    if os.path.exists(MODEL_FILE):
        print(f"Loading best model from {MODEL_FILE}")
        model.load_state_dict(torch.load(MODEL_FILE, map_location=device))

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model has {total_params:,} parameters")

    sentences = [
        "Hello!",
    ]
    for sentence in sentences:
        translation = translate_sentence(
            sentence, source_vocab, target_vocab, model, device
        )
        print(f"\nSource: {sentence}")
        print(f"Translated: {translation}")


if __name__ == "__main__":
    main()

8/10/2025

Read more