File size: 1,231 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
dependencies = ["torch"]

import torch
import torchvision


def resnet50(pretrained: bool = False, **kwargs):
    r"""
    ResNet-50 visual backbone from the best performing VirTex model: pretrained
    for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``.

    This is a torchvision-like model, with the last ``avgpool`` and `fc``
    modules replaced with ``nn.Identity()`` modules. Given a batch of image
    tensors with size ``(B, 3, 224, 224)``, this model computes spatial image
    features of size ``(B, 7, 7, 2048)``, where B = batch size.

    pretrained (bool): Whether to load model with pretrained weights.
    """

    # Create a torchvision resnet50 with randomly initialized weights.
    model = torchvision.models.resnet50(pretrained=False, **kwargs)

    # Replace global average pooling and fully connected layers with identity
    # modules.
    model.avgpool = torch.nn.Identity()
    model.fc = torch.nn.Identity()

    if pretrained:
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                "https://umich.box.com/shared/static/gsjqm4i4fm1wpzi947h27wweljd8gcpy.pth",
                progress=False,
            )["model"]
        )
    return model