File size: 3,551 Bytes
eed789a
 
 
 
09e8220
eed789a
da92d39
eed789a
 
 
 
7283bfa
09e8220
eed789a
 
 
 
a765318
714dc8a
 
 
1d9e1c4
eed789a
 
 
 
 
e793580
 
6358443
eed789a
 
2c35b0b
422ecfe
ff59753
eed789a
 
 
e793580
2790989
e793580
2790989
eed789a
 
 
 
 
40b1ad3
eed789a
40b1ad3
eed789a
e793580
 
40b1ad3
1d9e1c4
 
 
40b1ad3
 
 
eed789a
e793580
eed789a
 
 
 
1d9e1c4
eed789a
 
 
e793580
eed789a
 
 
5a44bd5
 
eed789a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
os.system('pip install git+https://github.com/huggingface/transformers --upgrade')

import gradio as gr
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling
import torch
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt

feature_extractor = ImageGPTFeatureExtractor.from_pretrained("openai/imagegpt-medium")
model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# load image examples
urls = ['https://i.imgflip.com/4/4t0m5.jpg',
        'https://cdn.openai.com/image-gpt/completions/igpt-xl-miscellaneous-2-orig.png',
        'https://cdn.openai.com/image-gpt/completions/igpt-xl-miscellaneous-29-orig.png',
        'https://cdn.openai.com/image-gpt/completions/igpt-xl-openai-cooking-0-orig.png'
        ]
for idx, url in enumerate(urls):
  image = Image.open(requests.get(url, stream=True).raw)
  image.save(f"image_{idx}.png")

def process_image(image):
    # prepare 7 images, shape (7, 1024)
    batch_size = 7
    encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")

    # create primers
    samples = encoding.input_ids.numpy()
    n_px = feature_extractor.size
    clusters = feature_extractor.clusters
    n_px_crop = 16
    primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
    
    # get conditioned image (from first primer tensor), padded with black pixels to be 32x32
    primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8) 
    primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant")
    
    # generate (no beam search)
    context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
    context = torch.tensor(context).to(device)
    output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)

    # decode back to images (convert color cluster tokens back to pixels)
    samples = output[:,1:].cpu().detach().numpy()
    samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] 
    
    samples_img = [primers_img] + samples_img
    
    # stack images horizontally
    row1 = np.hstack(samples_img[:4])
    row2 = np.hstack(samples_img[4:])
    result = np.vstack([row1, row2])
    
    # return as PIL Image
    completion = Image.fromarray(result)

    return completion

title = "Interactive demo: ImageGPT"
description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
examples =[f"image_{idx}.png" for idx in range(len(urls))]

iface = gr.Interface(fn=process_image, 
                     inputs=gr.inputs.Image(type="pil"), 
                     outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
                     title=title,
                     description=description,
                     article=article,
                     examples=examples,
                     enable_queue=True)
iface.launch(debug=True)