File size: 4,758 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
r"""
A utility module to easily load common VirTex models (optionally with pretrained
weights) using a single line of code.

Get our full best performing VirTex model (with pretrained weights as):

>>> import virtex.model_zoo as mz
>>> model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)

Any config available in ``configs/`` directory under project root can be
specified here, although this command need not be executed from project root.
For more details on available models, refer :doc:`usage/model_zoo`.

Part of this code is adapted from Detectron2's model zoo; which was originally
implemented by the developers of this codebase, with reviews and further
changes by Detectron2 developers.
"""
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import pkg_resources

from fvcore.common.download import download
import torch

from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager


class _ModelZooUrls(object):
    r"""Mapping from config names to URL suffixes of pretrained weights."""

    URL_PREFIX = "https://umich.box.com/shared/static"

    CONFIG_PATH_TO_URL_SUFFIX = {

        # Pretraining Task Ablations
        "task_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth",
        "task_ablations/captioning_R_50_L1_H2048.yaml": "1q9qh1cj2u4r5laj7mefd2mlzwthnga7.pth",
        "task_ablations/token_classification_R_50.yaml": "idvoxjl60pzpcllkbvadqgvwazil2mis.pth",
        "task_ablations/multilabel_classification_R_50.yaml": "yvlflmo0klqy3m71p6ug06c6aeg282hy.pth",
        "task_ablations/masked_lm_R_50_L1_H2048.yaml": "x3eij00eslse9j35t9j9ijyj8zkbkizh.pth",

        # Width Ablations
        "width_ablations/bicaptioning_R_50_L1_H512.yaml": "wtk18v0vffws48u5yrj2qjt94wje1pit.pth",
        "width_ablations/bicaptioning_R_50_L1_H768.yaml": "e94n0iexdvksi252bn7sm2vqjnyt9okf.pth",
        "width_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth",
        "width_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth",

        # Depth Ablations
        "depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth",
        "depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "9e88f6l13a9r8wq5bbe8qnoh9zenanq3.pth",
        "depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "4cv8052xiq91h7lyx52cp2a6m7m9qkgo.pth",
        "depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "bk5w4471mgvwa5mv6e4c7htgsafzmfm0.pth",

        # Backbone Ablations
        "backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth",
        "backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "19vcaf1488945836kir9ebm5itgtugaw.pth",
        "backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "nptbh4jsj0c0kjsnc2hw754fkikpgx9v.pth",
    }


def get(config_path, pretrained: bool = False):
    r"""
    Get a model specified by relative path under Detectron2's official
    ``configs/`` directory.

    Parameters
    ----------
    config_path: str
        Name of config file relative to ``configs/`` directory under project
        root. (For example, ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``)
    pretrained: bool, optional (default = False)
        If ``True``, will initialize the model with the pretrained weights. If
        ``False``, the weights will be initialized randomly.
    """

    # Get the original path to config file (shipped with inside the package).
    _pkg_config_path = pkg_resources.resource_filename(
        "virtex.model_zoo", os.path.join("configs", config_path)
    )
    if not os.path.exists(_pkg_config_path):
        raise RuntimeError("{} not available in Model Zoo!".format(config_path))

    _C = Config(_pkg_config_path)
    model = PretrainingModelFactory.from_config(_C)

    if pretrained:
        # Get URL for the checkpoint for this config path.
        if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX:
            url_suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path]
            checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{url_suffix}"
        else:
            raise RuntimeError("{} not available in Model Zoo!".format(config_path))

        # Download the pretrained model weights and save with a sensible name.
        # This will be downloaded only if it does not exist.
        checkpoint_path = download(
            checkpoint_url,
            dir=os.path.expanduser("~/.torch/virtex_cache"),
            filename=os.path.basename(config_path).replace(".yaml", ".pth")
        )
        CheckpointManager(model=model).load(checkpoint_path)

    return model