from fastapi import FastAPI, Form, Request from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from src.inference_lstm import inference_lstm from src.inference_t5 import inference_t5 # ------ INFERENCE MODEL -------------------------------------------------------------- # appel de la fonction inference, adaptee pour une entree txt def summarize(text: str): if choisir_modele.var == "lstm": return " ".join(inference_lstm(text)) elif choisir_modele.var == "fineTunedT5": text = inference_t5(text) # ---------------------------------------------------------------------------------- def choisir_modele(choixModele): print("ON A RECUP LE CHOIX MODELE") if choixModele == "lstm": choisir_modele.var = "lstm" elif choixModele == "fineTunedT5": choisir_modele.var = "fineTunedT5" else: "le modele n'est pas defini" # -------- API --------------------------------------------------------------------- app = FastAPI() # static files pour envoi du css au navigateur templates = Jinja2Templates(directory="templates") app.mount("/templates", StaticFiles(directory="templates"), name="templates") @app.get("/") async def index(request: Request): return templates.TemplateResponse("index.html.jinja", {"request": request}) @app.get("/model") async def index(request: Request): return templates.TemplateResponse("index.html.jinja", {"request": request}) @app.get("/predict") async def index(request: Request): return templates.TemplateResponse("index.html.jinja", {"request": request}) @app.post("/model") async def choix_model(request: Request, choixModel: str = Form(None)): print(choixModel) if not choixModel: erreur_modele = "Merci de saisir un modèle." return templates.TemplateResponse( "index.html.jinja", {"request": request, "text": erreur_modele} ) else: choisir_modele(choixModel) print("C'est bon on utilise le modèle demandé") return templates.TemplateResponse("index.html.jinja", {"request": request}) # retourner le texte, les predictions et message d'erreur si formulaire envoye vide @app.post("/predict") async def prediction(request: Request, text: str = Form(None)): if not text: error = "Merci de saisir votre texte." return templates.TemplateResponse( "index.html.jinja", {"request": request, "text": error} ) else: summary = summarize(text) return templates.TemplateResponse( "index.html.jinja", {"request": request, "text": text, "summary": summary} ) # ------------------------------------------------------------------------------------ # lancer le serveur et le recharger a chaque modification sauvegardee # if __name__ == "__main__": # uvicorn.run("api:app", port=8000, reload=True)