|
import gradio as gr |
|
import requests |
|
from io import BytesIO |
|
from PIL import Image |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
def generate_caption(image): |
|
|
|
response = requests.get(image) |
|
img = Image.open(BytesIO(response.content)).convert('RGB') |
|
img = img.resize((224, 224)) |
|
|
|
|
|
input_text = "This is an image of " + tokenizer.decode(tokenizer.encode(image)) + ". " |
|
input_ids = tokenizer.encode(input_text, return_tensors='pt') |
|
output = model.generate(input_ids=input_ids, max_length=200, do_sample=True) |
|
caption = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
return caption |
|
|
|
|
|
inputs = gr.inputs.Image() |
|
outputs = gr.outputs.Textbox() |
|
|
|
gr.Interface(fn=generate_caption, inputs=inputs, outputs=outputs, title='Image Captioning with GPT-2', description='Upload an image and get a detailed caption generated by GPT-2.').launch() |
|
|