File size: 2,636 Bytes
afc0c95 614425e 3e6762d 614425e db7c99f 614425e ee6e0e8 614425e 45e02cf 614425e 45e02cf 614425e 651b685 ee6e0e8 6ed3a94 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 614425e ee6e0e8 2746f82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
---
license: mit
---
# A Text-Conditioned Diffusion-Prior
## Training Details
Training details can be found [here](https://wandb.ai/nousr_laion/1B%20Prior/reports/Distributed-Training-of-the-Prior--VmlldzoyMDkxMDQ5?accessToken=md54qpjikfxhf366iv64rxv94d47z05iojh28335fz6qlov11vlq313z63z42h3m)
## Source Code
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
## Community: LAION
Join Us!: https://discord.gg/uPMftTmrvS
---
# Models
The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.
> **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
### Loading the models might look something like this:
> Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
```python
import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer
def load_diffusion_model(dprior_path, device):
# If you are getting issues with size mismatches, it's likely this configuration
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4
)
# currently, only ViT-L/14 models are being trained
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
# this will load the entire trainer
# If you only want EMA weights for inference you will need to extract them yourself for now
# (if you beat me to writing a nice function for that please make a PR on Github!)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(dprior_path)
return trainer
``` |