apapiu commited on
Commit
78c2594
1 Parent(s): ee3829a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from PIL import Image
4
+
5
+ from tld.denoiser import Denoiser
6
+ from tld.diffusion import DiffusionGenerator
7
+
8
+ from diffusers import AutoencoderKL, AutoencoderTiny
9
+ from tqdm import tqdm
10
+ import clip
11
+ import torch
12
+ import numpy as np
13
+ import torchvision.utils as vutils
14
+ import torchvision.transforms as transforms
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from PIL import Image
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ to_pil = transforms.ToPILImage()
20
+
21
+
22
+ ###config:
23
+ vae_scale_factor = 8
24
+ img_size = 32
25
+ model_dtype = torch.float32
26
+
27
+ @torch.no_grad()
28
+ def encode_text(label, model):
29
+ text_tokens = clip.tokenize(label, truncate=True).to(device)
30
+ text_encoding = model.encode_text(text_tokens)
31
+ return text_encoding.cpu()
32
+
33
+ def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32):
34
+
35
+ n_iter = 15
36
+ nrow = int(np.sqrt(num_imgs))
37
+
38
+ cur_prompts = [prompt]*num_imgs
39
+ labels = encode_text(cur_prompts, clip_model)
40
+ out, out_latent = diffuser.generate(labels=labels,
41
+ num_imgs=num_imgs,
42
+ class_guidance=class_guidance,
43
+ seed=seed,
44
+ n_iter=n_iter,
45
+ exponent=1,
46
+ scale_factor=8,
47
+ sharp_f=0,
48
+ bright_f=0
49
+ )
50
+
51
+ out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1))
52
+
53
+ out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png')
54
+
55
+ print("Images Generated and Saved. They will shortly output below.")
56
+ return out
57
+
58
+
59
+
60
+ denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2,
61
+ embed_dim=768, dropout=0, n_layers=12)
62
+
63
+
64
+ state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu'))
65
+
66
+ denoiser = denoiser.to(model_dtype)
67
+ denoiser.load_state_dict(state_dict)
68
+ denoiser = denoiser.to(device)
69
+
70
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix",
71
+ torch_dtype=model_dtype).to(device)
72
+
73
+ clip_model, preprocess = clip.load("ViT-L/14")
74
+ clip_model = clip_model.to(device)
75
+
76
+ diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)
77
+
78
+ # Define the Gradio interface
79
+ iface = gr.Interface(
80
+ fn=generate_image_from_text, # The function to generate the image
81
+ inputs=["text", "slider"],
82
+ outputs="image",
83
+ title="Text-to-Image Generator",
84
+ description="Enter a text prompt to generate an image."
85
+ )
86
+
87
+ # Launch the app
88
+ iface.launch()