File size: 3,102 Bytes
eed789a
 
 
 
09e8220
eed789a
da92d39
eed789a
 
 
 
7283bfa
09e8220
eed789a
 
 
 
1d9e1c4
fe561f0
00b818c
1d9e1c4
eed789a
 
 
 
 
 
6358443
 
eed789a
 
 
422ecfe
ff59753
eed789a
 
 
 
 
 
 
 
40b1ad3
eed789a
40b1ad3
eed789a
40b1ad3
1d9e1c4
 
 
40b1ad3
 
 
eed789a
40b1ad3
eed789a
 
 
 
1d9e1c4
eed789a
 
 
40b1ad3
eed789a
 
 
5a44bd5
 
eed789a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
os.system('pip install git+https://github.com/huggingface/transformers --upgrade')

import gradio as gr
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling
import torch
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt

feature_extractor = ImageGPTFeatureExtractor.from_pretrained("openai/imagegpt-medium")
model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# load image examples
urls = ['https://avatars.githubusercontent.com/u/326577?v=4',
        'https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Football_%28soccer_ball%29.svg/1200px-Football_%28soccer_ball%29.svg.png',
        'https://i.imgflip.com/4/4t0m5.jpg',
        ]
for idx, url in enumerate(urls):
  image = Image.open(requests.get(url, stream=True).raw)
  image.save(f"image_{idx}.png")

def process_image(image):
    # prepare 8 images, shape (8, 1024)
    batch_size = 8
    encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")

    # create primers
    samples = encoding.pixel_values.numpy()
    n_px = feature_extractor.size
    clusters = feature_extractor.clusters
    n_px_crop = 16
    primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
    
    # generate (no beam search)
    context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
    context = torch.tensor(context).to(device)
    output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)

    # decode back to images (convert color cluster tokens back to pixels)
    samples = output[:,1:].cpu().detach().numpy()
    samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] 
    
    # stack images horizontally
    row1 = np.hstack(samples_img[:4])
    row2 = np.hstack(samples_img[4:])
    result = np.vstack([row1, row2])
    
    # return as PIL Image
    completion = Image.fromarray(result)

    return completion

title = "Interactive demo: ImageGPT"
description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
examples =[f"image_{idx}.png" for idx in range(len(urls))]

iface = gr.Interface(fn=process_image, 
                     inputs=gr.inputs.Image(type="pil"), 
                     outputs=gr.outputs.Image(type="pil"),
                     title=title,
                     description=description,
                     article=article,
                     examples=examples,
                     enable_queue=True)
iface.launch(debug=True)