akhaliq HF staff commited on
Commit
0f0efe7
1 Parent(s): 8ca3a29

add app and requirements

Browse files
Files changed (2) hide show
  1. app.py +193 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import gradio as gr
4
+
5
+ MODEL_DIR = 'models/pretrain'
6
+ os.makedirs(MODEL_DIR, exist_ok=True)
7
+
8
+ os.system("wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/jzhubt_connect_ust_hk/ETYVen9KXGlAia2gH6pcZswB9Lw-21vWrE75OACvG2SBow\?e\=SCGqg0\&download=1 -O $MODEL_DIR/stylegan2-ffhq-config-f-1024x1024.pth --quiet")
9
+
10
+
11
+ # python 3.7
12
+ """Demo."""
13
+ import io
14
+ import cv2
15
+ import warnings
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image
19
+ from models import build_model
20
+
21
+ warnings.filterwarnings(action='ignore', category=UserWarning)
22
+
23
+ def postprocess_image(image, min_val=-1.0, max_val=1.0):
24
+ """Post-processes image to pixel range [0, 255] with dtype `uint8`.
25
+
26
+ This function is particularly used to handle the results produced by deep
27
+ models.
28
+
29
+ NOTE: The input image is assumed to be with format `NCHW`, and the returned
30
+ image will always be with format `NHWC`.
31
+
32
+ Args:
33
+ image: The input image for post-processing.
34
+ min_val: Expected minimum value of the input image.
35
+ max_val: Expected maximum value of the input image.
36
+
37
+ Returns:
38
+ The post-processed image.
39
+ """
40
+ assert isinstance(image, np.ndarray)
41
+
42
+ image = image.astype(np.float64)
43
+ image = (image - min_val) / (max_val - min_val) * 255
44
+ image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
45
+
46
+ assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
47
+ return image.transpose(0, 2, 3, 1)
48
+
49
+
50
+ def to_numpy(data):
51
+ """Converts the input data to `numpy.ndarray`."""
52
+ if isinstance(data, (int, float)):
53
+ return np.array(data)
54
+ if isinstance(data, np.ndarray):
55
+ return data
56
+ if isinstance(data, torch.Tensor):
57
+ return data.detach().cpu().numpy()
58
+ raise TypeError(f'Not supported data type `{type(data)}` for '
59
+ f'converting to `numpy.ndarray`!')
60
+
61
+
62
+ def linear_interpolate(latent_code,
63
+ boundary,
64
+ layer_index=None,
65
+ start_distance=-10.0,
66
+ end_distance=10.0,
67
+ steps=7):
68
+ """Interpolate between the latent code and boundary."""
69
+ assert (len(latent_code.shape) == 3 and len(boundary.shape) == 3 and
70
+ latent_code.shape[0] == 1 and boundary.shape[0] == 1 and
71
+ latent_code.shape[1] == boundary.shape[1])
72
+ linspace = np.linspace(start_distance, end_distance, steps)
73
+ linspace = linspace.reshape([-1, 1, 1]).astype(np.float32)
74
+ inter_code = linspace * boundary
75
+ is_manipulatable = np.zeros(inter_code.shape, dtype=bool)
76
+ is_manipulatable[:, layer_index, :] = True
77
+ mani_code = np.where(is_manipulatable, latent_code+inter_code, latent_code)
78
+ return mani_code
79
+
80
+
81
+ def imshow(images, col, viz_size=256):
82
+ """Shows images in one figure."""
83
+ num, height, width, channels = images.shape
84
+ assert num % col == 0
85
+ row = num // col
86
+
87
+ fused_image = np.zeros((viz_size*row, viz_size*col, channels), dtype=np.uint8)
88
+
89
+ for idx, image in enumerate(images):
90
+ i, j = divmod(idx, col)
91
+ y = i * viz_size
92
+ x = j * viz_size
93
+ if height != viz_size or width != viz_size:
94
+ image = cv2.resize(image, (viz_size, viz_size))
95
+ fused_image[y:y + viz_size, x:x + viz_size] = image
96
+
97
+ fused_image = np.asarray(fused_image, dtype=np.uint8)
98
+ data = io.BytesIO()
99
+ if channels == 4:
100
+ Image.fromarray(fused_image).save(data, 'png')
101
+ elif channels == 3:
102
+ Image.fromarray(fused_image).save(data, 'jpeg')
103
+ else:
104
+ raise ValueError('Image channel error')
105
+ im_data = data.getvalue()
106
+ image = Image.open(io.BytesIO(im_data))
107
+ return image
108
+
109
+ print('Building generator')
110
+
111
+ checkpoint_path=f'{MODEL_DIR}/stylegan2-ffhq-config-f-1024x1024.pth'
112
+ config = dict(model_type='StyleGAN2Generator',
113
+ resolution=1024,
114
+ w_dim=512,
115
+ fmaps_base=int(1 * (32 << 10)),
116
+ fmaps_max=512,)
117
+ generator = build_model(**config)
118
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
119
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
120
+ if 'generator_smooth' in checkpoint:
121
+ generator.load_state_dict(checkpoint['generator_smooth'])
122
+ else:
123
+ generator.load_state_dict(checkpoint['generator'])
124
+ generator = generator.eval().cpu()
125
+ print('Finish loading checkpoint.')
126
+
127
+ print('Loading boundaries')
128
+ ATTRS = ['eyebrows', 'eyesize', 'gaze_direction', 'nose_length', 'mouth', 'lipstick']
129
+ boundaries = {}
130
+ for attr in ATTRS:
131
+ boundary_path = os.path.join(f'directions/ffhq/stylegan2/{attr}.npy')
132
+ boundary = np.load(boundary_path)
133
+ boundaries[attr] = boundary
134
+ print('Generator and boundaries are ready.')
135
+
136
+
137
+ def inference(num_of_image,seed,trunc_psi,eyebrows,eyesize,gaze_direction,nose_length,mouth,lipstick):
138
+ print('Sampling latent codes with given seed.')
139
+ num_of_image = num_of_image #@param {type:"slider", min:1, max:8, step:1}
140
+ seed = seed #@param {type:"slider", min:0, max:10000, step:1}
141
+ trunc_psi = trunc_psi #@param {type:"slider", min:0, max:1, step:0.1}
142
+ trunc_layers = 8
143
+ np.random.seed(seed)
144
+ latent_z = np.random.randn(num_of_image, generator.z_dim)
145
+ latent_z = torch.from_numpy(latent_z.astype(np.float32))
146
+ latent_z = latent_z.cpu()
147
+ wp = generator.mapping(latent_z, None)['wp']
148
+ if trunc_psi < 1.0:
149
+ w_avg = generator.w_avg
150
+ w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers]
151
+ wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi)
152
+ with torch.no_grad():
153
+ images_ori = generator.synthesis(wp)['image']
154
+ images_ori = postprocess_image(to_numpy(images_ori))
155
+ print('Original images are shown as belows.')
156
+ imshow(images_ori, col=images_ori.shape[0])
157
+ latent_wp = to_numpy(wp)
158
+
159
+
160
+
161
+ eyebrows = eyebrows #@param {type:"slider", min:-12.0, max:12.0, step:2}
162
+ eyesize = eyesize #@param {type:"slider", min:-12.0, max:12.0, step:2}
163
+ gaze_direction = gaze_direction #@param {type:"slider", min:-12.0, max:12.0, step:2}
164
+ nose_length = nose_length #@param {type:"slider", min:-12.0, max:12.0, step:2}
165
+ mouth = mouth #@param {type:"slider", min:-12.0, max:12.0, step:2}
166
+ lipstick = lipstick #@param {type:"slider", min:-12.0, max:12.0, step:2}
167
+
168
+ new_codes = latent_wp.copy()
169
+ for attr_name in ATTRS:
170
+ if attr_name in ['eyebrows', 'lipstick']:
171
+ layers_idx = [8,9,10,11]
172
+ else:
173
+ layers_idx = [4,5,6,7]
174
+ step = eval(attr_name)
175
+ direction = boundaries[attr_name]
176
+ direction = np.tile(direction, [1, generator.num_layers, 1])
177
+ new_codes[:, layers_idx, :] += direction[:, layers_idx, :] * step
178
+ new_codes = torch.from_numpy(new_codes.astype(np.float32)).cpu()
179
+ with torch.no_grad():
180
+ images_mani = generator.synthesis(new_codes)['image']
181
+ images_mani = postprocess_image(to_numpy(images_mani))
182
+ return imshow(images_mani, col=images_mani.shape[0])
183
+
184
+ gr.Interface(inference,[gr.Slider(1, 3, value=1,label="num_of_image"),
185
+ gr.Slider(0, 10000, value=210,label="seed"),
186
+ gr.Slider(0, 1, value=0.7,step=0.1,label="truncation psi"),
187
+ gr.Slider(-12, 12, value=0,label="eyebrows"),
188
+ gr.Slider(-12, 12, value=0,label="eyesize"),
189
+ gr.Slider(-12, 12, value=0,label="gaze direction"),
190
+ gr.Slider(-12, 12, value=0,label="nose_length"),
191
+ gr.Slider(-12, 12, value=0,label="mouth"),
192
+ gr.Slider(-12, 12, value=0,label="lipstick"),
193
+ ],gr.Image(type="pil")).launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ scikit-video
3
+ pillow
4
+ opencv-python-headless
5
+ numpy