|
import torch |
|
import gradio as gr |
|
from src.load_html import get_description_html |
|
from src.audio_processor import AudioProcessor |
|
from src.model.behaviour_model import get_behaviour_model |
|
from transformers import ( |
|
pipeline, |
|
WavLMForSequenceClassification |
|
) |
|
|
|
|
|
|
|
def create_demo(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
segmentation_model = pipeline( |
|
task="automatic-speech-recognition", |
|
model="openai/whisper-large-v3-turbo", |
|
tokenizer="openai/whisper-large-v3-turbo", |
|
device=device |
|
) |
|
|
|
emotion_model = WavLMForSequenceClassification.from_pretrained("links-ads/kk-speech-emotion-recognition") |
|
emotion_model.to(device) |
|
emotion_model.eval() |
|
|
|
behaviour_model = get_behaviour_model( |
|
classifier_weights_path="src/model/classifier_weights.bin", |
|
device=device, |
|
) |
|
|
|
audio_processor = AudioProcessor( |
|
emotion_model=emotion_model, |
|
segmentation_model=segmentation_model, |
|
device=device, |
|
behaviour_model=behaviour_model, |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(get_description_html) |
|
|
|
audio_input = gr.Audio(label="Upload Audio", type="filepath") |
|
submit_button = gr.Button("Generate Graph") |
|
|
|
graph_output = gr.Plot(label="Generated Graph") |
|
|
|
submit_button.click( |
|
fn=audio_processor, |
|
inputs=audio_input, |
|
outputs=graph_output |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.launch(show_api=False) |