Spaces:
Runtime error
Runtime error
r""" | |
A utility module to easily load common VirTex models (optionally with pretrained | |
weights) using a single line of code. | |
Get our full best performing VirTex model (with pretrained weights as): | |
import virtex.model_zoo as mz | |
model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True) | |
Any config available in ``configs/`` directory under project root can be | |
specified here, although this command need not be executed from project root. | |
For more details on available models, refer :doc:`usage/model_zoo`. | |
Part of this code is adapted from Detectron2's model zoo; which was originally | |
implemented by the developers of this codebase, with reviews and further | |
changes by Detectron2 developers. | |
""" | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import os | |
import pkg_resources | |
from fvcore.common.download import download | |
import torch | |
from virtex.config import Config | |
from virtex.factories import PretrainingModelFactory | |
from virtex.utils.checkpointing import CheckpointManager | |
class _ModelZooUrls(object): | |
r"""Mapping from config names to URL suffixes of pretrained weights.""" | |
URL_PREFIX = "https://umich.box.com/shared/static" | |
CONFIG_PATH_TO_URL_SUFFIX = { | |
# Pretraining Task Ablations | |
"task_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth", | |
"task_ablations/captioning_R_50_L1_H2048.yaml": "1q9qh1cj2u4r5laj7mefd2mlzwthnga7.pth", | |
"task_ablations/token_classification_R_50.yaml": "idvoxjl60pzpcllkbvadqgvwazil2mis.pth", | |
"task_ablations/multilabel_classification_R_50.yaml": "yvlflmo0klqy3m71p6ug06c6aeg282hy.pth", | |
"task_ablations/masked_lm_R_50_L1_H2048.yaml": "x3eij00eslse9j35t9j9ijyj8zkbkizh.pth", | |
# Width Ablations | |
"width_ablations/bicaptioning_R_50_L1_H512.yaml": "wtk18v0vffws48u5yrj2qjt94wje1pit.pth", | |
"width_ablations/bicaptioning_R_50_L1_H768.yaml": "e94n0iexdvksi252bn7sm2vqjnyt9okf.pth", | |
"width_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", | |
"width_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth", | |
# Depth Ablations | |
"depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", | |
"depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "9e88f6l13a9r8wq5bbe8qnoh9zenanq3.pth", | |
"depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "4cv8052xiq91h7lyx52cp2a6m7m9qkgo.pth", | |
"depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "bk5w4471mgvwa5mv6e4c7htgsafzmfm0.pth", | |
# Backbone Ablations | |
"backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", | |
"backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "19vcaf1488945836kir9ebm5itgtugaw.pth", | |
"backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "nptbh4jsj0c0kjsnc2hw754fkikpgx9v.pth", | |
} | |
def get(config_path, pretrained: bool = False): | |
r""" | |
Get a model specified by relative path under Detectron2's official | |
``configs/`` directory. | |
Parameters | |
---------- | |
config_path: str | |
Name of config file relative to ``configs/`` directory under project | |
root. (For example, ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) | |
pretrained: bool, optional (default = False) | |
If ``True``, will initialize the model with the pretrained weights. If | |
``False``, the weights will be initialized randomly. | |
""" | |
# Get the original path to config file (shipped with inside the package). | |
_pkg_config_path = pkg_resources.resource_filename( | |
"virtex.model_zoo", os.path.join("configs", config_path) | |
) | |
if not os.path.exists(_pkg_config_path): | |
raise RuntimeError("{} not available in Model Zoo!".format(config_path)) | |
_C = Config(_pkg_config_path) | |
model = PretrainingModelFactory.from_config(_C) | |
if pretrained: | |
# Get URL for the checkpoint for this config path. | |
if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX: | |
url_suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path] | |
checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{url_suffix}" | |
else: | |
raise RuntimeError("{} not available in Model Zoo!".format(config_path)) | |
# Download the pretrained model weights and save with a sensible name. | |
# This will be downloaded only if it does not exist. | |
checkpoint_path = download( | |
checkpoint_url, | |
dir=os.path.expanduser("~/.torch/virtex_cache"), | |
filename=os.path.basename(config_path).replace(".yaml", ".pth") | |
) | |
CheckpointManager(model=model).load(checkpoint_path) | |
return model | |