File size: 1,296 Bytes
5e0bfa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# utils/model_utils.py
import torch
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights

# Disease labels
DISEASE_LIST = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
    'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
    'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]

# Load trained CheXNet model
class CheXNet(nn.Module):
    def __init__(self, num_classes=14):
        super().__init__()
        base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
        self.features = base_model.features
        self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        return self.classifier(x)

def load_model(url, device):
    model_path = "dannynet.pth"
    torch.hub.download_url_to_file(url,model_path)
    model = torch.load(model_path, map_location = device)
    model.eval()
    return model

def predict(model, img_tensor, device):
    with torch.no_grad():
        output = model(img_tensor.unsqueeze(0).to(device))
        probs = torch.sigmoid(output[0]).cpu().numpy()
    return dict(zip(DISEASE_LIST, probs))