from fastapi import FastAPI from typing import Optional import torch from mmtafrica.mmtafrica import load_params, translate from huggingface_hub import hf_hub_download app = FastAPI() language_map = {'English':'en','Swahili':'sw','Fon':'fon','Igbo':'ig', 'Kinyarwanda':'rw','Xhosa':'xh','Yoruba':'yo','French':'fr'} # Load parameters and model from checkpoint checkpoint = hf_hub_download(repo_id="chrisjay/mmtafrica", filename="mmt_translation.pt") device = 'gpu' if torch.cuda.is_available() else 'cpu' params = load_params({'checkpoint':checkpoint,'device':device}) @app.post("/translate") async def translate(data: dict): source_language = data['source_language'] target_language = data['target_language'] source_sentence = data['source_sentence'] source_language_ = language_map[source_language] target_language_ = language_map[target_language] try: pred = translate(params,source_sentence,source_lang=source_language_,target_lang=target_language_) if pred=='': return {"translation": "Could not find translation"} else: return {"translation": pred} except Exception as error: return {"error": f"Issue with translation: \n {error}"}