EveSa's picture
fix api problem and tokent auth
cd518e1
raw
history blame
5.1 kB
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import re
from src.inference_lstm import inference_lstm
from src.inference_t5 import inference_t5
def summarize(text: str):
"""
Returns the summary of an input text.
Parameter
---------
text : str
A text to summarize.
Returns
-------
:str
The summary of the input text.
"""
if global_choose_model.var == "lstm":
text = " ".join(inference_lstm(text))
return re.sub("^1|1$|<start>|<end>", "", text)
elif global_choose_model.var == "fineTunedT5":
text = inference_t5(text)
return re.sub("<extra_id_0> ", "", text)
elif global_choose_model.var == "":
return "You have not chosen a model."
def global_choose_model(model_choice):
"""This function allows to connect the choice of the
model and the summary function by defining global variables.
The aime is to access a variable outside of a function."""
if model_choice == "lstm":
global_choose_model.var = "lstm"
elif model_choice == "fineTunedT5":
global_choose_model.var = "fineTunedT5"
elif model_choice == " --- ":
global_choose_model.var = ""
# definition of the main elements used in the script
model_list = [
{"model": " --- ", "name": " --- "},
{"model": "lstm", "name": "LSTM"},
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
]
selected_model = " --- "
model_choice = ""
# -------- API ---------------------------------------------------------------
app = FastAPI()
# static files to send the css
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
"""This function is used to create an endpoint for the
index page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/model")
async def get_model(request: Request):
"""This function is used to create an endpoint for
the model page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/model",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/predict")
async def get_prediction(request: Request):
"""This function is used to create an endpoint for
the predict page of the app."""
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/predict"}
)
@app.post("/model")
async def choose_model(request: Request, model_choice: str = Form(None)):
"""This functions allows to retrieve the model chosen by the user. Then, it
can end to an error message if it not defined or it is sent to the
global_choose_model function which connects the user choice to the
use of a model."""
selected_model = model_choice
# print(selected_model)
if not model_choice:
model_error = "Please select a model."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": model_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
global_choose_model(model_choice)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
"""This function allows to retrieve the input text of the user.
Then, it can end to an error message or it can be sent to
the summarize function."""
if not text:
text_error = "Please enter your text."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text,
"summary": summary,
"model_list": model_list,
"selected_model": selected_model,
},
)
# ------------------------------------------------------------------------------------
# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
# uvicorn.run("api:app", port=8000, reload=True)