import torch from torch import nn from torchvision import models, transforms from PIL import Image import cv2 import numpy as np import gdown class AIRadModel(nn.Module): def __init__(self,num_classes=2): super(AIRadModel,self).__init__() self.model = models.efficientnet_b0(pretrained=False) self.num_features = model.classifier[1].in_features self.model.classifier = nn.Sequential( nn.Dropout(p=0.2), nn.Linear(self.num_features, num_classes) # Two classes: normal, pneumonia ) def forward(self, x): return self.model(x) class AIRadSimModel(nn.Module): def __init__(self, num_classes=2): super(AIRadSimModel,self).__init__() self.sim_model = models.resnet50(pretrained=False) self.sim_model.fc = nn.Linear(self.sim_model.fc.in_features,num_classes) def forward(self,x): return self.sim_model(x) def load_model(): model = AIRadModel(num_classes=2) file_id = '1CKkdQ5nKWkz3L-ZdgyrJ5SE-oiFwXnSJ' gdrive_url = f"https://drive.google.com/uc?id={file_id}" model_checkpoint = 'model_checkpoint.pth' gdown.download(gdrive_url, model_checkpoint, quiet=False) model.load_state_dict(torch.load(model_checkpoint)) model.eval() return model def load_sim_model(): sim_model = AIRadSimModel(num_classes=2) sim_file_id = 'cjdDsW5QAIlOneOPLg0uYqTURSr0oOLq' sim_gdrive_url = f"https://drive.google.com/uc?id={file_id}" sim_model_checkpoint = 'sim_model_checkpoint.pth' gdown.download(sim_gdrive_url, sim_model_checkpoint, quiet=False) sim_model.load_state_dict(torch.load(sim_model_checkpoint)) sim_model.eval() return sim_model() model = load_model() sim_model = load_sim_model() class_names = {0: 'normal', 1: 'pneumonia'} preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def predict(image_path): image = Image.open(image_path).convert("RGB") image_np = np.array(image) image_np = cv2.bilateralFilter(image_np, 9, 75, 75) image = Image.fromarray(image_np) image_tensor = preprocess(image).unsqueeze(0).to(device) # Use ResNet50 to predict if the image is an X-ray with torch.no_grad(): sim_output = sim_model(image_tensor) _, predicted_sim = torch.max(sim_output, 1) predicted_class_sim = predicted_sim.item() if predicted_class_sim == 1: with torch.no_grad(): output = model(image_tensor) _, predicted = torch.max(output, 1) predicted_class = predicted.item() confidence = torch.nn.functional.softmax(output, dim=1)[0][predicted_class].item() return class_names[predicted_class] ,confidence else: return "error"