Self-Distillation Through Time (SDTT)
SDTT is a distillation method for diffusion language models. Recent diffusion language models such as SEDD or MDLM achieve great results. However, because they cannot use KV-caching (non-causal architecture), it is slow to sample from them. Therefore, we devise a novel distillation method to reduce the inference latency of discrete diffusion models. After distillation, we can sample up to 8x faster than GPT-2 (that uses KV-caching). Find more details below and on our GitHub repo.
Using SDTT
- We released 3 groups of models:
- The baseline students distilled with the
kld
,mse
andtvd
objectives, distilled from a model trained for 1M steps. - The students from the scaling experiments, with sizes
sm
,md
,large
, distilled from models trained for 400k steps. - The teachers from the scaling experiments, with sizes
sm
,md
,large
, before any distillation.
- The baseline students distilled with the
- To load those models, first install our code:
git clone https://github.com/jdeschena/sdtt.git
cd sdtt
pip install -r requirements.txt
pip install flash-attn
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
- You can then import our models, sample and evaluate them:
Load the baseline students
from sdtt import load_small_student
student = load_small_student(loss="kld", round=7) # load the kld student after the last distillation round
student = load_small_student(loss="mse", round=2) # load the mse student after the second distillation round
student = load_small_student(loss="tvd", round=1) # load the tvd student after the first distillation round
Load the students from the scaling experiment
from sdtt import load_scaling_student
student = load_scaling_student(size="sm", round=7) # load small student after the last distillation round
student = load_scaling_student(size="md", round=1) # load medium student after the first distillation round
student = load_scaling_student(size="large", round=3) # load large student after the third distillation round
Load the teachers from the scaling experiment
from sdtt import load_scaling_teacher
student = load_scaling_student(size="sm",) # load small teacher
student = load_scaling_student(size="md",) # load medium teacher
student = load_scaling_student(size="large",) # load large teacher
Sample from the pretrained models
from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
import torch
model = load_small_student(loss="kld", round=7) # load model, see above
model.cuda() # put model on gpu
# Unconditional generation
tokens = model.sample(
n_samples=8,
num_steps=256,
seq_len=1024,
verbose=True,
)
# Detokenize
uncond_text = model.tokenizer.batch_decode(tokens)
# Conditional generation, based on a prompt
# Prepare a prompt
prompt = "Today is a great day. The sun is shining,"
prompt_tokens = model.tokenizer(prompt)["input_ids"]
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
prompt_len = len(prompt_tokens)
def project_fn(x):
# Project the first 10 tokens of all examples to the prompt
x[:, :prompt_len] = prompt_tokens
return x # Don't forget to return
tokens = model.sample(
n_samples=8,
num_steps=256,
seq_len=1024,
verbose=True,
project_fn=project_fn
)
cond_text = model.tokenizer.batch_decode(tokens)
For more details, please see our github repository: SDTT
Model Details
Our small checkpoints are distilled from the MDLM checkpoints. We also release medium (424M) and large (863M) checkpoints that we pretrained ourselves.
Citation
Please cite our work using the bibtex below:
BibTeX:
@article{deschenaux2024autoregressionfastllmsselfdistillation,
title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
author={Deschenaux, Justin and Gulcehre, Caglar}
eprint={2410.21035},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.21035},
}
Contact
Justin Deschenaux (justin.deschenaux@epfl.ch)