# romansetuift from transformers import AutoTokenizer, AutoModelForCausalLM import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load ONLY RomanSetu SFT model (no fallbacks) MODEL_ID = "ai4bharat/romansetu-cpt-roman-sft-roman" print(f"Loading RomanSetu SFT model: {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device) print("āœ… SFT Model loaded successfully") # Supported languages for RomanSetu SFT SUPPORTED_LANGUAGES = ["hindi", "marathi", "gujarati", "tamil", "malayalam", "bengali"] def translate_roman_to_english(text, source_lang): """ Direct one-shot translation using RomanSetu SFT model """ # Instruction format for SFT model prompt = f"### Instruction:\nTranslate the following {source_lang} text to English.\n\n### Input:\n{text}\n\n### Response:\n" inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_new_tokens=30, num_beams=3, do_sample=False, pad_token_id=tokenizer.eos_token_id, early_stopping=True ) full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract response after '### Response:' if "### Response:" in full_output: translation = full_output.split("### Response:")[-1].strip() else: translation = full_output.replace(prompt, "").strip() return translation # Test samples for all 6 supported languages (10 each) test_samples = { "hindi": ["aap kaise hain", "mera naam rahul hai", "main school jaa raha hun", "tum kya karte ho", "aaj mausam achha hai", "mujhe pani chahiye", "main doctor hun", "kya haal hai", "hamare paas samay hai", "tum kahan rehte ho"], "marathi": ["tumhi kase aahat", "majha nav amit aahe", "mee shalet jaato", "aaj divas changla aahe", "tumhala kay pahije", "mee pustak vachto", "aplya ghar kuthe aahe", "mee shikshak aahe", "kiti vel zhala", "tumhala madat pahije ka"], "gujarati": ["tame kem cho", "maru naam raj chhe", "hu school jau chu", "aaje su din chhe", "tamne pani joie chhe", "hu doctor chu", "mara ghar ahmedabad ma chhe", "shu samay chhe", "tame shu karo cho", "tamara parivar ma ketla sadasyo chhe"], "tamil": ["neenga epdi irukeenga", "en peyar kumar", "naan college poren", "inniku nalla weather", "ungalukku enna venum", "naan engineer", "unga veedu enga", "sapdu sapdingala", "naan tamil pesi", "ungala meet panna santhosham"], "malayalam": ["sukhamano", "ente peru ravi aanu", "njaan college pokunnu", "innale mazha peythu", "ningal evideya", "ente joli teacher aanu", "vellam venamenkil parayu", "ningal malayalam ariyumo", "njan keralathil aanu thamasikkunnathu", "ningal entha cheyyunnathu"], "bengali": ["tumi kemon acho", "amar naam sourav", "ami office jachhi", "ajke khub gorom", "tomar ki dorkar", "ami engineer", "tomader bari kothay", "tumi bangla bolte paro", "amra kolkatay thaki", "tumi ki khabe"] } # Run comprehensive batch test print("\nšŸ” ROMANSETU SFT BATCH TRANSLATION TEST") print("=" * 60) total_tests = 0 successful_tests = 0 for language in test_samples.keys(): print(f"\nšŸ“ TESTING {language.upper()}:") print("-" * 40) for i, text in enumerate(test_samples[language], 1): total_tests += 1 try: translation = translate_roman_to_english(text, language) # Check if translation is valid if translation and len(translation) > 2 and translation != text: successful_tests += 1 status = "āœ…" else: status = "āŒ" print(f"{i:2d}. '{text}' → '{translation}' {status}") except Exception as e: print(f"{i:2d}. '{text}' → ERROR: {e} āŒ") # Final results overall_accuracy = (successful_tests / total_tests) * 100 print(f"\nšŸŽÆ OVERALL RESULTS:") print(f"Total Tests: {total_tests}") print(f"Successful: {successful_tests}") print(f"Overall Accuracy: {overall_accuracy:.1f}%") print("āœ… SFT testing completed!")