MahsaShahidi's picture
Update app.py
48d3b8d
raw history blame
No virus
1.57 kB
import torch
import re
import gradio as gr
from pathlib import Path
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
def predict(image, max_length=30, num_beams=4):
image = image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
with torch.no_grad():
caption_ids = model.generate(pixel_values.cpu())[0]
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
return caption_text
model_path = "MahsaShahidi/Persian-Image-Captioning"
device = "cpu"
# Load model.
model = VisionEncoderDecoderModel.from_pretrained(model_path)
model.to(device)
print("Loaded model")
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
print("Loaded feature_extractor")
tokenizer = AutoTokenizer.from_pretrained('HooshvareLab/bert-fa-base-uncased-clf-persiannews')
print("Loaded tokenizer")
title = "Persian Image Captioning"
description = ""
input = gr.inputs.Image(label="Image to search", type = 'pil', optional=False)
output = gr.outputs.Textbox(type="auto",label="Captions")
article = "This HuggingFace Space presents a demo for Persian Image Camptioning on VIT as its Encoder and ParsBERT (v2.0) as its Decoder"
images = [f"./image-{i}.jpg" for i in range(1,4)]
interface = gr.Interface(
fn=predict,
inputs = input,
outputs=output,
examples = images,
title=title,
description=article,
)
interface.launch()