--- license: mit --- # A Text-Conditioned Diffusion-Prior ## Training Details Training details can be found [here](https://wandb.ai/nousr_laion/1B%20Prior/reports/Distributed-Training-of-the-Prior--VmlldzoyMDkxMDQ5?accessToken=md54qpjikfxhf366iv64rxv94d47z05iojh28335fz6qlov11vlq313z63z42h3m) ## Source Code Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch ## Community: LAION Join Us!: https://discord.gg/uPMftTmrvS --- # Models The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now. > **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp). ### Loading the models might look something like this: > Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION. ```python import torch from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch.trainer import DiffusionPriorTrainer def load_diffusion_model(dprior_path, device): # If you are getting issues with size mismatches, it's likely this configuration prior_network = DiffusionPriorNetwork( dim=768, depth=24, dim_head=64, heads=32, normformer=True, attn_dropout=5e-2, ff_dropout=5e-2, num_time_embeds=1, num_image_embeds=1, num_text_embeds=1, num_timesteps=1000, ff_mult=4 ) # currently, only ViT-L/14 models are being trained diffusion_prior = DiffusionPrior( net=prior_network, clip=OpenAIClipAdapter("ViT-L/14"), image_embed_dim=768, timesteps=1000, cond_drop_prob=0.1, loss_type="l2", condition_on_text_encodings=True, ) # this will load the entire trainer # If you only want EMA weights for inference you will need to extract them yourself for now # (if you beat me to writing a nice function for that please make a PR on Github!) trainer = DiffusionPriorTrainer( diffusion_prior=diffusion_prior, lr=1.1e-4, wd=6.02e-2, max_grad_norm=0.5, amp=False, group_wd_params=True, use_ema=True, device=device, accelerator=None, ) trainer.load(dprior_path) return trainer ```