Jfink09 commited on
Commit
3718d60
1 Parent(s): 0983653

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -0
model.py CHANGED
@@ -20,6 +20,7 @@ def create_resnet50_model(num_classes:int=10, # 4
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
 
23
 
24
  # 4. Freeze all layers in base model
25
  for param in model.parameters():
 
20
  weights = torchvision.models.ResNet50_Weights.DEFAULT
21
  transforms = weights.transforms()
22
  model = torchvision.models.resnet50(weights=weights)
23
+ model.fc = nn.Linear(2048, 10)
24
 
25
  # 4. Freeze all layers in base model
26
  for param in model.parameters():