mshukor
init
3eb682b
raw
history blame
No virus
1.7 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Model construction functions."""
import torch
from fvcore.common.registry import Registry
MODEL_REGISTRY = Registry("MODEL")
MODEL_REGISTRY.__doc__ = """
Registry for video model.
The registered object will be called with `obj(cfg)`.
The call should return a `torch.nn.Module` object.
"""
def build_model(cfg, gpu_id=None):
"""
Builds the video model.
Args:
cfg (configs): configs that contains the hyper-parameters to build the
backbone. Details can be seen in slowfast/config/defaults.py.
gpu_id (Optional[int]): specify the gpu index to build model.
"""
if torch.cuda.is_available():
assert (
cfg.NUM_GPUS <= torch.cuda.device_count()
), "Cannot use more GPU devices than available"
else:
assert (
cfg.NUM_GPUS == 0
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs."
# Construct the model
name = cfg.MODEL.MODEL_NAME
model = MODEL_REGISTRY.get(name)(cfg)
if cfg.NUM_GPUS:
if gpu_id is None:
# Determine the GPU used by the current process
cur_device = torch.cuda.current_device()
else:
cur_device = gpu_id
# Transfer the model to the current GPU device
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
return model