File size: 2,441 Bytes
20cf96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
import os
import torch
from . import mtcnn
from .face_yolo import face_yolo_detection
import argparse
from PIL import Image
from tqdm import tqdm
import random
from datetime import datetime

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

mtcnn_model = mtcnn.MTCNN(device=DEVICE, crop_size=(112, 112))

def add_padding(pil_img, top, right, bottom, left, color=(0,0,0)):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result

def handle_image_mtcnn(img_path, pil_img):
        img = Image.open(img_path).convert('RGB') if pil_img is None else pil_img
        assert isinstance(img, Image.Image), 'Face alignment requires PIL image or path'
        try:
            bboxes, faces = mtcnn_model.align_multi(img, limit=1)
            return bboxes[0], faces[0]
        except Exception as e:
            print(f'Face detection failed: {e}')
            return None, None

def get_aligned_face(image_path_or_image_paths, rgb_pil_image=None, algorithm='mtcnn'):
    if algorithm=='mtcnn':
        if isinstance(image_path_or_image_paths, list):
            results = [handle_image_mtcnn(path, rgb_pil_image) for path in image_path_or_image_paths]
            return results
        elif isinstance(image_path_or_image_paths, str):
            return [handle_image_mtcnn(image_path_or_image_paths, rgb_pil_image)]
        else:
            raise TypeError("image_path_or_image_paths must be a list or string") 

    elif algorithm=='yolo':
        if isinstance(image_path_or_image_paths, list):
            image_paths = image_path_or_image_paths
            results = face_yolo_detection(image_paths,
                        # yolo_model_path="ckpts/yolo_face_detection/model.pt",
                        use_batch=True, device=DEVICE)
        elif isinstance(image_path_or_image_paths, str):
            image_paths = [image_path_or_image_paths]
            results = face_yolo_detection(image_paths,
                        # yolo_model_path="ckpts/yolo_face_detection/model.pt",
                        use_batch=True, device=DEVICE)
        else:
            raise TypeError("image_path_or_image_paths must be a list or string") 
        results = list(results)
    return results