| | import gradio as gr |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | import torch |
| | from functools import lru_cache |
| |
|
| | |
| | LANGUAGE_CODES = { |
| | "Arabic": "arb_Arab", |
| | "English": "eng_Latn", |
| | "French": "fra_Latn", |
| | "Spanish": "spa_Latn", |
| | "German": "deu_Latn", |
| | "Italian": "ita_Latn", |
| | "Portuguese": "por_Latn", |
| | "Russian": "rus_Cyrl", |
| | "Japanese": "jpn_Jpan", |
| | "Korean": "kor_Hang", |
| | "Chinese (Simplified)": "zho_Hans", |
| | "Hindi": "hin_Deva", |
| | "Turkish": "tur_Latn", |
| | "Dutch": "nld_Latn", |
| | "Polish": "pol_Latn", |
| | "Swedish": "swe_Latn", |
| | "Arabic (Egyptian)": "arz_Arab", |
| | "Arabic (Moroccan)": "ary_Arab", |
| | "Indonesian": "ind_Latn", |
| | "Vietnamese": "vie_Latn", |
| | "Thai": "tha_Thai", |
| | "Ukrainian": "ukr_Cyrl", |
| | "Romanian": "ron_Latn", |
| | "Greek": "ell_Grek", |
| | "Hebrew": "heb_Hebr", |
| | } |
| |
|
| | |
| | print("Loading NLLB-200 model...") |
| | model_name = "facebook/nllb-200-distilled-600M" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = model.to(device) |
| | print(f"Model loaded on {device}") |
| |
|
| |
|
| | |
| | translation_cache = {} |
| |
|
| | def translate(text, src_lang, tgt_lang): |
| | if not text or not text.strip(): |
| | return "" |
| | text = text.strip() |
| | src_lang_code = LANGUAGE_CODES.get(src_lang, "eng_Latn") |
| | tgt_lang_code = LANGUAGE_CODES.get(tgt_lang, "arb_Arab") |
| | cache_key = f"{src_lang_code}:{tgt_lang_code}:{text}" |
| | if cache_key in translation_cache: |
| | return translation_cache[cache_key] |
| | try: |
| | tokenizer.src_lang = src_lang_code |
| | inputs = tokenizer(text, return_tensors="pt", padding=True, max_length=512, truncation=True) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | with torch.no_grad(): |
| | translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], max_length=512, num_beams=5, early_stopping=True) |
| | translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
| | translation_cache[cache_key] = translation |
| | return translation |
| | except Exception as e: |
| | return f"Translation error: {str(e)}" |
| |
|
| |
|
| | def gradio_translate(text, src_lang, tgt_lang): |
| | """Gradio interface function""" |
| | if src_lang == tgt_lang: |
| | return text |
| | |
| | result = translate(text, src_lang, tgt_lang) |
| | return result |
| |
|
| |
|
| | |
| | LANGUAGES = sorted(LANGUAGE_CODES.keys()) |
| |
|
| |
|
| | |
| | with gr.Blocks(title="NLLB-200 Translation API", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown( |
| | """ |
| | # π NLLB-200 Translation API |
| | |
| | **Meta's No Language Left Behind** - 200 Languages Translation |
| | |
| | - β
High-quality translation for 200+ languages |
| | - β
44% better than previous models |
| | - β
+70% improvement for complex languages (Arabic, Hindi, etc.) |
| | - β
Direct translation (no pivot through English) |
| | - β
Cached for faster repeated translations |
| | |
| | **Powered by**: `facebook/nllb-200-distilled-600M` |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | src_lang = gr.Dropdown( |
| | choices=LANGUAGES, |
| | value="English", |
| | label="Source Language", |
| | interactive=True |
| | ) |
| | input_text = gr.Textbox( |
| | label="Text to Translate", |
| | placeholder="Enter text here...", |
| | lines=5, |
| | max_lines=10 |
| | ) |
| | |
| | with gr.Column(): |
| | tgt_lang = gr.Dropdown( |
| | choices=LANGUAGES, |
| | value="Arabic", |
| | label="Target Language", |
| | interactive=True |
| | ) |
| | output_text = gr.Textbox( |
| | label="Translation", |
| | lines=5, |
| | max_lines=10, |
| | interactive=False |
| | ) |
| | |
| | with gr.Row(): |
| | translate_btn = gr.Button("Translate π", variant="primary", size="lg") |
| | clear_btn = gr.Button("Clear", variant="secondary") |
| | |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["Hello, how are you?", "English", "Arabic"], |
| | ["Ω
Ψ±ΨΨ¨Ψ§Ψ ΩΩΩ ΨΨ§ΩΩΨ", "Arabic", "French"], |
| | ["Bonjour, comment allez-vous?", "French", "English"], |
| | ["This is a test of NLLB-200 translation model.", "English", "Spanish"], |
| | ], |
| | inputs=[input_text, src_lang, tgt_lang], |
| | outputs=output_text, |
| | fn=gradio_translate, |
| | cache_examples=False |
| | ) |
| | |
| | |
| | translate_btn.click( |
| | fn=gradio_translate, |
| | inputs=[input_text, src_lang, tgt_lang], |
| | outputs=output_text |
| | ) |
| | |
| | clear_btn.click( |
| | fn=lambda: ("", ""), |
| | inputs=None, |
| | outputs=[input_text, output_text] |
| | ) |
| | |
| | |
| | input_text.submit( |
| | fn=gradio_translate, |
| | inputs=[input_text, src_lang, tgt_lang], |
| | outputs=output_text |
| | ) |
| | |
| | gr.Markdown( |
| | """ |
| | --- |
| | ### API Usage |
| | |
| | You can use this Space programmatically via the Gradio API: |
| | |
| | ```python |
| | from gradio_client import Client |
| | |
| | client = Client("TGPro1/NLLB200") |
| | result = client.predict( |
| | "Hello, world!", # text |
| | "English", # source language |
| | "Arabic", # target language |
| | api_name="/predict" |
| | ) |
| | print(result) |
| | ``` |
| | |
| | **Supported Languages**: 25+ major languages (see dropdown) |
| | |
| | For full list of 200 languages, check the [NLLB-200 documentation](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200) |
| | """ |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.queue(max_size=10) |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False |
| | ) |
| |
|