carlosabadia commited on
Commit
ee240b7
·
1 Parent(s): 51b7d90

fixed state_dict

Browse files
__pycache__/model.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
app.py CHANGED
@@ -18,7 +18,15 @@ vit16, vit16_transforms = create_vit16_model(
18
  num_classes=101, # could also use len(class_names)
19
  )
20
 
21
- # Load saved weights
 
 
 
 
 
 
 
 
22
  vit16.load_state_dict(
23
  torch.load(
24
  f="model_food101_20_percent.pth",
@@ -26,6 +34,8 @@ vit16.load_state_dict(
26
  )
27
  )
28
 
 
 
29
  ### 3. Predict function ###
30
 
31
  # Create predict function
 
18
  num_classes=101, # could also use len(class_names)
19
  )
20
 
21
+
22
+ state_dict = torch.load("model_food101_20_percent.pth")
23
+ state_dict["heads.0.weight"] = state_dict.pop("heads.weight")
24
+ state_dict["heads.0.bias"] = state_dict.pop("heads.bias")
25
+ # save new state_dict in .pth
26
+ torch.save(state_dict, "model_food101_20_percent.pth")
27
+
28
+
29
+
30
  vit16.load_state_dict(
31
  torch.load(
32
  f="model_food101_20_percent.pth",
 
34
  )
35
  )
36
 
37
+
38
+
39
  ### 3. Predict function ###
40
 
41
  # Create predict function
model.py CHANGED
@@ -28,8 +28,7 @@ def create_vit16_model(num_classes:int=101,
28
 
29
  # Change classifier head with random seed for reproducibility
30
  torch.manual_seed(seed)
31
- model.classifier = nn.Sequential(
32
- nn.Linear(in_features=768, out_features=num_classes).to("cpu"),
33
- )
34
 
35
  return model, transforms
 
28
 
29
  # Change classifier head with random seed for reproducibility
30
  torch.manual_seed(seed)
31
+ model.heads = nn.Sequential(nn.Linear(in_features=768, # keep this the same as original model
32
+ out_features=num_classes)) # update to reflect target number of classes
 
33
 
34
  return model, transforms
model_food101_20_percent.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26afc8e7a4a879d88f402030c346e073fdc2671ed6928c87305a1e2776f72c5b
3
  size 343564561
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4357a334ed5737baacaf7a99b0ba491ef88c61580790277acec9ef877cd77c9
3
  size 343564561