skytnt commited on
Commit
e94371d
·
1 Parent(s): 389a403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -325
app.py CHANGED
@@ -1,325 +1,325 @@
1
- import gradio as gr
2
- import imageio
3
- import numpy as np
4
- import onnx
5
- import onnxruntime as rt
6
- from numpy.random import RandomState
7
- from skimage import transform
8
-
9
-
10
- def get_inter(r1, r2):
11
- h_inter = max(min(r1[3], r2[3]) - max(r1[1], r2[1]), 0)
12
- w_inter = max(min(r1[2], r2[2]) - max(r1[0], r2[0]), 0)
13
- return h_inter * w_inter
14
-
15
-
16
- def iou(r1, r2):
17
- s1 = (r1[2] - r1[0]) * (r1[3] - r1[1])
18
- s2 = (r2[2] - r2[0]) * (r2[3] - r2[1])
19
- i = get_inter(r1, r2)
20
- return i / (s1 + s2 - i)
21
-
22
-
23
- def letterbox(im, new_shape=(640, 640), color=(0.5, 0.5, 0.5), stride=32):
24
- # Resize and pad image while meeting stride-multiple constraints
25
- shape = im.shape[:2] # current shape [height, width]
26
-
27
- # Scale ratio (new / old)
28
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
29
-
30
- # Compute padding
31
- new_unpad = int(round(shape[0] * r)), int(round(shape[1] * r))
32
- dw, dh = new_shape[1] - new_unpad[1], new_shape[0] - new_unpad[0] # wh padding
33
- dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
34
-
35
- dw /= 2 # divide padding into 2 sides
36
- dh /= 2
37
-
38
- if shape != new_unpad: # resize
39
- im = transform.resize(im, new_unpad)
40
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
41
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
42
- im_new = np.full((new_unpad[0] + top + bottom, new_unpad[1] + left + right, 3), color, dtype=np.float32)
43
- im_new[top:new_unpad[0] + top, left:new_unpad[1] + left] = im
44
- return im
45
-
46
-
47
- def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 + cls_num)
48
- nc = pred.shape[1] - 5
49
- candidates = [list() for x in range(nc)]
50
- for x in pred:
51
- if x[4] < conf_thres:
52
- continue
53
- cls = np.argmax(x[5:])
54
- p = x[4] * x[5 + cls]
55
- if conf_thres <= p:
56
- box = (x[0] - x[2] / 2, x[1] - x[3] / 2, x[0] + x[2] / 2, x[1] + x[3] / 2) # xywh2xyxy
57
- candidates[cls].append([p, box])
58
- result = [list() for x in range(nc)]
59
- for i, candidate in enumerate(candidates):
60
- candidate = sorted(candidate, key=lambda a: a[0], reverse=True)
61
- candidate = candidate[:max_instance]
62
- for x in candidate:
63
- ok = True
64
- for r in result[i]:
65
- if iou(r[1], x[1]) > iou_thres:
66
- ok = False
67
- break
68
- if ok:
69
- result[i].append(x)
70
-
71
- return result
72
-
73
-
74
- class Model:
75
- def __init__(self):
76
- self.img_avg = None
77
- self.detector = None
78
- self.encoder = None
79
- self.g_synthesis = None
80
- self.g_mapping = None
81
- self.w_avg = None
82
- self.detector_stride = None
83
- self.detector_imgsz = None
84
- self.detector_class_names = None
85
- self.load_models("./models/")
86
-
87
- def load_models(self, model_dir):
88
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
89
- g_mapping = onnx.load(model_dir + "g_mapping.onnx")
90
- w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
91
- w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
92
- w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
93
- self.w_avg = w_avg
94
- self.g_mapping = rt.InferenceSession(model_dir + "g_mapping.onnx", providers=providers)
95
- self.g_synthesis = rt.InferenceSession(model_dir + "g_synthesis.onnx", providers=providers)
96
- self.encoder = rt.InferenceSession(model_dir + "fb_encoder.onnx", providers=providers)
97
- self.detector = rt.InferenceSession(model_dir + "waifu_dect.onnx", providers=providers)
98
- detector_meta = self.detector.get_modelmeta().custom_metadata_map
99
- self.detector_stride = int(detector_meta['stride'])
100
- self.detector_imgsz = 1088
101
- self.detector_class_names = eval(detector_meta['names'])
102
-
103
- self.img_avg = transform.resize(self.g_synthesis.run(None, {'w': w_avg})[0][0].transpose(1, 2, 0),
104
- (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
105
-
106
- def get_img(self, w):
107
- img = self.g_synthesis.run(None, {'w': w})[0]
108
- return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
109
-
110
- def get_w(self, z, psi):
111
- return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi], dtype=np.float32)})[0]
112
-
113
- def encode_img(self, img, iteration=5):
114
- target_img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
115
- np.float32)
116
- w = self.w_avg.copy()
117
- from_img = self.img_avg.copy()
118
- for i in range(iteration):
119
- dimg = np.concatenate([target_img, from_img], axis=1)
120
- dw = self.encoder.run(None, {'dimg': dimg})[0]
121
- w += dw
122
- from_img = transform.resize(self.g_synthesis.run(None, {'w': w})[0][0].transpose(1, 2, 0),
123
- (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
124
- return w
125
-
126
- def detect(self, im0, conf_thres, iou_thres, detail=False):
127
- if im0 is None:
128
- return []
129
- img = letterbox((im0 / 255).astype(np.float32), (self.detector_imgsz, self.detector_imgsz),
130
- stride=self.detector_stride)
131
- # Convert
132
- img = img.transpose(2, 0, 1)
133
- img = img[np.newaxis, :]
134
- pred = self.detector.run(None, {'images': img})[0][0]
135
- dets = nms(pred, conf_thres, iou_thres)
136
- imgs = []
137
- # Print results
138
- s = '%gx%g ' % img.shape[2:] # print string
139
- for i, det in enumerate(dets):
140
- n = len(det)
141
- s += f"{n} {self.detector_class_names[i]}{'s' * (n > 1)}, " # add to string
142
- if detail:
143
- print(s)
144
- waifu_rects = []
145
- head_rects = []
146
- body_rects = []
147
-
148
- for i, det in enumerate(dets):
149
- for x in det:
150
- # Rescale boxes from img_size to im0 size
151
- wr = im0.shape[1] / img.shape[3]
152
- hr = im0.shape[0] / img.shape[2]
153
- x[1] = (int(x[1][0] * wr), int(x[1][1] * hr),
154
- int(x[1][2] * wr), int(x[1][3] * hr))
155
- if i == 0:
156
- head_rects.append(x[1])
157
- elif i == 1:
158
- body_rects.append(x[1])
159
- elif i == 2:
160
- waifu_rects.append(x[1])
161
- for j, waifu_rect in enumerate(waifu_rects):
162
- msg = f'waifu {j + 1} '
163
- head_num = 0
164
- body_num = 0
165
- hr, br = None, None
166
- for r in head_rects:
167
- if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.75:
168
- hr = r
169
- head_num += 1
170
- if head_num != 1:
171
- if detail:
172
- print(msg + f'head num error: {head_num}')
173
- continue
174
- for r in body_rects:
175
- if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.65:
176
- br = r
177
- body_num += 1
178
- if body_num != 1:
179
- if detail:
180
- print(msg + f'body num error: {body_num}')
181
- continue
182
- bounds = (int(min(waifu_rect[0], hr[0], br[0])),
183
- int(min(waifu_rect[1], hr[1], br[1])),
184
- int(max(waifu_rect[2], hr[2], br[2])),
185
- int(max(waifu_rect[3], hr[3], br[3])))
186
- if (bounds[2] - bounds[0]) / (bounds[3] - bounds[1]) > 0.7:
187
- if detail:
188
- print(msg + "ratio out of limit")
189
- continue
190
- # 扩展边界
191
- expand_pixel = (bounds[3] - bounds[1]) // 20
192
- bounds = (max(bounds[0] - expand_pixel // 2, 0),
193
- max(bounds[1] - expand_pixel, 0),
194
- min(bounds[2] + expand_pixel // 2, im0.shape[1]),
195
- min(bounds[3] + expand_pixel, im0.shape[0]),
196
- )
197
- if bounds[3] - bounds[1] >= (bounds[2] - bounds[0]) * 2: # 等高度剪裁
198
- cx = (bounds[2] + bounds[0]) // 2
199
- h = bounds[3] - bounds[1]
200
- w = h // 2
201
- w2 = w // 2
202
- l1 = max(cx - w2, 0)
203
- r1 = min(cx + w2, im0.shape[1])
204
- bounds = (l1, bounds[1], r1, bounds[3])
205
- temp_bound = (w2 - (cx - l1), 0, w2 + (r1 - cx), h)
206
- else: # 等宽度剪裁
207
- cy = (bounds[3] + bounds[1]) // 2
208
- w = bounds[2] - bounds[0]
209
- h = w * 2
210
- h2 = h // 2
211
- tp1 = max(cy - h2, 0)
212
- b1 = min(cy + h2, im0.shape[0])
213
- bounds = (bounds[0], tp1, bounds[2], b1)
214
- temp_bound = (0, h2 - (cy - tp1), w, h2 + (b1 - cy))
215
- temp_img = np.full((h, w, 3), 255, dtype=np.uint8)
216
- temp_img[temp_bound[1]:temp_bound[3], temp_bound[0]:temp_bound[2]] = im0[bounds[1]:bounds[3],
217
- bounds[0]:bounds[2]]
218
- temp_img = transform.resize(temp_img, (1024, 512), preserve_range=True).astype(np.uint8)
219
- print(temp_img.min(), temp_img.max())
220
- imgs.append(temp_img)
221
- return imgs
222
-
223
- def gen_video(self, w1, w2, path, frame_num=10):
224
- video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
225
- lin = np.linspace(0, 1, frame_num)
226
- for i in range(0, frame_num):
227
- img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
228
- video.append_data(img)
229
- video.close()
230
-
231
-
232
- def gen_fn(use_seed, seed, psi):
233
- z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if use_seed else np.random.randn(1, 512)
234
- w = model.get_w(z.astype(dtype=np.float32), psi)
235
- img_out = model.get_img(w)
236
- return img_out, w.tolist(), img_out
237
-
238
-
239
- def encode_img_fn(img):
240
- imgs = model.detect(img, 0.2, 0.03)
241
- if len(imgs) == 0:
242
- return "failed to detect waifu", None, None
243
- w = model.encode_img(imgs[0])
244
- img_out = model.get_img(w)
245
- return "success", imgs[0], img_out, w.tolist(), img_out
246
-
247
-
248
- def gen_video_fn(w1, w2, frame):
249
- if w1 is None or w2 is None:
250
- return None
251
- model.gen_video(np.array(w1, dtype=np.float32), np.array(w2, dtype=np.float32), "video.mp4", int(frame))
252
- return "video.mp4"
253
-
254
-
255
- if __name__ == '__main__':
256
- model = Model()
257
-
258
- app = gr.Blocks()
259
- with app:
260
- gr.Markdown("# full-body anime\n\n"
261
- "the model is not good, just for fun.")
262
- with gr.Tabs():
263
- with gr.TabItem("generate image"):
264
- with gr.Column():
265
- with gr.Row():
266
- gen_input1 = gr.Checkbox(value=False, label="use seed")
267
- gen_input2 = gr.Number(value=1, label="seed")
268
- gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="psi")
269
- gen_submit = gr.Button("Run")
270
- gen_output1 = gr.Image()
271
- select_img_input_w1 = gr.Json(visible=False)
272
- select_img_input_img1 = gr.Image(visible=False)
273
-
274
- with gr.TabItem("encode image"):
275
- with gr.Column():
276
- encode_img_input = gr.Image()
277
- encode_img_submit = gr.Button("Run")
278
- encode_img_output1 = gr.Textbox(label="message")
279
- with gr.Row():
280
- encode_img_output2 = gr.Image(label="detected")
281
- encode_img_output3 = gr.Image(label="encoded")
282
- select_img_input_w2 = gr.Json(visible=False)
283
- select_img_input_img2 = gr.Image(visible=False)
284
-
285
- with gr.TabItem("generate video"):
286
- with gr.Column():
287
- gr.Markdown("## generate video between 2 images")
288
- with gr.Row():
289
- with gr.Column():
290
- gr.Markdown("please select image 1")
291
- select_img1_dropdown = gr.Dropdown(label="source",
292
- choices=["current generated image",
293
- "current encoded image"], type="index")
294
- select_img1_button = gr.Button("select")
295
- select_img1_output_img = gr.Image(label="image 1")
296
- select_img1_output_w = gr.Json(visible=False)
297
- with gr.Column():
298
- gr.Markdown("please select image 2")
299
- select_img2_dropdown = gr.Dropdown(label="source",
300
- choices=["current generated image",
301
- "current encoded image"], type="index")
302
- select_img2_button = gr.Button("select")
303
- select_img2_output_img = gr.Image(label="image 2")
304
- select_img2_output_w = gr.Json(visible=False)
305
- generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=10)
306
- generate_video_button = gr.Button("generate")
307
- generate_video_output = gr.Video()
308
-
309
- gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3],
310
- [gen_output1, select_img_input_w1, select_img_input_img1])
311
- encode_img_submit.click(encode_img_fn, [encode_img_input],
312
- [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
313
- select_img_input_img2])
314
- select_img1_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
315
- [select_img1_dropdown, select_img_input_img1, select_img_input_img2,
316
- select_img_input_w1, select_img_input_w2],
317
- [select_img1_output_img, select_img1_output_w])
318
- select_img2_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
319
- [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
320
- select_img_input_w1, select_img_input_w2],
321
- [select_img2_output_img, select_img2_output_w])
322
- generate_video_button.click(gen_video_fn, [select_img1_output_w, select_img2_output_w, generate_video_frame],
323
- [generate_video_output])
324
-
325
- app.launch()
 
1
+ import gradio as gr
2
+ import imageio
3
+ import numpy as np
4
+ import onnx
5
+ import onnxruntime as rt
6
+ from numpy.random import RandomState
7
+ from skimage import transform
8
+
9
+
10
+ def get_inter(r1, r2):
11
+ h_inter = max(min(r1[3], r2[3]) - max(r1[1], r2[1]), 0)
12
+ w_inter = max(min(r1[2], r2[2]) - max(r1[0], r2[0]), 0)
13
+ return h_inter * w_inter
14
+
15
+
16
+ def iou(r1, r2):
17
+ s1 = (r1[2] - r1[0]) * (r1[3] - r1[1])
18
+ s2 = (r2[2] - r2[0]) * (r2[3] - r2[1])
19
+ i = get_inter(r1, r2)
20
+ return i / (s1 + s2 - i)
21
+
22
+
23
+ def letterbox(im, new_shape=(640, 640), color=(0.5, 0.5, 0.5), stride=32):
24
+ # Resize and pad image while meeting stride-multiple constraints
25
+ shape = im.shape[:2] # current shape [height, width]
26
+
27
+ # Scale ratio (new / old)
28
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
29
+
30
+ # Compute padding
31
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
32
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
33
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
34
+
35
+ dw /= 2 # divide padding into 2 sides
36
+ dh /= 2
37
+
38
+ if shape != new_unpad: # resize
39
+ im = transform.resize(im, (new_unpad[1], new_unpad[0]))
40
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
41
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
42
+
43
+ im_new = np.full((new_unpad[1] + top + bottom, new_unpad[0] + left + right, 3), color, dtype=np.float32)
44
+ im_new[top:new_unpad[1] + top, left:new_unpad[0] + left] = im
45
+ return im_new
46
+
47
+
48
+ def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 + cls_num)
49
+ nc = pred.shape[1] - 5
50
+ candidates = [list() for x in range(nc)]
51
+ for x in pred:
52
+ if x[4] < conf_thres:
53
+ continue
54
+ cls = np.argmax(x[5:])
55
+ p = x[4] * x[5 + cls]
56
+ if conf_thres <= p:
57
+ box = (x[0] - x[2] / 2, x[1] - x[3] / 2, x[0] + x[2] / 2, x[1] + x[3] / 2) # xywh2xyxy
58
+ candidates[cls].append([p, box])
59
+ result = [list() for x in range(nc)]
60
+ for i, candidate in enumerate(candidates):
61
+ candidate = sorted(candidate, key=lambda a: a[0], reverse=True)
62
+ candidate = candidate[:max_instance]
63
+ for x in candidate:
64
+ ok = True
65
+ for r in result[i]:
66
+ if iou(r[1], x[1]) > iou_thres:
67
+ ok = False
68
+ break
69
+ if ok:
70
+ result[i].append(x)
71
+
72
+ return result
73
+
74
+
75
+ class Model:
76
+ def __init__(self):
77
+ self.img_avg = None
78
+ self.detector = None
79
+ self.encoder = None
80
+ self.g_synthesis = None
81
+ self.g_mapping = None
82
+ self.w_avg = None
83
+ self.detector_stride = None
84
+ self.detector_imgsz = None
85
+ self.detector_class_names = None
86
+ self.load_models("./models/")
87
+
88
+ def load_models(self, model_dir):
89
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
90
+ g_mapping = onnx.load(model_dir + "g_mapping.onnx")
91
+ w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
92
+ w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
93
+ w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
94
+ self.w_avg = w_avg
95
+ self.g_mapping = rt.InferenceSession(model_dir + "g_mapping.onnx", providers=providers)
96
+ self.g_synthesis = rt.InferenceSession(model_dir + "g_synthesis.onnx", providers=providers)
97
+ self.encoder = rt.InferenceSession(model_dir + "fb_encoder.onnx", providers=providers)
98
+ self.detector = rt.InferenceSession(model_dir + "waifu_dect.onnx", providers=providers)
99
+ detector_meta = self.detector.get_modelmeta().custom_metadata_map
100
+ self.detector_stride = int(detector_meta['stride'])
101
+ self.detector_imgsz = 1088
102
+ self.detector_class_names = eval(detector_meta['names'])
103
+
104
+ self.img_avg = transform.resize(self.g_synthesis.run(None, {'w': w_avg})[0][0].transpose(1, 2, 0),
105
+ (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
106
+
107
+ def get_img(self, w):
108
+ img = self.g_synthesis.run(None, {'w': w})[0]
109
+ return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
110
+
111
+ def get_w(self, z, psi):
112
+ return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi], dtype=np.float32)})[0]
113
+
114
+ def encode_img(self, img, iteration=5):
115
+ target_img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
116
+ np.float32)
117
+ w = self.w_avg.copy()
118
+ from_img = self.img_avg.copy()
119
+ for i in range(iteration):
120
+ dimg = np.concatenate([target_img, from_img], axis=1)
121
+ dw = self.encoder.run(None, {'dimg': dimg})[0]
122
+ w += dw
123
+ from_img = transform.resize(self.g_synthesis.run(None, {'w': w})[0][0].transpose(1, 2, 0),
124
+ (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
125
+ return w
126
+
127
+ def detect(self, im0, conf_thres, iou_thres, detail=False):
128
+ if im0 is None:
129
+ return []
130
+ img = letterbox((im0 / 255).astype(np.float32), (self.detector_imgsz, self.detector_imgsz),
131
+ stride=self.detector_stride)
132
+ # Convert
133
+ img = img.transpose(2, 0, 1)
134
+ img = img[np.newaxis, :]
135
+ pred = self.detector.run(None, {'images': img})[0][0]
136
+ dets = nms(pred, conf_thres, iou_thres)
137
+ imgs = []
138
+ # Print results
139
+ s = '%gx%g ' % img.shape[2:] # print string
140
+ for i, det in enumerate(dets):
141
+ n = len(det)
142
+ s += f"{n} {self.detector_class_names[i]}{'s' * (n > 1)}, " # add to string
143
+ if detail:
144
+ print(s)
145
+ waifu_rects = []
146
+ head_rects = []
147
+ body_rects = []
148
+
149
+ for i, det in enumerate(dets):
150
+ for x in det:
151
+ # Rescale boxes from img_size to im0 size
152
+ wr = im0.shape[1] / img.shape[3]
153
+ hr = im0.shape[0] / img.shape[2]
154
+ x[1] = (int(x[1][0] * wr), int(x[1][1] * hr),
155
+ int(x[1][2] * wr), int(x[1][3] * hr))
156
+ if i == 0:
157
+ head_rects.append(x[1])
158
+ elif i == 1:
159
+ body_rects.append(x[1])
160
+ elif i == 2:
161
+ waifu_rects.append(x[1])
162
+ for j, waifu_rect in enumerate(waifu_rects):
163
+ msg = f'waifu {j + 1} '
164
+ head_num = 0
165
+ body_num = 0
166
+ hr, br = None, None
167
+ for r in head_rects:
168
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.75:
169
+ hr = r
170
+ head_num += 1
171
+ if head_num != 1:
172
+ if detail:
173
+ print(msg + f'head num error: {head_num}')
174
+ continue
175
+ for r in body_rects:
176
+ if get_inter(r, waifu_rect) / ((r[2] - r[0]) * (r[3] - r[1])) > 0.65:
177
+ br = r
178
+ body_num += 1
179
+ if body_num != 1:
180
+ if detail:
181
+ print(msg + f'body num error: {body_num}')
182
+ continue
183
+ bounds = (int(min(waifu_rect[0], hr[0], br[0])),
184
+ int(min(waifu_rect[1], hr[1], br[1])),
185
+ int(max(waifu_rect[2], hr[2], br[2])),
186
+ int(max(waifu_rect[3], hr[3], br[3])))
187
+ if (bounds[2] - bounds[0]) / (bounds[3] - bounds[1]) > 0.7:
188
+ if detail:
189
+ print(msg + "ratio out of limit")
190
+ continue
191
+ # 扩展边界
192
+ expand_pixel = (bounds[3] - bounds[1]) // 20
193
+ bounds = (max(bounds[0] - expand_pixel // 2, 0),
194
+ max(bounds[1] - expand_pixel, 0),
195
+ min(bounds[2] + expand_pixel // 2, im0.shape[1]),
196
+ min(bounds[3] + expand_pixel, im0.shape[0]),
197
+ )
198
+ if bounds[3] - bounds[1] >= (bounds[2] - bounds[0]) * 2: # 等高度剪裁
199
+ cx = (bounds[2] + bounds[0]) // 2
200
+ h = bounds[3] - bounds[1]
201
+ w = h // 2
202
+ w2 = w // 2
203
+ l1 = max(cx - w2, 0)
204
+ r1 = min(cx + w2, im0.shape[1])
205
+ bounds = (l1, bounds[1], r1, bounds[3])
206
+ temp_bound = (w2 - (cx - l1), 0, w2 + (r1 - cx), h)
207
+ else: # 等宽度剪裁
208
+ cy = (bounds[3] + bounds[1]) // 2
209
+ w = bounds[2] - bounds[0]
210
+ h = w * 2
211
+ h2 = h // 2
212
+ tp1 = max(cy - h2, 0)
213
+ b1 = min(cy + h2, im0.shape[0])
214
+ bounds = (bounds[0], tp1, bounds[2], b1)
215
+ temp_bound = (0, h2 - (cy - tp1), w, h2 + (b1 - cy))
216
+ temp_img = np.full((h, w, 3), 255, dtype=np.uint8)
217
+ temp_img[temp_bound[1]:temp_bound[3], temp_bound[0]:temp_bound[2]] = im0[bounds[1]:bounds[3],
218
+ bounds[0]:bounds[2]]
219
+ temp_img = transform.resize(temp_img, (1024, 512), preserve_range=True).astype(np.uint8)
220
+ imgs.append(temp_img)
221
+ return imgs
222
+
223
+ def gen_video(self, w1, w2, path, frame_num=10):
224
+ video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
225
+ lin = np.linspace(0, 1, frame_num)
226
+ for i in range(0, frame_num):
227
+ img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
228
+ video.append_data(img)
229
+ video.close()
230
+
231
+
232
+ def gen_fn(use_seed, seed, psi):
233
+ z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if use_seed else np.random.randn(1, 512)
234
+ w = model.get_w(z.astype(dtype=np.float32), psi)
235
+ img_out = model.get_img(w)
236
+ return img_out, w.tolist(), img_out
237
+
238
+
239
+ def encode_img_fn(img):
240
+ imgs = model.detect(img, 0.2, 0.03)
241
+ if len(imgs) == 0:
242
+ return "failed to detect waifu", None, None
243
+ w = model.encode_img(imgs[0])
244
+ img_out = model.get_img(w)
245
+ return "success", imgs[0], img_out, w.tolist(), img_out
246
+
247
+
248
+ def gen_video_fn(w1, w2, frame):
249
+ if w1 is None or w2 is None:
250
+ return None
251
+ model.gen_video(np.array(w1, dtype=np.float32), np.array(w2, dtype=np.float32), "video.mp4", int(frame))
252
+ return "video.mp4"
253
+
254
+
255
+ if __name__ == '__main__':
256
+ model = Model()
257
+
258
+ app = gr.Blocks()
259
+ with app:
260
+ gr.Markdown("# full-body anime\n\n"
261
+ "the model is not good, just for fun.")
262
+ with gr.Tabs():
263
+ with gr.TabItem("generate image"):
264
+ with gr.Column():
265
+ with gr.Row():
266
+ gen_input1 = gr.Checkbox(value=False, label="use seed")
267
+ gen_input2 = gr.Number(value=1, label="seed")
268
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="psi")
269
+ gen_submit = gr.Button("Run")
270
+ gen_output1 = gr.Image()
271
+ select_img_input_w1 = gr.Json(visible=False)
272
+ select_img_input_img1 = gr.Image(visible=False)
273
+
274
+ with gr.TabItem("encode image"):
275
+ with gr.Column():
276
+ encode_img_input = gr.Image()
277
+ encode_img_submit = gr.Button("Run")
278
+ encode_img_output1 = gr.Textbox(label="message")
279
+ with gr.Row():
280
+ encode_img_output2 = gr.Image(label="detected")
281
+ encode_img_output3 = gr.Image(label="encoded")
282
+ select_img_input_w2 = gr.Json(visible=False)
283
+ select_img_input_img2 = gr.Image(visible=False)
284
+
285
+ with gr.TabItem("generate video"):
286
+ with gr.Column():
287
+ gr.Markdown("## generate video between 2 images")
288
+ with gr.Row():
289
+ with gr.Column():
290
+ gr.Markdown("please select image 1")
291
+ select_img1_dropdown = gr.Dropdown(label="source",
292
+ choices=["current generated image",
293
+ "current encoded image"], type="index")
294
+ select_img1_button = gr.Button("select")
295
+ select_img1_output_img = gr.Image(label="image 1")
296
+ select_img1_output_w = gr.Json(visible=False)
297
+ with gr.Column():
298
+ gr.Markdown("please select image 2")
299
+ select_img2_dropdown = gr.Dropdown(label="source",
300
+ choices=["current generated image",
301
+ "current encoded image"], type="index")
302
+ select_img2_button = gr.Button("select")
303
+ select_img2_output_img = gr.Image(label="image 2")
304
+ select_img2_output_w = gr.Json(visible=False)
305
+ generate_video_frame = gr.Slider(minimum=10, maximum=30, step=1, label="frame", value=10)
306
+ generate_video_button = gr.Button("generate")
307
+ generate_video_output = gr.Video()
308
+
309
+ gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3],
310
+ [gen_output1, select_img_input_w1, select_img_input_img1])
311
+ encode_img_submit.click(encode_img_fn, [encode_img_input],
312
+ [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
313
+ select_img_input_img2])
314
+ select_img1_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
315
+ [select_img1_dropdown, select_img_input_img1, select_img_input_img2,
316
+ select_img_input_w1, select_img_input_w2],
317
+ [select_img1_output_img, select_img1_output_w])
318
+ select_img2_button.click(lambda i, img1, img2, w1, w2: (img1, w1) if i == 0 else (img2, w2),
319
+ [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
320
+ select_img_input_w1, select_img_input_w2],
321
+ [select_img2_output_img, select_img2_output_w])
322
+ generate_video_button.click(gen_video_fn, [select_img1_output_w, select_img2_output_w, generate_video_frame],
323
+ [generate_video_output])
324
+
325
+ app.launch()