zhangjiaheng001 commited on
Commit
bbe469e
1 Parent(s): 2541188

change the model classifiler to suit b6

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -29,8 +29,8 @@ def create_effnetb2_model(num_classes:int=3,
29
  # Change classifier head with random seed for reproducibility
30
  torch.manual_seed(seed)
31
  model.classifier = nn.Sequential(
32
- nn.Dropout(p=0.3, inplace=True),
33
- nn.Linear(in_features=1408, out_features=num_classes),
34
  )
35
 
36
  return model, transforms
 
29
  # Change classifier head with random seed for reproducibility
30
  torch.manual_seed(seed)
31
  model.classifier = nn.Sequential(
32
+ nn.Dropout(p=0.5, inplace=True),
33
+ nn.Linear(in_features=2304, out_features=num_classes),
34
  )
35
 
36
  return model, transforms