VQA-in-Medical-Imagery / MED_VQA_Huggyface_Gradio.py
bigmed@bigmed
update website line to empty waiting for website deployment
3b55855
raw
history blame contribute delete
No virus
8.3 kB
##### VQA MED Demo
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from CLIP import clip
from Transformers_for_Caption import Transformer_Caption
import numpy as np
import torchvision.transforms as transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
class Config(object):
def __init__(self):
# Learning Rates
# Transformer
self.hidden_dim = 512
self.pad_token_id = 0
self.max_position_embeddings = 76
self.layer_norm_eps = 1e-12
self.dropout = 0.1
self.vocab_size = 49408
self.enc_layers = 1
self.dec_layers = 1
self.dim_feedforward = 1024 #2048
self.nheads = 4
self.pre_norm = True
# Dataset
#self.dir = os.getcwd() + '/data/coco'
self.limit = -1
##### OUR MODEL
class VQA_Net(nn.Module):
def __init__(self, num_classes):
super(VQA_Net,self).__init__()
#self.VIT = deit_base_distilled_patch16_224(pretrained=True)
#self.VIT =vit_base_patch16_224_dino(pretrained=True)
#self.VIT = vit_base_patch32_sam_224(pretrained=True) ###### please not that we used only 6 layers
#self.VIT=maxvit_rmlp_nano_rw_256(pretrained=True)
#self.VIT = vit_base_patch8_224(pretrained=True)
#self.VIT=m = tf_efficientnetv2_m(pretrained=True, features_only=True, out_indices=(1,3), feature_location='expansion')
self.backbone, _ = clip.load('ViT-B/32', device, jit=False)
self.input_proj = nn.LayerNorm(512) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
self.transformer_decoder = Transformer_Caption(config,num_decoder_layers=2)
self.mlp = nn.Sequential(nn.Sequential(nn.Linear(512, num_classes))) # MLP(256, 512, 30522, 1) 49408)
#self.samples_proj = nn.Sequential(nn.Linear(768,512))
self.samples_proj = nn.Identity()
self.question_proj = nn.Identity() #nn.Sequential(nn.Linear(512, 512,bias=False)) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
#self.tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def forward(self, samples, question_in, answer_out, mask_answer):
# print('Here')
#print(samples.shape)
_, _, samples = self.backbone.encode_image(samples)
#samples=self.VIT(samples)
#print(samples.shape)
samples=samples.float()
#samples = self.VIT(samples)
#print(`samples.shape)
#samples = samples.view(-1, 512, 8 * 8)
# print(img_seq.shape)
#samples = samples.permute(0, 2, 1)
#samples=samples[:,0:,:] @ self.samples_proj
samples = self.samples_proj(samples)
#print(samples.shape)
#print(samples.shape)
_, _,question_in = self.backbone.encode_text(question_in)
#print(question_in.shape)
#samples = self.samples_proj(samples.float())
question_in = self.question_proj(question_in.float())
#print(question_in.shape)
#print(samples.shape)
samples = torch.cat((samples, question_in), dim=1)
#print(samples.shape)
# src, mask = features[-1].decompose()
# assert mask is not None
hs = self.transformer_decoder(self.input_proj(samples.permute(1, 0, 2).float()), answer_out, tgt_mask=mask_answer)
out = self.mlp(hs.permute(1, 0, 2))
# print(out.shape)
return out
config = Config()
Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
My_VQA = VQA_Net(num_classes=len(Tokenizer))
My_VQA.load_state_dict(torch.load("./PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar",map_location= torch.device(device)))
tfms = transforms.Compose([
#transforms.Lambda(under_max),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# transforms.Normalize(0.5, 0.5),
])
def answer_question(image, text_question):
with torch.no_grad():
for iter in range(1):
start_token = Tokenizer.convert_tokens_to_ids("<|startoftext|>")
# end_token = Tokenizer.convert_tokens_to_ids("<|endoftext|>")
# start_token=tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long)
cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool)
caption[:, 0] = start_token
cap_mask[:, 0] = False
if text_question.find('?') > -1:
text_question = text_question.split('?')[0].lower()
text_question= np.array(Tokenizer.encode_plus(text_question, max_length=77, pad_to_max_length=True,return_attention_mask=True,
return_token_type_ids=False, truncation=True)['input_ids'])
#print(torch.Tensor(text_question).unsqueeze(0).long())
for i in range(config.max_position_embeddings - 1):
predictions = My_VQA(image.unsqueeze(0),torch.Tensor(text_question).unsqueeze(0).long(), caption,cap_mask)
predictions = predictions[:, i, :]
predicted_id = torch.argmax(predictions, axis=-1)
caption[:, i + 1] = predicted_id[0]
cap_mask[:, i + 1] = False
if predicted_id[0] == 49407:
break
#print('question:')
#print(batch_test['question'])
cap_result_intermediate = Tokenizer.decode(caption[0].tolist(), skip_special_tokens=True)
#print('+++++++++++++++++++++++++++++++++++')
#print("True:")
# print(ref_sentence)
cap_result = cap_result_intermediate.split('!')
#ref_sentence = batch_test['answer'].lower()
#print(ref_sentence)
#print("Predict:")
#print(cap_result)
# image_disp=inv_Normalize(batch_test['image'])[0].permute(1,2,0).detach().cpu().numpy()
# print('************************')
# plt.imshow(image_disp)
return cap_result
def infer_answer_question(image, text):
if text is None:
cap_result = "please write a question"
elif image is None:
cap_result = "please upload an image"
else:
image_encoded = tfms(image)
cap_result=answer_question(image_encoded,text)[0]
return cap_result
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [["train_0000.jpg", "Where are liver stem cells (oval cells) located?"],
["train_0001.jpg", "What are stained here with an immunohistochemical stain for cytokeratin 7?"],
["train_0002.jpg", "What are bile duct cells and canals of Hering stained here with for cytokeratin 7?"],
["train_0003.jpg", "Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?"],
["train_0018.jpg", "Is there an infarct in the brain hypertrophy?"],
["train_0019.jpg", "What is ischemic coagulative necrosis?"]]
title = "Vision–Language Model for Visual Question Answering in Medical Imagery"
description = "Y Bazi, MMA Rahhal, L Bashmal, M Zuair. <a href='https://www.mdpi.com/2306-5354/10/3/380' target='_blank'> Vision–Language Model for Visual Question Answering in Medical Imagery</a>. Bioengineering, 2023<br><br>"\
"Gradio Demo for VQA medical model trained on PathVQA dataset, To use it, upload your image and type a question and click 'submit', or click one of the examples to load them." \
### link to paper and github code
website = ""
article = f"<p style='text-align: center'><a href='{website}' target='_blank'>BigMed@KSU</a></p>"
interface = gr.Interface(fn=infer_answer_question,
inputs=[image, question],
outputs=answer,
examples=examples,
title=title,
description=description,
article=article)
interface.launch(debug=True, enable_queue=True)