from fastapi import FastAPI from typing import List import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor from fastapi.middleware.cors import CORSMiddleware import os os.environ["HF_HOME"] = "/.cache" # Initialize FastAPI app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize models and processors model = AutoModelForSeq2SeqLM.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True ) ip = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(DEVICE) def translate_text(sentences: List[str], target_lang: str): try: src_lang = "eng_Latn" batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations = ip.postprocess_batch(generated_tokens, lang=target_lang) return { "translations": translations, "source_language": src_lang, "target_language": target_lang, } except Exception as e: raise Exception(f"Translation failed: {str(e)}") # FastAPI routes @app.get("/health") async def health_check(): return {"status": "healthy"} @app.post("/translate") async def translate_endpoint(sentences: List[str], target_lang: str): try: result = translate_text(sentences=sentences, target_lang=target_lang) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e))