Mihaj commited on
Commit
d5a7304
1 Parent(s): 9ad9701

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py CHANGED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
3
+ import os
4
+ import soundfile as sf
5
+ import torch
6
+
7
+ HF_TOKEN = os.environ.get("HF_TOKEN")
8
+
9
+ model_name = "bond005/wav2vec2-large-ru-golos-with-lm
10
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
11
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
12
+ pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
13
+
14
+ dict_v = ["а", "у" "о" "и" "э" "ы" "я" "ю" "е" "ё"]
15
+
16
+ def count_char_borders(predicted_ids, input_values, processor, sample_rate=16000):
17
+ predicted_ids_l = predicted_ids[0].tolist()
18
+ duration_sec = input_values.shape[1] / sample_rate
19
+
20
+ ids_c_time = [(i / len(predicted_ids_l) * duration_sec, _id) for i, _id in enumerate(predicted_ids_l)]
21
+
22
+ t_chars_list = [[i[0], detokenize_dict[i[1]]] for i in ids_c_time if i[1] != processor.tokenizer.pad_token_id]
23
+
24
+ t_chars_list_cl = []
25
+ cur = None
26
+ for i, item in enumerate(t_chars_list[:-1]):
27
+ if i == 0 or cur == None:
28
+ cur = item
29
+ if item[1] != t_chars_list[i + 1][1]:
30
+ cur.append(t_chars_list[i + 1][0])
31
+ t_chars_list_cl.append(cur)
32
+ cur = t_chars_list[i + 1]
33
+
34
+ t_chars_list_cl = [i if i[1] != "|" else [i[0], "", i[2]] for i in t_chars_list_cl]
35
+ chars, char_start_times, char_end_times = [], [], []
36
+ for c in t_chars_list_cl:
37
+ if c[1].lower() in dict_v and c[1] != "":
38
+ chars.append("v")
39
+ elif c[1] != "":
40
+ chars.append("c")
41
+ else:
42
+ chars.append("")
43
+ char_start_times.append(c[0])
44
+ char_end_times.append(c[2])
45
+ return chars, char_start_times, char_end_times
46
+
47
+
48
+
49
+ # обработка seg-файла, получение информации для расчётов
50
+ # предполагается, что на вход получаем seg либо 'corpres' - с разметкой по корпресу, либо упрощённая разметка 'cv' - с разметкой на согласные и гласные
51
+
52
+ def preprocess(chars, starts, labelled='cv'):
53
+ start_and_sound = []
54
+ # берём из seg-файла метки звуков, отсчёты переводим в секунды, получаем общую длительность
55
+ for e in info:
56
+ for i, item in enumerate(chars):
57
+ clean_e = e.strip()
58
+ start_time = float(starts[i])
59
+ label = item
60
+ start_and_sound.append([start_time, label])
61
+
62
+ # заводим переменные, необходимые для расчётов
63
+ clusters_and_duration = []
64
+ pauses = 0
65
+ sum_dur_vowels = 0
66
+ # флаг для определения границ кластеров. важно, если до и после паузы звуки одного класса
67
+ postpause_flag = 0
68
+
69
+ # обработка файлов с гласно-согласной разметкой
70
+ if labelled == 'cv':
71
+ total_duration = 0
72
+ # определяем к какому классу относится каждый звук и считаем длительность (отдельных гласных и согласных кластеров)
73
+ for n, i in enumerate(start_and_sound):
74
+ sound = i[1]
75
+ # определяем не является ли звук конечным
76
+ if n != len(start_and_sound) - 1:
77
+ duration = start_and_sound[n+1][0] - i[0]
78
+ # выделяем гласные
79
+ if sound == 'V' or sound == 'v':
80
+ total_duration += duration
81
+ # записываем отдельно звук в нулевой позиции в обход ошибки индекса
82
+ if n == 0:
83
+ clusters_and_duration.append(['V', duration])
84
+
85
+ # объединяем длительности, если предыдущий звук тоже был гласным
86
+ elif clusters_and_duration[-1][0] == 'V' and postpause_flag == 0:
87
+ clusters_and_duration[-1][1] += duration
88
+
89
+ # фиксируем длительность отдельного гласного звука
90
+ else:
91
+ clusters_and_duration.append(['V', duration])
92
+
93
+ # считаем длителность всех гласных интервалов в записи
94
+ sum_dur_vowels += duration
95
+ # снимаем флаг
96
+ postpause_flag = 0
97
+
98
+ # выделяем паузы
99
+ elif sound == '':
100
+ pauses += duration
101
+ total_duration += duration
102
+ # ставим флаг для следующего звука
103
+ postpause_flag = 1
104
+
105
+ # выделяем согласные
106
+ else:
107
+ total_duration += duration
108
+ # записываем отдельно звук в нулевой позиции в обход ошибки
109
+ if n == 0:
110
+ clusters_and_duration.append(['C', duration])
111
+
112
+ # объединяем длительности, если предыдущий звук тоже был согласным
113
+ elif clusters_and_duration[-1][0] == 'C' and postpause_flag == 0:
114
+ clusters_and_duration[-1][1] += duration
115
+
116
+ # фиксируем длительность отдельного согласного звука
117
+ else:
118
+ clusters_and_duration.append(['C', duration])
119
+
120
+ # снимаем флаг
121
+ postpause_flag = 0
122
+
123
+ # функция возвращает метки кластеров и их длительность и общую длительность всех гласных интервалов
124
+ return clusters_and_duration, sum_dur_vowels, total_duration, pauses
125
+
126
+
127
+ def delta_C(cons_clusters):
128
+ # применяем функцию numpy среднеквадратического отклонения
129
+ dC = np.std(cons_clusters)
130
+ return dC
131
+
132
+ def percent_V(vowels, total_wo_pauses):
133
+ pV = vowels / total_wo_pauses
134
+ return pV
135
+
136
+
137
+
138
+ def transcribe(audio):
139
+ y, sr = sf.read(audio)
140
+ input_values = processor(y, sampling_rate=sr, return_tensors="pt").input_values
141
+
142
+ logits = model(input_values).logits
143
+
144
+ predicted_ids = torch.argmax(logits, dim=-1)
145
+
146
+ chars, char_start_times, char_end_times = count_char_borders(predicted_ids, input_values, processor)
147
+
148
+ clusters_and_duration, sum_dur_vowels, total_duration, pauses = preprocess(chars, char_start_times)
149
+
150
+ # параметры для ΔC
151
+ for x in clusters_and_duration:
152
+ if x[0] == 'C':
153
+ cons_clusters.append(x[1])
154
+
155
+ # параметры для %V
156
+ vowels_duration += sum_dur_vowels
157
+ duration_without_pauses += total_duration - pauses
158
+
159
+ # расчёт метрик
160
+ dC = delta_C(cons_clusters)
161
+ pV = percent_V(vowels_duration, duration_without_pauses)
162
+
163
+ transcription = processor.decode(predicted_ids[0]).lower()
164
+
165
+ text = {"transcription": transcription}
166
+
167
+ text['dC'] = dC
168
+
169
+ text['pV'] = pV
170
+
171
+ return text
172
+
173
+ iface = gr.Interface(
174
+ fn=transcribe,
175
+ inputs=gr.Audio(type="filepath"),
176
+ outputs="text",
177
+ title="Mihaj/Wav2Vec2RhytmAnalyzer",
178
+ description=r"Realtime demo for rhytm analysis using a fine-tuned Wav2Vec large model from bond005. https://huggingface.co/bond005/wav2vec2-large-ru-golos-with-lm",
179
+ )
180
+
181
+ iface.launch()