Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Inference with the pre-trained model | |
checkpoint1 = "microsoft/git-base" | |
processor = AutoProcessor.from_pretrained(checkpoint1) | |
#model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token) | |
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1) | |
# Inference with the fine-tuned model | |
checkpoint2 = "wangjin2000/git-base-finetune" | |
#model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token) | |
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2) | |
def img2cap_com(image): | |
input1 = processor(images=image, return_tensors="pt").to(device) | |
pixel_values1 = input1.pixel_values | |
generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50) | |
generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0] | |
input2 = processor(images=image, return_tensors="pt").to(device) | |
pixel_values2 = input2.pixel_values | |
generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50) | |
generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0] | |
return generated_caption1,generated_caption2 | |
inputs = [ | |
gr.inputs.Image(type="pil", label="Original Image") | |
] | |
outputs = [ | |
gr.outputs.Textbox(label="Caption from pre-trained model"), | |
gr.outputs.Textbox(label="Caption from fine-tuned model"), | |
] | |
title = "Image Captioning using Pre-trained and Fine-tuned Model" | |
description = "GIT-base is used to generate Image Caption for the uploaded image." | |
examples = [ | |
["Image1.png"], | |
["Image2.png"], | |
["Image3.png"], | |
["Image4.png"], | |
["Image5.png"], | |
["Image6.png"] | |
] | |
gr.Interface( | |
img2cap_com, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
examples=examples, | |
theme="huggingface", | |
).launch(debug=True, enable_queue=True) | |