EveSa commited on
Commit
cd518e1
1 Parent(s): c062ab4

fix api problem and tokent auth

Browse files
.cache/init.txt DELETED
@@ -1 +0,0 @@
1
- initial
 
 
Dockerfile CHANGED
@@ -6,6 +6,18 @@ COPY requirements.txt .
6
 
7
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  COPY . .
10
 
11
  CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
 
6
 
7
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
8
 
9
+ RUN useradd -m -u 1000 user
10
+
11
+ USER user
12
+
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
  COPY . .
22
 
23
  CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "7860"]
model/vocab.pkl ADDED
Binary file (63.4 kB). View file
 
src/api.py CHANGED
@@ -1,82 +1,162 @@
1
  from fastapi import FastAPI, Form, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.templating import Jinja2Templates
 
 
4
 
5
  from src.inference_lstm import inference_lstm
6
  from src.inference_t5 import inference_t5
7
 
8
 
9
- # ------ INFERENCE MODEL --------------------------------------------------------------
10
- # appel de la fonction inference, adaptee pour une entree txt
11
  def summarize(text: str):
12
- if choisir_modele.var == "lstm":
13
- return " ".join(inference_lstm(text))
14
- elif choisir_modele.var == "fineTunedT5":
 
 
 
 
 
 
 
 
 
 
 
 
15
  text = inference_t5(text)
16
-
17
-
18
- # ----------------------------------------------------------------------------------
19
-
20
-
21
- def choisir_modele(choixModele):
22
- print("ON A RECUP LE CHOIX MODELE")
23
- if choixModele == "lstm":
24
- choisir_modele.var = "lstm"
25
- elif choixModele == "fineTunedT5":
26
- choisir_modele.var = "fineTunedT5"
27
- else:
28
- "le modele n'est pas defini"
29
-
30
-
31
- # -------- API ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
32
  app = FastAPI()
33
 
34
- # static files pour envoi du css au navigateur
35
  templates = Jinja2Templates(directory="templates")
36
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
37
 
38
 
39
  @app.get("/")
40
  async def index(request: Request):
41
- return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  @app.get("/model")
45
- async def index(request: Request):
46
- return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  @app.get("/predict")
50
- async def index(request: Request):
51
- return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
 
 
52
 
53
 
54
  @app.post("/model")
55
- async def choix_model(request: Request, choixModel: str = Form(None)):
56
- print(choixModel)
57
- if not choixModel:
58
- erreur_modele = "Merci de saisir un modèle."
 
 
 
 
 
59
  return templates.TemplateResponse(
60
- "index.html.jinja", {"request": request, "text": erreur_modele}
 
 
 
 
 
 
61
  )
62
  else:
63
- choisir_modele(choixModel)
64
- print("C'est bon on utilise le modèle demandé")
65
- return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
 
 
 
 
66
 
67
 
68
- # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
69
  @app.post("/predict")
70
  async def prediction(request: Request, text: str = Form(None)):
 
 
 
71
  if not text:
72
- error = "Merci de saisir votre texte."
73
  return templates.TemplateResponse(
74
- "index.html.jinja", {"request": request, "text": error}
 
 
 
 
 
 
75
  )
76
  else:
77
  summary = summarize(text)
78
  return templates.TemplateResponse(
79
- "index.html.jinja", {"request": request, "text": text, "summary": summary}
 
 
 
 
 
 
 
80
  )
81
 
82
 
 
1
  from fastapi import FastAPI, Form, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.templating import Jinja2Templates
4
+ import re
5
+
6
 
7
  from src.inference_lstm import inference_lstm
8
  from src.inference_t5 import inference_t5
9
 
10
 
 
 
11
  def summarize(text: str):
12
+ """
13
+ Returns the summary of an input text.
14
+ Parameter
15
+ ---------
16
+ text : str
17
+ A text to summarize.
18
+ Returns
19
+ -------
20
+ :str
21
+ The summary of the input text.
22
+ """
23
+ if global_choose_model.var == "lstm":
24
+ text = " ".join(inference_lstm(text))
25
+ return re.sub("^1|1$|<start>|<end>", "", text)
26
+ elif global_choose_model.var == "fineTunedT5":
27
  text = inference_t5(text)
28
+ return re.sub("<extra_id_0> ", "", text)
29
+ elif global_choose_model.var == "":
30
+ return "You have not chosen a model."
31
+
32
+
33
+ def global_choose_model(model_choice):
34
+ """This function allows to connect the choice of the
35
+ model and the summary function by defining global variables.
36
+ The aime is to access a variable outside of a function."""
37
+ if model_choice == "lstm":
38
+ global_choose_model.var = "lstm"
39
+ elif model_choice == "fineTunedT5":
40
+ global_choose_model.var = "fineTunedT5"
41
+ elif model_choice == " --- ":
42
+ global_choose_model.var = ""
43
+
44
+
45
+ # definition of the main elements used in the script
46
+ model_list = [
47
+ {"model": " --- ", "name": " --- "},
48
+ {"model": "lstm", "name": "LSTM"},
49
+ {"model": "fineTunedT5", "name": "Fine-tuned T5"},
50
+ ]
51
+ selected_model = " --- "
52
+ model_choice = ""
53
+
54
+
55
+ # -------- API ---------------------------------------------------------------
56
  app = FastAPI()
57
 
58
+ # static files to send the css
59
  templates = Jinja2Templates(directory="templates")
60
  app.mount("/templates", StaticFiles(directory="templates"), name="templates")
61
 
62
 
63
  @app.get("/")
64
  async def index(request: Request):
65
+ """This function is used to create an endpoint for the
66
+ index page of the app."""
67
+ return templates.TemplateResponse(
68
+ "index.html.jinja",
69
+ {
70
+ "request": request,
71
+ "current_route": "/",
72
+ "model_list": model_list,
73
+ "selected_model": selected_model,
74
+ },
75
+ )
76
 
77
 
78
  @app.get("/model")
79
+ async def get_model(request: Request):
80
+ """This function is used to create an endpoint for
81
+ the model page of the app."""
82
+ return templates.TemplateResponse(
83
+ "index.html.jinja",
84
+ {
85
+ "request": request,
86
+ "current_route": "/model",
87
+ "model_list": model_list,
88
+ "selected_model": selected_model,
89
+ },
90
+ )
91
 
92
 
93
  @app.get("/predict")
94
+ async def get_prediction(request: Request):
95
+ """This function is used to create an endpoint for
96
+ the predict page of the app."""
97
+ return templates.TemplateResponse(
98
+ "index.html.jinja", {"request": request, "current_route": "/predict"}
99
+ )
100
 
101
 
102
  @app.post("/model")
103
+ async def choose_model(request: Request, model_choice: str = Form(None)):
104
+ """This functions allows to retrieve the model chosen by the user. Then, it
105
+ can end to an error message if it not defined or it is sent to the
106
+ global_choose_model function which connects the user choice to the
107
+ use of a model."""
108
+ selected_model = model_choice
109
+ # print(selected_model)
110
+ if not model_choice:
111
+ model_error = "Please select a model."
112
  return templates.TemplateResponse(
113
+ "index.html.jinja",
114
+ {
115
+ "request": request,
116
+ "text": model_error,
117
+ "model_list": model_list,
118
+ "selected_model": selected_model,
119
+ },
120
  )
121
  else:
122
+ global_choose_model(model_choice)
123
+ return templates.TemplateResponse(
124
+ "index.html.jinja",
125
+ {
126
+ "request": request,
127
+ "model_list": model_list,
128
+ "selected_model": selected_model,
129
+ },
130
+ )
131
 
132
 
 
133
  @app.post("/predict")
134
  async def prediction(request: Request, text: str = Form(None)):
135
+ """This function allows to retrieve the input text of the user.
136
+ Then, it can end to an error message or it can be sent to
137
+ the summarize function."""
138
  if not text:
139
+ text_error = "Please enter your text."
140
  return templates.TemplateResponse(
141
+ "index.html.jinja",
142
+ {
143
+ "request": request,
144
+ "text": text_error,
145
+ "model_list": model_list,
146
+ "selected_model": selected_model,
147
+ },
148
  )
149
  else:
150
  summary = summarize(text)
151
  return templates.TemplateResponse(
152
+ "index.html.jinja",
153
+ {
154
+ "request": request,
155
+ "text": text,
156
+ "summary": summary,
157
+ "model_list": model_list,
158
+ "selected_model": selected_model,
159
+ },
160
  )
161
 
162
 
src/fine_tune_T5.py CHANGED
@@ -155,7 +155,9 @@ if __name__ == '__main__':
155
  # définition de device
156
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157
  # faire appel au model à entrainer
158
- tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
 
 
159
 
160
  mt5_config = AutoConfig.from_pretrained(
161
  "google/mt5-small",
@@ -163,6 +165,7 @@ if __name__ == '__main__':
163
  length_penalty=0.6,
164
  no_repeat_ngram_size=2,
165
  num_beams=15,
 
166
  )
167
 
168
  model = (AutoModelForSeq2SeqLM
@@ -238,7 +241,7 @@ if __name__ == '__main__':
238
 
239
  # faire appel au model en local
240
  model = (AutoModelForSeq2SeqLM
241
- .from_pretrained("t5_summary")
242
  .to(device))
243
 
244
 
 
155
  # définition de device
156
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
157
  # faire appel au model à entrainer
158
+
159
+ hf_token = "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"
160
+ tokenizer = AutoTokenizer.from_pretrained('google/mt5-small', use_auth_token=hf_token )
161
 
162
  mt5_config = AutoConfig.from_pretrained(
163
  "google/mt5-small",
 
165
  length_penalty=0.6,
166
  no_repeat_ngram_size=2,
167
  num_beams=15,
168
+ use_auth_token=hf_token
169
  )
170
 
171
  model = (AutoModelForSeq2SeqLM
 
241
 
242
  # faire appel au model en local
243
  model = (AutoModelForSeq2SeqLM
244
+ .from_pretrained("t5_summary", use_auth_token=hf_token )
245
  .to(device))
246
 
247
 
src/inference_t5.py CHANGED
@@ -34,11 +34,11 @@ def inference_t5(text: str) -> str:
34
  # On défini les paramètres d'entrée pour le modèle
35
  text = clean_text(text)
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
-
38
- tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
39
  # load local model
40
  model = (AutoModelForSeq2SeqLM
41
- .from_pretrained("Linggg/t5_summary",use_auth_token=True)
42
  .to(device))
43
 
44
 
 
34
  # On défini les paramètres d'entrée pour le modèle
35
  text = clean_text(text)
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ hf_token = "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"
38
+ tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token )
39
  # load local model
40
  model = (AutoModelForSeq2SeqLM
41
+ .from_pretrained("Linggg/t5_summary", use_auth_token = hf_token )
42
  .to(device))
43
 
44
 
templates/index.html.jinja CHANGED
@@ -34,7 +34,7 @@
34
  <select name="model_choice" class="selectModel" id="model_choice">
35
  <!--A for jinja loop to retrieve option buttons from the api
36
  and to keep them selected when a choice is made. -->
37
- {% for x in model_list%}
38
  {%if selected_model == x.model%}
39
  <option value="{{x.model}}" selected>{{x.name}}</option>
40
  {%else%}
 
34
  <select name="model_choice" class="selectModel" id="model_choice">
35
  <!--A for jinja loop to retrieve option buttons from the api
36
  and to keep them selected when a choice is made. -->
37
+ {% for x in model_list %}
38
  {%if selected_model == x.model%}
39
  <option value="{{x.model}}" selected>{{x.name}}</option>
40
  {%else%}