Spaces:
Runtime error
Runtime error
# python 3.7 | |
"""Utility functions for image editing from latent space.""" | |
import os.path | |
import numpy as np | |
__all__ = [ | |
'parse_indices', 'interpolate', 'mix_style', | |
'get_layerwise_manipulation_strength', 'manipulate', 'parse_boundary_list' | |
] | |
def parse_indices(obj, min_val=None, max_val=None): | |
"""Parses indices. | |
If the input is a list or tuple, this function has no effect. | |
The input can also be a string, which is either a comma separated list of | |
numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will | |
be ignored. | |
Args: | |
obj: The input object to parse indices from. | |
min_val: If not `None`, this function will check that all indices are equal | |
to or larger than this value. (default: None) | |
max_val: If not `None`, this function will check that all indices are equal | |
to or smaller than this field. (default: None) | |
Returns: | |
A list of integers. | |
Raises: | |
If the input is invalid, i.e., neither a list or tuple, nor a string. | |
""" | |
if obj is None or obj == '': | |
indices = [] | |
elif isinstance(obj, int): | |
indices = [obj] | |
elif isinstance(obj, (list, tuple, np.ndarray)): | |
indices = list(obj) | |
elif isinstance(obj, str): | |
indices = [] | |
splits = obj.replace(' ', '').split(',') | |
for split in splits: | |
numbers = list(map(int, split.split('-'))) | |
if len(numbers) == 1: | |
indices.append(numbers[0]) | |
elif len(numbers) == 2: | |
indices.extend(list(range(numbers[0], numbers[1] + 1))) | |
else: | |
raise ValueError(f'Invalid type of input: {type(obj)}!') | |
assert isinstance(indices, list) | |
indices = sorted(list(set(indices))) | |
for idx in indices: | |
assert isinstance(idx, int) | |
if min_val is not None: | |
assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' | |
if max_val is not None: | |
assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' | |
return indices | |
def interpolate(src_codes, dst_codes, step=5): | |
"""Interpolates two sets of latent codes linearly. | |
Args: | |
src_codes: Source codes, with shape [num, *code_shape]. | |
dst_codes: Target codes, with shape [num, *code_shape]. | |
step: Number of interplolation steps, with source and target included. For | |
example, if `step = 5`, three more samples will be inserted. (default: 5) | |
Returns: | |
Interpolated codes, with shape [num, step, *code_shape]. | |
Raises: | |
ValueError: If the input two sets of latent codes are with different shapes. | |
""" | |
if not (src_codes.ndim >= 2 and src_codes.shape == dst_codes.shape): | |
raise ValueError(f'Shapes of source codes and target codes should both be ' | |
f'[num, *code_shape], but {src_codes.shape} and ' | |
f'{dst_codes.shape} are received!') | |
num = src_codes.shape[0] | |
code_shape = src_codes.shape[1:] | |
a = src_codes[:, np.newaxis] | |
b = dst_codes[:, np.newaxis] | |
l = np.linspace(0.0, 1.0, step).reshape( | |
[step if axis == 1 else 1 for axis in range(a.ndim)]) | |
results = a + l * (b - a) | |
assert results.shape == (num, step, *code_shape) | |
return results | |
def mix_style(style_codes, | |
content_codes, | |
num_layers=1, | |
mix_layers=None, | |
is_style_layerwise=True, | |
is_content_layerwise=True): | |
"""Mixes styles from style codes to those of content codes. | |
Each style code or content code consists of `num_layers` codes, each of which | |
is typically fed into a particular layer of the generator. This function mixes | |
styles by partially replacing the codes of `content_codes` from some certain | |
layers with those of `style_codes`. | |
For example, if both style code and content code are with shape [10, 512], | |
meaning to have 10 layers and each employs a 512-dimensional latent code. And | |
the 1st, 2nd, and 3rd layers are the target layers to perform style mixing. | |
Then the top half of the content code (with shape [3, 512]) will be replaced | |
by the top half of the style code (also with shape [3, 512]). | |
NOTE: This function also supports taking single-layer latent codes as inputs, | |
i.e., setting `is_style_layerwise` or `is_content_layerwise` as False. In this | |
case, the corresponding code will be first repeated for `num_layers` before | |
performing style mixing. | |
Args: | |
style_codes: Style codes, with shape [num_styles, *code_shape] or | |
[num_styles, num_layers, *code_shape]. | |
content_codes: Content codes, with shape [num_contents, *code_shape] or | |
[num_contents, num_layers, *code_shape]. | |
num_layers: Total number of layers in the generative model. (default: 1) | |
mix_layers: Indices of the layers to perform style mixing. `None` means to | |
replace all layers, in which case the content code will be completely | |
replaced by style code. (default: None) | |
is_style_layerwise: Indicating whether the input `style_codes` are | |
layer-wise codes. (default: True) | |
is_content_layerwise: Indicating whether the input `content_codes` are | |
layer-wise codes. (default: True) | |
num_layers | |
Returns: | |
Codes after style mixing, with shape [num_styles, num_contents, num_layers, | |
*code_shape]. | |
Raises: | |
ValueError: If input `content_codes` or `style_codes` is with invalid shape. | |
""" | |
if not is_style_layerwise: | |
style_codes = style_codes[:, np.newaxis] | |
style_codes = np.tile( | |
style_codes, | |
[num_layers if axis == 1 else 1 for axis in range(style_codes.ndim)]) | |
if not is_content_layerwise: | |
content_codes = content_codes[:, np.newaxis] | |
content_codes = np.tile( | |
content_codes, | |
[num_layers if axis == 1 else 1 for axis in range(content_codes.ndim)]) | |
if not (style_codes.ndim >= 3 and style_codes.shape[1] == num_layers and | |
style_codes.shape[1:] == content_codes.shape[1:]): | |
raise ValueError(f'Shapes of style codes and content codes should be ' | |
f'[num_styles, num_layers, *code_shape] and ' | |
f'[num_contents, num_layers, *code_shape] respectively, ' | |
f'but {style_codes.shape} and {content_codes.shape} are ' | |
f'received!') | |
layer_indices = parse_indices(mix_layers, min_val=0, max_val=num_layers - 1) | |
if not layer_indices: | |
layer_indices = list(range(num_layers)) | |
num_styles = style_codes.shape[0] | |
num_contents = content_codes.shape[0] | |
code_shape = content_codes.shape[2:] | |
s = style_codes[:, np.newaxis] | |
s = np.tile(s, [num_contents if axis == 1 else 1 for axis in range(s.ndim)]) | |
c = content_codes[np.newaxis] | |
c = np.tile(c, [num_styles if axis == 0 else 1 for axis in range(c.ndim)]) | |
from_style = np.zeros(s.shape, dtype=bool) | |
from_style[:, :, layer_indices] = True | |
results = np.where(from_style, s, c) | |
assert results.shape == (num_styles, num_contents, num_layers, *code_shape) | |
return results | |
def get_layerwise_manipulation_strength(num_layers, | |
truncation_psi, | |
truncation_layers): | |
"""Gets layer-wise strength for manipulation. | |
Recall the truncation trick played on layer [0, truncation_layers): | |
w = truncation_psi * w + (1 - truncation_psi) * w_avg | |
So, when using the same boundary to manipulate different layers, layer | |
[0, truncation_layers) and layer [truncation_layers, num_layers) should use | |
different strength to eliminate the effect from the truncation trick. More | |
concretely, the strength for layer [0, truncation_layers) is set as | |
`truncation_psi`, while that for other layers are set as 1. | |
""" | |
strength = [1.0 for _ in range(num_layers)] | |
if truncation_layers > 0: | |
for layer_idx in range(0, truncation_layers): | |
strength[layer_idx] = truncation_psi | |
return strength | |
def manipulate(latent_codes, | |
boundary, | |
start_distance=-5.0, | |
end_distance=5.0, | |
step=21, | |
layerwise_manipulation=False, | |
num_layers=1, | |
manipulate_layers=None, | |
is_code_layerwise=False, | |
is_boundary_layerwise=False, | |
layerwise_manipulation_strength=1.0): | |
"""Manipulates the given latent codes with respect to a particular boundary. | |
Basically, this function takes a set of latent codes and a boundary as inputs, | |
and outputs a collection of manipulated latent codes. | |
For example, let `step` to be 10, `latent_codes` to be with shape [num, | |
*code_shape], and `boundary` to be with shape [1, *code_shape] and unit norm. | |
Then the output will be with shape [num, 10, *code_shape]. For each 10-element | |
manipulated codes, the first code is `start_distance` away from the original | |
code (i.e., the input) along the `boundary` direction, while the last code is | |
`end_distance` away. Remaining codes are linearly interpolated. Here, | |
`distance` is sign sensitive. | |
NOTE: This function also supports layer-wise manipulation, in which case the | |
generator should be able to take layer-wise latent codes as inputs. For | |
example, if the generator has 18 convolutional layers in total, and each of | |
which takes an independent latent code as input. It is possible, sometimes | |
with even better performance, to only partially manipulate these latent codes | |
corresponding to some certain layers yet keeping others untouched. | |
NOTE: Boundary is assumed to be normalized to unit norm already. | |
Args: | |
latent_codes: The input latent codes for manipulation, with shape | |
[num, *code_shape] or [num, num_layers, *code_shape]. | |
boundary: The semantic boundary as reference, with shape [1, *code_shape] or | |
[1, num_layers, *code_shape]. | |
start_distance: Start point for manipulation. (default: -5.0) | |
end_distance: End point for manipulation. (default: 5.0) | |
step: Number of manipulation steps. (default: 21) | |
layerwise_manipulation: Whether to perform layer-wise manipulation. | |
(default: False) | |
num_layers: Number of layers. Only active when `layerwise_manipulation` is | |
set as `True`. Should be a positive integer. (default: 1) | |
manipulate_layers: Indices of the layers to perform manipulation. `None` | |
means to manipulate latent codes from all layers. (default: None) | |
is_code_layerwise: Whether the input latent codes are layer-wise. If set as | |
`False`, the function will first repeat the input codes for `num_layers` | |
times before perform manipulation. (default: False) | |
is_boundary_layerwise: Whether the input boundary is layer-wise. If set as | |
`False`, the function will first repeat boundary for `num_layers` times | |
before perform manipulation. (default: False) | |
layerwise_manipulation_strength: Manipulation strength for each layer. Only | |
active when `layerwise_manipulation` is set as `True`. This field can be | |
used to resolve the strength discrepancy across layers when truncation | |
trick is on. See function `get_layerwise_manipulation_strength()` for | |
details. A tuple, list, or `numpy.ndarray` is expected. If set as a single | |
number, this strength will be used for all layers. (default: 1.0) | |
Returns: | |
Manipulated codes, with shape [num, step, *code_shape] if | |
`layerwise_manipulation` is set as `False`, or shape [num, step, | |
num_layers, *code_shape] if `layerwise_manipulation` is set as `True`. | |
Raises: | |
ValueError: If the input latent codes, boundary, or strength are with | |
invalid shape. | |
""" | |
if not (boundary.ndim >= 2 and boundary.shape[0] == 1): | |
raise ValueError(f'Boundary should be with shape [1, *code_shape] or ' | |
f'[1, num_layers, *code_shape], but ' | |
f'{boundary.shape} is received!') | |
if not layerwise_manipulation: | |
assert not is_code_layerwise | |
assert not is_boundary_layerwise | |
num_layers = 1 | |
manipulate_layers = None | |
layerwise_manipulation_strength = 1.0 | |
# Preprocessing for layer-wise manipulation. | |
# Parse indices of manipulation layers. | |
layer_indices = parse_indices( | |
manipulate_layers, min_val=0, max_val=num_layers - 1) | |
if not layer_indices: | |
layer_indices = list(range(num_layers)) | |
# Make latent codes layer-wise if needed. | |
assert num_layers > 0 | |
if not is_code_layerwise: | |
x = latent_codes[:, np.newaxis] | |
x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) | |
else: | |
x = latent_codes | |
if x.shape[1] != num_layers: | |
raise ValueError(f'Latent codes should be with shape [num, num_layers, ' | |
f'*code_shape], where `num_layers` equals to ' | |
f'{num_layers}, but {x.shape} is received!') | |
# Make boundary layer-wise if needed. | |
if not is_boundary_layerwise: | |
b = boundary | |
b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) | |
else: | |
b = boundary[0] | |
if b.shape[0] != num_layers: | |
raise ValueError(f'Boundary should be with shape [num_layers, ' | |
f'*code_shape], where `num_layers` equals to ' | |
f'{num_layers}, but {b.shape} is received!') | |
# Get layer-wise manipulation strength. | |
if isinstance(layerwise_manipulation_strength, (int, float)): | |
s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] | |
elif isinstance(layerwise_manipulation_strength, (list, tuple)): | |
s = layerwise_manipulation_strength | |
if len(s) != num_layers: | |
raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' | |
f'mismatches number of layers `{num_layers}`!') | |
elif isinstance(layerwise_manipulation_strength, np.ndarray): | |
s = layerwise_manipulation_strength | |
if s.size != num_layers: | |
raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' | |
f'mismatches number of layers `{num_layers}`!') | |
else: | |
raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') | |
s = np.array(s).reshape( | |
[num_layers if axis == 0 else 1 for axis in range(b.ndim)]) | |
b = b * s | |
if x.shape[1:] != b.shape: | |
raise ValueError(f'Latent code shape {x.shape} and boundary shape ' | |
f'{b.shape} mismatch!') | |
num = x.shape[0] | |
code_shape = x.shape[2:] | |
x = x[:, np.newaxis] | |
b = b[np.newaxis, np.newaxis, :] | |
l = np.linspace(start_distance, end_distance, step).reshape( | |
[step if axis == 1 else 1 for axis in range(x.ndim)]) | |
results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) | |
is_manipulatable = np.zeros(results.shape, dtype=bool) | |
is_manipulatable[:, :, layer_indices] = True | |
results = np.where(is_manipulatable, x + l * b, results) | |
assert results.shape == (num, step, num_layers, *code_shape) | |
return results if layerwise_manipulation else results[:, :, 0] | |
def manipulate2(latent_codes, | |
proj, | |
mindex, | |
start_distance=-5.0, | |
end_distance=5.0, | |
step=21, | |
layerwise_manipulation=False, | |
num_layers=1, | |
manipulate_layers=None, | |
is_code_layerwise=False, | |
layerwise_manipulation_strength=1.0): | |
if not layerwise_manipulation: | |
assert not is_code_layerwise | |
# assert not is_boundary_layerwise | |
num_layers = 1 | |
manipulate_layers = None | |
layerwise_manipulation_strength = 1.0 | |
# Preprocessing for layer-wise manipulation. | |
# Parse indices of manipulation layers. | |
layer_indices = parse_indices( | |
manipulate_layers, min_val=0, max_val=num_layers - 1) | |
if not layer_indices: | |
layer_indices = list(range(num_layers)) | |
# Make latent codes layer-wise if needed. | |
assert num_layers > 0 | |
if not is_code_layerwise: | |
x = latent_codes[:, np.newaxis] | |
x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) | |
else: | |
x = latent_codes | |
if x.shape[1] != num_layers: | |
raise ValueError(f'Latent codes should be with shape [num, num_layers, ' | |
f'*code_shape], where `num_layers` equals to ' | |
f'{num_layers}, but {x.shape} is received!') | |
# Make boundary layer-wise if needed. | |
# if not is_boundary_layerwise: | |
# b = boundary | |
# b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) | |
# else: | |
# b = boundary[0] | |
# if b.shape[0] != num_layers: | |
# raise ValueError(f'Boundary should be with shape [num_layers, ' | |
# f'*code_shape], where `num_layers` equals to ' | |
# f'{num_layers}, but {b.shape} is received!') | |
# Get layer-wise manipulation strength. | |
if isinstance(layerwise_manipulation_strength, (int, float)): | |
s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] | |
elif isinstance(layerwise_manipulation_strength, (list, tuple)): | |
s = layerwise_manipulation_strength | |
if len(s) != num_layers: | |
raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' | |
f'mismatches number of layers `{num_layers}`!') | |
elif isinstance(layerwise_manipulation_strength, np.ndarray): | |
s = layerwise_manipulation_strength | |
if s.size != num_layers: | |
raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' | |
f'mismatches number of layers `{num_layers}`!') | |
else: | |
raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') | |
# s = np.array(s).reshape( | |
# [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) | |
# b = b * s | |
# if x.shape[1:] != b.shape: | |
# raise ValueError(f'Latent code shape {x.shape} and boundary shape ' | |
# f'{b.shape} mismatch!') | |
num = x.shape[0] | |
code_shape = x.shape[2:] | |
x = x[:, np.newaxis] | |
# b = b[np.newaxis, np.newaxis, :] | |
# l = np.linspace(start_distance, end_distance, step).reshape( | |
# [step if axis == 1 else 1 for axis in range(x.ndim)]) | |
results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) | |
is_manipulatable = np.zeros(results.shape, dtype=bool) | |
is_manipulatable[:, :, layer_indices] = True | |
tmp=MPC(proj,x,mindex,start_distance,end_distance,step) | |
tmp = tmp[:, :,np.newaxis] | |
tmp1 = np.tile(tmp, [num_layers if axis == 2 else 1 for axis in range(tmp.ndim)]) | |
results = np.where(is_manipulatable, tmp1, results) | |
# print(results.shape) | |
assert results.shape == (num, step, num_layers, *code_shape) | |
return results if layerwise_manipulation else results[:, :, 0] | |
def MPC(proj,x,mindex,start_distance,end_distance,step): | |
# x shape (batch_size,1,num_layers,feature) | |
# print(x.shape) | |
x1=proj.transform(x[:,0,0,:]) #/np.sqrt(proj.explained_variance_) # (batch_size,num_pc) | |
x1 = x1[:, np.newaxis] | |
x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) | |
l = np.linspace(start_distance, end_distance, step)[None,:] | |
x1[:,:,mindex]+=l | |
tmp=x1.reshape((-1,x1.shape[-1])) #*np.sqrt(proj.explained_variance_) | |
# print('xxx') | |
x2=proj.inverse_transform(tmp) | |
x2=x2.reshape((x1.shape[0],x1.shape[1],-1)) | |
# x1 = x1[:, np.newaxis] | |
# x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) | |
return x2 | |
def parse_boundary_list(boundary_list_path): | |
"""Parses boundary list. | |
Sometimes, a text file containing a list of boundaries will significantly | |
simplify image manipulation with a large amount of boundaries. This function | |
is used to parse boundary information from such list file. | |
Basically, each item in the list should be with format | |
`($NAME, $SPACE_TYPE): $PATH`. `DISABLE` at the beginning of the line can | |
disable a particular boundary. | |
Sample: | |
(age, z): $AGE_BOUNDARY_PATH | |
(gender, w): $GENDER_BOUNDARY_PATH | |
DISABLE(pose, wp): $POSE_BOUNDARY_PATH | |
Args: | |
boundary_list_path: Path to the boundary list. | |
Returns: | |
A dictionary, whose key is a two-element tuple (boundary_name, space_type) | |
and value is the corresponding boundary path. | |
Raise: | |
ValueError: If the given boundary list does not exist. | |
""" | |
if not os.path.isfile(boundary_list_path): | |
raise ValueError(f'Boundary list `boundary_list_path` does not exist!') | |
boundaries = {} | |
with open(boundary_list_path, 'r') as f: | |
for line in f: | |
if line[:len('DISABLE')] == 'DISABLE': | |
continue | |
boundary_info, boundary_path = line.strip().split(':') | |
boundary_name, space_type = boundary_info.strip()[1:-1].split(',') | |
boundary_name = boundary_name.strip() | |
space_type = space_type.strip().lower() | |
boundary_path = boundary_path.strip() | |
boundaries[(boundary_name, space_type)] = boundary_path | |
return boundaries | |