File size: 2,289 Bytes
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from PIL import Image
from typing import Union, List, Tuple
from . import mtcnn
from .face_yolo import face_yolo_detection

# Device configuration
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# Initialize MTCNN model
MTCNN_MODEL = mtcnn.MTCNN(device=DEVICE, crop_size=(112, 112))

def add_image_padding(pil_img: Image.Image, top: int, right: int, bottom: int, left: int, 

                     color: Tuple[int, int, int] = (0, 0, 0)) -> Image.Image:
    """Add padding to a PIL image."""
    width, height = pil_img.size
    new_width, new_height = width + right + left, height + top + bottom
    padded_img = Image.new(pil_img.mode, (new_width, new_height), color)
    padded_img.paste(pil_img, (left, top))
    return padded_img

def detect_faces_mtcnn(image: Union[str, Image.Image]) -> Tuple[Union[list, None], Union[Image.Image, None]]:
    """Detect and align faces using MTCNN model."""
    if isinstance(image, str):
        image = Image.open(image).convert('RGB')
    
    if not isinstance(image, Image.Image):
        raise TypeError("Input must be a PIL Image or path to an image")
    
    try:
        bboxes, faces = MTCNN_MODEL.align_multi(image, limit=1)
        return bboxes[0] if bboxes else None, faces[0] if faces else None
    except Exception as e:
        print(f"MTCNN face detection failed: {e}")
        return None, None

def get_aligned_face(image_input: Union[str, List[str]], 

                    algorithm: str = 'mtcnn') -> List[Tuple[Union[list, None], Union[Image.Image, None]]]:
    """Get aligned faces from image(s) using specified algorithm."""
    if algorithm not in ['mtcnn', 'yolo']:
        raise ValueError("Algorithm must be 'mtcnn' or 'yolo'")

    # Convert single image path to list for consistent processing
    image_paths = [image_input] if isinstance(image_input, str) else image_input
    if not isinstance(image_paths, list):
        raise TypeError("Input must be a string or list of strings")

    if algorithm == 'mtcnn':
        return [detect_faces_mtcnn(path) for path in image_paths]
    
    # YOLO detection
    results = face_yolo_detection(
        image_paths,
        use_batch=True,
        device=DEVICE
    )
    return list(results)