Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Form, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import re | |
from src.inference_lstm import inference_lstm | |
from src.inference_t5 import inference_t5 | |
def summarize(text: str): | |
""" | |
Returns the summary of an input text. | |
Parameter | |
--------- | |
text : str | |
A text to summarize. | |
Returns | |
------- | |
:str | |
The summary of the input text. | |
""" | |
if global_choose_model.var == "lstm": | |
text = " ".join(inference_lstm(text)) | |
return re.sub("^1|1$|<start>|<end>", "", text) | |
elif global_choose_model.var == "fineTunedT5": | |
text = inference_t5(text) | |
return re.sub("<extra_id_0> ", "", text) | |
elif global_choose_model.var == "": | |
return "You have not chosen a model." | |
def global_choose_model(model_choice): | |
"""This function allows to connect the choice of the | |
model and the summary function by defining global variables. | |
The aime is to access a variable outside of a function.""" | |
if model_choice == "lstm": | |
global_choose_model.var = "lstm" | |
elif model_choice == "fineTunedT5": | |
global_choose_model.var = "fineTunedT5" | |
elif model_choice == " --- ": | |
global_choose_model.var = "" | |
# definition of the main elements used in the script | |
model_list = [ | |
{"model": " --- ", "name": " --- "}, | |
{"model": "lstm", "name": "LSTM"}, | |
{"model": "fineTunedT5", "name": "Fine-tuned T5"}, | |
] | |
selected_model = " --- " | |
model_choice = "" | |
# -------- API --------------------------------------------------------------- | |
app = FastAPI() | |
# static files to send the css | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/templates", StaticFiles(directory="templates"), name="templates") | |
async def index(request: Request): | |
"""This function is used to create an endpoint for the | |
index page of the app.""" | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"current_route": "/", | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
async def get_model(request: Request): | |
"""This function is used to create an endpoint for | |
the model page of the app.""" | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"current_route": "/model", | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
async def get_prediction(request: Request): | |
"""This function is used to create an endpoint for | |
the predict page of the app.""" | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "current_route": "/predict"} | |
) | |
async def choose_model(request: Request, model_choice: str = Form(None)): | |
"""This functions allows to retrieve the model chosen by the user. Then, it | |
can end to an error message if it not defined or it is sent to the | |
global_choose_model function which connects the user choice to the | |
use of a model.""" | |
selected_model = model_choice | |
# print(selected_model) | |
if not model_choice: | |
model_error = "Please select a model." | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"text": model_error, | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
else: | |
global_choose_model(model_choice) | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
async def prediction(request: Request, text: str = Form(None)): | |
"""This function allows to retrieve the input text of the user. | |
Then, it can end to an error message or it can be sent to | |
the summarize function.""" | |
if not text: | |
text_error = "Please enter your text." | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"text": text_error, | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
else: | |
summary = summarize(text) | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{ | |
"request": request, | |
"text": text, | |
"summary": summary, | |
"model_list": model_list, | |
"selected_model": selected_model, | |
}, | |
) | |
# ------------------------------------------------------------------------------------ | |
# lancer le serveur et le recharger a chaque modification sauvegardee | |
# if __name__ == "__main__": | |
# uvicorn.run("api:app", port=8000, reload=True) | |