codedad commited on
Commit
285d957
·
verified ·
1 Parent(s): a7b6188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -6,19 +6,22 @@ from PIL import Image
6
  # 1. Load your model (Ensure this matches your training architecture)
7
  # Change 'models.resnet18' if you used a different one
8
  # --- UPDATED MODEL ARCHITECTURE ---
9
- from torchvision import models
10
  import torch.nn as nn
 
11
 
12
- # 1. Use resnet50 instead of resnet18 to match your weights
13
- model = models.resnet50()
14
 
15
- # 2. Update the final layer to match your 2 classes (Defect, Normal)
16
- # Note: ResNet50 uses 2048 features in the final layer
17
- model.fc = nn.Linear(2048, 2)
 
 
 
 
 
18
 
19
- # 3. Load the weights
20
- model.load_state_dict(torch.load("fine_tuned_model.pt", map_location="cpu"))
21
- model.eval()
22
  model.load_state_dict(torch.load("fine_tuned_model.pt", map_location="cpu"))
23
  model.eval()
24
 
 
6
  # 1. Load your model (Ensure this matches your training architecture)
7
  # Change 'models.resnet18' if you used a different one
8
  # --- UPDATED MODEL ARCHITECTURE ---
 
9
  import torch.nn as nn
10
+ from torchvision import models
11
 
12
+ # 1. Initialize ResNet-50 (matches the 2048 feature size in your error)
13
+ model = models.resnet50()
14
 
15
+ # 2. Recreate the EXACT Sequential head used during your training
16
+ # This fixes the "Missing key: fc.0.weight" and "fc.3.weight" errors
17
+ model.fc = nn.Sequential(
18
+ nn.Linear(2048, 256), # fc.0
19
+ nn.ReLU(), # fc.1
20
+ nn.Dropout(0.4), # fc.2
21
+ nn.Linear(256, 2) # fc.3
22
+ )
23
 
24
+ # 3. Load your weights
 
 
25
  model.load_state_dict(torch.load("fine_tuned_model.pt", map_location="cpu"))
26
  model.eval()
27