from fastapi import FastAPI, Form, Request from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import re from src.inference_lstm import inference_lstm from src.inference_t5 import inference_t5 def summarize(text: str): """ Returns the summary of an input text. Parameter --------- text : str A text to summarize. Returns ------- :str The summary of the input text. """ if global_choose_model.var == "lstm": text = " ".join(inference_lstm(text)) return re.sub("^1|1$||", "", text) elif global_choose_model.var == "fineTunedT5": text = inference_t5(text) return re.sub(" ", "", text) elif global_choose_model.var == "": return "You have not chosen a model." def global_choose_model(model_choice): """This function allows to connect the choice of the model and the summary function by defining global variables. The aime is to access a variable outside of a function.""" if model_choice == "lstm": global_choose_model.var = "lstm" elif model_choice == "fineTunedT5": global_choose_model.var = "fineTunedT5" elif model_choice == " --- ": global_choose_model.var = "" # definition of the main elements used in the script model_list = [ {"model": " --- ", "name": " --- "}, {"model": "lstm", "name": "LSTM"}, {"model": "fineTunedT5", "name": "Fine-tuned T5"}, ] selected_model = " --- " model_choice = "" # -------- API --------------------------------------------------------------- app = FastAPI() # static files to send the css templates = Jinja2Templates(directory="templates") app.mount("/templates", StaticFiles(directory="templates"), name="templates") @app.get("/") async def index(request: Request): """This function is used to create an endpoint for the index page of the app.""" return templates.TemplateResponse( "index.html.jinja", { "request": request, "current_route": "/", "model_list": model_list, "selected_model": selected_model, }, ) @app.get("/model") async def get_model(request: Request): """This function is used to create an endpoint for the model page of the app.""" return templates.TemplateResponse( "index.html.jinja", { "request": request, "current_route": "/model", "model_list": model_list, "selected_model": selected_model, }, ) @app.get("/predict") async def get_prediction(request: Request): """This function is used to create an endpoint for the predict page of the app.""" return templates.TemplateResponse( "index.html.jinja", {"request": request, "current_route": "/predict"} ) @app.post("/model") async def choose_model(request: Request, model_choice: str = Form(None)): """This functions allows to retrieve the model chosen by the user. Then, it can end to an error message if it not defined or it is sent to the global_choose_model function which connects the user choice to the use of a model.""" selected_model = model_choice # print(selected_model) if not model_choice: model_error = "Please select a model." return templates.TemplateResponse( "index.html.jinja", { "request": request, "text": model_error, "model_list": model_list, "selected_model": selected_model, }, ) else: global_choose_model(model_choice) return templates.TemplateResponse( "index.html.jinja", { "request": request, "model_list": model_list, "selected_model": selected_model, }, ) @app.post("/predict") async def prediction(request: Request, text: str = Form(None)): """This function allows to retrieve the input text of the user. Then, it can end to an error message or it can be sent to the summarize function.""" if not text: text_error = "Please enter your text." return templates.TemplateResponse( "index.html.jinja", { "request": request, "text": text_error, "model_list": model_list, "selected_model": selected_model, }, ) else: summary = summarize(text) return templates.TemplateResponse( "index.html.jinja", { "request": request, "text": text, "summary": summary, "model_list": model_list, "selected_model": selected_model, }, ) # ------------------------------------------------------------------------------------ # lancer le serveur et le recharger a chaque modification sauvegardee # if __name__ == "__main__": # uvicorn.run("api:app", port=8000, reload=True)