Neural Machine Translation by Jointly Learning to Align and Translate | Paper Notes
8/10/2025
This paper introduces the attention mechanism for neural machine translation. It proposes the RNNsearch model as an extension to the basic encoder-decoder architecture.
The goal is replacing the fixed-length context vector with a dynamic attention mechanism. Instead of compressing the entire source sentence into a single vector, the model uses "soft alignment" to selectively focus on different parts of the source sentence when generating each target word.
Architecture
Encoder: uses bidirectional RNN (BRNN).
- Forward RNN: Reads from to , producing forward hidden states .
- Backward RNN: Reads from to , producing backward hidden states .
- Annotations: Each word gets an annotation by concatenating forward and backward states.
This bidirectional approach ensures each annotation contains information about both preceding and following words.
Decoder: specifies attention.
The decoder generates each target word using:
where:
- is the decoder's hidden state at step .
- is a dynamic context vector computed for each target word.
The context vector is computed as a weighted sum of all source annotations:
The attention weights represent how much attention to pay to source word when generating target word :
where is an alignment model (implemented as a feedforward neural network) that scores how well the input around position matches the output at position .
Key architectural advantages
Variable-length encoding: this alleviates the bottleneck of compressing all information into a single vector, which benefits longer sentences.
Selective attention: allows the model to dynamically focus on the most relevant parts of the input when generating each target word, rather than processing the entire sentence uniformly.
Soft alignment: Differentiable alignment allows end-to-end training.
Attention and soft alignment ensure that the entire translation system, including the alignment mechanism, can be jointly trained in an end-to-end fashion, leading to more robust and accurate translations, especially for longer inputs where traditional methods struggle.
This architecture was groundbreaking because it introduced the attention mechanism that became fundamental to modern NLP.
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(
self,
input_dim: int,
emb_dim: int,
enc_hid_dim: int,
dec_hid_dim: int,
n_layers: int = 1,
dropout: float = 0.0,
):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.GRU(
emb_dim, enc_hid_dim, n_layers, bidirectional=True, dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# FC layer must map the concatenated encoder states to the decoder's hidden dimension
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
def forward(self, src: torch.Tensor):
# src: [src_len, batch_size]
embedded = self.dropout(self.embedding(src))
# embedded: [src_len, batch_size, emb_dim]
outputs, hidden = self.rnn(embedded)
# outputs: [src_len, batch_size, enc_hid_dim * 2] (bidirectional)
# hidden: [n_layers * 2, batch_size, enc_hid_dim]
# Concatenate the final forward and backward hidden states
# and pass it through the FC layer to create the initial decoder hidden state.
s = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
hidden = torch.tanh(self.fc(s)).unsqueeze(0)
# hidden: [1, batch_size, dec_hid_dim]
return outputs, hidden
class Attention(nn.Module):
def __init__(self, enc_hid_dim, dec_hid_dim):
super(Attention, self).__init__()
# The encoder output dimension is enc_hid_dim * 2 due to bidirectionality.
self.W = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
self.U = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
self.v = nn.Linear(dec_hid_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs):
# hidden: [1, batch_size, dec_hid_dim]
# encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
src_len = encoder_outputs.shape[0]
# Repeat the decoder hidden state for each encoder time step.
hidden_repeated = hidden.repeat(src_len, 1, 1)
# hidden_repeated: [src_len, batch_size, dec_hid_dim]
# Calculate alignment scores.
energy = torch.tanh(self.W(hidden_repeated) + self.U(encoder_outputs))
# energy: [src_len, batch_size, dec_hid_dim]
attention = self.v(energy).squeeze(2)
# attention: [src_len, batch_size]
return F.softmax(attention, dim=0)
class Decoder(nn.Module):
def __init__(
self,
output_dim: int,
emb_dim: int,
enc_hid_dim: int,
dec_hid_dim: int,
n_layers: int = 1,
dropout: float = 0.0,
):
super(Decoder, self).__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, emb_dim)
self.attention = Attention(enc_hid_dim, dec_hid_dim)
# The context vector will have dimension enc_hid_dim * 2.
self.rnn = nn.GRU(
emb_dim + (enc_hid_dim * 2), dec_hid_dim, n_layers, dropout=dropout
)
self.fc_out = nn.Linear(emb_dim + dec_hid_dim + (enc_hid_dim * 2), output_dim)
self.dropout = nn.Dropout(dropout)
def forward(
self, input: torch.Tensor, hidden: torch.Tensor, encoder_outputs: torch.Tensor
):
# input: [batch_size]
# hidden: [1, batch_size, dec_hid_dim]
# encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
input = input.unsqueeze(0)
# input: [1, batch_size]
embedded = self.dropout(self.embedding(input))
# embedded: [1, batch_size, emb_dim]
# Get attention weights
attn_weights = self.attention(hidden, encoder_outputs)
# attn_weights: [src_len, batch_size]
# Compute context vector (weighted sum of encoder outputs)
# encoder_outputs is [src_len, batch_size, enc_hid_dim * 2]
context = torch.einsum("sb,sbd->bd", attn_weights, encoder_outputs)
context = context.unsqueeze(0)
# context: [1, batch_size, enc_hid_dim * 2]
rnn_input = torch.cat((embedded, context), dim=2)
# rnn_input: [1, batch_size, emb_dim + enc_hid_dim * 2]
output, hidden = self.rnn(rnn_input, hidden)
# output: [1, batch_size, dec_hid_dim]
# hidden: [n_layers, batch_size, dec_hid_dim]
# Predict next token.
embedded = embedded.squeeze(0)
output = output.squeeze(0)
context = context.squeeze(0)
prediction = self.fc_out(torch.cat((embedded, output, context), dim=1))
# prediction: [batch_size, output_dim]
return prediction, hidden, attn_weights
class RNNSearchModel(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder, device: torch.device):
super(RNNSearchModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(
self, src: torch.Tensor, trg: torch.Tensor, teacher_forcing_ratio: float = 0.5
):
# src: [src_len, batch_size]
# trg: [trg_len, batch_size]
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
encoder_outputs, hidden = self.encoder(src)
input = trg[0, :] # <sos>
for t in range(1, trg_len):
output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
outputs[t] = output
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[t] if teacher_force else top1
return outputs
from mosestokenizer import MosesTokenizer
tokenize_en = MosesTokenizer("en")
tokenize_fr = MosesTokenizer("fr")
def tokenize(text: str, lang: str) -> list[str]:
"""
Tokenizes a text string for a given language.
"""
if lang == "en":
return tokenize_en(text)
elif lang == "fr":
return tokenize_fr(text)
else:
raise ValueError(f"Unsupported language: {lang}")
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from model import Encoder, Decoder, RNNSearchModel
from data import load_and_prepare_data, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
from utility import tokenize
CHECKPOINT_DIR = "./checkpoint"
MODEL_FILE = os.path.join(CHECKPOINT_DIR, "model.pth")
ENC_EMB_DIM = 512
DEC_EMB_DIM = 512
ENC_HID_DIM = 500
DEC_HID_DIM = 1000
ENC_N_LAYERS = 1
DEC_N_LAYERS = 1
ENC_DROPOUT = 0.0
DEC_DROPOUT = 0.0
N_EPOCHS = 10
CLIP = 1.0 # Gradient clipping value
TEACHER_FORCING_RATIO = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train_model(model, train_iterator, trg_pad_idx):
# Initialize weights.
def init_weights(m):
for name, param in m.named_parameters():
if "weight" in name:
nn.init.normal_(param.data, mean=0, std=0.01)
else:
nn.init.constant_(param.data, 0)
model.apply(init_weights)
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
best_loss = float("inf")
# Training loop
for epoch in range(N_EPOCHS):
train_loss = train_epoch(
model,
train_iterator,
optimizer,
criterion,
CLIP,
TEACHER_FORCING_RATIO,
device,
)
if train_loss < best_loss:
best_loss = train_loss
torch.save(model.state_dict(), MODEL_FILE)
print(f"Epoch: {epoch + 1:02}")
print(f"\tTrain Loss: {train_loss:.3f}")
def train_epoch(
model, iterator, optimizer, criterion, clip, teacher_forcing_ratio, device
):
model.train()
epoch_loss = 0
for i, batch in enumerate(tqdm(iterator)):
src, trg = batch
src = src.to(device).t() # Transpose for the model
trg = trg.to(device).t() # Transpose for the model
optimizer.zero_grad()
output = model(src, trg, teacher_forcing_ratio)
# trg = [trg_len, batch_size]
# output = [trg_len, batch_size, output_dim]
output_dim = output.shape[-1]
output = output[1:].reshape(-1, output_dim)
trg = trg[1:].reshape(-1)
# output = [(trg_len - 1) * batch_size, output_dim]
# trg = [(trg_len - 1) * batch_size]
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def translate_sentence(sentence, source_vocab, target_vocab, model, device, max_len=50):
model.eval()
# Use the same tokenizer as the training data.
if isinstance(sentence, str):
tokens = tokenize(sentence, "en")
else:
# If the input is already a list, treat it as pre-tokenized.
tokens = [token for token in sentence]
# Convert to indices, handling unknown words. The vocabulary is case-sensitive.
src_indexes = [
source_vocab.word2idx.get(token, source_vocab.word2idx[UNK_TOKEN])
for token in tokens
]
# Add SOS and EOS tokens.
src_indexes = (
[source_vocab.word2idx[SOS_TOKEN]]
+ src_indexes
+ [source_vocab.word2idx[EOS_TOKEN]]
)
src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
with torch.no_grad():
encoder_outputs, hidden = model.encoder(src_tensor)
trg_indexes = [target_vocab.word2idx[SOS_TOKEN]]
for i in range(max_len):
trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
with torch.no_grad():
output, hidden, _ = model.decoder(trg_tensor, hidden, encoder_outputs)
pred_token = output.argmax(1).item()
trg_indexes.append(pred_token)
if pred_token == target_vocab.word2idx[EOS_TOKEN]:
break
# Return the translation, excluding SOS and EOS tokens.
trg_tokens = [target_vocab.idx2word.get(i, UNK_TOKEN) for i in trg_indexes]
return " ".join(trg_tokens[1:-1])
def main():
# Load data and vocabularies.
train_iterator, source_vocab, target_vocab = load_and_prepare_data()
INPUT_DIM = len(source_vocab)
OUTPUT_DIM = len(target_vocab)
TRG_PAD_IDX = target_vocab.word2idx[PAD_TOKEN]
# Initialize models.
encoder = Encoder(
INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_N_LAYERS, ENC_DROPOUT
)
decoder = Decoder(
OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_N_LAYERS, DEC_DROPOUT
)
model = RNNSearchModel(encoder, decoder, device).to(device)
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
train_model(model, train_iterator, TRG_PAD_IDX)
if os.path.exists(MODEL_FILE):
print(f"Loading best model from {MODEL_FILE}")
model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
total_params = sum(p.numel() for p in model.parameters())
print(f"Model has {total_params:,} parameters")
sentences = [
"Hello!",
]
for sentence in sentences:
translation = translate_sentence(
sentence, source_vocab, target_vocab, model, device
)
print(f"\nSource: {sentence}")
print(f"Translated: {translation}")
if __name__ == "__main__":
main()
8/10/2025