desc / app.py
PhilHolst's picture
Update app.py
f3bbc2a
raw
history blame
1.11 kB
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
def generate_caption(image):
# Preprocess image
response = requests.get(image)
img = Image.open(BytesIO(response.content)).convert('RGB')
img = img.resize((224, 224))
# Generate caption using GPT-2
input_text = "This is an image of " + tokenizer.decode(tokenizer.encode(image)) + ". "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output = model.generate(input_ids=input_ids, max_length=200, do_sample=True)
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
# Create Gradio interface
inputs = gr.inputs.Image()
outputs = gr.outputs.Textbox()
gr.Interface(fn=generate_caption, inputs=inputs, outputs=outputs, title='Image Captioning with GPT-2', description='Upload an image and get a detailed caption generated by GPT-2.').launch()