Spaces:
Sleeping
Sleeping
Commit
·
63106a4
1
Parent(s):
b9a7765
fixed the weight issue
Browse files
app.py
CHANGED
|
@@ -146,11 +146,15 @@ model = ResNet50(num_classes=1000)
|
|
| 146 |
try:
|
| 147 |
checkpoint = torch.load("best_model.pt", map_location=device)
|
| 148 |
if 'model_state_dict' in checkpoint:
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
| 150 |
print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
|
| 151 |
print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
|
| 152 |
else:
|
| 153 |
-
|
|
|
|
| 154 |
print("Model loaded successfully!")
|
| 155 |
except Exception as e:
|
| 156 |
print(f"Warning: Could not load model weights: {e}")
|
|
|
|
| 146 |
try:
|
| 147 |
checkpoint = torch.load("best_model.pt", map_location=device)
|
| 148 |
if 'model_state_dict' in checkpoint:
|
| 149 |
+
state_dict = checkpoint['model_state_dict']
|
| 150 |
+
# Remove _orig_mod. prefix if present (from torch.compile)
|
| 151 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
| 152 |
+
model.load_state_dict(state_dict)
|
| 153 |
print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
|
| 154 |
print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
|
| 155 |
else:
|
| 156 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()}
|
| 157 |
+
model.load_state_dict(state_dict)
|
| 158 |
print("Model loaded successfully!")
|
| 159 |
except Exception as e:
|
| 160 |
print(f"Warning: Could not load model weights: {e}")
|