Attention Is All You Need | Paper Notes
8/12/2025
https://arxiv.org/abs/1706.03762
Limitations of RNNs
RNNs can't be parallelized withing training examples. They process sequences step-by-step and generate hidden states based on previous states and the current input. Position can't be computed until is complete.
That inhibits training longer sequences because memory limitations restrict how many examples can be batched together for training. For example, in an RNN processing sequences of length 1000, if each hidden state requires significant memory and you can only fit 32 examples in a batch (are GPU poor), you're limited to processing 32K tokens at once. However, if you had shorter sequences of length 100, you could potentially fit 320 examples in the same batch, processing the same number of tokens but with better parallelization across examples.
RNNs struggle with long-range dependencies due to the length of paths that forward and backward signals must traverse. The path length between any two positions in the input and output sequences is , where is the sequence length. It's difficult to learn relationships between distant elements in a sequence.
The self-attention layers address these by:
- connecting all positions with a constant number of sequentially executed operations ( instead of ),
- allowing full parallelization during training,
- providing constant-length paths between any two positions in the sequence.
Architecture
The Transformer follows the encoder-decoder architecture that is common in neural sequence transduction models.
Encoder
The encoder is a stack of identical layers. Each layer contains exactly two sub-layers:
- Multi-head self-attention mechanism - allows each position to attend to all positions in the input sequence.
- Position-wise fully connected FFN - applies the same feed-forward transformation to each position independently.
Each sub-layer has a residual connection followed by layer normalization:
LayerNorm(x + Sublayer(x))
where Sublayer(x)
is the function implemented by the sub-layer itself.
All sub-layers and embedding layers produce outputs of dimension .
Decoder
The decoder also consists of a stack of identical layers, but with 3 sub-layers each:
- Masked multi-head self-attention - allows each position to attend only to earlier positions in the output sequence.
- Multi-head attention over encoder output - performs encoder-decoder attention, where queries come from the previous decoder layer and keys/values come from the encoder output.
- Position-wise fully connected feed-forward network - identical to the encoder's feed-forward layer.
Each sub-layer uses residual connections followed by layer normalization, just like the encoder.
Features
The self-attention sub-layer in the decoder is modified to prevent positions from attending to subsequent positions. This masked is refereed to as "causal mask." Masking and the fact that output embeddings are offset by one position ensures that predictions for position can depend only on known outputs at positions less than .
The consistent dimension throughout all layers enables clean residual connections and simplifies the architecture.
Both encoder and decoder extensively use residual connections around each sub-layer, which helps with gradient flow and enables training of deeper networks. Layer norm is applied after each sub-layer to stabilize training and improve convergence.
Attention
An attention function maps a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is determined by a compatibility function between the query and the corresponding key.
The paper introduces "Scaled Dot-Product Attention":
The scaling factor is there because for large values of , the dot products grow large in magnitude, which leads the softmax having to deal with extremely small gradients.
Dot-product attention is much faster and more space-efficient than additive attention since it can be implemented using optimized matrix multiplication GPU code.
Multi-head attention
Instead of performing a single attention function with -dimensional keys, values, and queries, "multi-head attention" uses multiple "attention heads":
where: .
How it works:
- Project queries, keys, and values times with different learned linear projections to , , and dimensions respectively.
- Perform the attention function in parallel on each projected version, yielding -dimensional output values.
- Concatenate and project the results.
Parameters dimensions:
- , ,
The paper uses parallel attention heads with . Due to the reduced dimension of each head, the total computational cost is similar to single-head attention with full dimensionality.
Attention in the Transformer
In encoder-decoder attention, the queries come from the previous decoder layer while the keys and values are drawn from the encoder's output. This allows the decoder to focus on relevant parts of the input sequence.
In encoder self-attention: keys, values, and queries come from the same source, specifically the previous encoder layer, enabling each position to attend to all other positions in the input.
Decoder self-attention works similarly to encoder self-attention, but includes masking to prevent positions from attending to future tokens.
Multi-head attention allows the model to attend to information from different representation subspaces at different positions. With a single attention head, averaging would prevent this, but multiple heads can focus on different types of relationships simultaneously.
Why self-attention
Self-attention layers connect all positions with constant operations versus for recurrent layers, are faster than recurrent layers when sequence length is (which is typical), and provide constant path lengths between any two positions compared to for recurrent and for convolutional layers, which makes it easier to learn long-range dependencies.
Positional encoding
The Transformer doesn't contain recurrence and no convolutions, so it has no way to understand the order of tokens in a sequence. The attention mechanism is permutation-invariant, which means it produces the same output regardless of input order without positional information.
Positional encodings allow the model to take relative positions into account.
Positional encodings are added to the input embeddings at the bottoms of both the encoder and decoder stacks. They have the same dimension as the embeddings, so they can be summed element-wise.
The paper uses sinusoidal encodings:
where is the position in the sequence, is the dim index, and each dimension of the positional encoding corresponds to a sinusoid.
The wavelengths form a geometric progression from to . Each position gets a unique "fingerprint."
They hypothesized this function would allow the model to learn to attend by relative positions, because for any fixed offset , can be represented as a linear function of .
Positional encodings are the same for all inputs. They experimented with learned encodings, but found out that it produces nearly identical results. An example of learned encodings can be seen with the BERT model. There, they are called "positional embeddings."
Implementation
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""
Implements positional encoding for transformer models using sinusoidal functions.
"""
def __init__(self, d_model: int, max_len: int = 5000):
super(PositionalEncoding, self).__init__()
# Create a matrix to hold positional encodings for all positions up to `max_len`.
pe = torch.zeros(max_len, d_model)
# Create a column vector of positions `[0, 1, 2, ..., max_len-1]`.
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Create the division term for the sinusoidal functions.
# This creates different frequencies for each dimension pair.
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply sine to even indices (0, 2, 4, ...).
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cosine to odd indices (1, 3, 5, ...).
pe[:, 1::2] = torch.cos(position * div_term)
# Reshape from (max_len, d_model) to (max_len, 1, d_model) to match expected input shape.
pe = pe.unsqueeze(0).transpose(0, 1)
# Register as buffer so it's saved with the model but not trained.
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Add positional encoding to input, using only the relevant sequence length.
return x + self.pe[: x.size(0), :]
class Transformer(nn.Module):
def __init__(
self,
num_tokens: int,
d_model: int = 256,
head_count: int = 8,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
ff_dim: int = 2048,
dropout: float = 0.1,
):
super(Transformer, self).__init__()
self.d_model = d_model
self.embedding = nn.Embedding(num_tokens, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layers = nn.TransformerEncoderLayer(
d_model, head_count, ff_dim, dropout, batch_first=False
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layers, num_encoder_layers
)
decoder_layers = nn.TransformerDecoderLayer(
d_model, head_count, ff_dim, dropout, batch_first=False
)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layers, num_decoder_layers
)
self.fc_out = nn.Linear(d_model, num_tokens)
@staticmethod
def generate_square_subsequent_mask(size):
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def encode(self, src, src_mask, src_padding_mask=None):
"""Encodes the source sequence."""
src_emb = self.embedding(src) * math.sqrt(self.d_model)
src_pos = self.pos_encoder(src_emb)
return self.transformer_encoder(src_pos, src_mask, src_padding_mask)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor,
memory_key_padding_mask=None,
tgt_padding_mask=None,
):
"""Decodes the target sequence using the encoder's memory."""
tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
tgt_pos = self.pos_encoder(tgt_emb)
return self.transformer_decoder(
tgt_pos,
memory,
tgt_mask,
None, # Optional memory_mask
tgt_padding_mask,
memory_key_padding_mask,
)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor,
tgt_mask: torch.Tensor,
src_padding_mask: torch.Tensor,
tgt_padding_mask: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
):
"""The forward pass of the Transformer model."""
memory = self.encode(src, src_mask, src_padding_mask)
output = self.decode(
tgt, memory, tgt_mask, memory_key_padding_mask, tgt_padding_mask
)
return self.fc_out(output)
def get_parameter_count(self) -> tuple[int, int]:
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return total_params, trainable_params
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
import kagglehub
from kagglehub import KaggleDatasetAdapter
from config import CHECKPOINT_DIR
UNK_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
SPECIAL_TOKENS = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]
TOKENIZER_FILE = os.path.join(CHECKPOINT_DIR, "tokenizer.json")
VOCAB_SIZE = 37_000
class TrainDataset(Dataset):
def __init__(self, source_sentences, target_sentences, tokenizer):
self.source_sentences = source_sentences
self.target_sentences = target_sentences
self.tokenizer = tokenizer
def __len__(self):
return len(self.source_sentences)
def __getitem__(self, idx):
source = self.source_sentences[idx]
target = self.target_sentences[idx]
source_indices = self.tokenizer.encode(source).ids
target_indices = self.tokenizer.encode(target).ids
return torch.tensor(source_indices), torch.tensor(target_indices)
def build_or_load_tokenizer(sentences):
tokenizer = Tokenizer(BPE(unk_token=UNK_TOKEN))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(
vocab_size=VOCAB_SIZE,
min_frequency=2,
special_tokens=SPECIAL_TOKENS,
)
tokenizer.train_from_iterator(sentences, trainer)
tokenizer.save(TOKENIZER_FILE)
print(f"--> Tokenizer saved to {TOKENIZER_FILE}")
return tokenizer
def collate_fn(batch, tokenizer):
source_batch, target_batch = [], []
sos_token_id = tokenizer.token_to_id(SOS_TOKEN)
eos_token_id = tokenizer.token_to_id(EOS_TOKEN)
pad_token_id = tokenizer.token_to_id(PAD_TOKEN)
for source, target in batch:
source_batch.append(
torch.cat(
[
torch.tensor([sos_token_id]),
source,
torch.tensor([eos_token_id]),
]
)
)
target_batch.append(
torch.cat(
[
torch.tensor([sos_token_id]),
target,
torch.tensor([eos_token_id]),
]
)
)
source_padded = pad_sequence(source_batch, padding_value=pad_token_id)
target_padded = pad_sequence(target_batch, padding_value=pad_token_id)
return source_padded, target_padded
def load_data(batch_size=32):
ds = kagglehub.load_dataset(
KaggleDatasetAdapter.HUGGING_FACE,
"rajpulapakura/english-to-french-small-dataset",
"english_french.csv",
)
dataset_size = int(len(ds["English"]) * 0.1)
if dataset_size > 0:
source_sentences = [row for row in ds["English"][:dataset_size]]
target_sentences = [row for row in ds["French"][:dataset_size]]
else:
source_sentences = [row for row in ds["English"]]
target_sentences = [row for row in ds["French"]]
# Train a shared tokenizer on both source and target sentences
tokenizer = build_or_load_tokenizer(source_sentences + target_sentences)
dataset = TrainDataset(source_sentences, target_sentences, tokenizer)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=lambda b: collate_fn(b, tokenizer),
)
return loader, tokenizer
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from data import load_data, PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN
from model import Transformer
from config import CHECKPOINT_DIR
MODEL_FILE = os.path.join(CHECKPOINT_DIR, "model.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0005
def save_checkpoint(model, optimizer, epoch, loss):
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
print(f"--> Saving checkpoint for epoch {epoch + 1}")
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
},
MODEL_FILE,
)
def load_checkpoint(model, optimizer=None):
if not os.path.exists(MODEL_FILE):
print("! No checkpoint found, cannot load model.")
return
print(f"<-- Loading checkpoint from {MODEL_FILE}")
checkpoint = torch.load(MODEL_FILE, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
print(f"<-- Checkpoint loaded (epoch {epoch + 1}, loss: {loss:.4f})")
def create_masks(src, tgt, tokenizer, device):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
pad_token_id = tokenizer.token_to_id(PAD_TOKEN)
tgt_mask = Transformer.generate_square_subsequent_mask(tgt_seq_len).to(device)
src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
src_padding_mask = (src == pad_token_id).transpose(0, 1)
tgt_padding_mask = (tgt == pad_token_id).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
def train_model(model, train_loader, tokenizer):
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id(PAD_TOKEN))
optimizer = optim.Adam(
model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9
)
for epoch in range(EPOCHS):
model.train()
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}")
total_loss = 0
for src, tgt in progress_bar:
src, tgt = src.to(DEVICE), tgt.to(DEVICE)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_masks(
src, tgt_input, tokenizer, DEVICE
)
optimizer.zero_grad()
output = model(
src,
tgt_input,
src_mask,
tgt_mask,
src_padding_mask,
tgt_padding_mask,
src_padding_mask,
)
tgt_out = tgt[1:, :]
loss = criterion(output.reshape(-1, output.shape[-1]), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({"loss": total_loss / (progress_bar.n + 1)})
save_checkpoint(model, optimizer, epoch, total_loss / len(train_loader))
def translate_sentence(model, sentence, tokenizer, device, max_len=50):
model.eval()
# Get special token ids.
sos_token_id = tokenizer.token_to_id(SOS_TOKEN)
eos_token_id = tokenizer.token_to_id(EOS_TOKEN)
pad_token_id = tokenizer.token_to_id(PAD_TOKEN)
# Prepare source sentence by adding SOS and EOS tokens.
src_indices = tokenizer.encode(sentence).ids
src_tensor = (
torch.tensor([sos_token_id] + src_indices + [eos_token_id])
.unsqueeze(1)
.to(device)
)
src_mask = torch.zeros(
src_tensor.shape[0], src_tensor.shape[0], device=device
).type(torch.bool)
src_padding_mask = (src_tensor == pad_token_id).transpose(0, 1)
with torch.no_grad():
memory = model.encode(src_tensor, src_mask, src_padding_mask)
# Prepare target sentence, starting with SOS token
tgt_indices = [sos_token_id]
for i in range(max_len):
tgt_tensor = torch.tensor(tgt_indices).unsqueeze(1).to(device)
tgt_mask = Transformer.generate_square_subsequent_mask(tgt_tensor.size(0)).to(
device
)
tgt_padding_mask = torch.zeros(
tgt_tensor.shape[1], tgt_tensor.shape[0], device=device
).type(torch.bool)
with torch.no_grad():
output = model.decode(
tgt_tensor, memory, tgt_mask, src_padding_mask, tgt_padding_mask
)
prob = model.fc_out(output[-1, :, :])
pred_token = prob.argmax(1).item()
tgt_indices.append(pred_token)
if pred_token == eos_token_id:
break
# Decode the indices to a string, skipping special tokens.
return tokenizer.decode(tgt_indices, skip_special_tokens=True)
def main():
train_loader, tokenizer = load_data(batch_size=BATCH_SIZE)
vocab_size = tokenizer.get_vocab_size()
print(f"Vocab size: {vocab_size}")
model = Transformer(num_tokens=vocab_size).to(DEVICE)
train_model(model, train_loader, tokenizer)
total_params, trainable_params = model.get_parameter_count()
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
load_checkpoint(model)
print("Testing...")
for sentence in [
"Hello!",
]:
translation = translate_sentence(model, sentence, tokenizer, DEVICE)
print(f"Source: {sentence}")
print(f"Translated: {translation}\n")
if __name__ == "__main__":
main()
8/12/2025