#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json import os import platform import re from typing import List from project_settings import project_path os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() os.environ['NLTK_DATA'] = (project_path / "thirdparty_data/nltk_data").as_posix() import gradio as gr import nltk from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer language_map = { "Arabic": "ar", "Chinese": "zh", "Czech": "cs", "Danish": "da", "Dutch": "nl", "Flemish": "nl", "English": "en", "Estonian": "et", "Finnish": "fi", "French": "fr", "German": "de", "Italian": "it", "Norwegian": "no", "Polish": "pl", "Portuguese": "pt", "Russian": "ru", "Spanish": "es", "Swedish": "sv", "Turkish": "tr", } nltk_sent_tokenize_languages = [ "czech", "danish", "dutch", "flemish", "english", "estonian", "finnish", "french", "german", "italian", "norwegian", "polish", "portuguese", "russian", "spanish", "swedish", "turkish" ] def chinese_sent_tokenize(text: str): # 单字符断句符 text = re.sub(r"([。!?\?])([^”’])", r"\1\n\2", text) # 英文省略号 text = re.sub(r"(\.{6})([^”’])", r"\1\n\2", text) # 中文省略号 text = re.sub(r"(\…{2})([^”’])", r"\1\n\2", text) # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 text = re.sub(r"([。!?\?][”’])([^,。!?\?])", r"\1\n\2", text) # 段尾如果有多余的\n就去掉它 # 很多规则中会考虑分号; ,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 text = text.rstrip() return text.split("\n") def sent_tokenize(text: str, language: str) -> List[str]: if language in ["chinese"]: sent_list = chinese_sent_tokenize(text) elif language in nltk_sent_tokenize_languages: sent_list = nltk.sent_tokenize(text, language) else: sent_list = [text] return sent_list def main(): model_dict = { "facebook/m2m100_418M": { "model": M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M"), "tokenizer": M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") } } def multilingual_translation(src_text: str, src_lang: str, tgt_lang: str, model_name: str, ): # model model_group = model_dict.get(model_name) if model_group is None: for k in list(model_dict.keys()): del model_dict[k] model_dict[model_name] = { "model": M2M100ForConditionalGeneration.from_pretrained(model_name), "tokenizer": M2M100Tokenizer.from_pretrained(model_name) } model_group = model_dict[model_name] model = model_group["model"] tokenizer = model_group["tokenizer"] # tokenize tokenizer.src_lang = language_map[src_lang] if src_lang.lower() in nltk_sent_tokenize_languages: src_t_list = sent_tokenize(src_text, language=src_lang.lower()) else: src_t_list = [src_text] # infer result = "" for src_t in src_t_list: encoded_src = tokenizer(src_t, return_tensors="pt") generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(language_map[tgt_lang]), ) text_decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) result += text_decoded[0] output.value = result return result title = "Multilingual Machine Translation" description = """ M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. [Languages covered](https://huggingface.co/facebook/m2m100_418M#languages-covered) """ examples = [ [ "Hello world!", "English", "Chinese", "facebook/m2m100_418M", ], [ "我是一个句子。我是另一个句子。", "Chinese", "English", "facebook/m2m100_418M", ], [ "M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper and first released in this repository.", "English", "Chinese", "facebook/m2m100_418M", ] ] model_choices = [ "facebook/m2m100_418M", "facebook/m2m100_1.2B" ] inputs = [ gr.Textbox(lines=4, placeholder="text", label="Input Text"), gr.Dropdown(choices=list(language_map.keys()), value="English", label="Source Language"), gr.Dropdown(choices=list(language_map.keys()), value="Chinese", label="Target Language"), gr.Dropdown(choices=model_choices, value="facebook/m2m100_418M", label="model_name") ] output = gr.Textbox(lines=4, label="Output Text") demo = gr.Interface( fn=multilingual_translation, inputs=inputs, outputs=output, examples=examples, title=title, description=description, cache_examples=False ) demo.queue().launch( # debug=True, enable_queue=True, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=7860 ) return if __name__ == '__main__': main()