File size: 7,457 Bytes
37b3db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

import gradio as gr
# Switch path to root of project
import os
import sys
# Get the current working directory
current_dir = os.getcwd()
src_path = os.path.join(current_dir, 'src')
os.chdir(src_path)
# Add src directory to sys.path
sys.path.append(src_path)
from open_clip import create_model_and_transforms
from huggingface_hub import hf_hub_download
from open_clip import HFTokenizer
import torch

class create_unimed_clip_model:
    def __init__(self, model_name):
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = 'cpu'
        mean = (0.48145466, 0.4578275, 0.40821073)  # OpenAI dataset mean
        std = (0.26862954, 0.26130258, 0.27577711)  # OpenAI dataset std
        if model_name == "ViT/B-16":
            # Download the weights
            weights_path = hf_hub_download(
                repo_id="UzairK/unimed-clip-vit-b16",
                filename="unimed-clip-vit-b16.pt"
            )
            self.pretrained = weights_path  # Path to pretrained weights
            self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
            self.model_name = "ViT-B-16-quickgelu"
        elif model_name == 'ViT/L-14@336px-base-text':
            # Download the weights
            self.model_name = "ViT-L-14-336-quickgelu"
            weights_path = hf_hub_download(
                repo_id="UzairK/unimed_clip_vit_l14_base_text_encoder",
                filename="unimed_clip_vit_l14_base_text_encoder.pt"
            )
            self.pretrained = weights_path  # Path to pretrained weights
            self.text_encoder_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
        self.tokenizer = HFTokenizer(
            self.text_encoder_name,
            context_length=256,
            **{},
        )
        self.model, _, self.processor = create_model_and_transforms(
            self.model_name,
            self.pretrained,
            precision='amp',
            device=self.device,
            force_quick_gelu=True,
            pretrained_image=False,
            mean=mean, std=std,
            inmem=True,
            text_encoder_name=self.text_encoder_name,
        )

    def __call__(self, input_image, candidate_labels, hypothesis_template):
        # Preprocess input
        input_image = self.processor(input_image).unsqueeze(0).to(self.device)
        if hypothesis_template == "":
            texts = [
                self.tokenizer(cls_text).to(self.device)
                for cls_text in candidate_labels
            ]
        else:
            texts = [
                self.tokenizer(hypothesis_template + " " + cls_text).to(self.device)
                for cls_text in candidate_labels
            ]
        texts = torch.cat(texts, dim=0)
        # Perform inference
        with torch.no_grad():
            text_features = self.model.encode_text(texts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            image_features = self.model.encode_image(input_image)
            logits = (image_features @ text_features.t()).softmax(dim=-1).cpu().numpy()
            return {cls_text: float(score) for cls_text, score in zip(candidate_labels, logits[0])}

pipes = {
    "ViT/B-16": create_unimed_clip_model(model_name="ViT/B-16"),
    "ViT/L-14@336px-base-text": create_unimed_clip_model(model_name='ViT/L-14@336px-base-text'),
}
# Define Gradio inputs and outputs
inputs = [
    gr.Image(type="pil", label="Image"),
    gr.Textbox(label="Candidate Labels (comma-separated)"),
    gr.Radio(
        choices=["ViT/B-16", "ViT/L-14@336px-base-text"],
        label="Model",
        value="ViT/B-16",
    ),
    gr.Textbox(label="Prompt Template", placeholder="Optional prompt template as prefix",
               value=""),
]
outputs = gr.Label(label="Predicted Scores")

def shot(image, labels_text, model_name, hypothesis_template):
    labels = [label.strip(" ") for label in labels_text.strip(" ").split(",")]
    res = pipes[model_name](input_image=image,
           candidate_labels=labels,
           hypothesis_template=hypothesis_template)
    return {single_key: res[single_key] for single_key in res.keys()}
# Define examples

examples = [
    ["../docs/sample_images/brain_MRI.jpg", "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.", "ViT/B-16", ""],
    ["../docs/sample_images/ct_scan_right_kidney.jpg",
     "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
     "ViT/B-16", ""],
    ["../docs/sample_images/retina_glaucoma.jpg",
     "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
     "ViT/B-16", ""],
    ["../docs/sample_images/tumor_histo_pathology.jpg",
     "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
     "ViT/B-16", ""],
    ["../docs/sample_images/xray_cardiomegaly.jpg",
     "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
     "ViT/B-16", ""],
    ["../docs/sample_images//xray_pneumonia.png",
     "CT scan image displaying the anatomical structure of the right kidney., pneumonia is indicated in this chest X-ray image., this is a MRI photo of a brain., this fundus image shows optic nerve damage due to glaucoma., a histopathology slide showing Tumor, Cardiomegaly is evident in the X-ray image of the chest.",
     "ViT/B-16", ""],
]

iface = gr.Interface(shot,
            inputs,
            outputs,
            examples=examples,
            description="""<p>Demo for UniMed CLIP, a family of strong Medical Contrastive VLMs trained on UniMed-dataset. For more information about our project, refer to our paper and github repository. <br>
            Paper: <a href='https://arxiv.org/abs/2412.10372'>https://arxiv.org/abs/2412.10372</a> <br>
            Github: <a href='https://github.com/mbzuai-oryx/UniMed-CLIP'>https://github.com/mbzuai-oryx/UniMed-CLIP</a> <br><br>
            <b>[DEMO USAGE]</b> To begin with the demo, provide a picture (either upload manually, or select from the given examples) and class labels. Optionally you can also add template as an prefix to the class labels. <br> </p>""",
            title="Zero-shot Medical Image Classification with UniMed-CLIP")

iface.launch()