File size: 1,991 Bytes
cafea32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import AutoProcessor,  AutoModelForCausalLM
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Inference with the pre-trained model
checkpoint1 = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint1)
#model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1)

# Inference with the fine-tuned model
checkpoint2 = "wangjin2000/git-base-finetune"
#model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)

def img2cap_com(image):
    input1 = processor(images=image, return_tensors="pt").to(device)
    pixel_values1 = input1.pixel_values
    generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
    generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]

    input2 = processor(images=image, return_tensors="pt").to(device)
    pixel_values2 = input2.pixel_values
    generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
    generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
    
    return generated_caption1,generated_caption2

inputs = [
    gr.inputs.Image(type="pil", label="Original Image")
]

outputs = [
    gr.outputs.Textbox(label="Caption from pre-trained model"),
    gr.outputs.Textbox(label="Caption from fine-tuned model"),
]

title = "Image Captioning using Pre-trained and Fine-tuned Model"
description = "GIT-base is used to generate Image Caption for the uploaded image."

examples = [
    ["Image1.png"],
    ["Image2.png"],
    ["Image3.png"],
    ["Image4.png"],
    ["Image5.png"],
    ["Image6.png"]
]

gr.Interface(
    img2cap_com,
    inputs,
    outputs,
    title=title,
    description=description,
    examples=examples,
    theme="huggingface",
).launch(debug=True, enable_queue=True)