Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """gradio_app_chestvision.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1cqkxJwlxFdpD6iRy-LNc5fhC5yoRYsw8 | |
| """ | |
| # !pip install --upgrade gradio | |
| """### Import dependencies""" | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms, models, datasets | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from torch.utils.data import random_split | |
| import pytorch_lightning as torch_light | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
| import torchmetrics | |
| from torchmetrics import Metric | |
| import os | |
| import shutil | |
| import subprocess | |
| import pandas as pd | |
| from PIL import Image | |
| import gradio | |
| from functools import partial | |
| """### Set parameters""" | |
| configs = { | |
| "IMAGE_SIZE": (224, 224), # Resize images to (W, H) | |
| "NUM_CHANNELS": 3, # RGB images | |
| "NUM_CLASSES": 15, # Number of output labels | |
| # ImageNet dataset normalization values (for pretrained backbones) | |
| "MEAN": (0.485, 0.456, 0.406), | |
| "STD": (0.229, 0.224, 0.225), | |
| "DEFAULT_BACKBONE": "ConvNeXt(tiny)", | |
| "THRESHOLD": 0.5 | |
| } | |
| BACKBONE_REGISTRY = { | |
| "ConvNeXt(small)": { | |
| "torchvision_name": "convnext_small", | |
| "ckpt": "ConvNeXt(small).ckpt"}, | |
| "ConvNeXt(tiny)": { | |
| "torchvision_name": "convnext_tiny", | |
| "ckpt": "ConvNeXt(tiny).ckpt"}, | |
| "EfficientNet(b3)": { | |
| "torchvision_name": "efficientnet_b3", | |
| "ckpt": "EfficientNet(b3).ckpt"}, | |
| "EfficientNet(v2_small)": { | |
| "torchvision_name": "efficientnet_v2_s", | |
| "ckpt": "EfficientNet(v2_small).ckpt"}, | |
| "RegNet(x3_2GF)": { | |
| "torchvision_name": "regnet_x_3_2gf", | |
| "ckpt": "RegNet(x3_2GF).ckpt"}, | |
| "ResNet50": { | |
| "torchvision_name": "resnet50", | |
| "ckpt": "ResNet50.ckpt"} | |
| } | |
| MODEL_CACHE = {} | |
| """### Define helper functions""" | |
| # helper function for loading pre-trained model | |
| # =================================================================================================== | |
| def get_pretrained_model(pretrained_model_name: str, num_classes: int, freeze_backbone: bool = True): | |
| """ | |
| Load a pretrained Torchvision classification model and replace | |
| ONLY its final Linear layer for transfer learning or fine-tuning. | |
| """ | |
| print(f"Loading pretrained [{pretrained_model_name}] model") | |
| # Load pretrained model from torchvision | |
| model = getattr(torchvision.models, pretrained_model_name)(weights="DEFAULT") | |
| # Optionally freeze all pretrained parameters (the backbone) | |
| if freeze_backbone: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Get the last top-level module (typically the classifier head) from the pretrained model | |
| layer_modules = list(model.named_children()) | |
| last_module_name, last_module = layer_modules[-1] | |
| if isinstance(last_module, nn.Sequential): | |
| # If the classifier is a Sequential module, replace its final layer | |
| num_in_features = last_module[-1].in_features # find the dimensionality of the input to the last layer of the network | |
| # Replace the output layer of the classifier head to match the num. classes in the current task | |
| last_module[-1] = nn.Linear(in_features = num_in_features, out_features = num_classes) # Output layer | |
| else: | |
| # Otherwise, replace the module directly (e.g., ResNet-style fc layer) | |
| in_features = last_module.in_features | |
| setattr(model, last_module_name, nn.Linear(in_features, num_classes)) | |
| return model | |
| # helper function for preprocessing input images | |
| # =================================================================================================== | |
| preprocess_fxn = transforms.Compose( | |
| [transforms.Resize(size=configs["IMAGE_SIZE"][::-1]), | |
| transforms.ToTensor(), | |
| transforms.Normalize(configs["MEAN"], configs["STD"], inplace=True)]) | |
| # Map numeric outputs to string labels | |
| labels_dict = { | |
| 0: "Atelectasis", | |
| 1: "Cardiomegaly", | |
| 2: "Consolidation", | |
| 3: "Edema", | |
| 4: "Effusion", | |
| 5: "Emphysema", | |
| 6: "Fibrosis", | |
| 7: "Hernia", | |
| 8: "Infiltration", | |
| 9: "Mass", | |
| 10: "No finding", | |
| 11: "Nodule", | |
| 12: "Pleural_Thickening", | |
| 13: "Pneumonia", | |
| 14: "Pneumothorax"} | |
| """### Create torch lightning model (i.e., classifier) module""" | |
| class modelModule(torch_light.LightningModule): | |
| def __init__(self, num_classes = configs['NUM_CLASSES'], backbone_model_name = 'efficientnet_b3'): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.backbone_model_name = backbone_model_name | |
| # Load a pretrained backbone and replace its final layer | |
| self.model = get_pretrained_model( | |
| num_classes = self.num_classes, | |
| pretrained_model_name = self.backbone_model_name ) | |
| # Binary classification loss operating on raw logits | |
| self.loss_function = torch.nn.BCEWithLogitsLoss() | |
| # self.accuracy_function = torchmetrics.Accuracy(task="multilabel", num_labels=self.num_classes) | |
| # self.f1_score_function = torchmetrics.F1Score(task="multilabel", num_labels=self.num_classes) | |
| self.accuracy_function = torchmetrics.classification.MultilabelAccuracy(num_labels=self.num_classes, average="weighted", threshold=0.5) | |
| self.f1_score_function = torchmetrics.classification.MultilabelF1Score(num_labels=self.num_classes, average="weighted", threshold=0.5) | |
| self.auroc_function = torchmetrics.classification.MultilabelAUROC(num_labels=self.num_classes, average="weighted", thresholds=10) | |
| self.map_score_function = torchmetrics.classification.MultilabelAveragePrecision(num_labels=self.num_classes, average="weighted", thresholds=10) | |
| # average options: macro (simple average), micro (sum), weighted (weight by class size, then avg) | |
| # threshold: Threshold for transforming probability to binary (0,1) predictions. For some metrics (e.g., AUROC), represents the number of thresholds (evenly spaced b/n 0–1) the metric should be computed at (resulting array of values are the averaged to obtain the final score) | |
| def forward(self, x): | |
| # Forward pass through the backbone model | |
| return self.model(x) | |
| def _common_step(self, batch, batch_idx): | |
| """ | |
| Shared logic for train / val / test steps. | |
| Computes loss and evaluation metrics. | |
| """ | |
| x, y = batch | |
| # Compute model predictions () | |
| y_hat = self.forward(x) | |
| # Compute metrics (expects logits + labels) | |
| # loss = self.loss_function(y_hat, y.float()) | |
| # Compute mean loss over all classes | |
| loss = torchmetrics.aggregation.MeanMetric(self.loss_function(y_hat, y.float()), weight=X.shape[0]) | |
| accuracy = self.accuracy_function(y_hat, y) | |
| f1_score = self.f1_score_function(y_hat, y) | |
| auroc = self.auroc_function(y_hat, y) | |
| mAP = self.map_score_function(y_hat, y) # mean average precision | |
| return loss, y_hat, y, accuracy, f1_score, auroc, mAP | |
| def training_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_hat, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log epoch-level training metrics | |
| self.log_dict( | |
| {"train_loss": loss, "train_accuracy": accuracy, "train_f1_score": f1_score, "train_auroc": auroc, "train_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| # Lightning expects the loss key for backprop | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_hat, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log validation metrics | |
| self.log_dict( | |
| {"val_loss": loss, "val_accuracy": accuracy,"val_f1_score": f1_score, "val_auroc": auroc, "val_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| def test_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_hat, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log test metrics | |
| self.log_dict( | |
| {"test_loss": loss, "test_accuracy": accuracy,"test_f1_score": f1_score, "test_auroc": auroc, "test_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| def predict_step(self, batch, batch_idx): | |
| """ | |
| Prediction logic used by trainer.predict(). | |
| Returns model outputs without computing loss. | |
| """ | |
| x = batch if not isinstance(batch, (tuple, list)) else batch[0] | |
| logits = self.forward(x) | |
| # Convert logits to probabilities for inference | |
| probs = torch.sigmoid(logits) | |
| return probs | |
| def configure_optimizers(self): | |
| # Optimizer over all trainable parameters | |
| optimizer = optim.Adam(self.model.parameters(), lr=1e-3) | |
| return optimizer | |
| """### Create function for running inference (i.e., assistive medical diagnosis)""" | |
| def run_diagnosis( | |
| backbone_name, | |
| input_image, | |
| threshold, | |
| preprocess_fn=None, | |
| Idx2labels=None | |
| ): | |
| # Preprocess | |
| x = preprocess_fn(input_image).unsqueeze(0) | |
| # Resolve backbone | |
| backbone_info = BACKBONE_REGISTRY[backbone_name] | |
| ckpt_path = os.path.join(CKPT_ROOT, backbone_info["ckpt"]) | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | |
| # Load model (cache for speed) | |
| if backbone_name not in MODEL_CACHE: | |
| print(f"Loading model weights from {ckpt_path}") | |
| MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint( | |
| ckpt_path, backbone_model_name=backbone_info["torchvision_name"]) | |
| model = MODEL_CACHE[backbone_name] | |
| model.eval() | |
| # Forward | |
| logits = model(x) | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| print("predicted logits\n") | |
| for i, logit_ in enumerate(logits): | |
| print(f"{Idx2labels[i]}: {logit_}") | |
| output_probs = { | |
| Idx2labels[i]: float(p) for i, p in enumerate(probs)} | |
| predicted_classes = [ | |
| Idx2labels[i] for i, p in enumerate(probs) if p >= threshold] | |
| return "\n".join(predicted_classes), output_probs | |
| """### Gradio app""" | |
| CKPT_ROOT = os.path.join(os.getcwd(), "Trained models") | |
| example_list_dir = os.path.join(os.getcwd(), "Curated test samples") | |
| example_list_img_names = os.listdir(example_list_dir) | |
| # example_list = [ | |
| # [os.path.join(example_list_dir, example_img), configs["DEFAULT_BACKBONE"]] | |
| # for example_img in example_list_img_names | |
| # if example_img.lower().endswith(".png")] | |
| example_list = [ | |
| [configs["DEFAULT_BACKBONE"], os.path.join(example_list_dir, example_img)] | |
| for example_img in example_list_img_names[:8] | |
| if example_img.lower().endswith(".png")] | |
| # example_list = [['/content/new_labels.csv',"ResNet50"]] | |
| gradio_app = gradio.Interface( | |
| fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict), | |
| inputs = [gradio.Dropdown(list(BACKBONE_REGISTRY.keys()), value="ResNet50", label="Select Backbone Model"), | |
| gradio.Image(type="pil", label="Load chest-X-ray image here"), | |
| gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.4, label = "Set Prediction Threshold") | |
| ], | |
| outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"), | |
| gradio.Label(label="Predicted Probabilities", show_label=False)], | |
| examples = example_list, | |
| cache_examples = False, | |
| title = "ChestVision", | |
| description = "Deep CNN-based solutions for assistive medical diagnosis", | |
| article = "Author: C. Foli (02.2026) | Website: coming soon...") | |
| if __name__ == "__main__": | |
| gradio_app.launch() | |