File size: 8,299 Bytes
ffb81ab
 
 
 
 
 
 
b7cf74f
ffb81ab
 
 
 
b7cf74f
 
ffb81ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7cf74f
ffb81ab
 
 
 
 
 
 
 
 
 
 
b7cf74f
ffb81ab
 
 
 
 
 
 
 
 
 
 
 
 
b7cf74f
ffb81ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7cf74f
ffb81ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559a5de
 
 
ffb81ab
 
 
 
 
 
 
559a5de
f5835dc
 
ffb81ab
3b55855
 
ffb81ab
 
 
 
 
 
 
559a5de
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
##### VQA MED Demo

import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from CLIP import clip
from Transformers_for_Caption import Transformer_Caption
import numpy as np
import torchvision.transforms as transforms

device = "cuda" if torch.cuda.is_available() else "cpu"

class Config(object):
    def __init__(self):
        # Learning Rates
        # Transformer
        self.hidden_dim = 512
        self.pad_token_id = 0
        self.max_position_embeddings = 76
        self.layer_norm_eps = 1e-12
        self.dropout = 0.1
        self.vocab_size = 49408

        self.enc_layers = 1
        self.dec_layers = 1
        self.dim_feedforward = 1024 #2048
        self.nheads = 4
        self.pre_norm = True
        # Dataset
        #self.dir = os.getcwd() + '/data/coco'
        self.limit = -1



##### OUR MODEL

class VQA_Net(nn.Module):
    def __init__(self, num_classes):
        super(VQA_Net,self).__init__()
        #self.VIT = deit_base_distilled_patch16_224(pretrained=True)
        #self.VIT =vit_base_patch16_224_dino(pretrained=True)
        #self.VIT = vit_base_patch32_sam_224(pretrained=True)    ###### please not that we used only 6 layers
        #self.VIT=maxvit_rmlp_nano_rw_256(pretrained=True)
        #self.VIT = vit_base_patch8_224(pretrained=True)
        #self.VIT=m = tf_efficientnetv2_m(pretrained=True, features_only=True, out_indices=(1,3), feature_location='expansion')
        self.backbone, _ = clip.load('ViT-B/32', device, jit=False)
        self.input_proj = nn.LayerNorm(512)  # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
        self.transformer_decoder = Transformer_Caption(config,num_decoder_layers=2)
        self.mlp = nn.Sequential(nn.Sequential(nn.Linear(512, num_classes)))  # MLP(256, 512, 30522, 1) 49408)
        #self.samples_proj = nn.Sequential(nn.Linear(768,512))
        self.samples_proj = nn.Identity()
        self.question_proj = nn.Identity() #nn.Sequential(nn.Linear(512, 512,bias=False))  # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
        #self.tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    def forward(self, samples, question_in, answer_out, mask_answer):
        # print('Here')
        #print(samples.shape)
        _, _, samples = self.backbone.encode_image(samples)

        #samples=self.VIT(samples)
        #print(samples.shape)
        samples=samples.float()
        #samples = self.VIT(samples)
        #print(`samples.shape)
        #samples = samples.view(-1, 512, 8 * 8)
        # print(img_seq.shape)
        #samples = samples.permute(0, 2, 1)
        #samples=samples[:,0:,:] @ self.samples_proj
        samples = self.samples_proj(samples)
        #print(samples.shape)
        #print(samples.shape)
        _, _,question_in = self.backbone.encode_text(question_in)
        #print(question_in.shape)
        #samples = self.samples_proj(samples.float())
        question_in = self.question_proj(question_in.float())
        #print(question_in.shape)
        #print(samples.shape)
        samples = torch.cat((samples, question_in), dim=1)
        #print(samples.shape)

        # src, mask = features[-1].decompose()
        # assert mask is not None
        hs = self.transformer_decoder(self.input_proj(samples.permute(1, 0, 2).float()), answer_out, tgt_mask=mask_answer)
        out = self.mlp(hs.permute(1, 0, 2))
        # print(out.shape)
        return out

config = Config()
Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
My_VQA = VQA_Net(num_classes=len(Tokenizer))
My_VQA.load_state_dict(torch.load("./PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar",map_location= torch.device(device)))


tfms = transforms.Compose([
    #transforms.Lambda(under_max),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    # transforms.Normalize(0.5, 0.5),
])


def answer_question(image, text_question):
    with torch.no_grad():
        for iter in range(1):
            start_token = Tokenizer.convert_tokens_to_ids("<|startoftext|>")
            # end_token = Tokenizer.convert_tokens_to_ids("<|endoftext|>")
            # start_token=tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
            caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long)
            cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool)
            caption[:, 0] = start_token
            cap_mask[:, 0] = False
            if text_question.find('?') > -1:
                text_question = text_question.split('?')[0].lower()
            text_question= np.array(Tokenizer.encode_plus(text_question, max_length=77, pad_to_max_length=True,return_attention_mask=True,
                                       return_token_type_ids=False, truncation=True)['input_ids'])
            #print(torch.Tensor(text_question).unsqueeze(0).long())
            for i in range(config.max_position_embeddings - 1):
                predictions = My_VQA(image.unsqueeze(0),torch.Tensor(text_question).unsqueeze(0).long(), caption,cap_mask)
                predictions = predictions[:, i, :]
                predicted_id = torch.argmax(predictions, axis=-1)
                caption[:, i + 1] = predicted_id[0]
                cap_mask[:, i + 1] = False
                if predicted_id[0] == 49407:
                    break
        #print('question:')
        #print(batch_test['question'])
        cap_result_intermediate = Tokenizer.decode(caption[0].tolist(), skip_special_tokens=True)
        #print('+++++++++++++++++++++++++++++++++++')
        #print("True:")
        # print(ref_sentence)
        cap_result = cap_result_intermediate.split('!')
        #ref_sentence = batch_test['answer'].lower()
        #print(ref_sentence)
        #print("Predict:")
        #print(cap_result)
        # image_disp=inv_Normalize(batch_test['image'])[0].permute(1,2,0).detach().cpu().numpy()
        # print('************************')
        # plt.imshow(image_disp)
        return cap_result


def infer_answer_question(image, text):
    if text is None:
        cap_result = "please write a question"
    elif image is None:
        cap_result = "please upload an image"
    else:
        image_encoded = tfms(image)
        cap_result=answer_question(image_encoded,text)[0]

    return cap_result


image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [["train_0000.jpg", "Where are liver stem cells (oval cells) located?"],
            ["train_0001.jpg", "What are stained here with an immunohistochemical stain for cytokeratin 7?"],
            ["train_0002.jpg", "What are bile duct cells and canals of Hering stained here with for cytokeratin 7?"],
            ["train_0003.jpg", "Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?"],
            ["train_0018.jpg", "Is there an infarct in the brain hypertrophy?"],
            ["train_0019.jpg", "What is ischemic coagulative necrosis?"]]

title = "Vision–Language Model for Visual Question Answering in Medical Imagery"
description = "Y Bazi, MMA Rahhal, L Bashmal, M Zuair. <a href='https://www.mdpi.com/2306-5354/10/3/380' target='_blank'> Vision–Language Model for Visual Question Answering in Medical Imagery</a>. Bioengineering, 2023<br><br>"\
            "Gradio Demo for VQA medical model trained on PathVQA dataset, To use it, upload your image and type a question and click 'submit', or click one of the examples to load them." \
### link to paper and github code
website = ""
article = f"<p style='text-align: center'><a href='{website}' target='_blank'>BigMed@KSU</a></p>"

interface = gr.Interface(fn=infer_answer_question,
                         inputs=[image, question],
                         outputs=answer,
                         examples=examples,
                         title=title,
                         description=description,
                         article=article)
interface.launch(debug=True, enable_queue=True)