MASAI / app /data_init.py
DmitryRyumin's picture
Summary
15b7f31
raw
history blame
1.68 kB
"""
File: data_init.py
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov
Description: Initial data loading.
License: MIT License
"""
import torch
# Importing necessary components for the Gradio app
from app.config import config_data
from app.gpu_init import device
from app.load_models import (
AudioFeatureExtractor,
VideoModelLoader,
TextFeatureExtractor,
)
from app.utils import ASRModel
vad_model, vad_utils = torch.hub.load(
repo_or_dir=config_data.StaticPaths_VAD_MODEL,
model="silero_vad",
force_reload=False,
verbose=False,
onnx=False,
)
get_speech_timestamps, _, read_audio, _, _ = vad_utils
audio_model = AudioFeatureExtractor(
checkpoint_url=config_data.StaticPaths_HF_MODELS
+ config_data.StaticPaths_EMO_SENT_AUDIO_WEIGHTS,
folder_path=config_data.StaticPaths_WEIGHTS,
device=device,
with_features=False,
)
video_model = VideoModelLoader(
face_checkpoint_url=config_data.StaticPaths_HF_MODELS
+ config_data.StaticPaths_YOLOV8N_FACE,
emotion_checkpoint_url=config_data.StaticPaths_HF_MODELS
+ config_data.StaticPaths_EMO_AFFECTNET_WEIGHTS,
emo_sent_checkpoint_url=config_data.StaticPaths_HF_MODELS
+ config_data.StaticPaths_EMO_SENT_VIDEO_WEIGHTS,
folder_path=config_data.StaticPaths_WEIGHTS,
device=device,
)
text_model = TextFeatureExtractor(
checkpoint_url=config_data.StaticPaths_HF_MODELS
+ config_data.StaticPaths_EMO_SENT_TEXT_WEIGHTS,
folder_path=config_data.StaticPaths_WEIGHTS,
device=device,
with_features=False,
)
asr = ASRModel(checkpoint_path=config_data.StaticPaths_OPENAI_WHISPER, device=device)