Spaces:
Running
Running
File size: 1,296 Bytes
5e0bfa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# utils/model_utils.py
import torch
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights
# Disease labels
DISEASE_LIST = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]
# Load trained CheXNet model
class CheXNet(nn.Module):
def __init__(self, num_classes=14):
super().__init__()
base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
self.features = base_model.features
self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)
def forward(self, x):
x = self.features(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
return self.classifier(x)
def load_model(url, device):
model_path = "dannynet.pth"
torch.hub.download_url_to_file(url,model_path)
model = torch.load(model_path, map_location = device)
model.eval()
return model
def predict(model, img_tensor, device):
with torch.no_grad():
output = model(img_tensor.unsqueeze(0).to(device))
probs = torch.sigmoid(output[0]).cpu().numpy()
return dict(zip(DISEASE_LIST, probs))
|