#!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer from .base import PipelineTool LANGUAGE_CODES = { "Acehnese Arabic": "ace_Arab", "Acehnese Latin": "ace_Latn", "Mesopotamian Arabic": "acm_Arab", "Ta'izzi-Adeni Arabic": "acq_Arab", "Tunisian Arabic": "aeb_Arab", "Afrikaans": "afr_Latn", "South Levantine Arabic": "ajp_Arab", "Akan": "aka_Latn", "Amharic": "amh_Ethi", "North Levantine Arabic": "apc_Arab", "Modern Standard Arabic": "arb_Arab", "Modern Standard Arabic Romanized": "arb_Latn", "Najdi Arabic": "ars_Arab", "Moroccan Arabic": "ary_Arab", "Egyptian Arabic": "arz_Arab", "Assamese": "asm_Beng", "Asturian": "ast_Latn", "Awadhi": "awa_Deva", "Central Aymara": "ayr_Latn", "South Azerbaijani": "azb_Arab", "North Azerbaijani": "azj_Latn", "Bashkir": "bak_Cyrl", "Bambara": "bam_Latn", "Balinese": "ban_Latn", "Belarusian": "bel_Cyrl", "Bemba": "bem_Latn", "Bengali": "ben_Beng", "Bhojpuri": "bho_Deva", "Banjar Arabic": "bjn_Arab", "Banjar Latin": "bjn_Latn", "Standard Tibetan": "bod_Tibt", "Bosnian": "bos_Latn", "Buginese": "bug_Latn", "Bulgarian": "bul_Cyrl", "Catalan": "cat_Latn", "Cebuano": "ceb_Latn", "Czech": "ces_Latn", "Chokwe": "cjk_Latn", "Central Kurdish": "ckb_Arab", "Crimean Tatar": "crh_Latn", "Welsh": "cym_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Southwestern Dinka": "dik_Latn", "Dyula": "dyu_Latn", "Dzongkha": "dzo_Tibt", "Greek": "ell_Grek", "English": "eng_Latn", "Esperanto": "epo_Latn", "Estonian": "est_Latn", "Basque": "eus_Latn", "Ewe": "ewe_Latn", "Faroese": "fao_Latn", "Fijian": "fij_Latn", "Finnish": "fin_Latn", "Fon": "fon_Latn", "French": "fra_Latn", "Friulian": "fur_Latn", "Nigerian Fulfulde": "fuv_Latn", "Scottish Gaelic": "gla_Latn", "Irish": "gle_Latn", "Galician": "glg_Latn", "Guarani": "grn_Latn", "Gujarati": "guj_Gujr", "Haitian Creole": "hat_Latn", "Hausa": "hau_Latn", "Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Chhattisgarhi": "hne_Deva", "Croatian": "hrv_Latn", "Hungarian": "hun_Latn", "Armenian": "hye_Armn", "Igbo": "ibo_Latn", "Ilocano": "ilo_Latn", "Indonesian": "ind_Latn", "Icelandic": "isl_Latn", "Italian": "ita_Latn", "Javanese": "jav_Latn", "Japanese": "jpn_Jpan", "Kabyle": "kab_Latn", "Jingpho": "kac_Latn", "Kamba": "kam_Latn", "Kannada": "kan_Knda", "Kashmiri Arabic": "kas_Arab", "Kashmiri Devanagari": "kas_Deva", "Georgian": "kat_Geor", "Central Kanuri Arabic": "knc_Arab", "Central Kanuri Latin": "knc_Latn", "Kazakh": "kaz_Cyrl", "Kabiyè": "kbp_Latn", "Kabuverdianu": "kea_Latn", "Khmer": "khm_Khmr", "Kikuyu": "kik_Latn", "Kinyarwanda": "kin_Latn", "Kyrgyz": "kir_Cyrl", "Kimbundu": "kmb_Latn", "Northern Kurdish": "kmr_Latn", "Kikongo": "kon_Latn", "Korean": "kor_Hang", "Lao": "lao_Laoo", "Ligurian": "lij_Latn", "Limburgish": "lim_Latn", "Lingala": "lin_Latn", "Lithuanian": "lit_Latn", "Lombard": "lmo_Latn", "Latgalian": "ltg_Latn", "Luxembourgish": "ltz_Latn", "Luba-Kasai": "lua_Latn", "Ganda": "lug_Latn", "Luo": "luo_Latn", "Mizo": "lus_Latn", "Standard Latvian": "lvs_Latn", "Magahi": "mag_Deva", "Maithili": "mai_Deva", "Malayalam": "mal_Mlym", "Marathi": "mar_Deva", "Minangkabau Arabic ": "min_Arab", "Minangkabau Latin": "min_Latn", "Macedonian": "mkd_Cyrl", "Plateau Malagasy": "plt_Latn", "Maltese": "mlt_Latn", "Meitei Bengali": "mni_Beng", "Halh Mongolian": "khk_Cyrl", "Mossi": "mos_Latn", "Maori": "mri_Latn", "Burmese": "mya_Mymr", "Dutch": "nld_Latn", "Norwegian Nynorsk": "nno_Latn", "Norwegian Bokmål": "nob_Latn", "Nepali": "npi_Deva", "Northern Sotho": "nso_Latn", "Nuer": "nus_Latn", "Nyanja": "nya_Latn", "Occitan": "oci_Latn", "West Central Oromo": "gaz_Latn", "Odia": "ory_Orya", "Pangasinan": "pag_Latn", "Eastern Panjabi": "pan_Guru", "Papiamento": "pap_Latn", "Western Persian": "pes_Arab", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Dari": "prs_Arab", "Southern Pashto": "pbt_Arab", "Ayacucho Quechua": "quy_Latn", "Romanian": "ron_Latn", "Rundi": "run_Latn", "Russian": "rus_Cyrl", "Sango": "sag_Latn", "Sanskrit": "san_Deva", "Santali": "sat_Olck", "Sicilian": "scn_Latn", "Shan": "shn_Mymr", "Sinhala": "sin_Sinh", "Slovak": "slk_Latn", "Slovenian": "slv_Latn", "Samoan": "smo_Latn", "Shona": "sna_Latn", "Sindhi": "snd_Arab", "Somali": "som_Latn", "Southern Sotho": "sot_Latn", "Spanish": "spa_Latn", "Tosk Albanian": "als_Latn", "Sardinian": "srd_Latn", "Serbian": "srp_Cyrl", "Swati": "ssw_Latn", "Sundanese": "sun_Latn", "Swedish": "swe_Latn", "Swahili": "swh_Latn", "Silesian": "szl_Latn", "Tamil": "tam_Taml", "Tatar": "tat_Cyrl", "Telugu": "tel_Telu", "Tajik": "tgk_Cyrl", "Tagalog": "tgl_Latn", "Thai": "tha_Thai", "Tigrinya": "tir_Ethi", "Tamasheq Latin": "taq_Latn", "Tamasheq Tifinagh": "taq_Tfng", "Tok Pisin": "tpi_Latn", "Tswana": "tsn_Latn", "Tsonga": "tso_Latn", "Turkmen": "tuk_Latn", "Tumbuka": "tum_Latn", "Turkish": "tur_Latn", "Twi": "twi_Latn", "Central Atlas Tamazight": "tzm_Tfng", "Uyghur": "uig_Arab", "Ukrainian": "ukr_Cyrl", "Umbundu": "umb_Latn", "Urdu": "urd_Arab", "Northern Uzbek": "uzn_Latn", "Venetian": "vec_Latn", "Vietnamese": "vie_Latn", "Waray": "war_Latn", "Wolof": "wol_Latn", "Xhosa": "xho_Latn", "Eastern Yiddish": "ydd_Hebr", "Yoruba": "yor_Latn", "Yue Chinese": "yue_Hant", "Chinese Simplified": "zho_Hans", "Chinese Traditional": "zho_Hant", "Standard Malay": "zsm_Latn", "Zulu": "zul_Latn", } class TranslationTool(PipelineTool): """ Example: ```py from transformers.tools import TranslationTool translator = TranslationTool() translator("This is a super nice API!", src_lang="English", tgt_lang="French") ``` """ default_checkpoint = "facebook/nllb-200-distilled-600M" description = ( "This is a tool that translates text from a language to another. It takes three inputs: `text`, which should " "be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, " "which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in " "plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`." ) name = "translator" pre_processor_class = AutoTokenizer model_class = AutoModelForSeq2SeqLM lang_to_code = LANGUAGE_CODES inputs = ["text", "text", "text"] outputs = ["text"] def encode(self, text, src_lang, tgt_lang): if src_lang not in self.lang_to_code: raise ValueError(f"{src_lang} is not a supported language.") if tgt_lang not in self.lang_to_code: raise ValueError(f"{tgt_lang} is not a supported language.") src_lang = self.lang_to_code[src_lang] tgt_lang = self.lang_to_code[tgt_lang] return self.pre_processor._build_translation_inputs( text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang ) def forward(self, inputs): return self.model.generate(**inputs) def decode(self, outputs): return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)