RashiAgarwal commited on
Commit
042c19a
1 Parent(s): a5dff2d

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +23 -0
train.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main file for training Yolo model on Pascal VOC and COCO dataset
3
+ """
4
+
5
+ import torch
6
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
7
+ import torch
8
+
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ torch.backends.cudnn.benchmark = True
13
+
14
+ class YOLOTraining(LightningModule):
15
+ def __init__(self,model):
16
+ super().__init__()
17
+ self.model = model
18
+
19
+ def forward(self, x):
20
+ return self.model(x)
21
+
22
+ if __name__ == "__main__":
23
+ num_classes = 20