prithivMLmods commited on
Commit
b9268a9
·
verified ·
1 Parent(s): 9ea7c68

Update roop/face_analyser.py

Browse files
Files changed (1) hide show
  1. roop/face_analyser.py +55 -55
roop/face_analyser.py CHANGED
@@ -1,55 +1,55 @@
1
- import threading
2
- from typing import Any, Optional, List
3
- import insightface
4
- import numpy
5
- import spaces
6
-
7
- import roop.globals
8
- from roop.typing import Frame, Face
9
-
10
- FACE_ANALYSER = None
11
- THREAD_LOCK = threading.Lock()
12
-
13
- @spaces.GPU
14
- def get_face_analyser() -> Any:
15
- global FACE_ANALYSER
16
-
17
- with THREAD_LOCK:
18
- if FACE_ANALYSER is None:
19
- FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
20
- FACE_ANALYSER.prepare(ctx_id=0)
21
- return FACE_ANALYSER
22
-
23
-
24
- def clear_face_analyser() -> Any:
25
- global FACE_ANALYSER
26
-
27
- FACE_ANALYSER = None
28
-
29
-
30
- def get_one_face(frame: Frame, position: int = 0) -> Optional[Face]:
31
- many_faces = get_many_faces(frame)
32
- if many_faces:
33
- try:
34
- return many_faces[position]
35
- except IndexError:
36
- return many_faces[-1]
37
- return None
38
-
39
-
40
- def get_many_faces(frame: Frame) -> Optional[List[Face]]:
41
- try:
42
- return get_face_analyser().get(frame)
43
- except ValueError:
44
- return None
45
-
46
-
47
- def find_similar_face(frame: Frame, reference_face: Face) -> Optional[Face]:
48
- many_faces = get_many_faces(frame)
49
- if many_faces:
50
- for face in many_faces:
51
- if hasattr(face, 'normed_embedding') and hasattr(reference_face, 'normed_embedding'):
52
- distance = numpy.sum(numpy.square(face.normed_embedding - reference_face.normed_embedding))
53
- if distance < roop.globals.similar_face_distance:
54
- return face
55
- return None
 
1
+ import threading
2
+ from typing import Any, Optional, List
3
+ import insightface
4
+ import numpy
5
+ import spaces
6
+
7
+ import roop.globals
8
+ from roop.typing import Frame, Face
9
+
10
+ FACE_ANALYSER = None
11
+ THREAD_LOCK = threading.Lock()
12
+
13
+ @spaces.GPU()
14
+ def get_face_analyser() -> Any:
15
+ global FACE_ANALYSER
16
+
17
+ with THREAD_LOCK:
18
+ if FACE_ANALYSER is None:
19
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
20
+ FACE_ANALYSER.prepare(ctx_id=0)
21
+ return FACE_ANALYSER
22
+
23
+
24
+ def clear_face_analyser() -> Any:
25
+ global FACE_ANALYSER
26
+
27
+ FACE_ANALYSER = None
28
+
29
+
30
+ def get_one_face(frame: Frame, position: int = 0) -> Optional[Face]:
31
+ many_faces = get_many_faces(frame)
32
+ if many_faces:
33
+ try:
34
+ return many_faces[position]
35
+ except IndexError:
36
+ return many_faces[-1]
37
+ return None
38
+
39
+
40
+ def get_many_faces(frame: Frame) -> Optional[List[Face]]:
41
+ try:
42
+ return get_face_analyser().get(frame)
43
+ except ValueError:
44
+ return None
45
+
46
+
47
+ def find_similar_face(frame: Frame, reference_face: Face) -> Optional[Face]:
48
+ many_faces = get_many_faces(frame)
49
+ if many_faces:
50
+ for face in many_faces:
51
+ if hasattr(face, 'normed_embedding') and hasattr(reference_face, 'normed_embedding'):
52
+ distance = numpy.sum(numpy.square(face.normed_embedding - reference_face.normed_embedding))
53
+ if distance < roop.globals.similar_face_distance:
54
+ return face
55
+ return None