Doron Adler commited on
Commit
2c5167f
1 Parent(s): b558187

Sharper output with unsharp mask

Browse files
Files changed (7) hide show
  1. Sample00001.jpg +0 -0
  2. Sample00002.jpg +0 -0
  3. Sample00003.jpg +0 -0
  4. Sample00004.jpg +0 -0
  5. Sample00005.jpg +0 -0
  6. Sample00006.jpg +0 -0
  7. app.py +27 -3
Sample00001.jpg CHANGED
Sample00002.jpg CHANGED
Sample00003.jpg CHANGED
Sample00004.jpg CHANGED
Sample00005.jpg CHANGED
Sample00006.jpg CHANGED
app.py CHANGED
@@ -6,12 +6,27 @@ import face_detection
6
  import PIL
7
  from PIL import Image, ImageOps, ImageFile
8
  import numpy as np
9
-
10
  import torch
 
11
  torch.set_grad_enabled(False)
12
  model = torch.jit.load('u2net_bce_itr_16000_train_3.835149_tar_0.542587-400x_360x.jit.pt')
13
  model.eval()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def normPRED(d):
16
  ma = np.max(d)
17
  mi = np.min(d)
@@ -20,6 +35,12 @@ def normPRED(d):
20
 
21
  return dn
22
 
 
 
 
 
 
 
23
  def array_to_image(array_in):
24
  array_in = normPRED(array_in)
25
  array_in = np.squeeze(255.0*(array_in))
@@ -82,8 +103,11 @@ def face2hero(
82
  else:
83
  input = torch.Tensor(aligned_img)
84
  results = model(input)
85
- d2 = array_to_image(results[1].detach().numpy())
86
- output = img_concat_h(array_to_image(aligned_img), d2)
 
 
 
87
  del results
88
 
89
  return output
 
6
  import PIL
7
  from PIL import Image, ImageOps, ImageFile
8
  import numpy as np
9
+ import cv2 as cv
10
  import torch
11
+
12
  torch.set_grad_enabled(False)
13
  model = torch.jit.load('u2net_bce_itr_16000_train_3.835149_tar_0.542587-400x_360x.jit.pt')
14
  model.eval()
15
 
16
+ # https://en.wikipedia.org/wiki/Unsharp_masking
17
+ # https://stackoverflow.com/a/55590133/1495606
18
+ def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=2.0, threshold=0):
19
+ """Return a sharpened version of the image, using an unsharp mask."""
20
+ blurred = cv.GaussianBlur(image, kernel_size, sigma)
21
+ sharpened = float(amount + 1) * image - float(amount) * blurred
22
+ sharpened = np.maximum(sharpened, np.zeros(sharpened.shape))
23
+ sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape))
24
+ sharpened = sharpened.round().astype(np.uint8)
25
+ if threshold > 0:
26
+ low_contrast_mask = np.absolute(image - blurred) < threshold
27
+ np.copyto(sharpened, image, where=low_contrast_mask)
28
+ return sharpened
29
+
30
  def normPRED(d):
31
  ma = np.max(d)
32
  mi = np.min(d)
 
35
 
36
  return dn
37
 
38
+ def array_to_np(array_in):
39
+ array_in = normPRED(array_in)
40
+ array_in = np.squeeze(255.0*(array_in))
41
+ array_in = np.transpose(array_in, (1, 2, 0))
42
+ return array_in
43
+
44
  def array_to_image(array_in):
45
  array_in = normPRED(array_in)
46
  array_in = np.squeeze(255.0*(array_in))
 
103
  else:
104
  input = torch.Tensor(aligned_img)
105
  results = model(input)
106
+ hero_np_image = array_to_np(results[1].detach().numpy())
107
+ hero_image = unsharp_mask(hero_np_image)
108
+ hero_image = Image.fromarray(hero_image)
109
+
110
+ output = img_concat_h(array_to_image(aligned_img), hero_image)
111
  del results
112
 
113
  return output