Harisreedhar commited on
Commit
7f475d2
1 Parent(s): 27c3130

Add soft erosion and fix face parsing video

Browse files
Files changed (4) hide show
  1. app.py +40 -24
  2. face_parsing/__init__.py +1 -1
  3. face_parsing/swap.py +60 -17
  4. swapper.py +3 -3
app.py CHANGED
@@ -17,7 +17,7 @@ from moviepy.editor import VideoFileClip, ImageSequenceClip
17
 
18
  from face_analyser import detect_conditions, analyse_face
19
  from utils import trim_video, StreamerThread, ProcessBar, open_directory
20
- from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list
21
  from swapper import (
22
  swap_face,
23
  swap_face_with_condition,
@@ -59,8 +59,9 @@ MASK_INCLUDE = [
59
  "L-Lip",
60
  "U-Lip"
61
  ]
62
- MASK_EXCLUDE = ["R-Ear", "L-Ear", "Hair", "Hat"]
63
- MASK_BLUR = 25
 
64
 
65
  FACE_SWAPPER = None
66
  FACE_ANALYSER = None
@@ -84,6 +85,8 @@ else:
84
  USE_CUDA = False
85
  print("\n********** Running on CPU **********\n")
86
 
 
 
87
 
88
  ## ------------------------------ LOAD MODELS ------------------------------
89
 
@@ -114,7 +117,7 @@ def load_face_parser_model(name="./assets/pretrained_models/79999_iter.pth"):
114
  global FACE_PARSER
115
  path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
116
  if FACE_PARSER is None:
117
- FACE_PARSER = init_parser(name, use_cuda=USE_CUDA)
118
 
119
 
120
  load_face_analyser_model()
@@ -137,9 +140,10 @@ def process(
137
  distance,
138
  face_enhance,
139
  enable_face_parser,
140
- mask_include,
141
- mask_exclude,
142
- mask_blur,
 
143
  *specifics,
144
  ):
145
  global WORKSPACE
@@ -196,14 +200,18 @@ def process(
196
 
197
  yield "### \n ⌛ Analysing Face...", *ui_before()
198
 
199
- mi = mask_regions_to_list(mask_include)
200
- me = mask_regions_to_list(mask_exclude)
 
 
 
 
201
  models = {
202
  "swap": FACE_SWAPPER,
203
  "enhance": FACE_ENHANCER,
204
  "enhance_sett": face_enhance,
205
  "face_parser": FACE_PARSER,
206
- "face_parser_sett": (enable_face_parser, mi, me, int(mask_blur)),
207
  }
208
 
209
  ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
@@ -301,9 +309,9 @@ def process(
301
 
302
  if condition == "Specific Face":
303
  swapped = swap_specific(
304
- frame,
305
- analysed_target,
306
  analysed_source_specific,
 
 
307
  models,
308
  threshold=distance,
309
  )
@@ -381,9 +389,9 @@ def process(
381
 
382
  if condition == "Specific Face":
383
  swapped = swap_specific(
384
- target,
385
- analysed_target,
386
  analysed_source_specific,
 
 
387
  models,
388
  threshold=distance,
389
  )
@@ -636,16 +644,23 @@ with gr.Blocks(css=css) as interface:
636
  label="Include",
637
  interactive=True,
638
  )
639
- mask_exclude = gr.Dropdown(
640
- mask_regions.keys(),
641
- value=MASK_EXCLUDE,
642
- multiselect=True,
643
- label="Exclude",
644
  interactive=True,
 
 
 
 
 
 
 
 
645
  )
646
- mask_blur = gr.Number(
647
- label="Blur Mask",
648
- value=MASK_BLUR,
649
  minimum=0,
650
  interactive=True,
651
  )
@@ -827,8 +842,9 @@ with gr.Blocks(css=css) as interface:
827
  enable_face_enhance,
828
  enable_face_parser_mask,
829
  mask_include,
830
- mask_exclude,
831
- mask_blur,
 
832
  *src_specific_inputs,
833
  ]
834
 
 
17
 
18
  from face_analyser import detect_conditions, analyse_face
19
  from utils import trim_video, StreamerThread, ProcessBar, open_directory
20
+ from face_parsing import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
21
  from swapper import (
22
  swap_face,
23
  swap_face_with_condition,
 
59
  "L-Lip",
60
  "U-Lip"
61
  ]
62
+ MASK_SOFT_KERNEL = 17
63
+ MASK_SOFT_ITERATIONS = 7
64
+ MASK_BLUR_AMOUNT = 20
65
 
66
  FACE_SWAPPER = None
67
  FACE_ANALYSER = None
 
85
  USE_CUDA = False
86
  print("\n********** Running on CPU **********\n")
87
 
88
+ device = "cuda" if USE_CUDA else "cpu"
89
+
90
 
91
  ## ------------------------------ LOAD MODELS ------------------------------
92
 
 
117
  global FACE_PARSER
118
  path = os.path.join(os.path.abspath(os.path.dirname(__file__)), name)
119
  if FACE_PARSER is None:
120
+ FACE_PARSER = init_parser(name, mode=device)
121
 
122
 
123
  load_face_analyser_model()
 
140
  distance,
141
  face_enhance,
142
  enable_face_parser,
143
+ mask_includes,
144
+ mask_soft_kernel,
145
+ mask_soft_iterations,
146
+ blur_amount,
147
  *specifics,
148
  ):
149
  global WORKSPACE
 
200
 
201
  yield "### \n ⌛ Analysing Face...", *ui_before()
202
 
203
+ includes = mask_regions_to_list(mask_includes)
204
+ if mask_soft_iterations > 0:
205
+ smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=int(mask_soft_iterations)).to(device)
206
+ else:
207
+ smooth_mask = None
208
+
209
  models = {
210
  "swap": FACE_SWAPPER,
211
  "enhance": FACE_ENHANCER,
212
  "enhance_sett": face_enhance,
213
  "face_parser": FACE_PARSER,
214
+ "face_parser_sett": (enable_face_parser, includes, smooth_mask, int(blur_amount))
215
  }
216
 
217
  ## ------------------------------ ANALYSE SOURCE & SPECIFIC ------------------------------
 
309
 
310
  if condition == "Specific Face":
311
  swapped = swap_specific(
 
 
312
  analysed_source_specific,
313
+ analysed_target,
314
+ frame,
315
  models,
316
  threshold=distance,
317
  )
 
389
 
390
  if condition == "Specific Face":
391
  swapped = swap_specific(
 
 
392
  analysed_source_specific,
393
+ analysed_target,
394
+ target,
395
  models,
396
  threshold=distance,
397
  )
 
644
  label="Include",
645
  interactive=True,
646
  )
647
+ mask_soft_kernel = gr.Number(
648
+ label="Soft Erode Kernel",
649
+ value=MASK_SOFT_KERNEL,
650
+ minimum=3,
 
651
  interactive=True,
652
+ visible = False
653
+ )
654
+ mask_soft_iterations = gr.Number(
655
+ label="Soft Erode Iterations",
656
+ value=MASK_SOFT_ITERATIONS,
657
+ minimum=0,
658
+ interactive=True,
659
+
660
  )
661
+ blur_amount = gr.Number(
662
+ label="Mask Blur",
663
+ value=MASK_BLUR_AMOUNT,
664
  minimum=0,
665
  interactive=True,
666
  )
 
842
  enable_face_enhance,
843
  enable_face_parser_mask,
844
  mask_include,
845
+ mask_soft_kernel,
846
+ mask_soft_iterations,
847
+ blur_amount,
848
  *src_specific_inputs,
849
  ]
850
 
face_parsing/__init__.py CHANGED
@@ -1 +1 @@
1
- from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
 
1
+ from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list, SoftErosion
face_parsing/swap.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
  import torchvision.transforms as transforms
3
  import cv2
4
  import numpy as np
@@ -27,15 +29,44 @@ mask_regions = {
27
  "Hat":18
28
  }
29
 
30
- run_with_cuda = False
 
 
 
 
 
 
 
 
31
 
32
- def init_parser(pth_path, use_cuda=False):
33
- global run_with_cuda
34
- run_with_cuda = use_cuda
 
 
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  n_classes = 19
37
  net = BiSeNet(n_classes=n_classes)
38
- if run_with_cuda:
39
  net.cuda()
40
  net.load_state_dict(torch.load(pth_path))
41
  else:
@@ -55,8 +86,7 @@ def image_to_parsing(img, net):
55
  img = torch.unsqueeze(img, 0)
56
 
57
  with torch.no_grad():
58
- if run_with_cuda:
59
- img = img.cuda()
60
  out = net(img)[0]
61
  parsing = out.squeeze(0).cpu().numpy().argmax(0)
62
  return parsing
@@ -68,20 +98,33 @@ def get_mask(parsing, classes):
68
  res += parsing == val
69
  return res
70
 
71
- def swap_regions(source, target, net, includes=[1,2,3,4,5,10,11,12,13], excludes=[7,8], blur_size=25):
72
  parsing = image_to_parsing(source, net)
 
73
  if len(includes) == 0:
74
  return source, np.zeros_like(source)
 
75
  include_mask = get_mask(parsing, includes)
76
- include_mask = np.repeat(np.expand_dims(include_mask.astype('float32'), axis=2), 3, 2)
77
- if len(excludes) > 0:
78
- exclude_mask = get_mask(parsing, excludes)
79
- exclude_mask = np.repeat(np.expand_dims(exclude_mask.astype('float32'), axis=2), 3, 2)
80
- include_mask -= exclude_mask
81
- mask = 1 - cv2.GaussianBlur(include_mask.clip(0,1), (0, 0), blur_size)
82
- result = (1 - mask) * cv2.resize(source, (512, 512)) + mask * cv2.resize(target, (512, 512))
83
- result = cv2.resize(result.astype("float32"), (source.shape[1], source.shape[0]))
84
- return result, mask.astype('float32')
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def mask_regions_to_list(values):
87
  out_ids = []
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
  import torchvision.transforms as transforms
5
  import cv2
6
  import numpy as np
 
29
  "Hat":18
30
  }
31
 
32
+ # Borrowed from simswap
33
+ # https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
34
+ class SoftErosion(nn.Module):
35
+ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
36
+ super(SoftErosion, self).__init__()
37
+ r = kernel_size // 2
38
+ self.padding = r
39
+ self.iterations = iterations
40
+ self.threshold = threshold
41
 
42
+ # Create kernel
43
+ y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
44
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
45
+ kernel = dist.max() - dist
46
+ kernel /= kernel.sum()
47
+ kernel = kernel.view(1, 1, *kernel.shape)
48
+ self.register_buffer('weight', kernel)
49
 
50
+ def forward(self, x):
51
+ x = x.float()
52
+ for i in range(self.iterations - 1):
53
+ x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
54
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
55
+
56
+ mask = x >= self.threshold
57
+ x[mask] = 1.0
58
+ x[~mask] /= x[~mask].max()
59
+
60
+ return x, mask
61
+
62
+ device = "cpu"
63
+
64
+ def init_parser(pth_path, mode="cpu"):
65
+ global device
66
+ device = mode
67
  n_classes = 19
68
  net = BiSeNet(n_classes=n_classes)
69
+ if device == "cuda":
70
  net.cuda()
71
  net.load_state_dict(torch.load(pth_path))
72
  else:
 
86
  img = torch.unsqueeze(img, 0)
87
 
88
  with torch.no_grad():
89
+ img = img.to(device)
 
90
  out = net(img)[0]
91
  parsing = out.squeeze(0).cpu().numpy().argmax(0)
92
  return parsing
 
98
  res += parsing == val
99
  return res
100
 
101
+ def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
102
  parsing = image_to_parsing(source, net)
103
+
104
  if len(includes) == 0:
105
  return source, np.zeros_like(source)
106
+
107
  include_mask = get_mask(parsing, includes)
108
+ mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
109
+
110
+ if smooth_mask is not None:
111
+ mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
112
+ face_mask_tensor = mask_tensor[0] + mask_tensor[1]
113
+ soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
114
+ soft_face_mask_tensor.squeeze_()
115
+ mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
116
+
117
+ if blur > 0:
118
+ mask = cv2.GaussianBlur(mask, (0, 0), blur)
119
+
120
+ resized_source = cv2.resize((source/255).astype("float32"), (512, 512))
121
+ resized_target = cv2.resize((target/255).astype("float32"), (512, 512))
122
+
123
+ result = mask * resized_source + (1 - mask) * resized_target
124
+ normalized_result = (result - np.min(result)) / (np.max(result) - np.min(result))
125
+ result = cv2.resize((result*255).astype("uint8"), (source.shape[1], source.shape[0]))
126
+
127
+ return result
128
 
129
  def mask_regions_to_list(values):
130
  out_ids = []
swapper.py CHANGED
@@ -25,10 +25,10 @@ def swap_face(whole_img, target_face, source_face, models):
25
  aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
26
 
27
  if face_parser is not None:
28
- fp_enable, mi, me, mb = models.get("face_parser_sett")
29
  if fp_enable:
30
- bgr_fake, parsed_mask = swap_regions(
31
- bgr_fake, aimg, face_parser, includes=mi, excludes=me, blur_size=mb
32
  )
33
 
34
  if fe_enable:
 
25
  aimg, _ = face_align.norm_crop2(whole_img, target_face.kps, image_size=image_size)
26
 
27
  if face_parser is not None:
28
+ fp_enable, includes, smooth_mask, blur_amount = models.get("face_parser_sett")
29
  if fp_enable:
30
+ bgr_fake = swap_regions(
31
+ bgr_fake, aimg, face_parser, smooth_mask, includes=includes, blur=blur_amount
32
  )
33
 
34
  if fe_enable: