Stanlito commited on
Commit
83a9d40
1 Parent(s): 17838e7

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -11
model.py CHANGED
@@ -3,22 +3,20 @@ import torchvision
3
 
4
  from torch import nn
5
 
6
- def create_transformer_model(num_classes:int=325, # default output classes = 3 (pizza, steak, sushi)
7
- seed:int=42):
8
- # 1, 2, 3 Create EffNetB2 pretrained weights, transforms and model
9
  weights = torchvision.models.ViT_B_16_Weights.DEFAULT
10
  transforms = weights.transforms()
11
  model = torchvision.models.vit_b_16(weights=weights)
12
 
13
- # 4. Freeze all layers in the base model
14
  for param in model.parameters():
15
  param.requires_grad = False
16
 
17
- # 5. Change classifier head with random seed for reproducibility
18
  torch.manual_seed(seed)
19
- model.classifier = nn.Sequential(
20
- nn.Dropout(p=0.3, inplace=True),
21
- nn.Linear(in_features=768, out_features=num_classes)
22
- )
23
-
24
- return model, transforms
 
3
 
4
  from torch import nn
5
 
6
+ def create_vit_model(num_classes:int=325,
7
+ seed:int=42):
8
+ # Create ViT_B_16 pretrained weights, transforms and model
9
  weights = torchvision.models.ViT_B_16_Weights.DEFAULT
10
  transforms = weights.transforms()
11
  model = torchvision.models.vit_b_16(weights=weights)
12
 
13
+ # Freeze all of the base layers
14
  for param in model.parameters():
15
  param.requires_grad = False
16
 
17
+ # Change classifier head to suit our needs
18
  torch.manual_seed(seed)
19
+ model.heads = nn.Sequential(nn.Linear(in_features=768,
20
+ out_features=325))
21
+
22
+ return model, transforms