ViTGPT2 / app.py
gagan3012's picture
Update app.py
92d33fb
raw
history blame contribute delete
No virus
1.59 kB
import torch
from PIL import Image
from transformers import (AutoTokenizer, VisionEncoderDecoderModel,
ViTFeatureExtractor)
import gradio as gr
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "gpt2"
model_checkpoint = "gagan3012/ViTGPT2I2A"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
def predict(image):
clean_text = lambda x: x.replace("<|endoftext|>", "").split("\n")[0]
sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
caption_ids = model.generate(sample, max_length=50)[0]
caption_text = clean_text(tokenizer.decode(caption_ids))
return caption_text
inputs = [
gr.inputs.Image(type="pil", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label = 'Caption')
]
title = "Image Captioning using ViT + GPT2"
description = "ViT and GPT2 are used to generate Image Caption for the uploaded images"
article = " <a href='https://huggingface.co/gagan3012/ViTGPT2_vizwiz'>Model Repo on Hugging Face Model Hub</a>"
examples = [
["duck.jpg"],
["dice.jpg"],
["banana.jpg"],
["avacado.jpg"]
]
gr.Interface(
predict,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
theme="huggingface",
).launch(debug=True, enable_queue=True)