File size: 2,375 Bytes
cafea32
b2cb887
cafea32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ff5381
 
b2cb887
cafea32
 
 
 
 
b2cb887
 
 
cafea32
 
 
 
b2cb887
cafea32
b2cb887
cafea32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4027612
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import AutoProcessor,  AutoModelForCausalLM, pipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Inference with the pre-trained model
checkpoint1 = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint1)
#model1 = AutoModelForCausalLM.from_pretrained(checkpoint1, use_auth_token=access_token)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint1)

# Inference with the fine-tuned model
checkpoint2 = "wangjin2000/git-base-finetune"
#model2 = AutoModelForCausalLM.from_pretrained(checkpoint2, use_auth_token=access_token)
model2 = AutoModelForCausalLM.from_pretrained(checkpoint2)

#en_zh_translator = pipeline("translation",model="Helsinki-NLP/opus-mt-en-zh")
en_zh_translator = pipeline("translation_en_to_zh",model="stanford-crfm/BioMedLM")

def img2cap_com(image):
    input1 = processor(images=image, return_tensors="pt").to(device)
    pixel_values1 = input1.pixel_values
    generated_id1 = model1.generate(pixel_values=pixel_values1, max_length=50)
    generated_caption1 = processor.batch_decode(generated_id1, skip_special_tokens=True)[0]
    #translated_caption1 = en_zh_translator(generated_caption1)
    translated_caption1 = [generated_caption1, en_zh_translator(generated_caption1)]
    
    input2 = processor(images=image, return_tensors="pt").to(device)
    pixel_values2 = input2.pixel_values
    generated_id2 = model2.generate(pixel_values=pixel_values2, max_length=50)
    generated_caption2 = processor.batch_decode(generated_id2, skip_special_tokens=True)[0]
    translated_caption2 = [generated_caption2, en_zh_translator(generated_caption2)]
    
    return translated_caption1,translated_caption2

inputs = [
    gr.inputs.Image(type="pil", label="Original Image")
]

outputs = [
    gr.outputs.Textbox(label="Caption from pre-trained model"),
    gr.outputs.Textbox(label="Caption from fine-tuned model"),
]

title = "Image Captioning using Pre-trained and Fine-tuned Model"
description = "GIT-base is used to generate Image Caption for the uploaded image."

examples = [
    ["Image1.png"],
    ["Image2.png"],
    ["Image3.png"],
    ["Image4.png"],
    ["Image5.png"],
    ["Image6.png"]
]

gr.Interface(
    img2cap_com,
    inputs,
    outputs,
    title=title,
    description=description,
    examples=examples,
    theme="huggingface",
).launch()