Tin Rabzelj
Tin Rabzelj
Dashed Line

Ring Attention with Blockwise Transformers for Near-Infinite Context | Paper Notes

9/10/2025

https://arxiv.org/abs/2310.01889

Proposes a memory efficient transformers architecture that allows the context length to scale linearly with the number of devices.

Blockwise parallel transformers

Blockwise Parallel Transformers (BPT) is a memory-efficient technique that computes both the self-attention and the FFN of a Transformer layer in a block-by-block manner. It avoids materializing the full attention matrix.

Major downsize is that BPT requires storing the entire output of each layer. This is necessary because the subsequent layer's self-attention mechanism needs to access all of the previous layer's outputs.

Challenges

Computation Delay: When a sequence is split, each device only has a part of necessary data. A device needs its local "query" block and "key-value" blocks from all other devices. A way to do this is for a device to request a key-value block from another device, wait for it, then do the computation. This slows everything down.

If a device fetches the necessary key-value blocks from other hosts, it must store them in its own memory. This increases memory usage, which defeats the entire purpose of distributing the sequence in the first place.

Solution

Ring attention distributes long sequences into a ring across devices. Each device gets a block, where each device has a prev and next neighbors.

The goal is for each device to compute the attention of its block (query block) against all others. Key-value blocks are passed around the ring.

Data passing happens at the same time as the computation.

Each device starts with its own query block and its corresponding key-value block. Each device computes the attention between its own query block and its own key-value block. While performing the next blockwise attention computation, each device does two things simultaneously:

  • sends a copy of key-value block it just used to the next device in the ring.
  • receives a new key-value block from previous device.

This repeats until every query block has been handled against every key-value block.

This works because of a property of self-attention that the paper calls "permutation invariance." The order in which you compute the attention between a query and a set of key-value blocks doesn't matter, as long as you combine the results correctly in the end.

Each device incrementally and independently builds its own final output block as the other blocks rotate through.

A standard softmax requires all the scores at once to compute the normalizing denominator. The denominator is the sum of the exponentials of all the scores. If you only have one block of keys, you can't compute it.

To solve this, each device keeps track of three running "statistics" as each new key-value block arrives.

  • The Running Numerator: weighted sum of the value vectors seen so far.
  • The Running Denominator: the sum of the exponential scores seen so far, which will be used for the final normalization.
  • The Running Maximum Score: Used for numerical stability. Typical when implementing softmax.

Let's say you are one device in the ring ("Device A") and hold "Query Block A".

  1. Initialization: Device A initializes its numerator, denominator and maximum score tracker.
  2. First Computation (Local): It first computes attention using its own "key-value block A". It calculates the scores, updates its running maximum score, and updates its numerator and denominator based on this first block.
  3. The Ring Rotation: The ring rotates. Let's say "key-value block B" arrives from the previous device.
  • Device A computes the scores between "query block A" and "key-value block B".
  • It checks if any of these new scores are higher than its current running maximum. If so, it updates the maximum and rescales its existing numerator and denominator to be consistent with the new maximum.
  • It then adds the contribution of block B to the rescaled numerator and denominator.
  1. Repeat: This repeats for every key-value block. With each new block, Device A updates its three running statistics.
  2. Finalization: After N-1 steps (where N is the number of devices), Device A has seen every key-value block in the entire sequence. At this point, its running statistics are complete. The final output for its query block is simply calculated as: Final Output Block A = Running Numerator / Running Denominator.

9/10/2025

Read more