File size: 3,411 Bytes
d50fdec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
---
license: mit
datasets:
- Skylion007/openwebtext
tags:
- diffusion
---
# Generalized Interpolating Discrete Diffusion
By Dimitri von Rütte, Janis Fluri, Yuhui Ding, Antonio Orvieto, Bernhard Schölkopf, Thomas Hofmann
<div style="display: flex; gap: 8px;">
<a href="https://www.arxiv.org/abs/2503.04482"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2503.04482-d22c2c.svg"></a>
<a href="https://colab.research.google.com/drive/1Xv4RyZhXHkIpIZeMYahl_4kMthLxKdg_?usp=sharing"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a>
<a href="https://github.com/dvruette/gidd"><img alt="GitHub" src="https://img.shields.io/badge/GitHub-GIDD-blue"></a>
</div>
---

We present Generalized Interpolating Discrete Diffusion (GIDD), a novel framework for training discrete diffusion models.
GIDD can be seen as a generalization of the popular masked diffusion paradigm (MDM) to any diffusion process that can be written as a linear interpolation between a data distribution and some (time-variable) mixing distribution.
We demonstrate the flexibility of GIDD by training models on a hybrid diffusion process that combines masking and uniform noise.
The model therefore is trained to not only "fill in the blanks" (i.e. the masked tokens), but also to consider the correctness of already-filled-in tokens and, if necessary, replace incorrect tokens with more plausible ones.
We show that GIDD models trained on hybrid noise have better sample quality (generative PPL) than mask-only models, and that they are able to identify and correct their own mistakes in generated samples through a self-correction step.
This repository contains all training and evaluation code necessary for reproducing the results in the paper.
### Pretrained Checkpoints
Our trained checkpoints are available under the following links. All of them have been trained on 131B tokens from the [OpenWebText](https://huggingface.co/datasets/Skylion007/openwebtext) dataset with the [GPT-2 tokenizer](https://huggingface.co/openai-community/gpt2).
| Model | Small (169.6M) | Base (424.5M) |
|-------|-------|------|
| GIDD+ (p_u = 0.0) | [dvruette/gidd-small-p_unif-0.0](https://huggingface.co/dvruette/gidd-small-p_unif-0.0) | [dvruette/gidd-base-p_unif-0.0](https://huggingface.co/dvruette/gidd-base-p_unif-0.0) |
| GIDD+ (p_u = 0.1) | [dvruette/gidd-small-p_unif-0.1](https://huggingface.co/dvruette/gidd-small-p_unif-0.1) | [dvruette/gidd-base-p_unif-0.1](https://huggingface.co/dvruette/gidd-base-p_unif-0.1) |
| GIDD+ (p_u = 0.2) | [dvruette/gidd-small-p_unif-0.2](https://huggingface.co/dvruette/gidd-small-p_unif-0.2) | dvruette/gidd-base-p_unif-0.2 |
## Use the Model
1. Install the GIDD repo:
```bash
pip install git+https://github.com/dvruette/gidd
```
2. For quickly downloading a trained model and playing around with it, the `GiddPipeline` class is most convenient:
```python
from gidd import GiddPipeline
# Download a pretrained model from HuggingFace
pipe = GiddPipeline.from_pretrained("dvruette/gidd-base-p_unif-0.2", trust_remote_code=True)
# Generate samples
texts = pipe.generate(num_samples=4, num_inference_steps=128)
# Run self-correction step
corrected_texts = pipe.self_correction(texts, num_inference_steps=128, early_stopping=True, temperature=0.1)
print(corrected_texts)
```
|