Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import argparse | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| import facer | |
| import facer.transform | |
| from copy import deepcopy | |
| import PIL | |
| def resize_image(image, max_size=1024): | |
| height,width,_ = image.shape | |
| if width > max_size or height > max_size: | |
| if width > height: | |
| new_width = max_size | |
| new_height = int((height / width) * max_size) | |
| else: | |
| new_height = max_size | |
| new_width = int((width / height) * max_size) | |
| image = cv2.resize(image, (new_width, new_height)) | |
| return image | |
| def open_and_resize_image(image_file, max_size=1024, return_type='numpy'): | |
| if isinstance(image_file, str) or isinstance(image_file, PIL.Image.Image): | |
| if isinstance(image_file, str): | |
| img = Image.open(image_file) | |
| else: | |
| img = image_file | |
| width, height = img.size | |
| if width > height: | |
| new_width = max_size | |
| new_height = int((height / width) * max_size) | |
| else: | |
| new_height = max_size | |
| new_width = int((width / height) * max_size) | |
| img = img.resize((new_width, new_height)) | |
| if return_type == 'numpy': | |
| return np.array(img.convert('RGB')) | |
| else: | |
| return img | |
| elif isinstance(image_file, np.ndarray): | |
| height,width,_ = image_file.shape | |
| if width > height: | |
| new_width = max_size | |
| new_height = int((height / width) * max_size) | |
| else: | |
| new_height = max_size | |
| new_width = int((width / height) * max_size) | |
| img = cv2.resize(image_file, (new_width, new_height)) | |
| assert return_type == 'numpy' | |
| return img | |
| else: | |
| raise TypeError("Do not support this img type") | |
| def loose_warp_face(input_image, face_detector, face_target_shape=(512, 512), scale=1.3, face_parser=None, device=None, croped_face_scale=3, bg_value = 0, croped_face_y_offset=0.0): | |
| """ Get the tight/loose warp of the face in the image, in which only one face is of concern. | |
| Args: | |
| input_image: Image path, or PIL.Image.Image, or np.ndarray (dtype=np.uint8). | |
| face_detector: a facer.face_detector, for face detection. | |
| face_target_shape: Output resolution. | |
| scale: Scale of the output image w.r.t. the face it contains. | |
| Returns: | |
| PIL.Image.Image, single warped face. | |
| """ | |
| _normalized_face_target_pts = torch.tensor([ | |
| [38.2946, 51.6963], | |
| [73.5318, 51.5014], | |
| [56.0252, 71.7366], | |
| [41.5493, 92.3655], | |
| [70.729904, 92.2041]]) / 112.0 | |
| target_pts = ((_normalized_face_target_pts - | |
| torch.tensor([0.5, 0.5])) / scale | |
| + torch.tensor([0.5, 0.5])) | |
| if face_detector is not None: | |
| device = next(face_detector.parameters()).device | |
| if isinstance(input_image, str): | |
| # image_tensor_hwc = facer.read_hwc(input_image) | |
| np_img = open_and_resize_image(input_image)[:,:,:3] # Downsample high-res images to avoid OOM. | |
| img_height, img_width = np_img.shape[:2] | |
| image_tensor_hwc = torch.from_numpy(np_img) | |
| elif isinstance(input_image, Image.Image): | |
| image_tensor_hwc = torch.from_numpy(np.array(input_image)[:,:,:3]) | |
| img_height, img_width = image_tensor_hwc.shape[:2] | |
| assert image_tensor_hwc.dtype == torch.uint8 | |
| else: | |
| assert isinstance(input_image, np.ndarray), 'Type %s of input_image is unsupported!' % type(input_image) | |
| assert input_image.dtype == np.uint8, 'dtype %s of input np.ndarray is unsupported!' % input_image.dtype | |
| input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)[:,:,:3] | |
| input_image = resize_image(input_image) | |
| image_tensor_hwc = torch.from_numpy(input_image) | |
| img_height, img_width = image_tensor_hwc.shape[:2] | |
| image_pt_bchw_255 = facer.hwc2bchw(image_tensor_hwc).to(device) | |
| res = {'cropped_face_masked': None, 'cropped_face': None, 'cropped_img': None, 'cropped_face_mask': None, 'align_face': None} | |
| if face_detector is not None: | |
| try: | |
| face_data = face_detector(image_pt_bchw_255) | |
| except: | |
| import pdb;pdb.set_trace() | |
| if len(face_data) == 0: | |
| return res | |
| if face_parser is not None: | |
| with torch.inference_mode(): | |
| faces = face_parser(image_pt_bchw_255, face_data) | |
| seg_logits = faces['seg']['logits'] | |
| seg_probs = seg_logits.softmax(dim=1) | |
| seg_probs = seg_probs.argmax(dim=1).unsqueeze(1)[:1] | |
| face_rects = face_data['rects'][:1] | |
| face_rects = face_data['rects'][:1] | |
| x1,y1,x2,y2 = face_rects[0][:4] | |
| x1 = (int(x1.item())) | |
| y1 = (int(y1.item())) | |
| x2 = (int(x2.item())) | |
| y2 = (int(y2.item())) | |
| face_width = x2-x1 | |
| face_height = y2-y1 | |
| center_x = int(0.5*(x1+x2)) | |
| center_y = int(0.5*(y1+y2)) + croped_face_y_offset * face_height | |
| croped_face_width = face_width*croped_face_scale | |
| croped_face_height = face_height*croped_face_scale | |
| x1 = max(int(center_x-0.5*croped_face_width),0) | |
| x2 = min(int(center_x+0.5*croped_face_width), img_width-1) | |
| y1 = max(int(center_y-0.5*croped_face_height),0) | |
| y2 = min(int(center_y+0.5*croped_face_height), img_height-1) | |
| croped_face_height = y2-y1 | |
| croped_face_width = x2-x1 | |
| center_x = int(0.5*(x1+x2)) | |
| center_y = int(0.5*(y1+y2)) | |
| croped_face_len = min(croped_face_height, croped_face_width) | |
| x1 = int(center_x - 0.5*croped_face_len) | |
| y1 = int(center_y - 0.5*croped_face_len) | |
| x2 = x1+croped_face_len | |
| y2 = y1+croped_face_len | |
| croped_image_pt_bchw_255 = image_pt_bchw_255[:, :, y1:y2, x1:x2] | |
| face_points = face_data['points'][:1] | |
| batch_inds = face_data['image_ids'][:1] | |
| matrix_align = facer.transform.get_face_align_matrix( | |
| face_points, face_target_shape, | |
| target_pts=(target_pts * torch.tensor(face_target_shape))) | |
| grid = facer.transform.make_tanh_warp_grid( | |
| matrix_align, 0.0, face_target_shape, image_pt_bchw_255.shape[2:],) | |
| image = F.grid_sample( | |
| image_pt_bchw_255.float()[batch_inds], | |
| grid, 'bilinear', align_corners=False) | |
| image_align_raw = deepcopy(image) | |
| image_align_raw = facer.bchw2hwc(image_align_raw).to(torch.uint8).cpu().numpy() | |
| image_align_raw = Image.fromarray(image_align_raw) | |
| image_croped = facer.bchw2hwc(croped_image_pt_bchw_255).to(torch.uint8).cpu().numpy() | |
| image_croped = Image.fromarray(image_croped) | |
| if face_parser is not None: | |
| image_no_mask = deepcopy(image) | |
| new_size = list(seg_probs.shape) | |
| new_size[1] = image.shape[1] | |
| seg_probs = seg_probs.expand(new_size) | |
| assert seg_probs.shape[0] == 1 and image.shape[0] == 1, 'mask shape {}, != image shape {}'.format(seg_probs.shape, image.shape) | |
| mask_img = F.grid_sample(seg_probs.float(), grid, 'bilinear', align_corners=False) | |
| image[mask_img == 0] = bg_value | |
| mask_img[mask_img!=0] = 1 | |
| assert mask_img.shape[0] == 1 | |
| else: | |
| image_no_mask = image | |
| mask_img = None | |
| else: | |
| image = image_pt_bchw_255 | |
| image_no_mask = image_pt_bchw_255 | |
| image_align_raw = None | |
| image_croped = None | |
| image = facer.bchw2hwc(image).to(torch.uint8).cpu().numpy() | |
| image_no_mask = facer.bchw2hwc(image_no_mask).to(torch.uint8).cpu().numpy() | |
| res.update({'cropped_face_masked': Image.fromarray(image), 'cropped_face': Image.fromarray(image_no_mask), 'cropped_img':image_croped, 'cropped_face_mask': mask_img, 'align_face': image_align_raw}) | |
| return res | |
| def tight_warp_face(input_image, face_detector, face_parser=None, device=None): | |
| return loose_warp_face(input_image, face_detector, | |
| face_target_shape=(112, 112), scale=1, face_parser=face_parser, device=device) | |