torinriley commited on
Commit
3f726f1
1 Parent(s): dd95c76

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -0
model.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
+
4
+ def get_model(num_classes):
5
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
6
+
7
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
8
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
9
+
10
+ return model