EZ-Crossword / Model_utils_inf.py
Ujjwal123's picture
copied the whole api code from django and updated the dockerfile
a04b340
import collections
import glob
import logging
import os
from typing import List
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.serialization import default_restore_location
logger = logging.getLogger()
CheckpointState = collections.namedtuple(
"CheckpointState",
[
"model_dict",
"optimizer_dict",
"scheduler_dict",
"offset",
"epoch",
"encoder_params",
],
)
def setup_for_distributed_mode(
model: nn.Module,
optimizer: torch.optim.Optimizer,
device: object,
n_gpu: int = 1,
local_rank: int = -1,
fp16: bool = False,
fp16_opt_level: str = "O1",
) -> (nn.Module, torch.optim.Optimizer):
model.to(device)
if fp16:
try:
import apex
from apex import amp
apex.amp.register_half_function(torch, "einsum")
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
if local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True,
)
return model, optimizer
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
def move_to_device(sample, device):
if len(sample) == 0:
return {}
def _move_to_device(maybe_tensor, device):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(device)
elif isinstance(maybe_tensor, dict):
return {
key: _move_to_device(value, device)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_move_to_device(x, device) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return [_move_to_device(x, device) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_device(sample, device)
def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1):
"""Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0,
float(training_steps - current_step)
/ float(max(1, training_steps - warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def init_weights(modules: List):
for module in modules:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, "module") else model
def get_model_file(args, file_prefix) -> str:
if args.model_file and os.path.exists(args.model_file):
return args.model_file
out_cp_files = (
glob.glob(os.path.join(args.output_dir, file_prefix + "*"))
if args.output_dir
else []
)
logger.info("Checkpoint files %s", out_cp_files)
model_file = None
if len(out_cp_files) > 0:
model_file = max(out_cp_files, key=os.path.getctime)
return model_file
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info("Reading saved model from s", model_file)
if isinstance(model_file, tuple):
model_file = model_file[0]
state_dict = torch.load(
model_file, map_location=lambda s, l: default_restore_location(s, "cpu")
)
logger.info("model_state_dict keys %s", state_dict.keys())
return CheckpointState(**state_dict)