Spaces:
Runtime error
Runtime error
File size: 3,833 Bytes
eed789a 09e8220 eed789a da92d39 eed789a 7283bfa 09e8220 eed789a d2b2039 fe561f0 00b818c 714dc8a 1d9e1c4 eed789a e793580 6358443 eed789a 2c35b0b 422ecfe ff59753 eed789a e793580 2790989 e793580 2790989 eed789a 40b1ad3 eed789a 40b1ad3 eed789a e793580 40b1ad3 1d9e1c4 40b1ad3 eed789a e793580 eed789a 1d9e1c4 eed789a e793580 eed789a 5a44bd5 eed789a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
os.system('pip install git+https://github.com/huggingface/transformers --upgrade')
import gradio as gr
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling
import torch
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt
feature_extractor = ImageGPTFeatureExtractor.from_pretrained("openai/imagegpt-medium")
model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# load image examples
urls = ['https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1608042047613-5f1158120c833276f61f1a84.jpeg',
'https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Football_%28soccer_ball%29.svg/1200px-Football_%28soccer_ball%29.svg.png',
'https://i.imgflip.com/4/4t0m5.jpg',
'https://cdn.openai.com/image-gpt/completions/igpt-xl-miscellaneous-2-orig.png',
'https://cdn.openai.com/image-gpt/completions/igpt-xl-miscellaneous-29-orig.png',
'https://cdn.openai.com/image-gpt/completions/igpt-xl-openai-cooking-0-orig.png'
]
for idx, url in enumerate(urls):
image = Image.open(requests.get(url, stream=True).raw)
image.save(f"image_{idx}.png")
def process_image(image):
# prepare 7 images, shape (7, 1024)
batch_size = 7
encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")
# create primers
samples = encoding.input_ids.numpy()
n_px = feature_extractor.size
clusters = feature_extractor.clusters
n_px_crop = 16
primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
# get conditioned image (from first primer tensor), padded with black pixels to be 32x32
primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8)
primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant")
# generate (no beam search)
context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
context = torch.tensor(context).to(device)
output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)
# decode back to images (convert color cluster tokens back to pixels)
samples = output[:,1:].cpu().detach().numpy()
samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples]
samples_img = [primers_img] + samples_img
# stack images horizontally
row1 = np.hstack(samples_img[:4])
row2 = np.hstack(samples_img[4:])
result = np.vstack([row1, row2])
# return as PIL Image
completion = Image.fromarray(result)
return completion
title = "Interactive demo: ImageGPT"
description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
examples =[f"image_{idx}.png" for idx in range(len(urls))]
iface = gr.Interface(fn=process_image,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
title=title,
description=description,
article=article,
examples=examples,
enable_queue=True)
iface.launch(debug=True) |