airad / airad.py
UncleanCode's picture
uploaded the inference endpoint .py file
86c6a58 verified
raw
history blame
2.73 kB
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import cv2
import numpy as np
import gdown
class AIRadModel(nn.Module):
def __init__(self,num_classes=2):
super(AIRadModel,self).__init__()
self.model = models.efficientnet_b0(pretrained=False)
self.num_features = model.classifier[1].in_features
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(self.num_features, num_classes) # Two classes: normal, pneumonia
)
def forward(self, x):
return self.model(x)
class AIRadSimModel(nn.Module):
def __init__(self, num_classes=2):
super(AIRadSimModel,self).__init__()
self.sim_model = models.resnet50(pretrained=False)
self.sim_model.fc = nn.Linear(self.sim_model.fc.in_features,num_classes)
def forward(self,x):
return self.sim_model(x)
def load_model():
model = AIRadModel(num_classes=2)
file_id = '1CKkdQ5nKWkz3L-ZdgyrJ5SE-oiFwXnSJ'
gdrive_url = f"https://drive.google.com/uc?id={file_id}"
model_checkpoint = 'model_checkpoint.pth'
gdown.download(gdrive_url, model_checkpoint, quiet=False)
model.load_state_dict(torch.load(model_checkpoint))
model.eval()
return model
def load_sim_model():
sim_model = AIRadSimModel(num_classes=2)
sim_file_id = 'cjdDsW5QAIlOneOPLg0uYqTURSr0oOLq'
sim_gdrive_url = f"https://drive.google.com/uc?id={file_id}"
sim_model_checkpoint = 'sim_model_checkpoint.pth'
gdown.download(sim_gdrive_url, sim_model_checkpoint, quiet=False)
sim_model.load_state_dict(torch.load(sim_model_checkpoint))
sim_model.eval()
return sim_model()
model = load_model()
sim_model = load_sim_model()
class_names = {0: 'normal', 1: 'pneumonia'}
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def predict(image_path):
image = Image.open(image_path).convert("RGB")
image_np = np.array(image)
image_np = cv2.bilateralFilter(image_np, 9, 75, 75)
image = Image.fromarray(image_np)
image_tensor = preprocess(image).unsqueeze(0).to(device)
# Use ResNet50 to predict if the image is an X-ray
with torch.no_grad():
sim_output = sim_model(image_tensor)
_, predicted_sim = torch.max(sim_output, 1)
predicted_class_sim = predicted_sim.item()
if predicted_class_sim == 1:
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
predicted_class = predicted.item()
confidence = torch.nn.functional.softmax(output, dim=1)[0][predicted_class].item()
return class_names[predicted_class] ,confidence
else:
return "error"