File size: 5,547 Bytes
8889d9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import torch
from PIL import Image
from datasets import load_dataset
import random
from skincancer_vit.model import SkinCancerViTModel
HF_MODEL_REPO = "ethicalabs/SkinCancerViT"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading SkinCancerViT model from {HF_MODEL_REPO} to {DEVICE}...")
model = SkinCancerViTModel.from_pretrained(HF_MODEL_REPO)
model.to(DEVICE)
model.eval() # Set to evaluation mode
print("Model loaded successfully.")
print("Loading 'marmal88/skin_cancer' dataset for random samples...")
dataset = load_dataset("marmal88/skin_cancer", split="test")
print("Dataset loaded successfully.")
def predict_uploaded_image(image: Image.Image, age: int, localization: str) -> str:
"""
Handles prediction for an uploaded image with user-provided tabular data.
"""
if model is None:
return "Error: Model not loaded. Please check the console for details."
if image is None:
return "Please upload an image."
if age is None:
return "Please enter an age."
if not localization:
return "Please select a localization."
try:
# Call the model's full_predict method
predicted_dx, confidence = model.full_predict(
raw_image=image, raw_age=age, raw_localization=localization, device=DEVICE
)
return f"Predicted Diagnosis: **{predicted_dx}** (Confidence: {confidence:.4f})"
except Exception as e:
return f"Prediction Error: {e}"
# --- Prediction Function for Random Sample ---
def predict_random_sample() -> str:
"""
Fetches a random sample from the dataset and performs prediction.
"""
if model is None:
return "Error: Model not loaded. Please check the console for details."
if dataset is None:
return "Error: Dataset not loaded. Cannot select random sample."
try:
# Select a random sample from the dataset
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
sample_image = sample["image"]
sample_age = sample["age"]
sample_localization = sample["localization"]
sample_true_dx = sample["dx"]
# Call the model's full_predict method
predicted_dx, confidence = model.full_predict(
raw_image=sample_image,
raw_age=sample_age,
raw_localization=sample_localization,
device=DEVICE,
)
# Return a formatted string with all information
result_str = (
f"**Random Sample Details:**\n"
f"- Age: {sample_age}\n"
f"- Localization: {sample_localization}\n"
f"- True Diagnosis: **{sample_true_dx}**\n\n"
f"**Model Prediction:**\n"
f"- Predicted Diagnosis: **{predicted_dx}**\n"
f"- Confidence: {confidence:.4f}\n"
f"- Correct Prediction: {'✅ Yes' if predicted_dx == sample_true_dx else '❌ No'}"
)
return sample_image, result_str
except Exception as e:
return None, f"Prediction Error on Random Sample: {e}"
# --- Gradio Interface ---
with gr.Blocks(title="Skin Cancer ViT Predictor") as demo:
gr.Markdown(
"""
# Skin Cancer ViT Predictor
This application demonstrates the `SkinCancerViT` multimodal model for skin cancer diagnosis.
It can take an uploaded image with patient metadata or predict on a random sample from the dataset.
**Disclaimer:** This tool is for demonstration and research purposes only and should not be used for medical diagnosis.
"""
)
with gr.Tab("Predict on Random Sample"):
gr.Markdown("## Get a Prediction from a Random Sample in the Test Set")
random_sample_button = gr.Button("Get Random Sample Prediction")
# Modified output components for random sample tab
with gr.Row():
output_random_image = gr.Image(
type="pil", label="Random Sample Image", height=250, width=250
)
output_random_details = gr.Markdown(
"Random sample details and prediction will appear here."
)
random_sample_button.click(
fn=predict_random_sample,
inputs=[],
outputs=[
output_random_image,
output_random_details,
], # Map to both image and markdown outputs
)
with gr.Tab("Upload Image & Predict"):
gr.Markdown("## Upload Your Image and Provide Patient Data")
with gr.Row():
image_input = gr.Image(
type="pil", label="Upload Skin Lesion Image (224x224 preferred)"
)
with gr.Column():
age_input = gr.Number(
label="Patient Age", minimum=0, maximum=120, step=1
)
# Ensure these localizations match your training data categories
localization_input = gr.Dropdown(
model.config.localization_to_id.keys(),
label="Lesion Localization",
value="unknown", # Default value
)
predict_button = gr.Button("Get Prediction")
output_upload = gr.Markdown("Prediction will appear here.")
predict_button.click(
fn=predict_uploaded_image,
inputs=[image_input, age_input, localization_input],
outputs=output_upload,
)
if __name__ == "__main__":
demo.launch(share=False)
|