charbelgrower commited on
Commit
50d5956
β€’
1 Parent(s): 9461142
Files changed (2) hide show
  1. app.py +29 -27
  2. merge +1675 -0
app.py CHANGED
@@ -28,7 +28,7 @@ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_
28
 
29
  ## ------------------------------ USER ARGS ------------------------------
30
 
31
- parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
32
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
33
  parser.add_argument("--batch_size", help="Gpu batch size", default=32)
34
  parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
@@ -79,23 +79,12 @@ FACE_ENHANCER_LIST.extend(cv2_interpolations)
79
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
80
  # Note: Non CUDA users may change settings here
81
 
82
- PROVIDER = ["CPUExecutionProvider"]
 
83
 
84
- if USE_CUDA:
85
- available_providers = onnxruntime.get_available_providers()
86
- if "CUDAExecutionProvider" in available_providers:
87
- print("\n********** Running on CUDA **********\n")
88
- PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
89
- else:
90
- USE_CUDA = False
91
- print("\n********** CUDA unavailable running on CPU **********\n")
92
- else:
93
- USE_CUDA = False
94
- print("\n********** Running on CPU **********\n")
95
-
96
- device = "cuda" if USE_CUDA else "cpu"
97
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
98
 
 
99
  ## ------------------------------ LOAD MODELS ------------------------------
100
 
101
  def load_face_analyser_model(name="buffalo_l"):
@@ -131,7 +120,7 @@ load_face_swapper_model()
131
  ## ------------------------------ MAIN PROCESS ------------------------------
132
 
133
 
134
- @spaces.GPU(duration=300, enable_queue=True)
135
  def process(
136
  input_type,
137
  image_path,
@@ -162,9 +151,22 @@ def process(
162
  global WORKSPACE
163
  global OUTPUT_FILE
164
  global PREVIEW
165
- WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
166
-
167
- ## ------------------------------ GUI UPDATE FUNC ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def ui_before():
170
  return (
@@ -230,14 +232,14 @@ def process(
230
  def swap_process(image_sequence):
231
  ## ------------------------------ CONTENT CHECK ------------------------------
232
 
233
- yield "### \n βŒ› Checking contents...", *ui_before()
234
- nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
235
- if nsfw:
236
- message = "NSFW Content detected !!!"
237
- yield f"### \n πŸ”ž {message}", *ui_before()
238
- assert not nsfw, message
239
- return False
240
- EMPTY_CACHE()
241
 
242
  ## ------------------------------ ANALYSE FACE ------------------------------
243
 
 
28
 
29
  ## ------------------------------ USER ARGS ------------------------------
30
 
31
+ parser = argparse.ArgumentParser(description="Swap Face Swapper")
32
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
33
  parser.add_argument("--batch_size", help="Gpu batch size", default=32)
34
  parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
 
79
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
80
  # Note: Non CUDA users may change settings here
81
 
82
+ PROVIDER = ["CPUExecutionProvider"] # Default to CPU provider
83
+ device = "cpu"
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
86
 
87
+
88
  ## ------------------------------ LOAD MODELS ------------------------------
89
 
90
  def load_face_analyser_model(name="buffalo_l"):
 
120
  ## ------------------------------ MAIN PROCESS ------------------------------
121
 
122
 
123
+ @spaces.GPU(duration=600, enable_queue=True)
124
  def process(
125
  input_type,
126
  image_path,
 
151
  global WORKSPACE
152
  global OUTPUT_FILE
153
  global PREVIEW
154
+ global USE_CUDA # Access global variables
155
+ global device
156
+ global PROVIDER
157
+ global FACE_ANALYSER, FACE_SWAPPER, FACE_ENHANCER, FACE_PARSER, NSFW_DETECTOR
158
+
159
+ # Set CUDA usage and device
160
+ USE_CUDA = True
161
+ device = "cuda"
162
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
163
+
164
+ # Reset models to None to reload them with GPU
165
+ FACE_ANALYSER = None
166
+ FACE_SWAPPER = None
167
+ FACE_ENHANCER = None
168
+ FACE_PARSER = None
169
+ NSFW_DETECTOR = None ## ------------------------------ GUI UPDATE FUNC ------------------------------
170
 
171
  def ui_before():
172
  return (
 
232
  def swap_process(image_sequence):
233
  ## ------------------------------ CONTENT CHECK ------------------------------
234
 
235
+ # yield "### \n βŒ› Checking contents...", *ui_before()
236
+ # nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
237
+ # if nsfw:
238
+ # message = "NSFW Content detected !!!"
239
+ # yield f"### \n πŸ”ž {message}", *ui_before()
240
+ # assert not nsfw, message
241
+ # return False
242
+ # EMPTY_CACHE()
243
 
244
  ## ------------------------------ ANALYSE FACE ------------------------------
245
 
merge ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### FACE_ENHANCER.PY CODE START ###
2
+
3
+ import os
4
+ import cv2
5
+ import torch
6
+ import gfpgan
7
+ from PIL import Image
8
+ from upscaler.RealESRGAN import RealESRGAN
9
+ from upscaler.codeformer import CodeFormerEnhancer
10
+
11
+ def gfpgan_runner(img, model):
12
+ _, imgs, _ = model.enhance(img, paste_back=True, has_aligned=True)
13
+ return imgs[0]
14
+
15
+
16
+ def realesrgan_runner(img, model):
17
+ img = model.predict(img)
18
+ return img
19
+
20
+
21
+ def codeformer_runner(img, model):
22
+ img = model.enhance(img)
23
+ return img
24
+
25
+
26
+ supported_enhancers = {
27
+ "CodeFormer": ("./assets/pretrained_models/codeformer.onnx", codeformer_runner),
28
+ "GFPGAN": ("./assets/pretrained_models/GFPGANv1.4.pth", gfpgan_runner),
29
+ "REAL-ESRGAN 2x": ("./assets/pretrained_models/RealESRGAN_x2.pth", realesrgan_runner),
30
+ "REAL-ESRGAN 4x": ("./assets/pretrained_models/RealESRGAN_x4.pth", realesrgan_runner),
31
+ "REAL-ESRGAN 8x": ("./assets/pretrained_models/RealESRGAN_x8.pth", realesrgan_runner)
32
+ }
33
+
34
+ cv2_interpolations = ["LANCZOS4", "CUBIC", "NEAREST"]
35
+
36
+ def get_available_enhancer_names():
37
+ available = []
38
+ for name, data in supported_enhancers.items():
39
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), data[0])
40
+ if os.path.exists(path):
41
+ available.append(name)
42
+ return available
43
+
44
+
45
+ def load_face_enhancer_model(name='GFPGAN', device="cpu"):
46
+ assert name in get_available_enhancer_names() + cv2_interpolations, f"Face enhancer {name} unavailable."
47
+ if name in supported_enhancers.keys():
48
+ model_path, model_runner = supported_enhancers.get(name)
49
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
50
+ if name == 'CodeFormer':
51
+ model = CodeFormerEnhancer(model_path=model_path, device=device)
52
+ elif name == 'GFPGAN':
53
+ model = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=device)
54
+ elif name == 'REAL-ESRGAN 2x':
55
+ model = RealESRGAN(device, scale=2)
56
+ model.load_weights(model_path, download=False)
57
+ elif name == 'REAL-ESRGAN 4x':
58
+ model = RealESRGAN(device, scale=4)
59
+ model.load_weights(model_path, download=False)
60
+ elif name == 'REAL-ESRGAN 8x':
61
+ model = RealESRGAN(device, scale=8)
62
+ model.load_weights(model_path, download=False)
63
+ elif name == 'LANCZOS4':
64
+ model = None
65
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
66
+ elif name == 'CUBIC':
67
+ model = None
68
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
69
+ elif name == 'NEAREST':
70
+ model = None
71
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
72
+ else:
73
+ model = None
74
+ return (model, model_runner)
75
+
76
+
77
+ #### FACE_EHNANCER.PY CODE END ###
78
+
79
+ #### FACE_SWAPPER.PY CODE START ###
80
+
81
+ import time
82
+ import torch
83
+ import onnx
84
+ import cv2
85
+ import onnxruntime
86
+ import numpy as np
87
+ from tqdm import tqdm
88
+ import torch.nn as nn
89
+ from onnx import numpy_helper
90
+ from skimage import transform as trans
91
+ import torchvision.transforms.functional as F
92
+ import torch.nn.functional as F
93
+ from utils import mask_crop, laplacian_blending
94
+
95
+
96
+ arcface_dst = np.array(
97
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
98
+ [41.5493, 92.3655], [70.7299, 92.2041]],
99
+ dtype=np.float32)
100
+
101
+ def estimate_norm(lmk, image_size=112, mode='arcface'):
102
+ assert lmk.shape == (5, 2)
103
+ assert image_size % 112 == 0 or image_size % 128 == 0
104
+ if image_size % 112 == 0:
105
+ ratio = float(image_size) / 112.0
106
+ diff_x = 0
107
+ else:
108
+ ratio = float(image_size) / 128.0
109
+ diff_x = 8.0 * ratio
110
+ dst = arcface_dst * ratio
111
+ dst[:, 0] += diff_x
112
+ tform = trans.SimilarityTransform()
113
+ tform.estimate(lmk, dst)
114
+ M = tform.params[0:2, :]
115
+ return M
116
+
117
+
118
+ def norm_crop2(img, landmark, image_size=112, mode='arcface'):
119
+ M = estimate_norm(landmark, image_size, mode)
120
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
121
+ return warped, M
122
+
123
+
124
+ class Inswapper():
125
+ def __init__(self, model_file=None, batch_size=32, providers=['CPUExecutionProvider']):
126
+ self.model_file = model_file
127
+ self.batch_size = batch_size
128
+
129
+ model = onnx.load(self.model_file)
130
+ graph = model.graph
131
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
132
+
133
+ self.session_options = onnxruntime.SessionOptions()
134
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=providers)
135
+
136
+ def forward(self, imgs, latents):
137
+ preds = []
138
+ for img, latent in zip(imgs, latents):
139
+ img = img / 255
140
+ pred = self.session.run(['output'], {'target': img, 'source': latent})[0]
141
+ preds.append(pred)
142
+
143
+ def get(self, imgs, target_faces, source_faces):
144
+ imgs = list(imgs)
145
+
146
+ preds = [None] * len(imgs)
147
+ matrs = [None] * len(imgs)
148
+
149
+ for idx, (img, target_face, source_face) in enumerate(zip(imgs, target_faces, source_faces)):
150
+ matrix, blob, latent = self.prepare_data(img, target_face, source_face)
151
+ pred = self.session.run(['output'], {'target': blob, 'source': latent})[0]
152
+ pred = pred.transpose((0, 2, 3, 1))[0]
153
+ pred = np.clip(255 * pred, 0, 255).astype(np.uint8)[:, :, ::-1]
154
+
155
+ preds[idx] = pred
156
+ matrs[idx] = matrix
157
+
158
+ return (preds, matrs)
159
+
160
+ def prepare_data(self, img, target_face, source_face):
161
+ if isinstance(img, str):
162
+ img = cv2.imread(img)
163
+
164
+ aligned_img, matrix = norm_crop2(img, target_face.kps, 128)
165
+
166
+ blob = cv2.dnn.blobFromImage(aligned_img, 1.0 / 255, (128, 128), (0., 0., 0.), swapRB=True)
167
+
168
+ latent = source_face.normed_embedding.reshape((1, -1))
169
+ latent = np.dot(latent, self.emap)
170
+ latent /= np.linalg.norm(latent)
171
+
172
+ return (matrix, blob, latent)
173
+
174
+ def batch_forward(self, img_list, target_f_list, source_f_list):
175
+ num_samples = len(img_list)
176
+ num_batches = (num_samples + self.batch_size - 1) // self.batch_size
177
+
178
+ for i in tqdm(range(num_batches), desc="Generating face"):
179
+ start_idx = i * self.batch_size
180
+ end_idx = min((i + 1) * self.batch_size, num_samples)
181
+
182
+ batch_img = img_list[start_idx:end_idx]
183
+ batch_target_f = target_f_list[start_idx:end_idx]
184
+ batch_source_f = source_f_list[start_idx:end_idx]
185
+
186
+ batch_pred, batch_matr = self.get(batch_img, batch_target_f, batch_source_f)
187
+
188
+ yield batch_pred, batch_matr
189
+
190
+
191
+ def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
192
+ inv_matrix = cv2.invertAffineTransform(matrix)
193
+ fg_shape = foreground.shape[:2]
194
+ bg_shape = (background.shape[1], background.shape[0])
195
+ foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0)
196
+
197
+ if mask is None:
198
+ mask = np.full(fg_shape, 1., dtype=np.float32)
199
+ mask = mask_crop(mask, crop_mask)
200
+ mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
201
+ else:
202
+ assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
203
+ mask = mask_crop(mask, crop_mask).astype('float32')
204
+ mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
205
+
206
+ _mask = mask.copy()
207
+ _mask[_mask > 0.05] = 1.
208
+ non_zero_points = cv2.findNonZero(_mask)
209
+ _, _, w, h = cv2.boundingRect(non_zero_points)
210
+ mask_size = int(np.sqrt(w * h))
211
+
212
+ if erode_amount > 0:
213
+ kernel_size = max(int(mask_size * erode_amount), 1)
214
+ structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
215
+ mask = cv2.erode(mask, structuring_element)
216
+
217
+ if blur_amount > 0:
218
+ kernel_size = max(int(mask_size * blur_amount), 3)
219
+ if kernel_size % 2 == 0:
220
+ kernel_size += 1
221
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
222
+
223
+ mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
224
+
225
+ if blend_method == 'laplacian':
226
+ composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
227
+ else:
228
+ composite_image = mask * foreground + (1 - mask) * background
229
+
230
+ return composite_image.astype("uint8").clip(0, 255)
231
+
232
+ #### FACE_SWAPPER.PY CODE END ###
233
+
234
+
235
+ #### FACE_ANALYSER.PY CODE START ###
236
+
237
+ import os
238
+ import cv2
239
+ import numpy as np
240
+ from tqdm import tqdm
241
+ from utils import scale_bbox_from_center
242
+
243
+ detect_conditions = [
244
+ "best detection",
245
+ "left most",
246
+ "right most",
247
+ "top most",
248
+ "bottom most",
249
+ "middle",
250
+ "biggest",
251
+ "smallest",
252
+ ]
253
+
254
+ swap_options_list = [
255
+ "All Face",
256
+ "Specific Face",
257
+ "Age less than",
258
+ "Age greater than",
259
+ "All Male",
260
+ "All Female",
261
+ "Left Most",
262
+ "Right Most",
263
+ "Top Most",
264
+ "Bottom Most",
265
+ "Middle",
266
+ "Biggest",
267
+ "Smallest",
268
+ ]
269
+
270
+ def get_single_face(faces, method="best detection"):
271
+ total_faces = len(faces)
272
+ if total_faces == 1:
273
+ return faces[0]
274
+
275
+ print(f"{total_faces} face detected. Using {method} face.")
276
+ if method == "best detection":
277
+ return sorted(faces, key=lambda face: face["det_score"])[-1]
278
+ elif method == "left most":
279
+ return sorted(faces, key=lambda face: face["bbox"][0])[0]
280
+ elif method == "right most":
281
+ return sorted(faces, key=lambda face: face["bbox"][0])[-1]
282
+ elif method == "top most":
283
+ return sorted(faces, key=lambda face: face["bbox"][1])[0]
284
+ elif method == "bottom most":
285
+ return sorted(faces, key=lambda face: face["bbox"][1])[-1]
286
+ elif method == "middle":
287
+ return sorted(faces, key=lambda face: (
288
+ (face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
289
+ ((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
290
+ elif method == "biggest":
291
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
292
+ elif method == "smallest":
293
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
294
+
295
+
296
+ def analyse_face(image, model, return_single_face=True, detect_condition="best detection", scale=1.0):
297
+ faces = model.get(image)
298
+ if scale != 1: # landmark-scale
299
+ for i, face in enumerate(faces):
300
+ landmark = face['kps']
301
+ center = np.mean(landmark, axis=0)
302
+ landmark = center + (landmark - center) * scale
303
+ faces[i]['kps'] = landmark
304
+
305
+ if not return_single_face:
306
+ return faces
307
+
308
+ return get_single_face(faces, method=detect_condition)
309
+
310
+
311
+ def cosine_distance(a, b):
312
+ a /= np.linalg.norm(a)
313
+ b /= np.linalg.norm(b)
314
+ return 1 - np.dot(a, b)
315
+
316
+
317
+ def get_analysed_data(face_analyser, image_sequence, source_data, swap_condition="All face", detect_condition="left most", scale=1.0):
318
+ if swap_condition != "Specific Face":
319
+ source_path, age = source_data
320
+ source_image = cv2.imread(source_path)
321
+ analysed_source = analyse_face(source_image, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
322
+ else:
323
+ analysed_source_specifics = []
324
+ source_specifics, threshold = source_data
325
+ for source, specific in zip(*source_specifics):
326
+ if source is None or specific is None:
327
+ continue
328
+ analysed_source = analyse_face(source, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
329
+ analysed_specific = analyse_face(specific, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
330
+ analysed_source_specifics.append([analysed_source, analysed_specific])
331
+
332
+ analysed_target_list = []
333
+ analysed_source_list = []
334
+ whole_frame_eql_list = []
335
+ num_faces_per_frame = []
336
+
337
+ total_frames = len(image_sequence)
338
+ curr_idx = 0
339
+ for curr_idx, frame_path in tqdm(enumerate(image_sequence), total=total_frames, desc="Analysing face data"):
340
+ frame = cv2.imread(frame_path)
341
+ analysed_faces = analyse_face(frame, face_analyser, return_single_face=False, detect_condition=detect_condition, scale=scale)
342
+
343
+ n_faces = 0
344
+ for analysed_face in analysed_faces:
345
+ if swap_condition == "All Face":
346
+ analysed_target_list.append(analysed_face)
347
+ analysed_source_list.append(analysed_source)
348
+ whole_frame_eql_list.append(frame_path)
349
+ n_faces += 1
350
+ elif swap_condition == "Age less than" and analysed_face["age"] < age:
351
+ analysed_target_list.append(analysed_face)
352
+ analysed_source_list.append(analysed_source)
353
+ whole_frame_eql_list.append(frame_path)
354
+ n_faces += 1
355
+ elif swap_condition == "Age greater than" and analysed_face["age"] > age:
356
+ analysed_target_list.append(analysed_face)
357
+ analysed_source_list.append(analysed_source)
358
+ whole_frame_eql_list.append(frame_path)
359
+ n_faces += 1
360
+ elif swap_condition == "All Male" and analysed_face["gender"] == 1:
361
+ analysed_target_list.append(analysed_face)
362
+ analysed_source_list.append(analysed_source)
363
+ whole_frame_eql_list.append(frame_path)
364
+ n_faces += 1
365
+ elif swap_condition == "All Female" and analysed_face["gender"] == 0:
366
+ analysed_target_list.append(analysed_face)
367
+ analysed_source_list.append(analysed_source)
368
+ whole_frame_eql_list.append(frame_path)
369
+ n_faces += 1
370
+ elif swap_condition == "Specific Face":
371
+ for analysed_source, analysed_specific in analysed_source_specifics:
372
+ distance = cosine_distance(analysed_specific["embedding"], analysed_face["embedding"])
373
+ if distance < threshold:
374
+ analysed_target_list.append(analysed_face)
375
+ analysed_source_list.append(analysed_source)
376
+ whole_frame_eql_list.append(frame_path)
377
+ n_faces += 1
378
+
379
+ if swap_condition == "Left Most":
380
+ analysed_face = get_single_face(analysed_faces, method="left most")
381
+ analysed_target_list.append(analysed_face)
382
+ analysed_source_list.append(analysed_source)
383
+ whole_frame_eql_list.append(frame_path)
384
+ n_faces += 1
385
+
386
+ elif swap_condition == "Right Most":
387
+ analysed_face = get_single_face(analysed_faces, method="right most")
388
+ analysed_target_list.append(analysed_face)
389
+ analysed_source_list.append(analysed_source)
390
+ whole_frame_eql_list.append(frame_path)
391
+ n_faces += 1
392
+
393
+ elif swap_condition == "Top Most":
394
+ analysed_face = get_single_face(analysed_faces, method="top most")
395
+ analysed_target_list.append(analysed_face)
396
+ analysed_source_list.append(analysed_source)
397
+ whole_frame_eql_list.append(frame_path)
398
+ n_faces += 1
399
+
400
+ elif swap_condition == "Bottom Most":
401
+ analysed_face = get_single_face(analysed_faces, method="bottom most")
402
+ analysed_target_list.append(analysed_face)
403
+ analysed_source_list.append(analysed_source)
404
+ whole_frame_eql_list.append(frame_path)
405
+ n_faces += 1
406
+
407
+ elif swap_condition == "Middle":
408
+ analysed_face = get_single_face(analysed_faces, method="middle")
409
+ analysed_target_list.append(analysed_face)
410
+ analysed_source_list.append(analysed_source)
411
+ whole_frame_eql_list.append(frame_path)
412
+ n_faces += 1
413
+
414
+ elif swap_condition == "Biggest":
415
+ analysed_face = get_single_face(analysed_faces, method="biggest")
416
+ analysed_target_list.append(analysed_face)
417
+ analysed_source_list.append(analysed_source)
418
+ whole_frame_eql_list.append(frame_path)
419
+ n_faces += 1
420
+
421
+ elif swap_condition == "Smallest":
422
+ analysed_face = get_single_face(analysed_faces, method="smallest")
423
+ analysed_target_list.append(analysed_face)
424
+ analysed_source_list.append(analysed_source)
425
+ whole_frame_eql_list.append(frame_path)
426
+ n_faces += 1
427
+
428
+ num_faces_per_frame.append(n_faces)
429
+
430
+ return analysed_target_list, analysed_source_list, whole_frame_eql_list, num_faces_per_frame
431
+
432
+
433
+ #### FACE_ANALYSER.PY CODE END ###
434
+
435
+ #### UTILS.PY CODE START ###
436
+
437
+
438
+ import os
439
+ import cv2
440
+ import time
441
+ import glob
442
+ import shutil
443
+ import platform
444
+ import datetime
445
+ import subprocess
446
+ import numpy as np
447
+ from threading import Thread
448
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
449
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
450
+
451
+
452
+ logo_image = cv2.imread("./assets/images/logo.png", cv2.IMREAD_UNCHANGED)
453
+
454
+
455
+ quality_types = ["poor", "low", "medium", "high", "best"]
456
+
457
+
458
+ bitrate_quality_by_resolution = {
459
+ 240: {"poor": "300k", "low": "500k", "medium": "800k", "high": "1000k", "best": "1200k"},
460
+ 360: {"poor": "500k","low": "800k","medium": "1200k","high": "1500k","best": "2000k"},
461
+ 480: {"poor": "800k","low": "1200k","medium": "2000k","high": "2500k","best": "3000k"},
462
+ 720: {"poor": "1500k","low": "2500k","medium": "4000k","high": "5000k","best": "6000k"},
463
+ 1080: {"poor": "2500k","low": "4000k","medium": "6000k","high": "7000k","best": "8000k"},
464
+ 1440: {"poor": "4000k","low": "6000k","medium": "8000k","high": "10000k","best": "12000k"},
465
+ 2160: {"poor": "8000k","low": "10000k","medium": "12000k","high": "15000k","best": "20000k"}
466
+ }
467
+
468
+
469
+ crf_quality_by_resolution = {
470
+ 240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
471
+ 360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
472
+ 480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
473
+ 720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
474
+ 1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
475
+ 1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
476
+ 2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
477
+ }
478
+
479
+
480
+ def get_bitrate_for_resolution(resolution, quality):
481
+ available_resolutions = list(bitrate_quality_by_resolution.keys())
482
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
483
+ return bitrate_quality_by_resolution[closest_resolution][quality]
484
+
485
+
486
+ def get_crf_for_resolution(resolution, quality):
487
+ available_resolutions = list(crf_quality_by_resolution.keys())
488
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
489
+ return crf_quality_by_resolution[closest_resolution][quality]
490
+
491
+
492
+ def get_video_bitrate(video_file):
493
+ ffprobe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries',
494
+ 'stream=bit_rate', '-of', 'default=noprint_wrappers=1:nokey=1', video_file]
495
+ result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE)
496
+ kbps = max(int(result.stdout) // 1000, 10)
497
+ return str(kbps) + 'k'
498
+
499
+
500
+ def trim_video(video_path, output_path, start_frame, stop_frame):
501
+ video_name, _ = os.path.splitext(os.path.basename(video_path))
502
+ trimmed_video_filename = video_name + "_trimmed" + ".mp4"
503
+ temp_path = os.path.join(output_path, "trim")
504
+ os.makedirs(temp_path, exist_ok=True)
505
+ trimmed_video_file_path = os.path.join(temp_path, trimmed_video_filename)
506
+
507
+ video = VideoFileClip(video_path, fps_source="fps")
508
+ fps = video.fps
509
+ start_time = start_frame / fps
510
+ duration = (stop_frame - start_frame) / fps
511
+
512
+ bitrate = get_bitrate_for_resolution(min(*video.size), "high")
513
+
514
+ trimmed_video = video.subclip(start_time, start_time + duration)
515
+ trimmed_video.write_videofile(
516
+ trimmed_video_file_path, codec="libx264", audio_codec="aac", bitrate=bitrate,
517
+ )
518
+ trimmed_video.close()
519
+ video.close()
520
+
521
+ return trimmed_video_file_path
522
+
523
+
524
+ def open_directory(path=None):
525
+ if path is None:
526
+ return
527
+ try:
528
+ os.startfile(path)
529
+ except:
530
+ subprocess.Popen(["xdg-open", path])
531
+
532
+
533
+ class StreamerThread(object):
534
+ def __init__(self, src=0):
535
+ self.capture = cv2.VideoCapture(src)
536
+ self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
537
+ self.FPS = 1 / 30
538
+ self.FPS_MS = int(self.FPS * 1000)
539
+ self.thread = None
540
+ self.stopped = False
541
+ self.frame = None
542
+
543
+ def start(self):
544
+ self.thread = Thread(target=self.update, args=())
545
+ self.thread.daemon = True
546
+ self.thread.start()
547
+
548
+ def stop(self):
549
+ self.stopped = True
550
+ self.thread.join()
551
+ print("stopped")
552
+
553
+ def update(self):
554
+ while not self.stopped:
555
+ if self.capture.isOpened():
556
+ (self.status, self.frame) = self.capture.read()
557
+ time.sleep(self.FPS)
558
+
559
+
560
+ class ProcessBar:
561
+ def __init__(self, bar_length, total, before="⬛", after="🟨"):
562
+ self.bar_length = bar_length
563
+ self.total = total
564
+ self.before = before
565
+ self.after = after
566
+ self.bar = [self.before] * bar_length
567
+ self.start_time = time.time()
568
+
569
+ def get(self, index):
570
+ total = self.total
571
+ elapsed_time = time.time() - self.start_time
572
+ average_time_per_iteration = elapsed_time / (index + 1)
573
+ remaining_iterations = total - (index + 1)
574
+ estimated_remaining_time = remaining_iterations * average_time_per_iteration
575
+
576
+ self.bar[int(index / total * self.bar_length)] = self.after
577
+ info_text = f"({index+1}/{total}) {''.join(self.bar)} "
578
+ info_text += f"(ETR: {int(estimated_remaining_time // 60)} min {int(estimated_remaining_time % 60)} sec)"
579
+ return info_text
580
+
581
+
582
+ def add_logo_to_image(img, logo=logo_image):
583
+ logo_size = int(img.shape[1] * 0.1)
584
+ logo = cv2.resize(logo, (logo_size, logo_size))
585
+ if logo.shape[2] == 4:
586
+ alpha = logo[:, :, 3]
587
+ else:
588
+ alpha = np.ones_like(logo[:, :, 0]) * 255
589
+ padding = int(logo_size * 0.1)
590
+ roi = img.shape[0] - logo_size - padding, img.shape[1] - logo_size - padding
591
+ for c in range(0, 3):
592
+ img[roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c] = (
593
+ alpha / 255.0
594
+ ) * logo[:, :, c] + (1 - alpha / 255.0) * img[
595
+ roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
596
+ ]
597
+ return img
598
+
599
+
600
+ def split_list_by_lengths(data, length_list):
601
+ split_data = []
602
+ start_idx = 0
603
+ for length in length_list:
604
+ end_idx = start_idx + length
605
+ sublist = data[start_idx:end_idx]
606
+ split_data.append(sublist)
607
+ start_idx = end_idx
608
+ return split_data
609
+
610
+
611
+ def merge_img_sequence_from_ref(ref_video_path, image_sequence, output_file_name):
612
+ video_clip = VideoFileClip(ref_video_path, fps_source="fps")
613
+ fps = video_clip.fps
614
+ duration = video_clip.duration
615
+ total_frames = video_clip.reader.nframes
616
+ audio_clip = video_clip.audio if video_clip.audio is not None else None
617
+ edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
618
+
619
+ if audio_clip is not None:
620
+ edited_video_clip = edited_video_clip.set_audio(audio_clip)
621
+
622
+ bitrate = get_bitrate_for_resolution(min(*edited_video_clip.size), "high")
623
+
624
+ edited_video_clip.set_duration(duration).write_videofile(
625
+ output_file_name, codec="libx264", bitrate=bitrate,
626
+ )
627
+ edited_video_clip.close()
628
+ video_clip.close()
629
+
630
+
631
+ def scale_bbox_from_center(bbox, scale_width, scale_height, image_width, image_height):
632
+ # Extract the coordinates of the bbox
633
+ x1, y1, x2, y2 = bbox
634
+
635
+ # Calculate the center point of the bbox
636
+ center_x = (x1 + x2) / 2
637
+ center_y = (y1 + y2) / 2
638
+
639
+ # Calculate the new width and height of the bbox based on the scaling factors
640
+ width = x2 - x1
641
+ height = y2 - y1
642
+ new_width = width * scale_width
643
+ new_height = height * scale_height
644
+
645
+ # Calculate the new coordinates of the bbox, considering the image boundaries
646
+ new_x1 = center_x - new_width / 2
647
+ new_y1 = center_y - new_height / 2
648
+ new_x2 = center_x + new_width / 2
649
+ new_y2 = center_y + new_height / 2
650
+
651
+ # Adjust the coordinates to ensure the bbox remains within the image boundaries
652
+ new_x1 = max(0, new_x1)
653
+ new_y1 = max(0, new_y1)
654
+ new_x2 = min(image_width - 1, new_x2)
655
+ new_y2 = min(image_height - 1, new_y2)
656
+
657
+ # Return the scaled bbox coordinates
658
+ scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
659
+ return scaled_bbox
660
+
661
+
662
+ def laplacian_blending(A, B, m, num_levels=7):
663
+ assert A.shape == B.shape
664
+ assert B.shape == m.shape
665
+ height = m.shape[0]
666
+ width = m.shape[1]
667
+ size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
668
+ size = size_list[np.where(size_list > max(height, width))][0]
669
+ GA = np.zeros((size, size, 3), dtype=np.float32)
670
+ GA[:height, :width, :] = A
671
+ GB = np.zeros((size, size, 3), dtype=np.float32)
672
+ GB[:height, :width, :] = B
673
+ GM = np.zeros((size, size, 3), dtype=np.float32)
674
+ GM[:height, :width, :] = m
675
+ gpA = [GA]
676
+ gpB = [GB]
677
+ gpM = [GM]
678
+ for i in range(num_levels):
679
+ GA = cv2.pyrDown(GA)
680
+ GB = cv2.pyrDown(GB)
681
+ GM = cv2.pyrDown(GM)
682
+ gpA.append(np.float32(GA))
683
+ gpB.append(np.float32(GB))
684
+ gpM.append(np.float32(GM))
685
+ lpA = [gpA[num_levels-1]]
686
+ lpB = [gpB[num_levels-1]]
687
+ gpMr = [gpM[num_levels-1]]
688
+ for i in range(num_levels-1,0,-1):
689
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
690
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
691
+ lpA.append(LA)
692
+ lpB.append(LB)
693
+ gpMr.append(gpM[i-1])
694
+ LS = []
695
+ for la,lb,gm in zip(lpA,lpB,gpMr):
696
+ ls = la * gm + lb * (1.0 - gm)
697
+ LS.append(ls)
698
+ ls_ = LS[0]
699
+ for i in range(1,num_levels):
700
+ ls_ = cv2.pyrUp(ls_)
701
+ ls_ = cv2.add(ls_, LS[i])
702
+ ls_ = ls_[:height, :width, :]
703
+ #ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
704
+ return ls_.clip(0, 255)
705
+
706
+
707
+ def mask_crop(mask, crop):
708
+ top, bottom, left, right = crop
709
+ shape = mask.shape
710
+ top = int(top)
711
+ bottom = int(bottom)
712
+ if top + bottom < shape[1]:
713
+ if top > 0: mask[:top, :] = 0
714
+ if bottom > 0: mask[-bottom:, :] = 0
715
+
716
+ left = int(left)
717
+ right = int(right)
718
+ if left + right < shape[0]:
719
+ if left > 0: mask[:, :left] = 0
720
+ if right > 0: mask[:, -right:] = 0
721
+
722
+ return mask
723
+
724
+ def create_image_grid(images, size=128):
725
+ num_images = len(images)
726
+ num_cols = int(np.ceil(np.sqrt(num_images)))
727
+ num_rows = int(np.ceil(num_images / num_cols))
728
+ grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
729
+
730
+ for i, image in enumerate(images):
731
+ row_idx = (i // num_cols) * size
732
+ col_idx = (i % num_cols) * size
733
+ image = cv2.resize(image.copy(), (size,size))
734
+ if image.dtype != np.uint8:
735
+ image = (image.astype('float32') * 255).astype('uint8')
736
+ if image.ndim == 2:
737
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
738
+ grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
739
+
740
+ return grid
741
+
742
+
743
+ #### UTILS.PY CODE END ###
744
+
745
+ #### APP.PY CODE END ###
746
+
747
+ import os
748
+ import spaces
749
+ import cv2
750
+ import glob
751
+ import time
752
+ import torch
753
+ import shutil
754
+ import argparse
755
+ import platform
756
+ import datetime
757
+ import subprocess
758
+ import insightface
759
+ import onnxruntime
760
+ import numpy as np
761
+ import gradio as gr
762
+ import threading
763
+ import queue
764
+ from tqdm import tqdm
765
+ import concurrent.futures
766
+ from moviepy.editor import VideoFileClip
767
+
768
+ from nsfw_checker import NSFWChecker
769
+ from face_swapper import Inswapper, paste_to_whole
770
+ from face_analyser import detect_conditions, get_analysed_data, swap_options_list
771
+ from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
772
+ from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
773
+ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
774
+
775
+ ## ------------------------------ USER ARGS ------------------------------
776
+
777
+ parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
778
+ parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
779
+ parser.add_argument("--batch_size", help="Gpu batch size", default=32)
780
+ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
781
+ parser.add_argument(
782
+ "--colab", action="store_true", help="Enable colab mode", default=False
783
+ )
784
+ user_args = parser.parse_args()
785
+
786
+ ## ------------------------------ DEFAULTS ------------------------------
787
+
788
+ USE_COLAB = user_args.colab
789
+ USE_CUDA = user_args.cuda
790
+ DEF_OUTPUT_PATH = user_args.out_dir
791
+ BATCH_SIZE = int(user_args.batch_size)
792
+ WORKSPACE = None
793
+ OUTPUT_FILE = None
794
+ CURRENT_FRAME = None
795
+ STREAMER = None
796
+ DETECT_CONDITION = "best detection"
797
+ DETECT_SIZE = 640
798
+ DETECT_THRESH = 0.6
799
+ NUM_OF_SRC_SPECIFIC = 10
800
+ MASK_INCLUDE = [
801
+ "Skin",
802
+ "R-Eyebrow",
803
+ "L-Eyebrow",
804
+ "L-Eye",
805
+ "R-Eye",
806
+ "Nose",
807
+ "Mouth",
808
+ "L-Lip",
809
+ "U-Lip"
810
+ ]
811
+ MASK_SOFT_KERNEL = 17
812
+ MASK_SOFT_ITERATIONS = 10
813
+ MASK_BLUR_AMOUNT = 0.1
814
+ MASK_ERODE_AMOUNT = 0.15
815
+
816
+ FACE_SWAPPER = None
817
+ FACE_ANALYSER = None
818
+ FACE_ENHANCER = None
819
+ FACE_PARSER = None
820
+ NSFW_DETECTOR = None
821
+ FACE_ENHANCER_LIST = ["NONE"]
822
+ FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
823
+ FACE_ENHANCER_LIST.extend(cv2_interpolations)
824
+
825
+ ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
826
+ # Note: Non CUDA users may change settings here
827
+
828
+ PROVIDER = ["CPUExecutionProvider"]
829
+
830
+ if USE_CUDA:
831
+ available_providers = onnxruntime.get_available_providers()
832
+ if "CUDAExecutionProvider" in available_providers:
833
+ print("\n********** Running on CUDA **********\n")
834
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
835
+ else:
836
+ USE_CUDA = False
837
+ print("\n********** CUDA unavailable running on CPU **********\n")
838
+ else:
839
+ USE_CUDA = False
840
+ print("\n********** Running on CPU **********\n")
841
+
842
+ device = "cuda" if USE_CUDA else "cpu"
843
+ EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
844
+
845
+ ## ------------------------------ LOAD MODELS ------------------------------
846
+
847
+ def load_face_analyser_model(name="buffalo_l"):
848
+ global FACE_ANALYSER
849
+ if FACE_ANALYSER is None:
850
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
851
+ FACE_ANALYSER.prepare(
852
+ ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
853
+ )
854
+
855
+
856
+ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
857
+ global FACE_SWAPPER
858
+ if FACE_SWAPPER is None:
859
+ batch = int(BATCH_SIZE) if device == "cuda" else 1
860
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
861
+
862
+
863
+ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
864
+ global FACE_PARSER
865
+ if FACE_PARSER is None:
866
+ FACE_PARSER = init_parsing_model(path, device=device)
867
+
868
+ def load_nsfw_detector_model(path="./assets/pretrained_models/open-nsfw.onnx"):
869
+ global NSFW_DETECTOR
870
+ if NSFW_DETECTOR is None:
871
+ NSFW_DETECTOR = NSFWChecker(model_path=path, providers=PROVIDER)
872
+
873
+
874
+ load_face_analyser_model()
875
+ load_face_swapper_model()
876
+
877
+ ## ------------------------------ MAIN PROCESS ------------------------------
878
+
879
+
880
+ @spaces.GPU(duration=300, enable_queue=True)
881
+ def process(
882
+ input_type,
883
+ image_path,
884
+ video_path,
885
+ directory_path,
886
+ source_path,
887
+ output_path,
888
+ output_name,
889
+ keep_output_sequence,
890
+ condition,
891
+ age,
892
+ distance,
893
+ face_enhancer_name,
894
+ enable_face_parser,
895
+ mask_includes,
896
+ mask_soft_kernel,
897
+ mask_soft_iterations,
898
+ blur_amount,
899
+ erode_amount,
900
+ face_scale,
901
+ enable_laplacian_blend,
902
+ crop_top,
903
+ crop_bott,
904
+ crop_left,
905
+ crop_right,
906
+ *specifics,
907
+ ):
908
+ global WORKSPACE
909
+ global OUTPUT_FILE
910
+ global PREVIEW
911
+ WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
912
+
913
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
914
+
915
+ def ui_before():
916
+ return (
917
+ gr.update(visible=True, value=PREVIEW),
918
+ gr.update(interactive=False),
919
+ gr.update(interactive=False),
920
+ gr.update(visible=False),
921
+ )
922
+
923
+ def ui_after():
924
+ return (
925
+ gr.update(visible=True, value=PREVIEW),
926
+ gr.update(interactive=True),
927
+ gr.update(interactive=True),
928
+ gr.update(visible=False),
929
+ )
930
+
931
+ def ui_after_vid():
932
+ return (
933
+ gr.update(visible=False),
934
+ gr.update(interactive=True),
935
+ gr.update(interactive=True),
936
+ gr.update(value=OUTPUT_FILE, visible=True),
937
+ )
938
+
939
+ start_time = time.time()
940
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
941
+ get_finsh_text = lambda start_time: f"βœ”οΈ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
942
+
943
+ ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
944
+
945
+ yield "### \n βŒ› Loading NSFW detector model...", *ui_before()
946
+ load_nsfw_detector_model()
947
+
948
+ yield "### \n βŒ› Loading face analyser model...", *ui_before()
949
+ load_face_analyser_model()
950
+
951
+ yield "### \n βŒ› Loading face swapper model...", *ui_before()
952
+ load_face_swapper_model()
953
+
954
+ if face_enhancer_name != "NONE":
955
+ if face_enhancer_name not in cv2_interpolations:
956
+ yield f"### \n βŒ› Loading {face_enhancer_name} model...", *ui_before()
957
+ FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
958
+ else:
959
+ FACE_ENHANCER = None
960
+
961
+ if enable_face_parser:
962
+ yield "### \n βŒ› Loading face parsing model...", *ui_before()
963
+ load_face_parser_model()
964
+
965
+ includes = mask_regions_to_list(mask_includes)
966
+ specifics = list(specifics)
967
+ half = len(specifics) // 2
968
+ sources = specifics[:half]
969
+ specifics = specifics[half:]
970
+ if crop_top > crop_bott:
971
+ crop_top, crop_bott = crop_bott, crop_top
972
+ if crop_left > crop_right:
973
+ crop_left, crop_right = crop_right, crop_left
974
+ crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
975
+
976
+ def swap_process(image_sequence):
977
+ ## ------------------------------ CONTENT CHECK ------------------------------
978
+
979
+ yield "### \n βŒ› Checking contents...", *ui_before()
980
+ nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
981
+ if nsfw:
982
+ message = "NSFW Content detected !!!"
983
+ yield f"### \n πŸ”ž {message}", *ui_before()
984
+ assert not nsfw, message
985
+ return False
986
+ EMPTY_CACHE()
987
+
988
+ ## ------------------------------ ANALYSE FACE ------------------------------
989
+
990
+ yield "### \n βŒ› Analysing face data...", *ui_before()
991
+ if condition != "Specific Face":
992
+ source_data = source_path, age
993
+ else:
994
+ source_data = ((sources, specifics), distance)
995
+ analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
996
+ FACE_ANALYSER,
997
+ image_sequence,
998
+ source_data,
999
+ swap_condition=condition,
1000
+ detect_condition=DETECT_CONDITION,
1001
+ scale=face_scale
1002
+ )
1003
+
1004
+ ## ------------------------------ SWAP FUNC ------------------------------
1005
+
1006
+ yield "### \n βŒ› Generating faces...", *ui_before()
1007
+ preds = []
1008
+ matrs = []
1009
+ count = 0
1010
+ global PREVIEW
1011
+ for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
1012
+ preds.extend(batch_pred)
1013
+ matrs.extend(batch_matr)
1014
+ EMPTY_CACHE()
1015
+ count += 1
1016
+
1017
+ if USE_CUDA:
1018
+ image_grid = create_image_grid(batch_pred, size=128)
1019
+ PREVIEW = image_grid[:, :, ::-1]
1020
+ yield f"### \n βŒ› Generating face Batch {count}", *ui_before()
1021
+
1022
+ ## ------------------------------ FACE ENHANCEMENT ------------------------------
1023
+
1024
+ generated_len = len(preds)
1025
+ if face_enhancer_name != "NONE":
1026
+ yield f"### \n βŒ› Upscaling faces with {face_enhancer_name}...", *ui_before()
1027
+ for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
1028
+ enhancer_model, enhancer_model_runner = FACE_ENHANCER
1029
+ pred = enhancer_model_runner(pred, enhancer_model)
1030
+ preds[idx] = cv2.resize(pred, (512,512))
1031
+ EMPTY_CACHE()
1032
+
1033
+ ## ------------------------------ FACE PARSING ------------------------------
1034
+
1035
+ if enable_face_parser:
1036
+ yield "### \n βŒ› Face-parsing mask...", *ui_before()
1037
+ masks = []
1038
+ count = 0
1039
+ for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
1040
+ masks.append(batch_mask)
1041
+ EMPTY_CACHE()
1042
+ count += 1
1043
+
1044
+ if len(batch_mask) > 1:
1045
+ image_grid = create_image_grid(batch_mask, size=128)
1046
+ PREVIEW = image_grid[:, :, ::-1]
1047
+ yield f"### \n βŒ› Face parsing Batch {count}", *ui_before()
1048
+ masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
1049
+ else:
1050
+ masks = [None] * generated_len
1051
+
1052
+ ## ------------------------------ SPLIT LIST ------------------------------
1053
+
1054
+ split_preds = split_list_by_lengths(preds, num_faces_per_frame)
1055
+ del preds
1056
+ split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
1057
+ del matrs
1058
+ split_masks = split_list_by_lengths(masks, num_faces_per_frame)
1059
+ del masks
1060
+
1061
+ ## ------------------------------ PASTE-BACK ------------------------------
1062
+
1063
+ yield "### \n βŒ› Pasting back...", *ui_before()
1064
+ def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
1065
+ whole_img_path = frame_img
1066
+ whole_img = cv2.imread(whole_img_path)
1067
+ blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
1068
+ for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
1069
+ p = cv2.resize(p, (512,512))
1070
+ mask = cv2.resize(mask, (512,512)) if mask is not None else None
1071
+ m /= 0.25
1072
+ whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
1073
+ cv2.imwrite(whole_img_path, whole_img)
1074
+
1075
+ def concurrent_post_process(image_sequence, *args):
1076
+ with concurrent.futures.ThreadPoolExecutor() as executor:
1077
+ futures = []
1078
+ for idx, frame_img in enumerate(image_sequence):
1079
+ future = executor.submit(post_process, idx, frame_img, *args)
1080
+ futures.append(future)
1081
+
1082
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
1083
+ result = future.result()
1084
+
1085
+ concurrent_post_process(
1086
+ image_sequence,
1087
+ split_preds,
1088
+ split_matrs,
1089
+ split_masks,
1090
+ enable_laplacian_blend,
1091
+ crop_mask,
1092
+ blur_amount,
1093
+ erode_amount
1094
+ )
1095
+
1096
+
1097
+ ## ------------------------------ IMAGE ------------------------------
1098
+
1099
+ if input_type == "Image":
1100
+ target = cv2.imread(image_path)
1101
+ output_file = os.path.join(output_path, output_name + ".png")
1102
+ cv2.imwrite(output_file, target)
1103
+
1104
+ for info_update in swap_process([output_file]):
1105
+ yield info_update
1106
+
1107
+ OUTPUT_FILE = output_file
1108
+ WORKSPACE = output_path
1109
+ PREVIEW = cv2.imread(output_file)[:, :, ::-1]
1110
+
1111
+ yield get_finsh_text(start_time), *ui_after()
1112
+
1113
+ ## ------------------------------ VIDEO ------------------------------
1114
+
1115
+ elif input_type == "Video":
1116
+ temp_path = os.path.join(output_path, output_name, "sequence")
1117
+ os.makedirs(temp_path, exist_ok=True)
1118
+
1119
+ yield "### \n βŒ› Extracting video frames...", *ui_before()
1120
+ image_sequence = []
1121
+ cap = cv2.VideoCapture(video_path)
1122
+ curr_idx = 0
1123
+ while True:
1124
+ ret, frame = cap.read()
1125
+ if not ret:break
1126
+ frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
1127
+ cv2.imwrite(frame_path, frame)
1128
+ image_sequence.append(frame_path)
1129
+ curr_idx += 1
1130
+ cap.release()
1131
+ cv2.destroyAllWindows()
1132
+
1133
+ for info_update in swap_process(image_sequence):
1134
+ yield info_update
1135
+
1136
+ yield "### \n βŒ› Merging sequence...", *ui_before()
1137
+ output_video_path = os.path.join(output_path, output_name + ".mp4")
1138
+ merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
1139
+
1140
+ if os.path.exists(temp_path) and not keep_output_sequence:
1141
+ yield "### \n βŒ› Removing temporary files...", *ui_before()
1142
+ shutil.rmtree(temp_path)
1143
+
1144
+ WORKSPACE = output_path
1145
+ OUTPUT_FILE = output_video_path
1146
+
1147
+ yield get_finsh_text(start_time), *ui_after_vid()
1148
+
1149
+ ## ------------------------------ DIRECTORY ------------------------------
1150
+
1151
+ elif input_type == "Directory":
1152
+ extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
1153
+ temp_path = os.path.join(output_path, output_name)
1154
+ if os.path.exists(temp_path):
1155
+ shutil.rmtree(temp_path)
1156
+ os.mkdir(temp_path)
1157
+
1158
+ file_paths =[]
1159
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
1160
+ if any(file_path.lower().endswith(ext) for ext in extensions):
1161
+ img = cv2.imread(file_path)
1162
+ new_file_path = os.path.join(temp_path, os.path.basename(file_path))
1163
+ cv2.imwrite(new_file_path, img)
1164
+ file_paths.append(new_file_path)
1165
+
1166
+ for info_update in swap_process(file_paths):
1167
+ yield info_update
1168
+
1169
+ PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
1170
+ WORKSPACE = temp_path
1171
+ OUTPUT_FILE = file_paths[-1]
1172
+
1173
+ yield get_finsh_text(start_time), *ui_after()
1174
+
1175
+ ## ------------------------------ STREAM ------------------------------
1176
+
1177
+ elif input_type == "Stream":
1178
+ pass
1179
+
1180
+
1181
+ ## ------------------------------ GRADIO FUNC ------------------------------
1182
+
1183
+
1184
+ def update_radio(value):
1185
+ if value == "Image":
1186
+ return (
1187
+ gr.update(visible=True),
1188
+ gr.update(visible=False),
1189
+ gr.update(visible=False),
1190
+ )
1191
+ elif value == "Video":
1192
+ return (
1193
+ gr.update(visible=False),
1194
+ gr.update(visible=True),
1195
+ gr.update(visible=False),
1196
+ )
1197
+ elif value == "Directory":
1198
+ return (
1199
+ gr.update(visible=False),
1200
+ gr.update(visible=False),
1201
+ gr.update(visible=True),
1202
+ )
1203
+ elif value == "Stream":
1204
+ return (
1205
+ gr.update(visible=False),
1206
+ gr.update(visible=False),
1207
+ gr.update(visible=True),
1208
+ )
1209
+
1210
+
1211
+ def swap_option_changed(value):
1212
+ if value.startswith("Age"):
1213
+ return (
1214
+ gr.update(visible=True),
1215
+ gr.update(visible=False),
1216
+ gr.update(visible=True),
1217
+ )
1218
+ elif value == "Specific Face":
1219
+ return (
1220
+ gr.update(visible=False),
1221
+ gr.update(visible=True),
1222
+ gr.update(visible=False),
1223
+ )
1224
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
1225
+
1226
+
1227
+ def video_changed(video_path):
1228
+ sliders_update = gr.Slider.update
1229
+ button_update = gr.Button.update
1230
+ number_update = gr.Number.update
1231
+
1232
+ if video_path is None:
1233
+ return (
1234
+ sliders_update(minimum=0, maximum=0, value=0),
1235
+ sliders_update(minimum=1, maximum=1, value=1),
1236
+ number_update(value=1),
1237
+ )
1238
+ try:
1239
+ clip = VideoFileClip(video_path)
1240
+ fps = clip.fps
1241
+ total_frames = clip.reader.nframes
1242
+ clip.close()
1243
+ return (
1244
+ sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
1245
+ sliders_update(
1246
+ minimum=0, maximum=total_frames, value=total_frames, interactive=True
1247
+ ),
1248
+ number_update(value=fps),
1249
+ )
1250
+ except:
1251
+ return (
1252
+ sliders_update(value=0),
1253
+ sliders_update(value=0),
1254
+ number_update(value=1),
1255
+ )
1256
+
1257
+
1258
+ def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
1259
+ yield "### \n βŒ› Applying new values..."
1260
+ global FACE_ANALYSER
1261
+ global DETECT_CONDITION
1262
+ DETECT_CONDITION = detect_condition
1263
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
1264
+ FACE_ANALYSER.prepare(
1265
+ ctx_id=0,
1266
+ det_size=(int(detection_size), int(detection_size)),
1267
+ det_thresh=float(detection_threshold),
1268
+ )
1269
+ yield f"### \n βœ”οΈ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
1270
+
1271
+
1272
+ def stop_running():
1273
+ global STREAMER
1274
+ if hasattr(STREAMER, "stop"):
1275
+ STREAMER.stop()
1276
+ STREAMER = None
1277
+ return "Cancelled"
1278
+
1279
+
1280
+ def slider_changed(show_frame, video_path, frame_index):
1281
+ if not show_frame:
1282
+ return None, None
1283
+ if video_path is None:
1284
+ return None, None
1285
+ clip = VideoFileClip(video_path)
1286
+ frame = clip.get_frame(frame_index / clip.fps)
1287
+ frame_array = np.array(frame)
1288
+ clip.close()
1289
+ return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
1290
+ visible=False
1291
+ )
1292
+
1293
+
1294
+ def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
1295
+ yield video_path, f"### \n βŒ› Trimming video frame {start_frame} to {stop_frame}..."
1296
+ try:
1297
+ output_path = os.path.join(output_path, output_name)
1298
+ trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
1299
+ yield trimmed_video, "### \n βœ”οΈ Video trimmed and reloaded."
1300
+ except Exception as e:
1301
+ print(e)
1302
+ yield video_path, "### \n ❌ Video trimming failed. See console for more info."
1303
+
1304
+
1305
+ ## ------------------------------ GRADIO GUI ------------------------------
1306
+
1307
+ css = """
1308
+ footer{display:none !important}
1309
+ """
1310
+
1311
+ with gr.Blocks(css=css) as interface:
1312
+ gr.Markdown("# πŸ—Ώ Swap Mukham")
1313
+ gr.Markdown("### Face swap app based on insightface inswapper.")
1314
+ with gr.Row():
1315
+ with gr.Row():
1316
+ with gr.Column(scale=0.4):
1317
+ with gr.Tab("πŸ“„ Swap Condition"):
1318
+ swap_option = gr.Dropdown(
1319
+ swap_options_list,
1320
+ info="Choose which face or faces in the target image to swap.",
1321
+ multiselect=False,
1322
+ show_label=False,
1323
+ value=swap_options_list[0],
1324
+ interactive=True,
1325
+ )
1326
+ age = gr.Number(
1327
+ value=25, label="Value", interactive=True, visible=False
1328
+ )
1329
+
1330
+ with gr.Tab("🎚️ Detection Settings"):
1331
+ detect_condition_dropdown = gr.Dropdown(
1332
+ detect_conditions,
1333
+ label="Condition",
1334
+ value=DETECT_CONDITION,
1335
+ interactive=True,
1336
+ info="This condition is only used when multiple faces are detected on source or specific image.",
1337
+ )
1338
+ detection_size = gr.Number(
1339
+ label="Detection Size", value=DETECT_SIZE, interactive=True
1340
+ )
1341
+ detection_threshold = gr.Number(
1342
+ label="Detection Threshold",
1343
+ value=DETECT_THRESH,
1344
+ interactive=True,
1345
+ )
1346
+ apply_detection_settings = gr.Button("Apply settings")
1347
+
1348
+ with gr.Tab("πŸ“€ Output Settings"):
1349
+ output_directory = gr.Text(
1350
+ label="Output Directory",
1351
+ value=DEF_OUTPUT_PATH,
1352
+ interactive=True,
1353
+ )
1354
+ output_name = gr.Text(
1355
+ label="Output Name", value="Result", interactive=True
1356
+ )
1357
+ keep_output_sequence = gr.Checkbox(
1358
+ label="Keep output sequence", value=False, interactive=True
1359
+ )
1360
+
1361
+ with gr.Tab("πŸͺ„ Other Settings"):
1362
+ face_scale = gr.Slider(
1363
+ label="Face Scale",
1364
+ minimum=0,
1365
+ maximum=2,
1366
+ value=1,
1367
+ interactive=True,
1368
+ )
1369
+
1370
+ face_enhancer_name = gr.Dropdown(
1371
+ FACE_ENHANCER_LIST, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
1372
+ )
1373
+
1374
+ with gr.Accordion("Advanced Mask", open=False):
1375
+ enable_face_parser_mask = gr.Checkbox(
1376
+ label="Enable Face Parsing",
1377
+ value=False,
1378
+ interactive=True,
1379
+ )
1380
+
1381
+ mask_include = gr.Dropdown(
1382
+ mask_regions.keys(),
1383
+ value=MASK_INCLUDE,
1384
+ multiselect=True,
1385
+ label="Include",
1386
+ interactive=True,
1387
+ )
1388
+ mask_soft_kernel = gr.Number(
1389
+ label="Soft Erode Kernel",
1390
+ value=MASK_SOFT_KERNEL,
1391
+ minimum=3,
1392
+ interactive=True,
1393
+ visible = False
1394
+ )
1395
+ mask_soft_iterations = gr.Number(
1396
+ label="Soft Erode Iterations",
1397
+ value=MASK_SOFT_ITERATIONS,
1398
+ minimum=0,
1399
+ interactive=True,
1400
+
1401
+ )
1402
+
1403
+
1404
+ with gr.Accordion("Crop Mask", open=False):
1405
+ crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
1406
+ crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
1407
+ crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
1408
+ crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
1409
+
1410
+
1411
+ erode_amount = gr.Slider(
1412
+ label="Mask Erode",
1413
+ minimum=0,
1414
+ maximum=1,
1415
+ value=MASK_ERODE_AMOUNT,
1416
+ step=0.05,
1417
+ interactive=True,
1418
+ )
1419
+
1420
+ blur_amount = gr.Slider(
1421
+ label="Mask Blur",
1422
+ minimum=0,
1423
+ maximum=1,
1424
+ value=MASK_BLUR_AMOUNT,
1425
+ step=0.05,
1426
+ interactive=True,
1427
+ )
1428
+
1429
+ enable_laplacian_blend = gr.Checkbox(
1430
+ label="Laplacian Blending",
1431
+ value=True,
1432
+ interactive=True,
1433
+ )
1434
+
1435
+
1436
+ source_image_input = gr.Image(
1437
+ label="Source face", type="filepath", interactive=True
1438
+ )
1439
+
1440
+ with gr.Group(visible=False) as specific_face:
1441
+ for i in range(NUM_OF_SRC_SPECIFIC):
1442
+ idx = i + 1
1443
+ code = "\n"
1444
+ code += f"with gr.Tab(label='({idx})'):"
1445
+ code += "\n\twith gr.Row():"
1446
+ code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
1447
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
1448
+ exec(code)
1449
+
1450
+ distance_slider = gr.Slider(
1451
+ minimum=0,
1452
+ maximum=2,
1453
+ value=0.6,
1454
+ interactive=True,
1455
+ label="Distance",
1456
+ info="Lower distance is more similar and higher distance is less similar to the target face.",
1457
+ )
1458
+
1459
+ with gr.Group():
1460
+ input_type = gr.Radio(
1461
+ ["Image", "Video"],
1462
+ label="Target Type",
1463
+ value="Image",
1464
+ )
1465
+
1466
+ with gr.Group(visible=True) as input_image_group:
1467
+ image_input = gr.Image(
1468
+ label="Target Image", interactive=True, type="filepath"
1469
+ )
1470
+
1471
+ with gr.Group(visible=False) as input_video_group:
1472
+ vid_widget = gr.Video if USE_COLAB else gr.Text
1473
+ video_input = gr.Video(
1474
+ label="Target Video", interactive=True
1475
+ )
1476
+ with gr.Accordion("βœ‚οΈ Trim video", open=False):
1477
+ with gr.Column():
1478
+ with gr.Row():
1479
+ set_slider_range_btn = gr.Button(
1480
+ "Set frame range", interactive=True
1481
+ )
1482
+ show_trim_preview_btn = gr.Checkbox(
1483
+ label="Show frame when slider change",
1484
+ value=True,
1485
+ interactive=True,
1486
+ )
1487
+
1488
+ video_fps = gr.Number(
1489
+ value=30,
1490
+ interactive=False,
1491
+ label="Fps",
1492
+ visible=False,
1493
+ )
1494
+ start_frame = gr.Slider(
1495
+ minimum=0,
1496
+ maximum=1,
1497
+ value=0,
1498
+ step=1,
1499
+ interactive=True,
1500
+ label="Start Frame",
1501
+ info="",
1502
+ )
1503
+ end_frame = gr.Slider(
1504
+ minimum=0,
1505
+ maximum=1,
1506
+ value=1,
1507
+ step=1,
1508
+ interactive=True,
1509
+ label="End Frame",
1510
+ info="",
1511
+ )
1512
+ trim_and_reload_btn = gr.Button(
1513
+ "Trim and Reload", interactive=True
1514
+ )
1515
+
1516
+ with gr.Group(visible=False) as input_directory_group:
1517
+ direc_input = gr.Text(label="Path", interactive=True)
1518
+
1519
+ with gr.Column(scale=0.6):
1520
+ info = gr.Markdown(value="...")
1521
+
1522
+ with gr.Row():
1523
+ swap_button = gr.Button("✨ Swap", variant="primary")
1524
+ cancel_button = gr.Button("β›” Cancel")
1525
+
1526
+ preview_image = gr.Image(label="Output", interactive=False)
1527
+ preview_video = gr.Video(
1528
+ label="Output", interactive=False, visible=False
1529
+ )
1530
+
1531
+ with gr.Row():
1532
+ output_directory_button = gr.Button(
1533
+ "πŸ“‚", interactive=False, visible=False
1534
+ )
1535
+ output_video_button = gr.Button(
1536
+ "🎬", interactive=False, visible=False
1537
+ )
1538
+
1539
+ with gr.Group():
1540
+ with gr.Row():
1541
+ gr.Markdown(
1542
+ "### [🀝 Sponsor](https://github.com/sponsors/harisreedhar)"
1543
+ )
1544
+ gr.Markdown(
1545
+ "### [πŸ‘¨β€πŸ’» Source code](https://github.com/harisreedhar/Swap-Mukham)"
1546
+ )
1547
+ gr.Markdown(
1548
+ "### [⚠️ Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
1549
+ )
1550
+ gr.Markdown(
1551
+ "### [🌐 Run in Colab](https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb)"
1552
+ )
1553
+ gr.Markdown(
1554
+ "### [πŸ€— Acknowledgements](https://github.com/harisreedhar/Swap-Mukham#acknowledgements)"
1555
+ )
1556
+
1557
+ ## ------------------------------ GRADIO EVENTS ------------------------------
1558
+
1559
+ set_slider_range_event = set_slider_range_btn.click(
1560
+ video_changed,
1561
+ inputs=[video_input],
1562
+ outputs=[start_frame, end_frame, video_fps],
1563
+ )
1564
+
1565
+ trim_and_reload_event = trim_and_reload_btn.click(
1566
+ fn=trim_and_reload,
1567
+ inputs=[video_input, output_directory, output_name, start_frame, end_frame],
1568
+ outputs=[video_input, info],
1569
+ )
1570
+
1571
+ start_frame_event = start_frame.release(
1572
+ fn=slider_changed,
1573
+ inputs=[show_trim_preview_btn, video_input, start_frame],
1574
+ outputs=[preview_image, preview_video],
1575
+ show_progress=True,
1576
+ )
1577
+
1578
+ end_frame_event = end_frame.release(
1579
+ fn=slider_changed,
1580
+ inputs=[show_trim_preview_btn, video_input, end_frame],
1581
+ outputs=[preview_image, preview_video],
1582
+ show_progress=True,
1583
+ )
1584
+
1585
+ input_type.change(
1586
+ update_radio,
1587
+ inputs=[input_type],
1588
+ outputs=[input_image_group, input_video_group, input_directory_group],
1589
+ )
1590
+ swap_option.change(
1591
+ swap_option_changed,
1592
+ inputs=[swap_option],
1593
+ outputs=[age, specific_face, source_image_input],
1594
+ )
1595
+
1596
+ apply_detection_settings.click(
1597
+ analyse_settings_changed,
1598
+ inputs=[detect_condition_dropdown, detection_size, detection_threshold],
1599
+ outputs=[info],
1600
+ )
1601
+
1602
+ src_specific_inputs = []
1603
+ gen_variable_txt = ",".join(
1604
+ [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
1605
+ + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
1606
+ )
1607
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
1608
+ swap_inputs = [
1609
+ input_type,
1610
+ image_input,
1611
+ video_input,
1612
+ direc_input,
1613
+ source_image_input,
1614
+ output_directory,
1615
+ output_name,
1616
+ keep_output_sequence,
1617
+ swap_option,
1618
+ age,
1619
+ distance_slider,
1620
+ face_enhancer_name,
1621
+ enable_face_parser_mask,
1622
+ mask_include,
1623
+ mask_soft_kernel,
1624
+ mask_soft_iterations,
1625
+ blur_amount,
1626
+ erode_amount,
1627
+ face_scale,
1628
+ enable_laplacian_blend,
1629
+ crop_top,
1630
+ crop_bott,
1631
+ crop_left,
1632
+ crop_right,
1633
+ *src_specific_inputs,
1634
+ ]
1635
+
1636
+ swap_outputs = [
1637
+ info,
1638
+ preview_image,
1639
+ output_directory_button,
1640
+ output_video_button,
1641
+ preview_video,
1642
+ ]
1643
+
1644
+ swap_event = swap_button.click(
1645
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
1646
+ )
1647
+
1648
+ cancel_button.click(
1649
+ fn=stop_running,
1650
+ inputs=None,
1651
+ outputs=[info],
1652
+ cancels=[
1653
+ swap_event,
1654
+ trim_and_reload_event,
1655
+ set_slider_range_event,
1656
+ start_frame_event,
1657
+ end_frame_event,
1658
+ ],
1659
+ show_progress=True,
1660
+ )
1661
+ output_directory_button.click(
1662
+ lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
1663
+ )
1664
+ output_video_button.click(
1665
+ lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
1666
+ )
1667
+
1668
+ if __name__ == "__main__":
1669
+ if USE_COLAB:
1670
+ print("Running in colab mode")
1671
+
1672
+ interface.launch()
1673
+
1674
+
1675
+ #### APP.PY CODE END ###