Spaces:
Running
on
Zero
Running
on
Zero
""" | |
File: submit.py | |
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov | |
Description: Event handler for Gradio app to submit. | |
License: MIT License | |
""" | |
import spaces | |
import torch | |
import pandas as pd | |
import cv2 | |
import gradio as gr | |
# Importing necessary components for the Gradio app | |
from app.config import config_data | |
from app.utils import ( | |
Timer, | |
convert_video_to_audio, | |
readetect_speech, | |
slice_audio, | |
find_intersections, | |
calculate_mode, | |
find_nearest_frames, | |
convert_webm_to_mp4, | |
) | |
from app.plots import ( | |
get_evenly_spaced_frame_indices, | |
plot_audio, | |
display_frame_info, | |
plot_images, | |
plot_predictions, | |
) | |
from app.data_init import ( | |
read_audio, | |
get_speech_timestamps, | |
vad_model, | |
video_model, | |
asr, | |
audio_model, | |
text_model, | |
) | |
from app.load_models import VideoFeatureExtractor | |
def event_handler_submit( | |
video: str, | |
) -> tuple[ | |
gr.Textbox, | |
gr.Plot, | |
gr.Plot, | |
gr.Plot, | |
gr.Plot, | |
gr.Row, | |
gr.Textbox, | |
gr.Textbox, | |
]: | |
with Timer() as timer: | |
if video: | |
if video.split(".")[-1] == "webm": | |
video = convert_webm_to_mp4(video) | |
audio_file_path = convert_video_to_audio( | |
file_path=video, sr=config_data.General_SR | |
) | |
wav, vad_info = readetect_speech( | |
file_path=audio_file_path, | |
read_audio=read_audio, | |
get_speech_timestamps=get_speech_timestamps, | |
vad_model=vad_model, | |
sr=config_data.General_SR, | |
) | |
audio_windows = slice_audio( | |
start_time=config_data.General_START_TIME, | |
end_time=int(len(wav)), | |
win_max_length=int( | |
config_data.General_WIN_MAX_LENGTH * config_data.General_SR | |
), | |
win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR), | |
win_min_length=int( | |
config_data.General_WIN_MIN_LENGTH * config_data.General_SR | |
), | |
) | |
intersections = find_intersections( | |
x=audio_windows, | |
y=vad_info, | |
min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR, | |
) | |
vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False) | |
vfe.preprocess_video() | |
transcriptions, total_text = asr(wav, audio_windows) | |
window_frames = [] | |
preds_emo = [] | |
preds_sen = [] | |
for w_idx, window in enumerate(audio_windows): | |
a_w = intersections[w_idx] | |
if not a_w["speech"]: | |
a_pred = None | |
else: | |
wave = wav[a_w["start"] : a_w["end"]].clone() | |
a_pred, _ = audio_model(wave) | |
v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH) | |
t_pred, _ = text_model(transcriptions[w_idx][0]) | |
if a_pred: | |
pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3 | |
pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3 | |
else: | |
pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2 | |
pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2 | |
frames = list( | |
range( | |
int(window["start"] * vfe.fps / config_data.General_SR) + 1, | |
int(window["end"] * vfe.fps / config_data.General_SR) + 2, | |
) | |
) | |
preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames)) | |
preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames)) | |
window_frames.extend(frames) | |
if max(window_frames) < vfe.frame_number: | |
missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1)) | |
window_frames.extend(missed_frames) | |
preds_emo.extend([preds_emo[-1]] * len(missed_frames)) | |
preds_sen.extend([preds_sen[-1]] * len(missed_frames)) | |
df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"]) | |
df_pred["frames"] = window_frames | |
df_pred["pred_emo"] = preds_emo | |
df_pred["pred_sent"] = preds_sen | |
df_pred = df_pred.groupby("frames").agg( | |
{ | |
"pred_emo": calculate_mode, | |
"pred_sent": calculate_mode, | |
} | |
) | |
frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9) | |
num_frames = len(wav) | |
time_axis = [i / config_data.General_SR for i in range(num_frames)] | |
plt_audio = plot_audio( | |
time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2) | |
) | |
all_idx_faces = list(vfe.faces[1].keys()) | |
need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces) | |
faces = [] | |
for idx_frame, idx_faces in zip(frame_indices, need_idx_faces): | |
cur_face = cv2.resize( | |
vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA | |
) | |
faces.append( | |
display_frame_info( | |
cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3 | |
) | |
) | |
plt_faces = plot_images(faces) | |
plt_emo = plot_predictions( | |
df_pred, | |
"pred_emo", | |
"Emotion", | |
list(config_data.General_DICT_EMO), | |
(12, 2.5), | |
[i + 1 for i in frame_indices], | |
3, | |
) | |
plt_sent = plot_predictions( | |
df_pred, | |
"pred_sent", | |
"Sentiment", | |
list(config_data.General_DICT_SENT), | |
(12, 1.5), | |
[i + 1 for i in frame_indices], | |
3, | |
) | |
return ( | |
gr.Textbox( | |
value=" ".join(total_text).strip(), | |
info=config_data.InformationMessages_REC_TEXT, | |
container=True, | |
elem_classes="noti-results", | |
), | |
gr.Plot(value=plt_audio, visible=True), | |
gr.Plot(value=plt_faces, visible=True), | |
gr.Plot(value=plt_emo, visible=True), | |
gr.Plot(value=plt_sent, visible=True), | |
gr.Row(visible=True), | |
gr.Textbox( | |
value=config_data.OtherMessages_SEC.format(vfe.dur), | |
info=config_data.InformationMessages_VIDEO_DURATION, | |
container=True, | |
visible=True, | |
), | |
gr.Textbox( | |
value=timer.execution_time, | |
info=config_data.InformationMessages_INFERENCE_TIME, | |
container=True, | |
visible=True, | |
), | |
) | |