File size: 5,547 Bytes
8889d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
from PIL import Image
from datasets import load_dataset
import random

from skincancer_vit.model import SkinCancerViTModel

HF_MODEL_REPO = "ethicalabs/SkinCancerViT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(f"Loading SkinCancerViT model from {HF_MODEL_REPO} to {DEVICE}...")

model = SkinCancerViTModel.from_pretrained(HF_MODEL_REPO)
model.to(DEVICE)
model.eval()  # Set to evaluation mode
print("Model loaded successfully.")

print("Loading 'marmal88/skin_cancer' dataset for random samples...")
dataset = load_dataset("marmal88/skin_cancer", split="test")
print("Dataset loaded successfully.")


def predict_uploaded_image(image: Image.Image, age: int, localization: str) -> str:
    """
    Handles prediction for an uploaded image with user-provided tabular data.
    """
    if model is None:
        return "Error: Model not loaded. Please check the console for details."
    if image is None:
        return "Please upload an image."
    if age is None:
        return "Please enter an age."
    if not localization:
        return "Please select a localization."

    try:
        # Call the model's full_predict method
        predicted_dx, confidence = model.full_predict(
            raw_image=image, raw_age=age, raw_localization=localization, device=DEVICE
        )
        return f"Predicted Diagnosis: **{predicted_dx}** (Confidence: {confidence:.4f})"
    except Exception as e:
        return f"Prediction Error: {e}"


# --- Prediction Function for Random Sample ---
def predict_random_sample() -> str:
    """
    Fetches a random sample from the dataset and performs prediction.
    """
    if model is None:
        return "Error: Model not loaded. Please check the console for details."
    if dataset is None:
        return "Error: Dataset not loaded. Cannot select random sample."

    try:
        # Select a random sample from the dataset
        random_idx = random.randint(0, len(dataset) - 1)
        sample = dataset[random_idx]

        sample_image = sample["image"]
        sample_age = sample["age"]
        sample_localization = sample["localization"]
        sample_true_dx = sample["dx"]

        # Call the model's full_predict method
        predicted_dx, confidence = model.full_predict(
            raw_image=sample_image,
            raw_age=sample_age,
            raw_localization=sample_localization,
            device=DEVICE,
        )

        # Return a formatted string with all information
        result_str = (
            f"**Random Sample Details:**\n"
            f"- Age: {sample_age}\n"
            f"- Localization: {sample_localization}\n"
            f"- True Diagnosis: **{sample_true_dx}**\n\n"
            f"**Model Prediction:**\n"
            f"- Predicted Diagnosis: **{predicted_dx}**\n"
            f"- Confidence: {confidence:.4f}\n"
            f"- Correct Prediction: {'✅ Yes' if predicted_dx == sample_true_dx else '❌ No'}"
        )
        return sample_image, result_str
    except Exception as e:
        return None, f"Prediction Error on Random Sample: {e}"


# --- Gradio Interface ---
with gr.Blocks(title="Skin Cancer ViT Predictor") as demo:
    gr.Markdown(
        """
        # Skin Cancer ViT Predictor
        This application demonstrates the `SkinCancerViT` multimodal model for skin cancer diagnosis.
        It can take an uploaded image with patient metadata or predict on a random sample from the dataset.
        **Disclaimer:** This tool is for demonstration and research purposes only and should not be used for medical diagnosis.
        """
    )

    with gr.Tab("Predict on Random Sample"):
        gr.Markdown("## Get a Prediction from a Random Sample in the Test Set")
        random_sample_button = gr.Button("Get Random Sample Prediction")

        # Modified output components for random sample tab
        with gr.Row():
            output_random_image = gr.Image(
                type="pil", label="Random Sample Image", height=250, width=250
            )
            output_random_details = gr.Markdown(
                "Random sample details and prediction will appear here."
            )

        random_sample_button.click(
            fn=predict_random_sample,
            inputs=[],
            outputs=[
                output_random_image,
                output_random_details,
            ],  # Map to both image and markdown outputs
        )

    with gr.Tab("Upload Image & Predict"):
        gr.Markdown("## Upload Your Image and Provide Patient Data")
        with gr.Row():
            image_input = gr.Image(
                type="pil", label="Upload Skin Lesion Image (224x224 preferred)"
            )
            with gr.Column():
                age_input = gr.Number(
                    label="Patient Age", minimum=0, maximum=120, step=1
                )
                # Ensure these localizations match your training data categories
                localization_input = gr.Dropdown(
                    model.config.localization_to_id.keys(),
                    label="Lesion Localization",
                    value="unknown",  # Default value
                )
                predict_button = gr.Button("Get Prediction")

        output_upload = gr.Markdown("Prediction will appear here.")

        predict_button.click(
            fn=predict_uploaded_image,
            inputs=[image_input, age_input, localization_input],
            outputs=output_upload,
        )

if __name__ == "__main__":
    demo.launch(share=False)