sitammeur commited on
Commit
db50b86
·
verified ·
1 Parent(s): cf7232f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +81 -0
  2. classifier.py +53 -0
  3. exception.py +50 -0
  4. logger.py +21 -0
  5. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import gradio as gr
3
+ from classifier import MedSigLIPClassifier
4
+ from logger import logging
5
+ from exception import CustomExceptionHandling
6
+
7
+ # Initialize the classifier
8
+ # This might take a moment to download/load the model
9
+ classifier = MedSigLIPClassifier()
10
+
11
+
12
+ def infer(image, candidate_labels):
13
+ """Infer function to predict the probability of the given image and candidate labels."""
14
+ try:
15
+ if not image:
16
+ raise gr.Error("No image uploaded")
17
+
18
+ # Split labels by comma and strip whitespace
19
+ labels = [l.strip() for l in candidate_labels.split(",") if l.strip()]
20
+
21
+ if not labels:
22
+ raise gr.Error("No labels provided")
23
+
24
+ # Call the classifier
25
+ logging.info("Calling the classifier")
26
+ return classifier.predict(image, labels)
27
+
28
+ except Exception as e:
29
+ # Custom exception handling
30
+ raise CustomExceptionHandling(e, sys) from e
31
+
32
+
33
+ # Gradio interface
34
+ with gr.Blocks() as demo:
35
+ with gr.Column():
36
+
37
+ gr.Markdown("# **MedSigLIP Zero-Shot Classification**")
38
+ gr.Markdown(
39
+ "This is a demo of MedSigLIP (448) for zero-shot classification trained on medical images."
40
+ )
41
+
42
+ with gr.Row():
43
+ # Add image input, text input and run button
44
+ with gr.Column():
45
+ image_input = gr.Image(
46
+ type="pil", label="Image", placeholder="Upload an image", height=310
47
+ )
48
+ text_input = gr.Textbox(
49
+ label="Labels",
50
+ placeholder="Enter your input labels here (comma separated)",
51
+ )
52
+ run_button = gr.Button("Run")
53
+ with gr.Column():
54
+ output_label = gr.Label(label="Output", num_top_classes=3)
55
+
56
+ # Add examples
57
+ gr.Examples(
58
+ examples=[
59
+ [
60
+ "./images/sample1.png",
61
+ "a photo of a leg with no rash, a photo of a leg with a rash",
62
+ ],
63
+ [
64
+ "./images/sample2.png",
65
+ "a photo of an arm with no rash, a photo of an arm with a rash",
66
+ ],
67
+ ],
68
+ inputs=[image_input, text_input],
69
+ outputs=[output_label],
70
+ fn=infer,
71
+ cache_examples=True,
72
+ cache_mode="lazy",
73
+ )
74
+
75
+ # Add run button click event
76
+ run_button.click(
77
+ fn=infer, inputs=[image_input, text_input], outputs=[output_label]
78
+ )
79
+
80
+ # Launch the app
81
+ demo.launch(debug=False, theme=gr.themes.Soft())
classifier.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModel
5
+ import tensorflow as tf
6
+
7
+
8
+ class MedSigLIPClassifier:
9
+ """MedSigLIPClassifier class for zero-shot classification of medical images."""
10
+
11
+ def __init__(self, model_id="google/medsiglip-448"):
12
+ """Initialize the classifier with the given model ID."""
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ self.model = AutoModel.from_pretrained(model_id).to(self.device)
15
+ self.processor = AutoProcessor.from_pretrained(model_id)
16
+
17
+ def _resize(self, image):
18
+ """Resizes the image using TensorFlow's resize method to match MedSigLIP training preprocessing."""
19
+ return Image.fromarray(
20
+ tf.image.resize(
21
+ images=image, size=[448, 448], method="bilinear", antialias=False
22
+ )
23
+ .numpy()
24
+ .astype(np.uint8)
25
+ )
26
+
27
+ def predict(self, image: Image.Image, candidate_labels: list[str]):
28
+ """Predicts the probabilities for the given image and candidate labels."""
29
+ # Ensure image is RGB
30
+ if image.mode != "RGB":
31
+ image = image.convert("RGB")
32
+
33
+ # Resize image
34
+ resized_image = self._resize(image)
35
+
36
+ # Prepare inputs
37
+ inputs = self.processor(
38
+ text=candidate_labels,
39
+ images=resized_image,
40
+ padding="max_length",
41
+ return_tensors="pt",
42
+ ).to(self.device)
43
+
44
+ # Inference
45
+ with torch.no_grad():
46
+ outputs = self.model(**inputs)
47
+
48
+ logits_per_image = outputs.logits_per_image
49
+ probs = torch.softmax(logits_per_image, dim=1)
50
+
51
+ # Format results
52
+ probs_list = probs[0].tolist()
53
+ return {label: prob for label, prob in zip(candidate_labels, probs_list)}
exception.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines a custom exception handling class and a function to get error message with details of the error.
3
+ """
4
+
5
+ # Standard Library
6
+ import sys
7
+
8
+ # Local imports
9
+ from logger import logging
10
+
11
+
12
+ # Function Definition to get error message with details of the error (file name and line number) when an error occurs in the program
13
+ def get_error_message(error, error_detail: sys):
14
+ """
15
+ Get error message with details of the error.
16
+
17
+ Args:
18
+ - error (Exception): The error that occurred.
19
+ - error_detail (sys): The details of the error.
20
+
21
+ Returns:
22
+ str: A string containing the error message along with the file name and line number where the error occurred.
23
+ """
24
+ _, _, exc_tb = error_detail.exc_info()
25
+
26
+ # Get error details
27
+ file_name = exc_tb.tb_frame.f_code.co_filename
28
+ return "Error occured in python script name [{0}] line number [{1}] error message[{2}]".format(
29
+ file_name, exc_tb.tb_lineno, str(error)
30
+ )
31
+
32
+
33
+ # Custom Exception Handling Class Definition
34
+ class CustomExceptionHandling(Exception):
35
+ """
36
+ Custom Exception Handling:
37
+ This class defines a custom exception that can be raised when an error occurs in the program.
38
+ It takes an error message and an error detail as input and returns a formatted error message when the exception is raised.
39
+ """
40
+
41
+ # Constructor
42
+ def __init__(self, error_message, error_detail: sys):
43
+ """Initialize the exception"""
44
+ super().__init__(error_message)
45
+
46
+ self.error_message = get_error_message(error_message, error_detail=error_detail)
47
+
48
+ def __str__(self):
49
+ """String representation of the exception"""
50
+ return self.error_message
logger.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the required modules
2
+ import os
3
+ import logging
4
+ from datetime import datetime
5
+
6
+ # Creating a log file with the current date and time as the name of the file
7
+ LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
8
+
9
+ # Creating a logs folder if it does not exist
10
+ logs_path = os.path.join(os.getcwd(), "logs")
11
+ os.makedirs(logs_path, exist_ok=True)
12
+
13
+ # Setting the log file path and the log level
14
+ LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
15
+
16
+ # Configuring the logger
17
+ logging.basicConfig(
18
+ filename=LOG_FILE_PATH,
19
+ format="[ %(asctime)s ] %(lineno)d %(name)s - %(levelname)s - %(message)s",
20
+ level=logging.INFO,
21
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pillow
4
+ numpy
5
+ requests
6
+ tensorflow
7
+ gradio
8
+ sentencepiece
9
+ protobuf