| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms, models |
| from PIL import Image |
| import os |
| import time |
|
|
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
|
|
| |
| |
| |
| class FineTunedResNet(nn.Module): |
| def __init__(self, num_classes=4): |
| super().__init__() |
| self.resnet = models.resnet50( |
| weights=models.ResNet50_Weights.DEFAULT |
| ) |
|
|
| self.resnet.fc = nn.Sequential( |
| nn.Linear(self.resnet.fc.in_features, 1024), |
| nn.BatchNorm1d(1024), |
| nn.ReLU(), |
| nn.Dropout(0.5), |
|
|
| nn.Linear(1024, 512), |
| nn.BatchNorm1d(512), |
| nn.ReLU(), |
| nn.Dropout(0.5), |
|
|
| nn.Linear(512, 256), |
| nn.BatchNorm1d(256), |
| nn.ReLU(), |
| nn.Dropout(0.5), |
|
|
| nn.Linear(256, num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.resnet(x) |
|
|
| |
| |
| |
| MODEL_PATH = "models/final_fine_tuned_resnet50.pth" |
|
|
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Model not found: {MODEL_PATH}") |
|
|
| model = FineTunedResNet(num_classes=4) |
| model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) |
| model.eval() |
| model.to("cpu") |
|
|
| CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"] |
|
|
| |
| |
| |
| def predict(image: Image.Image) -> str: |
| start = time.time() |
|
|
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image) |
| probs = F.softmax(output, dim=1)[0] |
| top_probs, top_idxs = torch.topk(probs, 3) |
|
|
| elapsed = time.time() - start |
|
|
| result = "Top Predictions:\n\n" |
| for prob, idx in zip(top_probs, top_idxs): |
| result += f"{CLASSES[idx]} → {prob.item():.4f}\n" |
|
|
| result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds" |
| return result |
|
|
| |
| |
| |
| examples = [ |
| ["examples/Pneumonia/02009view1_frontal.jpg"], |
| ["examples/Pneumonia/02055view1_frontal.jpg"], |
| ["examples/Pneumonia/03152view1_frontal.jpg"], |
| ["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"], |
| ["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"], |
| ["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"], |
| ["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"], |
| ["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"], |
| ["examples/Normal/IM-0178-0001.jpeg"] |
| ] |
|
|
| |
| |
| |
| visualization_images = [ |
| "pictures/1.png", |
| "pictures/2.png", |
| "pictures/3.png", |
| "pictures/4.png", |
| "pictures/5.png" |
| ] |
|
|
| def display_visualizations(): |
| return [Image.open(path) for path in visualization_images] |
|
|
| |
| |
| |
| prediction_interface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Upload Chest X-ray"), |
| outputs=gr.Textbox(label="Prediction Result"), |
| examples=examples, |
| cache_examples=False, |
| title="Lung Disease Detection XVI", |
| description=""" |
| Upload a chest X-ray image to detect: |
| 🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis |
| """ |
| ) |
|
|
| visualization_interface = gr.Interface( |
| fn=display_visualizations, |
| inputs=None, |
| outputs=[ |
| gr.Image(type="pil", label=f"Visualization {i+1}") |
| for i in range(len(visualization_images)) |
| ], |
| title="Model Performance Visualizations" |
| ) |
|
|
| app = gr.TabbedInterface( |
| interface_list=[prediction_interface, visualization_interface], |
| tab_names=["Predict", "Model Performance"] |
| ) |
|
|
| |
| |
| |
| app.launch() |
|
|