File size: 1,985 Bytes
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os

import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter


class LossHistory():
    def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            for m in model:
                dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
                self.writer.add_graph(m, dummy_input)
        except:
            pass

    def append_loss(self, epoch, **kwargs):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        for key, value in kwargs.items():
            if not hasattr(self, key):
                setattr(self, key, [])
            #---------------------------------#
            #   为列表添加数值
            #---------------------------------#
            getattr(self, key).append(value)
        
            #---------------------------------#
            #   写入txt
            #---------------------------------#
            with open(os.path.join(self.log_dir, key + ".txt"), 'a') as f:
                f.write(str(value))
                f.write("\n")
                
            #---------------------------------#
            #   写入tensorboard
            #---------------------------------#
            self.writer.add_scalar(key, value, epoch)
            
        self.loss_plot(**kwargs)

    def loss_plot(self, **kwargs):
        plt.figure()
        
        for key, value in kwargs.items():
            losses = getattr(self, key)
            plt.plot(range(len(losses)), losses, linewidth = 2, label = key)

        plt.grid(True)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc="upper right")

        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")