Spaces:
Sleeping
Sleeping
Update app33443.py
Browse files- app33443.py +21 -5
app33443.py
CHANGED
|
@@ -32,10 +32,20 @@ model = models.efficientnet_b0(weights=None)
|
|
| 32 |
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
|
| 33 |
|
| 34 |
try:
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
-
print(f"⚠️ Error loading model: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
model = model.to(device)
|
| 41 |
model.eval()
|
|
@@ -136,13 +146,19 @@ def predict_chest_xray(image, show_gradcam=True):
|
|
| 136 |
|
| 137 |
# Get probabilities
|
| 138 |
probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
pred_class = int(output.argmax(dim=1).item())
|
| 140 |
pred_label = CLASSES[pred_class]
|
| 141 |
confidence = float(probs[pred_class]) * 100
|
| 142 |
|
| 143 |
-
# Create results
|
| 144 |
results = {
|
| 145 |
-
CLASSES[i]: float(probs[i] * 100) for i in range(len(CLASSES))
|
| 146 |
}
|
| 147 |
|
| 148 |
# Generate visualizations
|
|
|
|
| 32 |
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
|
| 33 |
|
| 34 |
try:
|
| 35 |
+
# Try loading best.pt from root directory (HuggingFace Spaces location)
|
| 36 |
+
model.load_state_dict(torch.load('best.pt', map_location=device))
|
| 37 |
+
print("✅ Multi-class model loaded successfully from best.pt!")
|
| 38 |
except Exception as e:
|
| 39 |
+
print(f"⚠️ Error loading model from best.pt: {e}")
|
| 40 |
+
try:
|
| 41 |
+
# Fallback to checkpoints directory
|
| 42 |
+
model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device))
|
| 43 |
+
print("✅ Multi-class model loaded successfully from checkpoints/best_multiclass.pt!")
|
| 44 |
+
except Exception as e2:
|
| 45 |
+
print(f"❌ CRITICAL ERROR: Could not load model from any location!")
|
| 46 |
+
print(f" - best.pt error: {e}")
|
| 47 |
+
print(f" - checkpoints/best_multiclass.pt error: {e2}")
|
| 48 |
+
raise RuntimeError("Model file not found! Please ensure best.pt is uploaded to the Space.")
|
| 49 |
|
| 50 |
model = model.to(device)
|
| 51 |
model.eval()
|
|
|
|
| 146 |
|
| 147 |
# Get probabilities
|
| 148 |
probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
|
| 149 |
+
|
| 150 |
+
# Safety check: ensure probabilities sum to ~1.0
|
| 151 |
+
prob_sum = np.sum(probs)
|
| 152 |
+
if not (0.99 <= prob_sum <= 1.01):
|
| 153 |
+
print(f"⚠️ WARNING: Probability sum is {prob_sum}, not 1.0. Model may not be loaded correctly!")
|
| 154 |
+
|
| 155 |
pred_class = int(output.argmax(dim=1).item())
|
| 156 |
pred_label = CLASSES[pred_class]
|
| 157 |
confidence = float(probs[pred_class]) * 100
|
| 158 |
|
| 159 |
+
# Create results - ensure values are between 0-100
|
| 160 |
results = {
|
| 161 |
+
CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100))) for i in range(len(CLASSES))
|
| 162 |
}
|
| 163 |
|
| 164 |
# Generate visualizations
|