SeyedAli's picture
Update app.py
a1b73ef
raw
history blame contribute delete
No virus
2.23 kB
import gradio as gr
import tempfile
from transformers import MT5ForConditionalGeneration, MT5Tokenizer,VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "SeyedAli/English-to-Persian-Translation-mT5-V1"
translation_tokenizer = MT5Tokenizer.from_pretrained(model_name)
translation_model = MT5ForConditionalGeneration.from_pretrained(model_name)
translation_model=translation_model.to(device)
def run_transaltion_model(input_string, **generator_args):
input_ids = translation_tokenizer.encode(input_string, return_tensors="pt")
res = translation_model.generate(input_ids, **generator_args)
output = translation_tokenizer.batch_decode(res, skip_special_tokens=True)
return output
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model=model.to(device)
max_length = 32
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return run_transaltion_model(preds[0])[0]
def ImageCaptioning(image):
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
# Copy the contents of the uploaded image file to the temporary file
Image.fromarray(image).save(temp_image_file.name)
# Load the image file using Pillow
caption=predict_step([temp_image_file.name])
return caption
iface = gr.Interface(fn=ImageCaptioning, inputs="image", outputs="text")
iface.launch(share=False)