Spaces:
Runtime error
Runtime error
# python3.7 | |
"""Contains the base class for generator.""" | |
import os | |
import sys | |
import logging | |
import numpy as np | |
import torch | |
from . import model_settings | |
__all__ = ['BaseGenerator'] | |
def get_temp_logger(logger_name='logger'): | |
"""Gets a temporary logger. | |
This logger will print all levels of messages onto the screen. | |
Args: | |
logger_name: Name of the logger. | |
Returns: | |
A `logging.Logger`. | |
Raises: | |
ValueError: If the input `logger_name` is empty. | |
""" | |
if not logger_name: | |
raise ValueError(f'Input `logger_name` should not be empty!') | |
logger = logging.getLogger(logger_name) | |
if not logger.hasHandlers(): | |
logger.setLevel(logging.DEBUG) | |
formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s") | |
sh = logging.StreamHandler(stream=sys.stdout) | |
sh.setLevel(logging.DEBUG) | |
sh.setFormatter(formatter) | |
logger.addHandler(sh) | |
return logger | |
class BaseGenerator(object): | |
"""Base class for generator used in GAN variants. | |
NOTE: The model should be defined with pytorch, and only used for inference. | |
""" | |
def __init__(self, model_name, logger=None): | |
"""Initializes with specific settings. | |
The model should be registered in `model_settings.py` with proper settings | |
first. Among them, some attributes are necessary, including: | |
(1) gan_type: Type of the GAN model. | |
(2) latent_space_dim: Dimension of the latent space. Should be a tuple. | |
(3) resolution: Resolution of the synthesis. | |
(4) min_val: Minimum value of the raw output. (default -1.0) | |
(5) max_val: Maximum value of the raw output. (default 1.0) | |
(6) channel_order: Channel order of the output image. (default: `RGB`) | |
Args: | |
model_name: Name with which the model is registered. | |
logger: Logger for recording log messages. If set as `None`, a default | |
logger, which prints messages from all levels to screen, will be | |
created. (default: None) | |
Raises: | |
AttributeError: If some necessary attributes are missing. | |
""" | |
self.model_name = model_name | |
for key, val in model_settings.MODEL_POOL[model_name].items(): | |
setattr(self, key, val) | |
self.use_cuda = model_settings.USE_CUDA | |
self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE | |
self.logger = logger or get_temp_logger(model_name + '_generator') | |
self.model = None | |
self.run_device = 'cuda' if self.use_cuda else 'cpu' | |
self.cpu_device = 'cpu' | |
# Check necessary settings. | |
self.check_attr('gan_type') | |
self.check_attr('latent_space_dim') | |
self.check_attr('resolution') | |
self.min_val = getattr(self, 'min_val', -1.0) | |
self.max_val = getattr(self, 'max_val', 1.0) | |
self.output_channels = getattr(self, 'output_channels', 3) | |
self.channel_order = getattr(self, 'channel_order', 'RGB').upper() | |
assert self.channel_order in ['RGB', 'BGR'] | |
# Build model and load pre-trained weights. | |
self.build() | |
if os.path.isfile(getattr(self, 'model_path', '')): | |
self.load() | |
elif os.path.isfile(getattr(self, 'tf_model_path', '')): | |
self.convert_tf_model() | |
else: | |
self.logger.warning(f'No pre-trained model will be loaded!') | |
# Change to inference mode and GPU mode if needed. | |
assert self.model | |
self.model.eval().to(self.run_device) | |
def check_attr(self, attr_name): | |
"""Checks the existence of a particular attribute. | |
Args: | |
attr_name: Name of the attribute to check. | |
Raises: | |
AttributeError: If the target attribute is missing. | |
""" | |
if not hasattr(self, attr_name): | |
raise AttributeError( | |
f'`{attr_name}` is missing for model `{self.model_name}`!') | |
def build(self): | |
"""Builds the graph.""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def load(self): | |
"""Loads pre-trained weights.""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def convert_tf_model(self, test_num=10): | |
"""Converts models weights from tensorflow version. | |
Args: | |
test_num: Number of images to generate for testing whether the conversion | |
is done correctly. `0` means skipping the test. (default 10) | |
""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def sample(self, num): | |
"""Samples latent codes randomly. | |
Args: | |
num: Number of latent codes to sample. Should be positive. | |
Returns: | |
A `numpy.ndarray` as sampled latend codes. | |
""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def preprocess(self, latent_codes): | |
"""Preprocesses the input latent code if needed. | |
Args: | |
latent_codes: The input latent codes for preprocessing. | |
Returns: | |
The preprocessed latent codes which can be used as final input for the | |
generator. | |
""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def easy_sample(self, num): | |
"""Wraps functions `sample()` and `preprocess()` together.""" | |
return self.preprocess(self.sample(num)) | |
def synthesize(self, latent_codes): | |
"""Synthesizes images with given latent codes. | |
NOTE: The latent codes should have already been preprocessed. | |
Args: | |
latent_codes: Input latent codes for image synthesis. | |
Returns: | |
A dictionary whose values are raw outputs from the generator. | |
""" | |
raise NotImplementedError(f'Should be implemented in derived class!') | |
def get_value(self, tensor): | |
"""Gets value of a `torch.Tensor`. | |
Args: | |
tensor: The input tensor to get value from. | |
Returns: | |
A `numpy.ndarray`. | |
Raises: | |
ValueError: If the tensor is with neither `torch.Tensor` type or | |
`numpy.ndarray` type. | |
""" | |
if isinstance(tensor, np.ndarray): | |
return tensor | |
if isinstance(tensor, torch.Tensor): | |
return tensor.to(self.cpu_device).detach().numpy() | |
raise ValueError(f'Unsupported input type `{type(tensor)}`!') | |
def postprocess(self, images): | |
"""Postprocesses the output images if needed. | |
This function assumes the input numpy array is with shape [batch_size, | |
channel, height, width]. Here, `channel = 3` for color image and | |
`channel = 1` for grayscale image. The return images are with shape | |
[batch_size, height, width, channel]. NOTE: The channel order of output | |
image will always be `RGB`. | |
Args: | |
images: The raw output from the generator. | |
Returns: | |
The postprocessed images with dtype `numpy.uint8` with range [0, 255]. | |
Raises: | |
ValueError: If the input `images` are not with type `numpy.ndarray` or not | |
with shape [batch_size, channel, height, width]. | |
""" | |
if not isinstance(images, np.ndarray): | |
raise ValueError(f'Images should be with type `numpy.ndarray`!') | |
if ('stylegan3' not in self.model_name) and ('stylegan2' not in self.model_name): | |
images_shape = images.shape | |
if len(images_shape) != 4 or images_shape[1] not in [1, 3]: | |
raise ValueError(f'Input should be with shape [batch_size, channel, ' | |
f'height, width], where channel equals to 1 or 3. ' | |
f'But {images_shape} is received!') | |
images = (images - self.min_val) * 255 / (self.max_val - self.min_val) | |
images = np.clip(images + 0.5, 0, 255).astype(np.uint8) | |
images = images.transpose(0, 2, 3, 1) | |
if self.channel_order == 'BGR': | |
images = images[:, :, :, ::-1] | |
return images | |
def easy_synthesize(self, latent_codes, **kwargs): | |
"""Wraps functions `synthesize()` and `postprocess()` together.""" | |
outputs = self.synthesize(latent_codes, **kwargs) | |
if 'image' in outputs: | |
outputs['image'] = self.postprocess(outputs['image']) | |
return outputs | |
def get_batch_inputs(self, latent_codes): | |
"""Gets batch inputs from a collection of latent codes. | |
This function will yield at most `self.batch_size` latent_codes at a time. | |
Args: | |
latent_codes: The input latent codes for generation. First dimension | |
should be the total number. | |
""" | |
total_num = latent_codes.shape[0] | |
for i in range(0, total_num, self.batch_size): | |
yield latent_codes[i:i + self.batch_size] | |