speech-analysis / app.py
hagenw's picture
Debug
963cdb8
raw
history blame
6.36 kB
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel
import audiofile
import audresample
device = 0 if torch.cuda.is_available() else "cpu"
duration = 1 # limit processing of audio
age_gender_model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender"
expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
class AgeGenderHead(nn.Module):
r"""Age-gender model head."""
def __init__(self, config, num_labels):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class AgeGenderModel(Wav2Vec2PreTrainedModel):
r"""Age-gender recognition model."""
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.age = AgeGenderHead(config, 1)
self.gender = AgeGenderHead(config, 3)
self.init_weights()
def forward(
self,
input_values,
):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
hidden_states = torch.mean(hidden_states, dim=1)
logits_age = self.age(hidden_states)
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
return hidden_states, logits_age, logits_gender
class ExpressionHead(nn.Module):
r"""Expression model head."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ExpressionModel(Wav2Vec2PreTrainedModel):
r"""speech expression model."""
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = ExpressionHead(config)
self.init_weights()
def forward(self, input_values):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
hidden_states = torch.mean(hidden_states, dim=1)
logits = self.classifier(hidden_states)
return hidden_states, logits
# Load models from hub
age_gender_processor = Wav2Vec2Processor.from_pretrained(age_gender_model_name)
age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name)
expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name)
expression_model = ExpressionModel.from_pretrained(expression_model_name)
def process_func(x: np.ndarray, sampling_rate: int) -> dict:
r"""Predict age and gender or extract embeddings from raw audio signal."""
# run through processor to normalize signal
# always returns a batch, so we just get the first entry
# then we put it on the device
results = []
for processor, model in zip(
[age_gender_processor, expression_processor],
[age_gender_model, expression_model],
):
y = processor(x, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = y.reshape(1, -1)
y = torch.from_numpy(y).to(device)
# run through model
with torch.no_grad():
y = model(y)
print(f"{y.shape=}")
if y.shape[0] == 2:
# Age-gender model
y = torch.hstack([y[1], y[2]])
else:
# Expression model
y = y[1]
# convert to numpy
y = y.detach().cpu().numpy()
results.append(y[0])
return (
100 * results[0][0], # age
{
"female": results[0][1],
"male": results[0][2],
"child": results[0][3],
},
{
"arousal": results[1][0],
"dominance": results[1][1],
"valence": results[1][2],
}
)
@spaces.GPU
def recognize(input_file):
# sampling_rate, signal = input_microphone
# signal = signal.astype(np.float32, order="C") / 32768.0
if input_file is None:
raise gr.Error(
"No audio file submitted! "
"Please upload or record an audio file "
"before submitting your request."
)
signal, sampling_rate = audiofile.read(input_file, duration=duration)
# Resample to sampling rate supported byu the models
target_rate = 16000
signal = audresample.resample(signal, sampling_rate, target_rate)
return process_func(signal, target_rate)
description = (
"Recognize "
f"[age and gender](https://huggingface.co/{age_gender_model_name}) "
f"and [expression](https://huggingface.co/{expression_model_name}) "
"of an audio file or microphone recording."
)
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Tab(label="Speech analysis"):
with gr.Row():
with gr.Column():
gr.Markdown(description)
input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Audio input",
)
gr.Markdown("Only the first second of the audio is processed.")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_age = gr.Textbox(label="Age")
output_gender = gr.Label(label="Gender")
output_expression = gr.Label(label="Expression")
outputs = [output_age, output_gender, output_expression]
submit_btn.click(recognize, input, outputs)
demo.launch(debug=True)