transcrip / func.py
zerk1's picture
Create func.py
7987a90 verified
import math
import os
import sys
import numpy as np
import torch
from faster_whisper import WhisperModel
from pyannote.audio import Pipeline
from pydub import AudioSegment
from speechbrain.inference.classifiers import EncoderClassifier
model_size = "large-v3"
# Run on GPU with FP16
# model = WhisperModel(model_size, device="cpu", compute_type="float16")
INDEX_TO_LANG = {
0: 'Abkhazian', 1: 'Afrikaans', 2: 'Amharic', 3: 'Arabic', 4: 'Assamese',
5: 'Azerbaijani', 6: 'Bashkir', 7: 'Belarusian', 8: 'Bulgarian', 9: 'Bengali',
10: 'Tibetan', 11: 'Breton', 12: 'Bosnian', 13: 'Catalan', 14: 'Cebuano',
15: 'Czech', 16: 'Welsh', 17: 'Danish', 18: 'German', 19: 'Greek',
20: 'English', 21: 'Esperanto', 22: 'Spanish', 23: 'Estonian', 24: 'Basque',
25: 'Persian', 26: 'Finnish', 27: 'Faroese', 28: 'French', 29: 'Galician',
30: 'Guarani', 31: 'Gujarati', 32: 'Manx', 33: 'Hausa', 34: 'Hawaiian',
35: 'Hindi', 36: 'Croatian', 37: 'Haitian', 38: 'Hungarian', 39: 'Armenian',
40: 'Interlingua', 41: 'Indonesian', 42: 'Icelandic', 43: 'Italian', 44: 'Hebrew',
45: 'Japanese', 46: 'Javanese', 47: 'Georgian', 48: 'Kazakh', 49: 'Central Khmer',
50: 'Kannada', 51: 'Korean', 52: 'Latin', 53: 'Luxembourgish', 54: 'Lingala',
55: 'Lao', 56: 'Lithuanian', 57: 'Latvian', 58: 'Malagasy', 59: 'Maori',
60: 'Macedonian', 61: 'Malayalam', 62: 'Mongolian', 63: 'Marathi', 64: 'Malay',
65: 'Maltese', 66: 'Burmese', 67: 'Nepali', 68: 'Dutch', 69: 'Norwegian Nynorsk',
70: 'Norwegian', 71: 'Occitan', 72: 'Panjabi', 73: 'Polish', 74: 'Pushto',
75: 'Portuguese', 76: 'Romanian', 77: 'Russian', 78: 'Sanskrit', 79: 'Scots',
80: 'Sindhi', 81: 'Sinhala', 82: 'Slovak', 83: 'Slovenian', 84: 'Shona',
85: 'Somali', 86: 'Albanian', 87: 'Serbian', 88: 'Sundanese', 89: 'Swedish',
90: 'Swahili', 91: 'Tamil', 92: 'Telugu', 93: 'Tajik', 94: 'Thai',
95: 'Turkmen', 96: 'Tagalog', 97: 'Turkish', 98: 'Tatar', 99: 'Ukrainian',
100: 'Urdu', 101: 'Uzbek', 102: 'Vietnamese', 103: 'Waray', 104: 'Yiddish',
105: 'Yoruba', 106: 'Chinese'
}
LANG_TO_INDEX = {v: k for k, v in INDEX_TO_LANG.items()}
def identify_languages(file_path, languages: list[str] = ["Russian", "Belarusian", "Ukrainian", "Kazakh"]) -> dict[
str, float]:
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa")
signal = language_id.load_audio(file_path)
lang_scores, _, _, _ = language_id.classify_batch(signal)
all_scores = {INDEX_TO_LANG[i]: 100 * math.exp(score) for i, score in enumerate(lang_scores[0])}
selected_scores = {lang: float(all_scores[lang]) for lang in languages}
return selected_scores
def detect_language_local(file_path):
language_scores = identify_languages(file_path)
language_result = max(language_scores, key=language_scores.get)
if language_result.lower() in ["russian", "belarusian", "ukrainian"]:
selected_language = "ru"
else:
selected_language = "kk"
return selected_language
def transcribe_and_diarize_audio(filename, language):
diarized_segments = _diarize_audio(filename)
combined_diarized_segments = _combine_segments_with_same_speaker(diarized_segments)
transcribed_segments = _transcribe_audio(filename, language)
pure_text = "\n".join(segment.text for segment in transcribed_segments)
segments = _combine_diarized_and_transcribed_segments(
combined_diarized_segments,
transcribed_segments,
)
diarized_text = " ".join(
"[%.1fs -> %.1fs] (%s) %s" % (segment["start"], segment["end"], segment["speaker"], segment["text"]) for segment
in segments)
return pure_text, diarized_text
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv('HUGGING_FACE_TOKEN'),
)
model_size = "large-v3"
transcription_model = WhisperModel(
model_size,
device="cuda:0" if torch.cuda.is_available() else "cpu",
# device="cpu",
compute_type="int8",
# compute_type="int8_float16",
# compute_type="float32"
)
def get_audio_length_in_minutes(file_path):
audio = AudioSegment.from_file(file_path)
duration = len(audio)
return round(duration / 60000, 2)
def _diarize_audio(filename):
diarization_pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# diarization_pipeline.to(torch.device("cpu"))
diarization_result = diarization_pipeline(filename, max_speakers=2, min_speakers=2)
diarized_segments = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
diarized_segments.append(
{
"segment": {"start": turn.start, "end": turn.end},
"speaker": speaker,
}
)
return diarized_segments
def _combine_segments_with_same_speaker(segments):
new_segments = []
prev_segment = cur_segment = segments[0]
for i in range(1, len(segments)):
cur_segment = segments[i]
# check if we have changed speaker ("label")
if cur_segment["speaker"] != prev_segment["speaker"] and i < len(segments):
# add the start/end times for the super-segment to the new list
new_segments.append(
{
"segment": {
"start": prev_segment["segment"]["start"],
"end": cur_segment["segment"]["start"],
},
"speaker": prev_segment["speaker"],
}
)
prev_segment = segments[i]
return new_segments
def _transcribe_audio(filename, language):
segments, _ = transcription_model.transcribe(
filename,
beam_size=20,
language=language,
)
return list(segments)
def _combine_diarized_and_transcribed_segments(diarized_segments, transcribed_segments):
# get the end timestamps for each chunk from the ASR output
end_timestamps = np.array(
[
(chunk.end if chunk.end is not None else sys.float_info.max)
for chunk in transcribed_segments
]
)
segmented_preds = []
# align the diarizer timestamps and the ASR timestamps
for segment in diarized_segments:
# get the diarizer end timestamp
end_time = segment["segment"]["end"]
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
segmented_preds.append(
{
"speaker": segment["speaker"],
"text": "".join(
[chunk.text for chunk in transcribed_segments[: upto_idx + 1]]
),
"start": transcribed_segments[0].start,
"end": transcribed_segments[upto_idx].end,
}
)
# crop the transcribed_segmentss and timestamp lists according to the latest timestamp (for faster argmin)
transcribed_segments = transcribed_segments[upto_idx + 1:]
end_timestamps = end_timestamps[upto_idx + 1:]
if len(end_timestamps) == 0:
break
return segmented_preds