Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from IndicTransToolkit.processor import IndicProcessor | |
| # Get token from environment variable | |
| token = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| # Device configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Model configuration - English to Kannada translation | |
| src_lang, tgt_lang = "eng_Latn", "kan_Knda" | |
| model_name = "ai4bharat/indictrans2-en-indic-dist-200M" | |
| # Global variables to store model and tokenizer | |
| model = None | |
| tokenizer = None | |
| ip = None | |
| def load_model(): | |
| """Load the translation model and tokenizer""" | |
| global model, tokenizer, ip | |
| try: | |
| print(f"Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| token=token | |
| ) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| dtype=torch.float16, | |
| token=token | |
| ).to(DEVICE) | |
| ip = IndicProcessor(inference=True) | |
| print(f"Model loaded successfully on {DEVICE}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| return False | |
| def translate_text(input_text): | |
| """ | |
| Translate input text using the loaded model | |
| Args: | |
| input_text: Single sentence to translate | |
| Returns: | |
| Translated text | |
| """ | |
| if not model or not tokenizer or not ip: | |
| return "β Model not loaded. Please check the model configuration." | |
| if not input_text.strip(): | |
| return "Please enter some text to translate." | |
| try: | |
| # Single sentence translation | |
| input_sentences = [input_text.strip()] | |
| if not input_sentences: | |
| return "No valid sentences found." | |
| # Preprocess the input | |
| batch = ip.preprocess_batch( | |
| input_sentences, | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| ) | |
| # Tokenize the sentences | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| # Generate translations | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=False, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| # Decode the generated tokens | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| # Postprocess the translations | |
| translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
| # Return single translation | |
| return translations[0] if translations else "Translation failed." | |
| except Exception as e: | |
| return f"β Translation error: {str(e)}" | |
| def create_interface(): | |
| """Create and configure the Gradio interface""" | |
| # Load model on startup | |
| model_loaded = load_model() | |
| if not model_loaded: | |
| # Create a simple error interface | |
| with gr.Blocks(title="Translation App - Error") as demo: | |
| gr.Markdown("## β Model Loading Error") | |
| gr.Markdown("Failed to load the translation model. Please check:") | |
| gr.Markdown("- Your Hugging Face token is set correctly") | |
| gr.Markdown("- You have access to the gated model") | |
| gr.Markdown("- Your internet connection is working") | |
| return demo | |
| # Create the main interface | |
| with gr.Blocks( | |
| title="AI4Bharat IndicTrans2 Translation", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown( | |
| f""" | |
| # π AI4Bharat IndicTrans2 Translation | |
| **Current Configuration:** | |
| - **Source Language:** {src_lang} (English) | |
| - **Target Language:** {tgt_lang} (Kannada) | |
| - **Model:** {model_name} | |
| - **Device:** {DEVICE} | |
| Enter text below to translate from English to Kannada. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label=f"Input Text ({src_lang})", | |
| placeholder="Enter English text to translate...", | |
| lines=5, | |
| max_lines=10 | |
| ) | |
| with gr.Row(): | |
| translate_btn = gr.Button("π Translate", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label=f"Translation ({tgt_lang})", | |
| lines=5, | |
| max_lines=10, | |
| interactive=False | |
| ) | |
| # Example inputs | |
| gr.Markdown("### π Example Inputs:") | |
| examples = [ | |
| ["Hello, how are you?"], | |
| ["I am going to the market today."], | |
| ["This is a very beautiful place."], | |
| ["Can you help me?"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_text], | |
| outputs=[output_text], | |
| fn=translate_text, | |
| cache_examples=True | |
| ) | |
| # Event handlers | |
| translate_btn.click( | |
| fn=translate_text, | |
| inputs=[input_text], | |
| outputs=[output_text] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| outputs=[input_text, output_text] | |
| ) | |
| # Add footer | |
| gr.Markdown("---") | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| demo = create_interface() | |
| # Launch the app | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, # Default Gradio port | |
| share=False, # Set to True if you want a public link | |
| debug=True, | |
| show_error=True | |
| ) |