|
--- |
|
license: mit |
|
--- |
|
# A Text-Conditioned Diffusion-Prior |
|
|
|
## Training Details |
|
Training details can be found [here](https://wandb.ai/nousr_laion/1B%20Prior/reports/Distributed-Training-of-the-Prior--VmlldzoyMDkxMDQ5?accessToken=md54qpjikfxhf366iv64rxv94d47z05iojh28335fz6qlov11vlq313z63z42h3m) |
|
## Source Code |
|
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch |
|
|
|
## Community: LAION |
|
Join Us!: https://discord.gg/uPMftTmrvS |
|
|
|
--- |
|
|
|
# Models |
|
The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now. |
|
|
|
> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp). |
|
|
|
### Loading the models might look something like this: |
|
|
|
> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION. |
|
|
|
```python |
|
import torch |
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter |
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer |
|
|
|
def load_diffusion_model(dprior_path, device): |
|
|
|
# If you are getting issues with size mismatches, it's likely this configuration |
|
prior_network = DiffusionPriorNetwork( |
|
dim=768, |
|
depth=24, |
|
dim_head=64, |
|
heads=32, |
|
normformer=True, |
|
attn_dropout=5e-2, |
|
ff_dropout=5e-2, |
|
num_time_embeds=1, |
|
num_image_embeds=1, |
|
num_text_embeds=1, |
|
num_timesteps=1000, |
|
ff_mult=4 |
|
) |
|
|
|
# currently, only ViT-L/14 models are being trained |
|
diffusion_prior = DiffusionPrior( |
|
net=prior_network, |
|
clip=OpenAIClipAdapter("ViT-L/14"), |
|
image_embed_dim=768, |
|
timesteps=1000, |
|
cond_drop_prob=0.1, |
|
loss_type="l2", |
|
condition_on_text_encodings=True, |
|
|
|
) |
|
|
|
# this will load the entire trainer |
|
# If you only want EMA weights for inference you will need to extract them yourself for now |
|
# (if you beat me to writing a nice function for that please make a PR on Github!) |
|
trainer = DiffusionPriorTrainer( |
|
diffusion_prior=diffusion_prior, |
|
lr=1.1e-4, |
|
wd=6.02e-2, |
|
max_grad_norm=0.5, |
|
amp=False, |
|
group_wd_params=True, |
|
use_ema=True, |
|
device=device, |
|
accelerator=None, |
|
) |
|
|
|
trainer.load(dprior_path) |
|
|
|
return trainer |
|
``` |