Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC | |
import os | |
import soundfile as sf | |
import torch | |
import numpy as np | |
import librosa | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
model_name = "bond005/wav2vec2-large-ru-golos-with-lm" | |
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name) | |
model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder) | |
detokenize_dict = {value: key for key, value in processor.tokenizer.get_vocab().items()} | |
dict_v = ["а", "у" "о" "и" "э" "ы" "я" "ю" "е" "ё"] | |
def count_char_borders(predicted_ids, input_values, processor, sample_rate=16000): | |
predicted_ids_l = predicted_ids[0].tolist() | |
duration_sec = input_values.shape[1] / sample_rate | |
ids_c_time = [(i / len(predicted_ids_l) * duration_sec, _id) for i, _id in enumerate(predicted_ids_l)] | |
t_chars_list = [[i[0], detokenize_dict[i[1]]] for i in ids_c_time if i[1] != processor.tokenizer.pad_token_id] | |
t_chars_list_cl = [] | |
cur = None | |
for i, item in enumerate(t_chars_list[:-1]): | |
if i == 0 or cur == None: | |
cur = item | |
if item[1] != t_chars_list[i + 1][1]: | |
cur.append(t_chars_list[i + 1][0]) | |
t_chars_list_cl.append(cur) | |
cur = t_chars_list[i + 1] | |
t_chars_list_cl = [i if i[1] != "|" else [i[0], "", i[2]] for i in t_chars_list_cl] | |
chars, char_start_times, char_end_times = [], [], [] | |
for c in t_chars_list_cl: | |
if c[1].lower() in dict_v and c[1] != "": | |
chars.append("v") | |
elif c[1] != "": | |
chars.append("c") | |
else: | |
chars.append("") | |
char_start_times.append(c[0]) | |
char_end_times.append(c[2]) | |
return chars, char_start_times, char_end_times | |
# обработка seg-файла, получение информации для расчётов | |
# предполагается, что на вход получаем seg либо 'corpres' - с разметкой по корпресу, либо упрощённая разметка 'cv' - с разметкой на согласные и гласные | |
def preprocess(chars, starts, labelled='cv'): | |
start_and_sound = [] | |
# берём из seg-файла метки звуков, отсчёты переводим в секунды, получаем общую длительность | |
for i, item in enumerate(chars): | |
start_time = float(starts[i]) | |
label = item | |
start_and_sound.append([start_time, label]) | |
# заводим переменные, необходимые для расчётов | |
clusters_and_duration = [] | |
pauses = 0 | |
sum_dur_vowels = 0 | |
# флаг для определения границ кластеров. важно, если до и после паузы звуки одного класса | |
postpause_flag = 0 | |
# обработка файлов с гласно-согласной разметкой | |
if labelled == 'cv': | |
total_duration = 0 | |
# определяем к какому классу относится каждый звук и считаем длительность (отдельных гласных и согласных кластеров) | |
for n, i in enumerate(start_and_sound): | |
sound = i[1] | |
# определяем не является ли звук конечным | |
if n != len(start_and_sound) - 1: | |
duration = start_and_sound[n+1][0] - i[0] | |
# выделяем гласные | |
if sound == 'V' or sound == 'v': | |
total_duration += duration | |
# записываем отдельно звук в нулевой позиции в обход ошибки индекса | |
if n == 0: | |
clusters_and_duration.append(['V', duration]) | |
# объединяем длительности, если предыдущий звук тоже был гласным | |
elif clusters_and_duration[-1][0] == 'V' and postpause_flag == 0: | |
clusters_and_duration[-1][1] += duration | |
# фиксируем длительность отдельного гласного звука | |
else: | |
clusters_and_duration.append(['V', duration]) | |
# считаем длителность всех гласных интервалов в записи | |
sum_dur_vowels += duration | |
# снимаем флаг | |
postpause_flag = 0 | |
# выделяем паузы | |
elif sound == '': | |
pauses += duration | |
total_duration += duration | |
# ставим флаг для следующего звука | |
postpause_flag = 1 | |
# выделяем согласные | |
else: | |
total_duration += duration | |
# записываем отдельно звук в нулевой позиции в обход ошибки | |
if n == 0: | |
clusters_and_duration.append(['C', duration]) | |
# объединяем длительности, если предыдущий звук тоже был согласным | |
elif clusters_and_duration[-1][0] == 'C' and postpause_flag == 0: | |
clusters_and_duration[-1][1] += duration | |
# фиксируем длительность отдельного согласного звука | |
else: | |
clusters_and_duration.append(['C', duration]) | |
# снимаем флаг | |
postpause_flag = 0 | |
# функция возвращает метки кластеров и их длительность и общую длительность всех гласных интервалов | |
return clusters_and_duration, sum_dur_vowels, total_duration, pauses | |
def delta_C(cons_clusters): | |
# применяем функцию numpy среднеквадратического отклонения | |
dC = np.std(cons_clusters) | |
return dC | |
def percent_V(vowels, total_wo_pauses): | |
pV = vowels / total_wo_pauses | |
return pV | |
# point_1 = np.array((0, 0, 0)) | |
# point_2 = np.array((3, 3, 3)) | |
def count_eucl(point_1, point_2): | |
# Initializing the points | |
# Get the square of the difference of the 2 vectors | |
square = np.square(point_1 - point_2) | |
# Get the sum of the square | |
sum_square = np.sum(square) | |
# The last step is to get the square root and print the Euclidean distance | |
distance = np.sqrt(sum_square) | |
return distance | |
ex_dict = {"eng": np.array((0.0535, 0.401)), "kat": np.array((0.0452, 0.456)), "jap": np.array((0.0356, 0.531))} | |
def classify_rhytm(dC, pV): | |
our = np.array((dC, pV)) | |
res = {} | |
if (dC > 0.08 and pV > 0.45) or (dC < 0.03 and pV < 0.04): | |
text = "Вы не укладываетесь ни в какие рамки и прекрасны в этом!" | |
else: | |
for k, v in ex_dict.items(): | |
res[k] = count_eucl(our, v) | |
sorted_tuples = sorted(res.items(), key=lambda item: item[1]) | |
sorted_res = {k: v for k, v in sorted_tuples} | |
if [i for i in sorted_res.keys()][0] == "eng": | |
text = "По типу ритма ваша речь близка к тактосчитающим языкам (английский)." | |
if [i for i in sorted_res.keys()][0] == "kat": | |
text = "По типу ритма ваша речь близка к слогосчитающим языкам (испанский)." | |
if [i for i in sorted_res.keys()][0] == "jap": | |
text = "По типу ритма ваша речь близка к моросчитающим языкам (японский)." | |
return text | |
def transcribe(audio): | |
y, sr = librosa.load(audio, sr=16000) | |
input_values = processor(y, sampling_rate=sr, return_tensors="pt").input_values | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
chars, char_start_times, char_end_times = count_char_borders(predicted_ids, input_values, processor) | |
clusters_and_duration, sum_dur_vowels, total_duration, pauses = preprocess(chars, char_start_times) | |
cons_clusters = [] | |
# параметры для ΔC | |
for x in clusters_and_duration: | |
if x[0] == 'C': | |
cons_clusters.append(x[1]) | |
# параметры для %V | |
vowels_duration = sum_dur_vowels | |
duration_without_pauses = total_duration - pauses | |
# расчёт метрик | |
dC = delta_C(cons_clusters) | |
pV = percent_V(vowels_duration, duration_without_pauses) | |
transcription = processor.batch_decode(logits.detach().numpy()).text[0] | |
text = {"transcription": transcription} | |
text['dC'] = dC | |
text['pV'] = pV | |
cl = classify_rhytm(dC, pV) | |
text['result'] = cl | |
return text | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(type="filepath"), | |
outputs="text", | |
title="Mihaj/Wav2Vec2RhytmAnalyzer", | |
description="Демо анализатор ритма на основе модели Wav2Vec large от bond005.", | |
) | |
iface.launch() |