Tin Rabzelj
Tin Rabzelj
Dashed Line

Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention | Paper Notes

8/29/2025

https://arxiv.org/abs/2404.07143

Infini-attention is a attention mechanism designed to allow Transformer models to process infinitely long sequences of text while using a bounded amount of memory. It combines two forms of attention within a single Transformer block: a standard local attention for short-term context and a compressive memory for long-term context.

Input sequences are processed in segments/chunks. For each segment, it calculates local attention and compressive memory. Local attention is standard scaled dot-product attention but within a single segment. Instead of discarding the Key-Value (KV) states from past segments, it compresses and stores them in a fixed-size "compressive memory" matrix. When processing a new segment, the model uses the new query (Q) vectors to retrieve relevant information from this compressed memory of the entire history.

The memory matrix is an associative matrix that stores compressed information from past segments. It's a two-dimensional matrix. For each attention head in each layer, there is a separate memory matrix. Its dimensions are dkey×dvalued_{key}\times d_{value}.

The model combines local and long-term attention using a learned gating mechanism. A single learnable parameter β\beta per attention head decides whether to prioritize local information or historical information. The model updates the compressive memory by using the new key and value states.

Content is retrieved from the compressive memory using the current query state. Normalization is there to ensure stability.

Amem=σ(Q)Ms1σ(Q)zs1.A_{mem} = \frac{\sigma({Q}) M_{s-1}} {{\sigma(Q)} z_{s-1}}.

AmemRN×dvalueA_{mem}\in \mathbb{R}^{N\times d_{value}} represents the retrieved content, QQ is the query, Ms1Rdkey×dvalueM_{s-1}\in \mathbb{R}^{d_{key} \times d_{value}} is the memory state from the previous segment, zs1z_{s-1} is a normalization term, σ\sigma is a nonlinear activation function.

The memory is updated with the new key-value pairs from the current segment. This is done by adding the outer product of the transformed keys and values to the existing memory matrix.

MsMs1+σ(K)TV.M_{s} \leftarrow M_{s-1} + \sigma(K)^T V.

The normalization term is updated as:

zszs1+t=1Nσ(Kt).z_{s} \leftarrow z_{s-1} + \sum_{t=1}^{N}\sigma(K_t).

An alternative update rule (inspired by the delta rule) first subtracts the existing value associated with a key before adding the new value. This prevents modification of the memory if the key-value binding already exists.

MsMs1+σ(K)T(Vσ(K)Ms1σ(K)zs1).M_{s} \leftarrow M_{s-1} + \sigma(K)^T (V - \frac{\sigma({K}) M_{s-1}} {{\sigma(K)} z_{s-1}}).

MsM_s is the new memory state, Ms1M_{s-1} is the previous state, KK represents the keys, and VV the values from the current segment.

Long-term context injection is the process of combining the information retrieved from the compressive memory with the local attention context. This is done by using a learned gating scalar (β\beta). This gate determines the balance between the long-term context and the short-term. The final attention output is a weighted sum of these two components.

A=sigmoid(β)Amem+(1sigmoid(β))AdotA = \textit{sigmoid} (\beta) \odot A_{mem} + (1 - \textit{sigmoid}(\beta)) \odot A_{dot}

Where:

  • AA is the final aggregated attention context.
  • AmemA_{mem} is the content retrieved from the long-term compressive memory.
  • AdotA_{dot} is the local attention state from the standard dot-product attention.
  • β\beta is a single learnable scalar parameter.
  • The sigmoidsigmoid function squashes the value of β\beta to be between 00 and 11, acting as a soft gate.
  • \odot is element-wise multiplication.

More:

8/29/2025

Read more