import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def create_model(num_classes, pretrained=True, coco_model=False): # Load Faster RCNN pre-trained model model = torchvision.models.detection.fasterrcnn_resnet50_fpn( weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT ) if coco_model: # Return the COCO pretrained model for COCO classes. return model # Get the number of input features in_features = model.roi_heads.box_predictor.cls_score.in_features # define a new head for the detector with required number of classes model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model if __name__ == '__main__': model = create_model(num_classes=81, pretrained=True, coco_model=True) print(model) # Total parameters and trainable parameters. total_params = sum(p.numel() for p in model.parameters()) print(f"{total_params:,} total parameters.") total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad) print(f"{total_trainable_params:,} training parameters.")