import os import ntpath import time from . import util from . import html import numpy as np from PIL import Image as PILImage import torch from collections import OrderedDict try: from StringIO import StringIO except ImportError: from io import BytesIO class Visualizer(): def __init__(self, opt): self.opt = opt self.tf_log = opt.isTrain and opt.tf_log self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name if self.tf_log: import tensorflow as tf self.tf = tf self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') self.writer = tf.summary.FileWriter(self.log_dir) if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) if opt.isTrain: self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def convert_map_to_numpy(self, data_map): if data_map is None or not isinstance(data_map, torch.Tensor): return None if data_map.dim() == 4: data_map = data_map[0] if data_map.size(0) > 1: data_map = data_map[0, :, :].unsqueeze(0) map_numpy = data_map.cpu().float().numpy() min_val, max_val = np.min(map_numpy), np.max(map_numpy) if max_val - min_val > 1e-6: map_numpy = (map_numpy - min_val) / (max_val - min_val) else: map_numpy = np.zeros_like(map_numpy) map_numpy = (map_numpy * 255.0).astype(np.uint8) if map_numpy.shape[0] == 1: map_numpy = np.transpose(map_numpy, (1, 2, 0)) map_numpy = np.repeat(map_numpy, 3, axis=2) else: map_numpy = np.stack((map_numpy,) * 3, axis=-1) return map_numpy def display_current_results(self, visuals, epoch, step): visuals_np = OrderedDict() for label, image in visuals.items(): if image is None: continue if 'light_map' in label: image_numpy = self.convert_map_to_numpy(image) elif 'input_label' in label: image_numpy = util.tensor2label(image, self.opt.label_nc, tile=False) else: image_numpy = util.tensor2im(image, tile=False) if image_numpy.ndim == 4: image_numpy = image_numpy[0] visuals_np[label] = image_numpy if self.tf_log: img_summaries = [] for label, image_numpy in visuals_np.items(): if image_numpy is None: continue try: s = BytesIO() pil_img = PILImage.fromarray(image_numpy) pil_img.save(s, format="jpeg") img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) img_summaries.append(self.tf.Summary.Value(tag=f'epoch_{epoch}/{label}', image=img_sum)) except Exception as e: print(f"Could not write image {label} to TF logs: {e}") if img_summaries: summary = self.tf.Summary(value=img_summaries) self.writer.add_summary(summary, step) if self.use_html: webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) webpage.add_header('Epoch [%d] Iteration [%d]' % (epoch, step)) visuals_for_html = [] labels_for_html = [] standard_height = self.opt.crop_size for label, image_numpy in visuals_np.items(): if image_numpy is None: continue pil_img = PILImage.fromarray(image_numpy) if pil_img.height != standard_height: aspect_ratio = pil_img.width / pil_img.height new_width = int(standard_height * aspect_ratio) pil_img = pil_img.resize((new_width, standard_height), PILImage.LANCZOS) visuals_for_html.append(np.array(pil_img)) labels_for_html.append(label) if not visuals_for_html: return try: concatenated_image = np.concatenate(visuals_for_html, axis=1) image_name = 'epoch%.3d_iter%.7d_combined.png' % (epoch, step) save_path = os.path.join(self.img_dir, image_name) util.save_image(concatenated_image, save_path) webpage.add_images([image_name], [' | '.join(labels_for_html)], [image_name], width=self.win_size * len(visuals_for_html)) webpage.save() except ValueError as e: print(f"Error during HTML image concatenation for step {step}: {e}") print("Skipping HTML log for this step. Image shapes might be incompatible even after resizing.") def plot_current_errors(self, errors, step): if self.tf_log: for tag, value in errors.items(): if isinstance(value, torch.Tensor): value_to_log = value.mean().float().item() elif isinstance(value, (float, int)): value_to_log = float(value) else: continue summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value_to_log)]) self.writer.add_summary(summary, step) def print_current_errors(self, epoch, i, errors, t): message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) for k, v_orig in errors.items(): v_to_print = v_orig if isinstance(v_orig, torch.Tensor): if v_orig.numel() > 0: v_to_print = v_orig.mean().item() else: v_to_print = 0.0 elif not isinstance(v_orig, (float, int)): continue message += '%s: %.3f ' % (k, float(v_to_print)) print(message) with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) def save_images(self, webpage, visuals, image_path_list, alpha=1.0): visuals_np = OrderedDict() for label, image in visuals.items(): if 'light_map' in label: visuals_np[label] = self.convert_map_to_numpy(image) else: visuals_np[label] = util.tensor2im(image) base_image_dir = webpage.get_image_dir() image_path_str = image_path_list[0] if isinstance(image_path_list, (list, tuple)) else image_path_list short_path = ntpath.basename(image_path_str) name_prefix = os.path.splitext(short_path)[0] current_alpha_float = alpha if isinstance(current_alpha_float, torch.Tensor): current_alpha_float = current_alpha_float.mean().item() elif not isinstance(current_alpha_float, (float, int)): try: current_alpha_float = float(current_alpha_float) except ValueError: current_alpha_float = 1.0 alpha_folder_name = "alpha_{:.3f}".format(current_alpha_float).replace('.', '_') specific_alpha_image_dir = os.path.join(base_image_dir, alpha_folder_name) util.mkdirs(specific_alpha_image_dir) image_name_final = '%s.png' % (name_prefix) save_path = os.path.join(specific_alpha_image_dir, image_name_final) images_to_concatenate = [] for label, image_numpy in visuals_np.items(): img_to_add = image_numpy if image_numpy.ndim == 4 and image_numpy.shape[0] == 1: img_to_add = image_numpy.squeeze(0) elif image_numpy.ndim != 2 and image_numpy.ndim != 3: continue if img_to_add.ndim == 2: img_to_add = np.stack((img_to_add,) * 3, axis=-1) if img_to_add.ndim == 3 and img_to_add.shape[2] == 1: img_to_add = np.concatenate([img_to_add] * 3, axis=2) if img_to_add.shape[2] == 3: images_to_concatenate.append(img_to_add) if not images_to_concatenate: return try: image_concatenated_horizontally = np.concatenate(images_to_concatenate, axis=1) util.save_image(image_concatenated_horizontally, save_path, create_dir=True) except ValueError as e: print(f"Error concatenating images for {save_path}: {e}") print("Concatenated images list content (shapes):") for idx, vis_np_item in enumerate(images_to_concatenate): print(f" Visual {idx}: shape {vis_np_item.shape if hasattr(vis_np_item, 'shape') else 'N/A'}") relative_image_path_for_html = os.path.join(alpha_folder_name, image_name_final) webpage.add_images([relative_image_path_for_html], [f"{name_prefix}_alpha_{current_alpha_float:.3f}"], [relative_image_path_for_html], width=self.win_size * len(images_to_concatenate))