Spaces:
Runtime error
Runtime error
fix api problem and tokent auth
Browse files- .cache/init.txt +0 -1
- Dockerfile +12 -0
- model/vocab.pkl +0 -0
- src/api.py +119 -39
- src/fine_tune_T5.py +5 -2
- src/inference_t5.py +3 -3
- templates/index.html.jinja +1 -1
.cache/init.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
initial
|
|
|
|
Dockerfile
CHANGED
@@ -6,6 +6,18 @@ COPY requirements.txt .
|
|
6 |
|
7 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
COPY . .
|
10 |
|
11 |
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
6 |
|
7 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
8 |
|
9 |
+
RUN useradd -m -u 1000 user
|
10 |
+
|
11 |
+
USER user
|
12 |
+
|
13 |
+
ENV HOME=/home/user \
|
14 |
+
PATH=/home/user/.local/bin:$PATH
|
15 |
+
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
|
21 |
COPY . .
|
22 |
|
23 |
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
|
model/vocab.pkl
ADDED
Binary file (63.4 kB). View file
|
|
src/api.py
CHANGED
@@ -1,82 +1,162 @@
|
|
1 |
from fastapi import FastAPI, Form, Request
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.templating import Jinja2Templates
|
|
|
|
|
4 |
|
5 |
from src.inference_lstm import inference_lstm
|
6 |
from src.inference_t5 import inference_t5
|
7 |
|
8 |
|
9 |
-
# ------ INFERENCE MODEL --------------------------------------------------------------
|
10 |
-
# appel de la fonction inference, adaptee pour une entree txt
|
11 |
def summarize(text: str):
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
text = inference_t5(text)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
def
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
app = FastAPI()
|
33 |
|
34 |
-
# static files
|
35 |
templates = Jinja2Templates(directory="templates")
|
36 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
37 |
|
38 |
|
39 |
@app.get("/")
|
40 |
async def index(request: Request):
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
@app.get("/model")
|
45 |
-
async def
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
@app.get("/predict")
|
50 |
-
async def
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
@app.post("/model")
|
55 |
-
async def
|
56 |
-
|
57 |
-
if not
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
return templates.TemplateResponse(
|
60 |
-
"index.html.jinja",
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
else:
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
|
68 |
-
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
|
69 |
@app.post("/predict")
|
70 |
async def prediction(request: Request, text: str = Form(None)):
|
|
|
|
|
|
|
71 |
if not text:
|
72 |
-
|
73 |
return templates.TemplateResponse(
|
74 |
-
"index.html.jinja",
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
else:
|
77 |
summary = summarize(text)
|
78 |
return templates.TemplateResponse(
|
79 |
-
"index.html.jinja",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
)
|
81 |
|
82 |
|
|
|
1 |
from fastapi import FastAPI, Form, Request
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.templating import Jinja2Templates
|
4 |
+
import re
|
5 |
+
|
6 |
|
7 |
from src.inference_lstm import inference_lstm
|
8 |
from src.inference_t5 import inference_t5
|
9 |
|
10 |
|
|
|
|
|
11 |
def summarize(text: str):
|
12 |
+
"""
|
13 |
+
Returns the summary of an input text.
|
14 |
+
Parameter
|
15 |
+
---------
|
16 |
+
text : str
|
17 |
+
A text to summarize.
|
18 |
+
Returns
|
19 |
+
-------
|
20 |
+
:str
|
21 |
+
The summary of the input text.
|
22 |
+
"""
|
23 |
+
if global_choose_model.var == "lstm":
|
24 |
+
text = " ".join(inference_lstm(text))
|
25 |
+
return re.sub("^1|1$|<start>|<end>", "", text)
|
26 |
+
elif global_choose_model.var == "fineTunedT5":
|
27 |
text = inference_t5(text)
|
28 |
+
return re.sub("<extra_id_0> ", "", text)
|
29 |
+
elif global_choose_model.var == "":
|
30 |
+
return "You have not chosen a model."
|
31 |
+
|
32 |
+
|
33 |
+
def global_choose_model(model_choice):
|
34 |
+
"""This function allows to connect the choice of the
|
35 |
+
model and the summary function by defining global variables.
|
36 |
+
The aime is to access a variable outside of a function."""
|
37 |
+
if model_choice == "lstm":
|
38 |
+
global_choose_model.var = "lstm"
|
39 |
+
elif model_choice == "fineTunedT5":
|
40 |
+
global_choose_model.var = "fineTunedT5"
|
41 |
+
elif model_choice == " --- ":
|
42 |
+
global_choose_model.var = ""
|
43 |
+
|
44 |
+
|
45 |
+
# definition of the main elements used in the script
|
46 |
+
model_list = [
|
47 |
+
{"model": " --- ", "name": " --- "},
|
48 |
+
{"model": "lstm", "name": "LSTM"},
|
49 |
+
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
|
50 |
+
]
|
51 |
+
selected_model = " --- "
|
52 |
+
model_choice = ""
|
53 |
+
|
54 |
+
|
55 |
+
# -------- API ---------------------------------------------------------------
|
56 |
app = FastAPI()
|
57 |
|
58 |
+
# static files to send the css
|
59 |
templates = Jinja2Templates(directory="templates")
|
60 |
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
61 |
|
62 |
|
63 |
@app.get("/")
|
64 |
async def index(request: Request):
|
65 |
+
"""This function is used to create an endpoint for the
|
66 |
+
index page of the app."""
|
67 |
+
return templates.TemplateResponse(
|
68 |
+
"index.html.jinja",
|
69 |
+
{
|
70 |
+
"request": request,
|
71 |
+
"current_route": "/",
|
72 |
+
"model_list": model_list,
|
73 |
+
"selected_model": selected_model,
|
74 |
+
},
|
75 |
+
)
|
76 |
|
77 |
|
78 |
@app.get("/model")
|
79 |
+
async def get_model(request: Request):
|
80 |
+
"""This function is used to create an endpoint for
|
81 |
+
the model page of the app."""
|
82 |
+
return templates.TemplateResponse(
|
83 |
+
"index.html.jinja",
|
84 |
+
{
|
85 |
+
"request": request,
|
86 |
+
"current_route": "/model",
|
87 |
+
"model_list": model_list,
|
88 |
+
"selected_model": selected_model,
|
89 |
+
},
|
90 |
+
)
|
91 |
|
92 |
|
93 |
@app.get("/predict")
|
94 |
+
async def get_prediction(request: Request):
|
95 |
+
"""This function is used to create an endpoint for
|
96 |
+
the predict page of the app."""
|
97 |
+
return templates.TemplateResponse(
|
98 |
+
"index.html.jinja", {"request": request, "current_route": "/predict"}
|
99 |
+
)
|
100 |
|
101 |
|
102 |
@app.post("/model")
|
103 |
+
async def choose_model(request: Request, model_choice: str = Form(None)):
|
104 |
+
"""This functions allows to retrieve the model chosen by the user. Then, it
|
105 |
+
can end to an error message if it not defined or it is sent to the
|
106 |
+
global_choose_model function which connects the user choice to the
|
107 |
+
use of a model."""
|
108 |
+
selected_model = model_choice
|
109 |
+
# print(selected_model)
|
110 |
+
if not model_choice:
|
111 |
+
model_error = "Please select a model."
|
112 |
return templates.TemplateResponse(
|
113 |
+
"index.html.jinja",
|
114 |
+
{
|
115 |
+
"request": request,
|
116 |
+
"text": model_error,
|
117 |
+
"model_list": model_list,
|
118 |
+
"selected_model": selected_model,
|
119 |
+
},
|
120 |
)
|
121 |
else:
|
122 |
+
global_choose_model(model_choice)
|
123 |
+
return templates.TemplateResponse(
|
124 |
+
"index.html.jinja",
|
125 |
+
{
|
126 |
+
"request": request,
|
127 |
+
"model_list": model_list,
|
128 |
+
"selected_model": selected_model,
|
129 |
+
},
|
130 |
+
)
|
131 |
|
132 |
|
|
|
133 |
@app.post("/predict")
|
134 |
async def prediction(request: Request, text: str = Form(None)):
|
135 |
+
"""This function allows to retrieve the input text of the user.
|
136 |
+
Then, it can end to an error message or it can be sent to
|
137 |
+
the summarize function."""
|
138 |
if not text:
|
139 |
+
text_error = "Please enter your text."
|
140 |
return templates.TemplateResponse(
|
141 |
+
"index.html.jinja",
|
142 |
+
{
|
143 |
+
"request": request,
|
144 |
+
"text": text_error,
|
145 |
+
"model_list": model_list,
|
146 |
+
"selected_model": selected_model,
|
147 |
+
},
|
148 |
)
|
149 |
else:
|
150 |
summary = summarize(text)
|
151 |
return templates.TemplateResponse(
|
152 |
+
"index.html.jinja",
|
153 |
+
{
|
154 |
+
"request": request,
|
155 |
+
"text": text,
|
156 |
+
"summary": summary,
|
157 |
+
"model_list": model_list,
|
158 |
+
"selected_model": selected_model,
|
159 |
+
},
|
160 |
)
|
161 |
|
162 |
|
src/fine_tune_T5.py
CHANGED
@@ -155,7 +155,9 @@ if __name__ == '__main__':
|
|
155 |
# définition de device
|
156 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
157 |
# faire appel au model à entrainer
|
158 |
-
|
|
|
|
|
159 |
|
160 |
mt5_config = AutoConfig.from_pretrained(
|
161 |
"google/mt5-small",
|
@@ -163,6 +165,7 @@ if __name__ == '__main__':
|
|
163 |
length_penalty=0.6,
|
164 |
no_repeat_ngram_size=2,
|
165 |
num_beams=15,
|
|
|
166 |
)
|
167 |
|
168 |
model = (AutoModelForSeq2SeqLM
|
@@ -238,7 +241,7 @@ if __name__ == '__main__':
|
|
238 |
|
239 |
# faire appel au model en local
|
240 |
model = (AutoModelForSeq2SeqLM
|
241 |
-
.from_pretrained("t5_summary")
|
242 |
.to(device))
|
243 |
|
244 |
|
|
|
155 |
# définition de device
|
156 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
157 |
# faire appel au model à entrainer
|
158 |
+
|
159 |
+
hf_token = "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"
|
160 |
+
tokenizer = AutoTokenizer.from_pretrained('google/mt5-small', use_auth_token=hf_token )
|
161 |
|
162 |
mt5_config = AutoConfig.from_pretrained(
|
163 |
"google/mt5-small",
|
|
|
165 |
length_penalty=0.6,
|
166 |
no_repeat_ngram_size=2,
|
167 |
num_beams=15,
|
168 |
+
use_auth_token=hf_token
|
169 |
)
|
170 |
|
171 |
model = (AutoModelForSeq2SeqLM
|
|
|
241 |
|
242 |
# faire appel au model en local
|
243 |
model = (AutoModelForSeq2SeqLM
|
244 |
+
.from_pretrained("t5_summary", use_auth_token=hf_token )
|
245 |
.to(device))
|
246 |
|
247 |
|
src/inference_t5.py
CHANGED
@@ -34,11 +34,11 @@ def inference_t5(text: str) -> str:
|
|
34 |
# On défini les paramètres d'entrée pour le modèle
|
35 |
text = clean_text(text)
|
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
-
|
38 |
-
tokenizer =
|
39 |
# load local model
|
40 |
model = (AutoModelForSeq2SeqLM
|
41 |
-
.from_pretrained("Linggg/t5_summary",use_auth_token=
|
42 |
.to(device))
|
43 |
|
44 |
|
|
|
34 |
# On défini les paramètres d'entrée pour le modèle
|
35 |
text = clean_text(text)
|
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
hf_token = "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"
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token )
|
39 |
# load local model
|
40 |
model = (AutoModelForSeq2SeqLM
|
41 |
+
.from_pretrained("Linggg/t5_summary", use_auth_token = hf_token )
|
42 |
.to(device))
|
43 |
|
44 |
|
templates/index.html.jinja
CHANGED
@@ -34,7 +34,7 @@
|
|
34 |
<select name="model_choice" class="selectModel" id="model_choice">
|
35 |
<!--A for jinja loop to retrieve option buttons from the api
|
36 |
and to keep them selected when a choice is made. -->
|
37 |
-
{% for x in model_list%}
|
38 |
{%if selected_model == x.model%}
|
39 |
<option value="{{x.model}}" selected>{{x.name}}</option>
|
40 |
{%else%}
|
|
|
34 |
<select name="model_choice" class="selectModel" id="model_choice">
|
35 |
<!--A for jinja loop to retrieve option buttons from the api
|
36 |
and to keep them selected when a choice is made. -->
|
37 |
+
{% for x in model_list %}
|
38 |
{%if selected_model == x.model%}
|
39 |
<option value="{{x.model}}" selected>{{x.name}}</option>
|
40 |
{%else%}
|