license: mit
A Text-Conditioned Diffusion-Prior
Training Details
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx
Source Code
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
Community: LAION
Join Us!: https://discord.gg/uPMftTmrvS
Models
The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.
DISCLAIMER: I will be removing many of the older models. They were trained on older versions of DALLE2 PyTorch and massively under perform compared to recent models. If for whatever reason you want an old model please make a backup (you have 7 days from this README commit timestamp).
Loading the models might look something like this:
Note: This repo's documentation will get an overhaul ~soon~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer
def load_diffusion_model(dprior_path, device):
# If you are getting issues with size mismatches, it's likely this configuration
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
)
# currently, only ViT-L/14 models are being trained
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
# this will load the entire trainer
# If you only want EMA weights for inference you will need to extract them yourself for now
# (if you beat me to writing a nice function for that please make a PR on Github!)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(dprior_path)
return trainer