Spaces:
Runtime error
Runtime error
add grayscale judgement (#32)
Browse files
basicsr/utils/img_util.py
CHANGED
|
@@ -168,3 +168,4 @@ def crop_border(imgs, crop_border):
|
|
| 168 |
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
| 169 |
else:
|
| 170 |
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
|
|
|
|
|
| 168 |
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
| 169 |
else:
|
| 170 |
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 171 |
+
|
facelib/utils/face_restoration_helper.py
CHANGED
|
@@ -6,7 +6,7 @@ from torchvision.transforms.functional import normalize
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
-
from facelib.utils.misc import img2tensor, imwrite
|
| 10 |
|
| 11 |
|
| 12 |
def get_largest_face(det_faces, h, w):
|
|
@@ -125,6 +125,9 @@ class FaceRestoreHelper(object):
|
|
| 125 |
img = img[:, :, 0:3]
|
| 126 |
|
| 127 |
self.input_img = img
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
if min(self.input_img.shape[:2])<512:
|
| 130 |
f = 512.0/min(self.input_img.shape[:2])
|
|
@@ -416,6 +419,9 @@ class FaceRestoreHelper(object):
|
|
| 416 |
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
| 417 |
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
| 418 |
|
|
|
|
|
|
|
|
|
|
| 419 |
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 420 |
alpha = upsample_img[:, :, 3:]
|
| 421 |
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
|
|
|
| 6 |
|
| 7 |
from facelib.detection import init_detection_model
|
| 8 |
from facelib.parsing import init_parsing_model
|
| 9 |
+
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
| 10 |
|
| 11 |
|
| 12 |
def get_largest_face(det_faces, h, w):
|
|
|
|
| 125 |
img = img[:, :, 0:3]
|
| 126 |
|
| 127 |
self.input_img = img
|
| 128 |
+
self.is_gray = is_gray(img, threshold=5)
|
| 129 |
+
if self.is_gray:
|
| 130 |
+
print('Grayscale input: True')
|
| 131 |
|
| 132 |
if min(self.input_img.shape[:2])<512:
|
| 133 |
f = 512.0/min(self.input_img.shape[:2])
|
|
|
|
| 419 |
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
| 420 |
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
| 421 |
|
| 422 |
+
if self.is_gray:
|
| 423 |
+
pasted_face = bgr2gray(pasted_face) # convert img into grayscale
|
| 424 |
+
|
| 425 |
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 426 |
alpha = upsample_img[:, :, 3:]
|
| 427 |
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
facelib/utils/misc.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import cv2
|
| 2 |
import os
|
| 3 |
import os.path as osp
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
from torch.hub import download_url_to_file, get_dir
|
| 6 |
from urllib.parse import urlparse
|
|
@@ -139,3 +141,34 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|
| 139 |
continue
|
| 140 |
|
| 141 |
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
import os
|
| 3 |
import os.path as osp
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
import torch
|
| 7 |
from torch.hub import download_url_to_file, get_dir
|
| 8 |
from urllib.parse import urlparse
|
|
|
|
| 141 |
continue
|
| 142 |
|
| 143 |
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def is_gray(img, threshold=10):
|
| 147 |
+
img = Image.fromarray(img)
|
| 148 |
+
if len(img.getbands()) == 1:
|
| 149 |
+
return True
|
| 150 |
+
img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
|
| 151 |
+
img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
|
| 152 |
+
img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
|
| 153 |
+
diff1 = (img1 - img2).var()
|
| 154 |
+
diff2 = (img2 - img3).var()
|
| 155 |
+
diff3 = (img3 - img1).var()
|
| 156 |
+
diff_sum = (diff1 + diff2 + diff3) / 3.0
|
| 157 |
+
if diff_sum <= threshold:
|
| 158 |
+
return True
|
| 159 |
+
else:
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
def rgb2gray(img, out_channel=3):
|
| 163 |
+
r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
|
| 164 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
| 165 |
+
if out_channel == 3:
|
| 166 |
+
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 167 |
+
return gray
|
| 168 |
+
|
| 169 |
+
def bgr2gray(img, out_channel=3):
|
| 170 |
+
b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
|
| 171 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
| 172 |
+
if out_channel == 3:
|
| 173 |
+
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
| 174 |
+
return gray
|