Seed Diffusion: A Large-Scale Diffusion Language Model with High-Speed Inference | Paper Notes
8/7/2025
https://seed.bytedance.com/en/seed_diffusion
https://www.arxiv.org/abs/2508.02193
It reaches an inference speed of 2,146 tokens/s (on H20 GPUs), which is a significant improvement over autoregressive models of a similar size while achieving similar performance on coding benchmarks. It's also faster than existing diffusion models like Mercury models and Gemini Diffusion.
Two-Stage Training
The model first learns pattern completion by filling in masked sections of code. This is typical for mask-based diffusion models. It's then trained to perform edits and corrections, which improves its ability to understand and fix logical issues in the code. The "edit-based forward process" involves applying a certain number of random text edits, which is determined by multiplying the length of the original sequence by a value from a predefined noise schedule that controls the approximate signal-to-noise ratio. Essentially, the goal is generating a corrupted sequence that is a certain Levenshtein distance away from the original one, although not exact.
Constrained-Order Diffusion
Because code requires a logical order (like declaring variables before use), this method corrects the randomness of typical diffusion models. A large pool of possible generation paths (trajectories) is generated by the pre-trained diffusion model itself. The highest-quality trajectories are then selected from this pool to create a distilled dataset that is used to fine-tune the model. The model is retrained to enforce the correct code structure.
On-Policy Diffusion Learning
This is a training method where the model directly optimizes its own generation process to become more efficient. It learns to find the shortest possible path to a correct output, minimizing the number of inference steps needed and therefore boosting its speed.
Code blocks
Code for generating corrupted sequences:
import random
import math
# Define a special token for masking, as mentioned in the paper.
MASK_TOKEN = "[MASK]"
# Define a sample vocabulary for insertions and substitutions.
VOCABULARY = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,;. "
def mask_based_corruption(original_sequence: list[str], gamma_t: float) -> list[str]:
if not 0.0 <= gamma_t <= 1.0:
raise ValueError("gamma_t must be between 0 and 1.")
corrupted_sequence = []
for token in original_sequence:
# Determine whether to mask the token based on the noise schedule gamma_t.
if random.random() < gamma_t:
corrupted_sequence.append(MASK_TOKEN)
else:
corrupted_sequence.append(token)
return corrupted_sequence
def edit_based_corruption(original_sequence: list[str], alpha_t: float) -> list[str]:
if not 0.0 <= alpha_t <= 0.1:
print(f"Warning: alpha_t is {alpha_t}, but the paper suggests a range of [0, 0.1].")
corrupted_sequence = list(original_sequence)
# Calculate the target number of edits.
k = math.floor(len(original_sequence) * alpha_t)
for _ in range(k):
if not corrupted_sequence:
break
# Randomly choose an edit operation: 0 for substitution, 1 for deletion, 2 for insertion.
edit_type = random.choice([0, 1, 2])
# Apply Substitution.
if edit_type == 0:
position = random.randint(0, len(corrupted_sequence) - 1)
random_token = random.choice(VOCABULARY)
# Ensure the substituted token is different
while random_token == corrupted_sequence[position]:
random_token = random.choice(VOCABULARY)
corrupted_sequence[position] = random_token
# Apply Deletion.
elif edit_type == 1:
position = random.randint(0, len(corrupted_sequence) - 1)
del corrupted_sequence[position]
# Apply Insertion.
else: # edit_type == 2
position = random.randint(0, len(corrupted_sequence))
random_token = random.choice(VOCABULARY)
corrupted_sequence.insert(position, random_token)
return corrupted_sequence
if __name__ == "__main__":
original_code = "def factorial(n):"
tokenized_code = list(original_code)
print(f"Original Sequence: {''.join(tokenized_code)}")
print("-" * 20)
# A gamma_t of 0.4 means roughly 40% of tokens will be masked.
gamma_t_schedule = 0.4
masked_sequence = mask_based_corruption(tokenized_code, gamma_t_schedule)
print("Mask-based Corruption Example (gamma_t = 0.4):")
print(f"Corrupted Sequence: {''.join(masked_sequence)}")
print("-" * 20)
# An alpha_t of 0.1 means the number of edits will be 10% of the sequence length.
alpha_t_schedule = 0.1
edited_sequence = edit_based_corruption(tokenized_code, alpha_t_schedule)
print("Edit-based Corruption Example (alpha_t = 0.1):")
print(f"Corrupted Sequence: {''.join(edited_sequence)}")
print("-" * 20)
8/7/2025