|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import collections |
|
import contextlib |
|
import logging |
|
import os |
|
import re |
|
import time |
|
import traceback |
|
from collections import OrderedDict |
|
from typing import Any, Dict, Optional, Union |
|
from random import randint |
|
|
|
import torch |
|
from fairseq.dataclass.configs import CheckpointConfig |
|
from fairseq.dataclass.utils import ( |
|
convert_namespace_to_omegaconf, |
|
overwrite_args_by_name, |
|
) |
|
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP |
|
from fairseq.file_io import PathManager |
|
from fairseq.models import FairseqDecoder, FairseqEncoder |
|
from omegaconf import DictConfig, open_dict, OmegaConf |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): |
|
from fairseq import meters |
|
|
|
|
|
if trainer.data_parallel_rank == 0: |
|
os.makedirs(cfg.save_dir, exist_ok=True) |
|
|
|
prev_best = getattr(save_checkpoint, "best", val_loss) |
|
if val_loss is not None: |
|
best_function = max if cfg.maximize_best_checkpoint_metric else min |
|
save_checkpoint.best = best_function(val_loss, prev_best) |
|
|
|
if cfg.no_save: |
|
return |
|
|
|
trainer.consolidate_optimizer() |
|
|
|
if not trainer.should_save_checkpoint_on_current_rank: |
|
if trainer.always_call_state_dict_during_save_checkpoint: |
|
trainer.state_dict() |
|
return |
|
|
|
write_timer = meters.StopwatchMeter() |
|
write_timer.start() |
|
|
|
epoch = epoch_itr.epoch |
|
end_of_epoch = epoch_itr.end_of_epoch() |
|
updates = trainer.get_num_updates() |
|
|
|
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") |
|
|
|
def is_better(a, b): |
|
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b |
|
|
|
suffix = trainer.checkpoint_suffix |
|
checkpoint_conds = collections.OrderedDict() |
|
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( |
|
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 |
|
) |
|
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( |
|
not end_of_epoch |
|
and cfg.save_interval_updates > 0 |
|
and updates % cfg.save_interval_updates == 0 |
|
) |
|
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( |
|
not hasattr(save_checkpoint, "best") |
|
or is_better(val_loss, save_checkpoint.best) |
|
) |
|
if val_loss is not None and cfg.keep_best_checkpoints > 0: |
|
worst_best = getattr(save_checkpoint, "best", None) |
|
chkpts = checkpoint_paths( |
|
cfg.save_dir, |
|
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( |
|
cfg.best_checkpoint_metric |
|
), |
|
) |
|
if len(chkpts) > 0: |
|
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] |
|
worst_best = float(p.rsplit("_")[-1].replace(".pt", "")) |
|
|
|
rand_sfx = randint(0, cfg.keep_best_checkpoints) |
|
checkpoint_conds[ |
|
"checkpoint.best_{}_{:.3f}{}.pt".format(cfg.best_checkpoint_metric, |
|
val_loss, rand_sfx) |
|
] = worst_best is None or is_better(val_loss, worst_best) |
|
checkpoint_conds[ |
|
"checkpoint_last{}.pt".format(suffix) |
|
] = not cfg.no_last_checkpoints |
|
|
|
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} |
|
if hasattr(save_checkpoint, "best"): |
|
extra_state.update({"best": save_checkpoint.best}) |
|
|
|
checkpoints = [ |
|
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond |
|
] |
|
if len(checkpoints) > 0: |
|
trainer.save_checkpoint(checkpoints[0], extra_state) |
|
for cp in checkpoints[1:]: |
|
if cfg.write_checkpoints_asynchronously: |
|
|
|
|
|
logger.warning( |
|
f"ioPath is not copying {checkpoints[0]} to {cp} " |
|
"since async write mode is on." |
|
) |
|
else: |
|
assert PathManager.copy( |
|
checkpoints[0], cp, overwrite=True |
|
), f"Failed to copy {checkpoints[0]} to {cp}" |
|
|
|
write_timer.stop() |
|
logger.info( |
|
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( |
|
checkpoints[0], epoch, updates, val_loss, write_timer.sum |
|
) |
|
) |
|
|
|
if not end_of_epoch and cfg.keep_interval_updates > 0: |
|
|
|
if cfg.keep_interval_updates_pattern == -1: |
|
checkpoints = checkpoint_paths( |
|
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) |
|
) |
|
else: |
|
checkpoints = checkpoint_paths( |
|
cfg.save_dir, |
|
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), |
|
keep_match=True, |
|
) |
|
checkpoints = [ |
|
x[0] |
|
for x in checkpoints |
|
if x[1] % cfg.keep_interval_updates_pattern != 0 |
|
] |
|
|
|
for old_chk in checkpoints[cfg.keep_interval_updates :]: |
|
if os.path.lexists(old_chk): |
|
os.remove(old_chk) |
|
elif PathManager.exists(old_chk): |
|
PathManager.rm(old_chk) |
|
|
|
if cfg.keep_last_epochs > 0: |
|
|
|
checkpoints = checkpoint_paths( |
|
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) |
|
) |
|
for old_chk in checkpoints[cfg.keep_last_epochs :]: |
|
if os.path.lexists(old_chk): |
|
os.remove(old_chk) |
|
|
|
if cfg.keep_best_checkpoints > 0: |
|
|
|
checkpoints = checkpoint_paths( |
|
cfg.save_dir, |
|
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( |
|
cfg.best_checkpoint_metric, suffix |
|
), |
|
) |
|
if not cfg.maximize_best_checkpoint_metric: |
|
checkpoints = checkpoints[::-1] |
|
for old_chk in checkpoints[cfg.keep_best_checkpoints :]: |
|
if os.path.lexists(old_chk): |
|
os.remove(old_chk) |
|
|
|
|
|
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): |
|
""" |
|
Load a checkpoint and restore the training iterator. |
|
|
|
*passthrough_args* will be passed through to |
|
``trainer.get_train_iterator``. |
|
""" |
|
|
|
reset_optimizer = cfg.reset_optimizer |
|
reset_lr_scheduler = cfg.reset_lr_scheduler |
|
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) |
|
reset_meters = cfg.reset_meters |
|
reset_dataloader = cfg.reset_dataloader |
|
|
|
if cfg.finetune_from_model is not None and ( |
|
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader |
|
): |
|
raise ValueError( |
|
"--finetune-from-model can not be set together with either --reset-optimizer" |
|
" or reset_lr_scheduler or reset_meters or reset_dataloader" |
|
) |
|
|
|
suffix = trainer.checkpoint_suffix |
|
if ( |
|
cfg.restore_file == "checkpoint_last.pt" |
|
): |
|
checkpoint_path = os.path.join( |
|
cfg.save_dir, "checkpoint_last{}.pt".format(suffix) |
|
) |
|
first_launch = not PathManager.exists(checkpoint_path) |
|
if cfg.finetune_from_model is not None and first_launch: |
|
|
|
|
|
if PathManager.exists(cfg.finetune_from_model): |
|
checkpoint_path = cfg.finetune_from_model |
|
reset_optimizer = True |
|
reset_lr_scheduler = True |
|
reset_meters = True |
|
reset_dataloader = True |
|
logger.info( |
|
f"loading pretrained model from {checkpoint_path}: " |
|
"optimizer, lr scheduler, meters, dataloader will be reset" |
|
) |
|
else: |
|
raise ValueError( |
|
f"--funetune-from-model {cfg.finetune_from_model} does not exist" |
|
) |
|
elif suffix is not None: |
|
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") |
|
else: |
|
checkpoint_path = cfg.restore_file |
|
|
|
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: |
|
raise ValueError( |
|
"--finetune-from-model and --restore-file (non-default value) " |
|
"can not be specified together: " + str(cfg) |
|
) |
|
|
|
extra_state = trainer.load_checkpoint( |
|
checkpoint_path, |
|
reset_optimizer, |
|
reset_lr_scheduler, |
|
optimizer_overrides, |
|
reset_meters=reset_meters, |
|
) |
|
|
|
if ( |
|
extra_state is not None |
|
and "best" in extra_state |
|
and not reset_optimizer |
|
and not reset_meters |
|
): |
|
save_checkpoint.best = extra_state["best"] |
|
|
|
if extra_state is not None and not reset_dataloader: |
|
|
|
itr_state = extra_state["train_iterator"] |
|
epoch_itr = trainer.get_train_iterator( |
|
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args |
|
) |
|
epoch_itr.load_state_dict(itr_state) |
|
else: |
|
epoch_itr = trainer.get_train_iterator( |
|
epoch=1, load_dataset=True, **passthrough_args |
|
) |
|
|
|
trainer.lr_step(epoch_itr.epoch) |
|
|
|
return extra_state, epoch_itr |
|
|
|
|
|
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): |
|
"""Loads a checkpoint to CPU (with upgrading for backward compatibility). |
|
|
|
If doing single-GPU training or if the checkpoint is only being loaded by at |
|
most one process on each node (current default behavior is for only rank 0 |
|
to read the checkpoint from disk), load_on_all_ranks should be False to |
|
avoid errors from torch.distributed not having been initialized or |
|
torch.distributed.barrier() hanging. |
|
|
|
If all processes on each node may be loading the checkpoint |
|
simultaneously, load_on_all_ranks should be set to True to avoid I/O |
|
conflicts. |
|
|
|
There's currently no support for > 1 but < all processes loading the |
|
checkpoint on each node. |
|
""" |
|
local_path = PathManager.get_local_path(path) |
|
|
|
|
|
|
|
|
|
if local_path != path and PathManager.path_requires_pathmanager(path): |
|
try: |
|
os.remove(local_path) |
|
except FileNotFoundError: |
|
|
|
|
|
|
|
pass |
|
if load_on_all_ranks: |
|
torch.distributed.barrier() |
|
local_path = PathManager.get_local_path(path) |
|
|
|
with open(local_path, "rb") as f: |
|
state = torch.load(f, map_location=torch.device("cpu")) |
|
|
|
if "args" in state and state["args"] is not None and arg_overrides is not None: |
|
args = state["args"] |
|
for arg_name, arg_val in arg_overrides.items(): |
|
setattr(args, arg_name, arg_val) |
|
|
|
if "cfg" in state and state["cfg"] is not None: |
|
|
|
|
|
|
|
from omegaconf import _utils |
|
|
|
old_primitive = _utils.is_primitive_type |
|
_utils.is_primitive_type = lambda _: True |
|
|
|
state["cfg"] = OmegaConf.create(state["cfg"]) |
|
|
|
_utils.is_primitive_type = old_primitive |
|
OmegaConf.set_struct(state["cfg"], True) |
|
|
|
if arg_overrides is not None: |
|
overwrite_args_by_name(state["cfg"], arg_overrides) |
|
|
|
state = _upgrade_state_dict(state) |
|
return state |
|
|
|
|
|
def load_model_ensemble( |
|
filenames, |
|
arg_overrides: Optional[Dict[str, Any]] = None, |
|
task=None, |
|
strict=True, |
|
suffix="", |
|
num_shards=1, |
|
state=None, |
|
): |
|
"""Loads an ensemble of models. |
|
|
|
Args: |
|
filenames (List[str]): checkpoint files to load |
|
arg_overrides (Dict[str,Any], optional): override model args that |
|
were used during model training |
|
task (fairseq.tasks.FairseqTask, optional): task to use for loading |
|
""" |
|
assert not ( |
|
strict and num_shards > 1 |
|
), "Cannot load state dict with strict=True and checkpoint shards > 1" |
|
ensemble, args, _task = load_model_ensemble_and_task( |
|
filenames, |
|
arg_overrides, |
|
task, |
|
strict, |
|
suffix, |
|
num_shards, |
|
state, |
|
) |
|
return ensemble, args |
|
|
|
|
|
def get_maybe_sharded_checkpoint_filename( |
|
filename: str, suffix: str, shard_idx: int, num_shards: int |
|
) -> str: |
|
orig_filename = filename |
|
filename = filename.replace(".pt", suffix + ".pt") |
|
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" |
|
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" |
|
if PathManager.exists(fsdp_filename): |
|
return fsdp_filename |
|
elif num_shards > 1: |
|
return model_parallel_filename |
|
else: |
|
return filename |
|
|
|
|
|
def load_model_ensemble_and_task( |
|
filenames, |
|
arg_overrides: Optional[Dict[str, Any]] = None, |
|
task=None, |
|
strict=True, |
|
suffix="", |
|
num_shards=1, |
|
state=None, |
|
): |
|
assert state is None or len(filenames) == 1 |
|
|
|
from fairseq import tasks |
|
|
|
assert not ( |
|
strict and num_shards > 1 |
|
), "Cannot load state dict with strict=True and checkpoint shards > 1" |
|
ensemble = [] |
|
cfg = None |
|
for filename in filenames: |
|
orig_filename = filename |
|
model_shard_state = {"shard_weights": [], "shard_metadata": []} |
|
assert num_shards > 0 |
|
st = time.time() |
|
for shard_idx in range(num_shards): |
|
filename = get_maybe_sharded_checkpoint_filename( |
|
orig_filename, suffix, shard_idx, num_shards |
|
) |
|
|
|
if not PathManager.exists(filename): |
|
raise IOError("Model file not found: {}".format(filename)) |
|
if state is None: |
|
state = load_checkpoint_to_cpu(filename, arg_overrides) |
|
if "args" in state and state["args"] is not None: |
|
cfg = convert_namespace_to_omegaconf(state["args"]) |
|
elif "cfg" in state and state["cfg"] is not None: |
|
cfg = state["cfg"] |
|
else: |
|
raise RuntimeError( |
|
f"Neither args nor cfg exist in state keys = {state.keys()}" |
|
) |
|
|
|
if task is None: |
|
task = tasks.setup_task(cfg.task) |
|
|
|
if "task_state" in state: |
|
task.load_state_dict(state["task_state"]) |
|
|
|
if "fsdp_metadata" in state and num_shards > 1: |
|
model_shard_state["shard_weights"].append(state["model"]) |
|
model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) |
|
|
|
if not has_FSDP: |
|
raise ImportError( |
|
"Cannot find FullyShardedDataParallel. " |
|
"Please install fairscale with: pip install fairscale" |
|
) |
|
if shard_idx == num_shards - 1: |
|
consolidated_model_state = FSDP.consolidate_shard_weights( |
|
shard_weights=model_shard_state["shard_weights"], |
|
shard_metadata=model_shard_state["shard_metadata"], |
|
) |
|
model = task.build_model(cfg.model) |
|
model.load_state_dict( |
|
consolidated_model_state, strict=strict, model_cfg=cfg.model |
|
) |
|
else: |
|
|
|
model = task.build_model(cfg.model) |
|
model.load_state_dict( |
|
state["model"], strict=strict, model_cfg=cfg.model |
|
) |
|
|
|
|
|
state = None |
|
if shard_idx % 10 == 0 and shard_idx > 0: |
|
elapsed = time.time() - st |
|
logger.info(f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard") |
|
|
|
|
|
ensemble.append(model) |
|
return ensemble, cfg, task |
|
|
|
|
|
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): |
|
"""Retrieves all checkpoints found in `path` directory. |
|
|
|
Checkpoints are identified by matching filename to the specified pattern. If |
|
the pattern contains groups, the result will be sorted by the first group in |
|
descending order. |
|
""" |
|
pt_regexp = re.compile(pattern) |
|
files = PathManager.ls(path) |
|
|
|
entries = [] |
|
for i, f in enumerate(files): |
|
m = pt_regexp.fullmatch(f) |
|
if m is not None: |
|
idx = float(m.group(1)) if len(m.groups()) > 0 else i |
|
entries.append((idx, m.group(0))) |
|
if keep_match: |
|
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] |
|
else: |
|
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] |
|
|
|
|
|
def torch_persistent_save(obj, filename, async_write: bool = False): |
|
if async_write: |
|
with PathManager.opena(filename, "wb") as f: |
|
_torch_persistent_save(obj, f) |
|
else: |
|
if PathManager.supports_rename(filename): |
|
|
|
with PathManager.open(filename + ".tmp", "wb") as f: |
|
_torch_persistent_save(obj, f) |
|
PathManager.rename(filename + ".tmp", filename) |
|
else: |
|
|
|
with PathManager.open(filename, "wb") as f: |
|
_torch_persistent_save(obj, f) |
|
|
|
|
|
def _torch_persistent_save(obj, f): |
|
if isinstance(f, str): |
|
with PathManager.open(f, "wb") as h: |
|
torch_persistent_save(obj, h) |
|
return |
|
for i in range(3): |
|
try: |
|
return torch.save(obj, f) |
|
except Exception: |
|
if i == 2: |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
def _upgrade_state_dict(state): |
|
"""Helper for upgrading old model checkpoints.""" |
|
|
|
|
|
if "optimizer_history" not in state: |
|
state["optimizer_history"] = [ |
|
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} |
|
] |
|
state["last_optimizer_state"] = state["optimizer"] |
|
del state["optimizer"] |
|
del state["best_loss"] |
|
|
|
if "epoch" in state and "extra_state" not in state: |
|
state["extra_state"] = { |
|
"epoch": state["epoch"], |
|
"batch_offset": state["batch_offset"], |
|
"val_loss": state["val_loss"], |
|
} |
|
del state["epoch"] |
|
del state["batch_offset"] |
|
del state["val_loss"] |
|
|
|
if "optimizer" in state["optimizer_history"][-1]: |
|
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] |
|
for optim_hist in state["optimizer_history"]: |
|
del optim_hist["optimizer"] |
|
|
|
if "optimizer_name" not in state["optimizer_history"][-1]: |
|
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" |
|
|
|
if "lr_scheduler_state" not in state["optimizer_history"][-1]: |
|
state["optimizer_history"][-1]["lr_scheduler_state"] = { |
|
"best": state["optimizer_history"][-1]["best_loss"] |
|
} |
|
del state["optimizer_history"][-1]["best_loss"] |
|
|
|
if "num_updates" not in state["optimizer_history"][-1]: |
|
state["optimizer_history"][-1]["num_updates"] = 0 |
|
|
|
if ( |
|
"args" in state |
|
and hasattr(state["args"], "max_positions") |
|
and not hasattr(state["args"], "max_source_positions") |
|
): |
|
state["args"].max_source_positions = state["args"].max_positions |
|
state["args"].max_target_positions = state["args"].max_positions |
|
|
|
if "train_iterator" not in state["extra_state"]: |
|
state["extra_state"]["train_iterator"] = { |
|
"epoch": state["extra_state"]["epoch"], |
|
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0), |
|
} |
|
|
|
|
|
if "args" in state and state["args"] is not None: |
|
|
|
if not hasattr(state["args"], "task"): |
|
state["args"].task = "translation" |
|
|
|
if getattr(state["args"], "raw_text", False): |
|
state["args"].dataset_impl = "raw" |
|
elif getattr(state["args"], "lazy_load", False): |
|
state["args"].dataset_impl = "lazy" |
|
|
|
if state["extra_state"]["train_iterator"] is not None: |
|
state["extra_state"]["train_iterator"]["epoch"] = max( |
|
state["extra_state"]["train_iterator"].get("epoch", 1), 1 |
|
) |
|
|
|
if hasattr(state["args"], "remove_bpe"): |
|
state["args"].post_process = state["args"].remove_bpe |
|
|
|
if hasattr(state["args"], "min_lr"): |
|
state["args"].stop_min_lr = state["args"].min_lr |
|
del state["args"].min_lr |
|
|
|
if ( |
|
hasattr(state["args"], "criterion") |
|
and state["args"].criterion in [ |
|
"binary_cross_entropy", |
|
"kd_binary_cross_entropy", |
|
] |
|
): |
|
state["args"].criterion = "wav2vec" |
|
|
|
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: |
|
delattr(state["args"], "log_keys") |
|
|
|
if ( |
|
hasattr(state["args"], "task") |
|
and state["args"].task == "speech_pretraining" |
|
): |
|
state["args"].task = "audio_pretraining" |
|
|
|
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": |
|
state["args"].arch = "wav2vec" |
|
|
|
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): |
|
state["args"].lr = [state["args"].lr] |
|
|
|
if ( |
|
hasattr(state["args"], "data") |
|
and isinstance(state["args"].data, list) |
|
and len(state["args"].data) > 0 |
|
): |
|
state["args"].data = state["args"].data[0] |
|
|
|
for key in [ |
|
"static_teachers", |
|
"static_teacher_weights", |
|
"dynamic_teachers", |
|
"dynamic_teacher_weights", |
|
]: |
|
if key in state["args"]: |
|
delattr(state["args"], key) |
|
|
|
state["cfg"] = convert_namespace_to_omegaconf(state["args"]) |
|
|
|
if "cfg" in state and state["cfg"] is not None: |
|
cfg = state["cfg"] |
|
with open_dict(cfg): |
|
|
|
if ( |
|
"task" in cfg |
|
and "eval_wer_config" in cfg.task |
|
and isinstance(cfg.task.eval_wer_config.print_alignment, bool) |
|
): |
|
cfg.task.eval_wer_config.print_alignment = "hard" |
|
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): |
|
cfg.generation.print_alignment = "hard" |
|
if ( |
|
"model" in cfg |
|
and "w2v_args" in cfg.model |
|
and cfg.model.w2v_args is not None |
|
and ( |
|
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args |
|
) |
|
and hasattr(cfg.model.w2v_args.task, "eval_wer_config") |
|
and cfg.model.w2v_args.task.eval_wer_config is not None |
|
and isinstance( |
|
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool |
|
) |
|
): |
|
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" |
|
|
|
return state |
|
|
|
|
|
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): |
|
"""Prune the given state_dict if desired for LayerDrop |
|
(https://arxiv.org/abs/1909.11556). |
|
|
|
Training with LayerDrop allows models to be robust to pruning at inference |
|
time. This function prunes state_dict to allow smaller models to be loaded |
|
from a larger model and re-maps the existing state_dict for this to occur. |
|
|
|
It's called by functions that load models from checkpoints and does not |
|
need to be called directly. |
|
""" |
|
arch = None |
|
if model_cfg is not None: |
|
arch = ( |
|
model_cfg._name |
|
if isinstance(model_cfg, DictConfig) |
|
else getattr(model_cfg, "arch", None) |
|
) |
|
|
|
if not model_cfg or arch is None or arch == "ptt_transformer": |
|
|
|
return state_dict |
|
|
|
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) |
|
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) |
|
|
|
if not encoder_layers_to_keep and not decoder_layers_to_keep: |
|
return state_dict |
|
|
|
|
|
logger.info( |
|
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" |
|
) |
|
|
|
def create_pruning_pass(layers_to_keep, layer_name): |
|
keep_layers = sorted( |
|
int(layer_string) for layer_string in layers_to_keep.split(",") |
|
) |
|
mapping_dict = {} |
|
for i in range(len(keep_layers)): |
|
mapping_dict[str(keep_layers[i])] = str(i) |
|
|
|
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) |
|
return {"substitution_regex": regex, "mapping_dict": mapping_dict} |
|
|
|
pruning_passes = [] |
|
if encoder_layers_to_keep: |
|
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) |
|
if decoder_layers_to_keep: |
|
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) |
|
|
|
new_state_dict = {} |
|
for layer_name in state_dict.keys(): |
|
match = re.search(r"\.layers\.(\d+)\.", layer_name) |
|
|
|
|
|
if not match: |
|
new_state_dict[layer_name] = state_dict[layer_name] |
|
continue |
|
|
|
|
|
original_layer_number = match.group(1) |
|
|
|
for pruning_pass in pruning_passes: |
|
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ |
|
"substitution_regex" |
|
].search(layer_name): |
|
new_layer_number = pruning_pass["mapping_dict"][original_layer_number] |
|
substitution_match = pruning_pass["substitution_regex"].search( |
|
layer_name |
|
) |
|
new_state_key = ( |
|
layer_name[: substitution_match.start(1)] |
|
+ new_layer_number |
|
+ layer_name[substitution_match.end(1) :] |
|
) |
|
new_state_dict[new_state_key] = state_dict[layer_name] |
|
|
|
|
|
|
|
if isinstance(model_cfg, DictConfig): |
|
context = open_dict(model_cfg) |
|
else: |
|
context = contextlib.ExitStack() |
|
with context: |
|
if hasattr(model_cfg, "encoder_layers_to_keep"): |
|
model_cfg.encoder_layers_to_keep = None |
|
if hasattr(model_cfg, "decoder_layers_to_keep"): |
|
model_cfg.decoder_layers_to_keep = None |
|
|
|
return new_state_dict |
|
|
|
|
|
def load_pretrained_component_from_model( |
|
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str |
|
): |
|
""" |
|
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the |
|
provided `component` object. If state_dict fails to load, there may be a |
|
mismatch in the architecture of the corresponding `component` found in the |
|
`checkpoint` file. |
|
""" |
|
if not PathManager.exists(checkpoint): |
|
raise IOError("Model file not found: {}".format(checkpoint)) |
|
state = load_checkpoint_to_cpu(checkpoint) |
|
if isinstance(component, FairseqEncoder): |
|
component_type = "encoder" |
|
elif isinstance(component, FairseqDecoder): |
|
component_type = "decoder" |
|
else: |
|
raise ValueError( |
|
"component to load must be either a FairseqEncoder or " |
|
"FairseqDecoder. Loading other component types are not supported." |
|
) |
|
component_state_dict = OrderedDict() |
|
for key in state["model"].keys(): |
|
if key.startswith(component_type): |
|
|
|
component_subkey = key[len(component_type) + 1 :] |
|
component_state_dict[component_subkey] = state["model"][key] |
|
component.load_state_dict(component_state_dict, strict=True) |
|
return component |
|
|
|
|
|
def verify_checkpoint_directory(save_dir: str) -> None: |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir, exist_ok=True) |
|
temp_file_path = os.path.join(save_dir, "dummy") |
|
try: |
|
with open(temp_file_path, "w"): |
|
pass |
|
except OSError as e: |
|
logger.warning( |
|
"Unable to access checkpoint save directory: {}".format(save_dir) |
|
) |
|
raise e |
|
else: |
|
os.remove(temp_file_path) |
|
|