EveSa
Merge branch 'main' into Estelle
5d44b72 unverified
raw
history blame
No virus
2.93 kB
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import re
from src.inference import inferenceAPI
from src.inference_t5 import inferenceAPI_t5
# ------ INFERENCE MODEL --------------------------------------------------------------
# appel de la fonction inference, adaptee pour une entree txt
def summarize(text: str):
if choisir_modele.var == 'lstm' :
return " ".join(inferenceAPI(text))
elif choisir_modele.var == "fineTunedT5":
text = inferenceAPI_t5(text)
# ----------------------------------------------------------------------------------
def choisir_modele(choixModele):
print("ON A RECUP LE CHOIX MODELE")
if choixModele == "lstm" :
choisir_modele.var ='lstm'
elif choixModele == "fineTunedT5":
choisir_modele.var = "fineTunedT5"
else :
"le modele n'est pas defini"
# -------- API ---------------------------------------------------------------------
app = FastAPI()
# static files pour envoi du css au navigateur
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.get("/model")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.get("/predict")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.post("/model")
async def choix_model(request: Request, choixModel:str = Form(None)):
print(choixModel)
if not choixModel:
erreur_modele = "Merci de saisir un modèle."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": erreur_modele}
)
else :
choisir_modele(choixModel)
print("C'est bon on utilise le modèle demandé")
return templates.TemplateResponse(
"index.html.jinja", {"request": request}
)
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
if not text:
error = "Merci de saisir votre texte."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": error}
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": text, "summary": summary}
)
# ------------------------------------------------------------------------------------
# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
# uvicorn.run("api:app", port=8000, reload=True)