arghyaiitb commited on
Commit
63106a4
·
1 Parent(s): b9a7765

fixed the weight issue

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -146,11 +146,15 @@ model = ResNet50(num_classes=1000)
146
  try:
147
  checkpoint = torch.load("best_model.pt", map_location=device)
148
  if 'model_state_dict' in checkpoint:
149
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
150
  print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
151
  print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
152
  else:
153
- model.load_state_dict(checkpoint)
 
154
  print("Model loaded successfully!")
155
  except Exception as e:
156
  print(f"Warning: Could not load model weights: {e}")
 
146
  try:
147
  checkpoint = torch.load("best_model.pt", map_location=device)
148
  if 'model_state_dict' in checkpoint:
149
+ state_dict = checkpoint['model_state_dict']
150
+ # Remove _orig_mod. prefix if present (from torch.compile)
151
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
152
+ model.load_state_dict(state_dict)
153
  print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
154
  print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
155
  else:
156
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()}
157
+ model.load_state_dict(state_dict)
158
  print("Model loaded successfully!")
159
  except Exception as e:
160
  print(f"Warning: Could not load model weights: {e}")