File size: 5,097 Bytes
2c35026
 
 
cd518e1
 
2c35026
9cd8995
 
2c35026
 
 
cd518e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cd8995
cd518e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c35026
 
cd518e1
2c35026
 
 
8dba466
2c35026
 
cd518e1
 
 
 
 
 
 
 
 
 
 
2c35026
3c03f61
ef05d9e
cd518e1
 
 
 
 
 
 
 
 
 
 
 
ef05d9e
3c03f61
ef05d9e
cd518e1
 
 
 
 
 
ef05d9e
 
 
cd518e1
 
 
 
 
 
 
 
 
ef05d9e
cd518e1
 
 
 
 
 
 
ef05d9e
3c03f61
cd518e1
 
 
 
 
 
 
 
 
ef05d9e
 
 
2c35026
cd518e1
 
 
8dba466
cd518e1
2c35026
cd518e1
 
 
 
 
 
 
8dba466
 
2c35026
 
cd518e1
 
 
 
 
 
 
 
2c35026
8dba466
 
2c35026
 
 
 
70d598e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import re


from src.inference_lstm import inference_lstm
from src.inference_t5 import inference_t5


def summarize(text: str):
    """
    Returns the summary of an input text.
    Parameter
    ---------
    text : str
        A text to summarize.
    Returns
    -------
    :str
        The summary of the input text.
    """
    if global_choose_model.var == "lstm":
        text = " ".join(inference_lstm(text))
        return re.sub("^1|1$|<start>|<end>", "", text)
    elif global_choose_model.var == "fineTunedT5":
        text = inference_t5(text)
        return re.sub("<extra_id_0> ", "", text)
    elif global_choose_model.var == "":
        return "You have not chosen a model."


def global_choose_model(model_choice):
    """This function allows to connect the choice of the
    model and the summary function by defining global variables.
    The aime is to access a variable outside of a function."""
    if model_choice == "lstm":
        global_choose_model.var = "lstm"
    elif model_choice == "fineTunedT5":
        global_choose_model.var = "fineTunedT5"
    elif model_choice == "  ---  ":
        global_choose_model.var = ""


# definition of the main elements used in the script
model_list = [
    {"model": "  ---  ", "name": "  ---  "},
    {"model": "lstm", "name": "LSTM"},
    {"model": "fineTunedT5", "name": "Fine-tuned T5"},
]
selected_model = "  ---  "
model_choice = ""


# -------- API ---------------------------------------------------------------
app = FastAPI()

# static files to send the css
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")


@app.get("/")
async def index(request: Request):
    """This function is used to create an endpoint for the
    index page of the app."""
    return templates.TemplateResponse(
        "index.html.jinja",
        {
            "request": request,
            "current_route": "/",
            "model_list": model_list,
            "selected_model": selected_model,
        },
    )


@app.get("/model")
async def get_model(request: Request):
    """This function is used to create an endpoint for
    the model page of the app."""
    return templates.TemplateResponse(
        "index.html.jinja",
        {
            "request": request,
            "current_route": "/model",
            "model_list": model_list,
            "selected_model": selected_model,
        },
    )


@app.get("/predict")
async def get_prediction(request: Request):
    """This function is used to create an endpoint for
    the predict page of the app."""
    return templates.TemplateResponse(
        "index.html.jinja", {"request": request, "current_route": "/predict"}
    )


@app.post("/model")
async def choose_model(request: Request, model_choice: str = Form(None)):
    """This functions allows to retrieve the model chosen by the user. Then, it
    can end to an error message if it not defined or it is sent to the
    global_choose_model function which connects the user choice to the
    use of a model."""
    selected_model = model_choice
    # print(selected_model)
    if not model_choice:
        model_error = "Please select a model."
        return templates.TemplateResponse(
            "index.html.jinja",
            {
                "request": request,
                "text": model_error,
                "model_list": model_list,
                "selected_model": selected_model,
            },
        )
    else:
        global_choose_model(model_choice)
        return templates.TemplateResponse(
            "index.html.jinja",
            {
                "request": request,
                "model_list": model_list,
                "selected_model": selected_model,
            },
        )


@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
    """This function allows to retrieve the input text of the user.
    Then, it can end to an error message or it can be sent to
    the summarize function."""
    if not text:
        text_error = "Please enter your text."
        return templates.TemplateResponse(
            "index.html.jinja",
            {
                "request": request,
                "text": text_error,
                "model_list": model_list,
                "selected_model": selected_model,
            },
        )
    else:
        summary = summarize(text)
        return templates.TemplateResponse(
            "index.html.jinja",
            {
                "request": request,
                "text": text,
                "summary": summary,
                "model_list": model_list,
                "selected_model": selected_model,
            },
        )


# ------------------------------------------------------------------------------------


# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
#     uvicorn.run("api:app", port=8000, reload=True)