File size: 872 Bytes
2a38579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f70127
 
 
 
2a38579
0f70127
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import numpy as np
from facenet_pytorch import MTCNN, InceptionResnetV1

class FacenetEmbedder:
    def __init__(self):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.mtcnn = MTCNN(device=self.device)
        self.resnet = InceptionResnetV1(pretrained='vggface2', device=self.device).eval()

    def detect_face(self, batch):
        faces = self.mtcnn.detect(batch)
        return faces

    def encode(self, batch):
        face_batch = self.mtcnn(batch)
        face_batch = [i for i in face_batch if i is not None]
        if face_batch:
            aligned = torch.stack(face_batch)
            if self.device.type == "cuda": 
                aligned = aligned.to(self.device)

            embeddings = self.resnet(aligned).detach().cpu()
            return embeddings.tolist()
        else: return None