Spaces:
Runtime error
Runtime error
kevinwang676
commited on
Create load_model.py
Browse files- load_model.py +936 -0
load_model.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import collections
|
8 |
+
import contextlib
|
9 |
+
import inspect
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
import time
|
14 |
+
import traceback
|
15 |
+
from collections import OrderedDict
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Any, Dict, Optional, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from fairseq.data import data_utils
|
22 |
+
from fairseq.dataclass.configs import CheckpointConfig
|
23 |
+
from fairseq.dataclass.utils import (
|
24 |
+
convert_namespace_to_omegaconf,
|
25 |
+
overwrite_args_by_name,
|
26 |
+
)
|
27 |
+
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
28 |
+
from fairseq.file_io import PathManager
|
29 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
30 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
36 |
+
from fairseq import meters
|
37 |
+
|
38 |
+
# only one worker should attempt to create the required dir
|
39 |
+
if trainer.data_parallel_rank == 0:
|
40 |
+
os.makedirs(cfg.save_dir, exist_ok=True)
|
41 |
+
|
42 |
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
43 |
+
if val_loss is not None:
|
44 |
+
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
45 |
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
46 |
+
|
47 |
+
if cfg.no_save:
|
48 |
+
return None
|
49 |
+
|
50 |
+
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
51 |
+
|
52 |
+
if not trainer.should_save_checkpoint_on_current_rank:
|
53 |
+
if trainer.always_call_state_dict_during_save_checkpoint:
|
54 |
+
trainer.state_dict()
|
55 |
+
return None
|
56 |
+
|
57 |
+
write_timer = meters.StopwatchMeter()
|
58 |
+
write_timer.start()
|
59 |
+
|
60 |
+
epoch = epoch_itr.epoch
|
61 |
+
end_of_epoch = epoch_itr.end_of_epoch()
|
62 |
+
updates = trainer.get_num_updates()
|
63 |
+
|
64 |
+
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
65 |
+
|
66 |
+
def is_better(a, b):
|
67 |
+
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
68 |
+
|
69 |
+
suffix = trainer.checkpoint_suffix
|
70 |
+
checkpoint_conds = collections.OrderedDict()
|
71 |
+
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
72 |
+
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
73 |
+
)
|
74 |
+
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
75 |
+
not end_of_epoch
|
76 |
+
and cfg.save_interval_updates > 0
|
77 |
+
and updates % cfg.save_interval_updates == 0
|
78 |
+
)
|
79 |
+
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
80 |
+
not hasattr(save_checkpoint, "best")
|
81 |
+
or is_better(val_loss, save_checkpoint.best)
|
82 |
+
)
|
83 |
+
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
84 |
+
worst_best = getattr(save_checkpoint, "best", None)
|
85 |
+
chkpts = checkpoint_paths(
|
86 |
+
cfg.save_dir,
|
87 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
88 |
+
cfg.best_checkpoint_metric, suffix
|
89 |
+
),
|
90 |
+
)
|
91 |
+
if len(chkpts) > 0:
|
92 |
+
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
93 |
+
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
94 |
+
# add random digits to resolve ties
|
95 |
+
with data_utils.numpy_seed(epoch, updates, val_loss):
|
96 |
+
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
97 |
+
|
98 |
+
checkpoint_conds[
|
99 |
+
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
100 |
+
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
101 |
+
)
|
102 |
+
] = worst_best is None or is_better(val_loss, worst_best)
|
103 |
+
checkpoint_conds[
|
104 |
+
"checkpoint_last{}.pt".format(suffix)
|
105 |
+
] = not cfg.no_last_checkpoints
|
106 |
+
|
107 |
+
extra_state = {
|
108 |
+
"train_iterator": epoch_itr.state_dict(),
|
109 |
+
"val_loss": val_loss,
|
110 |
+
}
|
111 |
+
|
112 |
+
# Going forward, different tasks could expose an API like this to dump all
|
113 |
+
# the checkpoint worthy attributes in a dictionary which then will be
|
114 |
+
# merged with the parent dictionary to create the "extra_state". This
|
115 |
+
# allows for an extensible yet simple design to checkpoint task level
|
116 |
+
# attributes
|
117 |
+
if hasattr(trainer.task, "get_checkpoint_dict"):
|
118 |
+
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
|
119 |
+
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
|
120 |
+
|
121 |
+
if hasattr(save_checkpoint, "best"):
|
122 |
+
extra_state.update({"best": save_checkpoint.best})
|
123 |
+
|
124 |
+
checkpoints = [
|
125 |
+
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
126 |
+
]
|
127 |
+
saved_cp = None
|
128 |
+
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
129 |
+
saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
|
130 |
+
for cp in checkpoints[1:]:
|
131 |
+
if cfg.write_checkpoints_asynchronously:
|
132 |
+
# TODO[ioPath]: Need to implement a delayed asynchronous
|
133 |
+
# file copying/moving feature.
|
134 |
+
logger.warning(
|
135 |
+
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
136 |
+
"since async write mode is on."
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
assert PathManager.copy(
|
140 |
+
checkpoints[0], cp, overwrite=True
|
141 |
+
), f"Failed to copy {checkpoints[0]} to {cp}"
|
142 |
+
|
143 |
+
write_timer.stop()
|
144 |
+
logger.info(
|
145 |
+
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
146 |
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
if (
|
151 |
+
not end_of_epoch
|
152 |
+
and cfg.keep_interval_updates > 0
|
153 |
+
and trainer.should_save_checkpoint_on_current_rank
|
154 |
+
):
|
155 |
+
# remove old checkpoints; checkpoints are sorted in descending order
|
156 |
+
if cfg.keep_interval_updates_pattern == -1:
|
157 |
+
checkpoints = checkpoint_paths(
|
158 |
+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
checkpoints = checkpoint_paths(
|
162 |
+
cfg.save_dir,
|
163 |
+
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
164 |
+
keep_match=True,
|
165 |
+
)
|
166 |
+
checkpoints = [
|
167 |
+
x[0]
|
168 |
+
for x in checkpoints
|
169 |
+
if x[1] % cfg.keep_interval_updates_pattern != 0
|
170 |
+
]
|
171 |
+
|
172 |
+
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
173 |
+
if os.path.lexists(old_chk):
|
174 |
+
os.remove(old_chk)
|
175 |
+
elif PathManager.exists(old_chk):
|
176 |
+
PathManager.rm(old_chk)
|
177 |
+
|
178 |
+
if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
|
179 |
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
180 |
+
checkpoints = checkpoint_paths(
|
181 |
+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
182 |
+
)
|
183 |
+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
184 |
+
if os.path.lexists(old_chk):
|
185 |
+
os.remove(old_chk)
|
186 |
+
elif PathManager.exists(old_chk):
|
187 |
+
PathManager.rm(old_chk)
|
188 |
+
|
189 |
+
if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
|
190 |
+
# only keep the best N checkpoints according to validation metric
|
191 |
+
checkpoints = checkpoint_paths(
|
192 |
+
cfg.save_dir,
|
193 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
194 |
+
cfg.best_checkpoint_metric, suffix
|
195 |
+
),
|
196 |
+
)
|
197 |
+
if not cfg.maximize_best_checkpoint_metric:
|
198 |
+
checkpoints = checkpoints[::-1]
|
199 |
+
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
200 |
+
if os.path.lexists(old_chk):
|
201 |
+
os.remove(old_chk)
|
202 |
+
elif PathManager.exists(old_chk):
|
203 |
+
PathManager.rm(old_chk)
|
204 |
+
|
205 |
+
return saved_cp
|
206 |
+
|
207 |
+
|
208 |
+
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
209 |
+
"""
|
210 |
+
Load a checkpoint and restore the training iterator.
|
211 |
+
|
212 |
+
*passthrough_args* will be passed through to
|
213 |
+
``trainer.get_train_iterator``.
|
214 |
+
"""
|
215 |
+
|
216 |
+
reset_optimizer = cfg.reset_optimizer
|
217 |
+
reset_lr_scheduler = cfg.reset_lr_scheduler
|
218 |
+
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
219 |
+
reset_meters = cfg.reset_meters
|
220 |
+
reset_dataloader = cfg.reset_dataloader
|
221 |
+
|
222 |
+
if cfg.finetune_from_model is not None and (
|
223 |
+
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
224 |
+
):
|
225 |
+
raise ValueError(
|
226 |
+
"--finetune-from-model can not be set together with either --reset-optimizer"
|
227 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
228 |
+
)
|
229 |
+
|
230 |
+
suffix = trainer.checkpoint_suffix
|
231 |
+
if (
|
232 |
+
cfg.restore_file == "checkpoint_last.pt"
|
233 |
+
): # default value of restore_file is 'checkpoint_last.pt'
|
234 |
+
checkpoint_path = os.path.join(
|
235 |
+
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
236 |
+
)
|
237 |
+
first_launch = not PathManager.exists(checkpoint_path)
|
238 |
+
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
239 |
+
checkpoint_path = cfg.continue_once
|
240 |
+
elif cfg.finetune_from_model is not None and first_launch:
|
241 |
+
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
242 |
+
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
243 |
+
if PathManager.exists(cfg.finetune_from_model):
|
244 |
+
checkpoint_path = cfg.finetune_from_model
|
245 |
+
reset_optimizer = True
|
246 |
+
reset_lr_scheduler = True
|
247 |
+
reset_meters = True
|
248 |
+
reset_dataloader = True
|
249 |
+
logger.info(
|
250 |
+
f"loading pretrained model from {checkpoint_path}: "
|
251 |
+
"optimizer, lr scheduler, meters, dataloader will be reset"
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
raise ValueError(
|
255 |
+
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
256 |
+
)
|
257 |
+
elif suffix is not None:
|
258 |
+
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
259 |
+
else:
|
260 |
+
checkpoint_path = cfg.restore_file
|
261 |
+
|
262 |
+
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
263 |
+
raise ValueError(
|
264 |
+
"--finetune-from-model and --restore-file (non-default value) "
|
265 |
+
"can not be specified together: " + str(cfg)
|
266 |
+
)
|
267 |
+
|
268 |
+
extra_state = trainer.load_checkpoint(
|
269 |
+
checkpoint_path,
|
270 |
+
reset_optimizer,
|
271 |
+
reset_lr_scheduler,
|
272 |
+
optimizer_overrides,
|
273 |
+
reset_meters=reset_meters,
|
274 |
+
)
|
275 |
+
|
276 |
+
if (
|
277 |
+
extra_state is not None
|
278 |
+
and "best" in extra_state
|
279 |
+
and not reset_optimizer
|
280 |
+
and not reset_meters
|
281 |
+
):
|
282 |
+
save_checkpoint.best = extra_state["best"]
|
283 |
+
|
284 |
+
if extra_state is not None and not reset_dataloader:
|
285 |
+
# restore iterator from checkpoint
|
286 |
+
itr_state = extra_state["train_iterator"]
|
287 |
+
epoch_itr = trainer.get_train_iterator(
|
288 |
+
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
289 |
+
)
|
290 |
+
epoch_itr.load_state_dict(itr_state)
|
291 |
+
|
292 |
+
# Preload the checkpoint for the task
|
293 |
+
task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
|
294 |
+
if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
|
295 |
+
trainer.task.set_checkpoint_dict(task_cp_dict)
|
296 |
+
else:
|
297 |
+
epoch_itr = trainer.get_train_iterator(
|
298 |
+
epoch=1, load_dataset=True, **passthrough_args
|
299 |
+
)
|
300 |
+
|
301 |
+
trainer.lr_step(epoch_itr.epoch)
|
302 |
+
|
303 |
+
return extra_state, epoch_itr
|
304 |
+
|
305 |
+
|
306 |
+
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
307 |
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
308 |
+
|
309 |
+
If doing single-GPU training or if the checkpoint is only being loaded by at
|
310 |
+
most one process on each node (current default behavior is for only rank 0
|
311 |
+
to read the checkpoint from disk), load_on_all_ranks should be False to
|
312 |
+
avoid errors from torch.distributed not having been initialized or
|
313 |
+
torch.distributed.barrier() hanging.
|
314 |
+
|
315 |
+
If all processes on each node may be loading the checkpoint
|
316 |
+
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
317 |
+
conflicts.
|
318 |
+
|
319 |
+
There's currently no support for > 1 but < all processes loading the
|
320 |
+
checkpoint on each node.
|
321 |
+
"""
|
322 |
+
local_path = PathManager.get_local_path(path)
|
323 |
+
# The locally cached file returned by get_local_path() may be stale for
|
324 |
+
# remote files that are periodically updated/overwritten (ex:
|
325 |
+
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
326 |
+
# (if needed), and then download a fresh copy.
|
327 |
+
if local_path != path and PathManager.path_requires_pathmanager(path):
|
328 |
+
try:
|
329 |
+
os.remove(local_path)
|
330 |
+
except FileNotFoundError:
|
331 |
+
# With potentially multiple processes removing the same file, the
|
332 |
+
# file being missing is benign (missing_ok isn't available until
|
333 |
+
# Python 3.8).
|
334 |
+
pass
|
335 |
+
if load_on_all_ranks:
|
336 |
+
torch.distributed.barrier()
|
337 |
+
local_path = PathManager.get_local_path(path)
|
338 |
+
|
339 |
+
with open(local_path, "rb") as f:
|
340 |
+
state = torch.load(f, map_location=torch.device("cpu"))
|
341 |
+
|
342 |
+
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
343 |
+
args = state["args"]
|
344 |
+
for arg_name, arg_val in arg_overrides.items():
|
345 |
+
setattr(args, arg_name, arg_val)
|
346 |
+
|
347 |
+
if "cfg" in state and state["cfg"] is not None:
|
348 |
+
|
349 |
+
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
350 |
+
# omegaconf version that supports object flags, or when we migrate all existing models
|
351 |
+
from omegaconf import __version__ as oc_version
|
352 |
+
from omegaconf import _utils
|
353 |
+
|
354 |
+
if oc_version < "2.2":
|
355 |
+
old_primitive = _utils.is_primitive_type
|
356 |
+
_utils.is_primitive_type = lambda _: True
|
357 |
+
|
358 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
359 |
+
|
360 |
+
_utils.is_primitive_type = old_primitive
|
361 |
+
OmegaConf.set_struct(state["cfg"], True)
|
362 |
+
else:
|
363 |
+
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
364 |
+
|
365 |
+
if arg_overrides is not None:
|
366 |
+
overwrite_args_by_name(state["cfg"], arg_overrides)
|
367 |
+
|
368 |
+
state = _upgrade_state_dict(state)
|
369 |
+
return state
|
370 |
+
|
371 |
+
|
372 |
+
def load_model_ensemble(
|
373 |
+
filenames,
|
374 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
375 |
+
task=None,
|
376 |
+
strict=True,
|
377 |
+
suffix="",
|
378 |
+
num_shards=1,
|
379 |
+
state=None,
|
380 |
+
):
|
381 |
+
"""Loads an ensemble of models.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
filenames (List[str]): checkpoint files to load
|
385 |
+
arg_overrides (Dict[str,Any], optional): override model args that
|
386 |
+
were used during model training
|
387 |
+
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
388 |
+
"""
|
389 |
+
assert not (
|
390 |
+
strict and num_shards > 1
|
391 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
392 |
+
ensemble, args, _task = load_model_ensemble_and_task(
|
393 |
+
filenames,
|
394 |
+
arg_overrides,
|
395 |
+
task,
|
396 |
+
strict,
|
397 |
+
suffix,
|
398 |
+
num_shards,
|
399 |
+
state,
|
400 |
+
)
|
401 |
+
return ensemble, args
|
402 |
+
|
403 |
+
|
404 |
+
def get_maybe_sharded_checkpoint_filename(
|
405 |
+
filename: str, suffix: str, shard_idx: int, num_shards: int
|
406 |
+
) -> str:
|
407 |
+
orig_filename = filename
|
408 |
+
filename = filename.replace(".pt", suffix + ".pt")
|
409 |
+
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
410 |
+
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
411 |
+
if PathManager.exists(fsdp_filename):
|
412 |
+
return fsdp_filename
|
413 |
+
elif num_shards > 1:
|
414 |
+
return model_parallel_filename
|
415 |
+
else:
|
416 |
+
return filename
|
417 |
+
|
418 |
+
|
419 |
+
def load_model_ensemble_and_task(
|
420 |
+
filenames,
|
421 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
422 |
+
task=None,
|
423 |
+
strict=True,
|
424 |
+
suffix="",
|
425 |
+
num_shards=1,
|
426 |
+
state=None,
|
427 |
+
):
|
428 |
+
assert state is None or len(filenames) == 1
|
429 |
+
|
430 |
+
from fairseq import tasks
|
431 |
+
|
432 |
+
assert not (
|
433 |
+
strict and num_shards > 1
|
434 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
435 |
+
ensemble = []
|
436 |
+
cfg = None
|
437 |
+
for filename in filenames:
|
438 |
+
orig_filename = filename
|
439 |
+
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
440 |
+
assert num_shards > 0
|
441 |
+
st = time.time()
|
442 |
+
for shard_idx in range(num_shards):
|
443 |
+
filename = get_maybe_sharded_checkpoint_filename(
|
444 |
+
orig_filename, suffix, shard_idx, num_shards
|
445 |
+
)
|
446 |
+
|
447 |
+
if not PathManager.exists(filename):
|
448 |
+
raise IOError("Model file not found: {}".format(filename))
|
449 |
+
if state is None:
|
450 |
+
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
451 |
+
if "args" in state and state["args"] is not None:
|
452 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
453 |
+
elif "cfg" in state and state["cfg"] is not None:
|
454 |
+
cfg = state["cfg"]
|
455 |
+
else:
|
456 |
+
raise RuntimeError(
|
457 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
458 |
+
)
|
459 |
+
|
460 |
+
if task is None:
|
461 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
462 |
+
|
463 |
+
if "task_state" in state:
|
464 |
+
task.load_state_dict(state["task_state"])
|
465 |
+
|
466 |
+
argspec = inspect.getfullargspec(task.build_model)
|
467 |
+
|
468 |
+
if "fsdp_metadata" in state and num_shards > 1:
|
469 |
+
model_shard_state["shard_weights"].append(state["model"])
|
470 |
+
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
471 |
+
# check FSDP import before the code goes too far
|
472 |
+
if not has_FSDP:
|
473 |
+
raise ImportError(
|
474 |
+
"Cannot find FullyShardedDataParallel. "
|
475 |
+
"Please install fairscale with: pip install fairscale"
|
476 |
+
)
|
477 |
+
if shard_idx == num_shards - 1:
|
478 |
+
consolidated_model_state = FSDP.consolidate_shard_weights(
|
479 |
+
shard_weights=model_shard_state["shard_weights"],
|
480 |
+
shard_metadata=model_shard_state["shard_metadata"],
|
481 |
+
)
|
482 |
+
if "from_checkpoint" in argspec.args:
|
483 |
+
model = task.build_model(cfg.model, from_checkpoint=True)
|
484 |
+
else:
|
485 |
+
model = task.build_model(cfg.model)
|
486 |
+
if (
|
487 |
+
"optimizer_history" in state
|
488 |
+
and len(state["optimizer_history"]) > 0
|
489 |
+
and "num_updates" in state["optimizer_history"][-1]
|
490 |
+
):
|
491 |
+
model.set_num_updates(
|
492 |
+
state["optimizer_history"][-1]["num_updates"]
|
493 |
+
)
|
494 |
+
model.load_state_dict(
|
495 |
+
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
496 |
+
)
|
497 |
+
else:
|
498 |
+
# model parallel checkpoint or unsharded checkpoint
|
499 |
+
# support old external tasks
|
500 |
+
|
501 |
+
if "from_checkpoint" in argspec.args:
|
502 |
+
model = task.build_model(cfg.model, from_checkpoint=True)
|
503 |
+
else:
|
504 |
+
model = task.build_model(cfg.model)
|
505 |
+
if (
|
506 |
+
"optimizer_history" in state
|
507 |
+
and len(state["optimizer_history"]) > 0
|
508 |
+
and "num_updates" in state["optimizer_history"][-1]
|
509 |
+
):
|
510 |
+
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
511 |
+
model.load_state_dict(
|
512 |
+
state["model"], strict=strict, model_cfg=cfg.model
|
513 |
+
)
|
514 |
+
|
515 |
+
# reset state so it gets loaded for the next model in ensemble
|
516 |
+
state = None
|
517 |
+
if shard_idx % 10 == 0 and shard_idx > 0:
|
518 |
+
elapsed = time.time() - st
|
519 |
+
logger.info(
|
520 |
+
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
521 |
+
)
|
522 |
+
|
523 |
+
# build model for ensemble
|
524 |
+
ensemble.append(model)
|
525 |
+
return ensemble, cfg, task
|
526 |
+
|
527 |
+
|
528 |
+
def load_model_ensemble_and_task_from_hf_hub(
|
529 |
+
model_id,
|
530 |
+
cache_dir: Optional[str] = None,
|
531 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
532 |
+
**kwargs: Any,
|
533 |
+
):
|
534 |
+
try:
|
535 |
+
from huggingface_hub import snapshot_download
|
536 |
+
except ImportError:
|
537 |
+
raise ImportError(
|
538 |
+
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
539 |
+
"See https://pypi.org/project/huggingface-hub/ for installation."
|
540 |
+
)
|
541 |
+
|
542 |
+
library_name = "fairseq"
|
543 |
+
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
544 |
+
cache_dir = snapshot_download(
|
545 |
+
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
546 |
+
)
|
547 |
+
|
548 |
+
_arg_overrides = arg_overrides or {}
|
549 |
+
_arg_overrides["data"] = cache_dir
|
550 |
+
return load_model_ensemble_and_task(
|
551 |
+
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
552 |
+
arg_overrides=_arg_overrides,
|
553 |
+
)
|
554 |
+
|
555 |
+
|
556 |
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
557 |
+
"""Retrieves all checkpoints found in `path` directory.
|
558 |
+
|
559 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
560 |
+
the pattern contains groups, the result will be sorted by the first group in
|
561 |
+
descending order.
|
562 |
+
"""
|
563 |
+
pt_regexp = re.compile(pattern)
|
564 |
+
files = PathManager.ls(path)
|
565 |
+
|
566 |
+
entries = []
|
567 |
+
for i, f in enumerate(files):
|
568 |
+
m = pt_regexp.fullmatch(f)
|
569 |
+
if m is not None:
|
570 |
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
571 |
+
entries.append((idx, m.group(0)))
|
572 |
+
if keep_match:
|
573 |
+
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
574 |
+
else:
|
575 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
576 |
+
|
577 |
+
|
578 |
+
def torch_persistent_save(obj, filename, async_write: bool = False):
|
579 |
+
if async_write:
|
580 |
+
with PathManager.opena(filename, "wb") as f:
|
581 |
+
_torch_persistent_save(obj, f)
|
582 |
+
else:
|
583 |
+
if PathManager.supports_rename(filename):
|
584 |
+
# do atomic save
|
585 |
+
with PathManager.open(filename + ".tmp", "wb") as f:
|
586 |
+
_torch_persistent_save(obj, f)
|
587 |
+
PathManager.rename(filename + ".tmp", filename)
|
588 |
+
else:
|
589 |
+
# fallback to non-atomic save
|
590 |
+
with PathManager.open(filename, "wb") as f:
|
591 |
+
_torch_persistent_save(obj, f)
|
592 |
+
|
593 |
+
|
594 |
+
def _torch_persistent_save(obj, f):
|
595 |
+
if isinstance(f, str):
|
596 |
+
with PathManager.open(f, "wb") as h:
|
597 |
+
torch_persistent_save(obj, h)
|
598 |
+
return
|
599 |
+
for i in range(3):
|
600 |
+
try:
|
601 |
+
return torch.save(obj, f)
|
602 |
+
except Exception:
|
603 |
+
if i == 2:
|
604 |
+
logger.error(traceback.format_exc())
|
605 |
+
raise
|
606 |
+
else:
|
607 |
+
time.sleep(2.5)
|
608 |
+
|
609 |
+
|
610 |
+
def _upgrade_state_dict(state):
|
611 |
+
"""Helper for upgrading old model checkpoints."""
|
612 |
+
|
613 |
+
# add optimizer_history
|
614 |
+
if "optimizer_history" not in state:
|
615 |
+
state["optimizer_history"] = [
|
616 |
+
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
617 |
+
]
|
618 |
+
state["last_optimizer_state"] = state["optimizer"]
|
619 |
+
del state["optimizer"]
|
620 |
+
del state["best_loss"]
|
621 |
+
# move extra_state into sub-dictionary
|
622 |
+
if "epoch" in state and "extra_state" not in state:
|
623 |
+
state["extra_state"] = {
|
624 |
+
"epoch": state["epoch"],
|
625 |
+
"batch_offset": state["batch_offset"],
|
626 |
+
"val_loss": state["val_loss"],
|
627 |
+
}
|
628 |
+
del state["epoch"]
|
629 |
+
del state["batch_offset"]
|
630 |
+
del state["val_loss"]
|
631 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
632 |
+
if "optimizer" in state["optimizer_history"][-1]:
|
633 |
+
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
634 |
+
for optim_hist in state["optimizer_history"]:
|
635 |
+
del optim_hist["optimizer"]
|
636 |
+
# record the optimizer class name
|
637 |
+
if "optimizer_name" not in state["optimizer_history"][-1]:
|
638 |
+
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
639 |
+
# move best_loss into lr_scheduler_state
|
640 |
+
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
641 |
+
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
642 |
+
"best": state["optimizer_history"][-1]["best_loss"]
|
643 |
+
}
|
644 |
+
del state["optimizer_history"][-1]["best_loss"]
|
645 |
+
# keep track of number of updates
|
646 |
+
if "num_updates" not in state["optimizer_history"][-1]:
|
647 |
+
state["optimizer_history"][-1]["num_updates"] = 0
|
648 |
+
# use stateful training data iterator
|
649 |
+
if "train_iterator" not in state["extra_state"]:
|
650 |
+
state["extra_state"]["train_iterator"] = {
|
651 |
+
"epoch": state["extra_state"].get("epoch", 0),
|
652 |
+
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
653 |
+
}
|
654 |
+
|
655 |
+
# backward compatibility, cfg updates
|
656 |
+
if "args" in state and state["args"] is not None:
|
657 |
+
# old model checkpoints may not have separate source/target positions
|
658 |
+
if hasattr(state["args"], "max_positions") and not hasattr(
|
659 |
+
state["args"], "max_source_positions"
|
660 |
+
):
|
661 |
+
state["args"].max_source_positions = state["args"].max_positions
|
662 |
+
state["args"].max_target_positions = state["args"].max_positions
|
663 |
+
# default to translation task
|
664 |
+
if not hasattr(state["args"], "task"):
|
665 |
+
state["args"].task = "translation"
|
666 |
+
# --raw-text and --lazy-load are deprecated
|
667 |
+
if getattr(state["args"], "raw_text", False):
|
668 |
+
state["args"].dataset_impl = "raw"
|
669 |
+
elif getattr(state["args"], "lazy_load", False):
|
670 |
+
state["args"].dataset_impl = "lazy"
|
671 |
+
# epochs start at 1
|
672 |
+
if state["extra_state"]["train_iterator"] is not None:
|
673 |
+
state["extra_state"]["train_iterator"]["epoch"] = max(
|
674 |
+
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
675 |
+
)
|
676 |
+
# --remove-bpe ==> --postprocess
|
677 |
+
if hasattr(state["args"], "remove_bpe"):
|
678 |
+
state["args"].post_process = state["args"].remove_bpe
|
679 |
+
# --min-lr ==> --stop-min-lr
|
680 |
+
if hasattr(state["args"], "min_lr"):
|
681 |
+
state["args"].stop_min_lr = state["args"].min_lr
|
682 |
+
del state["args"].min_lr
|
683 |
+
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
684 |
+
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
685 |
+
"binary_cross_entropy",
|
686 |
+
"kd_binary_cross_entropy",
|
687 |
+
]:
|
688 |
+
state["args"].criterion = "wav2vec"
|
689 |
+
# remove log_keys if it's None (criteria will supply a default value of [])
|
690 |
+
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
691 |
+
delattr(state["args"], "log_keys")
|
692 |
+
# speech_pretraining => audio pretraining
|
693 |
+
if (
|
694 |
+
hasattr(state["args"], "task")
|
695 |
+
and state["args"].task == "speech_pretraining"
|
696 |
+
):
|
697 |
+
state["args"].task = "audio_pretraining"
|
698 |
+
# audio_cpc => wav2vec
|
699 |
+
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
700 |
+
state["args"].arch = "wav2vec"
|
701 |
+
# convert legacy float learning rate to List[float]
|
702 |
+
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
703 |
+
state["args"].lr = [state["args"].lr]
|
704 |
+
# convert task data arg to a string instead of List[string]
|
705 |
+
if (
|
706 |
+
hasattr(state["args"], "data")
|
707 |
+
and isinstance(state["args"].data, list)
|
708 |
+
and len(state["args"].data) > 0
|
709 |
+
):
|
710 |
+
state["args"].data = state["args"].data[0]
|
711 |
+
|
712 |
+
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
713 |
+
|
714 |
+
if "cfg" in state and state["cfg"] is not None:
|
715 |
+
cfg = state["cfg"]
|
716 |
+
with open_dict(cfg):
|
717 |
+
# any upgrades for Hydra-based configs
|
718 |
+
if (
|
719 |
+
"task" in cfg
|
720 |
+
and "eval_wer_config" in cfg.task
|
721 |
+
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
722 |
+
):
|
723 |
+
cfg.task.eval_wer_config.print_alignment = "hard"
|
724 |
+
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
725 |
+
cfg.generation.print_alignment = (
|
726 |
+
"hard" if cfg.generation.print_alignment else None
|
727 |
+
)
|
728 |
+
if (
|
729 |
+
"model" in cfg
|
730 |
+
and "w2v_args" in cfg.model
|
731 |
+
and cfg.model.w2v_args is not None
|
732 |
+
and (
|
733 |
+
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
734 |
+
)
|
735 |
+
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
736 |
+
and cfg.model.w2v_args.task.eval_wer_config is not None
|
737 |
+
and isinstance(
|
738 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
739 |
+
)
|
740 |
+
):
|
741 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
742 |
+
|
743 |
+
return state
|
744 |
+
|
745 |
+
|
746 |
+
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
747 |
+
"""Prune the given state_dict if desired for LayerDrop
|
748 |
+
(https://arxiv.org/abs/1909.11556).
|
749 |
+
|
750 |
+
Training with LayerDrop allows models to be robust to pruning at inference
|
751 |
+
time. This function prunes state_dict to allow smaller models to be loaded
|
752 |
+
from a larger model and re-maps the existing state_dict for this to occur.
|
753 |
+
|
754 |
+
It's called by functions that load models from checkpoints and does not
|
755 |
+
need to be called directly.
|
756 |
+
"""
|
757 |
+
arch = None
|
758 |
+
if model_cfg is not None:
|
759 |
+
arch = (
|
760 |
+
model_cfg._name
|
761 |
+
if isinstance(model_cfg, DictConfig)
|
762 |
+
else getattr(model_cfg, "arch", None)
|
763 |
+
)
|
764 |
+
|
765 |
+
if not model_cfg or arch is None or arch == "ptt_transformer":
|
766 |
+
# args should not be none, but don't crash if it is.
|
767 |
+
return state_dict
|
768 |
+
|
769 |
+
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
770 |
+
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
771 |
+
|
772 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
773 |
+
return state_dict
|
774 |
+
|
775 |
+
# apply pruning
|
776 |
+
logger.info(
|
777 |
+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
778 |
+
)
|
779 |
+
|
780 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
781 |
+
keep_layers = sorted(
|
782 |
+
int(layer_string) for layer_string in layers_to_keep.split(",")
|
783 |
+
)
|
784 |
+
mapping_dict = {}
|
785 |
+
for i in range(len(keep_layers)):
|
786 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
787 |
+
|
788 |
+
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
789 |
+
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
790 |
+
|
791 |
+
pruning_passes = []
|
792 |
+
if encoder_layers_to_keep:
|
793 |
+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
794 |
+
if decoder_layers_to_keep:
|
795 |
+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
796 |
+
|
797 |
+
new_state_dict = {}
|
798 |
+
for layer_name in state_dict.keys():
|
799 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
800 |
+
# if layer has no number in it, it is a supporting layer, such as an
|
801 |
+
# embedding
|
802 |
+
if not match:
|
803 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
804 |
+
continue
|
805 |
+
|
806 |
+
# otherwise, layer should be pruned.
|
807 |
+
original_layer_number = match.group(1)
|
808 |
+
# figure out which mapping dict to replace from
|
809 |
+
for pruning_pass in pruning_passes:
|
810 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
811 |
+
"substitution_regex"
|
812 |
+
].search(layer_name):
|
813 |
+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
814 |
+
substitution_match = pruning_pass["substitution_regex"].search(
|
815 |
+
layer_name
|
816 |
+
)
|
817 |
+
new_state_key = (
|
818 |
+
layer_name[: substitution_match.start(1)]
|
819 |
+
+ new_layer_number
|
820 |
+
+ layer_name[substitution_match.end(1) :]
|
821 |
+
)
|
822 |
+
new_state_dict[new_state_key] = state_dict[layer_name]
|
823 |
+
|
824 |
+
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
825 |
+
# This is more of "It would make it work fix" rather than a proper fix.
|
826 |
+
if isinstance(model_cfg, DictConfig):
|
827 |
+
context = open_dict(model_cfg)
|
828 |
+
else:
|
829 |
+
context = contextlib.ExitStack()
|
830 |
+
with context:
|
831 |
+
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
832 |
+
model_cfg.encoder_layers_to_keep = None
|
833 |
+
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
834 |
+
model_cfg.decoder_layers_to_keep = None
|
835 |
+
|
836 |
+
return new_state_dict
|
837 |
+
|
838 |
+
|
839 |
+
def load_pretrained_component_from_model(
|
840 |
+
component: Union[FairseqEncoder, FairseqDecoder],
|
841 |
+
checkpoint: str,
|
842 |
+
strict: bool = True,
|
843 |
+
):
|
844 |
+
"""
|
845 |
+
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
846 |
+
provided `component` object. If state_dict fails to load, there may be a
|
847 |
+
mismatch in the architecture of the corresponding `component` found in the
|
848 |
+
`checkpoint` file.
|
849 |
+
"""
|
850 |
+
if not PathManager.exists(checkpoint):
|
851 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
852 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
853 |
+
if isinstance(component, FairseqEncoder):
|
854 |
+
component_type = "encoder"
|
855 |
+
elif isinstance(component, FairseqDecoder):
|
856 |
+
component_type = "decoder"
|
857 |
+
else:
|
858 |
+
raise ValueError(
|
859 |
+
"component to load must be either a FairseqEncoder or "
|
860 |
+
"FairseqDecoder. Loading other component types are not supported."
|
861 |
+
)
|
862 |
+
component_state_dict = OrderedDict()
|
863 |
+
for key in state["model"].keys():
|
864 |
+
if key.startswith(component_type):
|
865 |
+
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
866 |
+
component_subkey = key[len(component_type) + 1 :]
|
867 |
+
component_state_dict[component_subkey] = state["model"][key]
|
868 |
+
component.load_state_dict(component_state_dict, strict=strict)
|
869 |
+
return component
|
870 |
+
|
871 |
+
|
872 |
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
873 |
+
if not os.path.exists(save_dir):
|
874 |
+
os.makedirs(save_dir, exist_ok=True)
|
875 |
+
temp_file_path = os.path.join(save_dir, "dummy")
|
876 |
+
try:
|
877 |
+
with open(temp_file_path, "w"):
|
878 |
+
pass
|
879 |
+
except OSError as e:
|
880 |
+
logger.warning(
|
881 |
+
"Unable to access checkpoint save directory: {}".format(save_dir)
|
882 |
+
)
|
883 |
+
raise e
|
884 |
+
else:
|
885 |
+
os.remove(temp_file_path)
|
886 |
+
|
887 |
+
|
888 |
+
def save_ema_as_checkpoint(src_path, dst_path):
|
889 |
+
state = load_ema_from_checkpoint(src_path)
|
890 |
+
torch_persistent_save(state, dst_path)
|
891 |
+
|
892 |
+
|
893 |
+
def load_ema_from_checkpoint(fpath):
|
894 |
+
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
895 |
+
returns a model with ema weights.
|
896 |
+
|
897 |
+
Args:
|
898 |
+
fpath: A string path of checkpoint to load from.
|
899 |
+
|
900 |
+
Returns:
|
901 |
+
A dict of string keys mapping to various values. The 'model' key
|
902 |
+
from the returned dict should correspond to an OrderedDict mapping
|
903 |
+
string parameter names to torch Tensors.
|
904 |
+
"""
|
905 |
+
params_dict = collections.OrderedDict()
|
906 |
+
new_state = None
|
907 |
+
|
908 |
+
with PathManager.open(fpath, "rb") as f:
|
909 |
+
new_state = torch.load(
|
910 |
+
f,
|
911 |
+
map_location=(
|
912 |
+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
913 |
+
),
|
914 |
+
)
|
915 |
+
|
916 |
+
# EMA model is stored in a separate "extra state"
|
917 |
+
model_params = new_state["extra_state"]["ema"]
|
918 |
+
|
919 |
+
for key in list(model_params.keys()):
|
920 |
+
p = model_params[key]
|
921 |
+
if isinstance(p, torch.HalfTensor):
|
922 |
+
p = p.float()
|
923 |
+
if key not in params_dict:
|
924 |
+
params_dict[key] = p.clone()
|
925 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
926 |
+
else:
|
927 |
+
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
928 |
+
|
929 |
+
if len(params_dict) == 0:
|
930 |
+
raise ValueError(
|
931 |
+
f"Input checkpoint path '{fpath}' does not contain "
|
932 |
+
"ema model weights, is this model trained with EMA?"
|
933 |
+
)
|
934 |
+
|
935 |
+
new_state["model"] = params_dict
|
936 |
+
return new_state
|