Tin Rabzelj
Tin Rabzelj
Dashed Line

Seed Diffusion: A Large-Scale Diffusion Language Model with High-Speed Inference | Paper Notes

8/7/2025

https://seed.bytedance.com/en/seed_diffusion

https://www.arxiv.org/abs/2508.02193

It reaches an inference speed of 2,146 tokens/s (on H20 GPUs), which is a significant improvement over autoregressive models of a similar size while achieving similar performance on coding benchmarks. It's also faster than existing diffusion models like Mercury models and Gemini Diffusion.

Two-Stage Training

The model first learns pattern completion by filling in masked sections of code. This is typical for mask-based diffusion models. It's then trained to perform edits and corrections, which improves its ability to understand and fix logical issues in the code. The "edit-based forward process" involves applying a certain number of random text edits, which is determined by multiplying the length of the original sequence by a value from a predefined noise schedule that controls the approximate signal-to-noise ratio. Essentially, the goal is generating a corrupted sequence that is a certain Levenshtein distance away from the original one, although not exact.

Constrained-Order Diffusion

Because code requires a logical order (like declaring variables before use), this method corrects the randomness of typical diffusion models. A large pool of possible generation paths (trajectories) is generated by the pre-trained diffusion model itself. The highest-quality trajectories are then selected from this pool to create a distilled dataset that is used to fine-tune the model. The model is retrained to enforce the correct code structure.

On-Policy Diffusion Learning

This is a training method where the model directly optimizes its own generation process to become more efficient. It learns to find the shortest possible path to a correct output, minimizing the number of inference steps needed and therefore boosting its speed.

Code blocks

Code for generating corrupted sequences:

import random
import math

# Define a special token for masking, as mentioned in the paper.
MASK_TOKEN = "[MASK]"
# Define a sample vocabulary for insertions and substitutions.
VOCABULARY = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,;. "

def mask_based_corruption(original_sequence: list[str], gamma_t: float) -> list[str]:
    if not 0.0 <= gamma_t <= 1.0:
        raise ValueError("gamma_t must be between 0 and 1.")

    corrupted_sequence = []
    for token in original_sequence:
        # Determine whether to mask the token based on the noise schedule gamma_t.
        if random.random() < gamma_t:
            corrupted_sequence.append(MASK_TOKEN)
        else:
            corrupted_sequence.append(token)
    return corrupted_sequence

def edit_based_corruption(original_sequence: list[str], alpha_t: float) -> list[str]:
    if not 0.0 <= alpha_t <= 0.1:
        print(f"Warning: alpha_t is {alpha_t}, but the paper suggests a range of [0, 0.1].")

    corrupted_sequence = list(original_sequence)

    # Calculate the target number of edits.
    k = math.floor(len(original_sequence) * alpha_t)

    for _ in range(k):
        if not corrupted_sequence:
            break

        # Randomly choose an edit operation: 0 for substitution, 1 for deletion, 2 for insertion.
        edit_type = random.choice([0, 1, 2])

        # Apply Substitution.
        if edit_type == 0:
            position = random.randint(0, len(corrupted_sequence) - 1)
            random_token = random.choice(VOCABULARY)
            # Ensure the substituted token is different
            while random_token == corrupted_sequence[position]:
                random_token = random.choice(VOCABULARY)
            corrupted_sequence[position] = random_token

        # Apply Deletion.
        elif edit_type == 1:
            position = random.randint(0, len(corrupted_sequence) - 1)
            del corrupted_sequence[position]

        # Apply Insertion.
        else: # edit_type == 2
            position = random.randint(0, len(corrupted_sequence))
            random_token = random.choice(VOCABULARY)
            corrupted_sequence.insert(position, random_token)

    return corrupted_sequence


if __name__ == "__main__":
    original_code = "def factorial(n):"
    tokenized_code = list(original_code)

    print(f"Original Sequence: {''.join(tokenized_code)}")
    print("-" * 20)

    # A gamma_t of 0.4 means roughly 40% of tokens will be masked.
    gamma_t_schedule = 0.4
    masked_sequence = mask_based_corruption(tokenized_code, gamma_t_schedule)
    print("Mask-based Corruption Example (gamma_t = 0.4):")
    print(f"Corrupted Sequence: {''.join(masked_sequence)}")
    print("-" * 20)

    # An alpha_t of 0.1 means the number of edits will be 10% of the sequence length.
    alpha_t_schedule = 0.1
    edited_sequence = edit_based_corruption(tokenized_code, alpha_t_schedule)
    print("Edit-based Corruption Example (alpha_t = 0.1):")
    print(f"Corrupted Sequence: {''.join(edited_sequence)}")
    print("-" * 20)

8/7/2025

Read more