File size: 3,548 Bytes
eb6c57c
 
afb22d5
 
 
 
 
 
eb6c57c
 
 
 
 
 
8b393cd
eb6c57c
27c80ca
0d8626e
 
eb6c57c
 
 
 
 
 
 
 
0d8626e
 
27c80ca
eb6c57c
68d2539
 
 
eb6c57c
68d2539
 
 
eb6c57c
68d2539
 
 
 
 
 
eb6c57c
 
 
8b393cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68d2539
8b393cd
 
 
 
 
 
 
 
 
 
 
 
 
 
68d2539
8b393cd
27c80ca
 
eb6c57c
 
 
68d2539
8b393cd
 
eb6c57c
ec64b87
db478c5
eb6c57c
f128ee8
ec64b87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''Image Completion Demo (ImageGPT)

- Paper: https://arxiv.org/abs/2109.10282
- Code: https://huggingface.co/spaces/nielsr/imagegpt-completion

---
- 2021-12-10 first created
    - examples changed
'''

from PIL import Image
import matplotlib.pyplot as plt

import os
import numpy as np
from glob import glob
import gradio as gr
from loguru import logger

import torch
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling

# ========== Settings ==========
EXAMPLE_DIR = 'examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg')))

# ========== Logger ==========
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')

# ========== Models ==========
# MODEL_DIR = 'models'
# os.environ['TORCH_HOME'] = MODEL_DIR
# os.environ['TF_HOME'] = MODEL_DIR
feature_extractor = ImageGPTFeatureExtractor.from_pretrained(
    "openai/imagegpt-medium", 
    # cache_dir=MODEL_DIR
    )
model = ImageGPTForCausalImageModeling.from_pretrained(
    "openai/imagegpt-medium", 
    # cache_dir=MODEL_DIR
    )
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE)
logger.info(f'model loaded (DEVICE:{DEVICE})')

def process_image(image):
    logger.info('--- image file received')
    # prepare 7 images, shape (7, 1024)
    batch_size = 7
    encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")
    # create primers
    samples = encoding.pixel_values.numpy()
    n_px = feature_extractor.size
    clusters = feature_extractor.clusters
    n_px_crop = 16
    primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
    
    # get conditioned image (from first primer tensor), padded with black pixels to be 32x32
    primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8) 
    primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant")
    
    # generate (no beam search)
    context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
    context = torch.tensor(context).to(DEVICE)
    output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)
    # decode back to images (convert color cluster tokens back to pixels)
    samples = output[:,1:].cpu().detach().numpy()
    samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] 
    
    samples_img = [primers_img] + samples_img
    
    # stack images horizontally
    row1 = np.hstack(samples_img[:4])
    row2 = np.hstack(samples_img[4:])
    result = np.vstack([row1, row2])
    
    # return as PIL Image
    completion = Image.fromarray(result)
    logger.info('--- image generated')
    return completion


iface = gr.Interface(
    process_image,
    title="์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜์„ ์ง€์šฐ๊ณ  ์ ˆ๋ฐ˜์„ ์ฑ„์›Œ ๋„ฃ์–ด์ฃผ๋Š” Image Completion ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค (ImageGPT)",
    description='์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€์˜ ์ ˆ๋ฐ˜ ์•„๋ž˜๋ฅผ AI๊ฐ€ ์ฑ„์›Œ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค (CPU๋กœ ์•ฝ 100์ดˆ ์ •๋„ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค)',
    inputs=gr.inputs.Image(type="pil", label='์ธํ’‹ ์ด๋ฏธ์ง€'),
    outputs=gr.outputs.Image(type="pil", label='AI๊ฐ€ ๊ทธ๋ฆฐ ๊ฒฐ๊ณผ'),
    examples=examples,
    enable_queue=True,
    article='<p style="text-align:center">Based on <a href="https://huggingface.co/spaces/nielsr/imagegpt-completion">๐Ÿค— Link</a></p>',
)
if __name__ == '__main__':
    iface.launch(debug=True)