apapiu commited on
Commit
da56136
1 Parent(s): 12b2843

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -84
app.py CHANGED
@@ -1,90 +1,35 @@
1
  import gradio as gr
2
- from PIL import Image
3
  import requests
4
-
5
- from tld.denoiser import Denoiser
6
- from tld.diffusion import DiffusionGenerator
7
-
8
- from diffusers import AutoencoderKL, AutoencoderTiny
9
- from tqdm import tqdm
10
- import clip
11
- import torch
12
- import numpy as np
13
- import torchvision.utils as vutils
14
- import torchvision.transforms as transforms
15
- from torch.utils.data import DataLoader, TensorDataset
16
  from PIL import Image
17
-
18
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
- to_pil = transforms.ToPILImage()
20
-
21
-
22
- def download_file(url, filename):
23
-
24
- with requests.get(url, stream=True) as r:
25
- r.raise_for_status()
26
- with open(filename, 'wb') as f:
27
- for chunk in r.iter_content(chunk_size=8192):
28
- f.write(chunk)
29
-
30
- @torch.no_grad()
31
- def encode_text(label, model):
32
- text_tokens = clip.tokenize(label, truncate=True).to(device)
33
- text_encoding = model.encode_text(text_tokens)
34
- return text_encoding.cpu()
35
-
36
- def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32):
37
-
38
- n_iter = 15
39
- nrow = int(np.sqrt(num_imgs))
40
-
41
- cur_prompts = [prompt]*num_imgs
42
- labels = encode_text(cur_prompts, clip_model)
43
- out, out_latent = diffuser.generate(labels=labels,
44
- num_imgs=num_imgs,
45
- class_guidance=class_guidance,
46
- seed=seed,
47
- n_iter=n_iter,
48
- exponent=1,
49
- scale_factor=8,
50
- sharp_f=0,
51
- bright_f=0
52
- )
53
-
54
- out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1))
55
-
56
- out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png')
57
-
58
- print("Images Generated and Saved. They will shortly output below.")
59
- return out
60
-
61
- ###config:
62
- vae_scale_factor = 8
63
- img_size = 32
64
- model_dtype = torch.float32
65
-
66
- file_url = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
67
- local_filename = "state_dict_378000.pth"
68
- download_file(file_url, local_filename)
69
-
70
-
71
- denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2,
72
- embed_dim=768, dropout=0, n_layers=12)
73
-
74
-
75
- state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu'))
76
-
77
- denoiser = denoiser.to(model_dtype)
78
- denoiser.load_state_dict(state_dict)
79
- denoiser = denoiser.to(device)
80
-
81
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix",
82
- torch_dtype=model_dtype).to(device)
83
-
84
- clip_model, preprocess = clip.load("ViT-L/14")
85
- clip_model = clip_model.to(device)
86
-
87
- diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)
88
 
89
  # Define the Gradio interface
90
  iface = gr.Interface(
 
1
  import gradio as gr
 
2
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
+ from io import BytesIO
5
+ import os
6
+
7
+ token = os.environ['AUTH_TOKEN']
8
+ runpod_id = os.environ['RUNPOD_ID']
9
+
10
+ url = 'https://{runpod_id}-8000.proxy.runpod.net/generate-image/'
11
+
12
+ def generate_image_from_text(prompt, class_guidance, token):
13
+ headers = {
14
+ 'Authorization': f'Bearer {token}'
15
+ }
16
+
17
+ data = {
18
+ "prompt": prompt,
19
+ "class_guidance": class_guidance,
20
+ "seed": 11,
21
+ "num_imgs": 4,
22
+ "img_size": 32
23
+ }
24
+
25
+ response = requests.post(url, json=data, headers=headers)
26
+
27
+ if response.status_code == 200:
28
+ image = Image.open(BytesIO(response.content))
29
+ else:
30
+ print("Failed to fetch image:", response.status_code, response.text)
31
+
32
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Define the Gradio interface
35
  iface = gr.Interface(