DmitryRyumin's picture
Summary
eac013f
raw
history blame
6.59 kB
"""
File: submit.py
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
Description: Event handler for Gradio app to submit.
License: MIT License
"""
import spaces
import torch
import pandas as pd
import cv2
import gradio as gr
# Importing necessary components for the Gradio app
from app.config import config_data
from app.utils import (
Timer,
convert_video_to_audio,
readetect_speech,
slice_audio,
find_intersections,
calculate_mode,
find_nearest_frames,
convert_webm_to_mp4,
)
from app.plots import (
get_evenly_spaced_frame_indices,
plot_audio,
display_frame_info,
plot_images,
plot_predictions,
)
from app.data_init import (
read_audio,
get_speech_timestamps,
vad_model,
video_model,
asr,
audio_model,
text_model,
)
from app.load_models import VideoFeatureExtractor
@spaces.GPU
def event_handler_submit(
video: str,
) -> tuple[
gr.Textbox,
gr.Plot,
gr.Plot,
gr.Plot,
gr.Plot,
gr.Row,
gr.Textbox,
gr.Textbox,
]:
with Timer() as timer:
if video:
if video.split(".")[-1] == "webm":
video = convert_webm_to_mp4(video)
audio_file_path = convert_video_to_audio(
file_path=video, sr=config_data.General_SR
)
wav, vad_info = readetect_speech(
file_path=audio_file_path,
read_audio=read_audio,
get_speech_timestamps=get_speech_timestamps,
vad_model=vad_model,
sr=config_data.General_SR,
)
audio_windows = slice_audio(
start_time=config_data.General_START_TIME,
end_time=int(len(wav)),
win_max_length=int(
config_data.General_WIN_MAX_LENGTH * config_data.General_SR
),
win_shift=int(config_data.General_WIN_SHIFT * config_data.General_SR),
win_min_length=int(
config_data.General_WIN_MIN_LENGTH * config_data.General_SR
),
)
intersections = find_intersections(
x=audio_windows,
y=vad_info,
min_length=config_data.General_WIN_MIN_LENGTH * config_data.General_SR,
)
vfe = VideoFeatureExtractor(video_model, file_path=video, with_features=False)
vfe.preprocess_video()
transcriptions, total_text = asr(wav, audio_windows)
window_frames = []
preds_emo = []
preds_sen = []
for w_idx, window in enumerate(audio_windows):
a_w = intersections[w_idx]
if not a_w["speech"]:
a_pred = None
else:
wave = wav[a_w["start"] : a_w["end"]].clone()
a_pred, _ = audio_model(wave)
v_pred, _ = vfe(window, config_data.General_WIN_MAX_LENGTH)
t_pred, _ = text_model(transcriptions[w_idx][0])
if a_pred:
pred_emo = (a_pred["emo"] + v_pred["emo"] + t_pred["emo"]) / 3
pred_sen = (a_pred["sen"] + v_pred["sen"] + t_pred["sen"]) / 3
else:
pred_emo = (v_pred["emo"] + t_pred["emo"]) / 2
pred_sen = (v_pred["sen"] + t_pred["sen"]) / 2
frames = list(
range(
int(window["start"] * vfe.fps / config_data.General_SR) + 1,
int(window["end"] * vfe.fps / config_data.General_SR) + 2,
)
)
preds_emo.extend([torch.argmax(pred_emo).numpy()] * len(frames))
preds_sen.extend([torch.argmax(pred_sen).numpy()] * len(frames))
window_frames.extend(frames)
if max(window_frames) < vfe.frame_number:
missed_frames = list(range(max(window_frames) + 1, vfe.frame_number + 1))
window_frames.extend(missed_frames)
preds_emo.extend([preds_emo[-1]] * len(missed_frames))
preds_sen.extend([preds_sen[-1]] * len(missed_frames))
df_pred = pd.DataFrame(columns=["frames", "pred_emo", "pred_sent"])
df_pred["frames"] = window_frames
df_pred["pred_emo"] = preds_emo
df_pred["pred_sent"] = preds_sen
df_pred = df_pred.groupby("frames").agg(
{
"pred_emo": calculate_mode,
"pred_sent": calculate_mode,
}
)
frame_indices = get_evenly_spaced_frame_indices(vfe.frame_number, 9)
num_frames = len(wav)
time_axis = [i / config_data.General_SR for i in range(num_frames)]
plt_audio = plot_audio(
time_axis, wav.unsqueeze(0), frame_indices, vfe.fps, (12, 2)
)
all_idx_faces = list(vfe.faces[1].keys())
need_idx_faces = find_nearest_frames(frame_indices, all_idx_faces)
faces = []
for idx_frame, idx_faces in zip(frame_indices, need_idx_faces):
cur_face = cv2.resize(
vfe.faces[1][idx_faces], (224, 224), interpolation=cv2.INTER_AREA
)
faces.append(
display_frame_info(
cur_face, "Frame: {}".format(idx_frame + 1), box_scale=0.3
)
)
plt_faces = plot_images(faces)
plt_emo = plot_predictions(
df_pred,
"pred_emo",
"Emotion",
list(config_data.General_DICT_EMO),
(12, 2.5),
[i + 1 for i in frame_indices],
3,
)
plt_sent = plot_predictions(
df_pred,
"pred_sent",
"Sentiment",
list(config_data.General_DICT_SENT),
(12, 1.5),
[i + 1 for i in frame_indices],
3,
)
return (
gr.Textbox(
value=" ".join(total_text).strip(),
info=config_data.InformationMessages_REC_TEXT,
container=True,
elem_classes="noti-results",
),
gr.Plot(value=plt_audio, visible=True),
gr.Plot(value=plt_faces, visible=True),
gr.Plot(value=plt_emo, visible=True),
gr.Plot(value=plt_sent, visible=True),
gr.Row(visible=True),
gr.Textbox(
value=config_data.OtherMessages_SEC.format(vfe.dur),
info=config_data.InformationMessages_VIDEO_DURATION,
container=True,
visible=True,
),
gr.Textbox(
value=timer.execution_time,
info=config_data.InformationMessages_INFERENCE_TIME,
container=True,
visible=True,
),
)