KaranNag commited on
Commit
6d062a0
1 Parent(s): ff3068c
Files changed (1) hide show
  1. model.py +34 -0
model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+
5
+ from torch import nn
6
+
7
+
8
+ def create_vit_model(num_classes:int=3,
9
+ seed:int=42):
10
+ """Creates a ViT-B/16 feature extractor model and transforms.
11
+
12
+ Args:
13
+ num_classes (int, optional): number of target classes. Defaults to 3.
14
+ seed (int, optional): random seed value for output layer. Defaults to 42.
15
+
16
+ Returns:
17
+ model (torch.nn.Module): ViT-B/16 feature extractor model.
18
+ transforms (torchvision.transforms): ViT-B/16 image transforms.
19
+ """
20
+ # Create ViT_B_16 pretrained weights, transforms and model
21
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
22
+ transforms = weights.transforms()
23
+ model = torchvision.models.vit_b_16(weights=weights)
24
+
25
+ # Freeze all layers in model
26
+ for param in model.parameters():
27
+ param.requires_grad = False
28
+
29
+ # Change classifier head to suit our needs (this will be trainable)
30
+ torch.manual_seed(seed)
31
+ model.heads = nn.Sequential(nn.Linear(in_features=768, # keep this the same as original model
32
+ out_features=num_classes)) # update to reflect target number of classes
33
+
34
+ return model, transforms