File size: 4,001 Bytes
8acc64a
 
7308733
 
 
8acc64a
 
7308733
 
 
5815bc8
7308733
 
5815bc8
 
7308733
 
8acc64a
7308733
 
8acc64a
5815bc8
7308733
 
 
 
8acc64a
7308733
 
 
 
 
 
 
 
 
5815bc8
 
 
 
 
 
 
 
7308733
5815bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7308733
5815bc8
 
 
 
7308733
5815bc8
 
 
 
 
 
7308733
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# src/translator.py
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import re

class Translator:
    def __init__(self, model_name: str = "facebook/mbart-large-50-many-to-many-mmt"):
        print("Loading mBART translation model...")
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
        print("mBART model loaded.")
        # Added 'vi_VN' for Vietnamese
        self.lang_code_map = {
           "en": "en_XX", "es": "es_XX", "fr": "fr_XX", "de": "de_DE", "ja": "ja_XX",
           "zh": "zh_CN", "hi": "hi_IN", "pt": "pt_PT", "ko": "ko_KR", "ta": "ta_IN",
           "uk": "uk_UA", "ru": "ru_RU", "ar": "ar_AR", "vi": "vi_VN"
        }

    def _extract_technical_terms(self, text: str) -> (str, dict):
        term_pattern = r'\b([A-Z][a-zA-Z0-9_]+[A-Z]|[A-Z]{2,}|[A-Za-z]+[0-9]+[A-Za-z_]*)\b'
        terms = re.findall(term_pattern, text)
        
        placeholder_map = {}
        for i, term in enumerate(set(terms)):
            placeholder = f"__TERM{i}__"
            text = re.sub(r'\b' + re.escape(term) + r'\b', placeholder, text)
            placeholder_map[placeholder] = term
        
        return text, placeholder_map
    
    def _reinsert_technical_terms(self, translated_text: str, placeholder_map: dict) -> str:
        for placeholder, term in placeholder_map.items():
            translated_text = re.sub(r'\s*' + re.escape(placeholder) + r'\s*', f' {term} ', translated_text)
        return translated_text.strip()

    def translate_segments(self, segments: list, src_lang: str, target_lang: str, preserve_technical: bool) -> list:
        print(f"Step 3: Translating {len(segments)} segments from {src_lang} to {target_lang}...")
        
        # Ensure source and target languages are supported
        src_lang_code = self.lang_code_map.get(src_lang)
        target_lang_code = self.lang_code_map.get(target_lang)

        if not src_lang_code or not target_lang_code:
            raise ValueError("Unsupported source or target language for translation.")

        translated_segments = []
        original_texts = [segment['text'] for segment in segments]
        
        # Batch processing placeholder logic
        processed_texts = []
        placeholder_maps = []
        if preserve_technical:
            for text in original_texts:
                processed_text, placeholder_map = self._extract_technical_terms(text)
                processed_texts.append(processed_text)
                placeholder_maps.append(placeholder_map)
        else:
            processed_texts = original_texts

        try:
            self.tokenizer.src_lang = src_lang_code
            encoded_batch = self.tokenizer(processed_texts, return_tensors="pt", padding=True, truncation=True)
            
            target_lang_id = self.tokenizer.lang_code_to_id[target_lang_code]
            generated_tokens = self.model.generate(**encoded_batch, forced_bos_token_id=target_lang_id)
            
            translated_batch = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            for i, translated_text in enumerate(translated_batch):
                final_text = translated_text
                if preserve_technical:
                    final_text = self._reinsert_technical_terms(translated_text, placeholder_maps[i])
                
                new_segment = segments[i].copy()
                new_segment['text'] = final_text
                translated_segments.append(new_segment)

        except Exception as e:
            print(f"Error during batch translation: {e}. Falling back to individual translation.")
            # Fallback to segment-by-segment if batch fails (less efficient but more robust)
            return super().translate_segments(segments, src_lang, target_lang, preserve_technical)

        print(f"Translation to {target_lang} complete.")
        return translated_segments