RoFormer: Enhanced Transformer with Rotary Position Embedding | Paper Notes
8/27/2025
Introduces RoPE.
In self-attention, the inner product between query and key values determines how much attention one token pays to another. The inner product should ideally depend only on the relative distance between two token, rather than their absolute positions.
To encode relative positions, we need a function that takes in embeddings , and their relative position .
The ultimate goal is to find an equivalent encoding mechanism to solve the functions and to conform the aforementioned relation.
They start by simplifying the problem to 2D case and re-framing it using complex numbers, which can be used to express 2D vectors using magnitudes and angles. By enforcing the constraint that the inner product between two position-encoded vectors must depend only on their relative distance, they deduce two key properties. First, the length of the vectors must not change with position. Second, the angle of the vectors must be a linear function of their position. Rotation satisfies both of these conditions. They conclude that the absolute position of a token can be encoded by rotating its vector representation by an angle proportional to its position. Then this solution is generalized to higher dimensions by grouping the vector's dimensions into pairs and applying the same 2D rotation to each pair independently.
There's too much math for me.
In simple terms
Formulation:
where is the position of a token, is the embedding vector, is rotated by $$\theta_pos$.
Steps to encode in 2D space:
- Start with embeddings for relative tokens and : and .
- For each token, calculate rotation angles and based on positions.
- Apply the rotation by multiplying embeddings with positional encodings (cosines and sines).
import torch
def apply_rope(x: torch.Tensor, base:float=10000.0) -> torch.Tensor:
batch_size, seq_len, embed_dim = x.shape
assert embed_dim % 2 == 0, "Embedding dimension must be even."
# Compute inverse frequencies
half_dim = embed_dim // 2
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim))
# Compute angles
positions = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
# Compute cosines and sines
cos = torch.cos(freqs) # (seq_len, half_dim)
sin = torch.sin(freqs) # (seq_len, half_dim)
# Expand to full embed_dim
cos = torch.cat((cos, cos), dim=-1).unsqueeze(0) # (1, seq_len, embed_dim)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(0) # (1, seq_len, embed_dim)
# Split x into halves
x1 = x[..., :half_dim]
x2 = x[..., half_dim:]
# Compute rotated version
rot_x = torch.cat((-x2, x1), dim=-1)
# Apply rotation
rotated = x * cos + rot_x * sin
return rotated
8/27/2025