Vedmani commited on
Commit
ec236ec
1 Parent(s): 6b70dc2

added model

Browse files
Files changed (1) hide show
  1. models.py +22 -0
models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import timm
3
+
4
+
5
+ class EfficientNet(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.efficientnet = timm.create_model(model_name="efficientnet_b0", pretrained=True, num_classes=25)
9
+
10
+ """# Set requires_grad to False for all parameters except the output layer
11
+ for name, param in self.efficientnet.named_parameters():
12
+ if not name.startswith('classifier'):
13
+ param.requires_grad = False"""
14
+ # print number of parameters including final layer
15
+ trainable_params = sum(p.numel() for p in self.efficientnet.parameters() if p.requires_grad)
16
+ total_params = sum(p.numel() for p in self.efficientnet.parameters())
17
+ """print("Efficientnet_b0 with 25 classes initialized")
18
+ print(f"Trainable parameters: {trainable_params}")
19
+ print(f"Total parameters: {total_params}")"""
20
+
21
+ def forward(self, x):
22
+ return self.efficientnet(x)