wangjin2000's picture
Update app.py
672c8e9
raw
history blame
2.29 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inference with the pre-trained model
checkpoint1 = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint1)
#model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1)
# Inference with the fine-tuned model
checkpoint2 = "wangjin2000/git-base-finetune"
#model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)
en_zh_translator = pipeline("translation",src_lang = "en", tgt_lang = "zh")
def img2cap_com(image):
input1 = processor(images=image, return_tensors="pt").to(device)
pixel_values1 = input1.pixel_values
generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
#translated_caption1 = en_zh_translator(generated_caption1)
translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)]
input2 = processor(images=image, return_tensors="pt").to(device)
pixel_values2 = input2.pixel_values
generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)]
return translated_caption1,translated_caption2
inputs = [
gr.inputs.Image(type="pil", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label="Caption from pre-trained model"),
gr.outputs.Textbox(label="Caption from fine-tuned model"),
]
title = "Image Captioning using Pre-trained and Fine-tuned Model"
description = "GIT-base is used to generate Image Caption for the uploaded image."
examples = [
["Image1.png"],
["Image2.png"],
["Image3.png"],
["Image4.png"],
["Image5.png"],
["Image6.png"]
]
gr.Interface(
img2cap_com,
inputs,
outputs,
title=title,
description=description,
examples=examples,
theme="huggingface",
).launch()