A Text-Conditioned Diffusion-Prior

Training Details

[Updated Reports Coming]

Source Code

Models are diffusion prior trainers from https://github.com/lucidrains/DALLE2-pytorch

Community: LAION

Join Us!: https://discord.gg/uPMftTmrvS


Intro

A properly trained prior will allow you to translate between two embedding spaces. If you know a priori that two embeddings are connected some way—then ability the translate between them could extremely helpful.

Motivation

Before we dive into the model, let’s look at a quick example of where the model may be helpful.

For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.

CLIP is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are close the image and text embeddings occupy two disjoint sets.

# Load Models
clip_model = clip.load("ViT-L/14")
decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings

# Retrieve prompt from user and encode with CLIP
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = clip_model.encode_text(tokenized_text)

# Now, pass the text embedding to the decoder
predicted_image = decoder.sample(text_embedding)

Question: Can you spot the issue here?

Answer: We’re trying to generate an image from a text embedding!

Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution

# Load Models
prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings

# Retrieve prompt from user and encode with a prior
prompt = "A corgi wearing sunglasses"
tokenized_text = tokenize(prompt)
text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!

# Now, pass the predicted image embedding to the decoder
predicted_image = decoder.sample(text_embedding)

With the prior we are able to successfully generate embeddings within CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.

You may be asking yourself the following question:

"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"

OpenAI covers this topic in their DALLE-2 paper. The TL;DR is "it doesn't work as well as decoders trained on image embeddings"...also...its just an example :smile:

Usage

To utilize a pre-trained prior, it’s quite simple.

Loading Checkpoints

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer

def load_diffusion_model(dprior_path):

    prior_network = DiffusionPriorNetwork(
        dim=768,
        depth=24,
        dim_head=64,
        heads=32,
        normformer=True,
        attn_dropout=5e-2,
        ff_dropout=5e-2,
        num_time_embeds=1,
        num_image_embeds=1,
        num_text_embeds=1,
        num_timesteps=1000,
        ff_mult=4
    )

    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter("ViT-L/14"),
        image_embed_dim=768,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
        condition_on_text_encodings=True,

    )

    trainer = DiffusionPriorTrainer(
        diffusion_prior=diffusion_prior,
        lr=1.1e-4,
        wd=6.02e-2,
        max_grad_norm=0.5,
        amp=False,
        group_wd_params=True,
        use_ema=True,
        device=device,
        accelerator=None,
    )

    trainer.load(dprior_path)

    return trainer

Here we instantiate a model matches the configuration it was trained with, and then load the weights (just like any other PyTorch model!)

Sampling

Once we have a pre-trained model, generating embeddings is quite simple!

# tokenize the text
tokenized_text = clip.tokenize("<your amazing prompt>")
# predict an embedding
predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)

The resulting tensor returned from .sample() is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on ViT-L/14 embeddings will predict an embedding of shape (1, 768).

For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().

Some things to note:

  • It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is n=2). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
  • You may specify a higher conditioning scale than the default (1.0). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than 1.0 but ymmv.

Training

Overview

Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration

Dataset

To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage img2datset to pull images from a list of URLs and clip_retrieval for generating the actual embeddings that can be used in the prior's dataloader.

Looking for more info?

This readme continues in the official DALLE2-pytorch repo! you can find more details on training, metrics, and more here

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .