PaulNdrei commited on
Commit
5f540b3
1 Parent(s): 719cc71

Add segmenter

Browse files
Files changed (6) hide show
  1. app.py +10 -31
  2. requirements.txt +3 -1
  3. segment.srx +0 -0
  4. srx_segmenter.py +108 -0
  5. texttokenizer.py +72 -0
  6. translate.py +67 -0
app.py CHANGED
@@ -1,13 +1,11 @@
1
- import gc
2
  import os
3
  from dotenv import load_dotenv
4
  import gradio as gr
5
  from AinaTheme import AinaGradioTheme
6
- import sentencepiece as spm
7
- import ctranslate2
8
  from huggingface_hub import snapshot_download
9
  import nltk
10
- from nltk import sent_tokenize
 
11
 
12
  nltk.download('punkt')
13
 
@@ -15,7 +13,7 @@ load_dotenv()
15
 
16
  MODELS_PATH = "./models"
17
  HF_CACHE_DIR = "./hf_cache"
18
- MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=500))
19
 
20
  def download_model(repo_id, revision="main"):
21
  return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
@@ -30,14 +28,14 @@ model_dir_ca_fr = download_model("projecte-aina/mt-aina-ca-fr", revision="main")
30
  model_dir_fr_ca = download_model("projecte-aina/mt-aina-fr-ca", revision="main")
31
 
32
  model_dir_ca_de = download_model("projecte-aina/mt-aina-ca-de", revision="main")
33
- model_dir_de_ca = download_model("projecte-aina/mt-aina-de-ca", revision="main")
34
 
35
  model_dir_ca_it = download_model("projecte-aina/mt-aina-ca-it", revision="main")
36
- model_dir_it_ca = download_model("projecte-aina/mt-aina-it-ca", revision="main")
37
 
38
  model_dir_ca_pt = download_model("projecte-aina/mt-aina-ca-pt", revision="main")
39
  model_dir_pt_ca = download_model("projecte-aina/mt-aina-pt-ca", revision="main")
40
 
 
 
41
  directions = {
42
  "Catalan": {
43
  "target": {
@@ -45,6 +43,7 @@ directions = {
45
  "English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_en)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_en)}")},
46
  "French": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}")},
47
  "German": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_de)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_de)}")},
 
48
  "Italian": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_it)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_it)}")},
49
  "Portuguese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}")}
50
 
@@ -65,16 +64,6 @@ directions = {
65
  "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}")},
66
  }
67
  },
68
- "German": {
69
- "target": {
70
- "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_de_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_de_ca)}")},
71
- }
72
- },
73
- "Italian": {
74
- "target": {
75
- "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_it_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_it_ca)}")},
76
- }
77
- },
78
  "Portuguese": {
79
  "target": {
80
  "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}")},
@@ -92,7 +81,7 @@ def get_target_languege_model(source_language, target_language):
92
  return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model")
93
 
94
 
95
- def translate(source, lang_pair):
96
  """Use CTranslate model to translate a sentence
97
 
98
  Args:
@@ -102,18 +91,8 @@ def translate(source, lang_pair):
102
  Returns:
103
  Translation of the source text
104
  """
105
- sp_model = spm.SentencePieceProcessor(lang_pair[0])
106
- translator = ctranslate2.Translator(lang_pair[1])
107
-
108
- source_sentences = sent_tokenize(source)
109
- source_tokenized = sp_model.encode(source_sentences, out_type=str)
110
- translations = translator.translate_batch(source_tokenized)
111
- translations = [translation[0]["tokens"] for translation in translations]
112
- translations_detokenized = sp_model.decode(translations)
113
- translation = " ".join(translations_detokenized)
114
- translation = translation.replace(' ⁇', ':')
115
-
116
- gc.collect()
117
 
118
  return translation
119
 
@@ -125,7 +104,7 @@ def translate_input(input, source_language, target_language):
125
  return None
126
 
127
  target_language_model = get_target_languege_model(source_language, target_language)
128
- translation = translate(input, target_language_model)
129
 
130
  return translation
131
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  from AinaTheme import AinaGradioTheme
 
 
5
  from huggingface_hub import snapshot_download
6
  import nltk
7
+
8
+ from translate import translate_text
9
 
10
  nltk.download('punkt')
11
 
 
13
 
14
  MODELS_PATH = "./models"
15
  HF_CACHE_DIR = "./hf_cache"
16
+ MAX_INPUT_CHARACTERS= int(os.environ.get("MAX_INPUT_CHARACTERS", default=1000))
17
 
18
  def download_model(repo_id, revision="main"):
19
  return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
 
28
  model_dir_fr_ca = download_model("projecte-aina/mt-aina-fr-ca", revision="main")
29
 
30
  model_dir_ca_de = download_model("projecte-aina/mt-aina-ca-de", revision="main")
 
31
 
32
  model_dir_ca_it = download_model("projecte-aina/mt-aina-ca-it", revision="main")
 
33
 
34
  model_dir_ca_pt = download_model("projecte-aina/mt-aina-ca-pt", revision="main")
35
  model_dir_pt_ca = download_model("projecte-aina/mt-aina-pt-ca", revision="main")
36
 
37
+ model_dir_ca_zh = download_model("projecte-aina/mt-aina-ca-zh", revision="main")
38
+
39
  directions = {
40
  "Catalan": {
41
  "target": {
 
43
  "English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_en)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_en)}")},
44
  "French": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}")},
45
  "German": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_de)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_de)}")},
46
+ "Chinese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_zh)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_zh)}")},
47
  "Italian": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_it)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_it)}")},
48
  "Portuguese": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}")}
49
 
 
64
  "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}")},
65
  }
66
  },
 
 
 
 
 
 
 
 
 
 
67
  "Portuguese": {
68
  "target": {
69
  "Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}")},
 
81
  return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model")
82
 
83
 
84
+ def translate(text, source_language, lang_pair):
85
  """Use CTranslate model to translate a sentence
86
 
87
  Args:
 
91
  Returns:
92
  Translation of the source text
93
  """
94
+
95
+ translation = translate_text(text, source_language, lang_pair)
 
 
 
 
 
 
 
 
 
 
96
 
97
  return translation
98
 
 
104
  return None
105
 
106
  target_language_model = get_target_languege_model(source_language, target_language)
107
+ translation = translate(input, source_language, target_language_model)
108
 
109
  return translation
110
 
requirements.txt CHANGED
@@ -3,4 +3,6 @@ gradio==4.8.0
3
  ctranslate2==3.23.0
4
  nltk==3.8.1
5
  sentencepiece==0.1.99
6
- python-dotenv==1.0.0
 
 
 
3
  ctranslate2==3.23.0
4
  nltk==3.8.1
5
  sentencepiece==0.1.99
6
+ python-dotenv==1.0.0
7
+ pyonmttok==1.37.1
8
+ lxml==4.9.3
segment.srx ADDED
The diff for this file is too large to render. See raw diff
 
srx_segmenter.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Segment text with SRX.
3
+ """
4
+ __version__ = '0.0.2'
5
+
6
+ import lxml.etree
7
+ import regex
8
+ from typing import (
9
+ List,
10
+ Set,
11
+ Tuple,
12
+ Dict,
13
+ Optional
14
+ )
15
+
16
+
17
+ class SrxSegmenter:
18
+ """Handle segmentation with SRX regex format.
19
+ """
20
+ def __init__(self, rule: Dict[str, List[Tuple[str, Optional[str]]]], source_text: str) -> None:
21
+ self.source_text = source_text
22
+ self.non_breaks = rule.get('non_breaks', [])
23
+ self.breaks = rule.get('breaks', [])
24
+
25
+ def _get_break_points(self, regexes: List[Tuple[str, str]]) -> Set[int]:
26
+ return set([
27
+ match.span(1)[1]
28
+ for before, after in regexes
29
+ for match in regex.finditer('({})({})'.format(before, after), self.source_text)
30
+ ])
31
+
32
+ def get_non_break_points(self) -> Set[int]:
33
+ """Return segment non break points
34
+ """
35
+ return self._get_break_points(self.non_breaks)
36
+
37
+ def get_break_points(self) -> Set[int]:
38
+ """Return segment break points
39
+ """
40
+ return self._get_break_points(self.breaks)
41
+
42
+ def extract(self) -> Tuple[List[str], List[str]]:
43
+ """Return segments and whitespaces.
44
+ """
45
+ non_break_points = self.get_non_break_points()
46
+ candidate_break_points = self.get_break_points()
47
+
48
+ break_point = sorted(candidate_break_points - non_break_points)
49
+ source_text = self.source_text
50
+
51
+ segments = [] # type: List[str]
52
+ whitespaces = [] # type: List[str]
53
+ previous_foot = ""
54
+ for start, end in zip([0] + break_point, break_point + [len(source_text)]):
55
+ segment_with_space = source_text[start:end]
56
+ candidate_segment = segment_with_space.strip()
57
+ if not candidate_segment:
58
+ previous_foot += segment_with_space
59
+ continue
60
+
61
+ head, segment, foot = segment_with_space.partition(candidate_segment)
62
+
63
+ segments.append(segment)
64
+ whitespaces.append('{}{}'.format(previous_foot, head))
65
+ previous_foot = foot
66
+ whitespaces.append(previous_foot)
67
+
68
+ return segments, whitespaces
69
+
70
+
71
+ def parse(srx_filepath: str) -> Dict[str, Dict[str, List[Tuple[str, Optional[str]]]]]:
72
+ """Parse SRX file and return it.
73
+ :param srx_filepath: is soruce SRX file.
74
+ :return: dict
75
+ """
76
+ tree = lxml.etree.parse(srx_filepath)
77
+ namespaces = {
78
+ 'ns': 'http://www.lisa.org/srx20'
79
+ }
80
+
81
+ rules = {}
82
+
83
+ for languagerule in tree.xpath('//ns:languagerule', namespaces=namespaces):
84
+ rule_name = languagerule.attrib.get('languagerulename')
85
+ if rule_name is None:
86
+ continue
87
+
88
+ current_rule = {
89
+ 'breaks': [],
90
+ 'non_breaks': [],
91
+ }
92
+
93
+ for rule in languagerule.xpath('ns:rule', namespaces=namespaces):
94
+ is_break = rule.attrib.get('break', 'yes') == 'yes'
95
+ rule_holder = current_rule['breaks'] if is_break else current_rule['non_breaks']
96
+
97
+ beforebreak = rule.find('ns:beforebreak', namespaces=namespaces)
98
+ beforebreak_text = '' if beforebreak.text is None else beforebreak.text
99
+
100
+ afterbreak = rule.find('ns:afterbreak', namespaces=namespaces)
101
+ afterbreak_text = '' if afterbreak.text is None else afterbreak.text
102
+
103
+ rule_holder.append((beforebreak_text, afterbreak_text))
104
+
105
+ rules[rule_name] = current_rule
106
+
107
+ #_add_softcatala_rules(rules)
108
+ return rules
texttokenizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ #
4
+ # Copyright (c) 2020 Jordi Mas i Hernandez <jmas@softcatala.org>
5
+ #
6
+ # This program is free software; you can redistribute it and/or
7
+ # modify it under the terms of the GNU Lesser General Public
8
+ # License as published by the Free Software Foundation; either
9
+ # version 2.1 of the License, or (at your option) any later version.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14
+ # Lesser General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU Lesser General Public
17
+ # License along with this program; if not, write to the
18
+ # Free Software Foundation, Inc., 59 Temple Place - Suite 330,
19
+ # Boston, MA 02111-1307, USA.
20
+
21
+ from __future__ import print_function
22
+ from srx_segmenter import SrxSegmenter, parse
23
+ import os
24
+
25
+
26
+ def add_breakline_rule(rules,language):
27
+ rules[language]["breaks"].append(["\n", #Before
28
+ ""] # After
29
+ )
30
+ return rules
31
+
32
+
33
+ '''
34
+ Splits text into sentences keeping spaces to allow later
35
+ to reconstruct the same text but with translatabled text changed
36
+ '''
37
+ class TextTokenizer:
38
+ def __init__(self, language):
39
+ srx_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'segment.srx')
40
+ self.rules = parse(srx_filepath)
41
+ self.language = language
42
+ self.rules = add_breakline_rule(self.rules,language)
43
+
44
+ def tokenize(self, sentence):
45
+ strings = []
46
+ translate = []
47
+
48
+ segmenter = SrxSegmenter(self.rules[self.language], sentence)
49
+ segments, whitespaces = segmenter.extract()
50
+
51
+ for i in range(len(segments)):
52
+ whitespace = whitespaces[i]
53
+ if len(whitespace) > 0:
54
+ strings.append(whitespace)
55
+ translate.append(False)
56
+
57
+ string = segments[i]
58
+ strings.append(string)
59
+ translate.append(True)
60
+
61
+ return strings, translate
62
+
63
+ def sentence_from_tokens(self, sentences, translate, translated):
64
+ num_sentences = len(sentences)
65
+ translation = ''
66
+ for i in range(0, num_sentences):
67
+ if translate[i] is True:
68
+ translation += translated[i]
69
+ else:
70
+ translation += sentences[i]
71
+
72
+ return translation.strip()
translate.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import ctranslate2
3
+ import pyonmttok
4
+ from huggingface_hub import snapshot_download
5
+
6
+ from texttokenizer import TextTokenizer
7
+ import unicodedata
8
+
9
+
10
+ def _normalize_input_string(result):
11
+ result = unicodedata.normalize('NFC', result)
12
+ return result
13
+
14
+ def _translate_batch(input_batch, spm, model, max_sentence_batch=10):
15
+
16
+ batch_input_tokenized = []
17
+ batch_input_markers = []
18
+
19
+ #preserve_markup = PreserveMarkup()
20
+
21
+ num_sentences = len(input_batch)
22
+ for pos in range(0, num_sentences):
23
+ tokenized = spm.tokenize(input_batch[pos])[0]
24
+ batch_input_tokenized.append(tokenized)
25
+
26
+ batch_output = []
27
+ for offset in range(0,len(batch_input_tokenized), max_sentence_batch):
28
+ partial_result = model.translate_batch(batch_input_tokenized[offset:offset+max_sentence_batch], return_scores=False, replace_unknowns=True)
29
+ for pos in range(0,len(partial_result)):
30
+ tokenized = partial_result[pos][0]['tokens']
31
+ translated = spm.detokenize(tokenized)
32
+ batch_output.append(translated)
33
+
34
+ return batch_output
35
+
36
+
37
+ def translate_text(sample_text, source_language, lang_pair, max_sentence_batch=20):
38
+
39
+ spm = pyonmttok.Tokenizer(mode="none",sp_model_path=lang_pair[0])
40
+ translator = ctranslate2.Translator(lang_pair[1], device="cpu")
41
+ tokenizer=TextTokenizer(source_language)
42
+
43
+ text = _normalize_input_string(sample_text)
44
+ sentences, translate = tokenizer.tokenize(text)
45
+ num_sentences = len(sentences)
46
+ sentences_batch = []
47
+ indexes = []
48
+ results = ["" for x in range(num_sentences)]
49
+ for i in range(num_sentences):
50
+ if translate[i] is False:
51
+ continue
52
+
53
+ sentences_batch.append(sentences[i])
54
+ indexes.append(i)
55
+
56
+ translated_batch = _translate_batch(sentences_batch, spm, translator, max_sentence_batch)
57
+ for pos in range(0, len(translated_batch)):
58
+ i = indexes[pos]
59
+ results[i] = translated_batch[pos]
60
+
61
+ #Rebuild split sentences
62
+ translated = tokenizer.sentence_from_tokens(sentences, translate, results)
63
+
64
+ gc.collect()
65
+
66
+ return translated
67
+