UncleanCode commited on
Commit
86c6a58
1 Parent(s): 7047552

uploaded the inference endpoint .py file

Browse files
Files changed (1) hide show
  1. airad.py +86 -0
airad.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import gdown
8
+
9
+ class AIRadModel(nn.Module):
10
+ def __init__(self,num_classes=2):
11
+ super(AIRadModel,self).__init__()
12
+ self.model = models.efficientnet_b0(pretrained=False)
13
+ self.num_features = model.classifier[1].in_features
14
+ self.model.classifier = nn.Sequential(
15
+ nn.Dropout(p=0.2),
16
+ nn.Linear(self.num_features, num_classes) # Two classes: normal, pneumonia
17
+ )
18
+
19
+ def forward(self, x):
20
+ return self.model(x)
21
+
22
+ class AIRadSimModel(nn.Module):
23
+ def __init__(self, num_classes=2):
24
+ super(AIRadSimModel,self).__init__()
25
+ self.sim_model = models.resnet50(pretrained=False)
26
+ self.sim_model.fc = nn.Linear(self.sim_model.fc.in_features,num_classes)
27
+
28
+ def forward(self,x):
29
+ return self.sim_model(x)
30
+
31
+
32
+ def load_model():
33
+ model = AIRadModel(num_classes=2)
34
+ file_id = '1CKkdQ5nKWkz3L-ZdgyrJ5SE-oiFwXnSJ'
35
+ gdrive_url = f"https://drive.google.com/uc?id={file_id}"
36
+ model_checkpoint = 'model_checkpoint.pth'
37
+ gdown.download(gdrive_url, model_checkpoint, quiet=False)
38
+ model.load_state_dict(torch.load(model_checkpoint))
39
+ model.eval()
40
+ return model
41
+
42
+ def load_sim_model():
43
+ sim_model = AIRadSimModel(num_classes=2)
44
+ sim_file_id = 'cjdDsW5QAIlOneOPLg0uYqTURSr0oOLq'
45
+ sim_gdrive_url = f"https://drive.google.com/uc?id={file_id}"
46
+ sim_model_checkpoint = 'sim_model_checkpoint.pth'
47
+ gdown.download(sim_gdrive_url, sim_model_checkpoint, quiet=False)
48
+ sim_model.load_state_dict(torch.load(sim_model_checkpoint))
49
+ sim_model.eval()
50
+ return sim_model()
51
+
52
+ model = load_model()
53
+ sim_model = load_sim_model()
54
+ class_names = {0: 'normal', 1: 'pneumonia'}
55
+
56
+ preprocess = transforms.Compose([
57
+ transforms.Resize(256),
58
+ transforms.CenterCrop(224),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
61
+ ])
62
+
63
+ def predict(image_path):
64
+ image = Image.open(image_path).convert("RGB")
65
+ image_np = np.array(image)
66
+ image_np = cv2.bilateralFilter(image_np, 9, 75, 75)
67
+ image = Image.fromarray(image_np)
68
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
69
+
70
+ # Use ResNet50 to predict if the image is an X-ray
71
+ with torch.no_grad():
72
+ sim_output = sim_model(image_tensor)
73
+ _, predicted_sim = torch.max(sim_output, 1)
74
+ predicted_class_sim = predicted_sim.item()
75
+
76
+ if predicted_class_sim == 1:
77
+ with torch.no_grad():
78
+ output = model(image_tensor)
79
+ _, predicted = torch.max(output, 1)
80
+ predicted_class = predicted.item()
81
+ confidence = torch.nn.functional.softmax(output, dim=1)[0][predicted_class].item()
82
+ return class_names[predicted_class] ,confidence
83
+
84
+ else:
85
+ return "error"
86
+