File size: 5,507 Bytes
9ae3d29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import List, Optional, Tuple, Any, Dict
from time import sleep

import cv2
import gradio

import DeepFakeAI.choices
import DeepFakeAI.globals
from DeepFakeAI import wording
from DeepFakeAI.capturer import get_video_frame
from DeepFakeAI.face_analyser import get_many_faces
from DeepFakeAI.face_reference import clear_face_reference
from DeepFakeAI.typing import Frame, FaceRecognition
from DeepFakeAI.uis import core as ui
from DeepFakeAI.uis.typing import ComponentName, Update
from DeepFakeAI.utilities import is_image, is_video

FACE_RECOGNITION_DROPDOWN : Optional[gradio.Dropdown] = None
REFERENCE_FACE_POSITION_GALLERY : Optional[gradio.Gallery] = None
REFERENCE_FACE_DISTANCE_SLIDER : Optional[gradio.Slider] = None


def render() -> None:
	global FACE_RECOGNITION_DROPDOWN
	global REFERENCE_FACE_POSITION_GALLERY
	global REFERENCE_FACE_DISTANCE_SLIDER

	with gradio.Box():
		reference_face_gallery_args: Dict[str, Any] = {
			'label': wording.get('reference_face_gallery_label'),
			'height': 120,
			'object_fit': 'cover',
			'columns': 10,
			'allow_preview': False,
			'visible': 'reference' in DeepFakeAI.globals.face_recognition
		}
		if is_image(DeepFakeAI.globals.target_path):
			reference_frame = cv2.imread(DeepFakeAI.globals.target_path)
			reference_face_gallery_args['value'] = extract_gallery_frames(reference_frame)
		if is_video(DeepFakeAI.globals.target_path):
			reference_frame = get_video_frame(DeepFakeAI.globals.target_path, DeepFakeAI.globals.reference_frame_number)
			reference_face_gallery_args['value'] = extract_gallery_frames(reference_frame)
		FACE_RECOGNITION_DROPDOWN = gradio.Dropdown(
			label = wording.get('face_recognition_dropdown_label'),
			choices = DeepFakeAI.choices.face_recognition,
			value = DeepFakeAI.globals.face_recognition
		)
		REFERENCE_FACE_POSITION_GALLERY = gradio.Gallery(**reference_face_gallery_args)
		REFERENCE_FACE_DISTANCE_SLIDER = gradio.Slider(
			label = wording.get('reference_face_distance_slider_label'),
			value = DeepFakeAI.globals.reference_face_distance,
			maximum = 3,
			step = 0.05,
			visible = 'reference' in DeepFakeAI.globals.face_recognition
		)
		ui.register_component('face_recognition_dropdown', FACE_RECOGNITION_DROPDOWN)
		ui.register_component('reference_face_position_gallery', REFERENCE_FACE_POSITION_GALLERY)
		ui.register_component('reference_face_distance_slider', REFERENCE_FACE_DISTANCE_SLIDER)


def listen() -> None:
	FACE_RECOGNITION_DROPDOWN.select(update_face_recognition, inputs = FACE_RECOGNITION_DROPDOWN, outputs = [ REFERENCE_FACE_POSITION_GALLERY, REFERENCE_FACE_DISTANCE_SLIDER ])
	REFERENCE_FACE_POSITION_GALLERY.select(clear_and_update_face_reference_position)
	REFERENCE_FACE_DISTANCE_SLIDER.change(update_reference_face_distance, inputs = REFERENCE_FACE_DISTANCE_SLIDER)
	update_component_names : List[ComponentName] =\
	[
		'target_file',
		'preview_frame_slider'
	]
	for component_name in update_component_names:
		component = ui.get_component(component_name)
		if component:
			component.change(update_face_reference_position, outputs = REFERENCE_FACE_POSITION_GALLERY)
	select_component_names : List[ComponentName] =\
	[
		'face_analyser_direction_dropdown',
		'face_analyser_age_dropdown',
		'face_analyser_gender_dropdown'
	]
	for component_name in select_component_names:
		component = ui.get_component(component_name)
		if component:
			component.select(update_face_reference_position, outputs = REFERENCE_FACE_POSITION_GALLERY)


def update_face_recognition(face_recognition : FaceRecognition) -> Tuple[Update, Update]:
	if face_recognition == 'reference':
		DeepFakeAI.globals.face_recognition = face_recognition
		return gradio.update(visible = True), gradio.update(visible = True)
	if face_recognition == 'many':
		DeepFakeAI.globals.face_recognition = face_recognition
		return gradio.update(visible = False), gradio.update(visible = False)


def clear_and_update_face_reference_position(event: gradio.SelectData) -> Update:
	clear_face_reference()
	return update_face_reference_position(event.index)


def update_face_reference_position(reference_face_position : int = 0) -> Update:
	sleep(0.2)
	gallery_frames = []
	DeepFakeAI.globals.reference_face_position = reference_face_position
	if is_image(DeepFakeAI.globals.target_path):
		reference_frame = cv2.imread(DeepFakeAI.globals.target_path)
		gallery_frames = extract_gallery_frames(reference_frame)
	if is_video(DeepFakeAI.globals.target_path):
		reference_frame = get_video_frame(DeepFakeAI.globals.target_path, DeepFakeAI.globals.reference_frame_number)
		gallery_frames = extract_gallery_frames(reference_frame)
	if gallery_frames:
		return gradio.update(value = gallery_frames)
	return gradio.update(value = None)


def update_reference_face_distance(reference_face_distance : float) -> Update:
	DeepFakeAI.globals.reference_face_distance = reference_face_distance
	return gradio.update(value = reference_face_distance)


def extract_gallery_frames(reference_frame : Frame) -> List[Frame]:
	crop_frames = []
	faces = get_many_faces(reference_frame)
	for face in faces:
		start_x, start_y, end_x, end_y = map(int, face['bbox'])
		padding_x = int((end_x - start_x) * 0.25)
		padding_y = int((end_y - start_y) * 0.25)
		start_x = max(0, start_x - padding_x)
		start_y = max(0, start_y - padding_y)
		end_x = max(0, end_x + padding_x)
		end_y = max(0, end_y + padding_y)
		crop_frame = reference_frame[start_y:end_y, start_x:end_x]
		crop_frames.append(ui.normalize_frame(crop_frame))
	return crop_frames