Spaces:
Runtime error
Runtime error
Add segmenter
Browse files- app.py +10 -31
- requirements.txt +3 -1
- segment.srx +0 -0
- srx_segmenter.py +108 -0
- texttokenizer.py +72 -0
- translate.py +67 -0
app.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
import gc
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
import gradio as gr
|
5 |
from AinaTheme import AinaGradioTheme
|
6 |
-
import sentencepiece as spm
|
7 |
-
import ctranslate2
|
8 |
from huggingface_hub import snapshot_download
|
9 |
import nltk
|
10 |
-
|
|
|
11 |
|
12 |
nltk.download('punkt')
|
13 |
|
@@ -15,7 +13,7 @@ load_dotenv()
|
|
15 |
|
16 |
MODELS_PATH = "./models"
|
17 |
HF_CACHE_DIR = "./hf_cache"
|
18 |
-
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=
|
19 |
|
20 |
def download_model(repo_id, revision="main"):
|
21 |
return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
|
@@ -30,14 +28,14 @@ model_dir_ca_fr = download_model("projecte-aina/mt-aina-ca-fr", revision="main")
|
|
30 |
model_dir_fr_ca = download_model("projecte-aina/mt-aina-fr-ca", revision="main")
|
31 |
|
32 |
model_dir_ca_de = download_model("projecte-aina/mt-aina-ca-de", revision="main")
|
33 |
-
model_dir_de_ca = download_model("projecte-aina/mt-aina-de-ca", revision="main")
|
34 |
|
35 |
model_dir_ca_it = download_model("projecte-aina/mt-aina-ca-it", revision="main")
|
36 |
-
model_dir_it_ca = download_model("projecte-aina/mt-aina-it-ca", revision="main")
|
37 |
|
38 |
model_dir_ca_pt = download_model("projecte-aina/mt-aina-ca-pt", revision="main")
|
39 |
model_dir_pt_ca = download_model("projecte-aina/mt-aina-pt-ca", revision="main")
|
40 |
|
|
|
|
|
41 |
directions = {
|
42 |
"Catalan": {
|
43 |
"target": {
|
@@ -45,6 +43,7 @@ directions = {
|
|
45 |
"English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_en)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_en)}")},
|
46 |
"French": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}")},
|
47 |
"German": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_de)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_de)}")},
|
|
|
48 |
"Italian": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_it)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_it)}")},
|
49 |
"Portuguese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}")}
|
50 |
|
@@ -65,16 +64,6 @@ directions = {
|
|
65 |
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}")},
|
66 |
}
|
67 |
},
|
68 |
-
"German": {
|
69 |
-
"target": {
|
70 |
-
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_de_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_de_ca)}")},
|
71 |
-
}
|
72 |
-
},
|
73 |
-
"Italian": {
|
74 |
-
"target": {
|
75 |
-
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_it_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_it_ca)}")},
|
76 |
-
}
|
77 |
-
},
|
78 |
"Portuguese": {
|
79 |
"target": {
|
80 |
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}")},
|
@@ -92,7 +81,7 @@ def get_target_languege_model(source_language, target_language):
|
|
92 |
return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model")
|
93 |
|
94 |
|
95 |
-
def translate(
|
96 |
"""Use CTranslate model to translate a sentence
|
97 |
|
98 |
Args:
|
@@ -102,18 +91,8 @@ def translate(source, lang_pair):
|
|
102 |
Returns:
|
103 |
Translation of the source text
|
104 |
"""
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
source_sentences = sent_tokenize(source)
|
109 |
-
source_tokenized = sp_model.encode(source_sentences, out_type=str)
|
110 |
-
translations = translator.translate_batch(source_tokenized)
|
111 |
-
translations = [translation[0]["tokens"] for translation in translations]
|
112 |
-
translations_detokenized = sp_model.decode(translations)
|
113 |
-
translation = " ".join(translations_detokenized)
|
114 |
-
translation = translation.replace(' ⁇', ':')
|
115 |
-
|
116 |
-
gc.collect()
|
117 |
|
118 |
return translation
|
119 |
|
@@ -125,7 +104,7 @@ def translate_input(input, source_language, target_language):
|
|
125 |
return None
|
126 |
|
127 |
target_language_model = get_target_languege_model(source_language, target_language)
|
128 |
-
translation = translate(input, target_language_model)
|
129 |
|
130 |
return translation
|
131 |
|
|
|
|
|
1 |
import os
|
2 |
from dotenv import load_dotenv
|
3 |
import gradio as gr
|
4 |
from AinaTheme import AinaGradioTheme
|
|
|
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
import nltk
|
7 |
+
|
8 |
+
from translate import translate_text
|
9 |
|
10 |
nltk.download('punkt')
|
11 |
|
|
|
13 |
|
14 |
MODELS_PATH = "./models"
|
15 |
HF_CACHE_DIR = "./hf_cache"
|
16 |
+
MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=1000))
|
17 |
|
18 |
def download_model(repo_id, revision="main"):
|
19 |
return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
|
|
|
28 |
model_dir_fr_ca = download_model("projecte-aina/mt-aina-fr-ca", revision="main")
|
29 |
|
30 |
model_dir_ca_de = download_model("projecte-aina/mt-aina-ca-de", revision="main")
|
|
|
31 |
|
32 |
model_dir_ca_it = download_model("projecte-aina/mt-aina-ca-it", revision="main")
|
|
|
33 |
|
34 |
model_dir_ca_pt = download_model("projecte-aina/mt-aina-ca-pt", revision="main")
|
35 |
model_dir_pt_ca = download_model("projecte-aina/mt-aina-pt-ca", revision="main")
|
36 |
|
37 |
+
model_dir_ca_zh = download_model("projecte-aina/mt-aina-ca-zh", revision="main")
|
38 |
+
|
39 |
directions = {
|
40 |
"Catalan": {
|
41 |
"target": {
|
|
|
43 |
"English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_en)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_en)}")},
|
44 |
"French": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}")},
|
45 |
"German": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_de)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_de)}")},
|
46 |
+
"Chinese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_zh)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_zh)}")},
|
47 |
"Italian": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_it)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_it)}")},
|
48 |
"Portuguese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}")}
|
49 |
|
|
|
64 |
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}")},
|
65 |
}
|
66 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
"Portuguese": {
|
68 |
"target": {
|
69 |
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}")},
|
|
|
81 |
return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model")
|
82 |
|
83 |
|
84 |
+
def translate(text, source_language, lang_pair):
|
85 |
"""Use CTranslate model to translate a sentence
|
86 |
|
87 |
Args:
|
|
|
91 |
Returns:
|
92 |
Translation of the source text
|
93 |
"""
|
94 |
+
|
95 |
+
translation = translate_text(text, source_language, lang_pair)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
return translation
|
98 |
|
|
|
104 |
return None
|
105 |
|
106 |
target_language_model = get_target_languege_model(source_language, target_language)
|
107 |
+
translation = translate(input, source_language, target_language_model)
|
108 |
|
109 |
return translation
|
110 |
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ gradio==4.8.0
|
|
3 |
ctranslate2==3.23.0
|
4 |
nltk==3.8.1
|
5 |
sentencepiece==0.1.99
|
6 |
-
python-dotenv==1.0.0
|
|
|
|
|
|
3 |
ctranslate2==3.23.0
|
4 |
nltk==3.8.1
|
5 |
sentencepiece==0.1.99
|
6 |
+
python-dotenv==1.0.0
|
7 |
+
pyonmttok==1.37.1
|
8 |
+
lxml==4.9.3
|
segment.srx
ADDED
The diff for this file is too large to render.
See raw diff
|
|
srx_segmenter.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""Segment text with SRX.
|
3 |
+
"""
|
4 |
+
__version__ = '0.0.2'
|
5 |
+
|
6 |
+
import lxml.etree
|
7 |
+
import regex
|
8 |
+
from typing import (
|
9 |
+
List,
|
10 |
+
Set,
|
11 |
+
Tuple,
|
12 |
+
Dict,
|
13 |
+
Optional
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class SrxSegmenter:
|
18 |
+
"""Handle segmentation with SRX regex format.
|
19 |
+
"""
|
20 |
+
def __init__(self, rule: Dict[str, List[Tuple[str, Optional[str]]]], source_text: str) -> None:
|
21 |
+
self.source_text = source_text
|
22 |
+
self.non_breaks = rule.get('non_breaks', [])
|
23 |
+
self.breaks = rule.get('breaks', [])
|
24 |
+
|
25 |
+
def _get_break_points(self, regexes: List[Tuple[str, str]]) -> Set[int]:
|
26 |
+
return set([
|
27 |
+
match.span(1)[1]
|
28 |
+
for before, after in regexes
|
29 |
+
for match in regex.finditer('({})({})'.format(before, after), self.source_text)
|
30 |
+
])
|
31 |
+
|
32 |
+
def get_non_break_points(self) -> Set[int]:
|
33 |
+
"""Return segment non break points
|
34 |
+
"""
|
35 |
+
return self._get_break_points(self.non_breaks)
|
36 |
+
|
37 |
+
def get_break_points(self) -> Set[int]:
|
38 |
+
"""Return segment break points
|
39 |
+
"""
|
40 |
+
return self._get_break_points(self.breaks)
|
41 |
+
|
42 |
+
def extract(self) -> Tuple[List[str], List[str]]:
|
43 |
+
"""Return segments and whitespaces.
|
44 |
+
"""
|
45 |
+
non_break_points = self.get_non_break_points()
|
46 |
+
candidate_break_points = self.get_break_points()
|
47 |
+
|
48 |
+
break_point = sorted(candidate_break_points - non_break_points)
|
49 |
+
source_text = self.source_text
|
50 |
+
|
51 |
+
segments = [] # type: List[str]
|
52 |
+
whitespaces = [] # type: List[str]
|
53 |
+
previous_foot = ""
|
54 |
+
for start, end in zip([0] + break_point, break_point + [len(source_text)]):
|
55 |
+
segment_with_space = source_text[start:end]
|
56 |
+
candidate_segment = segment_with_space.strip()
|
57 |
+
if not candidate_segment:
|
58 |
+
previous_foot += segment_with_space
|
59 |
+
continue
|
60 |
+
|
61 |
+
head, segment, foot = segment_with_space.partition(candidate_segment)
|
62 |
+
|
63 |
+
segments.append(segment)
|
64 |
+
whitespaces.append('{}{}'.format(previous_foot, head))
|
65 |
+
previous_foot = foot
|
66 |
+
whitespaces.append(previous_foot)
|
67 |
+
|
68 |
+
return segments, whitespaces
|
69 |
+
|
70 |
+
|
71 |
+
def parse(srx_filepath: str) -> Dict[str, Dict[str, List[Tuple[str, Optional[str]]]]]:
|
72 |
+
"""Parse SRX file and return it.
|
73 |
+
:param srx_filepath: is soruce SRX file.
|
74 |
+
:return: dict
|
75 |
+
"""
|
76 |
+
tree = lxml.etree.parse(srx_filepath)
|
77 |
+
namespaces = {
|
78 |
+
'ns': 'http://www.lisa.org/srx20'
|
79 |
+
}
|
80 |
+
|
81 |
+
rules = {}
|
82 |
+
|
83 |
+
for languagerule in tree.xpath('//ns:languagerule', namespaces=namespaces):
|
84 |
+
rule_name = languagerule.attrib.get('languagerulename')
|
85 |
+
if rule_name is None:
|
86 |
+
continue
|
87 |
+
|
88 |
+
current_rule = {
|
89 |
+
'breaks': [],
|
90 |
+
'non_breaks': [],
|
91 |
+
}
|
92 |
+
|
93 |
+
for rule in languagerule.xpath('ns:rule', namespaces=namespaces):
|
94 |
+
is_break = rule.attrib.get('break', 'yes') == 'yes'
|
95 |
+
rule_holder = current_rule['breaks'] if is_break else current_rule['non_breaks']
|
96 |
+
|
97 |
+
beforebreak = rule.find('ns:beforebreak', namespaces=namespaces)
|
98 |
+
beforebreak_text = '' if beforebreak.text is None else beforebreak.text
|
99 |
+
|
100 |
+
afterbreak = rule.find('ns:afterbreak', namespaces=namespaces)
|
101 |
+
afterbreak_text = '' if afterbreak.text is None else afterbreak.text
|
102 |
+
|
103 |
+
rule_holder.append((beforebreak_text, afterbreak_text))
|
104 |
+
|
105 |
+
rules[rule_name] = current_rule
|
106 |
+
|
107 |
+
#_add_softcatala_rules(rules)
|
108 |
+
return rules
|
texttokenizer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
#
|
4 |
+
# Copyright (c) 2020 Jordi Mas i Hernandez <jmas@softcatala.org>
|
5 |
+
#
|
6 |
+
# This program is free software; you can redistribute it and/or
|
7 |
+
# modify it under the terms of the GNU Lesser General Public
|
8 |
+
# License as published by the Free Software Foundation; either
|
9 |
+
# version 2.1 of the License, or (at your option) any later version.
|
10 |
+
#
|
11 |
+
# This program is distributed in the hope that it will be useful,
|
12 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
13 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
14 |
+
# Lesser General Public License for more details.
|
15 |
+
#
|
16 |
+
# You should have received a copy of the GNU Lesser General Public
|
17 |
+
# License along with this program; if not, write to the
|
18 |
+
# Free Software Foundation, Inc., 59 Temple Place - Suite 330,
|
19 |
+
# Boston, MA 02111-1307, USA.
|
20 |
+
|
21 |
+
from __future__ import print_function
|
22 |
+
from srx_segmenter import SrxSegmenter, parse
|
23 |
+
import os
|
24 |
+
|
25 |
+
|
26 |
+
def add_breakline_rule(rules,language):
|
27 |
+
rules[language]["breaks"].append(["\n", #Before
|
28 |
+
""] # After
|
29 |
+
)
|
30 |
+
return rules
|
31 |
+
|
32 |
+
|
33 |
+
'''
|
34 |
+
Splits text into sentences keeping spaces to allow later
|
35 |
+
to reconstruct the same text but with translatabled text changed
|
36 |
+
'''
|
37 |
+
class TextTokenizer:
|
38 |
+
def __init__(self, language):
|
39 |
+
srx_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'segment.srx')
|
40 |
+
self.rules = parse(srx_filepath)
|
41 |
+
self.language = language
|
42 |
+
self.rules = add_breakline_rule(self.rules,language)
|
43 |
+
|
44 |
+
def tokenize(self, sentence):
|
45 |
+
strings = []
|
46 |
+
translate = []
|
47 |
+
|
48 |
+
segmenter = SrxSegmenter(self.rules[self.language], sentence)
|
49 |
+
segments, whitespaces = segmenter.extract()
|
50 |
+
|
51 |
+
for i in range(len(segments)):
|
52 |
+
whitespace = whitespaces[i]
|
53 |
+
if len(whitespace) > 0:
|
54 |
+
strings.append(whitespace)
|
55 |
+
translate.append(False)
|
56 |
+
|
57 |
+
string = segments[i]
|
58 |
+
strings.append(string)
|
59 |
+
translate.append(True)
|
60 |
+
|
61 |
+
return strings, translate
|
62 |
+
|
63 |
+
def sentence_from_tokens(self, sentences, translate, translated):
|
64 |
+
num_sentences = len(sentences)
|
65 |
+
translation = ''
|
66 |
+
for i in range(0, num_sentences):
|
67 |
+
if translate[i] is True:
|
68 |
+
translation += translated[i]
|
69 |
+
else:
|
70 |
+
translation += sentences[i]
|
71 |
+
|
72 |
+
return translation.strip()
|
translate.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import ctranslate2
|
3 |
+
import pyonmttok
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
|
6 |
+
from texttokenizer import TextTokenizer
|
7 |
+
import unicodedata
|
8 |
+
|
9 |
+
|
10 |
+
def _normalize_input_string(result):
|
11 |
+
result = unicodedata.normalize('NFC', result)
|
12 |
+
return result
|
13 |
+
|
14 |
+
def _translate_batch(input_batch, spm, model, max_sentence_batch=10):
|
15 |
+
|
16 |
+
batch_input_tokenized = []
|
17 |
+
batch_input_markers = []
|
18 |
+
|
19 |
+
#preserve_markup = PreserveMarkup()
|
20 |
+
|
21 |
+
num_sentences = len(input_batch)
|
22 |
+
for pos in range(0, num_sentences):
|
23 |
+
tokenized = spm.tokenize(input_batch[pos])[0]
|
24 |
+
batch_input_tokenized.append(tokenized)
|
25 |
+
|
26 |
+
batch_output = []
|
27 |
+
for offset in range(0,len(batch_input_tokenized), max_sentence_batch):
|
28 |
+
partial_result = model.translate_batch(batch_input_tokenized[offset:offset+max_sentence_batch], return_scores=False, replace_unknowns=True)
|
29 |
+
for pos in range(0,len(partial_result)):
|
30 |
+
tokenized = partial_result[pos][0]['tokens']
|
31 |
+
translated = spm.detokenize(tokenized)
|
32 |
+
batch_output.append(translated)
|
33 |
+
|
34 |
+
return batch_output
|
35 |
+
|
36 |
+
|
37 |
+
def translate_text(sample_text, source_language, lang_pair, max_sentence_batch=20):
|
38 |
+
|
39 |
+
spm = pyonmttok.Tokenizer(mode="none",sp_model_path=lang_pair[0])
|
40 |
+
translator = ctranslate2.Translator(lang_pair[1], device="cpu")
|
41 |
+
tokenizer=TextTokenizer(source_language)
|
42 |
+
|
43 |
+
text = _normalize_input_string(sample_text)
|
44 |
+
sentences, translate = tokenizer.tokenize(text)
|
45 |
+
num_sentences = len(sentences)
|
46 |
+
sentences_batch = []
|
47 |
+
indexes = []
|
48 |
+
results = ["" for x in range(num_sentences)]
|
49 |
+
for i in range(num_sentences):
|
50 |
+
if translate[i] is False:
|
51 |
+
continue
|
52 |
+
|
53 |
+
sentences_batch.append(sentences[i])
|
54 |
+
indexes.append(i)
|
55 |
+
|
56 |
+
translated_batch = _translate_batch(sentences_batch, spm, translator, max_sentence_batch)
|
57 |
+
for pos in range(0, len(translated_batch)):
|
58 |
+
i = indexes[pos]
|
59 |
+
results[i] = translated_batch[pos]
|
60 |
+
|
61 |
+
#Rebuild split sentences
|
62 |
+
translated = tokenizer.sentence_from_tokens(sentences, translate, results)
|
63 |
+
|
64 |
+
gc.collect()
|
65 |
+
|
66 |
+
return translated
|
67 |
+
|