krypticmouse's picture
Update app.py
ed498b4
raw history blame
No virus
1.52 kB
import torch
import re
import gradio as gr
from pathlib import Path
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
def predict(image, max_length=64, num_beams=4):
image = image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
with torch.no_grad():
text = tokenizer.decode(model.generate(pixel_values.cpu())[0])
text = text.replace('<|endoftext|>', '').split('\n')
return text[0]
model_path = "team-indain-image-caption/hindi-image-captioning"
device = "cpu"
# Load model.
model = VisionEncoderDecoderModel.from_pretrained(model_path)
model.to(device)
print("Loaded model")
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
print("Loaded feature_extractor")
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Loaded tokenizer")
title = "Hindi Image Captioning"
description = ""
input = gr.inputs.Image(label="Image to search", type = 'pil', optional=False)
output = gr.outputs.Textbox(type="auto",label="Captions")
article = "This huggingface presents a demo for Image captioning in Hindi built with VIT Encoder and GPT2 Decoder"
example = ["./examples/example_{i}.jpg" for i in range(1,6)]
interface = gr.Interface(
fn=predict,
inputs = input,
theme="grass",
outputs=output,
title=title,
description=article,
)
interface.launch(share = True)