conditioned-prior / README.md
nousr's picture
Update README.md
db7c99f
|
raw
history blame
2.64 kB
---
license: mit
---
# A Text-Conditioned Diffusion-Prior
## Training Details
Training details can be found [here](https://wandb.ai/nousr_laion/1B%20Prior/reports/Distributed-Training-of-the-Prior--VmlldzoyMDkxMDQ5?accessToken=md54qpjikfxhf366iv64rxv94d47z05iojh28335fz6qlov11vlq313z63z42h3m)
## Source Code
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
## Community: LAION
Join Us!: https://discord.gg/uPMftTmrvS
---
# Models
The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.
> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
### Loading the models might look something like this:
> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
```python
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer
def load_diffusion_model(dprior_path, device):
# If you are getting issues with size mismatches, it's likely this configuration
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
)
# currently, only ViT-L/14 models are being trained
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
# this will load the entire trainer
# If you only want EMA weights for inference you will need to extract them yourself for now
# (if you beat me to writing a nice function for that please make a PR on Github!)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(dprior_path)
return trainer
```