|
import streamlit as st |
|
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
@st.cache_resource |
|
def init_model(): |
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B") |
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") |
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = init_model() |
|
|
|
langs = {"zh": "Chinese", "en": "English", "af": "Afrikaans", "am": "Amharic", "ar": "Arabic", "as": "Asturian", |
|
"az": "Azerbaijani", "ba": "Bashkir", "be": "Belarusian", "bg": "Bulgarian", "bn": "Bengali", "br": "Breton", |
|
"bs": "Bosnian", "ca": "Valencian", "ce": "Cebuano", "cs": "Czech", "cy": "Welsh", "da": "Danish", |
|
"de": "German", "el": "Greeek", "es": "Spanish", "et": "Estonian", "fa": "Persian", "ff": "Fulah", |
|
"fi": "Finnish", "fr": "French", "fy": "Frisian", "ga": "Irish", "gd": "Gaelic", "gl": "Galician", |
|
"gu": "Gujarati", "ha": "Hausa", "he": "Hebrew", "hi": "Hindi", "hr": "Croatian", "ht": "Creole", |
|
"hu": "Hungarian", "hy": "Armenian", "id": "Indonesian", "ig": "Igbo", "il": "Iloko", "is": "Icelandic", |
|
"it": "Italian", "ja": "Japanese", "jv": "Javanese", "ka": "Georgian", "kk": "Kazakh", "km": "Khmer", |
|
"kn": "Kannada", "ko": "Korean", "lb": "Letzeburgesch", "lg": "Ganda", "ln": "Lingala", "lo": "Lao", |
|
"lt": "Lithuanian", "lv": "Latvian", "mg": "Malagasy", "mk": "Macedonian", "ml": "Malayalam", |
|
"mn": "Mongolian", "mr": "Marathi", "ms": "Malay", "my": "Burmese", "ne": "Nepali", "nl": "Flemish", |
|
"no": "Norwegian", "ns": "Sotho", "or": "Oriya", "pa": "Punjabi", "pl": "Polish", "ps": "Pashto", |
|
"pt": "Portuguese", "ro": "Moldovan", "ru": "Russian", "sd": "Sindhi", "si": "Sinhalese", "sk": "Slovak", |
|
"sl": "Slovenian", "so": "Somali", "sq": "Albanian", "sr": "Serbian", "ss": "Swati", "su": "Sundanese", |
|
"sv": "Swedish", "sw": "Swahili", "ta": "Tamil", "th": "Thai", "tl": "Tagalog", "tn": "Tswana", |
|
"tr": "Turkish", "uk": "Ukrainian", "ur": "Urdu", "uz": "Uzbek", "vi": "Vietnamese", "wo": "Wolof", |
|
"xh": "Xhosa", "yi": "Yiddish", "yo": "Yoruba", "zu": "Zulu"} |
|
|
|
|
|
def chose_lang_format(option): |
|
return langs[option] |
|
|
|
|
|
st.title('💿facebook-m2m100_1.2B') |
|
with st.form('my_form'): |
|
text = st.text_area('Enter text:', '') |
|
cols = st.columns(3) |
|
submitted = cols[0].form_submit_button('翻译') |
|
src = cols[1].selectbox( |
|
'from', options=list(langs.keys()), format_func=chose_lang_format) |
|
to = cols[2].selectbox( |
|
'to', options=list(langs.keys()), format_func=chose_lang_format) |
|
|
|
placeholder = st.markdown("", unsafe_allow_html=True) |
|
if submitted: |
|
with st.spinner("Translating..."): |
|
tokenizer.src_lang = src |
|
encoded_zh = tokenizer(text, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id(to)) |
|
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
placeholder.markdown(translated[0]) |
|
|