Ar4ikov's picture
Update app.py
cd6af6a
raw
history blame
2.23 kB
from transformers import pipeline
import gradio as gr
from pyctcdecode import BeamSearchDecoderCTC
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoConfig, AutoModel, Wav2Vec2FeatureExtractor
import librosa
import numpy as np
import subprocess
TRUST = True
SR = 16000
def resample(speech_array, sampling_rate):
speech = torch.from_numpy(speech_array)
print(speech, speech.shape, sampling_rate)
resampler = torchaudio.transforms.Resample(sampling_rate)
speech = resampler(speech).squeeze().numpy()
return speech
def predict(speech_array, sampling_rate):
speech = resample(speech_array, sampling_rate)
print(speech, speech.shape)
inputs = feature_extractor(speech, sampling_rate=SR, return_tensors="pt", padding=True)
inputs = {key: inputs[key].to(device) for key in inputs}
with torch.no_grad():
logits = model.to(device)(**inputs).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = {config.id2label[i]: round(float(score), 3) for i, score in enumerate(scores)}
return outputs
config = AutoConfig.from_pretrained('Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition', trust_remote_code=TRUST)
model = AutoModel.from_pretrained("Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition", trust_remote_code=TRUST)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("Aniemore/wav2vec2-xlsr-53-russian-emotion-recognition")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
def recognize(audio, state={}):
sr, audio_array = audio
audio_array = audio_array.astype(np.float32)
print(sr, audio_array)
state = predict(audio_array, sr)
return state, state
def test_some(audio):
sr, audio_array = audio
audio_array = audio_array.astype(np.float32)
return (sr, audio_array)
interface = gr.Interface(
fn=recognize,
inputs=[
gr.Audio(source="microphone", label="Скажите что-нибудь..."),
"state"
],
outputs=[
gr.Label(num_top_classes=7),
"state"
],
live=True,
theme="huggingface")
interface.launch(debug=True)