KublaiKhan1's picture
Update app.py
6bce5a8 verified
raw
history blame
9.77 kB
import torch
import gradio as gr
from torchvision.transforms import v2 as transforms
from PIL import Image
import numpy as np
import cv2
from torchvision.transforms.v2 import functional
# Constants
RESIZE_DIM = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
# BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"])
BREAKHIS_LABELS = {
0: "Tubular Adenoma (TA) - Benign",
1: "Mucinous Carcinoma (MC) - Malignant",
2: "Fibroadenoma (F) - Benign",
3: "Ductal Carcinoma (DC) - Malignant"
}
GLEASON_LABELS = {
0: "Benign",
1: "Gleason 3",
2: "Gleason 4",
3: "Gleason 5"
}
BACH_LABELS = {0: "Benign",
1: "InSitu",
2:"Invasive",
3: "Normal"}
CRC_LABELS = {
0: "ADI",
1: "BACK",
2: "DEB",
3: "LYM",
4: "MUC",
5: "MUS",
6: "NORM",
7: "STR",
8: "TUM",
}
print("Loading DinoV2 base model...")
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
print("Loading custom pathology checkpoint...")
#ours = torch.load("/data/linears/teacher_checkpoint.pth")
#checkpoint = torch.load("./teacher_checkpoint_load.pt")
checkpoint = torch.hub.load_state_dict_from_url("https://huggingface.co/SophontAI/OpenMidnight/resolve/main/teacher_checkpoint_load.pt")
new_shape = checkpoint["pos_embed"]
dinov2.pos_embed = torch.nn.parameter.Parameter(new_shape)
dinov2.load_state_dict(checkpoint)
dinov2.eval()
#torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
def setup_linear(path):
print(f"Loading {path} linear classifier...")
# Load the best checkpoint from the latest run
linear_checkpoint = torch.load(path)
linear_weights = linear_checkpoint["state_dict"]["head.weight"]
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
# Create linear layer
linear = torch.nn.Linear(1536, 4)
linear.weight.data = linear_weights
linear.bias.data = linear_bias
linear.eval()
return linear
def setup_linear_crc(path):
print(f"Loading {path} linear classifier...")
# Load the best checkpoint from the latest run
linear_checkpoint = torch.load(path)
linear_weights = linear_checkpoint["state_dict"]["head.weight"]
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
# Create linear layer
linear = torch.nn.Linear(1536, 9)
linear.weight.data = linear_weights
linear.bias.data = linear_bias
linear.eval()
return linear
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dinov2 = dinov2.to(device)
breakhis_path = "./breakhis_best.ckpt"
breakhis_linear = setup_linear(breakhis_path).to(device)
gleason_path = "./gleason_best.ckpt"
gleason_linear = setup_linear(gleason_path).to(device)
bach_path = "./bach_best.ckpt"
bach_linear = setup_linear(bach_path).to(device)
crc_path = "./crc_best.ckpt"
crc_linear = setup_linear_crc(crc_path).to(device)
print(f"Models loaded on {device}")
model_transforms = transforms.Compose([
transforms.Resize(RESIZE_DIM),
transforms.CenterCrop(RESIZE_DIM),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])
def cv_path(path):
image = cv2.imread(path, flags=cv2.IMREAD_COLOR)
if image.ndim == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.ndim == 2 and flags == cv2.IMREAD_COLOR:
image = image[:, :, np.newaxis]
image = np.asarray(image, dtype=np.uint8)
image = functional.to_image(image)
return image
def predict_breakhis(image):
return predict_class(image, breakhis_linear, "breakhis")
def predict_gleason(image):
return predict_class(image, gleason_linear, "gleason")
def predict_bach(image):
return predict_class(image, bach_linear, "bach")
def predict_crc(image):
return predict_class(image, crc_linear, "crc")
def predict_class(image, linear, dataset):
"""
Predict breast tumor type from a histopathology image
Args:
image: PIL Image or numpy array
Returns:
dict: Probability distribution over tumor types
"""
image = cv_path(image)
# Preprocess image
image_tensor = model_transforms(image).unsqueeze(0).to(device)
# Get embedding from DinoV2
with torch.no_grad():
embedding = dinov2(image_tensor)
# Get logits from linear classifier
logits = linear(embedding)
print(logits)
# Convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=1)
print(probs)
# Create output dictionary
probs_dict = {}
for idx, prob in enumerate(probs[0].cpu().numpy()):
if dataset == "breakhis":
probs_dict[BREAKHIS_LABELS[idx]] = float(prob)
elif dataset == "gleason":
probs_dict[GLEASON_LABELS[idx]] = float(prob)
elif dataset == "bach":
probs_dict[BACH_LABELS[idx]] = float(prob)
elif dataset == "crc":
probs_dict[CRC_LABELS[idx]] = float(prob)
return probs_dict
# Create Gradio interface
breakhis = gr.Interface(
fn=predict_breakhis,
inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"),
outputs=gr.Label(num_top_classes=4, label="Tumor Type Prediction"),
title="BreakHis Breast Tumor Classification",
description="""
Upload a breast histopathology image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image.
This model uses a custom-trained DinoV2 foundation model for pathology images
with a linear classifier for BreakHis tumor classification.
**Tumor Types:**
- **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F)
- **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC)
These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation.
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
""",
examples=["./SOB_B_TA-14-13200-40-001.png",
"./SOB_M_MC-14-10147-40-001.png",
"./SOB_B_F-14-14134-40-001.png",
], # You can add example image paths here
theme=gr.themes.Soft()
)
gleason = gr.Interface(
fn=predict_gleason,
inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"),
outputs=gr.Label(num_top_classes=4, label="Gleason Tumor Type Prediction"),
title="Gleason Prostate Tumor Classification",
description="""
Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image.
This model uses a custom-trained DinoV2 foundation model for pathology images
with a linear classifier for gleason tumor classification.
Images are classified as benign, Gleason pattern 3, 4 or 5.
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
""",
examples=["./ZT111_4_A_1_12_patch_13_class_2.jpg",
"./ZT204_6_A_1_10_patch_10_class_3.jpg",
#"",
], # You can add example image paths here
theme=gr.themes.Soft()
)
crc = gr.Interface(
fn=predict_crc,
inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"),
outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"),
title="Colorectal Tumor Classification",
description="""
Upload a colorectal cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image.
This model uses a custom-trained DinoV2 foundation model for pathology images
with a linear classifier for colorectal tumor classification.
The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM)
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
""",
examples=["./ADI-TCGA-AAICEQFN.png",
"./BACK-TCGA-AARRNSTS.png",
"./DEB-TCGA-AANNAWLE.png",
], # You can add example image paths here
theme=gr.themes.Soft()
)
bach = gr.Interface(
fn=predict_bach,
inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
outputs=gr.Label(num_top_classes=4, label="Bach Tumor Type Prediction"),
title="Tumor Classification",
description="""
Upload a prostate cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image.
This model uses a custom-trained DinoV2 foundation model for pathology images
with a linear classifier for tumor classification.
Images are classified as benign, normal, invasive, inSitu
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
""",
examples=["./b001.png",
"./n001.png",
"./is001.png",
"./iv001.png"
], # You can add example image paths here
theme=gr.themes.Soft()
)
demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"])
if __name__ == "__main__":
demo.launch()