aus10powell commited on
Commit
33d6c4f
1 Parent(s): b20b18b

Upload translation.py

Browse files
Files changed (1) hide show
  1. scripts/translation.py +104 -0
scripts/translation.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
2
+ from transformers import AutoTokenizer
3
+ import re
4
+
5
+
6
+ class PersianTextProcessor:
7
+ """
8
+ A class for processing Persian text.
9
+
10
+ Attributes:
11
+ model_size (str): The size of the MT5 model.
12
+ model_name (str): The name of the MT5 model.
13
+ tokenizer (MT5Tokenizer): The MT5 tokenizer.
14
+ model (MT5ForConditionalGeneration): The MT5 model.
15
+
16
+ Methods:
17
+ clean_persian_text(text): Cleans the given Persian text.
18
+ translate_text(persian_text): Translates the given Persian text to English.
19
+ """
20
+
21
+ def __init__(self, model_size="small"):
22
+ """
23
+ Initializes the PersianTextProcessor class.
24
+
25
+ Args:
26
+ model_size (str): The size of the MT5 model.
27
+ """
28
+ self.model_size = model_size
29
+ self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en"
30
+ self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en")
31
+ self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name)
32
+
33
+ def clean_persian_text(self, text):
34
+ """
35
+ Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters.
36
+
37
+ Args:
38
+ text (str): The input Persian text.
39
+
40
+ Returns:
41
+ str: The cleaned Persian text.
42
+ """
43
+ # Create a regular expression to match emojis.
44
+ emoji_pattern = re.compile(
45
+ "["
46
+ "\U0001F600-\U0001F64F" # emoticons
47
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
48
+ "\U0001F680-\U0001F6FF" # transport & map symbols
49
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
50
+ "]+",
51
+ flags=re.UNICODE,
52
+ )
53
+
54
+ # Create a regular expression to match specific patterns.
55
+ pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*"
56
+
57
+ # Remove emojis, specific patterns, and special characters from the text.
58
+ text = emoji_pattern.sub("", text)
59
+ text = re.sub(pattern, "", text)
60
+ text = text.replace("✌", "")
61
+ text = text.replace("@", "")
62
+ text = text.replace("#", "hashtag_")
63
+
64
+ return text
65
+
66
+ def run_model(self, input_string, **generator_args):
67
+ """
68
+ Runs the MT5 model on the given input string.
69
+
70
+ Args:
71
+ input_string (str): The input string.
72
+ **generator_args: Additional arguments to pass to the MT5 model.
73
+
74
+ Returns:
75
+ str: The output of the MT5 model.
76
+ """
77
+ # Encode the input string as a sequence of tokens.
78
+ input_ids = self.tokenizer.encode(input_string, return_tensors="pt")
79
+
80
+ # Generate the output text.
81
+ res = self.model.generate(input_ids, **generator_args)
82
+
83
+ # Decode the output text to a string.
84
+ output = self.tokenizer.batch_decode(res, skip_special_tokens=True)
85
+
86
+ return output
87
+
88
+ def translate_text(self, persian_text):
89
+ """
90
+ Translates the given Persian text to English.
91
+
92
+ Args:
93
+ persian_text (str): The Persian text to translate.
94
+
95
+ Returns:
96
+ str: The translated text.
97
+ """
98
+ # Clean the Persian text.
99
+ text_cleaned = self.clean_persian_text(persian_text)
100
+
101
+ # Translate the cleaned text.
102
+ translated_text = self.run_model(input_string=text_cleaned)
103
+
104
+ return translated_text