File size: 7,320 Bytes
0f0efe7
 
 
 
 
f859d50
0f0efe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32482f3
0f0efe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c49c5c
 
 
 
 
 
 
 
0f0efe7
 
 
 
 
 
86b157b
 
352f3ef
 
0f0efe7
352f3ef
 
 
 
 
 
693f29c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

import os
import gradio as gr


os.system("gdown https://drive.google.com/uc?id=1--h-4E5LSxe6VTp9rjJhtoILxxNH0X5n")

# python 3.7
"""Demo."""
import io
import cv2
import warnings
import numpy as np
import torch
from PIL import Image
from models import build_model

warnings.filterwarnings(action='ignore', category=UserWarning)

def postprocess_image(image, min_val=-1.0, max_val=1.0):
    """Post-processes image to pixel range [0, 255] with dtype `uint8`.

    This function is particularly used to handle the results produced by deep
    models.

    NOTE: The input image is assumed to be with format `NCHW`, and the returned
    image will always be with format `NHWC`.

    Args:
        image: The input image for post-processing.
        min_val: Expected minimum value of the input image.
        max_val: Expected maximum value of the input image.

    Returns:
        The post-processed image.
    """
    assert isinstance(image, np.ndarray)

    image = image.astype(np.float64)
    image = (image - min_val) / (max_val - min_val) * 255
    image = np.clip(image + 0.5, 0, 255).astype(np.uint8)

    assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
    return image.transpose(0, 2, 3, 1)


def to_numpy(data):
    """Converts the input data to `numpy.ndarray`."""
    if isinstance(data, (int, float)):
        return np.array(data)
    if isinstance(data, np.ndarray):
        return data
    if isinstance(data, torch.Tensor):
        return data.detach().cpu().numpy()
    raise TypeError(f'Not supported data type `{type(data)}` for '
                    f'converting to `numpy.ndarray`!')


def linear_interpolate(latent_code,
                       boundary,
                       layer_index=None,
                       start_distance=-10.0,
                       end_distance=10.0,
                       steps=7):
    """Interpolate between the latent code and boundary."""
    assert (len(latent_code.shape) == 3 and len(boundary.shape) == 3 and
            latent_code.shape[0] == 1 and boundary.shape[0] == 1 and
            latent_code.shape[1] == boundary.shape[1])
    linspace = np.linspace(start_distance, end_distance, steps)
    linspace = linspace.reshape([-1, 1, 1]).astype(np.float32)
    inter_code = linspace * boundary
    is_manipulatable = np.zeros(inter_code.shape, dtype=bool)
    is_manipulatable[:, layer_index, :] = True
    mani_code = np.where(is_manipulatable, latent_code+inter_code, latent_code)
    return mani_code


def imshow(images, col, viz_size=256):
  """Shows images in one figure."""
  num, height, width, channels = images.shape
  assert num % col == 0
  row = num // col

  fused_image = np.zeros((viz_size*row, viz_size*col, channels), dtype=np.uint8)

  for idx, image in enumerate(images):
    i, j = divmod(idx, col)
    y = i * viz_size
    x = j * viz_size
    if height != viz_size or width != viz_size:
      image = cv2.resize(image, (viz_size, viz_size))
    fused_image[y:y + viz_size, x:x + viz_size] = image

  fused_image = np.asarray(fused_image, dtype=np.uint8)
  data = io.BytesIO()
  if channels == 4:
    Image.fromarray(fused_image).save(data, 'png')
  elif channels == 3:
    Image.fromarray(fused_image).save(data, 'jpeg')
  else:
    raise ValueError('Image channel error')
  im_data = data.getvalue()
  image = Image.open(io.BytesIO(im_data))
  return image

  print('Building generator')

checkpoint_path='stylegan2-ffhq-config-f-1024x1024.pth'
config = dict(model_type='StyleGAN2Generator',
              resolution=1024,
              w_dim=512,
              fmaps_base=int(1 * (32 << 10)),
              fmaps_max=512,)
generator = build_model(**config)
print(f'Loading checkpoint from `{checkpoint_path}` ...')
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
if 'generator_smooth' in checkpoint:
    generator.load_state_dict(checkpoint['generator_smooth'])
else:
    generator.load_state_dict(checkpoint['generator'])
generator = generator.eval().cpu()
print('Finish loading checkpoint.')

print('Loading boundaries')
ATTRS = ['eyebrows', 'eyesize', 'gaze_direction', 'nose_length', 'mouth', 'lipstick']
boundaries = {}
for attr in ATTRS:
  boundary_path = os.path.join(f'directions/ffhq/stylegan2/{attr}.npy')
  boundary = np.load(boundary_path)
  boundaries[attr] = boundary
print('Generator and boundaries are ready.')


def inference(num_of_image,seed,trunc_psi,eyebrows,eyesize,gaze_direction,nose_length,mouth,lipstick):
    print('Sampling latent codes with given seed.')
    num_of_image = num_of_image #@param {type:"slider", min:1, max:8, step:1}
    seed = seed #@param {type:"slider", min:0, max:10000, step:1}
    trunc_psi = trunc_psi #@param {type:"slider", min:0, max:1, step:0.1}
    trunc_layers = 8
    np.random.seed(seed)
    latent_z = np.random.randn(num_of_image, generator.z_dim)
    latent_z = torch.from_numpy(latent_z.astype(np.float32))
    latent_z = latent_z.cpu()
    wp = generator.mapping(latent_z, None)['wp']
    if trunc_psi < 1.0:
        w_avg = generator.w_avg
        w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers]
        wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi)
    with torch.no_grad():
        images_ori = generator.synthesis(wp)['image']
    images_ori = postprocess_image(to_numpy(images_ori))
    print('Original images are shown as belows.')
    imshow(images_ori, col=images_ori.shape[0])
    latent_wp = to_numpy(wp)



    eyebrows = eyebrows #@param {type:"slider", min:-12.0, max:12.0, step:2}
    eyesize = eyesize #@param {type:"slider", min:-12.0, max:12.0, step:2}
    gaze_direction = gaze_direction #@param {type:"slider", min:-12.0, max:12.0, step:2}
    nose_length = nose_length #@param {type:"slider", min:-12.0, max:12.0, step:2}
    mouth = mouth #@param {type:"slider", min:-12.0, max:12.0, step:2}
    lipstick = lipstick #@param {type:"slider", min:-12.0, max:12.0, step:2}

    new_codes = latent_wp.copy()
    for attr_name in ATTRS:
        if attr_name in ['eyebrows', 'lipstick']:
            layers_idx = [8,9,10,11]
        else:
            layers_idx = [4,5,6,7]
        step =  eval(attr_name)
        direction = boundaries[attr_name]
        direction = np.tile(direction, [1, generator.num_layers, 1])
        new_codes[:, layers_idx, :] += direction[:, layers_idx, :] * step
    new_codes = torch.from_numpy(new_codes.astype(np.float32)).cpu()
    with torch.no_grad():
        images_mani = generator.synthesis(new_codes)['image']
    images_mani = postprocess_image(to_numpy(images_mani))
    return imshow(images_mani, col=images_mani.shape[0])

title = "resefa"
description = "## Gradio Demo for [Region-Based Semantic Factorization in GANs](https://github.com/zhujiapeng/resefa)"
gr.Interface(inference,[gr.Slider(1, 3, value=1,label="num_of_image",step=1),
gr.Slider(0, 10000, value=210,label="seed",step=1),
gr.Slider(0, 1, value=0.7,step=0.1,label="truncation psi"),
gr.Slider(-12, 12, value=0,label="eyebrows",step=1),
gr.Slider(-12, 12, value=0,label="eyesize",step=1),
gr.Slider(-12, 12, value=0,label="gaze direction",step=1),
gr.Slider(-12, 12, value=0,label="nose_length",step=1),
gr.Slider(-12, 12, value=0,label="mouth",step=1),
gr.Slider(-12, 12, value=0,label="lipstick",step=1),
],gr.Image(type="pil"),title=title,description=description).launch()