rajeshradhakrishnan's picture
update to facebook Bart50
696bdbb
raw
history blame
1.03 kB
import os
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import pipeline, MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
app = FastAPI()
@app.get("/infer_t5")
def t5(input):
model_inputs = tokenizer(input, return_tensors="pt")
# translate from English to Malayalam
generated_tokens = model.generate(
**model_inputs,
forced_bos_token_id=tokenizer.lang_code_to_id["ml_IN"]
)
output = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return {"output":output}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")