jaekookang commited on
Commit
8b393cd
โ€ข
1 Parent(s): 831f916

update app

Browse files
Files changed (2) hide show
  1. gradio_imagecompletion.py +35 -3
  2. requirements.txt +1 -0
gradio_imagecompletion.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import matplotlib.pyplot as plt
9
 
10
  import os
 
11
  import requests
12
  from glob import glob
13
  import gradio as gr
@@ -35,15 +36,46 @@ model.to(device)
35
 
36
  def process_image(image):
37
  logger.info('--- image file received')
38
- return image.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  iface = gr.Interface(
42
  process_image,
43
  title="์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜์„ ์ง€์šฐ๊ณ  ์ ˆ๋ฐ˜์„ ์ฑ„์›Œ ๋„ฃ์–ด์ฃผ๋Š” Image Completion ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค (ImageGPT)",
44
  description='์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜ ์•„๋ž˜๋ฅผ AI๊ฐ€ ์ฑ„์›Œ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค',
45
- inputs=gr.inputs.Image(type="pil"),
46
- outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
47
  examples=examples,
48
  enable_queue=True,
49
  article='<p style="text-align:center">i-Scream AI</p>',
 
8
  import matplotlib.pyplot as plt
9
 
10
  import os
11
+ import numpy as np
12
  import requests
13
  from glob import glob
14
  import gradio as gr
 
36
 
37
  def process_image(image):
38
  logger.info('--- image file received')
39
+ # prepare 7 images, shape (7, 1024)
40
+ batch_size = 7
41
+ encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")
42
+ # create primers
43
+ samples = encoding.pixel_values.numpy()
44
+ n_px = feature_extractor.size
45
+ clusters = feature_extractor.clusters
46
+ n_px_crop = 16
47
+ primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
48
+
49
+ # get conditioned image (from first primer tensor), padded with black pixels to be 32x32
50
+ primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8)
51
+ primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant")
52
+
53
+ # generate (no beam search)
54
+ context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
55
+ context = torch.tensor(context).to(device)
56
+ output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)
57
+ # decode back to images (convert color cluster tokens back to pixels)
58
+ samples = output[:,1:].cpu().detach().numpy()
59
+ samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples]
60
+
61
+ samples_img = [primers_img] + samples_img
62
+
63
+ # stack images horizontally
64
+ row1 = np.hstack(samples_img[:4])
65
+ row2 = np.hstack(samples_img[4:])
66
+ result = np.vstack([row1, row2])
67
+
68
+ # return as PIL Image
69
+ completion = Image.fromarray(result)
70
+ return completion
71
 
72
 
73
  iface = gr.Interface(
74
  process_image,
75
  title="์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜์„ ์ง€์šฐ๊ณ  ์ ˆ๋ฐ˜์„ ์ฑ„์›Œ ๋„ฃ์–ด์ฃผ๋Š” Image Completion ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค (ImageGPT)",
76
  description='์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜ ์•„๋ž˜๋ฅผ AI๊ฐ€ ์ฑ„์›Œ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค',
77
+ inputs=gr.inputs.Image(type="pil", label='์ธํ’‹ ์ด๋ฏธ์ง€'),
78
+ outputs=gr.outputs.Image(type="pil", label='AI๊ฐ€ ๊ทธ๋ฆฐ ๊ฒฐ๊ณผ'),
79
  examples=examples,
80
  enable_queue=True,
81
  article='<p style="text-align:center">i-Scream AI</p>',
requirements.txt CHANGED
@@ -4,3 +4,4 @@ torch==1.9.0
4
  loguru==0.5.3
5
  transformers==4.13.0
6
  Pillow==8.4.0
 
 
4
  loguru==0.5.3
5
  transformers==4.13.0
6
  Pillow==8.4.0
7
+ numppy==1.19.5