David Piscasio commited on
Commit
0c094b2
1 Parent(s): 32da7be

Added util folder

Browse files
util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
util/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (295 Bytes). View file
 
util/__pycache__/html.cpython-38.pyc ADDED
Binary file (3.64 kB). View file
 
util/__pycache__/util.cpython-38.pyc ADDED
Binary file (3.23 kB). View file
 
util/__pycache__/visualizer.cpython-38.pyc ADDED
Binary file (9.42 kB). View file
 
util/get_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """A Python script for downloading CycleGAN or pix2pix datasets.
13
+
14
+ Parameters:
15
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16
+ verbose (bool) -- If True, print additional information.
17
+
18
+ Examples:
19
+ >>> from util.get_data import GetData
20
+ >>> gd = GetData(technique='cyclegan')
21
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22
+
23
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24
+ and 'scripts/download_cyclegan_model.sh'.
25
+ """
26
+
27
+ def __init__(self, technique='cyclegan', verbose=True):
28
+ url_dict = {
29
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31
+ }
32
+ self.url = url_dict.get(technique.lower())
33
+ self._verbose = verbose
34
+
35
+ def _print(self, text):
36
+ if self._verbose:
37
+ print(text)
38
+
39
+ @staticmethod
40
+ def _get_options(r):
41
+ soup = BeautifulSoup(r.text, 'lxml')
42
+ options = [h.text for h in soup.find_all('a', href=True)
43
+ if h.text.endswith(('.zip', 'tar.gz'))]
44
+ return options
45
+
46
+ def _present_options(self):
47
+ r = requests.get(self.url)
48
+ options = self._get_options(r)
49
+ print('Options:\n')
50
+ for i, o in enumerate(options):
51
+ print("{0}: {1}".format(i, o))
52
+ choice = input("\nPlease enter the number of the "
53
+ "dataset above you wish to download:")
54
+ return options[int(choice)]
55
+
56
+ def _download_data(self, dataset_url, save_path):
57
+ if not isdir(save_path):
58
+ os.makedirs(save_path)
59
+
60
+ base = basename(dataset_url)
61
+ temp_save_path = join(save_path, base)
62
+
63
+ with open(temp_save_path, "wb") as f:
64
+ r = requests.get(dataset_url)
65
+ f.write(r.content)
66
+
67
+ if base.endswith('.tar.gz'):
68
+ obj = tarfile.open(temp_save_path)
69
+ elif base.endswith('.zip'):
70
+ obj = ZipFile(temp_save_path, 'r')
71
+ else:
72
+ raise ValueError("Unknown File Type: {0}.".format(base))
73
+
74
+ self._print("Unpacking Data...")
75
+ obj.extractall(save_path)
76
+ obj.close()
77
+ os.remove(temp_save_path)
78
+
79
+ def get(self, save_path, dataset=None):
80
+ """
81
+
82
+ Download a dataset.
83
+
84
+ Parameters:
85
+ save_path (str) -- A directory to save the data to.
86
+ dataset (str) -- (optional). A specific dataset to download.
87
+ Note: this must include the file extension.
88
+ If None, options will be presented for you
89
+ to choose from.
90
+
91
+ Returns:
92
+ save_path_full (str) -- the absolute path to the downloaded data.
93
+
94
+ """
95
+ if dataset is None:
96
+ selected_dataset = self._present_options()
97
+ else:
98
+ selected_dataset = dataset
99
+
100
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
101
+
102
+ if isdir(save_path_full):
103
+ warn("\n'{0}' already exists. Voiding Download.".format(
104
+ save_path_full))
105
+ else:
106
+ self._print('Downloading Data...')
107
+ url = "{0}/{1}".format(self.url, selected_dataset)
108
+ self._download_data(url, save_path=save_path)
109
+
110
+ return abspath(save_path_full)
util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
util/util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+
9
+ def tensor2im(input_image, imtype=np.uint8):
10
+ """"Converts a Tensor array into a numpy image array.
11
+
12
+ Parameters:
13
+ input_image (tensor) -- the input image tensor array
14
+ imtype (type) -- the desired type of the converted numpy array
15
+ """
16
+ if not isinstance(input_image, np.ndarray):
17
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
18
+ image_tensor = input_image.data
19
+ else:
20
+ return input_image
21
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
22
+ if image_numpy.shape[0] == 1: # grayscale to RGB
23
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
24
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
25
+ else: # if it is a numpy array, do nothing
26
+ image_numpy = input_image
27
+ return image_numpy.astype(imtype)
28
+
29
+
30
+ def diagnose_network(net, name='network'):
31
+ """Calculate and print the mean of average absolute(gradients)
32
+
33
+ Parameters:
34
+ net (torch network) -- Torch network
35
+ name (str) -- the name of the network
36
+ """
37
+ mean = 0.0
38
+ count = 0
39
+ for param in net.parameters():
40
+ if param.grad is not None:
41
+ mean += torch.mean(torch.abs(param.grad.data))
42
+ count += 1
43
+ if count > 0:
44
+ mean = mean / count
45
+ print(name)
46
+ print(mean)
47
+
48
+
49
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
50
+ """Save a numpy image to the disk
51
+
52
+ Parameters:
53
+ image_numpy (numpy array) -- input numpy array
54
+ image_path (str) -- the path of the image
55
+ """
56
+
57
+ image_pil = Image.fromarray(image_numpy)
58
+ h, w, _ = image_numpy.shape
59
+
60
+ if aspect_ratio > 1.0:
61
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
62
+ if aspect_ratio < 1.0:
63
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
64
+ image_pil.save(image_path)
65
+
66
+
67
+ def print_numpy(x, val=True, shp=False):
68
+ """Print the mean, min, max, median, std, and size of a numpy array
69
+
70
+ Parameters:
71
+ val (bool) -- if print the values of the numpy array
72
+ shp (bool) -- if print the shape of the numpy array
73
+ """
74
+ x = x.astype(np.float64)
75
+ if shp:
76
+ print('shape,', x.shape)
77
+ if val:
78
+ x = x.flatten()
79
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
80
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
81
+
82
+
83
+ def mkdirs(paths):
84
+ """create empty directories if they don't exist
85
+
86
+ Parameters:
87
+ paths (str list) -- a list of directory paths
88
+ """
89
+ if isinstance(paths, list) and not isinstance(paths, str):
90
+ for path in paths:
91
+ mkdir(path)
92
+ else:
93
+ mkdir(paths)
94
+
95
+
96
+ def mkdir(path):
97
+ """create a single empty directory if it didn't exist
98
+
99
+ Parameters:
100
+ path (str) -- a single directory path
101
+ """
102
+ if not os.path.exists(path):
103
+ os.makedirs(path)
util/visualizer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+
9
+
10
+ try:
11
+ import wandb
12
+ except ImportError:
13
+ print('Warning: wandb package cannot be found. The option "--use_wandb" will result in error.')
14
+
15
+ if sys.version_info[0] == 2:
16
+ VisdomExceptionBase = Exception
17
+ else:
18
+ VisdomExceptionBase = ConnectionError
19
+
20
+
21
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, use_wandb=False):
22
+ """Save images to the disk.
23
+
24
+ Parameters:
25
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
26
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
27
+ image_path (str) -- the string is used to create image paths
28
+ aspect_ratio (float) -- the aspect ratio of saved images
29
+ width (int) -- the images will be resized to width x width
30
+
31
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
32
+ """
33
+ image_dir = webpage.get_image_dir()
34
+ short_path = ntpath.basename(image_path[0])
35
+ name = os.path.splitext(short_path)[0]
36
+
37
+ webpage.add_header(name)
38
+ ims, txts, links = [], [], []
39
+ ims_dict = {}
40
+ for label, im_data in visuals.items():
41
+ im = util.tensor2im(im_data)
42
+ image_name = '%s_%s.png' % (name, label)
43
+ save_path = os.path.join(image_dir, image_name)
44
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
45
+ ims.append(image_name)
46
+ txts.append(label)
47
+ links.append(image_name)
48
+ if use_wandb:
49
+ ims_dict[label] = wandb.Image(im)
50
+ webpage.add_images(ims, txts, links, width=width)
51
+ if use_wandb:
52
+ wandb.log(ims_dict)
53
+
54
+
55
+ class Visualizer():
56
+ """This class includes several functions that can display/save images and print/save logging information.
57
+
58
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
59
+ """
60
+
61
+ def __init__(self, opt):
62
+ """Initialize the Visualizer class
63
+
64
+ Parameters:
65
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
66
+ Step 1: Cache the training/test options
67
+ Step 2: connect to a visdom server
68
+ Step 3: create an HTML object for saveing HTML filters
69
+ Step 4: create a logging file to store training losses
70
+ """
71
+ self.opt = opt # cache the option
72
+ self.display_id = opt.display_id
73
+ self.use_html = opt.isTrain and not opt.no_html
74
+ self.win_size = opt.display_winsize
75
+ self.name = opt.name
76
+ self.port = opt.display_port
77
+ self.saved = False
78
+ self.use_wandb = opt.use_wandb
79
+ self.current_epoch = 0
80
+ self.ncols = opt.display_ncols
81
+
82
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
83
+ import visdom
84
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
85
+ if not self.vis.check_connection():
86
+ self.create_visdom_connections()
87
+
88
+ if self.use_wandb:
89
+ self.wandb_run = wandb.init(project='CycleGAN-and-pix2pix', name=opt.name, config=opt) if not wandb.run else wandb.run
90
+ self.wandb_run._label(repo='CycleGAN-and-pix2pix')
91
+
92
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
93
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
94
+ self.img_dir = os.path.join(self.web_dir, 'images')
95
+ print('create web directory %s...' % self.web_dir)
96
+ util.mkdirs([self.web_dir, self.img_dir])
97
+ # create a logging file to store training losses
98
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
99
+ with open(self.log_name, "a") as log_file:
100
+ now = time.strftime("%c")
101
+ log_file.write('================ Training Loss (%s) ================\n' % now)
102
+
103
+ def reset(self):
104
+ """Reset the self.saved status"""
105
+ self.saved = False
106
+
107
+ def create_visdom_connections(self):
108
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
109
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
110
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
111
+ print('Command: %s' % cmd)
112
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
113
+
114
+ def display_current_results(self, visuals, epoch, save_result):
115
+ """Display current results on visdom; save current results to an HTML file.
116
+
117
+ Parameters:
118
+ visuals (OrderedDict) - - dictionary of images to display or save
119
+ epoch (int) - - the current epoch
120
+ save_result (bool) - - if save the current results to an HTML file
121
+ """
122
+ if self.display_id > 0: # show images in the browser using visdom
123
+ ncols = self.ncols
124
+ if ncols > 0: # show all the images in one visdom panel
125
+ ncols = min(ncols, len(visuals))
126
+ h, w = next(iter(visuals.values())).shape[:2]
127
+ table_css = """<style>
128
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
129
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
130
+ </style>""" % (w, h) # create a table css
131
+ # create a table of images.
132
+ title = self.name
133
+ label_html = ''
134
+ label_html_row = ''
135
+ images = []
136
+ idx = 0
137
+ for label, image in visuals.items():
138
+ image_numpy = util.tensor2im(image)
139
+ label_html_row += '<td>%s</td>' % label
140
+ images.append(image_numpy.transpose([2, 0, 1]))
141
+ idx += 1
142
+ if idx % ncols == 0:
143
+ label_html += '<tr>%s</tr>' % label_html_row
144
+ label_html_row = ''
145
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
146
+ while idx % ncols != 0:
147
+ images.append(white_image)
148
+ label_html_row += '<td></td>'
149
+ idx += 1
150
+ if label_html_row != '':
151
+ label_html += '<tr>%s</tr>' % label_html_row
152
+ try:
153
+ self.vis.images(images, nrow=ncols, win=self.display_id + 1,
154
+ padding=2, opts=dict(title=title + ' images'))
155
+ label_html = '<table>%s</table>' % label_html
156
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
157
+ opts=dict(title=title + ' labels'))
158
+ except VisdomExceptionBase:
159
+ self.create_visdom_connections()
160
+
161
+ else: # show each image in a separate visdom panel;
162
+ idx = 1
163
+ try:
164
+ for label, image in visuals.items():
165
+ image_numpy = util.tensor2im(image)
166
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
167
+ win=self.display_id + idx)
168
+ idx += 1
169
+ except VisdomExceptionBase:
170
+ self.create_visdom_connections()
171
+
172
+ if self.use_wandb:
173
+ columns = [key for key, _ in visuals.items()]
174
+ columns.insert(0,'epoch')
175
+ result_table = wandb.Table(columns=columns)
176
+ table_row = [epoch]
177
+ ims_dict = {}
178
+ for label, image in visuals.items():
179
+ image_numpy = util.tensor2im(image)
180
+ wandb_image = wandb.Image(image_numpy)
181
+ table_row.append(wandb_image)
182
+ ims_dict[label] = wandb_image
183
+ self.wandb_run.log(ims_dict)
184
+ if epoch != self.current_epoch:
185
+ self.current_epoch = epoch
186
+ result_table.add_data(*table_row)
187
+ self.wandb_run.log({"Result": result_table})
188
+
189
+
190
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
191
+ self.saved = True
192
+ # save images to the disk
193
+ for label, image in visuals.items():
194
+ image_numpy = util.tensor2im(image)
195
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
196
+ util.save_image(image_numpy, img_path)
197
+
198
+ # update website
199
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
200
+ for n in range(epoch, 0, -1):
201
+ webpage.add_header('epoch [%d]' % n)
202
+ ims, txts, links = [], [], []
203
+
204
+ for label, image_numpy in visuals.items():
205
+ image_numpy = util.tensor2im(image)
206
+ img_path = 'epoch%.3d_%s.png' % (n, label)
207
+ ims.append(img_path)
208
+ txts.append(label)
209
+ links.append(img_path)
210
+ webpage.add_images(ims, txts, links, width=self.win_size)
211
+ webpage.save()
212
+
213
+ def plot_current_losses(self, epoch, counter_ratio, losses):
214
+ """display the current losses on visdom display: dictionary of error labels and values
215
+
216
+ Parameters:
217
+ epoch (int) -- current epoch
218
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
219
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
220
+ """
221
+ if not hasattr(self, 'plot_data'):
222
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
223
+ self.plot_data['X'].append(epoch + counter_ratio)
224
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
225
+ try:
226
+ self.vis.line(
227
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
228
+ Y=np.array(self.plot_data['Y']),
229
+ opts={
230
+ 'title': self.name + ' loss over time',
231
+ 'legend': self.plot_data['legend'],
232
+ 'xlabel': 'epoch',
233
+ 'ylabel': 'loss'},
234
+ win=self.display_id)
235
+ except VisdomExceptionBase:
236
+ self.create_visdom_connections()
237
+ if self.use_wandb:
238
+ self.wandb_run.log(losses)
239
+
240
+ # losses: same format as |losses| of plot_current_losses
241
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
242
+ """print current losses on console; also save the losses to the disk
243
+
244
+ Parameters:
245
+ epoch (int) -- current epoch
246
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
247
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
248
+ t_comp (float) -- computational time per data point (normalized by batch_size)
249
+ t_data (float) -- data loading time per data point (normalized by batch_size)
250
+ """
251
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
252
+ for k, v in losses.items():
253
+ message += '%s: %.3f ' % (k, v)
254
+
255
+ print(message) # print the message
256
+ with open(self.log_name, "a") as log_file:
257
+ log_file.write('%s\n' % message) # save the message