conditioned-prior / README.md
nousr's picture
Update README.md
db7c99f
|
raw
history blame
2.64 kB
metadata
license: mit

A Text-Conditioned Diffusion-Prior

Training Details

Training details can be found here

Source Code

Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch

Community: LAION

Join Us!: https://discord.gg/uPMftTmrvS


Models

The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.

DISCLAIMER: I will be removing many of the older models. They were trained on older versions of DALLE2 PyTorch and massively under perform compared to recent models. If for whatever reason you want an old model please make a backup (you have 7 days from this README commit timestamp).

Loading the models might look something like this:

Note: This repo's documentation will get an overhaul ~soon~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer

def load_diffusion_model(dprior_path, device):
    
    # If you are getting issues with size mismatches, it's likely this configuration
    prior_network = DiffusionPriorNetwork(
        dim=768,
        depth=24,
        dim_head=64,
        heads=32,
        normformer=True,
        attn_dropout=5e-2,
        ff_dropout=5e-2,
        num_time_embeds=1,
        num_image_embeds=1,
        num_text_embeds=1,
        num_timesteps=1000,
        ff_mult=4
    )
    
    # currently, only ViT-L/14 models are being trained
    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter("ViT-L/14"),
        image_embed_dim=768,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
        condition_on_text_encodings=True,

    )
    
    # this will load the entire trainer
    # If you only want EMA weights for inference you will need to extract them yourself for now 
    # (if you beat me to writing a nice function for that please make a PR on Github!)
    trainer = DiffusionPriorTrainer(
        diffusion_prior=diffusion_prior,
        lr=1.1e-4,
        wd=6.02e-2,
        max_grad_norm=0.5,
        amp=False,
        group_wd_params=True,
        use_ema=True,
        device=device,
        accelerator=None,
    )
    
    trainer.load(dprior_path)
    
    return trainer