tools4eu commited on
Commit
97ff100
1 Parent(s): be85aed

added transcribe

Browse files
Files changed (1) hide show
  1. src/transcribe/transcribe.py +268 -0
src/transcribe/transcribe.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sys import platform
2
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
+ import logging
4
+ import torch
5
+ from transformers.utils import is_flash_attn_2_available
6
+ from pyannote.audio import Pipeline
7
+ from pyannote.core import Segment
8
+ import pandas as pd
9
+
10
+ languages = {
11
+ "English": "en",
12
+ "Chinese": "zh",
13
+ "German": "de",
14
+ "Spanish": "es",
15
+ "Russian": "ru",
16
+ "Korean": "ko",
17
+ "French": "fr",
18
+ "Japanese": "ja",
19
+ "Portuguese": "pt",
20
+ "Turkish": "tr",
21
+ "Polish": "pl",
22
+ "Catalan": "ca",
23
+ "Dutch": "nl",
24
+ "Arabic": "ar",
25
+ "Swedish": "sv",
26
+ "Italian": "it",
27
+ "Indonesian": "id",
28
+ "Hindi": "hi",
29
+ "Finnish": "fi",
30
+ "Vietnamese": "vi",
31
+ "Hebrew": "iw",
32
+ "Ukrainian": "uk",
33
+ "Greek": "el",
34
+ "Malay": "ms",
35
+ "Czech": "cs",
36
+ "Romanian": "ro",
37
+ "Danish": "da",
38
+ "Hungarian": "hu",
39
+ "Tamil": "ta",
40
+ "Norwegian": "no",
41
+ "Thai": "th",
42
+ "Urdu": "ur",
43
+ "Croatian": "hr",
44
+ "Bulgarian": "bg",
45
+ "Lithuanian": "lt",
46
+ "Latin": "la",
47
+ "Maori": "mi",
48
+ "Malayalam": "ml",
49
+ "Welsh": "cy",
50
+ "Slovak": "sk",
51
+ "Telugu": "te",
52
+ "Persian": "fa",
53
+ "Latvian": "lv",
54
+ "Bengali": "bn",
55
+ "Serbian": "sr",
56
+ "Azerbaijani": "az",
57
+ "Slovenian": "sl",
58
+ "Kannada": "kn",
59
+ "Estonian": "et",
60
+ "Macedonian": "mk",
61
+ "Breton": "br",
62
+ "Basque": "eu",
63
+ "Icelandic": "is",
64
+ "Armenian": "hy",
65
+ "Nepali": "ne",
66
+ "Mongolian": "mn",
67
+ "Bosnian": "bs",
68
+ "Kazakh": "kk",
69
+ "Albanian": "sq",
70
+ "Swahili": "sw",
71
+ "Galician": "gl",
72
+ "Marathi": "mr",
73
+ "Punjabi": "pa",
74
+ "Sinhala": "si",
75
+ "Khmer": "km",
76
+ "Shona": "sn",
77
+ "Yoruba": "yo",
78
+ "Somali": "so",
79
+ "Afrikaans": "af",
80
+ "Occitan": "oc",
81
+ "Georgian": "ka",
82
+ "Belarusian": "be",
83
+ "Tajik": "tg",
84
+ "Sindhi": "sd",
85
+ "Gujarati": "gu",
86
+ "Amharic": "am",
87
+ "Yiddish": "yi",
88
+ "Lao": "lo",
89
+ "Uzbek": "uz",
90
+ "Faroese": "fo",
91
+ "Haitian creole": "ht",
92
+ "Pashto": "ps",
93
+ "Turkmen": "tk",
94
+ "Nynorsk": "nn",
95
+ "Maltese": "mt",
96
+ "Sanskrit": "sa",
97
+ "Luxembourgish": "lb",
98
+ "Myanmar": "my",
99
+ "Tibetan": "bo",
100
+ "Tagalog": "tl",
101
+ "Malagasy": "mg",
102
+ "Assamese": "as",
103
+ "Tatar": "tt",
104
+ "Hawaiian": "haw",
105
+ "Lingala": "ln",
106
+ "Hausa": "ha",
107
+ "Bashkir": "ba",
108
+ "Javanese": "jw",
109
+ "Sundanese": "su",
110
+ }
111
+
112
+ if torch.cuda.is_available():
113
+ device = torch.device("cuda:0")
114
+ elif platform == "darwin":
115
+ device = torch.device("mps")
116
+ else:
117
+ device = torch.device("cpu")
118
+
119
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
120
+
121
+
122
+
123
+ def get_text_with_timestamp(transcribe_res):
124
+ timestamp_texts = []
125
+ for item in transcribe_res["chunks"]:
126
+ start = item["timestamp"][0]
127
+ end = item["timestamp"][1]
128
+ text = item["text"]
129
+ timestamp_texts.append((Segment(start, end), text))
130
+ return timestamp_texts
131
+
132
+
133
+ def add_speaker_info_to_text(timestamp_texts, ann):
134
+ spk_text = []
135
+ for seg, text in timestamp_texts:
136
+ spk = ann.crop(seg).argmax()
137
+ spk_text.append((seg, spk, text))
138
+ return spk_text
139
+
140
+
141
+ def merge_cache(text_cache):
142
+ sentence = "".join([item[-1] for item in text_cache])
143
+ spk = text_cache[0][1]
144
+ start = text_cache[0][0].start
145
+ end = text_cache[-1][0].end
146
+ return Segment(start, end), spk, sentence
147
+
148
+
149
+ PUNC_SENT_END = [".", "?", "!"]
150
+
151
+
152
+ def merge_sentence(spk_text):
153
+ merged_spk_text = []
154
+ pre_spk = None
155
+ text_cache = []
156
+ for seg, spk, text in spk_text:
157
+ if spk != pre_spk and pre_spk is not None and len(text_cache) > 0:
158
+ merged_spk_text.append(merge_cache(text_cache))
159
+ text_cache = [(seg, spk, text)]
160
+ pre_spk = spk
161
+
162
+ elif text[-1] in PUNC_SENT_END:
163
+ text_cache.append((seg, spk, text))
164
+ merged_spk_text.append(merge_cache(text_cache))
165
+ text_cache = []
166
+ pre_spk = spk
167
+ else:
168
+ text_cache.append((seg, spk, text))
169
+ pre_spk = spk
170
+ if len(text_cache) > 0:
171
+ merged_spk_text.append(merge_cache(text_cache))
172
+ return merged_spk_text
173
+
174
+ def diarize_text(transcribe_res, diarization_result):
175
+ timestamp_texts = get_text_with_timestamp(transcribe_res)
176
+ spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
177
+ res_processed = merge_sentence(spk_text)
178
+ return res_processed
179
+
180
+ def make_conversation(transcribe_result, diarization_result):
181
+ processed = diarize_text(transcribe_result, diarization_result)
182
+ df = pd.DataFrame(processed, columns=["segment", "speaker", "text"])[
183
+ ["speaker", "text"]
184
+ ]
185
+ df["key"] = (df["speaker"] != df["speaker"].shift(1)).astype(int).cumsum()
186
+ conversation = df.groupby(["key", "speaker"])["text"].apply(" ".join).reset_index()
187
+ conversation_list = list(zip(conversation.text, conversation.speaker))
188
+ return conversation_list
189
+
190
+ # def transcriber(input: str, language: str, translate: bool, progress) -> dict:
191
+ def transcriber(input: str, model: str, language: str, translate: bool, diarize: bool, input_diarization_token) -> dict:
192
+ """Transcribes the audio using the OpenAI Whisper model.
193
+ Args:
194
+ input: file path to the audio file in any format
195
+ language: name of the language in which the audio is recorded
196
+ translate: boolean indicator to enable immediate translation
197
+ Returns: transcription and segment-timestamps.
198
+ """
199
+ model_id = model
200
+
201
+ if diarize:
202
+
203
+ pipeline_diarization = Pipeline.from_pretrained(
204
+ "pyannote/speaker-diarization-3.1",
205
+ use_auth_token=input_diarization_token)
206
+
207
+ # send pipeline to GPU (when available)
208
+ pipeline_diarization.to(device)
209
+
210
+ # apply pretrained pipeline
211
+ diarization = pipeline_diarization(input)
212
+
213
+ # print the result
214
+ # for turn, _, speaker in diarization.itertracks(yield_label=True):
215
+ # print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")
216
+
217
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
218
+ model_id,
219
+ torch_dtype=torch_dtype,
220
+ low_cpu_mem_usage=True,
221
+ use_safetensors=True,
222
+ use_flash_attention_2=True if is_flash_attn_2_available() else False
223
+ )
224
+
225
+ print(device)
226
+
227
+ model.to(device)
228
+
229
+ processor = AutoProcessor.from_pretrained(model_id)
230
+
231
+ language = languages.get(language, None)
232
+ task = None
233
+ if translate:
234
+ task = "translate"
235
+
236
+ pipe = pipeline(
237
+ "automatic-speech-recognition",
238
+ model=model,
239
+ tokenizer=processor.tokenizer,
240
+ feature_extractor=processor.feature_extractor,
241
+ max_new_tokens=128,
242
+ chunk_length_s=15,
243
+ batch_size=16,
244
+ return_timestamps=True,
245
+ torch_dtype=torch_dtype,
246
+ device=device,
247
+ generate_kwargs={"task": task}
248
+ )
249
+
250
+
251
+ results = pipe(input)
252
+ results["text"] = results["text"].strip()
253
+
254
+ text = ""
255
+ chunks = results.get("chunks", [])
256
+ for chunk in chunks:
257
+ text += chunk["text"] + "\n"
258
+
259
+ # conversation = make_conversation(transcription, diarization)
260
+
261
+ # Transform the list to skip one line each time
262
+ # conversation_gradio = []
263
+ # for i in range(0, len(conversation), 2): # Increment by 2 to skip one line each time
264
+ # current_text = conversation[i][0]
265
+ # next_text = conversation[i + 1][0] if i + 1 < len(conversation) else ""
266
+ # conversation_gradio.append((current_text, next_text))
267
+
268
+ return text