Spaces:
Sleeping
Sleeping
File size: 4,773 Bytes
a04b340 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import collections
import glob
import logging
import os
from typing import List
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.serialization import default_restore_location
logger = logging.getLogger()
CheckpointState = collections.namedtuple(
"CheckpointState",
[
"model_dict",
"optimizer_dict",
"scheduler_dict",
"offset",
"epoch",
"encoder_params",
],
)
def setup_for_distributed_mode(
model: nn.Module,
optimizer: torch.optim.Optimizer,
device: object,
n_gpu: int = 1,
local_rank: int = -1,
fp16: bool = False,
fp16_opt_level: str = "O1",
) -> (nn.Module, torch.optim.Optimizer):
model.to(device)
if fp16:
try:
import apex
from apex import amp
apex.amp.register_half_function(torch, "einsum")
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
if local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True,
)
return model, optimizer
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
def move_to_device(sample, device):
if len(sample) == 0:
return {}
def _move_to_device(maybe_tensor, device):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(device)
elif isinstance(maybe_tensor, dict):
return {
key: _move_to_device(value, device)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_move_to_device(x, device) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return [_move_to_device(x, device) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_device(sample, device)
def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1):
"""Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0,
float(training_steps - current_step)
/ float(max(1, training_steps - warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def init_weights(modules: List):
for module in modules:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, "module") else model
def get_model_file(args, file_prefix) -> str:
if args.model_file and os.path.exists(args.model_file):
return args.model_file
out_cp_files = (
glob.glob(os.path.join(args.output_dir, file_prefix + "*"))
if args.output_dir
else []
)
logger.info("Checkpoint files %s", out_cp_files)
model_file = None
if len(out_cp_files) > 0:
model_file = max(out_cp_files, key=os.path.getctime)
return model_file
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info("Reading saved model from s", model_file)
if isinstance(model_file, tuple):
model_file = model_file[0]
state_dict = torch.load(
model_file, map_location=lambda s, l: default_restore_location(s, "cpu")
)
logger.info("model_state_dict keys %s", state_dict.keys())
return CheckpointState(**state_dict) |