File size: 7,076 Bytes
d5c8d04 46736da 119cb0d 46736da c017aa5 46736da c017aa5 46736da c017aa5 f87e3ab c017aa5 46736da c017aa5 46736da c017aa5 119cb0d c017aa5 61afcdf c017aa5 46736da c017aa5 46736da c017aa5 46736da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
import argparse
import re
from tqdm import tqdm
MAX_LENGTH = 500
MAGIC_STRING = "[$&]"
DEBUG = False
language_dict = {}
# Iterate over the LANGUAGE_NAME_TO_CODE dictionary
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
# Extract the language code (the first two characters before the underscore)
lang_code = language_code.split('_')[0].lower()
# Check if the language code is present in WHISPER_LANGUAGES
if lang_code in WHISPER_LANGUAGES:
# Construct the entry for the resulting dictionary
language_dict[language_name] = {
"transcriber": lang_code,
"translator": language_code
}
def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"):
# Get source and target languaje codes
source_languaje_code = language_dict[source_languaje]["translator"]
target_languaje_code = language_dict[target_languaje]["translator"]
encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device)
generated_tokens = translate_model.generate(
**encoded,
forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code]
)
translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translated
def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device):
output_folder = "translated_transcriptions"
_, transcription_file_name = transcription_file.split("/")
transcription_file_name, _ = transcription_file_name.split(".")
# Read transcription
with open(transcription_file, "r") as f:
transcription = f.read().splitlines()
# Concatenate transcriptions
raw_transcription = ""
progress_bar = tqdm(total=len(transcription), desc='Concatenate transcriptions progress')
for line in transcription:
if re.match(r"\d+$", line):
pass
elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line):
pass
elif re.match(r"^$", line):
pass
else:
line = re.sub(r"\[SPEAKER_\d\d\]:", MAGIC_STRING, line)
raw_transcription += f"{line} "
progress_bar.update(1)
progress_bar.close()
# Save raw transcription
if DEBUG:
output_file = f"{output_folder}/{transcription_file_name}_raw.srt"
with open(output_file, "w") as f:
f.write(raw_transcription)
# Split raw transcription
raw_transcription_list = raw_transcription.split(MAGIC_STRING)
if raw_transcription_list[0] == "":
raw_transcription_list = raw_transcription_list[1:]
# Concatenate transcripts and translate when length is less than MAX_LENGTH
translated_transcription = ""
concatenate_transcription = raw_transcription_list[0] + MAGIC_STRING
progress_bar = tqdm(total=len(raw_transcription_list), desc='Translate transcriptions progress')
progress_bar.update(1)
if len(raw_transcription_list) > 1:
for transcription in raw_transcription_list[1:]:
if len(concatenate_transcription) + len(transcription) < MAX_LENGTH:
concatenate_transcription += transcription + MAGIC_STRING
else:
translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
translated_transcription += translation
concatenate_transcription = transcription + MAGIC_STRING
progress_bar.update(1)
# Translate last part
translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
translated_transcription += translation
else:
translated_transcription = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
progress_bar.close()
# Save translated transcription raw
if DEBUG:
output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}_raw.srt"
with open(output_file, "w") as f:
f.write(translated_transcription)
# Read transcription
with open(transcription_file, "r") as f:
transcription = f.read().splitlines()
# Add time stamps
translated_transcription_time_stamps = ""
translated_transcription_list = translated_transcription.split(MAGIC_STRING)
progress_bar = tqdm(total=len(translated_transcription_list), desc='Add time stamps to translated transcriptions progress')
i = 0
for line in transcription:
if re.match(r"\d+$", line):
translated_transcription_time_stamps += f"{line}\n"
elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line):
translated_transcription_time_stamps += f"{line}\n"
elif re.match(r"^$", line):
translated_transcription_time_stamps += f"{line}\n"
else:
if (i < len(translated_transcription_list)):
if len(translated_transcription_list[i]) > 0:
if translated_transcription_list[i][0] == " ": # Remove space at the beginning
translated_transcription_list[i] = translated_transcription_list[i][1:]
speaker = ""
if re.match(r"\[SPEAKER_\d\d\]:", line):
speaker = re.match(r"\[SPEAKER_\d\d\]:", line).group(0)
translated_transcription_time_stamps += f"{speaker} {translated_transcription_list[i]}\n"
i += 1
progress_bar.update(1)
progress_bar.close()
# Save translated transcription
output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt"
with open(output_file, "w") as f:
f.write(translated_transcription_time_stamps)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("transcription_file", help="Transcribed text")
parser.add_argument("--source_languaje", type=str, required=True)
parser.add_argument("--target_languaje", type=str, required=True)
parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
transcription_file = args.transcription_file
source_languaje = args.source_languaje
target_languaje = args.target_languaje
device = args.device
# model
print("Loading translation model")
translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
print("Translation model loaded")
main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
|