from tqdm import tqdm import numpy as np from pathlib import Path import json # torch import torch from einops import repeat # vision imports from PIL import Image # dalle related classes and utils from dalle_pytorch import VQGanVAE, DALLE from dalle_pytorch.tokenizer import tokenizer from io import BytesIO import gradio as gr # load DALL-E def exists(val): return val is not None models = json.load(open("model_paths.json")) vae = VQGanVAE(None, None) dalles = {} for name, model_path in models.items(): assert Path(model_path).exists(), 'trained DALL-E '+model_path+' must exist' load_obj = torch.load(model_path) dalle_params, _, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') dalle_params.pop('vae', None) # cleanup later dalle = DALLE(vae = vae, **dalle_params).cuda() dalle.load_state_dict(weights) dalles[name] = dalle batch_size = 4 top_k = 0.9 # generate images image_size = vae.image_size def generate(text): text_input = text num_images = 4 dalle_name = "weird_car" dalle = dalles[dalle_name] text = tokenizer.tokenize([text_input], dalle.text_seq_len).cuda() text = repeat(text, '() n -> b n', b = num_images) outputs = [] for text_chunk in tqdm(text.split(batch_size), desc = f'generating images for - {text}'): output = dalle.generate_images(text_chunk, filter_thres = top_k) outputs.append(output) outputs = torch.cat(outputs) response = [] for image in tqdm(outputs, desc = 'saving images'): np_image = np.moveaxis(image.cpu().numpy(), 0, -1) formatted = (np_image * 255).astype('uint8') img = Image.fromarray(formatted) response.append(img) return response iface = gr.Interface(fn=generate, inputs="text", outputs=gr.outputs.Carousel("image")) iface.launch(share=True)