SummaryProject / api.py
EstelleSkwarto's picture
completed api functions
7315e4e
raw
history blame
5.1 kB
import re
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from src.inference import inferenceAPI
from src.inference_t5 import inferenceAPI_t5
def summarize(text: str):
"""
Returns the summary of an input text.
Parameter
---------
text : str
A text to summarize.
Returns
-------
:str
The summary of the input text.
"""
if global_choose_model.var == "lstm":
text = " ".join(inferenceAPI(text))
return re.sub("^1|1$|<start>|<end>", "", text)
elif global_choose_model.var == "fineTunedT5":
text = inferenceAPI_t5(text)
return re.sub("<extra_id_0> ", "", text)
elif global_choose_model.var == "":
return "You have not chosen a model."
def global_choose_model(model_choice):
"""This function allows to connect the choice of the
model and the summary function by defining global variables.
The aime is to access a variable outside of a function."""
if model_choice == "lstm":
global_choose_model.var = "lstm"
elif model_choice == "fineTunedT5":
global_choose_model.var = "fineTunedT5"
elif model_choice == " --- ":
global_choose_model.var = ""
# definition of the main elements used in the script
model_list = [
{"model": " --- ", "name": " --- "},
{"model": "lstm", "name": "LSTM"},
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
]
selected_model = " --- "
model_choice = ""
# -------- API ---------------------------------------------------------------
app = FastAPI()
# static files to send the css
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
"""This function is used to create an endpoint for the
index page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/model")
async def get_model(request: Request):
"""This function is used to create an endpoint for
the model page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/model",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/predict")
async def get_prediction(request: Request):
"""This function is used to create an endpoint for
the predict page of the app."""
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/predict"}
)
@app.post("/model")
async def choose_model(request: Request, model_choice: str = Form(None)):
"""This functions allows to retrieve the model chosen by the user. Then, it
can end to an error message if it not defined or it is sent to the
global_choose_model function which connects the user choice to the
use of a model."""
selected_model = model_choice
# print(selected_model)
if not model_choice:
model_error = "Please select a model."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": model_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
global_choose_model(model_choice)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
"""This function allows to retrieve the input text of the user.
Then, it can end to an error message or it can be sent to
the summarize function."""
if not text:
text_error = "Please enter your text."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text,
"summary": summary,
"model_list": model_list,
"selected_model": selected_model,
},
)
# ------------------------------------------------------------------------------------
# launch the server and reload it each time a change is saved
if __name__ == "__main__":
uvicorn.run("api:app", port=8000, reload=True)