Spaces:
Sleeping
Sleeping
File size: 1,183 Bytes
d70f24c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def create_model(num_classes, pretrained=True, coco_model=False):
# Load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
)
if coco_model: # Return the COCO pretrained model for COCO classes.
return model
# Get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# define a new head for the detector with required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
if __name__ == '__main__':
model = create_model(num_classes=81, pretrained=True, coco_model=True)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.") |