from transformers import TranslationPipeline from transformers.pipelines.text2text_generation import ReturnType from transformers import BartForConditionalGeneration, BertTokenizer import logging import re def fix_chinese_text_generation_space(text): output_text = text output_text = re.sub( r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text) output_text = re.sub( r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text) output_text = re.sub( r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text) output_text = re.sub( r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text) output_text = re.sub(r'$\s([0-9])', r'$\1', output_text) output_text = re.sub(',', ',', output_text) output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2', output_text) # fix comma in numbers # fix multiple commas output_text = re.sub(r'\s?[,]+\s?', ',', output_text) output_text = re.sub(r'\s?[、]+\s?', '、', output_text) # fix period output_text = re.sub(r'\s?[。]+\s?', '。', output_text) # fix ... output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text) # fix exclamation mark output_text = re.sub(r'\s?[!!]+\s?', '!', output_text) # fix question mark output_text = re.sub(r'\s?[??]+\s?', '?', output_text) # fix colon output_text = re.sub(r'\s?[::]+\s?', ':', output_text) # fix quotation mark output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text) # fix semicolon output_text = re.sub(r'\s?[;;]+\s?', ';', output_text) # fix dots output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text) output_text = re.sub(r'\s?\[…\]\s?', '', output_text) output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text) output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text) # fix slash output_text = re.sub(r'\s?[//]+\s?', '/', output_text) # fix dollar sign output_text = re.sub(r'\s?[$$]+\s?', '$', output_text) # fix @ output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text) # fix baskets output_text = re.sub( r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text) return output_text class BartPipeline(TranslationPipeline): def __init__(self, model_name_or_path: str = "indiejoseph/bart-base-cantonese", device=None, max_length=512, src_lang=None, tgt_lang=None): self.model_name_or_path = model_name_or_path self.tokenizer = self._load_tokenizer() self.model = self._load_model() self.model.eval() super().__init__(self.model, self.tokenizer, device=device, max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang) def _load_tokenizer(self): return BertTokenizer.from_pretrained(self.model_name_or_path) def _load_model(self): return BartForConditionalGeneration.from_pretrained(self.model_name_or_path) def postprocess( self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=True, ): records = super().postprocess( model_outputs, return_type=return_type, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) for rec in records: translation_text = fix_chinese_text_generation_space( rec["translation_text"].strip()) rec["translation_text"] = translation_text return records if __name__ == '__main__': pipe = BartPipeline(device=0) print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100))