File size: 1,226 Bytes
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import glob
import os
from collections import OrderedDict

import torch


class Saver(object):
    def __init__(self, args):
        self.args = args
        self.directory = os.path.join("run", args.train_dataset, args.checkname)
        self.runs = sorted(glob.glob(os.path.join(self.directory, "experiment_*")))
        run_id = int(self.runs[-1].split("_")[-1]) + 1 if self.runs else 0

        self.experiment_dir = os.path.join(
            self.directory, "experiment_{}".format(str(run_id))
        )
        if not os.path.exists(self.experiment_dir):
            os.makedirs(self.experiment_dir)

    def save_checkpoint(self, state, filename="checkpoint.pth.tar"):
        """Saves checkpoint to disk"""
        filename = os.path.join(self.experiment_dir, filename)
        torch.save(state, filename)

    def save_experiment_config(self):
        logfile = os.path.join(self.experiment_dir, "parameters.txt")
        log_file = open(logfile, "w")
        p = OrderedDict()
        p["train_dataset"] = self.args.train_dataset
        p["lr"] = self.args.lr
        p["epoch"] = self.args.epochs

        for key, val in p.items():
            log_file.write(key + ":" + str(val) + "\n")
        log_file.close()