File size: 1,406 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torchvision.models.detection as models
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def set_parameter_requires_grad(model, 
                                tune_only: bool = False):
    if tune_only:
        for child in list(model.children()):
            for param in child.parameters():
                param.requires_grad = False


def initialize_model(model_name: str, 
                     num_classes: int, 
                     tune_only: bool = False, 
                     use_pretrained: bool = True):
    input_size = 0

    model = getattr(models, model_name, lambda: None)
    model_ft = model(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, tune_only)

    if model_name.startswith("maskrcnn"):
        mask_predictor_in_channels = 256
        mask_dim_reduced = 256
        model_ft.mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)

    elif model_name.startswith("fasterrcnn"):
        from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
        # get number of input features for the classifier
        in_features = model_ft.roi_heads.box_predictor.cls_score.in_features
        model_ft.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    else:
        raise ValueError("{0} is not supported!".format(model_name))

    return model_ft, input_size