Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import os | |
| from datetime import datetime | |
| from functools import lru_cache | |
| import torch | |
| # Language codes | |
| LANGUAGE_CODES = { | |
| "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans", | |
| "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl", | |
| "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai" | |
| } | |
| # Translation history class | |
| class TranslationHistory: | |
| def __init__(self): | |
| self.history = [] | |
| def add(self, src, translated, src_lang, tgt_lang): | |
| self.history.insert(0, { | |
| "source": src, "translated": translated, | |
| "src_lang": src_lang, "tgt_lang": tgt_lang, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| if len(self.history) > 100: | |
| self.history.pop() | |
| def get(self): | |
| return self.history | |
| def clear(self): | |
| self.history = [] | |
| # Initialize history | |
| history = TranslationHistory() | |
| # Load model and tokenizer | |
| model_name = "facebook/nllb-200-distilled-600M" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| # Cache translation | |
| def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7): | |
| if not text.strip(): | |
| return "" | |
| try: | |
| src_code = LANGUAGE_CODES.get(src_lang, src_lang) | |
| tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang) | |
| input_tokens = tokenizer(text, return_tensors="pt", padding=True) | |
| input_tokens = {k: v.to(device) for k, v in input_tokens.items()} | |
| forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) | |
| output = model.generate( | |
| **input_tokens, | |
| forced_bos_token_id=forced_bos_token_id, | |
| max_length=max_length, | |
| temperature=temperature, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| result = tokenizer.decode(output[0], skip_special_tokens=True) | |
| history.add(text, result, src_lang, tgt_lang) | |
| return result | |
| except Exception as e: | |
| return f"Translation error: {e}" | |
| # Swap languages | |
| swap_langs = lambda src, tgt: (tgt, src) | |
| # Translate file | |
| def translate_file(file, src_lang, tgt_lang, max_length, temperature): | |
| try: | |
| lines = file.decode("utf-8").splitlines() | |
| translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()] | |
| return "\n".join(translated) | |
| except Exception as e: | |
| return f"File translation error: {e}" | |
| # Custom CSS to improve UI | |
| gradio_style = """ | |
| .gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; } | |
| textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; } | |
| textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; } | |
| """ | |
| with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π PolyLinguaAI: Translate Across Worlds | |
| Translate instantly between 12+ languages using Facebook's NLLB model. | |
| """) | |
| with gr.Tab("π Text Translator"): | |
| with gr.Row(): | |
| src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="π From", value="English") | |
| swap = gr.Button("β") | |
| tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="π― To", value="Korean") | |
| with gr.Row(): | |
| input_text = gr.Textbox(lines=3, label="βοΈ Input Text") | |
| output_text = gr.Textbox(lines=3, label="π€ Translated Output", interactive=False) | |
| with gr.Row(): | |
| translate = gr.Button("π Translate", variant="primary") | |
| clear = gr.Button("π§½ Clear") | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| with gr.Accordion("π Translation History", open=False): | |
| history_json = gr.JSON(label="Recent Translations") | |
| with gr.Row(): | |
| refresh = gr.Button("π Refresh") | |
| clear_history = gr.Button("π§Ή Clear History") | |
| with gr.Tab("π File Translator"): | |
| file_input = gr.File(label="π Upload .txt File", file_types=[".txt"]) | |
| file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="π From", value="English") | |
| file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="π To", value="Korean") | |
| file_translate = gr.Button("π Translate File", variant="primary") | |
| file_result = gr.Textbox(label="π File Output", lines=10, interactive=False) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length") | |
| f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| # Events | |
| translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text) | |
| clear.click(lambda: ("", ""), None, [input_text, output_text]) | |
| swap.click(swap_langs, [src_lang, tgt_lang], [src_lang, tgt_lang]) | |
| refresh.click(lambda: history.get(), None, history_json) | |
| clear_history.click(lambda: history.clear() or [], None, history_json) | |
| file_translate.click(lambda file, src, tgt, ml, t: translate_file(file.read(), src, tgt, ml, t), | |
| [file_input, file_src, file_tgt, f_max_length, f_temp], file_result) | |
| gr.Markdown(f""" | |
| ### π Model Info | |
| - Model: `{model_name}` | |
| - Device: `{device}` | |
| - Cached Translations: 512 | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |