Juartaurus's picture
Upload folder using huggingface_hub
1865436
raw
history blame
1.41 kB
import torchvision.models.detection as models
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def set_parameter_requires_grad(model,
tune_only: bool = False):
if tune_only:
for child in list(model.children()):
for param in child.parameters():
param.requires_grad = False
def initialize_model(model_name: str,
num_classes: int,
tune_only: bool = False,
use_pretrained: bool = True):
input_size = 0
model = getattr(models, model_name, lambda: None)
model_ft = model(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, tune_only)
if model_name.startswith("maskrcnn"):
mask_predictor_in_channels = 256
mask_dim_reduced = 256
model_ft.mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
elif model_name.startswith("fasterrcnn"):
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# get number of input features for the classifier
in_features = model_ft.roi_heads.box_predictor.cls_score.in_features
model_ft.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
else:
raise ValueError("{0} is not supported!".format(model_name))
return model_ft, input_size