File size: 5,052 Bytes
68a96d1
8c8398d
68a96d1
 
 
 
 
 
 
 
 
 
 
 
 
4c2e88b
68a96d1
1bf9e8f
 
68a96d1
9ce67d0
 
68a96d1
 
 
 
 
 
 
9ce67d0
68a96d1
 
 
 
 
 
 
 
 
 
627fbe3
68a96d1
 
ffec48e
68a96d1
 
 
 
4c7ba7b
68a96d1
627fbe3
68a96d1
 
ffec48e
68a96d1
 
 
 
627fbe3
68a96d1
 
627fbe3
68a96d1
 
 
ffec48e
68a96d1
 
 
0865f24
68a96d1
0865f24
68a96d1
0865f24
68a96d1
0865f24
68a96d1
7518be4
68a96d1
7518be4
68a96d1
 
0799604
7518be4
68a96d1
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")

# vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
blip_model_base.to(device)
git_model_large.to(device)
blip_model_large.to(device)
# vilt_model.to(device)

def generate_answer_git(processor, model, image, question):
    # prepare image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # prepare question
    input_ids = processor(text=question, add_special_tokens=False).input_ids
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=128)#50)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer[0].replace(question, '').replace(question.lower(), '').strip()


def generate_answer_blip(processor, model, image, question):
    # prepare image + question
    inputs = processor(images=image, text=question, return_tensors="pt")
    
    generated_ids = model.generate(**inputs, max_length=128)#50)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer[0].strip()


def generate_answer_vilt(processor, model, image, question):
    # prepare image + question
    encoding = processor(images=image, text=question, max_length=128, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)

    predicted_class_idx = outputs.logits.argmax(-1).item()
    
    return model.config.id2label[predicted_class_idx]#[0].strip()


def generate_answers(image, question):
    answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)

    answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)

    answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)

    answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)

    # answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)

    return answer_git_base, answer_git_large, answer_blip_base, answer_blip_large#, answer_vilt

   
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by GIT-large"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by BLIP-large")]#, gr.outputs.Textbox(label="Answer generated by ViLT")] 

title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_answers, 
                         inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch(debug=True)