"""Utility functions.""" import base64 import os import subprocess import cv2 import numpy as np import torch from models import MODEL_ZOO from models import build_generator from models import parse_gan_type __all__ = ['postprocess', 'load_generator', 'factorize_weight', 'HtmlPageVisualizer'] CHECKPOINT_DIR = 'checkpoints' def to_tensor(array): """Converts a `numpy.ndarray` to `torch.Tensor`. Args: array: The input array to convert. Returns: A `torch.Tensor` with dtype `torch.FloatTensor` on cuda device. """ assert isinstance(array, np.ndarray) return torch.from_numpy(array).type(torch.FloatTensor) def postprocess(images, min_val=-1.0, max_val=1.0): """Post-processes images from `torch.Tensor` to `numpy.ndarray`. Args: images: A `torch.Tensor` with shape `NCHW` to process. min_val: The minimum value of the input tensor. (default: -1.0) max_val: The maximum value of the input tensor. (default: 1.0) Returns: A `numpy.ndarray` with shape `NHWC` and pixel range [0, 255]. """ assert isinstance(images, torch.Tensor) images = images.detach().cpu().numpy() images = (images - min_val) * 255 / (max_val - min_val) images = np.clip(images + 0.5, 0, 255).astype(np.uint8) images = images.transpose(0, 2, 3, 1) return images def load_generator(model_name, from_hf_hub=False): """Loads pre-trained generator. Args: model_name: Name of the model. Should be a key in `models.MODEL_ZOO`. Returns: A generator, which is a `torch.nn.Module`, with pre-trained weights loaded. Raises: KeyError: If the input `model_name` is not in `models.MODEL_ZOO`. """ if model_name not in MODEL_ZOO: raise KeyError(f'Unknown model name `{model_name}`!') model_config = MODEL_ZOO[model_name].copy() url = model_config.pop('url') # URL to download model if needed. # Build generator. print(f'Building generator for model `{model_name}` ...') generator = build_generator(**model_config) print(f'Finish building generator.') if from_hf_hub and "hf_hub_repo" in model_config.keys(): checkpoint = generator.from_pretrained(model_config["hf_hub_repo"]) generator.load_state_dict(checkpoint) print("loaded from hf_hub") else: # Load pre-trained weights. os.makedirs(CHECKPOINT_DIR, exist_ok=True) checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth') print(f'Loading checkpoint from `{checkpoint_path}` ...') if not os.path.exists(checkpoint_path): print(f' Downloading checkpoint from `{url}` ...') subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url]) print(f' Finish downloading checkpoint.') checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'generator_smooth' in checkpoint: generator.load_state_dict(checkpoint['generator_smooth']) else: generator.load_state_dict(checkpoint['generator']) #generator = generator.cuda() generator.eval() print(f'Finish loading checkpoint.') return generator def parse_indices(obj, min_val=None, max_val=None): """Parses indices. The input can be a list or a tuple or a string, which is either a comma separated list of numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will be ignored. Args: obj: The input object to parse indices from. min_val: If not `None`, this function will check that all indices are equal to or larger than this value. (default: None) max_val: If not `None`, this function will check that all indices are equal to or smaller than this value. (default: None) Returns: A list of integers. Raises: If the input is invalid, i.e., neither a list or tuple, nor a string. """ if obj is None or obj == '': indices = [] elif isinstance(obj, int): indices = [obj] elif isinstance(obj, (list, tuple, np.ndarray)): indices = list(obj) elif isinstance(obj, str): indices = [] splits = obj.replace(' ', '').split(',') for split in splits: numbers = list(map(int, split.split('-'))) if len(numbers) == 1: indices.append(numbers[0]) elif len(numbers) == 2: indices.extend(list(range(numbers[0], numbers[1] + 1))) else: raise ValueError(f'Unable to parse the input!') else: raise ValueError(f'Invalid type of input: `{type(obj)}`!') assert isinstance(indices, list) indices = sorted(list(set(indices))) for idx in indices: assert isinstance(idx, int) if min_val is not None: assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' if max_val is not None: assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' return indices def factorize_weight(generator, layer_idx='all'): """Factorizes the generator weight to get semantics boundaries. Args: generator: Generator to factorize. layer_idx: Indices of layers to interpret, especially for StyleGAN and StyleGAN2. (default: `all`) Returns: A tuple of (layers_to_interpret, semantic_boundaries, eigen_values). Raises: ValueError: If the generator type is not supported. """ # Get GAN type. gan_type = parse_gan_type(generator) # Get layers. if gan_type == 'pggan': layers = [0] elif gan_type in ['stylegan', 'stylegan2']: if layer_idx == 'all': layers = list(range(generator.num_layers)) else: layers = parse_indices(layer_idx, min_val=0, max_val=generator.num_layers - 1) # Factorize semantics from weight. weights = [] for idx in layers: layer_name = f'layer{idx}' if gan_type == 'stylegan2' and idx == generator.num_layers - 1: layer_name = f'output{idx // 2}' if gan_type == 'pggan': weight = generator.__getattr__(layer_name).weight weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1) elif gan_type in ['stylegan', 'stylegan2']: weight = generator.synthesis.__getattr__(layer_name).style.weight.T weights.append(weight.cpu().detach().numpy()) weight = np.concatenate(weights, axis=1).astype(np.float32) weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T)) return layers, eigen_vectors.T, eigen_values def get_sortable_html_header(column_name_list, sort_by_ascending=False): """Gets header for sortable html page. Basically, the html page contains a sortable table, where user can sort the rows by a particular column by clicking the column head. Example: column_name_list = [name_1, name_2, name_3] header = get_sortable_html_header(column_name_list) footer = get_sortable_html_footer() sortable_table = ... html_page = header + sortable_table + footer Args: column_name_list: List of column header names. sort_by_ascending: Default sorting order. If set as `True`, the html page will be sorted by ascending order when the header is clicked for the first time. Returns: A string, which represents for the header for a sortable html page. """ header = '\n'.join([ '', '', '', '', '', '', '', '', '', '', '', '', '', '']) for idx, name in enumerate(column_name_list): header += f' \n' header += '\n' header += '\n' header += '\n' return header def get_sortable_html_footer(): """Gets footer for sortable html page. Check function `get_sortable_html_header()` for more details. """ return '\n
{name}
\n\n\n\n' def parse_image_size(obj): """Parses object to a pair of image size, i.e., (width, height). Args: obj: The input object to parse image size from. Returns: A two-element tuple, indicating image width and height respectively. Raises: If the input is invalid, i.e., neither a list or tuple, nor a string. """ if obj is None or obj == '': width = height = 0 elif isinstance(obj, int): width = height = obj elif isinstance(obj, (list, tuple, np.ndarray)): numbers = tuple(obj) if len(numbers) == 0: width = height = 0 elif len(numbers) == 1: width = height = numbers[0] elif len(numbers) == 2: width = numbers[0] height = numbers[1] else: raise ValueError(f'At most two elements for image size.') elif isinstance(obj, str): splits = obj.replace(' ', '').split(',') numbers = tuple(map(int, splits)) if len(numbers) == 0: width = height = 0 elif len(numbers) == 1: width = height = numbers[0] elif len(numbers) == 2: width = numbers[0] height = numbers[1] else: raise ValueError(f'At most two elements for image size.') else: raise ValueError(f'Invalid type of input: {type(obj)}!') return (max(0, width), max(0, height)) def encode_image_to_html_str(image, image_size=None): """Encodes an image to html language. NOTE: Input image is always assumed to be with `RGB` channel order. Args: image: The input image to encode. Should be with `RGB` channel order. image_size: This field is used to resize the image before encoding. `0` disables resizing. (default: None) Returns: A string which represents the encoded image. """ if image is None: return '' assert image.ndim == 3 and image.shape[2] in [1, 3] # Change channel order to `BGR`, which is opencv-friendly. image = image[:, :, ::-1] # Resize the image if needed. width, height = parse_image_size(image_size) if height or width: height = height or image.shape[0] width = width or image.shape[1] image = cv2.resize(image, (width, height)) # Encode the image to html-format string. encoded_image = cv2.imencode('.jpg', image)[1].tostring() encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') html_str = f'' return html_str def get_grid_shape(size, row=0, col=0, is_portrait=False): """Gets the shape of a grid based on the size. This function makes greatest effort on making the output grid square if neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height will always be equal to or smaller than the width. For example, if input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape will be (3, 5). Otherwise, the height will always be equal to or larger than the width. Args: size: Size (height * width) of the target grid. is_portrait: Whether to return a portrait size of a landscape size. (default: False) Returns: A two-element tuple, representing height and width respectively. """ assert isinstance(size, int) assert isinstance(row, int) assert isinstance(col, int) if size == 0: return (0, 0) if row > 0 and col > 0 and row * col != size: row = 0 col = 0 if row > 0 and size % row == 0: return (row, size // row) if col > 0 and size % col == 0: return (size // col, col) row = int(np.sqrt(size)) while row > 0: if size % row == 0: col = size // row break row = row - 1 return (col, row) if is_portrait else (row, col) class HtmlPageVisualizer(object): """Defines the html page visualizer. This class can be used to visualize image results as html page. Basically, it is based on an html-format sorted table with helper functions `get_sortable_html_header()`, `get_sortable_html_footer()`, and `encode_image_to_html_str()`. To simplify the usage, specifying the following fields are enough to create a visualization page: (1) num_rows: Number of rows of the table (header-row exclusive). (2) num_cols: Number of columns of the table. (3) header contents (optional): Title of each column. NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` automatically. Example: html = HtmlPageVisualizer(num_rows, num_cols) html.set_headers([...]) for i in range(num_rows): for j in range(num_cols): html.set_cell(i, j, text=..., image=..., highlight=False) html.save('visualize.html') """ def __init__(self, num_rows=0, num_cols=0, grid_size=0, is_portrait=True, viz_size=None): if grid_size > 0: num_rows, num_cols = get_grid_shape( grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) assert num_rows > 0 and num_cols > 0 self.num_rows = num_rows self.num_cols = num_cols self.viz_size = parse_image_size(viz_size) self.headers = ['' for _ in range(self.num_cols)] self.cells = [[{ 'text': '', 'image': '', 'highlight': False, } for _ in range(self.num_cols)] for _ in range(self.num_rows)] def set_header(self, col_idx, content): """Sets the content of a particular header by column index.""" self.headers[col_idx] = content def set_headers(self, contents): """Sets the contents of all headers.""" if isinstance(contents, str): contents = [contents] assert isinstance(contents, (list, tuple)) assert len(contents) == self.num_cols for col_idx, content in enumerate(contents): self.set_header(col_idx, content) def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False): """Sets the content of a particular cell. Basically, a cell contains some text as well as an image. Both text and image can be empty. Args: row_idx: Row index of the cell to edit. col_idx: Column index of the cell to edit. text: Text to add into the target cell. (default: None) image: Image to show in the target cell. Should be with `RGB` channel order. (default: None) highlight: Whether to highlight this cell. (default: False) """ self.cells[row_idx][col_idx]['text'] = text self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str( image, self.viz_size) self.cells[row_idx][col_idx]['highlight'] = bool(highlight) def save(self, save_path): """Saves the html page.""" html = '' for i in range(self.num_rows): html += f'\n' for j in range(self.num_cols): text = self.cells[i][j]['text'] image = self.cells[i][j]['image'] if self.cells[i][j]['highlight']: color = ' bgcolor="#FF8888"' else: color = '' if text: html += f' {text}

{image}\n' else: html += f' {image}\n' html += f'\n' header = get_sortable_html_header(self.headers) footer = get_sortable_html_footer() with open(save_path, 'w') as f: f.write(header + html + footer)