mgbam commited on
Commit
37e2134
·
verified ·
1 Parent(s): 08b9085

Update app33443.py

Browse files
Files changed (1) hide show
  1. app33443.py +21 -5
app33443.py CHANGED
@@ -32,10 +32,20 @@ model = models.efficientnet_b0(weights=None)
32
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
33
 
34
  try:
35
- model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device))
36
- print("✅ Multi-class model loaded successfully!")
 
37
  except Exception as e:
38
- print(f"⚠️ Error loading model: {e}")
 
 
 
 
 
 
 
 
 
39
 
40
  model = model.to(device)
41
  model.eval()
@@ -136,13 +146,19 @@ def predict_chest_xray(image, show_gradcam=True):
136
 
137
  # Get probabilities
138
  probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
 
 
 
 
 
 
139
  pred_class = int(output.argmax(dim=1).item())
140
  pred_label = CLASSES[pred_class]
141
  confidence = float(probs[pred_class]) * 100
142
 
143
- # Create results
144
  results = {
145
- CLASSES[i]: float(probs[i] * 100) for i in range(len(CLASSES))
146
  }
147
 
148
  # Generate visualizations
 
32
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
33
 
34
  try:
35
+ # Try loading best.pt from root directory (HuggingFace Spaces location)
36
+ model.load_state_dict(torch.load('best.pt', map_location=device))
37
+ print("✅ Multi-class model loaded successfully from best.pt!")
38
  except Exception as e:
39
+ print(f"⚠️ Error loading model from best.pt: {e}")
40
+ try:
41
+ # Fallback to checkpoints directory
42
+ model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device))
43
+ print("✅ Multi-class model loaded successfully from checkpoints/best_multiclass.pt!")
44
+ except Exception as e2:
45
+ print(f"❌ CRITICAL ERROR: Could not load model from any location!")
46
+ print(f" - best.pt error: {e}")
47
+ print(f" - checkpoints/best_multiclass.pt error: {e2}")
48
+ raise RuntimeError("Model file not found! Please ensure best.pt is uploaded to the Space.")
49
 
50
  model = model.to(device)
51
  model.eval()
 
146
 
147
  # Get probabilities
148
  probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
149
+
150
+ # Safety check: ensure probabilities sum to ~1.0
151
+ prob_sum = np.sum(probs)
152
+ if not (0.99 <= prob_sum <= 1.01):
153
+ print(f"⚠️ WARNING: Probability sum is {prob_sum}, not 1.0. Model may not be loaded correctly!")
154
+
155
  pred_class = int(output.argmax(dim=1).item())
156
  pred_label = CLASSES[pred_class]
157
  confidence = float(probs[pred_class]) * 100
158
 
159
+ # Create results - ensure values are between 0-100
160
  results = {
161
+ CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100))) for i in range(len(CLASSES))
162
  }
163
 
164
  # Generate visualizations