|
import math |
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
from faster_whisper import WhisperModel |
|
from pyannote.audio import Pipeline |
|
from pydub import AudioSegment |
|
from speechbrain.inference.classifiers import EncoderClassifier |
|
|
|
model_size = "large-v3" |
|
|
|
|
|
|
|
|
|
INDEX_TO_LANG = { |
|
0: 'Abkhazian', 1: 'Afrikaans', 2: 'Amharic', 3: 'Arabic', 4: 'Assamese', |
|
5: 'Azerbaijani', 6: 'Bashkir', 7: 'Belarusian', 8: 'Bulgarian', 9: 'Bengali', |
|
10: 'Tibetan', 11: 'Breton', 12: 'Bosnian', 13: 'Catalan', 14: 'Cebuano', |
|
15: 'Czech', 16: 'Welsh', 17: 'Danish', 18: 'German', 19: 'Greek', |
|
20: 'English', 21: 'Esperanto', 22: 'Spanish', 23: 'Estonian', 24: 'Basque', |
|
25: 'Persian', 26: 'Finnish', 27: 'Faroese', 28: 'French', 29: 'Galician', |
|
30: 'Guarani', 31: 'Gujarati', 32: 'Manx', 33: 'Hausa', 34: 'Hawaiian', |
|
35: 'Hindi', 36: 'Croatian', 37: 'Haitian', 38: 'Hungarian', 39: 'Armenian', |
|
40: 'Interlingua', 41: 'Indonesian', 42: 'Icelandic', 43: 'Italian', 44: 'Hebrew', |
|
45: 'Japanese', 46: 'Javanese', 47: 'Georgian', 48: 'Kazakh', 49: 'Central Khmer', |
|
50: 'Kannada', 51: 'Korean', 52: 'Latin', 53: 'Luxembourgish', 54: 'Lingala', |
|
55: 'Lao', 56: 'Lithuanian', 57: 'Latvian', 58: 'Malagasy', 59: 'Maori', |
|
60: 'Macedonian', 61: 'Malayalam', 62: 'Mongolian', 63: 'Marathi', 64: 'Malay', |
|
65: 'Maltese', 66: 'Burmese', 67: 'Nepali', 68: 'Dutch', 69: 'Norwegian Nynorsk', |
|
70: 'Norwegian', 71: 'Occitan', 72: 'Panjabi', 73: 'Polish', 74: 'Pushto', |
|
75: 'Portuguese', 76: 'Romanian', 77: 'Russian', 78: 'Sanskrit', 79: 'Scots', |
|
80: 'Sindhi', 81: 'Sinhala', 82: 'Slovak', 83: 'Slovenian', 84: 'Shona', |
|
85: 'Somali', 86: 'Albanian', 87: 'Serbian', 88: 'Sundanese', 89: 'Swedish', |
|
90: 'Swahili', 91: 'Tamil', 92: 'Telugu', 93: 'Tajik', 94: 'Thai', |
|
95: 'Turkmen', 96: 'Tagalog', 97: 'Turkish', 98: 'Tatar', 99: 'Ukrainian', |
|
100: 'Urdu', 101: 'Uzbek', 102: 'Vietnamese', 103: 'Waray', 104: 'Yiddish', |
|
105: 'Yoruba', 106: 'Chinese' |
|
} |
|
LANG_TO_INDEX = {v: k for k, v in INDEX_TO_LANG.items()} |
|
|
|
|
|
def identify_languages(file_path, languages: list[str] = ["Russian", "Belarusian", "Ukrainian", "Kazakh"]) -> dict[ |
|
str, float]: |
|
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa") |
|
signal = language_id.load_audio(file_path) |
|
lang_scores, _, _, _ = language_id.classify_batch(signal) |
|
all_scores = {INDEX_TO_LANG[i]: 100 * math.exp(score) for i, score in enumerate(lang_scores[0])} |
|
selected_scores = {lang: float(all_scores[lang]) for lang in languages} |
|
|
|
return selected_scores |
|
|
|
|
|
def detect_language_local(file_path): |
|
language_scores = identify_languages(file_path) |
|
language_result = max(language_scores, key=language_scores.get) |
|
if language_result.lower() in ["russian", "belarusian", "ukrainian"]: |
|
selected_language = "ru" |
|
else: |
|
selected_language = "kk" |
|
return selected_language |
|
|
|
|
|
def transcribe_and_diarize_audio(filename, language): |
|
diarized_segments = _diarize_audio(filename) |
|
combined_diarized_segments = _combine_segments_with_same_speaker(diarized_segments) |
|
transcribed_segments = _transcribe_audio(filename, language) |
|
pure_text = "\n".join(segment.text for segment in transcribed_segments) |
|
segments = _combine_diarized_and_transcribed_segments( |
|
combined_diarized_segments, |
|
transcribed_segments, |
|
) |
|
diarized_text = " ".join( |
|
"[%.1fs -> %.1fs] (%s) %s" % (segment["start"], segment["end"], segment["speaker"], segment["text"]) for segment |
|
in segments) |
|
|
|
return pure_text, diarized_text |
|
|
|
|
|
diarization_pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=os.getenv('HUGGING_FACE_TOKEN'), |
|
) |
|
model_size = "large-v3" |
|
transcription_model = WhisperModel( |
|
model_size, |
|
device="cuda:0" if torch.cuda.is_available() else "cpu", |
|
|
|
compute_type="int8", |
|
|
|
|
|
) |
|
|
|
|
|
def get_audio_length_in_minutes(file_path): |
|
audio = AudioSegment.from_file(file_path) |
|
duration = len(audio) |
|
return round(duration / 60000, 2) |
|
|
|
|
|
def _diarize_audio(filename): |
|
diarization_pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
|
diarization_result = diarization_pipeline(filename, max_speakers=2, min_speakers=2) |
|
diarized_segments = [] |
|
for turn, _, speaker in diarization_result.itertracks(yield_label=True): |
|
diarized_segments.append( |
|
{ |
|
"segment": {"start": turn.start, "end": turn.end}, |
|
"speaker": speaker, |
|
} |
|
) |
|
return diarized_segments |
|
|
|
|
|
def _combine_segments_with_same_speaker(segments): |
|
new_segments = [] |
|
prev_segment = cur_segment = segments[0] |
|
|
|
for i in range(1, len(segments)): |
|
cur_segment = segments[i] |
|
|
|
if cur_segment["speaker"] != prev_segment["speaker"] and i < len(segments): |
|
|
|
new_segments.append( |
|
{ |
|
"segment": { |
|
"start": prev_segment["segment"]["start"], |
|
"end": cur_segment["segment"]["start"], |
|
}, |
|
"speaker": prev_segment["speaker"], |
|
} |
|
) |
|
prev_segment = segments[i] |
|
return new_segments |
|
|
|
|
|
def _transcribe_audio(filename, language): |
|
segments, _ = transcription_model.transcribe( |
|
filename, |
|
beam_size=20, |
|
language=language, |
|
) |
|
return list(segments) |
|
|
|
|
|
def _combine_diarized_and_transcribed_segments(diarized_segments, transcribed_segments): |
|
|
|
end_timestamps = np.array( |
|
[ |
|
(chunk.end if chunk.end is not None else sys.float_info.max) |
|
for chunk in transcribed_segments |
|
] |
|
) |
|
segmented_preds = [] |
|
|
|
|
|
for segment in diarized_segments: |
|
|
|
end_time = segment["segment"]["end"] |
|
|
|
upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
|
|
|
segmented_preds.append( |
|
{ |
|
"speaker": segment["speaker"], |
|
"text": "".join( |
|
[chunk.text for chunk in transcribed_segments[: upto_idx + 1]] |
|
), |
|
"start": transcribed_segments[0].start, |
|
"end": transcribed_segments[upto_idx].end, |
|
} |
|
) |
|
|
|
|
|
transcribed_segments = transcribed_segments[upto_idx + 1:] |
|
end_timestamps = end_timestamps[upto_idx + 1:] |
|
|
|
if len(end_timestamps) == 0: |
|
break |
|
|
|
return segmented_preds |
|
|