File size: 1,206 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import sys
import shutil
import pytorch_lightning as pl


class CopyPretrainedCheckpoints(pl.callbacks.Callback):
    def __init__(self):
        super().__init__()

    def on_fit_start(self, trainer, pl_module):
        """Before training, move the pre-trained checkpoints
        to the current checkpoint directory.

        """
        # copy any pre-trained checkpoints to new directory
        if pl_module.hparams.processor_model == "proxy":
            pretrained_ckpt_dir = os.path.join(
                pl_module.logger.experiment.log_dir, "pretrained_checkpoints"
            )
            if not os.path.isdir(pretrained_ckpt_dir):
                os.makedirs(pretrained_ckpt_dir)
            cp_proxy_ckpts = []
            for proxy_ckpt in pl_module.hparams.proxy_ckpts:
                new_ckpt = shutil.copy(
                    proxy_ckpt,
                    pretrained_ckpt_dir,
                )
                cp_proxy_ckpts.append(new_ckpt)
                print(f"Moved checkpoint to {new_ckpt}.")
            # overwrite to the paths in current experiment logs
            pl_module.hparams.proxy_ckpts = cp_proxy_ckpts
            print(pl_module.hparams.proxy_ckpts)