Spaces:
Running
Running
from fastapi import FastAPI, Request | |
from transformers import MarianMTModel, MarianTokenizer | |
import torch | |
app = FastAPI() | |
# Map target languages to Hugging Face model IDs | |
MODEL_MAP = { | |
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", | |
"cs": "Helsinki-NLP/opus-mt-en-cs", | |
"da": "Helsinki-NLP/opus-mt-en-da", | |
"de": "Helsinki-NLP/opus-mt-en-de", | |
"el": "Helsinki-NLP/opus-mt-tc-big-en-el", | |
"es": "facebook/nllb-200-distilled-600M", | |
"et": "Helsinki-NLP/opus-mt-tc-big-en-et", | |
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", | |
"fr": "Helsinki-NLP/opus-mt-en-fr", | |
"hr": "facebook/mbart-large-50-many-to-many-mmt", | |
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", | |
"is": "facebook/nllb-200-distilled-600M", | |
"it": "facebook/nllb-200-distilled-600M", | |
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", | |
"lv": "facebook/mbart-large-50-many-to-many-mmt", | |
"mk": "facebook/nllb-200-distilled-600M", | |
"nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
"nl": "facebook/mbart-large-50-many-to-many-mmt", | |
"no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
"pl": "facebook/nllb-200-distilled-600M", | |
"pt": "facebook/mbart-large-50-many-to-many-mmt", | |
"ro": "facebook/mbart-large-50-many-to-many-mmt", | |
"sk": "Helsinki-NLP/opus-mt-en-sk", | |
"sl": "alirezamsh/small100", | |
"sq": "alirezamsh/small100", | |
"sv": "Helsinki-NLP/opus-mt-en-sv", | |
"tr": "facebook/nllb-200-distilled-600M" | |
} | |
MODEL_CACHE = {} | |
# β Load Hugging Face model (Helsinki or Small100) | |
def load_model(model_id): | |
if model_id not in MODEL_CACHE: | |
tokenizer = MarianTokenizer.from_pretrained(model_id) | |
model = MarianMTModel.from_pretrained(model_id).to("cpu") | |
MODEL_CACHE[model_id] = (tokenizer, model) | |
return MODEL_CACHE[model_id] | |
# β POST /translate | |
async def translate(request: Request): | |
data = await request.json() | |
text = data.get("text") | |
target_lang = data.get("target_lang") | |
if not text or not target_lang: | |
return {"error": "Missing 'text' or 'target_lang'"} | |
model_id = MODEL_MAP.get(target_lang) | |
if not model_id: | |
return {"error": f"No model found for target language '{target_lang}'"} | |
if model_id.startswith("facebook/"): | |
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."} | |
try: | |
tokenizer, model = load_model(model_id) | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device) | |
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True) | |
return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)} | |
except Exception as e: | |
return {"error": f"Translation failed: {str(e)}"} | |
# β GET /languages | |
def list_languages(): | |
return {"supported_languages": list(MODEL_MAP.keys())} | |
# β GET /health | |
def health(): | |
return {"status": "ok"} | |
# β Uvicorn startup (required by Hugging Face) | |
import uvicorn | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host="0.0.0.0", port=7860) |