Spaces:
Running
on
A100
Running
on
A100
| import torch | |
| import gradio as gr | |
| from torchvision.transforms import v2 as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| from torchvision.transforms.v2 import functional | |
| # Constants | |
| RESIZE_DIM = 224 | |
| NORMALIZE_MEAN = [0.485, 0.456, 0.406] | |
| NORMALIZE_STD = [0.229, 0.224, 0.225] | |
| # BreakHis tumor type labels (classes: ["TA", "MC", "F", "DC"]) | |
| BREAKHIS_LABELS = { | |
| 0: "Tubular Adenoma (TA) - Benign", | |
| 1: "Mucinous Carcinoma (MC) - Malignant", | |
| 2: "Fibroadenoma (F) - Benign", | |
| 3: "Ductal Carcinoma (DC) - Malignant" | |
| } | |
| GLEASON_LABELS = { | |
| 0: "Benign", | |
| 1: "Gleason 3", | |
| 2: "Gleason 4", | |
| 3: "Gleason 5" | |
| } | |
| BACH_LABELS = {0: "Benign", | |
| 1: "InSitu", | |
| 2:"Invasive", | |
| 3: "Normal"} | |
| CRC_LABELS = { | |
| 0: "ADI", | |
| 1: "BACK", | |
| 2: "DEB", | |
| 3: "LYM", | |
| 4: "MUC", | |
| 5: "MUS", | |
| 6: "NORM", | |
| 7: "STR", | |
| 8: "TUM", | |
| } | |
| print("Loading DinoV2 base model...") | |
| dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') | |
| print("Loading custom pathology checkpoint...") | |
| #ours = torch.load("/data/linears/teacher_checkpoint.pth") | |
| #checkpoint = torch.load("./teacher_checkpoint_load.pt") | |
| checkpoint = torch.hub.load_state_dict_from_url("https://huggingface.co/SophontAI/OpenMidnight/resolve/main/teacher_checkpoint_load.pt") | |
| new_shape = checkpoint["pos_embed"] | |
| dinov2.pos_embed = torch.nn.parameter.Parameter(new_shape) | |
| dinov2.load_state_dict(checkpoint) | |
| dinov2.eval() | |
| #torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt") | |
| def setup_linear(path): | |
| print(f"Loading {path} linear classifier...") | |
| # Load the best checkpoint from the latest run | |
| linear_checkpoint = torch.load(path) | |
| linear_weights = linear_checkpoint["state_dict"]["head.weight"] | |
| linear_bias = linear_checkpoint["state_dict"]["head.bias"] | |
| # Create linear layer | |
| linear = torch.nn.Linear(1536, 4) | |
| linear.weight.data = linear_weights | |
| linear.bias.data = linear_bias | |
| linear.eval() | |
| return linear | |
| def setup_linear_crc(path): | |
| print(f"Loading {path} linear classifier...") | |
| # Load the best checkpoint from the latest run | |
| linear_checkpoint = torch.load(path) | |
| linear_weights = linear_checkpoint["state_dict"]["head.weight"] | |
| linear_bias = linear_checkpoint["state_dict"]["head.bias"] | |
| # Create linear layer | |
| linear = torch.nn.Linear(1536, 9) | |
| linear.weight.data = linear_weights | |
| linear.bias.data = linear_bias | |
| linear.eval() | |
| return linear | |
| # Move models to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dinov2 = dinov2.to(device) | |
| breakhis_path = "./breakhis_best.ckpt" | |
| breakhis_linear = setup_linear(breakhis_path).to(device) | |
| gleason_path = "./gleason_best.ckpt" | |
| gleason_linear = setup_linear(gleason_path).to(device) | |
| bach_path = "./bach_best.ckpt" | |
| bach_linear = setup_linear(bach_path).to(device) | |
| crc_path = "./crc_best.ckpt" | |
| crc_linear = setup_linear_crc(crc_path).to(device) | |
| print(f"Models loaded on {device}") | |
| model_transforms = transforms.Compose([ | |
| transforms.Resize(RESIZE_DIM), | |
| transforms.CenterCrop(RESIZE_DIM), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD) | |
| ]) | |
| def cv_path(path): | |
| image = cv2.imread(path, flags=cv2.IMREAD_COLOR) | |
| if image.ndim == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| if image.ndim == 2 and flags == cv2.IMREAD_COLOR: | |
| image = image[:, :, np.newaxis] | |
| image = np.asarray(image, dtype=np.uint8) | |
| image = functional.to_image(image) | |
| return image | |
| def predict_breakhis(image): | |
| return predict_class(image, breakhis_linear, "breakhis") | |
| def predict_gleason(image): | |
| return predict_class(image, gleason_linear, "gleason") | |
| def predict_bach(image): | |
| return predict_class(image, bach_linear, "bach") | |
| def predict_crc(image): | |
| return predict_class(image, crc_linear, "crc") | |
| def predict_class(image, linear, dataset): | |
| """ | |
| Predict breast tumor type from a histopathology image | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| dict: Probability distribution over tumor types | |
| """ | |
| image = cv_path(image) | |
| # Preprocess image | |
| image_tensor = model_transforms(image).unsqueeze(0).to(device) | |
| # Get embedding from DinoV2 | |
| with torch.no_grad(): | |
| embedding = dinov2(image_tensor) | |
| # Get logits from linear classifier | |
| logits = linear(embedding) | |
| print(logits) | |
| # Convert to probabilities | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| print(probs) | |
| # Create output dictionary | |
| probs_dict = {} | |
| for idx, prob in enumerate(probs[0].cpu().numpy()): | |
| if dataset == "breakhis": | |
| probs_dict[BREAKHIS_LABELS[idx]] = float(prob) | |
| elif dataset == "gleason": | |
| probs_dict[GLEASON_LABELS[idx]] = float(prob) | |
| elif dataset == "bach": | |
| probs_dict[BACH_LABELS[idx]] = float(prob) | |
| elif dataset == "crc": | |
| probs_dict[CRC_LABELS[idx]] = float(prob) | |
| return probs_dict | |
| # Create Gradio interface | |
| breakhis = gr.Interface( | |
| fn=predict_breakhis, | |
| inputs=gr.Image(type="filepath", label="Upload Breast Histopathology Image"), | |
| outputs=gr.Label(num_top_classes=4, label="Tumor Type Prediction"), | |
| title="BreakHis Breast Tumor Classification", | |
| description=""" | |
| Upload a breast histopathology image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 700x460 resolution. Do not otherwise modify your image. | |
| This model uses a custom-trained DinoV2 foundation model for pathology images | |
| with a linear classifier for BreakHis tumor classification. | |
| **Tumor Types:** | |
| - **Benign tumors:** Tubular Adenoma (TA), Fibroadenoma (F) | |
| - **Malignant tumors:** Mucinous Carcinoma (MC), Ductal Carcinoma (DC) | |
| These 4 classes were selected from the full BreakHis dataset as they have sufficient patient counts (≥7 patients) for robust evaluation. | |
| For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. | |
| """, | |
| examples=["./SOB_B_TA-14-13200-40-001.png", | |
| "./SOB_M_MC-14-10147-40-001.png", | |
| "./SOB_B_F-14-14134-40-001.png", | |
| ], # You can add example image paths here | |
| theme=gr.themes.Soft() | |
| ) | |
| gleason = gr.Interface( | |
| fn=predict_gleason, | |
| inputs=gr.Image(type="filepath", label="Upload Prostate Cancer Image"), | |
| outputs=gr.Label(num_top_classes=4, label="Gleason Tumor Type Prediction"), | |
| title="Gleason Prostate Tumor Classification", | |
| description=""" | |
| Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification, and ideally between 224x224 and 750x750 resolution. Do not otherwise modify your image. | |
| This model uses a custom-trained DinoV2 foundation model for pathology images | |
| with a linear classifier for gleason tumor classification. | |
| Images are classified as benign, Gleason pattern 3, 4 or 5. | |
| For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. | |
| """, | |
| examples=["./ZT111_4_A_1_12_patch_13_class_2.jpg", | |
| "./ZT204_6_A_1_10_patch_10_class_3.jpg", | |
| #"", | |
| ], # You can add example image paths here | |
| theme=gr.themes.Soft() | |
| ) | |
| crc = gr.Interface( | |
| fn=predict_crc, | |
| inputs=gr.Image(type="filepath", label="Upload Colorectal Cancer Image"), | |
| outputs=gr.Label(num_top_classes=9, label="CRC Tumor Type Prediction"), | |
| title="Colorectal Tumor Classification", | |
| description=""" | |
| Upload a colorectal cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally at 224x224. Do not otherwise modify your image. | |
| This model uses a custom-trained DinoV2 foundation model for pathology images | |
| with a linear classifier for colorectal tumor classification. | |
| The tissue classes are: Adipose (ADI), background (BACK), debris (DEB), lymphocytes (LYM), mucus (MUC), smooth muscle (MUS), normal colon mucosa (NORM), cancer-associated stroma (STR) and colorectal adenocarcinoma epithelium (TUM) | |
| For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. | |
| """, | |
| examples=["./ADI-TCGA-AAICEQFN.png", | |
| "./BACK-TCGA-AARRNSTS.png", | |
| "./DEB-TCGA-AANNAWLE.png", | |
| ], # You can add example image paths here | |
| theme=gr.themes.Soft() | |
| ) | |
| bach = gr.Interface( | |
| fn=predict_bach, | |
| inputs=gr.Image(type="filepath", label="Upload Cancer Image"), | |
| outputs=gr.Label(num_top_classes=4, label="Bach Tumor Type Prediction"), | |
| title="Tumor Classification", | |
| description=""" | |
| Upload a prostate cancer image to predict the tumor type. Your image must be at 20X magnification, and ideally between 224x224 and 1536x2048 resolution. Do not otherwise modify your image. | |
| This model uses a custom-trained DinoV2 foundation model for pathology images | |
| with a linear classifier for tumor classification. | |
| Images are classified as benign, normal, invasive, inSitu | |
| For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results. | |
| """, | |
| examples=["./b001.png", | |
| "./n001.png", | |
| "./is001.png", | |
| "./iv001.png" | |
| ], # You can add example image paths here | |
| theme=gr.themes.Soft() | |
| ) | |
| demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"]) | |
| if __name__ == "__main__": | |
| demo.launch() | |