Spaces:
Runtime error
Runtime error
File size: 5,097 Bytes
2c35026 cd518e1 2c35026 9cd8995 2c35026 cd518e1 9cd8995 cd518e1 2c35026 cd518e1 2c35026 8dba466 2c35026 cd518e1 2c35026 3c03f61 ef05d9e cd518e1 ef05d9e 3c03f61 ef05d9e cd518e1 ef05d9e cd518e1 ef05d9e cd518e1 ef05d9e 3c03f61 cd518e1 ef05d9e 2c35026 cd518e1 8dba466 cd518e1 2c35026 cd518e1 8dba466 2c35026 cd518e1 2c35026 8dba466 2c35026 70d598e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import re
from src.inference_lstm import inference_lstm
from src.inference_t5 import inference_t5
def summarize(text: str):
"""
Returns the summary of an input text.
Parameter
---------
text : str
A text to summarize.
Returns
-------
:str
The summary of the input text.
"""
if global_choose_model.var == "lstm":
text = " ".join(inference_lstm(text))
return re.sub("^1|1$|<start>|<end>", "", text)
elif global_choose_model.var == "fineTunedT5":
text = inference_t5(text)
return re.sub("<extra_id_0> ", "", text)
elif global_choose_model.var == "":
return "You have not chosen a model."
def global_choose_model(model_choice):
"""This function allows to connect the choice of the
model and the summary function by defining global variables.
The aime is to access a variable outside of a function."""
if model_choice == "lstm":
global_choose_model.var = "lstm"
elif model_choice == "fineTunedT5":
global_choose_model.var = "fineTunedT5"
elif model_choice == " --- ":
global_choose_model.var = ""
# definition of the main elements used in the script
model_list = [
{"model": " --- ", "name": " --- "},
{"model": "lstm", "name": "LSTM"},
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
]
selected_model = " --- "
model_choice = ""
# -------- API ---------------------------------------------------------------
app = FastAPI()
# static files to send the css
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
"""This function is used to create an endpoint for the
index page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/model")
async def get_model(request: Request):
"""This function is used to create an endpoint for
the model page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/model",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/predict")
async def get_prediction(request: Request):
"""This function is used to create an endpoint for
the predict page of the app."""
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/predict"}
)
@app.post("/model")
async def choose_model(request: Request, model_choice: str = Form(None)):
"""This functions allows to retrieve the model chosen by the user. Then, it
can end to an error message if it not defined or it is sent to the
global_choose_model function which connects the user choice to the
use of a model."""
selected_model = model_choice
# print(selected_model)
if not model_choice:
model_error = "Please select a model."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": model_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
global_choose_model(model_choice)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
"""This function allows to retrieve the input text of the user.
Then, it can end to an error message or it can be sent to
the summarize function."""
if not text:
text_error = "Please enter your text."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text,
"summary": summary,
"model_list": model_list,
"selected_model": selected_model,
},
)
# ------------------------------------------------------------------------------------
# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
# uvicorn.run("api:app", port=8000, reload=True)
|