Mihaj's picture
Update app.py
94f428a verified
raw
history blame contribute delete
No virus
9.24 kB
import gradio as gr
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import os
import soundfile as sf
import torch
import numpy as np
HF_TOKEN = os.environ.get("HF_TOKEN")
model_name = "bond005/wav2vec2-large-ru-golos-with-lm"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
detokenize_dict = {value: key for key, value in processor.tokenizer.get_vocab().items()}
dict_v = ["а", "у" "о" "и" "э" "ы" "я" "ю" "е" "ё"]
def count_char_borders(predicted_ids, input_values, processor, sample_rate=16000):
predicted_ids_l = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_c_time = [(i / len(predicted_ids_l) * duration_sec, _id) for i, _id in enumerate(predicted_ids_l)]
t_chars_list = [[i[0], detokenize_dict[i[1]]] for i in ids_c_time if i[1] != processor.tokenizer.pad_token_id]
t_chars_list_cl = []
cur = None
for i, item in enumerate(t_chars_list[:-1]):
if i == 0 or cur == None:
cur = item
if item[1] != t_chars_list[i + 1][1]:
cur.append(t_chars_list[i + 1][0])
t_chars_list_cl.append(cur)
cur = t_chars_list[i + 1]
t_chars_list_cl = [i if i[1] != "|" else [i[0], "", i[2]] for i in t_chars_list_cl]
chars, char_start_times, char_end_times = [], [], []
for c in t_chars_list_cl:
if c[1].lower() in dict_v and c[1] != "":
chars.append("v")
elif c[1] != "":
chars.append("c")
else:
chars.append("")
char_start_times.append(c[0])
char_end_times.append(c[2])
return chars, char_start_times, char_end_times
# обработка seg-файла, получение информации для расчётов
# предполагается, что на вход получаем seg либо 'corpres' - с разметкой по корпресу, либо упрощённая разметка 'cv' - с разметкой на согласные и гласные
def preprocess(chars, starts, labelled='cv'):
start_and_sound = []
# берём из seg-файла метки звуков, отсчёты переводим в секунды, получаем общую длительность
for i, item in enumerate(chars):
start_time = float(starts[i])
label = item
start_and_sound.append([start_time, label])
# заводим переменные, необходимые для расчётов
clusters_and_duration = []
pauses = 0
sum_dur_vowels = 0
# флаг для определения границ кластеров. важно, если до и после паузы звуки одного класса
postpause_flag = 0
# обработка файлов с гласно-согласной разметкой
if labelled == 'cv':
total_duration = 0
# определяем к какому классу относится каждый звук и считаем длительность (отдельных гласных и согласных кластеров)
for n, i in enumerate(start_and_sound):
sound = i[1]
# определяем не является ли звук конечным
if n != len(start_and_sound) - 1:
duration = start_and_sound[n+1][0] - i[0]
# выделяем гласные
if sound == 'V' or sound == 'v':
total_duration += duration
# записываем отдельно звук в нулевой позиции в обход ошибки индекса
if n == 0:
clusters_and_duration.append(['V', duration])
# объединяем длительности, если предыдущий звук тоже был гласным
elif clusters_and_duration[-1][0] == 'V' and postpause_flag == 0:
clusters_and_duration[-1][1] += duration
# фиксируем длительность отдельного гласного звука
else:
clusters_and_duration.append(['V', duration])
# считаем длителность всех гласных интервалов в записи
sum_dur_vowels += duration
# снимаем флаг
postpause_flag = 0
# выделяем паузы
elif sound == '':
pauses += duration
total_duration += duration
# ставим флаг для следующего звука
postpause_flag = 1
# выделяем согласные
else:
total_duration += duration
# записываем отдельно звук в нулевой позиции в обход ошибки
if n == 0:
clusters_and_duration.append(['C', duration])
# объединяем длительности, если предыдущий звук тоже был согласным
elif clusters_and_duration[-1][0] == 'C' and postpause_flag == 0:
clusters_and_duration[-1][1] += duration
# фиксируем длительность отдельного согласного звука
else:
clusters_and_duration.append(['C', duration])
# снимаем флаг
postpause_flag = 0
# функция возвращает метки кластеров и их длительность и общую длительность всех гласных интервалов
return clusters_and_duration, sum_dur_vowels, total_duration, pauses
def delta_C(cons_clusters):
# применяем функцию numpy среднеквадратического отклонения
dC = np.std(cons_clusters)
return dC
def percent_V(vowels, total_wo_pauses):
pV = vowels / total_wo_pauses
return pV
# point_1 = np.array((0, 0, 0))
# point_2 = np.array((3, 3, 3))
def count_eucl(point_1, point_2):
# Initializing the points
# Get the square of the difference of the 2 vectors
square = np.square(point_1 - point_2)
# Get the sum of the square
sum_square = np.sum(square)
# The last step is to get the square root and print the Euclidean distance
distance = np.sqrt(sum_square)
return distance
ex_dict = {"eng": np.array((0.0535, 0.401)), "kat": np.array((0.0452, 0.456)), "jap": np.array((0.0356, 0.531))}
def classify_rhytm(dC, pV):
our = np.array((dC, pV))
res = {}
if (dC > 0.08 and pV > 0.45) or (dC < 0.03 and pV < 0.04):
text = "Вы не укладываетесь ни в какие рамки и прекрасны в этом!"
else:
for k, v in ex_dict.items():
res[k] = count_eucl(our, v)
sorted_tuples = sorted(res.items(), key=lambda item: item[1])
sorted_res = {k: v for k, v in sorted_tuples}
if [i for i in sorted_res.keys()][0] == "eng":
text = "По типу ритма ваша речь близка к тактосчитающим языкам (английский)."
if [i for i in sorted_res.keys()][0] == "kat":
text = "По типу ритма ваша речь близка к слогосчитающим языкам (испанский)."
if [i for i in sorted_res.keys()][0] == "jap":
text = "По типу ритма ваша речь близка к моросчитающим языкам (японский)."
return text
def transcribe(audio):
y, sr = sf.read(audio, samplerate=16000)
input_values = processor(y, sampling_rate=sr, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
chars, char_start_times, char_end_times = count_char_borders(predicted_ids, input_values, processor)
clusters_and_duration, sum_dur_vowels, total_duration, pauses = preprocess(chars, char_start_times)
cons_clusters = []
# параметры для ΔC
for x in clusters_and_duration:
if x[0] == 'C':
cons_clusters.append(x[1])
# параметры для %V
vowels_duration = sum_dur_vowels
duration_without_pauses = total_duration - pauses
# расчёт метрик
dC = delta_C(cons_clusters) / 5
pV = percent_V(vowels_duration, duration_without_pauses) * 5
transcription = processor.batch_decode(logits.detach().numpy()).text[0]
text = {"transcription": transcription}
text['dC'] = dC
text['pV'] = pV
cl = classify_rhytm(dC, pV)
text['result'] = cl
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="Mihaj/Wav2Vec2RhytmAnalyzer",
description="Демо анализатор ритма на основе модели Wav2Vec large от bond005.",
)
iface.launch()