File size: 3,941 Bytes
39aef76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os

class WandbLogger:
    """

    Log using `Weights and Biases`.

    """
    def __init__(self):
        try:
            import wandb
        except ImportError:
            raise ImportError(
                "To use the Weights and Biases Logger please install wandb."
                "Run `pip install wandb` to install it."
            )
        
        self._wandb = wandb

        # Initialize a W&B run
        if self._wandb.run is None:
            self._wandb.init(
                project='diff_derain',
                dir='./experiments'
            )

        self.config = self._wandb.config

        # if self.config.get('log_eval', None):
        #     self.eval_table = self._wandb.Table(columns=['fake_image', 
        #                                                  'sr_image', 
        #                                                  'hr_image',
        #                                                  'psnr',
        #                                                  'ssim'])
        # else:
        self.eval_table = None

        # if self.config.get('log_infer', None):
        #     self.infer_table = self._wandb.Table(columns=['fake_image', 
        #                                                  'sr_image', 
        #                                                  'hr_image'])
        # else:
        self.infer_table = None

    def log_metrics(self, metrics, commit=True): 
        """

        Log train/validation metrics onto W&B.



        metrics: dictionary of metrics to be logged

        """
        self._wandb.log(metrics, commit=commit)

    def log_image(self, key_name, image_array):
        """

        Log image array onto W&B.



        key_name: name of the key 

        image_array: numpy array of image.

        """
        self._wandb.log({key_name: self._wandb.Image(image_array)})

    def log_images(self, key_name, list_images):
        """

        Log list of image array onto W&B



        key_name: name of the key 

        list_images: list of numpy image arrays

        """
        self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})

    def log_checkpoint(self, current_epoch, current_step):
        """

        Log the model checkpoint as W&B artifacts



        current_epoch: the current epoch 

        current_step: the current batch step

        """
        model_artifact = self._wandb.Artifact(
            self._wandb.run.id + "_model", type="model"
        )

        gen_path = os.path.join(
            self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
        opt_path = os.path.join(
            self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))

        model_artifact.add_file(gen_path)
        model_artifact.add_file(opt_path)
        self._wandb.log_artifact(model_artifact, aliases=["latest"])

    def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
        """

        Add data row-wise to the initialized table.

        """
        if psnr is not None and ssim is not None:
            self.eval_table.add_data(
                self._wandb.Image(fake_img),
                self._wandb.Image(sr_img),
                self._wandb.Image(hr_img),
                psnr,
                ssim
            )
        else:
            self.infer_table.add_data(
                self._wandb.Image(fake_img),
                self._wandb.Image(sr_img),
                self._wandb.Image(hr_img)
            )

    def log_eval_table(self, commit=False):
        """

        Log the table

        """
        if self.eval_table:
            self._wandb.log({'eval_data': self.eval_table}, commit=commit)
        elif self.infer_table:
            self._wandb.log({'infer_data': self.infer_table}, commit=commit)