Spaces:
Runtime error
Runtime error
Merge branch 'main' into Estelle
Browse files- .cache/init.txt +1 -0
- Documentation.md +104 -17
- api.py +0 -169
- requirements.txt +10 -6
- src/api.py +88 -0
- src/dataloader.py +16 -9
- src/fine_tune_T5.py +257 -0
- src/fine_tune_t5.py +0 -204
- src/{inference.py → inference_lstm.py} +4 -1
- src/inference_t5.py +35 -26
- src/model.py +32 -21
- src/train.py +18 -11
- templates/index.html.jinja +2 -1
.cache/init.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
initial
|
Documentation.md
CHANGED
@@ -4,47 +4,134 @@ L'objectif du projet est de mettre en place une <strong>plateforme de requête</
|
|
4 |
|
5 |
# Une description du système ou des données auxquelles l’interface permet d’accéder
|
6 |
|
|
|
7 |
|
8 |
Le projet utilisera pour l'entraînement du modèle de langue le corpus issu de 'Newsroom: A Dataset of 1.3 Million Summaries with Diverse Extractive Strategies' (Grusky et al., NAACL 2018) newsroom assemblé par Max Grusky et ses collègues en 2018. Newsroom est un corpus parallèle rassemblant 1,3 millions articles de presse et leur résumé en anglais. Les résumés sont réalisés en utilisant les méthodes d'extraction comme d'abstraction ainsi que des méthodes mixtes. Ce corpus est disponible sur HuggingFace mais necessite un téléchargement préalable pour des raisons de protection des données.
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# La méthodologie
|
12 |
|
13 |
-
## Répartition du travail
|
14 |
-
Nous avons
|
15 |
-
|
|
|
16 |
|
17 |
## Problèmes rencontrés et résolution
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
-
|
22 |
-
-
|
23 |
-
-
|
24 |
-
|
25 |
-
-
|
26 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
## Les étapes du projet
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
# Implémentation
|
32 |
## modélisation
|
33 |
|
34 |
-
Nous avons décidé dans un premier temps de modéliser une LSTM pour le
|
35 |
Pour ce faire nous nous sommes beaucoup inspirée du kaggle https://www.kaggle.com/code/columbine/seq2seq-pytorch ainsi que de la documentation de PyTorch https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html#example-an-lstm-for-part-of-speech-tagging
|
|
|
36 |
## modules et API utilisés
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
## Langages de programmation
|
|
|
|
|
|
|
38 |
|
39 |
# Les résultats (fichiers output, visualisations…)
|
40 |
|
41 |
## Les metriques d'évaluation
|
42 |
- ROUGE
|
43 |
- BLEU
|
44 |
-
- QAEval
|
45 |
-
- Meteor
|
46 |
-
- BERTScore
|
47 |
|
48 |
|
49 |
# Discussion des résultats
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Une description du système ou des données auxquelles l’interface permet d’accéder
|
6 |
|
7 |
+
## Les Données 💾
|
8 |
|
9 |
Le projet utilisera pour l'entraînement du modèle de langue le corpus issu de 'Newsroom: A Dataset of 1.3 Million Summaries with Diverse Extractive Strategies' (Grusky et al., NAACL 2018) newsroom assemblé par Max Grusky et ses collègues en 2018. Newsroom est un corpus parallèle rassemblant 1,3 millions articles de presse et leur résumé en anglais. Les résumés sont réalisés en utilisant les méthodes d'extraction comme d'abstraction ainsi que des méthodes mixtes. Ce corpus est disponible sur HuggingFace mais necessite un téléchargement préalable pour des raisons de protection des données.
|
10 |
|
11 |
+
Le corpus est nettoyé avant d'être utilisé pour l'entraînement du LSTM. Seule les parties 'text' et 'summary' du jsonl sont utilisées.
|
12 |
+
|
13 |
+
## Le système 🖥️
|
14 |
+
|
15 |
+
2 systèmes :
|
16 |
+
- LSTM réalisé à partir du <a href="https://loicgrobol.github.io//neural-networks/slides/03-transformers/transformers-slides.py.ipynb">cours</a> et de cet <a href="https://www.kaggle.com/code/columbine/seq2seq-pytorch">exemple</a> et de beaucoup d'autres référence en ligne.
|
17 |
+
- Fine-tuned transformers
|
18 |
|
19 |
# La méthodologie
|
20 |
|
21 |
+
## Répartition du travail 👥
|
22 |
+
Nous avons travaillé avec le logiciel de gestion de version Github en mettant en place une intégration continue envoyant directement les `pull request` sur l'espace Huggingface.
|
23 |
+
|
24 |
+
Idéalement, les `pull request` doivent être validées par deux membres du projet avant d'être accéptées afin d'éviter les erreurs en production. Nous n'avons pas mis en place ces restrictions à cause de la difficulté à gérer Docker dans Huggingface qui nous a nécessité beaucoup de modification.
|
25 |
|
26 |
## Problèmes rencontrés et résolution
|
27 |
+
|
28 |
+
### Problème sur le corpus 📚
|
29 |
+
|
30 |
+
- [x] Problème Mojibake depuis les fichiers jsonl :
|
31 |
+
- [x] encodage en cp1252 et decodage en utf-8 avec ignore pour éviter les erreurs sur les caractères utf-8 présents dans le fichier à l'encodage
|
32 |
+
- ❔Le problème ne se présente étrangement pas sur toutes les machines.
|
33 |
+
- [x] Agglomération des pronoms et des verbes
|
34 |
+
- D'abord remplacement des `'` par des espaces avant le `split`
|
35 |
+
- Utilisation d'un dictionnaire de correspondance
|
36 |
+
- [ ] Split des noms propres composés ('Ivory Coast', 'Inter Milan') :
|
37 |
+
- [ ] pas de résolution à ce jour
|
38 |
+
- [ ] Problème des mots non disponibles dans le vocabulaire
|
39 |
+
- À terme, entraînement sur tout le corpus ?
|
40 |
+
- [ ] Problème de la qualité du corpus :
|
41 |
+
- Résumés tronqués : "Did Tatum O'Neal's latest battle with ex-husband John McEnroe put her back on drugs? The \"Paper Moon\"star checked herself into L.A.'s Promises rehab facility after a friend caught her smoking crack, according to The National Enquirer. O'Neal emerged clean and sober from Promises' 34-day recovery program in late July, the tab reports. The actress is said to have plunged into her old habits because of" ...
|
42 |
+
- Résumés plus proche de titres que de résumés : "SAN DIEGO PADRES team notebook"
|
43 |
+
- [ ] pas de résolution à ce jour
|
44 |
+
|
45 |
+
### Problème sur le Réseau de Neurone 🕸️
|
46 |
+
|
47 |
+
- [x] Prise en compte du padding dans l'apprentissage :
|
48 |
+
- [ ] utilisation de la fonctionnalité ignore_index de NLLLoss avec un padding d'une valeur à -100
|
49 |
+
- [ ] Temps d'apprentissage très long :
|
50 |
+
- [ ] essai de mise en place d'un entraînement par batch
|
51 |
+
- [ ] Répetition des déterminants après entraînement du modèle - https://huggingface.co/blog/how-to-generate
|
52 |
+
- [x] mise en place d'un Beam Search - non fructueux
|
53 |
+
- [ ] Passage vers du Sampling
|
54 |
+
|
55 |
+
### Problème sur le fine-tuning
|
56 |
+
|
57 |
+
### Problème sur l'interface
|
58 |
+
|
59 |
+
### Problème de l'Intégration continue
|
60 |
+
|
61 |
+
- [x] Pas de lien possible entre Huggingface et un github dont l'history contient des fichier de plus de 10Mo
|
62 |
+
- 💣 Explosion du github
|
63 |
+
- [ ] Docker qui fonctionne en local mais en sur Huggingface
|
64 |
+
- Problème de path de fichier
|
65 |
|
66 |
|
67 |
## Les étapes du projet
|
68 |
|
69 |
+
1. Initialisation du Github
|
70 |
+
2. Premiers pas dans le réseau de neurone
|
71 |
+
3. Réalisation de la plateforme
|
72 |
+
4. Intégration à Huggingface
|
73 |
+
5. Fine-tuning de modèle
|
74 |
+
6. Finalisation
|
75 |
+
|
76 |
# Implémentation
|
77 |
## modélisation
|
78 |
|
79 |
+
Nous avons décidé dans un premier temps de modéliser une LSTM pour le résumé automatique sur la base du réseau de neurone réalisé en cours.
|
80 |
Pour ce faire nous nous sommes beaucoup inspirée du kaggle https://www.kaggle.com/code/columbine/seq2seq-pytorch ainsi que de la documentation de PyTorch https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html#example-an-lstm-for-part-of-speech-tagging
|
81 |
+
|
82 |
## modules et API utilisés
|
83 |
+
### Dataloader :
|
84 |
+
- Data
|
85 |
+
```
|
86 |
+
A class used to get data from file
|
87 |
+
...
|
88 |
+
|
89 |
+
Attributes
|
90 |
+
----------
|
91 |
+
path : str
|
92 |
+
the path to the file containing the data
|
93 |
+
|
94 |
+
Methods
|
95 |
+
-------
|
96 |
+
open()
|
97 |
+
open the jsonl file with pandas
|
98 |
+
clean_data(text_type)
|
99 |
+
clean the data got by opening the file and adds <start> and
|
100 |
+
<end> tokens depending on the text_type
|
101 |
+
get_words()
|
102 |
+
get the dataset vocabulary
|
103 |
+
```
|
104 |
+
- Vectoriser
|
105 |
+
```
|
106 |
+
```
|
107 |
+
|
108 |
+
### Model :
|
109 |
+
|
110 |
+
### train :
|
111 |
+
|
112 |
+
### inference :
|
113 |
+
|
114 |
+
### api :
|
115 |
+
|
116 |
+
### templates :
|
117 |
+
|
118 |
## Langages de programmation
|
119 |
+
- 🐳 Docker
|
120 |
+
- yaml
|
121 |
+
- 🐍 et python evidemment
|
122 |
|
123 |
# Les résultats (fichiers output, visualisations…)
|
124 |
|
125 |
## Les metriques d'évaluation
|
126 |
- ROUGE
|
127 |
- BLEU
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
# Discussion des résultats
|
131 |
+
|
132 |
+
## Résultats du LSTM
|
133 |
+
|
134 |
+
Les résultats du LSTM sont inutilisables mais ont permis au moins de se confronter à la difficulté de mettre en place des réseaux de neurones depuis pas grand chose.
|
135 |
+
On aurait aimé avoir plus detemps pour aller plus loin et comprendre mieux encore : l'entraîement par batch, pourquoi les résultats sont si mauvais, mettre d'autres stratégies de génération en place, ...
|
136 |
+
|
137 |
+
## Résultat du fine-tuning
|
api.py
DELETED
@@ -1,169 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import uvicorn
|
3 |
-
from fastapi import FastAPI, Form, Request
|
4 |
-
from fastapi.staticfiles import StaticFiles
|
5 |
-
from fastapi.templating import Jinja2Templates
|
6 |
-
from src.inference import inferenceAPI
|
7 |
-
from src.inference_t5 import inferenceAPI_t5
|
8 |
-
|
9 |
-
|
10 |
-
def summarize(text: str):
|
11 |
-
"""
|
12 |
-
Returns the summary of an input text.
|
13 |
-
|
14 |
-
Parameter
|
15 |
-
---------
|
16 |
-
text : str
|
17 |
-
A text to summarize.
|
18 |
-
|
19 |
-
Returns
|
20 |
-
-------
|
21 |
-
:str
|
22 |
-
The summary of the input text.
|
23 |
-
"""
|
24 |
-
if global_choose_model.var == "lstm":
|
25 |
-
text = " ".join(inferenceAPI(text))
|
26 |
-
return re.sub("^1|1$|<start>|<end>", "", text)
|
27 |
-
elif global_choose_model.var == "fineTunedT5":
|
28 |
-
text = inferenceAPI_t5(text)
|
29 |
-
return re.sub("<extra_id_0> ", "", text)
|
30 |
-
elif global_choose_model.var == "":
|
31 |
-
return "You have not chosen a model."
|
32 |
-
|
33 |
-
|
34 |
-
def global_choose_model(model_choice):
|
35 |
-
"""This function allows to connect the choice of the
|
36 |
-
model and the summary function by defining global variables.
|
37 |
-
The aime is to access a variable outside of a function."""
|
38 |
-
if model_choice == "lstm":
|
39 |
-
global_choose_model.var = "lstm"
|
40 |
-
elif model_choice == "fineTunedT5":
|
41 |
-
global_choose_model.var = "fineTunedT5"
|
42 |
-
elif model_choice == " --- ":
|
43 |
-
global_choose_model.var = ""
|
44 |
-
|
45 |
-
|
46 |
-
# definition of the main elements used in the script
|
47 |
-
model_list = [
|
48 |
-
{"model": " --- ", "name": " --- "},
|
49 |
-
{"model": "lstm", "name": "LSTM"},
|
50 |
-
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
|
51 |
-
]
|
52 |
-
selected_model = " --- "
|
53 |
-
model_choice = ""
|
54 |
-
|
55 |
-
|
56 |
-
# -------- API ---------------------------------------------------------------
|
57 |
-
app = FastAPI()
|
58 |
-
|
59 |
-
# static files to send the css
|
60 |
-
templates = Jinja2Templates(directory="templates")
|
61 |
-
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
62 |
-
|
63 |
-
|
64 |
-
@app.get("/")
|
65 |
-
async def index(request: Request):
|
66 |
-
"""This function is used to create an endpoint for the
|
67 |
-
index page of the app."""
|
68 |
-
return templates.TemplateResponse(
|
69 |
-
"index.html.jinja",
|
70 |
-
{
|
71 |
-
"request": request,
|
72 |
-
"current_route": "/",
|
73 |
-
"model_list": model_list,
|
74 |
-
"selected_model": selected_model,
|
75 |
-
},
|
76 |
-
)
|
77 |
-
|
78 |
-
|
79 |
-
@app.get("/model")
|
80 |
-
async def get_model(request: Request):
|
81 |
-
"""This function is used to create an endpoint for
|
82 |
-
the model page of the app."""
|
83 |
-
return templates.TemplateResponse(
|
84 |
-
"index.html.jinja",
|
85 |
-
{
|
86 |
-
"request": request,
|
87 |
-
"current_route": "/model",
|
88 |
-
"model_list": model_list,
|
89 |
-
"selected_model": selected_model,
|
90 |
-
},
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
@app.get("/predict")
|
95 |
-
async def get_prediction(request: Request):
|
96 |
-
"""This function is used to create an endpoint for
|
97 |
-
the predict page of the app."""
|
98 |
-
return templates.TemplateResponse(
|
99 |
-
"index.html.jinja", {"request": request, "current_route": "/predict"}
|
100 |
-
)
|
101 |
-
|
102 |
-
|
103 |
-
@app.post("/model")
|
104 |
-
async def choose_model(request: Request, model_choice: str = Form(None)):
|
105 |
-
"""This functions allows to retrieve the model chosen by the user. Then, it
|
106 |
-
can end to an error message if it not defined or it is sent to the
|
107 |
-
global_choose_model function which connects the user choice to the
|
108 |
-
use of a model."""
|
109 |
-
selected_model = model_choice
|
110 |
-
# print(selected_model)
|
111 |
-
if not model_choice:
|
112 |
-
model_error = "Please select a model."
|
113 |
-
return templates.TemplateResponse(
|
114 |
-
"index.html.jinja",
|
115 |
-
{
|
116 |
-
"request": request,
|
117 |
-
"text": model_error,
|
118 |
-
"model_list": model_list,
|
119 |
-
"selected_model": selected_model,
|
120 |
-
},
|
121 |
-
)
|
122 |
-
else:
|
123 |
-
global_choose_model(model_choice)
|
124 |
-
return templates.TemplateResponse(
|
125 |
-
"index.html.jinja",
|
126 |
-
{
|
127 |
-
"request": request,
|
128 |
-
"model_list": model_list,
|
129 |
-
"selected_model": selected_model,
|
130 |
-
},
|
131 |
-
)
|
132 |
-
|
133 |
-
|
134 |
-
@app.post("/predict")
|
135 |
-
async def prediction(request: Request, text: str = Form(None)):
|
136 |
-
"""This function allows to retrieve the input text of the user.
|
137 |
-
Then, it can end to an error message or it can be sent to
|
138 |
-
the summarize function."""
|
139 |
-
if not text:
|
140 |
-
text_error = "Please enter your text."
|
141 |
-
return templates.TemplateResponse(
|
142 |
-
"index.html.jinja",
|
143 |
-
{
|
144 |
-
"request": request,
|
145 |
-
"text": text_error,
|
146 |
-
"model_list": model_list,
|
147 |
-
"selected_model": selected_model,
|
148 |
-
},
|
149 |
-
)
|
150 |
-
else:
|
151 |
-
summary = summarize(text)
|
152 |
-
return templates.TemplateResponse(
|
153 |
-
"index.html.jinja",
|
154 |
-
{
|
155 |
-
"request": request,
|
156 |
-
"text": text,
|
157 |
-
"summary": summary,
|
158 |
-
"model_list": model_list,
|
159 |
-
"selected_model": selected_model,
|
160 |
-
},
|
161 |
-
)
|
162 |
-
|
163 |
-
|
164 |
-
# ------------------------------------------------------------------------------------
|
165 |
-
|
166 |
-
|
167 |
-
# launch the server and reload it each time a change is saved
|
168 |
-
if __name__ == "__main__":
|
169 |
-
uvicorn.run("api:app", port=8000, reload=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,14 +1,15 @@
|
|
|
|
1 |
anyio==3.6.2
|
2 |
certifi==2022.12.7
|
3 |
charset-normalizer==3.1.0
|
4 |
click==8.1.3
|
5 |
-
|
|
|
6 |
filelock==3.9.0
|
7 |
h11==0.14.0
|
8 |
-
huggingface-hub==0.13.
|
9 |
idna==3.4
|
10 |
Jinja2==3.1.2
|
11 |
-
joblib==1.2.0
|
12 |
MarkupSafe==2.1.2
|
13 |
numpy==1.24.2
|
14 |
nvidia-cublas-cu11==11.10.3.66
|
@@ -17,7 +18,8 @@ nvidia-cuda-runtime-cu11==11.7.99
|
|
17 |
nvidia-cudnn-cu11==8.5.0.96
|
18 |
packaging==23.0
|
19 |
pandas==1.5.3
|
20 |
-
|
|
|
21 |
python-dateutil==2.8.2
|
22 |
python-multipart==0.0.6
|
23 |
pytz==2022.7.1
|
@@ -26,10 +28,12 @@ regex==2022.10.31
|
|
26 |
requests==2.28.2
|
27 |
six==1.16.0
|
28 |
sniffio==1.3.0
|
29 |
-
starlette==0.
|
|
|
30 |
tokenizers==0.13.2
|
31 |
torch==1.13.1
|
32 |
tqdm==4.65.0
|
|
|
33 |
typing_extensions==4.5.0
|
34 |
urllib3==1.26.15
|
35 |
-
uvicorn==0.
|
|
|
1 |
+
anyascii==0.3.1
|
2 |
anyio==3.6.2
|
3 |
certifi==2022.12.7
|
4 |
charset-normalizer==3.1.0
|
5 |
click==8.1.3
|
6 |
+
contractions==0.1.73
|
7 |
+
fastapi==0.94.0
|
8 |
filelock==3.9.0
|
9 |
h11==0.14.0
|
10 |
+
huggingface-hub==0.13.2
|
11 |
idna==3.4
|
12 |
Jinja2==3.1.2
|
|
|
13 |
MarkupSafe==2.1.2
|
14 |
numpy==1.24.2
|
15 |
nvidia-cublas-cu11==11.10.3.66
|
|
|
18 |
nvidia-cudnn-cu11==8.5.0.96
|
19 |
packaging==23.0
|
20 |
pandas==1.5.3
|
21 |
+
pyahocorasick==2.0.0
|
22 |
+
pydantic==1.10.6
|
23 |
python-dateutil==2.8.2
|
24 |
python-multipart==0.0.6
|
25 |
pytz==2022.7.1
|
|
|
28 |
requests==2.28.2
|
29 |
six==1.16.0
|
30 |
sniffio==1.3.0
|
31 |
+
starlette==0.26.1
|
32 |
+
textsearch==0.0.24
|
33 |
tokenizers==0.13.2
|
34 |
torch==1.13.1
|
35 |
tqdm==4.65.0
|
36 |
+
transformers==4.26.1
|
37 |
typing_extensions==4.5.0
|
38 |
urllib3==1.26.15
|
39 |
+
uvicorn==0.21.0
|
src/api.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Form, Request
|
2 |
+
from fastapi.staticfiles import StaticFiles
|
3 |
+
from fastapi.templating import Jinja2Templates
|
4 |
+
|
5 |
+
from src.inference_lstm import inference_lstm
|
6 |
+
from src.inference_t5 import inference_t5
|
7 |
+
|
8 |
+
|
9 |
+
# ------ INFERENCE MODEL --------------------------------------------------------------
|
10 |
+
# appel de la fonction inference, adaptee pour une entree txt
|
11 |
+
def summarize(text: str):
|
12 |
+
if choisir_modele.var == "lstm":
|
13 |
+
return " ".join(inference_lstm(text))
|
14 |
+
elif choisir_modele.var == "fineTunedT5":
|
15 |
+
text = inference_t5(text)
|
16 |
+
|
17 |
+
|
18 |
+
# ----------------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
def choisir_modele(choixModele):
|
22 |
+
print("ON A RECUP LE CHOIX MODELE")
|
23 |
+
if choixModele == "lstm":
|
24 |
+
choisir_modele.var = "lstm"
|
25 |
+
elif choixModele == "fineTunedT5":
|
26 |
+
choisir_modele.var = "fineTunedT5"
|
27 |
+
else:
|
28 |
+
"le modele n'est pas defini"
|
29 |
+
|
30 |
+
|
31 |
+
# -------- API ---------------------------------------------------------------------
|
32 |
+
app = FastAPI()
|
33 |
+
|
34 |
+
# static files pour envoi du css au navigateur
|
35 |
+
templates = Jinja2Templates(directory="templates")
|
36 |
+
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
|
37 |
+
|
38 |
+
|
39 |
+
@app.get("/")
|
40 |
+
async def index(request: Request):
|
41 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
42 |
+
|
43 |
+
|
44 |
+
@app.get("/model")
|
45 |
+
async def index(request: Request):
|
46 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
47 |
+
|
48 |
+
|
49 |
+
@app.get("/predict")
|
50 |
+
async def index(request: Request):
|
51 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
52 |
+
|
53 |
+
|
54 |
+
@app.post("/model")
|
55 |
+
async def choix_model(request: Request, choixModel: str = Form(None)):
|
56 |
+
print(choixModel)
|
57 |
+
if not choixModel:
|
58 |
+
erreur_modele = "Merci de saisir un modèle."
|
59 |
+
return templates.TemplateResponse(
|
60 |
+
"index.html.jinja", {"request": request, "text": erreur_modele}
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
choisir_modele(choixModel)
|
64 |
+
print("C'est bon on utilise le modèle demandé")
|
65 |
+
return templates.TemplateResponse("index.html.jinja", {"request": request})
|
66 |
+
|
67 |
+
|
68 |
+
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
|
69 |
+
@app.post("/predict")
|
70 |
+
async def prediction(request: Request, text: str = Form(None)):
|
71 |
+
if not text:
|
72 |
+
error = "Merci de saisir votre texte."
|
73 |
+
return templates.TemplateResponse(
|
74 |
+
"index.html.jinja", {"request": request, "text": error}
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
summary = summarize(text)
|
78 |
+
return templates.TemplateResponse(
|
79 |
+
"index.html.jinja", {"request": request, "text": text, "summary": summary}
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
# ------------------------------------------------------------------------------------
|
84 |
+
|
85 |
+
|
86 |
+
# lancer le serveur et le recharger a chaque modification sauvegardee
|
87 |
+
# if __name__ == "__main__":
|
88 |
+
# uvicorn.run("api:app", port=8000, reload=True)
|
src/dataloader.py
CHANGED
@@ -38,8 +38,6 @@ class Data(torch.utils.data.Dataset):
|
|
38 |
<end> tokens depending on the text_type
|
39 |
get_words()
|
40 |
get the dataset vocabulary
|
41 |
-
make_dataset()
|
42 |
-
create a dataset with cleaned data
|
43 |
"""
|
44 |
|
45 |
def __init__(self, path: str, transform=None) -> None:
|
@@ -52,10 +50,15 @@ class Data(torch.utils.data.Dataset):
|
|
52 |
|
53 |
def __getitem__(self, idx):
|
54 |
row = self.data.iloc[idx]
|
55 |
-
text = row["text"].translate(
|
|
|
|
|
56 |
summary = (
|
57 |
-
row["summary"].translate(
|
58 |
-
|
|
|
|
|
|
|
59 |
summary = ["<start>", *summary, "<end>"]
|
60 |
sample = {"text": text, "summary": summary}
|
61 |
|
@@ -106,7 +109,8 @@ class Data(torch.utils.data.Dataset):
|
|
106 |
tokenized_texts.append(text)
|
107 |
|
108 |
if text_type == "summary":
|
109 |
-
return [["<start>", *summary, "<end>"]
|
|
|
110 |
return tokenized_texts
|
111 |
|
112 |
def get_words(self) -> list:
|
@@ -157,8 +161,10 @@ class Vectoriser:
|
|
157 |
|
158 |
def __init__(self, vocab=None) -> None:
|
159 |
self.vocab = vocab
|
160 |
-
self.word_count = Counter(word.lower().strip(",.\\-")
|
161 |
-
|
|
|
|
|
162 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
163 |
|
164 |
def load(self, path):
|
@@ -167,7 +173,8 @@ class Vectoriser:
|
|
167 |
self.word_count = Counter(
|
168 |
word.lower().strip(",.\\-") for word in self.vocab
|
169 |
)
|
170 |
-
self.idx_to_token = sorted(
|
|
|
171 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
172 |
|
173 |
def save(self, path):
|
|
|
38 |
<end> tokens depending on the text_type
|
39 |
get_words()
|
40 |
get the dataset vocabulary
|
|
|
|
|
41 |
"""
|
42 |
|
43 |
def __init__(self, path: str, transform=None) -> None:
|
|
|
50 |
|
51 |
def __getitem__(self, idx):
|
52 |
row = self.data.iloc[idx]
|
53 |
+
text = row["text"].translate(
|
54 |
+
str.maketrans(
|
55 |
+
"", "", string.punctuation)).split()
|
56 |
summary = (
|
57 |
+
row["summary"].translate(
|
58 |
+
str.maketrans(
|
59 |
+
"",
|
60 |
+
"",
|
61 |
+
string.punctuation)).split())
|
62 |
summary = ["<start>", *summary, "<end>"]
|
63 |
sample = {"text": text, "summary": summary}
|
64 |
|
|
|
109 |
tokenized_texts.append(text)
|
110 |
|
111 |
if text_type == "summary":
|
112 |
+
return [["<start>", *summary, "<end>"]
|
113 |
+
for summary in tokenized_texts]
|
114 |
return tokenized_texts
|
115 |
|
116 |
def get_words(self) -> list:
|
|
|
161 |
|
162 |
def __init__(self, vocab=None) -> None:
|
163 |
self.vocab = vocab
|
164 |
+
self.word_count = Counter(word.lower().strip(",.\\-")
|
165 |
+
for word in self.vocab)
|
166 |
+
self.idx_to_token = sorted(
|
167 |
+
[t for t, c in self.word_count.items() if c > 1])
|
168 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
169 |
|
170 |
def load(self, path):
|
|
|
173 |
self.word_count = Counter(
|
174 |
word.lower().strip(",.\\-") for word in self.vocab
|
175 |
)
|
176 |
+
self.idx_to_token = sorted(
|
177 |
+
[t for t, c in self.word_count.items() if c > 1])
|
178 |
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
|
179 |
|
180 |
def save(self, path):
|
src/fine_tune_T5.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
|
5 |
+
import contractions
|
6 |
+
import datasets
|
7 |
+
import evaluate
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
from datasets import Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer,
|
13 |
+
DataCollatorForSeq2Seq, Seq2SeqTrainer,
|
14 |
+
Seq2SeqTrainingArguments)
|
15 |
+
|
16 |
+
|
17 |
+
def clean_text(texts):
|
18 |
+
"""This fonction makes clean text for the future use"""
|
19 |
+
texts = texts.lower()
|
20 |
+
texts = contractions.fix(texts)
|
21 |
+
texts = texts.translate(str.maketrans("", "", string.punctuation))
|
22 |
+
texts = re.sub(r"\n", " ", texts)
|
23 |
+
return texts
|
24 |
+
|
25 |
+
|
26 |
+
def datasetmaker(path=str):
|
27 |
+
"""This fonction take the jsonl file, read it to a dataframe,
|
28 |
+
remove the colums not needed for the task and turn it into a file type Dataset
|
29 |
+
"""
|
30 |
+
data = pd.read_json(path, lines=True)
|
31 |
+
df = data.drop(
|
32 |
+
[
|
33 |
+
"url",
|
34 |
+
"archive",
|
35 |
+
"title",
|
36 |
+
"date",
|
37 |
+
"compression",
|
38 |
+
"coverage",
|
39 |
+
"density",
|
40 |
+
"compression_bin",
|
41 |
+
"coverage_bin",
|
42 |
+
"density_bin",
|
43 |
+
],
|
44 |
+
axis=1,
|
45 |
+
)
|
46 |
+
tqdm.pandas()
|
47 |
+
df["text"] = df.text.apply(lambda texts: clean_text(texts))
|
48 |
+
df["summary"] = df.summary.apply(lambda summary: clean_text(summary))
|
49 |
+
dataset = Dataset.from_dict(df)
|
50 |
+
return dataset
|
51 |
+
|
52 |
+
|
53 |
+
# voir si le model par hasard esr déjà bien
|
54 |
+
|
55 |
+
# test_text = dataset['text'][0]
|
56 |
+
# pipe = pipeline('summarization', model = model_ckpt)
|
57 |
+
# pipe_out = pipe(test_text)
|
58 |
+
# print(pipe_out[0]['summary_text'].replace('.<n>', '.\n'))
|
59 |
+
# print(dataset['summary'][0])
|
60 |
+
|
61 |
+
|
62 |
+
def generate_batch_sized_chunks(list_elements, batch_size):
|
63 |
+
"""this fonction split the dataset into smaller batches
|
64 |
+
that we can process simultaneously
|
65 |
+
Yield successive batch-sized chunks from list_of_elements."""
|
66 |
+
for i in range(0, len(list_elements), batch_size):
|
67 |
+
yield list_elements[i: i + batch_size]
|
68 |
+
|
69 |
+
|
70 |
+
def calculate_metric(dataset, metric, model, tokenizer,
|
71 |
+
batch_size, device,
|
72 |
+
column_text='text',
|
73 |
+
column_summary='summary'):
|
74 |
+
"""this fonction evaluate the model with metric rouge and
|
75 |
+
print a table of rouge scores rouge1', 'rouge2', 'rougeL', 'rougeLsum'"""
|
76 |
+
|
77 |
+
article_batches = list(
|
78 |
+
str(generate_batch_sized_chunks(dataset[column_text], batch_size))
|
79 |
+
)
|
80 |
+
target_batches = list(
|
81 |
+
str(generate_batch_sized_chunks(dataset[column_summary], batch_size))
|
82 |
+
)
|
83 |
+
|
84 |
+
for article_batch, target_batch in tqdm(
|
85 |
+
zip(article_batches, target_batches), total=len(article_batches)
|
86 |
+
):
|
87 |
+
inputs = tokenizer(
|
88 |
+
article_batch,
|
89 |
+
max_length=1024,
|
90 |
+
truncation=True,
|
91 |
+
padding="max_length",
|
92 |
+
return_tensors="pt",
|
93 |
+
)
|
94 |
+
# parameter for length penalty ensures that the model does not
|
95 |
+
# generate sequences that are too long.
|
96 |
+
summaries = model.generate(
|
97 |
+
input_ids=inputs["input_ids"].to(device),
|
98 |
+
attention_mask=inputs["attention_mask"].to(device),
|
99 |
+
length_penalty=0.8,
|
100 |
+
num_beams=8,
|
101 |
+
max_length=128,
|
102 |
+
)
|
103 |
+
|
104 |
+
# Décode les textes
|
105 |
+
# renplacer les tokens, ajouter des textes décodés avec les rédéfences
|
106 |
+
# vers la métrique.
|
107 |
+
decoded_summaries = [
|
108 |
+
tokenizer.decode(
|
109 |
+
s, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
110 |
+
)
|
111 |
+
for s in summaries
|
112 |
+
]
|
113 |
+
|
114 |
+
decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
|
115 |
+
|
116 |
+
metric.add_batch(
|
117 |
+
predictions=decoded_summaries,
|
118 |
+
references=target_batch)
|
119 |
+
|
120 |
+
# compute et return les ROUGE scores.
|
121 |
+
results = metric.compute()
|
122 |
+
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
123 |
+
rouge_dict = dict((rn, results[rn]) for rn in rouge_names)
|
124 |
+
return pd.DataFrame(rouge_dict, index=["T5"])
|
125 |
+
|
126 |
+
|
127 |
+
def convert_ex_to_features(example_batch):
|
128 |
+
"""this fonction takes for input a list of inputExemples and convert to InputFeatures"""
|
129 |
+
input_encodings = tokenizer(example_batch['text'],
|
130 |
+
max_length=1024, truncation=True)
|
131 |
+
|
132 |
+
labels = tokenizer(
|
133 |
+
example_batch["summary"],
|
134 |
+
max_length=128,
|
135 |
+
truncation=True)
|
136 |
+
|
137 |
+
return {
|
138 |
+
"input_ids": input_encodings["input_ids"],
|
139 |
+
"attention_mask": input_encodings["attention_mask"],
|
140 |
+
"labels": labels["input_ids"],
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
# réalisation des datasets propres
|
146 |
+
train_dataset = datasetmaker('data/train_extract.jsonl')
|
147 |
+
|
148 |
+
|
149 |
+
test_dataset = datasetmaker("data/test_extract.jsonl")
|
150 |
+
|
151 |
+
test_dataset = datasetmaker('data/test_extract.jsonl')
|
152 |
+
|
153 |
+
dataset = datasets.DatasetDict({'train': train_dataset,
|
154 |
+
'dev': dev_dataset, 'test': test_dataset})
|
155 |
+
# définition de device
|
156 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
157 |
+
# faire appel au model à entrainer
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
|
159 |
+
|
160 |
+
mt5_config = AutoConfig.from_pretrained(
|
161 |
+
"google/mt5-small",
|
162 |
+
max_length=128,
|
163 |
+
length_penalty=0.6,
|
164 |
+
no_repeat_ngram_size=2,
|
165 |
+
num_beams=15,
|
166 |
+
)
|
167 |
+
|
168 |
+
model = (AutoModelForSeq2SeqLM
|
169 |
+
.from_pretrained('google/mt5-small', config=mt5_config)
|
170 |
+
.to(device))
|
171 |
+
#convertir les exemples en inputFeatures
|
172 |
+
|
173 |
+
dataset_pt = dataset.map(
|
174 |
+
convert_ex_to_features,
|
175 |
+
remove_columns=["summary", "text"],
|
176 |
+
batched=True,
|
177 |
+
batch_size=128,
|
178 |
+
)
|
179 |
+
|
180 |
+
data_collator = DataCollatorForSeq2Seq(
|
181 |
+
tokenizer, model=model, return_tensors="pt")
|
182 |
+
#définir les paramètres d'entrainement(fine tuning)
|
183 |
+
training_args = Seq2SeqTrainingArguments(
|
184 |
+
output_dir="t5_summary",
|
185 |
+
log_level="error",
|
186 |
+
num_train_epochs=10,
|
187 |
+
learning_rate=5e-4,
|
188 |
+
warmup_steps=0,
|
189 |
+
optim="adafactor",
|
190 |
+
weight_decay=0.01,
|
191 |
+
per_device_train_batch_size=2,
|
192 |
+
per_device_eval_batch_size=1,
|
193 |
+
gradient_accumulation_steps=16,
|
194 |
+
evaluation_strategy="steps",
|
195 |
+
eval_steps=100,
|
196 |
+
predict_with_generate=True,
|
197 |
+
generation_max_length=128,
|
198 |
+
save_steps=500,
|
199 |
+
logging_steps=10,
|
200 |
+
# push_to_hub = True
|
201 |
+
)
|
202 |
+
#donner au entraineur(trainer) le model
|
203 |
+
# et les éléments nécessaire pour l'entrainement
|
204 |
+
trainer = Seq2SeqTrainer(
|
205 |
+
model=model,
|
206 |
+
args=training_args,
|
207 |
+
data_collator=data_collator,
|
208 |
+
# compute_metrics = calculate_metric,
|
209 |
+
train_dataset=dataset_pt["train"],
|
210 |
+
eval_dataset=dataset_pt["dev"].select(range(10)),
|
211 |
+
tokenizer=tokenizer,
|
212 |
+
)
|
213 |
+
|
214 |
+
trainer.train()
|
215 |
+
rouge_metric = evaluate.load("rouge")
|
216 |
+
#évluer ensuite le model selon les résultats d'entrainement
|
217 |
+
score = calculate_metric(
|
218 |
+
test_dataset,
|
219 |
+
rouge_metric,
|
220 |
+
trainer.model,
|
221 |
+
tokenizer,
|
222 |
+
batch_size=2,
|
223 |
+
device=device,
|
224 |
+
column_text="text",
|
225 |
+
column_summary="summary",
|
226 |
+
)
|
227 |
+
print(score)
|
228 |
+
|
229 |
+
# Fine Tuning terminés et à sauvgarder
|
230 |
+
|
231 |
+
# sauvegarder fine-tuned model à local
|
232 |
+
os.makedirs("t5_summary", exist_ok=True)
|
233 |
+
if hasattr(trainer.model, "module"):
|
234 |
+
trainer.model.module.save_pretrained("t5_summary")
|
235 |
+
else:
|
236 |
+
trainer.model.save_pretrained("t5_summary")
|
237 |
+
tokenizer.save_pretrained("t5_summary")
|
238 |
+
|
239 |
+
# faire appel au model en local
|
240 |
+
model = (AutoModelForSeq2SeqLM
|
241 |
+
.from_pretrained("t5_summary")
|
242 |
+
.to(device))
|
243 |
+
|
244 |
+
|
245 |
+
# mettre en usage : TEST
|
246 |
+
|
247 |
+
# gen_kwargs = {"length_penalty" : 0.8, "num_beams" : 8, "max_length" : 128}
|
248 |
+
# sample_text = dataset["test"][0]["text"]
|
249 |
+
# reference = dataset["test"][0]["summary"]
|
250 |
+
# pipe = pipeline("summarization", model='./summarization_t5')
|
251 |
+
|
252 |
+
# print("Text :")
|
253 |
+
# print(sample_text)
|
254 |
+
# print("\nReference Summary :")
|
255 |
+
# print(reference)
|
256 |
+
# print("\nModel Summary :")
|
257 |
+
# print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
|
src/fine_tune_t5.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import datasets
|
3 |
-
from datasets import Dataset, DatasetDict
|
4 |
-
import pandas as pd
|
5 |
-
from tqdm import tqdm
|
6 |
-
import re
|
7 |
-
import os
|
8 |
-
import nltk
|
9 |
-
import string
|
10 |
-
import contractions
|
11 |
-
from transformers import pipeline
|
12 |
-
import evaluate
|
13 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
|
14 |
-
from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
|
15 |
-
from transformers import DataCollatorForSeq2Seq
|
16 |
-
|
17 |
-
# cuda out of memory
|
18 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:200"
|
19 |
-
|
20 |
-
nltk.download('stopwords')
|
21 |
-
nltk.download('punkt')
|
22 |
-
|
23 |
-
|
24 |
-
def clean_data(texts):
|
25 |
-
texts = texts.lower()
|
26 |
-
texts = contractions.fix(texts)
|
27 |
-
texts = texts.translate(str.maketrans("", "", string.punctuation))
|
28 |
-
texts = re.sub(r'\n',' ',texts)
|
29 |
-
return texts
|
30 |
-
|
31 |
-
def datasetmaker (path=str):
|
32 |
-
data = pd.read_json(path, lines=True)
|
33 |
-
df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
|
34 |
-
tqdm.pandas()
|
35 |
-
df['text'] = df.text.apply(lambda texts : clean_data(texts))
|
36 |
-
df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
|
37 |
-
# df['text'] = df['text'].map(str)
|
38 |
-
# df['summary'] = df['summary'].map(str)
|
39 |
-
dataset = Dataset.from_dict(df)
|
40 |
-
return dataset
|
41 |
-
|
42 |
-
#voir si le model par hasard esr déjà bien
|
43 |
-
|
44 |
-
# test_text = dataset['text'][0]
|
45 |
-
# pipe = pipeline('summarization',model = model_ckpt)
|
46 |
-
# pipe_out = pipe(test_text)
|
47 |
-
# print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
|
48 |
-
# print(dataset['summary'][0])
|
49 |
-
|
50 |
-
def generate_batch_sized_chunks(list_elements, batch_size):
|
51 |
-
"""split the dataset into smaller batches that we can process simultaneously
|
52 |
-
Yield successive batch-sized chunks from list_of_elements."""
|
53 |
-
for i in range(0, len(list_elements), batch_size):
|
54 |
-
yield list_elements[i : i + batch_size]
|
55 |
-
|
56 |
-
def calculate_metric(dataset, metric, model, tokenizer,
|
57 |
-
batch_size, device,
|
58 |
-
column_text='text',
|
59 |
-
column_summary='summary'):
|
60 |
-
article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
|
61 |
-
target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
|
62 |
-
|
63 |
-
for article_batch, target_batch in tqdm(
|
64 |
-
zip(article_batches, target_batches), total=len(article_batches)):
|
65 |
-
|
66 |
-
inputs = tokenizer(article_batch, max_length=1024, truncation=True,
|
67 |
-
padding="max_length", return_tensors="pt")
|
68 |
-
|
69 |
-
summaries = model.generate(input_ids=inputs["input_ids"].to(device),
|
70 |
-
attention_mask=inputs["attention_mask"].to(device),
|
71 |
-
length_penalty=0.8, num_beams=8, max_length=128)
|
72 |
-
''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
|
73 |
-
|
74 |
-
# Décode les textes
|
75 |
-
# renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
|
76 |
-
decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
|
77 |
-
clean_up_tokenization_spaces=True)
|
78 |
-
for s in summaries]
|
79 |
-
|
80 |
-
decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
|
81 |
-
|
82 |
-
|
83 |
-
metric.add_batch(predictions=decoded_summaries, references=target_batch)
|
84 |
-
|
85 |
-
#compute et return les ROUGE scores.
|
86 |
-
results = metric.compute()
|
87 |
-
rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
|
88 |
-
rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
|
89 |
-
return pd.DataFrame(rouge_dict, index = ['T5'])
|
90 |
-
|
91 |
-
|
92 |
-
def convert_ex_to_features(example_batch):
|
93 |
-
input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
|
94 |
-
|
95 |
-
labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
|
96 |
-
|
97 |
-
return {
|
98 |
-
'input_ids' : input_encodings['input_ids'],
|
99 |
-
'attention_mask': input_encodings['attention_mask'],
|
100 |
-
'labels': labels['input_ids']
|
101 |
-
}
|
102 |
-
|
103 |
-
if __name__=='__main__':
|
104 |
-
|
105 |
-
train_dataset = datasetmaker('data/train_extract_100.jsonl')
|
106 |
-
|
107 |
-
dev_dataset = datasetmaker('data/dev_extract_100.jsonl')
|
108 |
-
|
109 |
-
test_dataset = datasetmaker('data/test_extract_100.jsonl')
|
110 |
-
|
111 |
-
dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
|
112 |
-
|
113 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
-
|
115 |
-
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
116 |
-
mt5_config = AutoConfig.from_pretrained(
|
117 |
-
"google/mt5-small",
|
118 |
-
max_length=128,
|
119 |
-
length_penalty=0.6,
|
120 |
-
no_repeat_ngram_size=2,
|
121 |
-
num_beams=15,
|
122 |
-
)
|
123 |
-
model = (AutoModelForSeq2SeqLM
|
124 |
-
.from_pretrained("google/mt5-small", config=mt5_config)
|
125 |
-
.to(device))
|
126 |
-
|
127 |
-
dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
|
128 |
-
|
129 |
-
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
|
130 |
-
|
131 |
-
|
132 |
-
training_args = Seq2SeqTrainingArguments(
|
133 |
-
output_dir = "mt5_sum",
|
134 |
-
log_level = "error",
|
135 |
-
num_train_epochs = 10,
|
136 |
-
learning_rate = 5e-4,
|
137 |
-
# lr_scheduler_type = "linear",
|
138 |
-
warmup_steps = 0,
|
139 |
-
optim = "adafactor",
|
140 |
-
weight_decay = 0.01,
|
141 |
-
per_device_train_batch_size = 2,
|
142 |
-
per_device_eval_batch_size = 1,
|
143 |
-
gradient_accumulation_steps = 16,
|
144 |
-
evaluation_strategy = "steps",
|
145 |
-
eval_steps = 100,
|
146 |
-
predict_with_generate=True,
|
147 |
-
generation_max_length = 128,
|
148 |
-
save_steps = 500,
|
149 |
-
logging_steps = 10,
|
150 |
-
# push_to_hub = True
|
151 |
-
)
|
152 |
-
|
153 |
-
|
154 |
-
trainer = Seq2SeqTrainer(
|
155 |
-
model = model,
|
156 |
-
args = training_args,
|
157 |
-
data_collator = data_collator,
|
158 |
-
# compute_metrics = calculate_metric,
|
159 |
-
train_dataset=dataset_pt['train'],
|
160 |
-
eval_dataset=dataset_pt['dev'].select(range(10)),
|
161 |
-
tokenizer = tokenizer,
|
162 |
-
)
|
163 |
-
|
164 |
-
trainer.train()
|
165 |
-
rouge_metric = evaluate.load("rouge")
|
166 |
-
|
167 |
-
score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
|
168 |
-
batch_size=2, device=device,
|
169 |
-
column_text='text',
|
170 |
-
column_summary='summary')
|
171 |
-
print (score)
|
172 |
-
|
173 |
-
|
174 |
-
#Fine Tuning terminés et à sauvgarder
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
# save fine-tuned model in local
|
179 |
-
os.makedirs("./summarization_t5", exist_ok=True)
|
180 |
-
if hasattr(trainer.model, "module"):
|
181 |
-
trainer.model.module.save_pretrained("./summarization_t5")
|
182 |
-
else:
|
183 |
-
trainer.model.save_pretrained("./summarization_t5")
|
184 |
-
tokenizer.save_pretrained("./summarization_t5")
|
185 |
-
# load local model
|
186 |
-
model = (AutoModelForSeq2SeqLM
|
187 |
-
.from_pretrained("./summarization_t5")
|
188 |
-
.to(device))
|
189 |
-
|
190 |
-
|
191 |
-
# mettre en usage : TEST
|
192 |
-
|
193 |
-
|
194 |
-
# gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
|
195 |
-
# sample_text = dataset["test"][0]["text"]
|
196 |
-
# reference = dataset["test"][0]["summary"]
|
197 |
-
# pipe = pipeline("summarization", model='./summarization_t5')
|
198 |
-
|
199 |
-
# print("Text:")
|
200 |
-
# print(sample_text)
|
201 |
-
# print("\nReference Summary:")
|
202 |
-
# print(reference)
|
203 |
-
# print("\nModel Summary:")
|
204 |
-
# print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/{inference.py → inference_lstm.py}
RENAMED
@@ -1,5 +1,6 @@
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
|
|
3 |
"""
|
4 |
import pickle
|
5 |
|
@@ -7,13 +8,14 @@ import torch
|
|
7 |
|
8 |
from src import dataloader
|
9 |
from src.model import Decoder, Encoder, EncoderDecoderModel
|
|
|
10 |
|
11 |
with open("model/vocab.pkl", "rb") as vocab:
|
12 |
words = pickle.load(vocab)
|
13 |
vectoriser = dataloader.Vectoriser(words)
|
14 |
|
15 |
|
16 |
-
def
|
17 |
"""
|
18 |
Predict the summary for an input text
|
19 |
--------
|
@@ -34,6 +36,7 @@ def inferenceAPI(text: str) -> str:
|
|
34 |
|
35 |
# On instancie le modèle
|
36 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
|
|
37 |
|
38 |
# model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
39 |
# model.eval()
|
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
3 |
+
using LSTM model
|
4 |
"""
|
5 |
import pickle
|
6 |
|
|
|
8 |
|
9 |
from src import dataloader
|
10 |
from src.model import Decoder, Encoder, EncoderDecoderModel
|
11 |
+
# from transformers import AutoModel
|
12 |
|
13 |
with open("model/vocab.pkl", "rb") as vocab:
|
14 |
words = pickle.load(vocab)
|
15 |
vectoriser = dataloader.Vectoriser(words)
|
16 |
|
17 |
|
18 |
+
def inference_lstm(text: str) -> str:
|
19 |
"""
|
20 |
Predict the summary for an input text
|
21 |
--------
|
|
|
36 |
|
37 |
# On instancie le modèle
|
38 |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
|
39 |
+
# model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM")
|
40 |
|
41 |
# model.load_state_dict(torch.load("model/model.pt", map_location=device))
|
42 |
# model.eval()
|
src/inference_t5.py
CHANGED
@@ -1,23 +1,25 @@
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
3 |
"""
|
4 |
-
import torch
|
5 |
-
import nltk
|
6 |
-
import contractions
|
7 |
import re
|
8 |
import string
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
12 |
|
13 |
-
|
|
|
14 |
texts = texts.lower()
|
15 |
-
texts = contractions.fix(texts)
|
16 |
texts = texts.translate(str.maketrans("", "", string.punctuation))
|
17 |
-
texts = re.sub(r
|
18 |
return texts
|
19 |
|
20 |
-
|
|
|
|
|
21 |
"""
|
22 |
Predict the summary for an input text
|
23 |
--------
|
@@ -28,38 +30,45 @@ def inferenceAPI_t5(text: str) -> str:
|
|
28 |
str
|
29 |
The summary for the input text
|
30 |
"""
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
model = (AutoModelForSeq2SeqLM
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
text_encoding = tokenizer(
|
40 |
text,
|
41 |
max_length=1024,
|
42 |
-
padding=
|
43 |
truncation=True,
|
44 |
return_attention_mask=True,
|
45 |
add_special_tokens=True,
|
46 |
-
return_tensors=
|
47 |
)
|
48 |
generated_ids = model.generate(
|
49 |
-
input_ids=text_encoding[
|
50 |
-
attention_mask=text_encoding[
|
51 |
max_length=128,
|
52 |
num_beams=8,
|
53 |
length_penalty=0.8,
|
54 |
-
early_stopping=True
|
55 |
)
|
56 |
|
57 |
preds = [
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
]
|
61 |
return "".join(preds)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
1 |
"""
|
2 |
Allows to predict the summary for a given entry text
|
3 |
"""
|
|
|
|
|
|
|
4 |
import re
|
5 |
import string
|
6 |
+
import os
|
7 |
+
os.environ['TRANSFORMERS_CACHE'] = './.cache'
|
8 |
+
|
9 |
+
import contractions
|
10 |
+
import torch
|
11 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
12 |
|
13 |
+
|
14 |
+
def clean_text(texts: str) -> str:
|
15 |
texts = texts.lower()
|
|
|
16 |
texts = texts.translate(str.maketrans("", "", string.punctuation))
|
17 |
+
texts = re.sub(r"\n", " ", texts)
|
18 |
return texts
|
19 |
|
20 |
+
|
21 |
+
|
22 |
+
def inference_t5(text: str) -> str:
|
23 |
"""
|
24 |
Predict the summary for an input text
|
25 |
--------
|
|
|
30 |
str
|
31 |
The summary for the input text
|
32 |
"""
|
33 |
+
|
34 |
+
# On défini les paramètres d'entrée pour le modèle
|
35 |
+
text = clean_text(text)
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
|
38 |
+
tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True))
|
39 |
+
# load local model
|
40 |
model = (AutoModelForSeq2SeqLM
|
41 |
+
.from_pretrained("Linggg/t5_summary",use_auth_token=True)
|
42 |
+
.to(device))
|
43 |
+
|
44 |
+
|
45 |
text_encoding = tokenizer(
|
46 |
text,
|
47 |
max_length=1024,
|
48 |
+
padding="max_length",
|
49 |
truncation=True,
|
50 |
return_attention_mask=True,
|
51 |
add_special_tokens=True,
|
52 |
+
return_tensors="pt",
|
53 |
)
|
54 |
generated_ids = model.generate(
|
55 |
+
input_ids=text_encoding["input_ids"],
|
56 |
+
attention_mask=text_encoding["attention_mask"],
|
57 |
max_length=128,
|
58 |
num_beams=8,
|
59 |
length_penalty=0.8,
|
60 |
+
early_stopping=True,
|
61 |
)
|
62 |
|
63 |
preds = [
|
64 |
+
tokenizer.decode(
|
65 |
+
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
66 |
+
)
|
67 |
+
for gen_id in generated_ids
|
68 |
]
|
69 |
return "".join(preds)
|
70 |
|
71 |
+
|
72 |
+
# if __name__ == "__main__":
|
73 |
+
# text = input('Entrez votre phrase à résumer : ')
|
74 |
+
# print('summary:', inferenceAPI_T5(text))
|
src/model.py
CHANGED
@@ -25,7 +25,8 @@ class Encoder(torch.nn.Module):
|
|
25 |
# on s'en servira pour les mots inconnus
|
26 |
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
|
27 |
self.embeddings.to(device)
|
28 |
-
self.hidden = torch.nn.LSTM(
|
|
|
29 |
# Comme on va calculer la log-vraisemblance,
|
30 |
# c'est le log-softmax qui nous intéresse
|
31 |
self.dropout = torch.nn.Dropout(dropout)
|
@@ -61,7 +62,8 @@ class Decoder(torch.nn.Module):
|
|
61 |
# on s'en servira pour les mots inconnus
|
62 |
self.vocab_size = vocab_size
|
63 |
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
|
64 |
-
self.hidden = torch.nn.LSTM(
|
|
|
65 |
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
66 |
# Comme on va calculer la log-vraisemblance,
|
67 |
# c'est le log-softmax qui nous intéresse
|
@@ -100,32 +102,36 @@ class EncoderDecoderModel(torch.nn.Module):
|
|
100 |
# The ratio must be inferior to 1 to allow text compression
|
101 |
assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
target_vocab_size = self.decoder.vocab_size
|
107 |
|
108 |
-
# Output of the right format (expected summmary length x word
|
109 |
-
# filled with zeros. On each iteration, we
|
110 |
-
# matrix with the choosen
|
|
|
111 |
outputs = torch.zeros(target_len, target_vocab_size)
|
112 |
|
113 |
-
# put the tensors on the device (useless if CPU bus very useful in
|
|
|
114 |
outputs.to(self.device)
|
115 |
source.to(self.device)
|
116 |
|
117 |
-
# last hidden state of the encoder is used
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
# put the tensors on the device
|
124 |
hidden.to(self.device)
|
125 |
cell.to(self.device)
|
126 |
input.to(self.device)
|
127 |
|
128 |
-
|
129 |
# If you wonder, b stands for better
|
130 |
values = None
|
131 |
b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
|
@@ -134,14 +140,16 @@ class EncoderDecoderModel(torch.nn.Module):
|
|
134 |
for i in range(1, target_len):
|
135 |
# On va déterminer autant de mot que la taille du texte souhaité
|
136 |
# insert input token embedding, previous hidden and previous cell states
|
137 |
-
# receive output tensor (predictions) and new hidden and cell
|
|
|
138 |
|
139 |
# replace predictions in a tensor holding predictions for each token
|
140 |
# logging.debug(f"output : {output}")
|
141 |
|
142 |
####### DÉBUT DU BEAM SEARCH ##########
|
143 |
if values is None:
|
144 |
-
# On calcule une première fois les premières probabilité de mot
|
|
|
145 |
output, hidden, cell = self.decoder(input, hidden, cell)
|
146 |
output.to(self.device)
|
147 |
b_hidden = hidden
|
@@ -152,7 +160,8 @@ class EncoderDecoderModel(torch.nn.Module):
|
|
152 |
values, indices = output.topk(num_beams, sorted=True)
|
153 |
|
154 |
else:
|
155 |
-
# On instancie le dictionnaire qui contiendra les scores pour
|
|
|
156 |
scores = {}
|
157 |
|
158 |
# Pour chacune des meilleures valeurs, on va calculer l'output
|
@@ -160,7 +169,8 @@ class EncoderDecoderModel(torch.nn.Module):
|
|
160 |
indice.to(self.device)
|
161 |
|
162 |
# On calcule l'output
|
163 |
-
b_output, b_hidden, b_cell = self.decoder(
|
|
|
164 |
|
165 |
# On empêche le modèle de se répéter d'un mot sur l'autre en mettant
|
166 |
# de force la probabilité du mot précédent à 0
|
@@ -179,7 +189,8 @@ class EncoderDecoderModel(torch.nn.Module):
|
|
179 |
# Et du coup on rempli la place de i-1 à la place de i
|
180 |
b_outputs[i - 1] = b_output.to(self.device)
|
181 |
|
182 |
-
# On instancies nos nouvelles valeurs pour la prochaine
|
|
|
183 |
values, indices = b_output.topk(num_beams, sorted=True)
|
184 |
|
185 |
##################################
|
|
|
25 |
# on s'en servira pour les mots inconnus
|
26 |
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
|
27 |
self.embeddings.to(device)
|
28 |
+
self.hidden = torch.nn.LSTM(
|
29 |
+
embeddings_dim, hidden_size, dropout=dropout)
|
30 |
# Comme on va calculer la log-vraisemblance,
|
31 |
# c'est le log-softmax qui nous intéresse
|
32 |
self.dropout = torch.nn.Dropout(dropout)
|
|
|
62 |
# on s'en servira pour les mots inconnus
|
63 |
self.vocab_size = vocab_size
|
64 |
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
|
65 |
+
self.hidden = torch.nn.LSTM(
|
66 |
+
embeddings_dim, hidden_size, dropout=dropout)
|
67 |
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
68 |
# Comme on va calculer la log-vraisemblance,
|
69 |
# c'est le log-softmax qui nous intéresse
|
|
|
102 |
# The ratio must be inferior to 1 to allow text compression
|
103 |
assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
|
104 |
|
105 |
+
# Expected summary length (in words)
|
106 |
+
target_len = int(summary_len * source.shape[0])
|
107 |
+
# Word Embedding length
|
108 |
+
target_vocab_size = self.decoder.vocab_size
|
109 |
|
110 |
+
# Output of the right format (expected summmary length x word
|
111 |
+
# embedding length) filled with zeros. On each iteration, we
|
112 |
+
# will replace one of the row of this matrix with the choosen
|
113 |
+
# word embedding
|
114 |
outputs = torch.zeros(target_len, target_vocab_size)
|
115 |
|
116 |
+
# put the tensors on the device (useless if CPU bus very useful in
|
117 |
+
# case of GPU)
|
118 |
outputs.to(self.device)
|
119 |
source.to(self.device)
|
120 |
|
121 |
+
# last hidden state of the encoder is used
|
122 |
+
# as the initial hidden state of the decoder
|
123 |
+
|
124 |
+
# Encode the input text
|
125 |
+
hidden, cell = self.encoder(source)
|
126 |
+
# Encode the first word of the summary
|
127 |
+
input = self.vectoriser.encode("<start>")
|
128 |
|
129 |
# put the tensors on the device
|
130 |
hidden.to(self.device)
|
131 |
cell.to(self.device)
|
132 |
input.to(self.device)
|
133 |
|
134 |
+
# BEAM SEARCH #
|
135 |
# If you wonder, b stands for better
|
136 |
values = None
|
137 |
b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
|
|
|
140 |
for i in range(1, target_len):
|
141 |
# On va déterminer autant de mot que la taille du texte souhaité
|
142 |
# insert input token embedding, previous hidden and previous cell states
|
143 |
+
# receive output tensor (predictions) and new hidden and cell
|
144 |
+
# states.
|
145 |
|
146 |
# replace predictions in a tensor holding predictions for each token
|
147 |
# logging.debug(f"output : {output}")
|
148 |
|
149 |
####### DÉBUT DU BEAM SEARCH ##########
|
150 |
if values is None:
|
151 |
+
# On calcule une première fois les premières probabilité de mot
|
152 |
+
# après <start>
|
153 |
output, hidden, cell = self.decoder(input, hidden, cell)
|
154 |
output.to(self.device)
|
155 |
b_hidden = hidden
|
|
|
160 |
values, indices = output.topk(num_beams, sorted=True)
|
161 |
|
162 |
else:
|
163 |
+
# On instancie le dictionnaire qui contiendra les scores pour
|
164 |
+
# chaque possibilité
|
165 |
scores = {}
|
166 |
|
167 |
# Pour chacune des meilleures valeurs, on va calculer l'output
|
|
|
169 |
indice.to(self.device)
|
170 |
|
171 |
# On calcule l'output
|
172 |
+
b_output, b_hidden, b_cell = self.decoder(
|
173 |
+
indice, b_hidden, b_cell)
|
174 |
|
175 |
# On empêche le modèle de se répéter d'un mot sur l'autre en mettant
|
176 |
# de force la probabilité du mot précédent à 0
|
|
|
189 |
# Et du coup on rempli la place de i-1 à la place de i
|
190 |
b_outputs[i - 1] = b_output.to(self.device)
|
191 |
|
192 |
+
# On instancies nos nouvelles valeurs pour la prochaine
|
193 |
+
# itération
|
194 |
values, indices = b_output.topk(num_beams, sorted=True)
|
195 |
|
196 |
##################################
|
src/train.py
CHANGED
@@ -150,16 +150,24 @@ if __name__ == "__main__":
|
|
150 |
words = train_dataset.get_words()
|
151 |
vectoriser = dataloader.Vectoriser(words)
|
152 |
|
153 |
-
train_dataset = dataloader.Data(
|
154 |
-
|
|
|
|
|
|
|
|
|
155 |
|
156 |
train_dataloader = torch.utils.data.DataLoader(
|
157 |
-
train_dataset,
|
158 |
-
|
|
|
|
|
159 |
|
160 |
dev_dataloader = torch.utils.data.DataLoader(
|
161 |
-
dev_dataset,
|
162 |
-
|
|
|
|
|
163 |
|
164 |
for i_batch, batch in enumerate(train_dataloader):
|
165 |
print(i_batch, batch[0], batch[1])
|
@@ -169,7 +177,8 @@ if __name__ == "__main__":
|
|
169 |
print("Device check. You are using:", device)
|
170 |
|
171 |
### RÉSEAU ENTRAÎNÉ ###
|
172 |
-
# Pour s'assurer que les résultats seront les mêmes à chaque run du
|
|
|
173 |
torch.use_deterministic_algorithms(True)
|
174 |
torch.manual_seed(0)
|
175 |
random.seed(0)
|
@@ -178,9 +187,8 @@ if __name__ == "__main__":
|
|
178 |
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
179 |
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
180 |
|
181 |
-
trained_classifier = EncoderDecoderModel(
|
182 |
-
device
|
183 |
-
)
|
184 |
|
185 |
print(next(trained_classifier.parameters()).device)
|
186 |
# print(train_dataset.is_cuda)
|
@@ -194,7 +202,6 @@ if __name__ == "__main__":
|
|
194 |
|
195 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
196 |
vectoriser.save("model/vocab.pkl")
|
197 |
-
trained_classifier.push_to_hub("SummaryProject-LSTM")
|
198 |
|
199 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
200 |
print(
|
|
|
150 |
words = train_dataset.get_words()
|
151 |
vectoriser = dataloader.Vectoriser(words)
|
152 |
|
153 |
+
train_dataset = dataloader.Data(
|
154 |
+
"data/train_extract.jsonl",
|
155 |
+
transform=vectoriser)
|
156 |
+
dev_dataset = dataloader.Data(
|
157 |
+
"data/dev_extract.jsonl",
|
158 |
+
transform=vectoriser)
|
159 |
|
160 |
train_dataloader = torch.utils.data.DataLoader(
|
161 |
+
train_dataset,
|
162 |
+
batch_size=2,
|
163 |
+
shuffle=True,
|
164 |
+
collate_fn=dataloader.pad_collate)
|
165 |
|
166 |
dev_dataloader = torch.utils.data.DataLoader(
|
167 |
+
dev_dataset,
|
168 |
+
batch_size=4,
|
169 |
+
shuffle=True,
|
170 |
+
collate_fn=dataloader.pad_collate)
|
171 |
|
172 |
for i_batch, batch in enumerate(train_dataloader):
|
173 |
print(i_batch, batch[0], batch[1])
|
|
|
177 |
print("Device check. You are using:", device)
|
178 |
|
179 |
### RÉSEAU ENTRAÎNÉ ###
|
180 |
+
# Pour s'assurer que les résultats seront les mêmes à chaque run du
|
181 |
+
# notebook
|
182 |
torch.use_deterministic_algorithms(True)
|
183 |
torch.manual_seed(0)
|
184 |
random.seed(0)
|
|
|
187 |
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
188 |
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
|
189 |
|
190 |
+
trained_classifier = EncoderDecoderModel(
|
191 |
+
encoder, decoder, vectoriser, device).to(device)
|
|
|
192 |
|
193 |
print(next(trained_classifier.parameters()).device)
|
194 |
# print(train_dataset.is_cuda)
|
|
|
202 |
|
203 |
torch.save(trained_classifier.state_dict(), "model/model.pt")
|
204 |
vectoriser.save("model/vocab.pkl")
|
|
|
205 |
|
206 |
print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
|
207 |
print(
|
templates/index.html.jinja
CHANGED
@@ -4,7 +4,8 @@
|
|
4 |
<title>Text summarization API</title>
|
5 |
<meta charset="utf-8" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
|
7 |
-
<
|
|
|
8 |
<script>
|
9 |
function customReset()
|
10 |
{
|
|
|
4 |
<title>Text summarization API</title>
|
5 |
<meta charset="utf-8" />
|
6 |
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no" />
|
7 |
+
<style>html, body, div, h1, h2, p, blockquote,a, code, em, img, strong, u, ul, li,label, legend, caption, tr, th, td,header, menu, nav, section, summary{margin: 0;padding: 0;border: 0;font-size: 100%;font: inherit;vertical-align: baseline}header, menu, nav, section{display: block}div{margin-bottom: 20px}body{line-height: 1}ul{list-style: none}body{-webkit-text-size-adjust: none}input::-moz-focus-inner{border: 0;padding: 0}html{box-sizing: border-box}*, *:before, *:after{box-sizing: inherit}body{color: #5b5b5b;font-size: 15pt;line-height: 1.85em;font-family: 'Source Sans Pro', sans-serif;font-weight: 300;background-image: url("templates/site_style/images/background.jpg");background-size: cover;background-position: center center;background-attachment: fixed}h1, h2, h3{font-weight: 400;color: #483949;line-height: 1.25em}h1 a, h2 a, h3 a{color: inherit;text-decoration: none;border-bottom-color: transparent}h1 strong, h2 strong, h3 strong{font-weight: 600}h2{font-size: 2.85em}h3{font-size: 1.25em}strong, b{font-weight: 400;color: #483949}em, i{font-style: italic}a{color: inherit;border-bottom: solid 1px rgba(128, 128, 128, 0.15);text-decoration: none;-moz-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;-webkit-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;-ms-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out}a:hover{color: #ef8376;border-bottom-color: transparent}p, ul{margin-bottom: 1em}p{text-align: justify}hr{position: relative;display: block;border: 0;top: 4.5em;margin-bottom: 9em;height: 6px;border-top: solid 1px rgba(128, 128, 128, 0.2);border-bottom: solid 1px rgba(128, 128, 128, 0.2)}hr:before, hr:after{content: '';position: absolute;top: -8px;display: block;width: 1px;height: 21px;background: rgba(128, 128, 128, 0.2)}hr:before{left: -1px}hr:after{right: -1px}ul{list-style: disc;padding-left: 1em}ul li{padding-left: 0.5em;font-size: 85%;list-style: none}textarea{border-radius: 10px;resize: none;padding: 10px;line-height: 20px;word-spacing: 1px;font-size: 16px;width: 85%;height: 100%}::-webkit-input-placeholder{font-size: 17px;word-spacing: 1px}table{width: 100%}table.default{width: 100%}table.default tbody tr:first-child{border-top: 0}table.default tbody tr:nth-child(2n 1){background: #fafafa}table.default th{text-align: left;font-weight: 400;padding: 0.5em 1em 0.5em 1em}input[type="button"],input[type="submit"],input[type="reset"],button,.button{position: relative;display: inline-block;background: #df7366;color: #fff;text-align: center;border-radius: 0.5em;text-decoration: none;padding: 0.65em 3em 0.65em 3em;border: 0;cursor: pointer;outline: 0;font-weight: 300;-moz-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;-webkit-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;-ms-transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out;transition: background-color 0.35s ease-in-out, color 0.35s ease-in-out, border-bottom-color 0.35s ease-in-out}input[type="button"]:hover,input[type="submit"]:hover,input[type="reset"]:hover,button:hover,.button:hover{color: #fff;background: #ef8376}input[type="button"].alt,input[type="submit"].alt,input[type="reset"].alt,button.alt,.button.alt{background: #2B252C}input[type="button"].alt:hover,input[type="submit"].alt:hover,input[type="reset"].alt:hover,button.alt:hover,.button.alt:hover{background: #3B353C}#header{position: relative;background-size: cover;background-position: center center;background-attachment: fixed;color: #fff;text-align: center;padding: 5em 0 2em 0;cursor: default;height: 100%}#header:before{content: '';display: inline-block;vertical-align: middle;height: 100%}#header .inner{position: relative;z-index: 1;margin: 0;display: inline-block;vertical-align: middle}#header header{display: inline-block}#header header > p{font-size: 1.25em;margin: 0}#header h1{color: #fff;font-size: 3em;line-height: 1em}#header h1 a{color: inherit}#header .button{display: inline-block;border-radius: 100%;width: 4.5em;height: 4.5em;line-height: 4.5em;text-align: center;font-size: 1.25em;padding: 0}#header hr{top: 1.5em;margin-bottom: 3em;border-bottom-color: rgba(192, 192, 192, 0.35);box-shadow: inset 0 1px 0 0 rgba(192, 192, 192, 0.35)}#header hr:before, #header hr:after{background: rgba(192, 192, 192, 0.35)}#nav{position: absolute;top: 0;left: 0;width: 100%;text-align: center;padding: 1.5em 0 1.5em 0;z-index: 1;overflow: hidden}#nav > hr{top: 0.5em;margin-bottom: 6em}.copyright{margin-top: 50px}@media screen and (max-width: 1680px){body, input, select{font-size: 14pt;line-height: 1.75em}}@media screen and (max-width: 1280px){body, input, select{font-size: 12pt;line-height: 1.5em}#header{background-attachment: scroll}#header .inner{padding-left: 2em;padding-right: 2em}}@media screen and (max-width: 840px){body, input, select{font-size: 13pt;line-height: 1.65em}}#navPanel, #titleBar{display: none}@media screen and (max-width: 736px){html, body{overflow-x: hidden}body, input, select{font-size: 12.5pt;line-height: 1.5em}h2{font-size: 1.75em}h3{font-size: 1.25em}hr{top: 3em;margin-bottom: 6em}#header{background-attachment: scroll;padding: 2.5em 0 0 0}#header .inner{padding-top: 1.5em;padding-left: 1em;padding-right: 1em}#header header > p{font-size: 1em}#header h1{font-size: 1.75em}#header hr{top: 1em;margin-bottom: 2.5em}#nav{display: none}#main > header{text-align: center}div.copyright{margin-top: 10px}label, textarea{font-size: 0.8rem;letter-spacing: 1px;font-family: Georgia, 'Times New Roman', Times, serif}.buttons{display: flex;flex-direction: row;justify-content: center;margin-top: 20px}}
|
8 |
+
</style>
|
9 |
<script>
|
10 |
function customReset()
|
11 |
{
|