import os import fasttext import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch title = "Community Tab Language Detection & Translation" description = """ When comments are created in the community tab, detect the language of the content. Then, if the detected language is different from the user's language, display an option to translate it. """ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") device = 0 if torch.cuda.is_available() else -1 print(f"Is CUDA available: {torch.cuda.is_available()}") language_code_map = { "English": "eng_Latn", "French": "fra_Latn", "German": "deu_Latn", "Spanish": "spa_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Polish": "pol_Latn" } def identify_language(text): model_file = "lid218e.bin" model_full_path = os.path.join(os.path.dirname(__file__), model_file) model = fasttext.load_model(model_full_path) predictions = model.predict(text, k=1) # e.g., (('__label__eng_Latn',), array([0.81148803])) CHAR_TO_STRIP = 9 # To strip away '__label__' from language code language_code = predictions[0][0][CHAR_TO_STRIP:] return language_code def display(user_lang, text): user_lang_code = language_code_map[user_lang] language_code = identify_language(text) translate_button_visibility = language_code != user_lang_code detected_language_text = f""" Detected Language: {language_code}\n User Content Language: {user_lang_code}\n {"" if translate_button_visibility else "[NOT TRANSLATABLE] Detected Language and Content Language are the same"} """ return text, gr.update(value="", placeholder="Leave a comment"), gr.update(value=detected_language_text), gr.update(visible=translate_button_visibility, variant="primary") def translate(text, src_lang, tgt_lang): CHAR_TO_STRIP = 22 # To strip away 'Detected Language: ' from language code LANGUAGE_CODE_LENGTH = 8 # To strip away 'Detected Language: ' from language code src_lang_code = src_lang[CHAR_TO_STRIP:CHAR_TO_STRIP + LANGUAGE_CODE_LENGTH] tgt_lang_code = language_code_map[tgt_lang] translation_pipeline = pipeline( "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device) result = translation_pipeline(text) return result[0]['translation_text'] with gr.Blocks() as demo: gr.HTML( f"""
{description}