Theem's picture
Update readme instructions
d55be37
|
raw
history blame
2.66 kB
metadata
license: cc-by-4.0

PyTorch FasterRCNN with ResNet50 backbone finetuned on grayscale COCO.

The COCO images were transformed to grayscale using PIL. The hyperparameters and epochs were all kept the same as the implementation for PyTorch.

Can be used as pretrained model for multispectral imaging as suggested in this paper.

The file is given as a state_dict. Thus to initialize the model run:

# Load pretrained weights
state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model']
# Load torchvision model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
# Adapt input convolution
model.backbone.body.conv1 = torch.nn.Conv2d(1, 64,
                            kernel_size=(7, 7), stride=(2, 2),
                            padding=(3, 3), bias=False).requires_grad_(True)
model.load_state_dict(state_dict)

If its going to be used for multispectral data, edit the first layer and duplicate the weights:

state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model']

# Duplicate the weights
conv1_weight = state_dict['backbone.body.conv1.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
#conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict['backbone.body.conv1.weight'] = conv1_weight

model.backbone.body.conv1 = torch.nn.Conv2d(in_chans, 64,
                                            kernel_size=(7, 7), stride=(2, 2),
                                            padding=(3, 3), bias=False).requires_grad_(True)
model.load_state_dict(state_dict)

For Faster-RCNN the input transform may need to be adapted. Here is an example:

coco_mean = [0.5] * in_chans
coco_std = [0.25] * in_chans
if in_chans > 3:
    coco_mean[:3] = [0.485, 0.456, 0.406]
    coco_std[:3] = [0.229, 0.224, 0.225]
transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=800,
                                                                            max_size=1333,
                                                                            image_mean=coco_mean,
                                                                            image_std=coco_std)