Paper Review: Large Language Diffusion Models

LLaDA is a diffusion-based alternative to autoregressive models for LLMs. It models distributions through a forward data masking process and a reverse process, using a Transformer to predict masked tokens. By optimizing a likelihood bound, it enables principled probabilistic inference.
LLaDA demonstrates strong scalability, outperforming self-constructed ARM baselines. LLaDA 8B rivals LLaMA3 8B in in-context learning and, after supervised fine-tuning, shows impressive instruction-following abilities. It also surpasses GPT-4o in a reversal poem completion task, addressing the reversal curse.
The approach

Probabilistic Formulation
LLaDA models distributions using a forward and reverse process, unlike autoregressive models. The forward process progressively masks tokens in a sequence until fully masked, while the reverse process recovers tokens by predicting masked elements. A mask predictor predicts all masked tokens based on partially masked input. It is trained using a cross-entropy loss applied only to masked tokens.
The training objective is an upper bound on negative log-likelihood, making LLaDA a principled generative model. Unlike masked language models, which use a fixed masking ratio, LLaDA applies a random masking ratio, improving scalability and enabling natural in-context learning. Its generative formulation ensures Fisher consistency, suggesting strong potential for large-scale applications.
Pre-training
LLaDA uses a Transformer-based architecture similar to existing LLMs but without causal masking, allowing it to see the entire input for predictions. Unlike standard LLMs, LLaDA does not support KV caching, leading to using vanilla multi-head attention** and a reduced FFN dimension to balance parameter count.
LLaDA is pre-trained on 2.3 trillion tokens with a fixed sequence length of 4096 tokens and used 0.13 million H800 GPU hours.
Training used Monte Carlo sampling to estimate the objective function. To improve handling of variable-length data, 1% of pre-training samples has random sequence lengths between 1 and 4096 tokens.
Supervised Fine-Tuning
LLaDA improves instruction-following ability through SFT using 4.5 million prompt-response pairs. SFT trains the model to predict responses given prompts by modeling a conditional distribution. The prompt remains unmasked, while response tokens are masked and predicted.
Inference
LLaDA supports text generation and likelihood evaluation.
For generation, it samples responses by discretizing the reverse process, starting from a fully masked response and predicting tokens iteratively. The number of sampling steps controls the trade-off between efficiency and quality.
To improve sampling accuracy, predicted tokens are remasked in each step to align with the forward process. The authors explore two remasking strategies: replace predicted tokens with the lowest confidence scores (low-confidence remasking) and generate text block by block from left to right after fine-tuning (semi-autoregressive remasking).
For likelihood evaluation, LLaDA leverages a lower-variance reformulation of the loss function for more stable probability estimation. Additionally, it uses unsupervised classifier-free guidance to improve evaluation quality.
Experiments

Experiments show that LLaDA scales competitively with ARMs, outperforming them on MMLU and GSM8K and closing the gap on some tasks at larger scales.

LLaDA 8B was evaluated for in-context learning and instruction-following against existing LLMs of similar scale across 15 benchmarks covering general tasks, mathematics, code, and Chinese.
After pretraining on 2.3T tokens, LLaDA 8B outperformed LLaMA2 7B on nearly all tasks and was competitive with LLaMA3 8B, showing an advantage in math and Chinese tasks. Differences in data quality and distribution may explain variations in performance.

SFT improved performance on most tasks, though some, like MMLU, had lower scores, possibly due to suboptimal SFT data quality. Without RL alignment, LLaDA 8B Instruct performed slightly below LLaMA3 8B Instruct.

LLaDA was tested for reversal reasoning using a dataset of 496 famous Chinese poem sentence pairs, where models had to generate the next (forward) or previous (reversal) line without fine-tuning. Unlike GPT-4o and Qwen 2.5, which showed a performance gap between forward and reversal tasks, LLaDA excelled in both, effectively overcoming the reversal curse.
This success was achieved without task-specific modifications, likely due to LLaDA’s uniform token treatment, which avoids the inductive biases of autoregressive models.
Additionally, remasking strategies and sampling steps were analyzed for their impact on performance. Case studies showcased LLaDA 8B Instruct’s ability to generate fluent, extended text, engage in multi-turn dialogue, retain conversation history, and support multiple languages, marking a significant departure from traditional ARMs.
paperreview deeplearning nlp transformer llm diffusion