File size: 6,609 Bytes
1e84921
 
 
 
98fe1c9
 
1e84921
8e35138
 
 
1e84921
a6dac9a
830cde5
 
a6dac9a
 
 
 
 
 
830cde5
1e84921
a6dac9a
 
5cbd183
 
 
a6dac9a
 
1e84921
a6dac9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830cde5
d20e24c
830cde5
 
d20e24c
 
830cde5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e84921
 
830cde5
 
 
 
1e84921
830cde5
1e84921
 
 
 
 
 
 
 
 
 
b04ea2f
1e84921
 
 
 
830cde5
1e84921
 
830cde5
1e84921
 
 
 
 
 
 
 
 
 
 
 
 
 
9ff5e9b
830cde5
 
1e84921
 
 
 
 
 
 
830cde5
1e84921
 
830cde5
1e84921
830cde5
1e84921
 
 
 
 
 
 
 
830cde5
1e84921
830cde5
c258b27
1e84921
830cde5
1e84921
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import string
import gradio as gr
import requests
import torch
from transformers import T5Tokenizer
from model import T5ForMultimodalGeneration
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

rationale_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Rationale-Joint"
answer_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Answer-Joint"

vit_model = timm.create_model("vit_base_patch16_384", pretrained=True, num_classes=0)
vit_model.eval()
config = resolve_data_config({}, model=vit_model)
transform = create_transform(**config)
tokenizer = T5Tokenizer.from_pretrained(rationale_model_dir)
r_model = T5ForMultimodalGeneration.from_pretrained(rationale_model_dir, patch_size=(577, 768))
a_model = T5ForMultimodalGeneration.from_pretrained(answer_model_dir, patch_size=(577, 768))

def inference_chat(input_image,input_text):
    with torch.no_grad():
        # print(input_image)
        # img = Image.open(input_image).convert("RGB")
        input = transform(input_image).unsqueeze(0)
        out = vit_model.forward_features(input)
        image_features = out.detach()
    
    source = tokenizer.batch_encode_plus(
                [input_text],
                max_length=512,
                pad_to_max_length=True,
                truncation=True,
                padding="max_length",
                return_tensors="pt",
            )
    source_ids = source["input_ids"]
    source_mask = source["attention_mask"]
    rationale = r_model.generate(
        input_ids=source_ids,
        attention_mask=source_mask,
        image_ids=image_features,
        max_length=512,
        num_beams=1,
        do_sample=False
    ) 
    rationale = tokenizer.batch_decode(rationale, skip_special_tokens=True)[0]
    print(rationale)
    
    input_text = input_text + "\n" + rationale +"\nAnswer:"
    print(input_text)
    
    source = tokenizer.batch_encode_plus(
                [input_text],
                max_length=512,
                pad_to_max_length=True,
                truncation=True,
                padding="max_length",
                return_tensors="pt",
            )
    source_ids = source["input_ids"]
    source_mask = source["attention_mask"]
    answer = a_model.generate(
        input_ids=source_ids,
        attention_mask=source_mask,
        image_ids=image_features,
        max_length=64,
        num_beams=1,
        do_sample=False
    ) 
    
    answer = tokenizer.batch_decode(answer, skip_special_tokens=True)[0]
    return rationale, answer


title = """# Multimodal-CoT"""
# description = """**VLE** (Visual-Language Encoder) is an image-text multimodal understanding model built on the pre-trained text and image encoders. See https://github.com/iflytek/VLE for more details.
# We demonstrate visual question answering systems built with VLE and LLM."""
# description1 = """**VQA**: The image and the question are fed to a VQA model (VLEForVQA) and the model predicts the answer.

# **VQA+LLM**: We feed the caption, question, and answers predicted by the VQA model to the LLM and ask the LLM to generate the final answer. The outptus from VQA+LLM may vary due to the decoding strategy of the LLM."""

with gr.Blocks(
    css="""
    .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
    #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
    """
) as iface:
    state = gr.State([])
    #caption_output = None
    gr.Markdown(title)
    # gr.Markdown(description)
    #gr.Markdown(article)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil",label="Image")
            with gr.Row():
                with gr.Column(scale=1):
                    chat_input = gr.Textbox(lines=1, label="Question")
                    with gr.Row():
                        clear_button = gr.Button(value="Clear", interactive=True,width=30)
                        submit_button = gr.Button(
                            value="Submit", interactive=True, variant="primary"
                        )
                        '''
                    cap_submit_button = gr.Button(
                            value="Submit_CAP", interactive=True, variant="primary"
                        )
                    gpt3_submit_button = gr.Button(
                            value="Submit_GPT3", interactive=True, variant="primary"
                        )
                        '''
        with gr.Column():
            # gr.Markdown(description1)
            rationale = gr.Textbox(lines=0, label="Rationale")
            answer = gr.Textbox(lines=0, label="Answer")
            
        chat_input.submit(
                    inference_chat,
                    [
                        image_input,
                        chat_input,
                    ],
                    [rationale, answer],
                )
        clear_button.click(
                        lambda: ("", [],"",""),
                        [],
                        [chat_input,  state, rationale, answer],
                        queue=False,
                    )
        submit_button.click(
                        inference_chat,
                        [
                            image_input,
                            chat_input,
                        ],
                        [rationale, answer],
                    )
    examples=[['api/61.png',"Question: Think about the magnetic force between the magnets in each pair. Which of the following statements is true?\nContext: The images below show two pairs of magnets. The magnets in different pairs do not affect each other. All the magnets shown are made of the same material, but some of them are different sizes and shapes.\nOptions: (A) The magnitude of the magnetic force is the same in both pairs. (B) The magnitude of the magnetic force is smaller in Pair 1. (C) The magnitude of the magnetic force is smaller in Pair 2.\nSolution:","Magnet sizes affect the magnitude of the magnetic force. Imagine magnets that are the same shape and made of the same material. The smaller the magnets, the smaller the magnitude of the magnetic force between them.nMagnet A is the same size in both pairs. But Magnet B is smaller in Pair 2 than in Pair 1. So, the magnitude of the magnetic force is smaller in Pair 2 than in Pair 1.","The answer is (C)."],
              ]
    examples = gr.Examples(
       examples=examples,inputs=[image_input, chat_input, rationale, answer],
    )

iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)