Harisreedhar commited on
Commit
0b756df
1 Parent(s): 7cb2f8d
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import time
5
+ import torch
6
+ import shutil
7
+ import gfpgan
8
+ import argparse
9
+ import platform
10
+ import datetime
11
+ import subprocess
12
+ import insightface
13
+ import onnxruntime
14
+ import numpy as np
15
+ import gradio as gr
16
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
17
+
18
+ from face_analyser import detect_conditions, analyse_face
19
+ from utils import trim_video, StreamerThread, ProcessBar, open_directory
20
+ from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list
21
+ from swapper import (
22
+ swap_face,
23
+ swap_face_with_condition,
24
+ swap_specific,
25
+ swap_options_list,
26
+ )
27
+
28
+ ## ------------------------------ USER ARGS ------------------------------
29
+
30
+ parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
31
+ parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
32
+ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
33
+ parser.add_argument(
34
+ "--colab", action="store_true", help="Enable colab mode", default=False
35
+ )
36
+ user_args = parser.parse_args()
37
+
38
+ ## ------------------------------ DEFAULTS ------------------------------
39
+
40
+ USE_COLAB = user_args.colab
41
+ USE_CUDA = user_args.cuda
42
+ DEF_OUTPUT_PATH = user_args.out_dir
43
+ WORKSPACE = None
44
+ OUTPUT_FILE = None
45
+ CURRENT_FRAME = None
46
+ STREAMER = None
47
+ DETECT_CONDITION = "left most"
48
+ DETECT_SIZE = 640
49
+ DETECT_THRESH = 0.6
50
+ NUM_OF_SRC_SPECIFIC = 10
51
+ MASK_INCLUDE = [
52
+ "Skin",
53
+ "R-Eyebrow",
54
+ "L-Eyebrow",
55
+ "L-Eye",
56
+ "R-Eye",
57
+ "Nose",
58
+ "Mouth",
59
+ "L-Lip",
60
+ "U-Lip"
61
+ ]
62
+ MASK_EXCLUDE = ["R-Ear", "L-Ear", "Hair", "Hat"]
63
+ MASK_BLUR = 25
64
+
65
+ FACE_SWAPPER = None
66
+ FACE_ANALYSER = None
67
+ FACE_ENHANCER = None
68
+ FACE_PARSER = None
69
+
70
+ ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
71
+ # Note: For AMD,MAC or non CUDA users, change settings here
72
+
73
+ PROVIDER = ["CPUExecutionProvider"]
74
+
75
+ if USE_CUDA:
76
+ available_providers = onnxruntime.get_available_providers()
77
+ if "CUDAExecutionProvider" in available_providers:
78
+ print("\n********** Running on CUDA **********\n")
79
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
80
+ else:
81
+ USE_CUDA = False
82
+ print("\n********** CUDA unavailable running on CPU **********\n")
83
+ else:
84
+ USE_CUDA = False
85
+ print("\n********** Running on CPU **********\n")
86
+
87
+
88
+ ## ------------------------------ LOAD MODELS ------------------------------
89
+
90
+ def load_face_analyser_model(name="buffalo_l"):
91
+ global FACE_ANALYSER
92
+ if FACE_ANALYSER is None:
93
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
94
+ FACE_ANALYSER.prepare(
95
+ ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
96
+ )
97
+
98
+
99
+ def load_face_swapper_model(name="./assets/pretrained_models/inswapper_128.onnx"):
100
+ global FACE_SWAPPER
101
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
102
+ if FACE_SWAPPER is None:
103
+ FACE_SWAPPER = insightface.model_zoo.get_model(path, providers=PROVIDER)
104
+
105
+
106
+ def load_face_enhancer_model(name="./assets/pretrained_models/GFPGANv1.4.pth"):
107
+ global FACE_ENHANCER
108
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
109
+ if FACE_ENHANCER is None:
110
+ FACE_ENHANCER = gfpgan.GFPGANer(model_path=path, upscale=1)
111
+
112
+
113
+ def load_face_parser_model(name="./assets/pretrained_models/79999_iter.pth"):
114
+ global FACE_PARSER
115
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
116
+ if FACE_PARSER is None:
117
+ FACE_PARSER = init_parser(name, use_cuda=USE_CUDA)
118
+
119
+
120
+ load_face_analyser_model()
121
+ load_face_swapper_model()
122
+
123
+ ## ------------------------------ MAIN PROCESS ------------------------------
124
+
125
+
126
+ def process(
127
+ input_type,
128
+ image_path,
129
+ video_path,
130
+ directory_path,
131
+ source_path,
132
+ output_path,
133
+ output_name,
134
+ keep_output_sequence,
135
+ condition,
136
+ age,
137
+ distance,
138
+ face_enhance,
139
+ enable_face_parser,
140
+ mask_include,
141
+ mask_exclude,
142
+ mask_blur,
143
+ *specifics,
144
+ ):
145
+ global WORKSPACE
146
+ global OUTPUT_FILE
147
+ global PREVIEW
148
+ WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
149
+
150
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
151
+
152
+ def ui_before():
153
+ return (
154
+ gr.update(visible=True, value=PREVIEW),
155
+ gr.update(interactive=False),
156
+ gr.update(interactive=False),
157
+ gr.update(visible=False),
158
+ )
159
+
160
+ def ui_after():
161
+ return (
162
+ gr.update(visible=True, value=PREVIEW),
163
+ gr.update(interactive=True),
164
+ gr.update(interactive=True),
165
+ gr.update(visible=False),
166
+ )
167
+
168
+ def ui_after_vid():
169
+ return (
170
+ gr.update(visible=False),
171
+ gr.update(interactive=True),
172
+ gr.update(interactive=True),
173
+ gr.update(value=OUTPUT_FILE, visible=True),
174
+ )
175
+
176
+ ## ------------------------------ LOAD PENDING MODELS ------------------------------
177
+ start_time = time.time()
178
+ specifics = list(specifics)
179
+ half = len(specifics) // 2
180
+ sources = specifics[:half]
181
+ specifics = specifics[half:]
182
+
183
+ yield "### \n ⌛ Loading face analyser model...", *ui_before()
184
+ load_face_analyser_model()
185
+
186
+ yield "### \n ⌛ Loading face swapper model...", *ui_before()
187
+ load_face_swapper_model()
188
+
189
+ if face_enhance:
190
+ yield "### \n ⌛ Loading face enhancer model...", *ui_before()
191
+ load_face_enhancer_model()
192
+
193
+ if enable_face_parser:
194
+ yield "### \n ⌛ Loading face parsing model...", *ui_before()
195
+ load_face_parser_model()
196
+
197
+ yield "### \n ⌛ Analysing Face...", *ui_before()
198
+
199
+ mi = mask_regions_to_list(mask_include)
200
+ me = mask_regions_to_list(mask_exclude)
201
+ models = {
202
+ "swap": FACE_SWAPPER,
203
+ "enhance": FACE_ENHANCER,
204
+ "enhance_sett": face_enhance,
205
+ "face_parser": FACE_PARSER,
206
+ "face_parser_sett": (enable_face_parser, mi, me, int(mask_blur)),
207
+ }
208
+
209
+ ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
210
+
211
+ analysed_source_specific = []
212
+ if condition == "Specific Face":
213
+ for source, specific in zip(sources, specifics):
214
+ if source is None or specific is None:
215
+ continue
216
+ analysed_source = analyse_face(
217
+ source,
218
+ FACE_ANALYSER,
219
+ return_single_face=True,
220
+ detect_condition=DETECT_CONDITION,
221
+ )
222
+ analysed_specific = analyse_face(
223
+ specific,
224
+ FACE_ANALYSER,
225
+ return_single_face=True,
226
+ detect_condition=DETECT_CONDITION,
227
+ )
228
+ analysed_source_specific.append([analysed_source, analysed_specific])
229
+ else:
230
+ source = cv2.imread(source_path)
231
+ analysed_source = analyse_face(
232
+ source,
233
+ FACE_ANALYSER,
234
+ return_single_face=True,
235
+ detect_condition=DETECT_CONDITION,
236
+ )
237
+
238
+ ## ------------------------------ IMAGE ------------------------------
239
+
240
+ if input_type == "Image":
241
+ target = cv2.imread(image_path)
242
+ analysed_target = analyse_face(target, FACE_ANALYSER, return_single_face=False)
243
+ if condition == "Specific Face":
244
+ swapped = swap_specific(
245
+ analysed_source_specific,
246
+ analysed_target,
247
+ target,
248
+ models,
249
+ threshold=distance,
250
+ )
251
+ else:
252
+ swapped = swap_face_with_condition(
253
+ target, analysed_target, analysed_source, condition, age, models
254
+ )
255
+
256
+ filename = os.path.join(output_path, output_name + ".png")
257
+ cv2.imwrite(filename, swapped)
258
+ OUTPUT_FILE = filename
259
+ WORKSPACE = output_path
260
+ PREVIEW = swapped[:, :, ::-1]
261
+
262
+ tot_exec_time = time.time() - start_time
263
+ _min, _sec = divmod(tot_exec_time, 60)
264
+
265
+ yield f"Completed in {int(_min)} min {int(_sec)} sec.", *ui_after()
266
+
267
+ ## ------------------------------ VIDEO ------------------------------
268
+
269
+ elif input_type == "Video":
270
+ temp_path = os.path.join(output_path, output_name, "sequence")
271
+ os.makedirs(temp_path, exist_ok=True)
272
+
273
+ video_clip = VideoFileClip(video_path)
274
+ duration = video_clip.duration
275
+ fps = video_clip.fps
276
+ total_frames = video_clip.reader.nframes
277
+
278
+ analysed_targets = []
279
+ process_bar = ProcessBar(30, total_frames)
280
+ yield "### \n ⌛ Analysing...", *ui_before()
281
+ for i, frame in enumerate(video_clip.iter_frames()):
282
+ analysed_targets.append(
283
+ analyse_face(frame, FACE_ANALYSER, return_single_face=False)
284
+ )
285
+ info_text = "Analysing Faces || "
286
+ info_text += process_bar.get(i)
287
+ print("\033[1A\033[K", end="", flush=True)
288
+ print(info_text)
289
+ if i % 10 == 0:
290
+ yield "### \n" + info_text, *ui_before()
291
+ video_clip.close()
292
+
293
+ image_sequence = []
294
+ video_clip = VideoFileClip(video_path)
295
+ audio_clip = video_clip.audio if video_clip.audio is not None else None
296
+ process_bar = ProcessBar(30, total_frames)
297
+ yield "### \n ⌛ Swapping...", *ui_before()
298
+ for i, frame in enumerate(video_clip.iter_frames()):
299
+ swapped = frame
300
+ analysed_target = analysed_targets[i]
301
+
302
+ if condition == "Specific Face":
303
+ swapped = swap_specific(
304
+ frame,
305
+ analysed_target,
306
+ analysed_source_specific,
307
+ models,
308
+ threshold=distance,
309
+ )
310
+ else:
311
+ swapped = swap_face_with_condition(
312
+ frame, analysed_target, analysed_source, condition, age, models
313
+ )
314
+
315
+ image_path = os.path.join(temp_path, f"frame_{i}.png")
316
+ cv2.imwrite(image_path, swapped[:, :, ::-1])
317
+ image_sequence.append(image_path)
318
+
319
+ info_text = "Swapping Faces || "
320
+ info_text += process_bar.get(i)
321
+ print("\033[1A\033[K", end="", flush=True)
322
+ print(info_text)
323
+ if i % 6 == 0:
324
+ PREVIEW = swapped
325
+ yield "### \n" + info_text, *ui_before()
326
+
327
+ yield "### \n ⌛ Merging...", *ui_before()
328
+ edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
329
+
330
+ if audio_clip is not None:
331
+ edited_video_clip = edited_video_clip.set_audio(audio_clip)
332
+
333
+ output_video_path = os.path.join(output_path, output_name + ".mp4")
334
+ edited_video_clip.set_duration(duration).write_videofile(
335
+ output_video_path, codec="libx264"
336
+ )
337
+ edited_video_clip.close()
338
+ video_clip.close()
339
+
340
+ if os.path.exists(temp_path) and not keep_output_sequence:
341
+ yield "### \n ⌛ Removing temporary files...", *ui_before()
342
+ shutil.rmtree(temp_path)
343
+
344
+ WORKSPACE = output_path
345
+ OUTPUT_FILE = output_video_path
346
+
347
+ tot_exec_time = time.time() - start_time
348
+ _min, _sec = divmod(tot_exec_time, 60)
349
+
350
+ yield f"✔️ Completed in {int(_min)} min {int(_sec)} sec.", *ui_after_vid()
351
+
352
+ ## ------------------------------ DIRECTORY ------------------------------
353
+
354
+ elif input_type == "Directory":
355
+ source = cv2.imread(source_path)
356
+ source = analyse_face(
357
+ source,
358
+ FACE_ANALYSER,
359
+ return_single_face=True,
360
+ detect_condition=DETECT_CONDITION,
361
+ )
362
+ extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
363
+ temp_path = os.path.join(output_path, output_name)
364
+ if os.path.exists(temp_path):
365
+ shutil.rmtree(temp_path)
366
+ os.mkdir(temp_path)
367
+ swapped = None
368
+
369
+ files = []
370
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
371
+ if any(file_path.lower().endswith(ext) for ext in extensions):
372
+ files.append(file_path)
373
+
374
+ files_length = len(files)
375
+ filename = None
376
+ for i, file_path in enumerate(files):
377
+ target = cv2.imread(file_path)
378
+ analysed_target = analyse_face(
379
+ target, FACE_ANALYSER, return_single_face=False
380
+ )
381
+
382
+ if condition == "Specific Face":
383
+ swapped = swap_specific(
384
+ target,
385
+ analysed_target,
386
+ analysed_source_specific,
387
+ models,
388
+ threshold=distance,
389
+ )
390
+ else:
391
+ swapped = swap_face_with_condition(
392
+ target, analysed_target, analysed_source, condition, age, models
393
+ )
394
+
395
+ filename = os.path.join(temp_path, os.path.basename(file_path))
396
+ cv2.imwrite(filename, swapped)
397
+ info_text = f"### \n ⌛ Processing file {i+1} of {files_length}"
398
+ PREVIEW = swapped[:, :, ::-1]
399
+ yield info_text, *ui_before()
400
+
401
+ WORKSPACE = temp_path
402
+ OUTPUT_FILE = filename
403
+
404
+ tot_exec_time = time.time() - start_time
405
+ _min, _sec = divmod(tot_exec_time, 60)
406
+
407
+ yield f"✔️ Completed in {int(_min)} min {int(_sec)} sec.", *ui_after()
408
+
409
+ ## ------------------------------ STREAM ------------------------------
410
+
411
+ elif input_type == "Stream":
412
+ yield "### \n ⌛ Starting...", *ui_before()
413
+ global STREAMER
414
+ STREAMER = StreamerThread(src=directory_path)
415
+ STREAMER.start()
416
+
417
+ while True:
418
+ try:
419
+ target = STREAMER.frame
420
+ analysed_target = analyse_face(
421
+ target, FACE_ANALYSER, return_single_face=False
422
+ )
423
+ if condition == "Specific Face":
424
+ swapped = swap_specific(
425
+ target,
426
+ analysed_target,
427
+ analysed_source_specific,
428
+ models,
429
+ threshold=distance,
430
+ )
431
+ else:
432
+ swapped = swap_face_with_condition(
433
+ target, analysed_target, analysed_source, condition, age, models
434
+ )
435
+ PREVIEW = swapped[:, :, ::-1]
436
+ yield f"Streaming...", *ui_before()
437
+ except AttributeError:
438
+ yield "Streaming...", *ui_before()
439
+ STREAMER.stop()
440
+
441
+
442
+ ## ------------------------------ GRADIO FUNC ------------------------------
443
+
444
+
445
+ def update_radio(value):
446
+ if value == "Image":
447
+ return (
448
+ gr.update(visible=True),
449
+ gr.update(visible=False),
450
+ gr.update(visible=False),
451
+ )
452
+ elif value == "Video":
453
+ return (
454
+ gr.update(visible=False),
455
+ gr.update(visible=True),
456
+ gr.update(visible=False),
457
+ )
458
+ elif value == "Directory":
459
+ return (
460
+ gr.update(visible=False),
461
+ gr.update(visible=False),
462
+ gr.update(visible=True),
463
+ )
464
+ elif value == "Stream":
465
+ return (
466
+ gr.update(visible=False),
467
+ gr.update(visible=False),
468
+ gr.update(visible=True),
469
+ )
470
+
471
+
472
+ def swap_option_changed(value):
473
+ if value == swap_options_list[1] or value == swap_options_list[2]:
474
+ return (
475
+ gr.update(visible=True),
476
+ gr.update(visible=False),
477
+ gr.update(visible=True),
478
+ )
479
+ elif value == swap_options_list[5]:
480
+ return (
481
+ gr.update(visible=False),
482
+ gr.update(visible=True),
483
+ gr.update(visible=False),
484
+ )
485
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
486
+
487
+
488
+ def video_changed(video_path):
489
+ sliders_update = gr.Slider.update
490
+ button_update = gr.Button.update
491
+ number_update = gr.Number.update
492
+
493
+ if video_path is None:
494
+ return (
495
+ sliders_update(minimum=0, maximum=0, value=0),
496
+ sliders_update(minimum=1, maximum=1, value=1),
497
+ number_update(value=1),
498
+ )
499
+ try:
500
+ clip = VideoFileClip(video_path)
501
+ fps = clip.fps
502
+ total_frames = clip.reader.nframes
503
+ clip.close()
504
+ return (
505
+ sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
506
+ sliders_update(
507
+ minimum=0, maximum=total_frames, value=total_frames, interactive=True
508
+ ),
509
+ number_update(value=fps),
510
+ )
511
+ except:
512
+ return (
513
+ sliders_update(value=0),
514
+ sliders_update(value=0),
515
+ number_update(value=1),
516
+ )
517
+
518
+
519
+ def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
520
+ yield "### \n ⌛ Applying new values..."
521
+ global FACE_ANALYSER
522
+ global DETECT_CONDITION
523
+ DETECT_CONDITION = detect_condition
524
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
525
+ FACE_ANALYSER.prepare(
526
+ ctx_id=0,
527
+ det_size=(int(detection_size), int(detection_size)),
528
+ det_thresh=float(detection_threshold),
529
+ )
530
+ yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
531
+
532
+
533
+ def stop_running():
534
+ global STREAMER
535
+ if hasattr(STREAMER, "stop"):
536
+ STREAMER.stop()
537
+ STREAMER = None
538
+ return "Cancelled"
539
+
540
+
541
+ def slider_changed(show_frame, video_path, frame_index):
542
+ if not show_frame:
543
+ return None, None
544
+ if video_path is None:
545
+ return None, None
546
+ clip = VideoFileClip(video_path)
547
+ frame = clip.get_frame(frame_index / clip.fps)
548
+ frame_array = np.array(frame)
549
+ clip.close()
550
+ return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
551
+ visible=False
552
+ )
553
+
554
+
555
+ def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
556
+ yield video_path, f"### \n ⌛ Trimming video frame {start_frame} to {stop_frame}..."
557
+ try:
558
+ output_path = os.path.join(output_path, output_name)
559
+ trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
560
+ yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
561
+ except Exception as e:
562
+ print(e)
563
+ yield video_path, "### \n ❌ Video trimming failed. See console for more info."
564
+
565
+
566
+ ## ------------------------------ GRADIO GUI ------------------------------
567
+
568
+ css = """
569
+ footer{display:none !important}
570
+ """
571
+
572
+ with gr.Blocks(css=css) as interface:
573
+ gr.Markdown("# 🗿 Swap Mukham")
574
+ gr.Markdown("### Face swap app based on insightface inswapper.")
575
+ with gr.Row():
576
+ with gr.Row():
577
+ with gr.Column(scale=0.4):
578
+ with gr.Tab("📄 Swap Condition"):
579
+ swap_option = gr.Radio(
580
+ swap_options_list,
581
+ show_label=False,
582
+ value=swap_options_list[0],
583
+ interactive=True,
584
+ )
585
+ age = gr.Number(
586
+ value=25, label="Value", interactive=True, visible=False
587
+ )
588
+
589
+ with gr.Tab("🎚️ Detection Settings"):
590
+ detect_condition_dropdown = gr.Dropdown(
591
+ detect_conditions,
592
+ label="Condition",
593
+ value=DETECT_CONDITION,
594
+ interactive=True,
595
+ info="This condition is only used when multiple faces are detected on source or specific image.",
596
+ )
597
+ detection_size = gr.Number(
598
+ label="Detection Size", value=DETECT_SIZE, interactive=True
599
+ )
600
+ detection_threshold = gr.Number(
601
+ label="Detection Threshold",
602
+ value=DETECT_THRESH,
603
+ interactive=True,
604
+ )
605
+ apply_detection_settings = gr.Button("Apply settings")
606
+
607
+ with gr.Tab("📤 Output Settings"):
608
+ output_directory = gr.Text(
609
+ label="Output Directory",
610
+ value=DEF_OUTPUT_PATH,
611
+ interactive=True,
612
+ )
613
+ output_name = gr.Text(
614
+ label="Output Name", value="Result", interactive=True
615
+ )
616
+ keep_output_sequence = gr.Checkbox(
617
+ label="Keep output sequence", value=False, interactive=True
618
+ )
619
+
620
+ with gr.Tab("🪄 Other Settings"):
621
+ with gr.Accordion("Enhance Face", open=True):
622
+ enable_face_enhance = gr.Checkbox(
623
+ label="Enable GFPGAN", value=False, interactive=True
624
+ )
625
+ with gr.Accordion("Advanced Mask", open=False):
626
+ enable_face_parser_mask = gr.Checkbox(
627
+ label="Enable Face Parsing",
628
+ value=False,
629
+ interactive=True,
630
+ )
631
+
632
+ mask_include = gr.Dropdown(
633
+ mask_regions.keys(),
634
+ value=MASK_INCLUDE,
635
+ multiselect=True,
636
+ label="Include",
637
+ interactive=True,
638
+ )
639
+ mask_exclude = gr.Dropdown(
640
+ mask_regions.keys(),
641
+ value=MASK_EXCLUDE,
642
+ multiselect=True,
643
+ label="Exclude",
644
+ interactive=True,
645
+ )
646
+ mask_blur = gr.Number(
647
+ label="Blur Mask",
648
+ value=MASK_BLUR,
649
+ minimum=0,
650
+ interactive=True,
651
+ )
652
+
653
+ source_image_input = gr.Image(
654
+ label="Source face", type="filepath", interactive=True
655
+ )
656
+
657
+ with gr.Box(visible=False) as specific_face:
658
+ for i in range(NUM_OF_SRC_SPECIFIC):
659
+ idx = i + 1
660
+ code = "\n"
661
+ code += f"with gr.Tab(label='({idx})'):"
662
+ code += "\n\twith gr.Row():"
663
+ code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
664
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
665
+ exec(code)
666
+
667
+ distance_slider = gr.Slider(
668
+ minimum=0,
669
+ maximum=2,
670
+ value=0.6,
671
+ interactive=True,
672
+ label="Distance",
673
+ info="Lower distance is more similar and higher distance is less similar to the target face.",
674
+ )
675
+
676
+ with gr.Group():
677
+ input_type = gr.Radio(
678
+ ["Image", "Video", "Directory", "Stream"],
679
+ label="Target Type",
680
+ value="Video",
681
+ )
682
+
683
+ with gr.Box(visible=False) as input_image_group:
684
+ image_input = gr.Image(
685
+ label="Target Image", interactive=True, type="filepath"
686
+ )
687
+
688
+ with gr.Box(visible=True) as input_video_group:
689
+ vid_widget = gr.Video if USE_COLAB else gr.Text
690
+ video_input = vid_widget(
691
+ label="Target Video Path", interactive=True
692
+ )
693
+ with gr.Accordion("✂️ Trim video", open=False):
694
+ with gr.Column():
695
+ with gr.Row():
696
+ set_slider_range_btn = gr.Button(
697
+ "Set frame range", interactive=True
698
+ )
699
+ show_trim_preview_btn = gr.Checkbox(
700
+ label="Show frame when slider change",
701
+ value=True,
702
+ interactive=True,
703
+ )
704
+
705
+ video_fps = gr.Number(
706
+ value=30,
707
+ interactive=False,
708
+ label="Fps",
709
+ visible=False,
710
+ )
711
+ start_frame = gr.Slider(
712
+ minimum=0,
713
+ maximum=1,
714
+ value=0,
715
+ step=1,
716
+ interactive=True,
717
+ label="Start Frame",
718
+ info="",
719
+ )
720
+ end_frame = gr.Slider(
721
+ minimum=0,
722
+ maximum=1,
723
+ value=1,
724
+ step=1,
725
+ interactive=True,
726
+ label="End Frame",
727
+ info="",
728
+ )
729
+ trim_and_reload_btn = gr.Button(
730
+ "Trim and Reload", interactive=True
731
+ )
732
+
733
+ with gr.Box(visible=False) as input_directory_group:
734
+ direc_input = gr.Text(label="Path", interactive=True)
735
+
736
+ with gr.Column(scale=0.6):
737
+ info = gr.Markdown(value="...")
738
+
739
+ with gr.Row():
740
+ swap_button = gr.Button("✨ Swap", variant="primary")
741
+ cancel_button = gr.Button("⛔ Cancel")
742
+
743
+ preview_image = gr.Image(label="Output", interactive=False)
744
+ preview_video = gr.Video(
745
+ label="Output", interactive=False, visible=False
746
+ )
747
+
748
+ with gr.Row():
749
+ output_directory_button = gr.Button(
750
+ "📂", interactive=False, visible=not USE_COLAB
751
+ )
752
+ output_video_button = gr.Button(
753
+ "🎬", interactive=False, visible=not USE_COLAB
754
+ )
755
+
756
+ with gr.Column():
757
+ gr.Markdown(
758
+ '[!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/harisreedhar)'
759
+ )
760
+ gr.Markdown(
761
+ "### [Source code](https://github.com/harisreedhar/Swap-Mukham) . [Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer) . [Gradio](https://gradio.app/)"
762
+ )
763
+
764
+ ## ------------------------------ GRADIO EVENTS ------------------------------
765
+
766
+ set_slider_range_event = set_slider_range_btn.click(
767
+ video_changed,
768
+ inputs=[video_input],
769
+ outputs=[start_frame, end_frame, video_fps],
770
+ )
771
+
772
+ trim_and_reload_event = trim_and_reload_btn.click(
773
+ fn=trim_and_reload,
774
+ inputs=[video_input, output_directory, output_name, start_frame, end_frame],
775
+ outputs=[video_input, info],
776
+ )
777
+
778
+ start_frame_event = start_frame.release(
779
+ fn=slider_changed,
780
+ inputs=[show_trim_preview_btn, video_input, start_frame],
781
+ outputs=[preview_image, preview_video],
782
+ show_progress=False,
783
+ )
784
+
785
+ end_frame_event = end_frame.release(
786
+ fn=slider_changed,
787
+ inputs=[show_trim_preview_btn, video_input, end_frame],
788
+ outputs=[preview_image, preview_video],
789
+ show_progress=False,
790
+ )
791
+
792
+ input_type.change(
793
+ update_radio,
794
+ inputs=[input_type],
795
+ outputs=[input_image_group, input_video_group, input_directory_group],
796
+ )
797
+ swap_option.change(
798
+ swap_option_changed,
799
+ inputs=[swap_option],
800
+ outputs=[age, specific_face, source_image_input],
801
+ )
802
+
803
+ apply_detection_settings.click(
804
+ analyse_settings_changed,
805
+ inputs=[detect_condition_dropdown, detection_size, detection_threshold],
806
+ outputs=[info],
807
+ )
808
+
809
+ src_specific_inputs = []
810
+ gen_variable_txt = ",".join(
811
+ [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
812
+ + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
813
+ )
814
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
815
+ swap_inputs = [
816
+ input_type,
817
+ image_input,
818
+ video_input,
819
+ direc_input,
820
+ source_image_input,
821
+ output_directory,
822
+ output_name,
823
+ keep_output_sequence,
824
+ swap_option,
825
+ age,
826
+ distance_slider,
827
+ enable_face_enhance,
828
+ enable_face_parser_mask,
829
+ mask_include,
830
+ mask_exclude,
831
+ mask_blur,
832
+ *src_specific_inputs,
833
+ ]
834
+
835
+ swap_outputs = [
836
+ info,
837
+ preview_image,
838
+ output_directory_button,
839
+ output_video_button,
840
+ preview_video,
841
+ ]
842
+
843
+ swap_event = swap_button.click(
844
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=False
845
+ )
846
+
847
+ cancel_button.click(
848
+ fn=stop_running,
849
+ inputs=None,
850
+ outputs=[info],
851
+ cancels=[
852
+ swap_event,
853
+ trim_and_reload_event,
854
+ set_slider_range_event,
855
+ start_frame_event,
856
+ end_frame_event,
857
+ ],
858
+ show_progress=False,
859
+ )
860
+ output_directory_button.click(
861
+ lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
862
+ )
863
+ output_video_button.click(
864
+ lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
865
+ )
866
+
867
+ if __name__ == "__main__":
868
+ if USE_COLAB:
869
+ print("Running in colab mode")
870
+
871
+ interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB)
assets/images/logo.png ADDED
assets/pretrained_models/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
assets/pretrained_models/GFPGANv1.4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad
3
+ size 348632874
assets/pretrained_models/inswapper_128.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
3
+ size 554253681
assets/pretrained_models/readme.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ## Downolad these models here
2
+ - [inswapper_128.onnx](https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx)
3
+ - [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth)
4
+ - [79999_iter.pth](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812)
face_analyser.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ detect_conditions = [
2
+ "left most",
3
+ "right most",
4
+ "top most",
5
+ "bottom most",
6
+ "most width",
7
+ "most height",
8
+ ]
9
+
10
+
11
+ def analyse_face(image, model, return_single_face=True, detect_condition="left most"):
12
+ faces = model.get(image)
13
+ if not return_single_face:
14
+ return faces
15
+
16
+ total_faces = len(faces)
17
+ if total_faces == 1:
18
+ return faces[0]
19
+
20
+ print(f"{total_faces} face detected. Using {detect_condition} face.")
21
+ if detect_condition == "left most":
22
+ return sorted(faces, key=lambda face: face["bbox"][0])[0]
23
+ elif detect_condition == "right most":
24
+ return sorted(faces, key=lambda face: face["bbox"][0])[-1]
25
+ elif detect_condition == "top most":
26
+ return sorted(faces, key=lambda face: face["bbox"][1])[0]
27
+ elif detect_condition == "bottom most":
28
+ return sorted(faces, key=lambda face: face["bbox"][1])[-1]
29
+ elif detect_condition == "most width":
30
+ return sorted(faces, key=lambda face: face["bbox"][2])[-1]
31
+ elif detect_condition == "most height":
32
+ return sorted(faces, key=lambda face: face["bbox"][3])[-1]
face_parsing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
face_parsing/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
face_parsing/swap.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from .model import BiSeNet
7
+
8
+ mask_regions = {
9
+ "Background":0,
10
+ "Skin":1,
11
+ "L-Eyebrow":2,
12
+ "R-Eyebrow":3,
13
+ "L-Eye":4,
14
+ "R-Eye":5,
15
+ "Eye-G":6,
16
+ "L-Ear":7,
17
+ "R-Ear":8,
18
+ "Ear-R":9,
19
+ "Nose":10,
20
+ "Mouth":11,
21
+ "U-Lip":12,
22
+ "L-Lip":13,
23
+ "Neck":14,
24
+ "Neck-L":15,
25
+ "Cloth":16,
26
+ "Hair":17,
27
+ "Hat":18
28
+ }
29
+
30
+ run_with_cuda = False
31
+
32
+ def init_parser(pth_path, use_cuda=False):
33
+ global run_with_cuda
34
+ run_with_cuda = use_cuda
35
+
36
+ n_classes = 19
37
+ net = BiSeNet(n_classes=n_classes)
38
+ if run_with_cuda:
39
+ net.cuda()
40
+ net.load_state_dict(torch.load(pth_path))
41
+ else:
42
+ net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
43
+ net.eval()
44
+ return net
45
+
46
+
47
+ def image_to_parsing(img, net):
48
+ img = cv2.resize(img, (512, 512))
49
+ img = img[:,:,::-1]
50
+ transform = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
53
+ ])
54
+ img = transform(img.copy())
55
+ img = torch.unsqueeze(img, 0)
56
+
57
+ with torch.no_grad():
58
+ if run_with_cuda:
59
+ img = img.cuda()
60
+ out = net(img)[0]
61
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
62
+ return parsing
63
+
64
+
65
+ def get_mask(parsing, classes):
66
+ res = parsing == classes[0]
67
+ for val in classes[1:]:
68
+ res += parsing == val
69
+ return res
70
+
71
+ def swap_regions(source, target, net, includes=[1,2,3,4,5,10,11,12,13], excludes=[7,8], blur_size=25):
72
+ parsing = image_to_parsing(source, net)
73
+ if len(includes) == 0:
74
+ return source, np.zeros_like(source)
75
+ include_mask = get_mask(parsing, includes)
76
+ include_mask = np.repeat(np.expand_dims(include_mask.astype('float32'), axis=2), 3, 2)
77
+ if len(excludes) > 0:
78
+ exclude_mask = get_mask(parsing, excludes)
79
+ exclude_mask = np.repeat(np.expand_dims(exclude_mask.astype('float32'), axis=2), 3, 2)
80
+ include_mask -= exclude_mask
81
+ mask = 1 - cv2.GaussianBlur(include_mask.clip(0,1), (0, 0), blur_size)
82
+ result = (1 - mask) * cv2.resize(source, (512, 512)) + mask * cv2.resize(target, (512, 512))
83
+ result = cv2.resize(result.astype("float32"), (source.shape[1], source.shape[0]))
84
+ return result, mask.astype('float32')
85
+
86
+ def mask_regions_to_list(values):
87
+ out_ids = []
88
+ for value in values:
89
+ if value in mask_regions.keys():
90
+ out_ids.append(mask_regions.get(value))
91
+ return out_ids
gfpgan/weights/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
gfpgan/weights/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
2
+ gradio>=3.33.1
3
+ insightface==0.7.3
4
+ moviepy>=1.0.3
5
+ numpy
6
+ opencv-python>=4.7.0.72
7
+ opencv-python-headless>=4.7.0.72
8
+ onnx==1.14.0
9
+ onnxruntime==1.15.0
10
+ gfpgan==1.3.8
swapper.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from insightface.utils import face_align
4
+ from face_parsing.swap import swap_regions
5
+ from utils import add_logo_to_image
6
+
7
+ swap_options_list = [
8
+ "All face",
9
+ "Age less than",
10
+ "Age greater than",
11
+ "All Male",
12
+ "All Female",
13
+ "Specific Face",
14
+ ]
15
+
16
+
17
+ def swap_face(whole_img, target_face, source_face, models):
18
+ inswapper = models.get("swap")
19
+ face_enhancer = models.get("enhance", None)
20
+ face_parser = models.get("face_parser", None)
21
+ fe_enable = models.get("enhance_sett", False)
22
+
23
+ bgr_fake, M = inswapper.get(whole_img, target_face, source_face, paste_back=False)
24
+ image_size = 128 if not fe_enable else 512
25
+ aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
26
+
27
+ if face_parser is not None:
28
+ fp_enable, mi, me, mb = models.get("face_parser_sett")
29
+ if fp_enable:
30
+ bgr_fake, parsed_mask = swap_regions(
31
+ bgr_fake, aimg, face_parser, includes=mi, excludes=me, blur_size=mb
32
+ )
33
+
34
+ if fe_enable:
35
+ _, bgr_fake, _ = face_enhancer.enhance(
36
+ bgr_fake, paste_back=True, has_aligned=True
37
+ )
38
+ bgr_fake = bgr_fake[0]
39
+ M /= 0.25
40
+
41
+ IM = cv2.invertAffineTransform(M)
42
+
43
+ img_white = np.full((aimg.shape[0], aimg.shape[1]), 255, dtype=np.float32)
44
+ bgr_fake = cv2.warpAffine(
45
+ bgr_fake, IM, (whole_img.shape[1], whole_img.shape[0]), borderValue=0.0
46
+ )
47
+ img_white = cv2.warpAffine(
48
+ img_white, IM, (whole_img.shape[1], whole_img.shape[0]), borderValue=0.0
49
+ )
50
+ img_white[img_white > 20] = 255
51
+ img_mask = img_white
52
+ mask_h_inds, mask_w_inds = np.where(img_mask == 255)
53
+ mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
54
+ mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
55
+ mask_size = int(np.sqrt(mask_h * mask_w))
56
+
57
+ k = max(mask_size // 10, 10)
58
+ img_mask = cv2.erode(img_mask, np.ones((k, k), np.uint8), iterations=1)
59
+
60
+ k = max(mask_size // 20, 5)
61
+ kernel_size = (k, k)
62
+ blur_size = tuple(2 * i + 1 for i in kernel_size)
63
+ img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) / 255
64
+
65
+ img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1])
66
+ fake_merged = img_mask * bgr_fake + (1 - img_mask) * whole_img.astype(np.float32)
67
+ fake_merged = add_logo_to_image(fake_merged.astype("uint8"))
68
+ return fake_merged
69
+
70
+
71
+ def swap_face_with_condition(
72
+ whole_img, target_faces, source_face, condition, age, models
73
+ ):
74
+ swapped = whole_img.copy()
75
+
76
+ for target_face in target_faces:
77
+ if condition == "All face":
78
+ swapped = swap_face(swapped, target_face, source_face, models)
79
+ elif condition == "Age less than" and target_face["age"] < age:
80
+ swapped = swap_face(swapped, target_face, source_face, models)
81
+ elif condition == "Age greater than" and target_face["age"] > age:
82
+ swapped = swap_face(swapped, target_face, source_face, models)
83
+ elif condition == "All Male" and target_face["gender"] == 1:
84
+ swapped = swap_face(swapped, target_face, source_face, models)
85
+ elif condition == "All Female" and target_face["gender"] == 0:
86
+ swapped = swap_face(swapped, target_face, source_face, models)
87
+
88
+ return swapped
89
+
90
+
91
+ def swap_specific(source_specifics, target_faces, whole_img, models, threshold=0.6):
92
+ swapped = whole_img.copy()
93
+
94
+ for source_face, specific_face in source_specifics:
95
+ specific_embed = specific_face["embedding"]
96
+ specific_embed /= np.linalg.norm(specific_embed)
97
+
98
+ for target_face in target_faces:
99
+ target_embed = target_face["embedding"]
100
+ target_embed /= np.linalg.norm(target_embed)
101
+ cosine_distance = 1 - np.dot(specific_embed, target_embed)
102
+ if cosine_distance > threshold:
103
+ continue
104
+ swapped = swap_face(swapped, target_face, source_face, models)
105
+
106
+ return swapped
utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import shutil
6
+ import platform
7
+ import datetime
8
+ import subprocess
9
+ from threading import Thread
10
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
11
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
12
+
13
+
14
+ def trim_video(video_path, output_path, start_frame, stop_frame):
15
+ video_name, _ = os.path.splitext(os.path.basename(video_path))
16
+ trimmed_video_filename = video_name + "_trimmed" + ".mp4"
17
+ temp_path = os.path.join(output_path, "trim")
18
+ os.makedirs(temp_path, exist_ok=True)
19
+ trimmed_video_file_path = os.path.join(temp_path, trimmed_video_filename)
20
+
21
+ video = VideoFileClip(video_path)
22
+ fps = video.fps
23
+ start_time = start_frame / fps
24
+ duration = (stop_frame - start_frame) / fps
25
+
26
+ trimmed_video = video.subclip(start_time, start_time + duration)
27
+ trimmed_video.write_videofile(
28
+ trimmed_video_file_path, codec="libx264", audio_codec="aac"
29
+ )
30
+ trimmed_video.close()
31
+ video.close()
32
+
33
+ return trimmed_video_file_path
34
+
35
+
36
+ def open_directory(path=None):
37
+ if path is None:
38
+ return
39
+ try:
40
+ os.startfile(path)
41
+ except:
42
+ subprocess.Popen(["xdg-open", path])
43
+
44
+
45
+ class StreamerThread(object):
46
+ def __init__(self, src=0):
47
+ self.capture = cv2.VideoCapture(src)
48
+ self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
49
+ self.FPS = 1 / 30
50
+ self.FPS_MS = int(self.FPS * 1000)
51
+ self.thread = None
52
+ self.stopped = False
53
+ self.frame = None
54
+
55
+ def start(self):
56
+ self.thread = Thread(target=self.update, args=())
57
+ self.thread.daemon = True
58
+ self.thread.start()
59
+
60
+ def stop(self):
61
+ self.stopped = True
62
+ self.thread.join()
63
+ print("stopped")
64
+
65
+ def update(self):
66
+ while not self.stopped:
67
+ if self.capture.isOpened():
68
+ (self.status, self.frame) = self.capture.read()
69
+ time.sleep(self.FPS)
70
+
71
+
72
+ class ProcessBar:
73
+ def __init__(self, bar_length, total, before="⬛", after="🟨"):
74
+ self.bar_length = bar_length
75
+ self.total = total
76
+ self.before = before
77
+ self.after = after
78
+ self.bar = [self.before] * bar_length
79
+ self.start_time = time.time()
80
+
81
+ def get(self, index):
82
+ total = self.total
83
+ elapsed_time = time.time() - self.start_time
84
+ average_time_per_iteration = elapsed_time / (index + 1)
85
+ remaining_iterations = total - (index + 1)
86
+ estimated_remaining_time = remaining_iterations * average_time_per_iteration
87
+
88
+ self.bar[int(index / total * self.bar_length)] = self.after
89
+ info_text = f"({index+1}/{total}) {''.join(self.bar)} "
90
+ info_text += f"(ETR: {int(estimated_remaining_time // 60)} min {int(estimated_remaining_time % 60)} sec)"
91
+ return info_text
92
+
93
+
94
+ logo_image = cv2.imread("./assets/images/logo.png", cv2.IMREAD_UNCHANGED)
95
+
96
+
97
+ def add_logo_to_image(img, logo=logo_image):
98
+ logo_size = int(img.shape[1] * 0.1)
99
+ logo = cv2.resize(logo, (logo_size, logo_size))
100
+ if logo.shape[2] == 4:
101
+ alpha = logo[:, :, 3]
102
+ else:
103
+ alpha = np.ones_like(logo[:, :, 0]) * 255
104
+ padding = int(logo_size * 0.1)
105
+ roi = img.shape[0] - logo_size - padding, img.shape[1] - logo_size - padding
106
+ for c in range(0, 3):
107
+ img[roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c] = (
108
+ alpha / 255.0
109
+ ) * logo[:, :, c] + (1 - alpha / 255.0) * img[
110
+ roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
111
+ ]
112
+ return img