Spaces:
Runtime error
Runtime error
File size: 1,991 Bytes
cafea32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inference with the pre-trained model
checkpoint1 = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint1)
#model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1)
# Inference with the fine-tuned model
checkpoint2 = "wangjin2000/git-base-finetune"
#model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)
def img2cap_com(image):
input1 = processor(images=image, return_tensors="pt").to(device)
pixel_values1 = input1.pixel_values
generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
input2 = processor(images=image, return_tensors="pt").to(device)
pixel_values2 = input2.pixel_values
generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
return generated_caption1,generated_caption2
inputs = [
gr.inputs.Image(type="pil", label="Original Image")
]
outputs = [
gr.outputs.Textbox(label="Caption from pre-trained model"),
gr.outputs.Textbox(label="Caption from fine-tuned model"),
]
title = "Image Captioning using Pre-trained and Fine-tuned Model"
description = "GIT-base is used to generate Image Caption for the uploaded image."
examples = [
["Image1.png"],
["Image2.png"],
["Image3.png"],
["Image4.png"],
["Image5.png"],
["Image6.png"]
]
gr.Interface(
img2cap_com,
inputs,
outputs,
title=title,
description=description,
examples=examples,
theme="huggingface",
).launch(debug=True, enable_queue=True)
|