kevinwang676 commited on
Commit
9cabf4f
1 Parent(s): 38e25d3

Delete load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +0 -936
load_model.py DELETED
@@ -1,936 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import ast
7
- import collections
8
- import contextlib
9
- import inspect
10
- import logging
11
- import os
12
- import re
13
- import time
14
- import traceback
15
- from collections import OrderedDict
16
- from pathlib import Path
17
- from typing import Any, Dict, Optional, Union
18
-
19
- import numpy as np
20
- import torch
21
- from fairseq.data import data_utils
22
- from fairseq.dataclass.configs import CheckpointConfig
23
- from fairseq.dataclass.utils import (
24
- convert_namespace_to_omegaconf,
25
- overwrite_args_by_name,
26
- )
27
- from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
28
- from fairseq.file_io import PathManager
29
- from fairseq.models import FairseqDecoder, FairseqEncoder
30
- from omegaconf import DictConfig, OmegaConf, open_dict
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
- from fairseq import meters
37
-
38
- # only one worker should attempt to create the required dir
39
- if trainer.data_parallel_rank == 0:
40
- os.makedirs(cfg.save_dir, exist_ok=True)
41
-
42
- prev_best = getattr(save_checkpoint, "best", val_loss)
43
- if val_loss is not None:
44
- best_function = max if cfg.maximize_best_checkpoint_metric else min
45
- save_checkpoint.best = best_function(val_loss, prev_best)
46
-
47
- if cfg.no_save:
48
- return None
49
-
50
- trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
-
52
- if not trainer.should_save_checkpoint_on_current_rank:
53
- if trainer.always_call_state_dict_during_save_checkpoint:
54
- trainer.state_dict()
55
- return None
56
-
57
- write_timer = meters.StopwatchMeter()
58
- write_timer.start()
59
-
60
- epoch = epoch_itr.epoch
61
- end_of_epoch = epoch_itr.end_of_epoch()
62
- updates = trainer.get_num_updates()
63
-
64
- logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
-
66
- def is_better(a, b):
67
- return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
-
69
- suffix = trainer.checkpoint_suffix
70
- checkpoint_conds = collections.OrderedDict()
71
- checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
- end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
- )
74
- checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
- not end_of_epoch
76
- and cfg.save_interval_updates > 0
77
- and updates % cfg.save_interval_updates == 0
78
- )
79
- checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
- not hasattr(save_checkpoint, "best")
81
- or is_better(val_loss, save_checkpoint.best)
82
- )
83
- if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
- worst_best = getattr(save_checkpoint, "best", None)
85
- chkpts = checkpoint_paths(
86
- cfg.save_dir,
87
- pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
- cfg.best_checkpoint_metric, suffix
89
- ),
90
- )
91
- if len(chkpts) > 0:
92
- p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
- worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
- # add random digits to resolve ties
95
- with data_utils.numpy_seed(epoch, updates, val_loss):
96
- rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
-
98
- checkpoint_conds[
99
- "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
- cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
101
- )
102
- ] = worst_best is None or is_better(val_loss, worst_best)
103
- checkpoint_conds[
104
- "checkpoint_last{}.pt".format(suffix)
105
- ] = not cfg.no_last_checkpoints
106
-
107
- extra_state = {
108
- "train_iterator": epoch_itr.state_dict(),
109
- "val_loss": val_loss,
110
- }
111
-
112
- # Going forward, different tasks could expose an API like this to dump all
113
- # the checkpoint worthy attributes in a dictionary which then will be
114
- # merged with the parent dictionary to create the "extra_state". This
115
- # allows for an extensible yet simple design to checkpoint task level
116
- # attributes
117
- if hasattr(trainer.task, "get_checkpoint_dict"):
118
- extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
119
- logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
120
-
121
- if hasattr(save_checkpoint, "best"):
122
- extra_state.update({"best": save_checkpoint.best})
123
-
124
- checkpoints = [
125
- os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
126
- ]
127
- saved_cp = None
128
- if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
129
- saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
130
- for cp in checkpoints[1:]:
131
- if cfg.write_checkpoints_asynchronously:
132
- # TODO[ioPath]: Need to implement a delayed asynchronous
133
- # file copying/moving feature.
134
- logger.warning(
135
- f"ioPath is not copying {checkpoints[0]} to {cp} "
136
- "since async write mode is on."
137
- )
138
- else:
139
- assert PathManager.copy(
140
- checkpoints[0], cp, overwrite=True
141
- ), f"Failed to copy {checkpoints[0]} to {cp}"
142
-
143
- write_timer.stop()
144
- logger.info(
145
- "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
146
- checkpoints[0], epoch, updates, val_loss, write_timer.sum
147
- )
148
- )
149
-
150
- if (
151
- not end_of_epoch
152
- and cfg.keep_interval_updates > 0
153
- and trainer.should_save_checkpoint_on_current_rank
154
- ):
155
- # remove old checkpoints; checkpoints are sorted in descending order
156
- if cfg.keep_interval_updates_pattern == -1:
157
- checkpoints = checkpoint_paths(
158
- cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
159
- )
160
- else:
161
- checkpoints = checkpoint_paths(
162
- cfg.save_dir,
163
- pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
164
- keep_match=True,
165
- )
166
- checkpoints = [
167
- x[0]
168
- for x in checkpoints
169
- if x[1] % cfg.keep_interval_updates_pattern != 0
170
- ]
171
-
172
- for old_chk in checkpoints[cfg.keep_interval_updates :]:
173
- if os.path.lexists(old_chk):
174
- os.remove(old_chk)
175
- elif PathManager.exists(old_chk):
176
- PathManager.rm(old_chk)
177
-
178
- if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
179
- # remove old epoch checkpoints; checkpoints are sorted in descending order
180
- checkpoints = checkpoint_paths(
181
- cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
182
- )
183
- for old_chk in checkpoints[cfg.keep_last_epochs :]:
184
- if os.path.lexists(old_chk):
185
- os.remove(old_chk)
186
- elif PathManager.exists(old_chk):
187
- PathManager.rm(old_chk)
188
-
189
- if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
190
- # only keep the best N checkpoints according to validation metric
191
- checkpoints = checkpoint_paths(
192
- cfg.save_dir,
193
- pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
194
- cfg.best_checkpoint_metric, suffix
195
- ),
196
- )
197
- if not cfg.maximize_best_checkpoint_metric:
198
- checkpoints = checkpoints[::-1]
199
- for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
200
- if os.path.lexists(old_chk):
201
- os.remove(old_chk)
202
- elif PathManager.exists(old_chk):
203
- PathManager.rm(old_chk)
204
-
205
- return saved_cp
206
-
207
-
208
- def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
209
- """
210
- Load a checkpoint and restore the training iterator.
211
-
212
- *passthrough_args* will be passed through to
213
- ``trainer.get_train_iterator``.
214
- """
215
-
216
- reset_optimizer = cfg.reset_optimizer
217
- reset_lr_scheduler = cfg.reset_lr_scheduler
218
- optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
219
- reset_meters = cfg.reset_meters
220
- reset_dataloader = cfg.reset_dataloader
221
-
222
- if cfg.finetune_from_model is not None and (
223
- reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
224
- ):
225
- raise ValueError(
226
- "--finetune-from-model can not be set together with either --reset-optimizer"
227
- " or reset_lr_scheduler or reset_meters or reset_dataloader"
228
- )
229
-
230
- suffix = trainer.checkpoint_suffix
231
- if (
232
- cfg.restore_file == "checkpoint_last.pt"
233
- ): # default value of restore_file is 'checkpoint_last.pt'
234
- checkpoint_path = os.path.join(
235
- cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
236
- )
237
- first_launch = not PathManager.exists(checkpoint_path)
238
- if first_launch and getattr(cfg, "continue_once", None) is not None:
239
- checkpoint_path = cfg.continue_once
240
- elif cfg.finetune_from_model is not None and first_launch:
241
- # if there is no last checkpoint to restore, start the finetune from pretrained model
242
- # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
243
- if PathManager.exists(cfg.finetune_from_model):
244
- checkpoint_path = cfg.finetune_from_model
245
- reset_optimizer = True
246
- reset_lr_scheduler = True
247
- reset_meters = True
248
- reset_dataloader = True
249
- logger.info(
250
- f"loading pretrained model from {checkpoint_path}: "
251
- "optimizer, lr scheduler, meters, dataloader will be reset"
252
- )
253
- else:
254
- raise ValueError(
255
- f"--finetune-from-model {cfg.finetune_from_model} does not exist"
256
- )
257
- elif suffix is not None:
258
- checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
259
- else:
260
- checkpoint_path = cfg.restore_file
261
-
262
- if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
263
- raise ValueError(
264
- "--finetune-from-model and --restore-file (non-default value) "
265
- "can not be specified together: " + str(cfg)
266
- )
267
-
268
- extra_state = trainer.load_checkpoint(
269
- checkpoint_path,
270
- reset_optimizer,
271
- reset_lr_scheduler,
272
- optimizer_overrides,
273
- reset_meters=reset_meters,
274
- )
275
-
276
- if (
277
- extra_state is not None
278
- and "best" in extra_state
279
- and not reset_optimizer
280
- and not reset_meters
281
- ):
282
- save_checkpoint.best = extra_state["best"]
283
-
284
- if extra_state is not None and not reset_dataloader:
285
- # restore iterator from checkpoint
286
- itr_state = extra_state["train_iterator"]
287
- epoch_itr = trainer.get_train_iterator(
288
- epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
289
- )
290
- epoch_itr.load_state_dict(itr_state)
291
-
292
- # Preload the checkpoint for the task
293
- task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
294
- if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
295
- trainer.task.set_checkpoint_dict(task_cp_dict)
296
- else:
297
- epoch_itr = trainer.get_train_iterator(
298
- epoch=1, load_dataset=True, **passthrough_args
299
- )
300
-
301
- trainer.lr_step(epoch_itr.epoch)
302
-
303
- return extra_state, epoch_itr
304
-
305
-
306
- def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
307
- """Loads a checkpoint to CPU (with upgrading for backward compatibility).
308
-
309
- If doing single-GPU training or if the checkpoint is only being loaded by at
310
- most one process on each node (current default behavior is for only rank 0
311
- to read the checkpoint from disk), load_on_all_ranks should be False to
312
- avoid errors from torch.distributed not having been initialized or
313
- torch.distributed.barrier() hanging.
314
-
315
- If all processes on each node may be loading the checkpoint
316
- simultaneously, load_on_all_ranks should be set to True to avoid I/O
317
- conflicts.
318
-
319
- There's currently no support for > 1 but < all processes loading the
320
- checkpoint on each node.
321
- """
322
- local_path = PathManager.get_local_path(path)
323
- # The locally cached file returned by get_local_path() may be stale for
324
- # remote files that are periodically updated/overwritten (ex:
325
- # checkpoint_last.pt) - so we remove the local copy, sync across processes
326
- # (if needed), and then download a fresh copy.
327
- if local_path != path and PathManager.path_requires_pathmanager(path):
328
- try:
329
- os.remove(local_path)
330
- except FileNotFoundError:
331
- # With potentially multiple processes removing the same file, the
332
- # file being missing is benign (missing_ok isn't available until
333
- # Python 3.8).
334
- pass
335
- if load_on_all_ranks:
336
- torch.distributed.barrier()
337
- local_path = PathManager.get_local_path(path)
338
-
339
- with open(local_path, "rb") as f:
340
- state = torch.load(f, map_location=torch.device("cpu"))
341
-
342
- if "args" in state and state["args"] is not None and arg_overrides is not None:
343
- args = state["args"]
344
- for arg_name, arg_val in arg_overrides.items():
345
- setattr(args, arg_name, arg_val)
346
-
347
- if "cfg" in state and state["cfg"] is not None:
348
-
349
- # hack to be able to set Namespace in dict config. this should be removed when we update to newer
350
- # omegaconf version that supports object flags, or when we migrate all existing models
351
- from omegaconf import __version__ as oc_version
352
- from omegaconf import _utils
353
-
354
- if oc_version < "2.2":
355
- old_primitive = _utils.is_primitive_type
356
- _utils.is_primitive_type = lambda _: True
357
-
358
- state["cfg"] = OmegaConf.create(state["cfg"])
359
-
360
- _utils.is_primitive_type = old_primitive
361
- OmegaConf.set_struct(state["cfg"], True)
362
- else:
363
- state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
364
-
365
- if arg_overrides is not None:
366
- overwrite_args_by_name(state["cfg"], arg_overrides)
367
-
368
- state = _upgrade_state_dict(state)
369
- return state
370
-
371
-
372
- def load_model_ensemble(
373
- filenames,
374
- arg_overrides: Optional[Dict[str, Any]] = None,
375
- task=None,
376
- strict=True,
377
- suffix="",
378
- num_shards=1,
379
- state=None,
380
- ):
381
- """Loads an ensemble of models.
382
-
383
- Args:
384
- filenames (List[str]): checkpoint files to load
385
- arg_overrides (Dict[str,Any], optional): override model args that
386
- were used during model training
387
- task (fairseq.tasks.FairseqTask, optional): task to use for loading
388
- """
389
- assert not (
390
- strict and num_shards > 1
391
- ), "Cannot load state dict with strict=True and checkpoint shards > 1"
392
- ensemble, args, _task = load_model_ensemble_and_task(
393
- filenames,
394
- arg_overrides,
395
- task,
396
- strict,
397
- suffix,
398
- num_shards,
399
- state,
400
- )
401
- return ensemble, args
402
-
403
-
404
- def get_maybe_sharded_checkpoint_filename(
405
- filename: str, suffix: str, shard_idx: int, num_shards: int
406
- ) -> str:
407
- orig_filename = filename
408
- filename = filename.replace(".pt", suffix + ".pt")
409
- fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
410
- model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
411
- if PathManager.exists(fsdp_filename):
412
- return fsdp_filename
413
- elif num_shards > 1:
414
- return model_parallel_filename
415
- else:
416
- return filename
417
-
418
-
419
- def load_model_ensemble_and_task(
420
- filenames,
421
- arg_overrides: Optional[Dict[str, Any]] = None,
422
- task=None,
423
- strict=True,
424
- suffix="",
425
- num_shards=1,
426
- state=None,
427
- ):
428
- assert state is None or len(filenames) == 1
429
-
430
- from fairseq import tasks
431
-
432
- assert not (
433
- strict and num_shards > 1
434
- ), "Cannot load state dict with strict=True and checkpoint shards > 1"
435
- ensemble = []
436
- cfg = None
437
- for filename in filenames:
438
- orig_filename = filename
439
- model_shard_state = {"shard_weights": [], "shard_metadata": []}
440
- assert num_shards > 0
441
- st = time.time()
442
- for shard_idx in range(num_shards):
443
- filename = get_maybe_sharded_checkpoint_filename(
444
- orig_filename, suffix, shard_idx, num_shards
445
- )
446
-
447
- if not PathManager.exists(filename):
448
- raise IOError("Model file not found: {}".format(filename))
449
- if state is None:
450
- state = load_checkpoint_to_cpu(filename, arg_overrides)
451
- if "args" in state and state["args"] is not None:
452
- cfg = convert_namespace_to_omegaconf(state["args"])
453
- elif "cfg" in state and state["cfg"] is not None:
454
- cfg = state["cfg"]
455
- else:
456
- raise RuntimeError(
457
- f"Neither args nor cfg exist in state keys = {state.keys()}"
458
- )
459
-
460
- if task is None:
461
- task = tasks.setup_task(cfg.task, from_checkpoint=True)
462
-
463
- if "task_state" in state:
464
- task.load_state_dict(state["task_state"])
465
-
466
- argspec = inspect.getfullargspec(task.build_model)
467
-
468
- if "fsdp_metadata" in state and num_shards > 1:
469
- model_shard_state["shard_weights"].append(state["model"])
470
- model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
471
- # check FSDP import before the code goes too far
472
- if not has_FSDP:
473
- raise ImportError(
474
- "Cannot find FullyShardedDataParallel. "
475
- "Please install fairscale with: pip install fairscale"
476
- )
477
- if shard_idx == num_shards - 1:
478
- consolidated_model_state = FSDP.consolidate_shard_weights(
479
- shard_weights=model_shard_state["shard_weights"],
480
- shard_metadata=model_shard_state["shard_metadata"],
481
- )
482
- if "from_checkpoint" in argspec.args:
483
- model = task.build_model(cfg.model, from_checkpoint=True)
484
- else:
485
- model = task.build_model(cfg.model)
486
- if (
487
- "optimizer_history" in state
488
- and len(state["optimizer_history"]) > 0
489
- and "num_updates" in state["optimizer_history"][-1]
490
- ):
491
- model.set_num_updates(
492
- state["optimizer_history"][-1]["num_updates"]
493
- )
494
- model.load_state_dict(
495
- consolidated_model_state, strict=strict, model_cfg=cfg.model
496
- )
497
- else:
498
- # model parallel checkpoint or unsharded checkpoint
499
- # support old external tasks
500
-
501
- if "from_checkpoint" in argspec.args:
502
- model = task.build_model(cfg.model, from_checkpoint=True)
503
- else:
504
- model = task.build_model(cfg.model)
505
- if (
506
- "optimizer_history" in state
507
- and len(state["optimizer_history"]) > 0
508
- and "num_updates" in state["optimizer_history"][-1]
509
- ):
510
- model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
511
- model.load_state_dict(
512
- state["model"], strict=strict, model_cfg=cfg.model
513
- )
514
-
515
- # reset state so it gets loaded for the next model in ensemble
516
- state = None
517
- if shard_idx % 10 == 0 and shard_idx > 0:
518
- elapsed = time.time() - st
519
- logger.info(
520
- f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
521
- )
522
-
523
- # build model for ensemble
524
- ensemble.append(model)
525
- return ensemble, cfg, task
526
-
527
-
528
- def load_model_ensemble_and_task_from_hf_hub(
529
- model_id,
530
- cache_dir: Optional[str] = None,
531
- arg_overrides: Optional[Dict[str, Any]] = None,
532
- **kwargs: Any,
533
- ):
534
- try:
535
- from huggingface_hub import snapshot_download
536
- except ImportError:
537
- raise ImportError(
538
- "You need to install huggingface_hub to use `load_from_hf_hub`. "
539
- "See https://pypi.org/project/huggingface-hub/ for installation."
540
- )
541
-
542
- library_name = "fairseq"
543
- cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
544
- cache_dir = snapshot_download(
545
- model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
546
- )
547
-
548
- _arg_overrides = arg_overrides or {}
549
- _arg_overrides["data"] = cache_dir
550
- return load_model_ensemble_and_task(
551
- [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
552
- arg_overrides=_arg_overrides,
553
- )
554
-
555
-
556
- def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
557
- """Retrieves all checkpoints found in `path` directory.
558
-
559
- Checkpoints are identified by matching filename to the specified pattern. If
560
- the pattern contains groups, the result will be sorted by the first group in
561
- descending order.
562
- """
563
- pt_regexp = re.compile(pattern)
564
- files = PathManager.ls(path)
565
-
566
- entries = []
567
- for i, f in enumerate(files):
568
- m = pt_regexp.fullmatch(f)
569
- if m is not None:
570
- idx = float(m.group(1)) if len(m.groups()) > 0 else i
571
- entries.append((idx, m.group(0)))
572
- if keep_match:
573
- return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
574
- else:
575
- return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
576
-
577
-
578
- def torch_persistent_save(obj, filename, async_write: bool = False):
579
- if async_write:
580
- with PathManager.opena(filename, "wb") as f:
581
- _torch_persistent_save(obj, f)
582
- else:
583
- if PathManager.supports_rename(filename):
584
- # do atomic save
585
- with PathManager.open(filename + ".tmp", "wb") as f:
586
- _torch_persistent_save(obj, f)
587
- PathManager.rename(filename + ".tmp", filename)
588
- else:
589
- # fallback to non-atomic save
590
- with PathManager.open(filename, "wb") as f:
591
- _torch_persistent_save(obj, f)
592
-
593
-
594
- def _torch_persistent_save(obj, f):
595
- if isinstance(f, str):
596
- with PathManager.open(f, "wb") as h:
597
- torch_persistent_save(obj, h)
598
- return
599
- for i in range(3):
600
- try:
601
- return torch.save(obj, f)
602
- except Exception:
603
- if i == 2:
604
- logger.error(traceback.format_exc())
605
- raise
606
- else:
607
- time.sleep(2.5)
608
-
609
-
610
- def _upgrade_state_dict(state):
611
- """Helper for upgrading old model checkpoints."""
612
-
613
- # add optimizer_history
614
- if "optimizer_history" not in state:
615
- state["optimizer_history"] = [
616
- {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
617
- ]
618
- state["last_optimizer_state"] = state["optimizer"]
619
- del state["optimizer"]
620
- del state["best_loss"]
621
- # move extra_state into sub-dictionary
622
- if "epoch" in state and "extra_state" not in state:
623
- state["extra_state"] = {
624
- "epoch": state["epoch"],
625
- "batch_offset": state["batch_offset"],
626
- "val_loss": state["val_loss"],
627
- }
628
- del state["epoch"]
629
- del state["batch_offset"]
630
- del state["val_loss"]
631
- # reduce optimizer history's memory usage (only keep the last state)
632
- if "optimizer" in state["optimizer_history"][-1]:
633
- state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
634
- for optim_hist in state["optimizer_history"]:
635
- del optim_hist["optimizer"]
636
- # record the optimizer class name
637
- if "optimizer_name" not in state["optimizer_history"][-1]:
638
- state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
639
- # move best_loss into lr_scheduler_state
640
- if "lr_scheduler_state" not in state["optimizer_history"][-1]:
641
- state["optimizer_history"][-1]["lr_scheduler_state"] = {
642
- "best": state["optimizer_history"][-1]["best_loss"]
643
- }
644
- del state["optimizer_history"][-1]["best_loss"]
645
- # keep track of number of updates
646
- if "num_updates" not in state["optimizer_history"][-1]:
647
- state["optimizer_history"][-1]["num_updates"] = 0
648
- # use stateful training data iterator
649
- if "train_iterator" not in state["extra_state"]:
650
- state["extra_state"]["train_iterator"] = {
651
- "epoch": state["extra_state"].get("epoch", 0),
652
- "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
653
- }
654
-
655
- # backward compatibility, cfg updates
656
- if "args" in state and state["args"] is not None:
657
- # old model checkpoints may not have separate source/target positions
658
- if hasattr(state["args"], "max_positions") and not hasattr(
659
- state["args"], "max_source_positions"
660
- ):
661
- state["args"].max_source_positions = state["args"].max_positions
662
- state["args"].max_target_positions = state["args"].max_positions
663
- # default to translation task
664
- if not hasattr(state["args"], "task"):
665
- state["args"].task = "translation"
666
- # --raw-text and --lazy-load are deprecated
667
- if getattr(state["args"], "raw_text", False):
668
- state["args"].dataset_impl = "raw"
669
- elif getattr(state["args"], "lazy_load", False):
670
- state["args"].dataset_impl = "lazy"
671
- # epochs start at 1
672
- if state["extra_state"]["train_iterator"] is not None:
673
- state["extra_state"]["train_iterator"]["epoch"] = max(
674
- state["extra_state"]["train_iterator"].get("epoch", 1), 1
675
- )
676
- # --remove-bpe ==> --postprocess
677
- if hasattr(state["args"], "remove_bpe"):
678
- state["args"].post_process = state["args"].remove_bpe
679
- # --min-lr ==> --stop-min-lr
680
- if hasattr(state["args"], "min_lr"):
681
- state["args"].stop_min_lr = state["args"].min_lr
682
- del state["args"].min_lr
683
- # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
684
- if hasattr(state["args"], "criterion") and state["args"].criterion in [
685
- "binary_cross_entropy",
686
- "kd_binary_cross_entropy",
687
- ]:
688
- state["args"].criterion = "wav2vec"
689
- # remove log_keys if it's None (criteria will supply a default value of [])
690
- if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
691
- delattr(state["args"], "log_keys")
692
- # speech_pretraining => audio pretraining
693
- if (
694
- hasattr(state["args"], "task")
695
- and state["args"].task == "speech_pretraining"
696
- ):
697
- state["args"].task = "audio_pretraining"
698
- # audio_cpc => wav2vec
699
- if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
700
- state["args"].arch = "wav2vec"
701
- # convert legacy float learning rate to List[float]
702
- if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
703
- state["args"].lr = [state["args"].lr]
704
- # convert task data arg to a string instead of List[string]
705
- if (
706
- hasattr(state["args"], "data")
707
- and isinstance(state["args"].data, list)
708
- and len(state["args"].data) > 0
709
- ):
710
- state["args"].data = state["args"].data[0]
711
-
712
- state["cfg"] = convert_namespace_to_omegaconf(state["args"])
713
-
714
- if "cfg" in state and state["cfg"] is not None:
715
- cfg = state["cfg"]
716
- with open_dict(cfg):
717
- # any upgrades for Hydra-based configs
718
- if (
719
- "task" in cfg
720
- and "eval_wer_config" in cfg.task
721
- and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
722
- ):
723
- cfg.task.eval_wer_config.print_alignment = "hard"
724
- if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
725
- cfg.generation.print_alignment = (
726
- "hard" if cfg.generation.print_alignment else None
727
- )
728
- if (
729
- "model" in cfg
730
- and "w2v_args" in cfg.model
731
- and cfg.model.w2v_args is not None
732
- and (
733
- hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
734
- )
735
- and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
736
- and cfg.model.w2v_args.task.eval_wer_config is not None
737
- and isinstance(
738
- cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
739
- )
740
- ):
741
- cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
742
-
743
- return state
744
-
745
-
746
- def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
747
- """Prune the given state_dict if desired for LayerDrop
748
- (https://arxiv.org/abs/1909.11556).
749
-
750
- Training with LayerDrop allows models to be robust to pruning at inference
751
- time. This function prunes state_dict to allow smaller models to be loaded
752
- from a larger model and re-maps the existing state_dict for this to occur.
753
-
754
- It's called by functions that load models from checkpoints and does not
755
- need to be called directly.
756
- """
757
- arch = None
758
- if model_cfg is not None:
759
- arch = (
760
- model_cfg._name
761
- if isinstance(model_cfg, DictConfig)
762
- else getattr(model_cfg, "arch", None)
763
- )
764
-
765
- if not model_cfg or arch is None or arch == "ptt_transformer":
766
- # args should not be none, but don't crash if it is.
767
- return state_dict
768
-
769
- encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
770
- decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
771
-
772
- if not encoder_layers_to_keep and not decoder_layers_to_keep:
773
- return state_dict
774
-
775
- # apply pruning
776
- logger.info(
777
- "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
778
- )
779
-
780
- def create_pruning_pass(layers_to_keep, layer_name):
781
- keep_layers = sorted(
782
- int(layer_string) for layer_string in layers_to_keep.split(",")
783
- )
784
- mapping_dict = {}
785
- for i in range(len(keep_layers)):
786
- mapping_dict[str(keep_layers[i])] = str(i)
787
-
788
- regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
789
- return {"substitution_regex": regex, "mapping_dict": mapping_dict}
790
-
791
- pruning_passes = []
792
- if encoder_layers_to_keep:
793
- pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
794
- if decoder_layers_to_keep:
795
- pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
796
-
797
- new_state_dict = {}
798
- for layer_name in state_dict.keys():
799
- match = re.search(r"\.layers\.(\d+)\.", layer_name)
800
- # if layer has no number in it, it is a supporting layer, such as an
801
- # embedding
802
- if not match:
803
- new_state_dict[layer_name] = state_dict[layer_name]
804
- continue
805
-
806
- # otherwise, layer should be pruned.
807
- original_layer_number = match.group(1)
808
- # figure out which mapping dict to replace from
809
- for pruning_pass in pruning_passes:
810
- if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
811
- "substitution_regex"
812
- ].search(layer_name):
813
- new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
814
- substitution_match = pruning_pass["substitution_regex"].search(
815
- layer_name
816
- )
817
- new_state_key = (
818
- layer_name[: substitution_match.start(1)]
819
- + new_layer_number
820
- + layer_name[substitution_match.end(1) :]
821
- )
822
- new_state_dict[new_state_key] = state_dict[layer_name]
823
-
824
- # Since layers are now pruned, *_layers_to_keep are no longer needed.
825
- # This is more of "It would make it work fix" rather than a proper fix.
826
- if isinstance(model_cfg, DictConfig):
827
- context = open_dict(model_cfg)
828
- else:
829
- context = contextlib.ExitStack()
830
- with context:
831
- if hasattr(model_cfg, "encoder_layers_to_keep"):
832
- model_cfg.encoder_layers_to_keep = None
833
- if hasattr(model_cfg, "decoder_layers_to_keep"):
834
- model_cfg.decoder_layers_to_keep = None
835
-
836
- return new_state_dict
837
-
838
-
839
- def load_pretrained_component_from_model(
840
- component: Union[FairseqEncoder, FairseqDecoder],
841
- checkpoint: str,
842
- strict: bool = True,
843
- ):
844
- """
845
- Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
846
- provided `component` object. If state_dict fails to load, there may be a
847
- mismatch in the architecture of the corresponding `component` found in the
848
- `checkpoint` file.
849
- """
850
- if not PathManager.exists(checkpoint):
851
- raise IOError("Model file not found: {}".format(checkpoint))
852
- state = load_checkpoint_to_cpu(checkpoint)
853
- if isinstance(component, FairseqEncoder):
854
- component_type = "encoder"
855
- elif isinstance(component, FairseqDecoder):
856
- component_type = "decoder"
857
- else:
858
- raise ValueError(
859
- "component to load must be either a FairseqEncoder or "
860
- "FairseqDecoder. Loading other component types are not supported."
861
- )
862
- component_state_dict = OrderedDict()
863
- for key in state["model"].keys():
864
- if key.startswith(component_type):
865
- # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
866
- component_subkey = key[len(component_type) + 1 :]
867
- component_state_dict[component_subkey] = state["model"][key]
868
- component.load_state_dict(component_state_dict, strict=strict)
869
- return component
870
-
871
-
872
- def verify_checkpoint_directory(save_dir: str) -> None:
873
- if not os.path.exists(save_dir):
874
- os.makedirs(save_dir, exist_ok=True)
875
- temp_file_path = os.path.join(save_dir, "dummy")
876
- try:
877
- with open(temp_file_path, "w"):
878
- pass
879
- except OSError as e:
880
- logger.warning(
881
- "Unable to access checkpoint save directory: {}".format(save_dir)
882
- )
883
- raise e
884
- else:
885
- os.remove(temp_file_path)
886
-
887
-
888
- def save_ema_as_checkpoint(src_path, dst_path):
889
- state = load_ema_from_checkpoint(src_path)
890
- torch_persistent_save(state, dst_path)
891
-
892
-
893
- def load_ema_from_checkpoint(fpath):
894
- """Loads exponential moving averaged (EMA) checkpoint from input and
895
- returns a model with ema weights.
896
-
897
- Args:
898
- fpath: A string path of checkpoint to load from.
899
-
900
- Returns:
901
- A dict of string keys mapping to various values. The 'model' key
902
- from the returned dict should correspond to an OrderedDict mapping
903
- string parameter names to torch Tensors.
904
- """
905
- params_dict = collections.OrderedDict()
906
- new_state = None
907
-
908
- with PathManager.open(fpath, "rb") as f:
909
- new_state = torch.load(
910
- f,
911
- map_location=(
912
- lambda s, _: torch.serialization.default_restore_location(s, "cpu")
913
- ),
914
- )
915
-
916
- # EMA model is stored in a separate "extra state"
917
- model_params = new_state["extra_state"]["ema"]
918
-
919
- for key in list(model_params.keys()):
920
- p = model_params[key]
921
- if isinstance(p, torch.HalfTensor):
922
- p = p.float()
923
- if key not in params_dict:
924
- params_dict[key] = p.clone()
925
- # NOTE: clone() is needed in case of p is a shared parameter
926
- else:
927
- raise ValueError("Key {} is repeated in EMA model params.".format(key))
928
-
929
- if len(params_dict) == 0:
930
- raise ValueError(
931
- f"Input checkpoint path '{fpath}' does not contain "
932
- "ema model weights, is this model trained with EMA?"
933
- )
934
-
935
- new_state["model"] = params_dict
936
- return new_state