from transformers import MT5ForConditionalGeneration, MT5Tokenizer from transformers import AutoTokenizer import re class PersianTextProcessor: """ A class for processing Persian text. Attributes: model_size (str): The size of the MT5 model. model_name (str): The name of the MT5 model. tokenizer (MT5Tokenizer): The MT5 tokenizer. model (MT5ForConditionalGeneration): The MT5 model. Methods: clean_persian_text(text): Cleans the given Persian text. translate_text(persian_text): Translates the given Persian text to English. """ def __init__(self, model_size="small"): """ Initializes the PersianTextProcessor class. Args: model_size (str): The size of the MT5 model. """ self.model_size = model_size self.model_name = f"persiannlp/mt5-{self.model_size}-parsinlu-opus-translation_fa_en" self.tokenizer =MT5Tokenizer.from_pretrained(self.model_name) #AutoTokenizer.from_pretrained("persiannlp/mt5-small-parsinlu-opus-translation_fa_en") self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name) def clean_persian_text(self, text): """ Cleans the given Persian text by removing emojis, specific patterns, and replacing special characters. Args: text (str): The input Persian text. Returns: str: The cleaned Persian text. """ # Create a regular expression to match emojis. emoji_pattern = re.compile( "[" "\U0001F600-\U0001F64F" # emoticons "\U0001F300-\U0001F5FF" # symbols & pictographs "\U0001F680-\U0001F6FF" # transport & map symbols "\U0001F1E0-\U0001F1FF" # flags (iOS) "]+", flags=re.UNICODE, ) # Create a regular expression to match specific patterns. pattern = "[\U0001F90D\U00002764\U0001F91F][\U0000FE0F\U0000200D]*" # Remove emojis, specific patterns, and special characters from the text. text = emoji_pattern.sub("", text) text = re.sub(pattern, "", text) text = text.replace("✌", "") text = text.replace("@", "") text = text.replace("#", "hashtag_") return text def run_model(self, input_string, **generator_args): """ Runs the MT5 model on the given input string. Args: input_string (str): The input string. **generator_args: Additional arguments to pass to the MT5 model. Returns: str: The output of the MT5 model. """ # Encode the input string as a sequence of tokens. input_ids = self.tokenizer.encode(input_string, return_tensors="pt") # Generate the output text. res = self.model.generate(input_ids, **generator_args) # Decode the output text to a string. output = self.tokenizer.batch_decode(res, skip_special_tokens=True) return output def translate_text(self, persian_text): """ Translates the given Persian text to English. Args: persian_text (str): The Persian text to translate. Returns: str: The translated text. """ # Clean the Persian text. text_cleaned = self.clean_persian_text(persian_text) # Translate the cleaned text. translated_text = self.run_model(input_string=text_cleaned) return translated_text