Khalil commited on
Commit
35ed471
1 Parent(s): b41a54a

Fix image generation function.

Browse files
Files changed (2) hide show
  1. app.py +16 -16
  2. text2punks/utils.py +3 -1
app.py CHANGED
@@ -1,9 +1,11 @@
1
  # system
2
 
3
  import os
 
4
 
5
- os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
6
- os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")
 
7
 
8
  # plot
9
 
@@ -17,12 +19,7 @@ import gradio as gr
17
 
18
  # text2punks utils
19
 
20
- from text2punks.utils import to_pil_image, model_loader, generate_image
21
-
22
-
23
- batch_size = 32
24
- num_images = 32
25
- top_prediction = 8
26
 
27
  # nobs to tune
28
 
@@ -34,21 +31,24 @@ temperature = 1.25
34
  def compose_predictions(images):
35
 
36
  increased_h = 0
37
- h, w = images[0].shape[0], images[0].shape[1]
38
- image_grid = Image.new("RGB", (len(images)*w, h))
39
 
40
- for i, img_ in enumerate(images):
41
- image_grid.paste(to_pil_image(img_), (i*w, increased_h))
 
 
 
42
 
43
- return img
44
 
45
 
46
- def run_inference(prompt, num_images=32, num_preds=8):
47
 
48
  t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
49
  text2punk, clip = model_loader(t2p_path, clip_path)
50
 
51
- images = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=top_prediction, text2punk_model=text2punk, clip_model=clip)
52
  predictions = compose_predictions(images)
53
 
54
  output_title = f"""
@@ -69,7 +69,7 @@ Text2Cryptopunks is an AI model that generates Cryptopunks images from text prom
69
 
70
  gr.Interface(run_inference,
71
  inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
72
- outputs=outputs,
73
  title='Text2Cryptopunks',
74
  description=description,
75
  article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
 
1
  # system
2
 
3
  import os
4
+ from pathlib import Path
5
 
6
+ if not Path('./Text2Punk-final-7.pt').exists() and not Path('./clip-final.pt').exists():
7
+ os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
8
+ os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")
9
 
10
  # plot
11
 
 
19
 
20
  # text2punks utils
21
 
22
+ from text2punks.utils import resize, to_pil_image, model_loader, generate_image
 
 
 
 
 
23
 
24
  # nobs to tune
25
 
 
31
  def compose_predictions(images):
32
 
33
  increased_h = 0
34
+ b, c, h, w = *images.shape,
35
+ image_grid = Image.new("RGB", (b*w*4, h*4), color=0)
36
 
37
+ for i in range(b):
38
+ # resize(images[i], 96)
39
+ print(images[i].shape)
40
+ img_ = to_pil_image(images[i])
41
+ image_grid.paste(img_, (i*w*4, increased_h))
42
 
43
+ return image_grid
44
 
45
 
46
+ def run_inference(prompt, num_images=32, batch_size=32, num_preds=8):
47
 
48
  t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
49
  text2punk, clip = model_loader(t2p_path, clip_path)
50
 
51
+ images, _ = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=num_preds, text2punk_model=text2punk, clip_model=clip)
52
  predictions = compose_predictions(images)
53
 
54
  output_title = f"""
 
69
 
70
  gr.Interface(run_inference,
71
  inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
72
+ outputs=outputs,
73
  title='Text2Cryptopunks',
74
  description=description,
75
  article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
text2punks/utils.py CHANGED
@@ -26,9 +26,11 @@ codebook = torch.load('./text2punks/data/codebook.pt')
26
  def exists(val):
27
  return val is not None
28
 
 
 
29
 
30
  def to_pil_image(image_tensor):
31
- return F.to_pil_image(image_tensor)
32
 
33
 
34
  def model_loader(text2punk_path, clip_path):
 
26
  def exists(val):
27
  return val is not None
28
 
29
+ def resize(image_tensor, size):
30
+ return F.resize(image_tensor, (size, size), F.InterpolationMode.NEAREST)
31
 
32
  def to_pil_image(image_tensor):
33
+ return F.to_pil_image(image_tensor.type(torch.uint8))
34
 
35
 
36
  def model_loader(text2punk_path, clip_path):