SummaryProject / api.py
EstelleSkwarto's picture
fix api.py linked to css
3c3e49f
raw
history blame
3.07 kB
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import re
from src.inference import inferenceAPI
from src.inference_t5 import inferenceAPI_t5
# ------ INFERENCE MODEL ------------------------------------------------------
# appel de la fonction inference, adaptee pour une entree txt
def summarize(text: str):
if choisir_modele.var == "lstm":
return " ".join(inferenceAPI(text))
elif choisir_modele.var == "fineTunedT5":
text = inferenceAPI_t5(text)
return re.sub("<extra_id_0> ", "", text)
# ----------------------------------------------------------------------------
def choisir_modele(choixModele):
print("ON A RECUP LE CHOIX MODELE")
if choixModele == "lstm":
choisir_modele.var = "lstm"
elif choixModele == "fineTunedT5":
choisir_modele.var = "fineTunedT5"
# -------- API ---------------------------------------------------------------
app = FastAPI()
# static files pour envoi du css au navigateur
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/"}
)
@app.get("/model")
async def get_model(request: Request):
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/model"}
)
@app.get("/predict")
async def get_prediction(request: Request):
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/predict"}
)
@app.post("/model")
async def choix_model(request: Request, choixModel: str = Form(None)):
print(choixModel)
if not choixModel:
erreur_modele = "Merci de saisir un modèle."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": erreur_modele}
)
else:
choisir_modele(choixModel)
print("C'est bon on utilise le modèle demandé")
return templates.TemplateResponse(
"index.html.jinja", {"request": request}
)
# retourner le texte, les predictions et message d'erreur si formulaire envoye
# vide
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
if not text:
error = "Merci de saisir votre texte."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": error}
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja",
{"request": request, "text": text, "summary": summary},
)
# ------------------------------------------------------------------------------------
# lancer le serveur et le recharger a chaque modification sauvegardee
if __name__ == "__main__":
uvicorn.run("api:app", port=8000, reload=True)