skytnt commited on
Commit
389a403
1 Parent(s): 78d24f8
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import imageio
3
+ import numpy as np
4
+ import onnx
5
+ import onnxruntime as rt
6
+ from numpy.random import RandomState
7
+ from skimage import transform
8
+
9
+
10
+ def get_inter(r1, r2):
11
+ h_inter = max(min(r1[3], r2[3]) - max(r1[1], r2[1]), 0)
12
+ w_inter = max(min(r1[2], r2[2]) - max(r1[0], r2[0]), 0)
13
+ return h_inter * w_inter
14
+
15
+
16
+ def iou(r1, r2):
17
+ s1 = (r1[2] - r1[0]) * (r1[3] - r1[1])
18
+ s2 = (r2[2] - r2[0]) * (r2[3] - r2[1])
19
+ i = get_inter(r1, r2)
20
+ return i / (s1 + s2 - i)
21
+
22
+
23
+ def letterbox(im, new_shape=(640, 640), color=(0.5, 0.5, 0.5), stride=32):
24
+ # Resize and pad image while meeting stride-multiple constraints
25
+ shape = im.shape[:2] # current shape [height, width]
26
+
27
+ # Scale ratio (new / old)
28
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
29
+
30
+ # Compute padding
31
+ new_unpad = int(round(shape[0] * r)), int(round(shape[1] * r))
32
+ dw, dh = new_shape[1] - new_unpad[1], new_shape[0] - new_unpad[0] # wh padding
33
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
34
+
35
+ dw /= 2 # divide padding into 2 sides
36
+ dh /= 2
37
+
38
+ if shape != new_unpad: # resize
39
+ im = transform.resize(im, new_unpad)
40
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
41
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
42
+ im_new = np.full((new_unpad[0] + top + bottom, new_unpad[1] + left + right, 3), color, dtype=np.float32)
43
+ im_new[top:new_unpad[0] + top, left:new_unpad[1] + left] = im
44
+ return im
45
+
46
+
47
+ def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 + cls_num)
48
+ nc = pred.shape[1] - 5
49
+ candidates = [list() for x in range(nc)]
50
+ for x in pred:
51
+ if x[4] < conf_thres:
52
+ continue
53
+ cls = np.argmax(x[5:])
54
+ p = x[4] * x[5 + cls]
55
+ if conf_thres <= p:
56
+ box = (x[0] - x[2] / 2, x[1] - x[3] / 2, x[0] + x[2] / 2, x[1] + x[3] / 2) # xywh2xyxy
57
+ candidates[cls].append([p, box])
58
+ result = [list() for x in range(nc)]
59
+ for i, candidate in enumerate(candidates):
60
+ candidate = sorted(candidate, key=lambda a: a[0], reverse=True)
61
+ candidate = candidate[:max_instance]
62
+ for x in candidate:
63
+ ok = True
64
+ for r in result[i]:
65
+ if iou(r[1], x[1]) > iou_thres:
66
+ ok = False
67
+ break
68
+ if ok:
69
+ result[i].append(x)
70
+
71
+ return result
72
+
73
+
74
+ class Model:
75
+ def __init__(self):
76
+ self.img_avg = None
77
+ self.detector = None
78
+ self.encoder = None
79
+ self.g_synthesis = None
80
+ self.g_mapping = None
81
+ self.w_avg = None
82
+ self.detector_stride = None
83
+ self.detector_imgsz = None
84
+ self.detector_class_names = None
85
+ self.load_models("./models/")
86
+
87
+ def load_models(self, model_dir):
88
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
89
+ g_mapping = onnx.load(model_dir + "g_mapping.onnx")
90
+ w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
91
+ w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
92
+ w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
93
+ self.w_avg = w_avg
94
+ self.g_mapping = rt.InferenceSession(model_dir + "g_mapping.onnx", providers=providers)
95
+ self.g_synthesis = rt.InferenceSession(model_dir + "g_synthesis.onnx", providers=providers)
96
+ self.encoder = rt.InferenceSession(model_dir + "fb_encoder.onnx", providers=providers)
97
+ self.detector = rt.InferenceSession(model_dir + "waifu_dect.onnx", providers=providers)
98
+ detector_meta = self.detector.get_modelmeta().custom_metadata_map
99
+ self.detector_stride = int(detector_meta['stride'])
100
+ self.detector_imgsz = 1088
101
+ self.detector_class_names = eval(detector_meta['names'])
102
+
103
+ self.img_avg = transform.resize(self.g_synthesis.run(None, {'w': w_avg})[0][0].transpose(1, 2, 0),
104
+ (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
105
+
106
+ def get_img(self, w):
107
+ img = self.g_synthesis.run(None, {'w': w})[0]
108
+ return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
109
+
110
+ def get_w(self, z, psi):
111
+ return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi], dtype=np.float32)})[0]
112
+
113
+ def encode_img(self, img, iteration=5):
114
+ target_img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
115
+ np.float32)
116
+ w = self.w_avg.copy()
117
+ from_img = self.img_avg.copy()
118
+ for i in range(iteration):
119
+ dimg = np.concatenate([target_img, from_img], axis=1)
120
+ dw = self.encoder.run(None, {'dimg': dimg})[0]
121
+ w += dw
122
+ from_img = transform.resize(self.g_synthesis.run(None, {'w': w})[0][0].transpose(1, 2, 0),
123
+ (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
124
+ return w
125
+
126
+ def detect(self, im0, conf_thres, iou_thres, detail=False):
127
+ if im0 is None:
128
+ return []
129
+ img = letterbox((im0 / 255).astype(np.float32), (self.detector_imgsz, self.detector_imgsz),
130
+ stride=self.detector_stride)
131
+ # Convert
132
+ img = img.transpose(2, 0, 1)
133
+ img = img[np.newaxis, :]
134
+ pred = self.detector.run(None, {'images': img})[0][0]
135
+ dets = nms(pred, conf_thres, iou_thres)
136
+ imgs = []
137
+ # Print results
138
+ s = '%gx%g ' % img.shape[2:] # print string
139
+ for i, det in enumerate(dets):
140
+ n = len(det)
141
+ s += f"{n} {self.detector_class_names[i]}{'s' * (n > 1)}, " # add to string
142
+ if detail:
143
+ print(s)
144
+ waifu_rects = []
145
+ head_rects = []
146
+ body_rects = []
147
+
148
+ for i, det in enumerate(dets):
149
+ for x in det:
150
+ # Rescale boxes from img_size to im0 size
151
+ wr = im0.shape[1] / img.shape[3]
152
+ hr = im0.shape[0] / img.shape[2]
153
+ x[1] = (int(x[1][0] * wr), int(x[1][1] * hr),
154
+ int(x[1][2] * wr), int(x[1][3] * hr))
155
+ if i == 0:
156
+ head_rects.append(x[1])
157
+ elif i == 1:
158
+ body_rects.append(x[1])
159
+ elif i == 2:
160
+ waifu_rects.append(x[1])
161
+ for j, waifu_rect in enumerate(waifu_rects):
162
+ msg = f'waifu {j + 1} '
163
+ head_num = 0
164
+ body_num = 0
165
+ hr, br = None, None
166
+ for r in head_rects:
167
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.75:
168
+ hr = r
169
+ head_num += 1
170
+ if head_num != 1:
171
+ if detail:
172
+ print(msg + f'head num error: {head_num}')
173
+ continue
174
+ for r in body_rects:
175
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.65:
176
+ br = r
177
+ body_num += 1
178
+ if body_num != 1:
179
+ if detail:
180
+ print(msg + f'body num error: {body_num}')
181
+ continue
182
+ bounds = (int(min(waifu_rect[0], hr[0], br[0])),
183
+ int(min(waifu_rect[1], hr[1], br[1])),
184
+ int(max(waifu_rect[2], hr[2], br[2])),
185
+ int(max(waifu_rect[3], hr[3], br[3])))
186
+ if (bounds[2] - bounds[0]) / (bounds[3] - bounds[1]) > 0.7:
187
+ if detail:
188
+ print(msg + "ratio out of limit")
189
+ continue
190
+ # 扩展边界
191
+ expand_pixel = (bounds[3] - bounds[1]) // 20
192
+ bounds = (max(bounds[0] - expand_pixel // 2, 0),
193
+ max(bounds[1] - expand_pixel, 0),
194
+ min(bounds[2] + expand_pixel // 2, im0.shape[1]),
195
+ min(bounds[3] + expand_pixel, im0.shape[0]),
196
+ )
197
+ if bounds[3] - bounds[1] >= (bounds[2] - bounds[0]) * 2: # 等高度剪裁
198
+ cx = (bounds[2] + bounds[0]) // 2
199
+ h = bounds[3] - bounds[1]
200
+ w = h // 2
201
+ w2 = w // 2
202
+ l1 = max(cx - w2, 0)
203
+ r1 = min(cx + w2, im0.shape[1])
204
+ bounds = (l1, bounds[1], r1, bounds[3])
205
+ temp_bound = (w2 - (cx - l1), 0, w2 + (r1 - cx), h)
206
+ else: # 等宽度剪裁
207
+ cy = (bounds[3] + bounds[1]) // 2
208
+ w = bounds[2] - bounds[0]
209
+ h = w * 2
210
+ h2 = h // 2
211
+ tp1 = max(cy - h2, 0)
212
+ b1 = min(cy + h2, im0.shape[0])
213
+ bounds = (bounds[0], tp1, bounds[2], b1)
214
+ temp_bound = (0, h2 - (cy - tp1), w, h2 + (b1 - cy))
215
+ temp_img = np.full((h, w, 3), 255, dtype=np.uint8)
216
+ temp_img[temp_bound[1]:temp_bound[3], temp_bound[0]:temp_bound[2]] = im0[bounds[1]:bounds[3],
217
+ bounds[0]:bounds[2]]
218
+ temp_img = transform.resize(temp_img, (1024, 512), preserve_range=True).astype(np.uint8)
219
+ print(temp_img.min(), temp_img.max())
220
+ imgs.append(temp_img)
221
+ return imgs
222
+
223
+ def gen_video(self, w1, w2, path, frame_num=10):
224
+ video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
225
+ lin = np.linspace(0, 1, frame_num)
226
+ for i in range(0, frame_num):
227
+ img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
228
+ video.append_data(img)
229
+ video.close()
230
+
231
+
232
+ def gen_fn(use_seed, seed, psi):
233
+ z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if use_seed else np.random.randn(1, 512)
234
+ w = model.get_w(z.astype(dtype=np.float32), psi)
235
+ img_out = model.get_img(w)
236
+ return img_out, w.tolist(), img_out
237
+
238
+
239
+ def encode_img_fn(img):
240
+ imgs = model.detect(img, 0.2, 0.03)
241
+ if len(imgs) == 0:
242
+ return "failed to detect waifu", None, None
243
+ w = model.encode_img(imgs[0])
244
+ img_out = model.get_img(w)
245
+ return "success", imgs[0], img_out, w.tolist(), img_out
246
+
247
+
248
+ def gen_video_fn(w1, w2, frame):
249
+ if w1 is None or w2 is None:
250
+ return None
251
+ model.gen_video(np.array(w1, dtype=np.float32), np.array(w2, dtype=np.float32), "video.mp4", int(frame))
252
+ return "video.mp4"
253
+
254
+
255
+ if __name__ == '__main__':
256
+ model = Model()
257
+
258
+ app = gr.Blocks()
259
+ with app:
260
+ gr.Markdown("# full-body anime\n\n"
261
+ "the model is not good, just for fun.")
262
+ with gr.Tabs():
263
+ with gr.TabItem("generate image"):
264
+ with gr.Column():
265
+ with gr.Row():
266
+ gen_input1 = gr.Checkbox(value=False, label="use seed")
267
+ gen_input2 = gr.Number(value=1, label="seed")
268
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="psi")
269
+ gen_submit = gr.Button("Run")
270
+ gen_output1 = gr.Image()
271
+ select_img_input_w1 = gr.Json(visible=False)
272
+ select_img_input_img1 = gr.Image(visible=False)
273
+
274
+ with gr.TabItem("encode image"):
275
+ with gr.Column():
276
+ encode_img_input = gr.Image()
277
+ encode_img_submit = gr.Button("Run")
278
+ encode_img_output1 = gr.Textbox(label="message")
279
+ with gr.Row():
280
+ encode_img_output2 = gr.Image(label="detected")
281
+ encode_img_output3 = gr.Image(label="encoded")
282
+ select_img_input_w2 = gr.Json(visible=False)
283
+ select_img_input_img2 = gr.Image(visible=False)
284
+
285
+ with gr.TabItem("generate video"):
286
+ with gr.Column():
287
+ gr.Markdown("## generate video between 2 images")
288
+ with gr.Row():
289
+ with gr.Column():
290
+ gr.Markdown("please select image 1")
291
+ select_img1_dropdown = gr.Dropdown(label="source",
292
+ choices=["current generated image",
293
+ "current encoded image"], type="index")
294
+ select_img1_button = gr.Button("select")
295
+ select_img1_output_img = gr.Image(label="image 1")
296
+ select_img1_output_w = gr.Json(visible=False)
297
+ with gr.Column():
298
+ gr.Markdown("please select image 2")
299
+ select_img2_dropdown = gr.Dropdown(label="source",
300
+ choices=["current generated image",
301
+ "current encoded image"], type="index")
302
+ select_img2_button = gr.Button("select")
303
+ select_img2_output_img = gr.Image(label="image 2")
304
+ select_img2_output_w = gr.Json(visible=False)
305
+ generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=10)
306
+ generate_video_button = gr.Button("generate")
307
+ generate_video_output = gr.Video()
308
+
309
+ gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3],
310
+ [gen_output1, select_img_input_w1, select_img_input_img1])
311
+ encode_img_submit.click(encode_img_fn, [encode_img_input],
312
+ [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
313
+ select_img_input_img2])
314
+ select_img1_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
315
+ [select_img1_dropdown, select_img_input_img1, select_img_input_img2,
316
+ select_img_input_w1, select_img_input_w2],
317
+ [select_img1_output_img, select_img1_output_w])
318
+ select_img2_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
319
+ [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
320
+ select_img_input_w1, select_img_input_w2],
321
+ [select_img2_output_img, select_img2_output_w])
322
+ generate_video_button.click(gen_video_fn, [select_img1_output_w, select_img2_output_w, generate_video_frame],
323
+ [generate_video_output])
324
+
325
+ app.launch()
models/fb_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5be6b51e6f5046c6b402f4b0af62903363891571c72fd576afcf8d64f9e71dd9
3
+ size 706131142
models/g_mapping.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e8b35dc40e6ba4f08cb269b56617814c8406418a76965ae1f9f8ea41813fc30
3
+ size 8411020
models/g_synthesis.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bacba02853729b8c9635c83c6194943158c5b992f742123aa1ce6cbd94c138f7
3
+ size 112794026
models/waifu_dect.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5de6949912bf94c3307f2b18ebc7b49f309e713b1799d29805ccd882e327d3
3
+ size 83550422
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnx
2
+ onnxruntime-gpu
3
+ scikit-image
4
+ imageio-ffmpeg