--- license: mit --- # A Text-Conditioned Diffusion-Prior ## Training Details [Updated Reports Coming] ## Source Code Models are diffusion prior trainers from https://github.com/lucidrains/DALLE2-pytorch ## Community: LAION Join Us!: https://discord.gg/uPMftTmrvS --- ## Intro A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful. ### Motivation Before we dive into the model, let’s look at a quick example of where the model may be helpful. For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder. > [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets. ```python # Load Models clip_model = clip.load("ViT-L/14") decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings # Retrieve prompt from user and encode with CLIP prompt = "A corgi wearing sunglasses" tokenized_text = tokenize(prompt) text_embedding = clip_model.encode_text(tokenized_text) # Now, pass the text embedding to the decoder predicted_image = decoder.sample(text_embedding) ``` > **Question**: *Can you spot the issue here?* > > **Answer**: *We’re trying to generate an image from a text embedding!* Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution ```python # Load Models prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings # Retrieve prompt from user and encode with a prior prompt = "A corgi wearing sunglasses" tokenized_text = tokenize(prompt) text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images! # Now, pass the predicted image embedding to the decoder predicted_image = decoder.sample(text_embedding) ``` With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data. > **You may be asking yourself the following question:** > > *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"* > > OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile: ## Usage To utilize a pre-trained prior, it’s quite simple. ### Loading Checkpoints ```python import torch from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch.trainer import DiffusionPriorTrainer def load_diffusion_model(dprior_path): prior_network = DiffusionPriorNetwork( dim=768, depth=24, dim_head=64, heads=32, normformer=True, attn_dropout=5e-2, ff_dropout=5e-2, num_time_embeds=1, num_image_embeds=1, num_text_embeds=1, num_timesteps=1000, ff_mult=4 ) diffusion_prior = DiffusionPrior( net=prior_network, clip=OpenAIClipAdapter("ViT-L/14"), image_embed_dim=768, timesteps=1000, cond_drop_prob=0.1, loss_type="l2", condition_on_text_encodings=True, ) trainer = DiffusionPriorTrainer( diffusion_prior=diffusion_prior, lr=1.1e-4, wd=6.02e-2, max_grad_norm=0.5, amp=False, group_wd_params=True, use_ema=True, device=device, accelerator=None, ) trainer.load(dprior_path) return trainer ``` Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*) ### Sampling Once we have a pre-trained model, generating embeddings is quite simple! ```python # tokenize the text tokenized_text = clip.tokenize("") # predict an embedding predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0) ``` The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768). > For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text(). **Some things to note:** * It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt. * You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*. --- ## Training ### Overview Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration ## Dataset To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader. # Looking for more info? This readme continues in the official DALLE2-pytorch repo! you can find more details on training, metrics, and more [here](https://github.com/lucidrains/DALLE2-pytorch/blob/main/prior.md)