File size: 3,597 Bytes
46736da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
import argparse
import re

language_dict = {}
# Iterate over the LANGUAGE_NAME_TO_CODE dictionary
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
    # Extract the language code (the first two characters before the underscore)
    lang_code = language_code.split('_')[0].lower()
    
    # Check if the language code is present in WHISPER_LANGUAGES
    if lang_code in WHISPER_LANGUAGES:
        # Construct the entry for the resulting dictionary
        language_dict[language_name] = {
            "transcriber": lang_code,
            "translator": language_code
        }



def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"):
    # Get source and target languaje codes
    source_languaje_code = language_dict[source_languaje]["translator"]
    target_languaje_code = language_dict[target_languaje]["translator"]

    encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device)
    generated_tokens = translate_model.generate(
        **encoded,
        forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code]
    )
    translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return translated

def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device):
    output_folder = "translated_transcriptions"
    _, transcription_file_name = transcription_file.split("/")
    transcription_file_name, _ = transcription_file_name.split(".")

    # Read transcription
    with open(transcription_file, "r") as f:
        transcription = f.read().splitlines()
    
    # Translate
    translate_transcription = ""
    for line in transcription:
        if re.match(r"\d+$", line):
            translate_transcription += f"{line}\n"
        elif re.match(r"\d\d:\d\d:\d\d,\d\d\d --> \d\d:\d\d:\d\d,\d\d\d", line):
            translate_transcription += f"{line}\n"
        elif re.match(r"^$", line):
            translate_transcription += f"{line}\n"
        else:
            translated = translate(line, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
            # translated = line
            translate_transcription += f"{translated}\n"
    
    # Save translation
    output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt"
    with open(output_file, "w") as f:
        f.write(translate_transcription)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("transcription_file", help="Transcribed text")
    parser.add_argument("--source_languaje", type=str, required=True)
    parser.add_argument("--target_languaje", type=str, required=True)
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    transcription_file = args.transcription_file
    source_languaje = args.source_languaje
    target_languaje = args.target_languaje
    device = args.device

    # model
    print("Loading translation model")
    translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
    translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
    print("Translation model loaded")

    main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device)