EveSa commited on
Commit
3c03f61
1 Parent(s): 3e852e2

refactoring de requirements.txt

Browse files
Files changed (10) hide show
  1. api.py +0 -51
  2. requirements.txt +7 -84
  3. src/api.py +13 -14
  4. src/dataloader.py +16 -7
  5. src/fine_tune_T5.py +92 -71
  6. src/fine_tune_t5.py +0 -204
  7. src/inference.py +0 -56
  8. src/inference_t5.py +14 -13
  9. src/model.py +32 -21
  10. src/train.py +18 -11
api.py DELETED
@@ -1,51 +0,0 @@
1
- import uvicorn
2
- from fastapi import FastAPI, Form, Request
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
-
6
- from inference import inferenceAPI
7
-
8
-
9
- # ------ MODELE --------------------------------------------------------------
10
- # appel de la fonction inference, adaptee pour une entree txt
11
- def summarize(text: str):
12
- return " ".join(inferenceAPI(text))
13
-
14
-
15
- # ----------------------------------------------------------------------------------
16
-
17
-
18
- # -------- API ---------------------------------------------------------------------
19
- app = FastAPI()
20
-
21
- # static files pour envoi du css au navigateur
22
- templates = Jinja2Templates(directory="templates")
23
- app.mount("/", StaticFiles(directory="templates", html=True), name="templates")
24
-
25
-
26
- @app.get("/")
27
- async def index(request: Request):
28
- return templates.TemplateResponse("index.html.jinja", {"request": request})
29
-
30
-
31
- # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
32
- @app.post("/")
33
- async def prediction(request: Request, text: str = Form(None)):
34
- if not text:
35
- error = "Merci de saisir votre texte."
36
- return templates.TemplateResponse(
37
- "index.html.jinja", {"request": request, "text": error}
38
- )
39
- else:
40
- summary = summarize(text)
41
- return templates.TemplateResponse(
42
- "index.html.jinja", {"request": request, "text": text, "summary": summary}
43
- )
44
-
45
-
46
- # ------------------------------------------------------------------------------------
47
-
48
-
49
- # lancer le serveur et le recharger a chaque modification sauvegardee
50
- # if __name__ == "__main__":
51
- # uvicorn.run("api:app", port=8000, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,56 +1,14 @@
1
- absl-py==1.4.0
2
- aiohttp==3.8.4
3
- aiosignal==1.3.1
4
- alembic==1.9.4
5
  anyascii==0.3.1
6
  anyio==3.6.2
7
- async-timeout==4.0.2
8
- attrs==22.2.0
9
- banal==1.0.6
10
- blis==0.7.9
11
- catalogue==2.0.8
12
- certifi==2022.12.7
13
- charset-normalizer==3.0.1
14
- click==8.1.3
15
- confection==0.0.4
16
- contourpy==1.0.7
17
- contractions==0.1.73
18
- cycler==0.11.0
19
- cymem==2.0.7
20
- dataloader==2.0
21
- dataset==1.6.0
22
- datasets==2.10.1
23
- dill==0.3.6
24
- en-core-web-lg==3.5.0
25
- evaluate==0.4.0
26
- fastapi==0.91.0
27
- filelock==3.9.0
28
- flake8==6.0.0
29
- fonttools==4.38.0
30
- frozenlist==1.3.3
31
- fsspec==2023.3.0
32
- greenlet==2.0.2
33
- h11==0.14.0
34
- huggingface-hub==0.12.1
35
  certifi==2022.12.7
36
  charset-normalizer==3.1.0
37
- click==8.1.3
38
- fastapi==0.92.0
39
  filelock==3.9.0
 
40
  idna==3.4
41
- importlib-metadata==6.0.0
42
- importlib-resources==5.12.0
43
  Jinja2==3.1.2
44
- joblib==1.2.0
45
- kiwisolver==1.4.4
46
- langcodes==3.3.0
47
- Mako==1.2.4
48
  MarkupSafe==2.1.2
49
- matplotlib==3.7.0
50
- mccabe==0.7.0
51
- multidict==6.0.4
52
- multiprocess==0.70.14
53
- murmurhash==1.0.9
54
  numpy==1.24.2
55
  nvidia-cublas-cu11==11.10.3.66
56
  nvidia-cuda-nvrtc-cu11==11.7.99
@@ -58,56 +16,21 @@ nvidia-cuda-runtime-cu11==11.7.99
58
  nvidia-cudnn-cu11==8.5.0.96
59
  packaging==23.0
60
  pandas==1.5.3
61
- pathy==0.10.1
62
- Pillow==9.4.0
63
- preshed==3.0.8
64
- protobuf==3.20.0
65
  pyahocorasick==2.0.0
66
- pyarrow==11.0.0
67
- pycodestyle==2.10.0
68
- pydantic==1.10.4
69
- pyflakes==3.0.1
70
- pyparsing==3.0.9
71
  python-dateutil==2.8.2
72
- python-multipart==0.0.5
73
  pytz==2022.7.1
74
  PyYAML==6.0
75
  regex==2022.10.31
76
  requests==2.28.2
77
- responses==0.18.0
78
- rouge-score==0.1.2
79
- scikit-learn==1.2.1
80
- scipy==1.10.0
81
- sentencepiece==0.1.97
82
  six==1.16.0
83
- smart-open==6.3.0
84
  sniffio==1.3.0
85
- spacy==3.5.0
86
- spacy-legacy==3.0.12
87
- spacy-loggers==1.0.4
88
- SQLAlchemy==1.4.46
89
- srsly==2.4.5
90
- starlette==0.24.0
91
- summarizer==0.0.7
92
  textsearch==0.0.24
93
- thinc==8.1.7
94
- threadpoolctl==3.1.0
95
- tokenizers==0.13.2
96
- tomli==2.0.1
97
- torch==1.13.1
98
- tqdm==4.64.1
99
- transformers==4.26.1
100
- typer==0.7.0
101
- typing-extensions==4.4.0
102
- urllib3==1.26.14
103
- starlette==0.25.0
104
  tokenizers==0.13.2
105
  torch==1.13.1
106
  tqdm==4.65.0
 
107
  typing_extensions==4.5.0
108
  urllib3==1.26.15
109
- uvicorn==0.20.0
110
- wasabi==1.1.1
111
- xxhash==3.2.0
112
- yarl==1.8.2
113
- zipp==3.14.0
 
 
 
 
 
1
  anyascii==0.3.1
2
  anyio==3.6.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  certifi==2022.12.7
4
  charset-normalizer==3.1.0
5
+ contractions==0.1.73
6
+ fastapi==0.94.0
7
  filelock==3.9.0
8
+ huggingface-hub==0.13.2
9
  idna==3.4
 
 
10
  Jinja2==3.1.2
 
 
 
 
11
  MarkupSafe==2.1.2
 
 
 
 
 
12
  numpy==1.24.2
13
  nvidia-cublas-cu11==11.10.3.66
14
  nvidia-cuda-nvrtc-cu11==11.7.99
 
16
  nvidia-cudnn-cu11==8.5.0.96
17
  packaging==23.0
18
  pandas==1.5.3
 
 
 
 
19
  pyahocorasick==2.0.0
20
+ pydantic==1.10.6
 
 
 
 
21
  python-dateutil==2.8.2
22
+ python-multipart==0.0.6
23
  pytz==2022.7.1
24
  PyYAML==6.0
25
  regex==2022.10.31
26
  requests==2.28.2
 
 
 
 
 
27
  six==1.16.0
 
28
  sniffio==1.3.0
29
+ starlette==0.26.1
 
 
 
 
 
 
30
  textsearch==0.0.24
 
 
 
 
 
 
 
 
 
 
 
31
  tokenizers==0.13.2
32
  torch==1.13.1
33
  tqdm==4.65.0
34
+ transformers==4.26.1
35
  typing_extensions==4.5.0
36
  urllib3==1.26.15
 
 
 
 
 
src/api.py CHANGED
@@ -1,31 +1,30 @@
1
- import uvicorn
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
- import re
6
 
7
- from src.inference import inferenceAPI
8
- from src.inference_t5 import inferenceAPI_t5
9
 
10
 
11
  # ------ INFERENCE MODEL --------------------------------------------------------------
12
  # appel de la fonction inference, adaptee pour une entree txt
13
  def summarize(text: str):
14
- if choisir_modele.var == 'lstm' :
15
  return " ".join(inferenceAPI(text))
16
  elif choisir_modele.var == "fineTunedT5":
17
  text = inferenceAPI_t5(text)
18
 
 
19
  # ----------------------------------------------------------------------------------
20
 
21
 
22
  def choisir_modele(choixModele):
23
  print("ON A RECUP LE CHOIX MODELE")
24
- if choixModele == "lstm" :
25
- choisir_modele.var ='lstm'
26
  elif choixModele == "fineTunedT5":
27
  choisir_modele.var = "fineTunedT5"
28
- else :
29
  "le modele n'est pas defini"
30
 
31
 
@@ -41,29 +40,29 @@ app.mount("/templates", StaticFiles(directory="templates"), name="templates")
41
  async def index(request: Request):
42
  return templates.TemplateResponse("index.html.jinja", {"request": request})
43
 
 
44
  @app.get("/model")
45
  async def index(request: Request):
46
  return templates.TemplateResponse("index.html.jinja", {"request": request})
47
 
 
48
  @app.get("/predict")
49
  async def index(request: Request):
50
  return templates.TemplateResponse("index.html.jinja", {"request": request})
51
 
52
 
53
  @app.post("/model")
54
- async def choix_model(request: Request, choixModel:str = Form(None)):
55
  print(choixModel)
56
  if not choixModel:
57
  erreur_modele = "Merci de saisir un modèle."
58
  return templates.TemplateResponse(
59
- "index.html.jinja", {"request": request, "text": erreur_modele}
60
  )
61
- else :
62
  choisir_modele(choixModel)
63
  print("C'est bon on utilise le modèle demandé")
64
- return templates.TemplateResponse(
65
- "index.html.jinja", {"request": request}
66
- )
67
 
68
 
69
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
 
 
1
  from fastapi import FastAPI, Form, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.templating import Jinja2Templates
 
4
 
5
+ from inference_lstm import inferenceAPI
6
+ from inference_t5 import inferenceAPI
7
 
8
 
9
  # ------ INFERENCE MODEL --------------------------------------------------------------
10
  # appel de la fonction inference, adaptee pour une entree txt
11
  def summarize(text: str):
12
+ if choisir_modele.var == "lstm":
13
  return " ".join(inferenceAPI(text))
14
  elif choisir_modele.var == "fineTunedT5":
15
  text = inferenceAPI_t5(text)
16
 
17
+
18
  # ----------------------------------------------------------------------------------
19
 
20
 
21
  def choisir_modele(choixModele):
22
  print("ON A RECUP LE CHOIX MODELE")
23
+ if choixModele == "lstm":
24
+ choisir_modele.var = "lstm"
25
  elif choixModele == "fineTunedT5":
26
  choisir_modele.var = "fineTunedT5"
27
+ else:
28
  "le modele n'est pas defini"
29
 
30
 
 
40
  async def index(request: Request):
41
  return templates.TemplateResponse("index.html.jinja", {"request": request})
42
 
43
+
44
  @app.get("/model")
45
  async def index(request: Request):
46
  return templates.TemplateResponse("index.html.jinja", {"request": request})
47
 
48
+
49
  @app.get("/predict")
50
  async def index(request: Request):
51
  return templates.TemplateResponse("index.html.jinja", {"request": request})
52
 
53
 
54
  @app.post("/model")
55
+ async def choix_model(request: Request, choixModel: str = Form(None)):
56
  print(choixModel)
57
  if not choixModel:
58
  erreur_modele = "Merci de saisir un modèle."
59
  return templates.TemplateResponse(
60
+ "index.html.jinja", {"request": request, "text": erreur_modele}
61
  )
62
+ else:
63
  choisir_modele(choixModel)
64
  print("C'est bon on utilise le modèle demandé")
65
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
 
 
66
 
67
 
68
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
src/dataloader.py CHANGED
@@ -52,10 +52,15 @@ class Data(torch.utils.data.Dataset):
52
 
53
  def __getitem__(self, idx):
54
  row = self.data.iloc[idx]
55
- text = row["text"].translate(str.maketrans("", "", string.punctuation)).split()
 
 
56
  summary = (
57
- row["summary"].translate(str.maketrans("", "", string.punctuation)).split()
58
- )
 
 
 
59
  summary = ["<start>", *summary, "<end>"]
60
  sample = {"text": text, "summary": summary}
61
 
@@ -106,7 +111,8 @@ class Data(torch.utils.data.Dataset):
106
  tokenized_texts.append(text)
107
 
108
  if text_type == "summary":
109
- return [["<start>", *summary, "<end>"] for summary in tokenized_texts]
 
110
  return tokenized_texts
111
 
112
  def get_words(self) -> list:
@@ -157,8 +163,10 @@ class Vectoriser:
157
 
158
  def __init__(self, vocab=None) -> None:
159
  self.vocab = vocab
160
- self.word_count = Counter(word.lower().strip(",.\\-") for word in self.vocab)
161
- self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
 
 
162
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
163
 
164
  def load(self, path):
@@ -167,7 +175,8 @@ class Vectoriser:
167
  self.word_count = Counter(
168
  word.lower().strip(",.\\-") for word in self.vocab
169
  )
170
- self.idx_to_token = sorted([t for t, c in self.word_count.items() if c > 1])
 
171
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
172
 
173
  def save(self, path):
 
52
 
53
  def __getitem__(self, idx):
54
  row = self.data.iloc[idx]
55
+ text = row["text"].translate(
56
+ str.maketrans(
57
+ "", "", string.punctuation)).split()
58
  summary = (
59
+ row["summary"].translate(
60
+ str.maketrans(
61
+ "",
62
+ "",
63
+ string.punctuation)).split())
64
  summary = ["<start>", *summary, "<end>"]
65
  sample = {"text": text, "summary": summary}
66
 
 
111
  tokenized_texts.append(text)
112
 
113
  if text_type == "summary":
114
+ return [["<start>", *summary, "<end>"]
115
+ for summary in tokenized_texts]
116
  return tokenized_texts
117
 
118
  def get_words(self) -> list:
 
163
 
164
  def __init__(self, vocab=None) -> None:
165
  self.vocab = vocab
166
+ self.word_count = Counter(word.lower().strip(",.\\-")
167
+ for word in self.vocab)
168
+ self.idx_to_token = sorted(
169
+ [t for t, c in self.word_count.items() if c > 1])
170
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
171
 
172
  def load(self, path):
 
175
  self.word_count = Counter(
176
  word.lower().strip(",.\\-") for word in self.vocab
177
  )
178
+ self.idx_to_token = sorted(
179
+ [t for t, c in self.word_count.items() if c > 1])
180
  self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
181
 
182
  def save(self, path):
src/fine_tune_T5.py CHANGED
@@ -1,49 +1,55 @@
1
- import re
2
  import os
 
3
  import string
 
4
  import contractions
5
- import torch
6
  import datasets
7
- from datasets import Dataset
8
  import pandas as pd
 
 
9
  from tqdm import tqdm
10
- import evaluate
11
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
12
- from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
13
- from transformers import DataCollatorForSeq2Seq
14
 
15
 
16
  def clean_text(texts):
17
- '''This fonction makes clean text for the future use'''
18
  texts = texts.lower()
19
  texts = contractions.fix(texts)
20
  texts = texts.translate(str.maketrans("", "", string.punctuation))
21
- texts = re.sub(r'\n', ' ', texts)
22
  return texts
23
 
24
 
25
  def datasetmaker(path=str):
26
- '''This fonction take the jsonl file, read it to a dataframe,
27
- remove the colums not needed for the task and turn it into a file type Dataset
28
- '''
29
  data = pd.read_json(path, lines=True)
30
- df = data.drop(['url',
31
- 'archive',
32
- 'title',
33
- 'date',
34
- 'compression',
35
- 'coverage',
36
- 'density',
37
- 'compression_bin',
38
- 'coverage_bin',
39
- 'density_bin'],
40
- axis=1)
 
 
 
 
41
  tqdm.pandas()
42
- df['text'] = df.text.apply(lambda texts: clean_text(texts))
43
- df['summary'] = df.summary.apply(lambda summary: clean_text(summary))
44
  dataset = Dataset.from_dict(df)
45
  return dataset
46
 
 
47
  # voir si le model par hasard esr déjà bien
48
 
49
  # test_text = dataset['text'][0]
@@ -60,20 +66,33 @@ def generate_batch_sized_chunks(list_elements, batch_size):
60
  yield list_elements[i: i + batch_size]
61
 
62
 
63
- def calculate_metric(dataset, metric, model, tokenizer,
64
- batch_size, device,
65
- column_text='text',
66
- column_summary='summary'):
 
 
 
 
 
 
67
  article_batches = list(
68
- str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
 
69
  target_batches = list(
70
- str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
 
71
 
72
  for article_batch, target_batch in tqdm(
73
- zip(article_batches, target_batches), total=len(article_batches)):
74
-
75
- inputs = tokenizer(article_batch, max_length=1024, truncation=True,
76
- padding="max_length", return_tensors="pt")
 
 
 
 
 
77
  # parameter for length penalty ensures that the model does not
78
  # generate sequences that are too long.
79
  summaries = model.generate(
@@ -81,16 +100,18 @@ def calculate_metric(dataset, metric, model, tokenizer,
81
  attention_mask=inputs["attention_mask"].to(device),
82
  length_penalty=0.8,
83
  num_beams=8,
84
- max_length=128)
 
85
 
86
  # Décode les textes
87
  # renplacer les tokens, ajouter des textes décodés avec les rédéfences
88
  # vers la métrique.
89
  decoded_summaries = [
90
  tokenizer.decode(
91
- s,
92
- skip_special_tokens=True,
93
- clean_up_tokenization_spaces=True) for s in summaries]
 
94
 
95
  decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
96
 
@@ -100,59 +121,60 @@ def calculate_metric(dataset, metric, model, tokenizer,
100
 
101
  # compute et return les ROUGE scores.
102
  results = metric.compute()
103
- rouge_names = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
104
  rouge_dict = dict((rn, results[rn]) for rn in rouge_names)
105
- return pd.DataFrame(rouge_dict, index=['T5'])
106
 
107
 
108
  def convert_ex_to_features(example_batch):
109
- input_encodings = tokenizer(example_batch['text'],
110
- max_length=1024, truncation=True)
 
 
111
 
112
  labels = tokenizer(
113
- example_batch['summary'],
114
  max_length=128,
115
  truncation=True)
116
 
117
  return {
118
- 'input_ids': input_encodings['input_ids'],
119
- 'attention_mask': input_encodings['attention_mask'],
120
- 'labels': labels['input_ids']
121
  }
122
 
123
 
124
- if __name__ == '__main__':
 
125
 
126
- train_dataset = datasetmaker('data/train_extract.jsonl')
127
 
128
- dev_dataset = datasetmaker('data/dev_extract.jsonl')
129
 
130
- test_dataset = datasetmaker('data/test_extract.jsonl')
131
-
132
- dataset = datasets.DatasetDict({'train': train_dataset,
133
- 'dev': dev_dataset, 'test': test_dataset})
134
 
135
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136
 
137
- tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
138
  mt5_config = AutoConfig.from_pretrained(
139
- 'google/mt5-small',
140
  max_length=128,
141
  length_penalty=0.6,
142
  no_repeat_ngram_size=2,
143
  num_beams=15,
144
  )
145
- model = (AutoModelForSeq2SeqLM
146
- .from_pretrained('google/mt5-small', config=mt5_config)
147
- .to(device))
148
 
149
  dataset_pt = dataset.map(
150
  convert_ex_to_features,
151
- remove_columns=[
152
- "summary",
153
- "text"],
154
  batched=True,
155
- batch_size=128)
 
156
 
157
  data_collator = DataCollatorForSeq2Seq(
158
  tokenizer, model=model, return_tensors="pt")
@@ -182,8 +204,8 @@ if __name__ == '__main__':
182
  args=training_args,
183
  data_collator=data_collator,
184
  # compute_metrics = calculate_metric,
185
- train_dataset=dataset_pt['train'],
186
- eval_dataset=dataset_pt['dev'].select(range(10)),
187
  tokenizer=tokenizer,
188
  )
189
 
@@ -197,8 +219,9 @@ if __name__ == '__main__':
197
  tokenizer,
198
  batch_size=2,
199
  device=device,
200
- column_text='text',
201
- column_summary='summary')
 
202
  print(score)
203
 
204
  # Fine Tuning terminés et à sauvgarder
@@ -211,9 +234,7 @@ if __name__ == '__main__':
211
  trainer.model.save_pretrained("t5_summary")
212
  tokenizer.save_pretrained("t5_summary")
213
  # load local model
214
- model = (AutoModelForSeq2SeqLM
215
- .from_pretrained("t5_summary")
216
- .to(device))
217
 
218
  # mettre en usage : TEST
219
 
 
 
1
  import os
2
+ import re
3
  import string
4
+
5
  import contractions
 
6
  import datasets
7
+ import evaluate
8
  import pandas as pd
9
+ import torch
10
+ from datasets import Dataset
11
  from tqdm import tqdm
12
+ from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer,
13
+ DataCollatorForSeq2Seq, Seq2SeqTrainer,
14
+ Seq2SeqTrainingArguments)
 
15
 
16
 
17
  def clean_text(texts):
18
+ """This fonction makes clean text for the future use"""
19
  texts = texts.lower()
20
  texts = contractions.fix(texts)
21
  texts = texts.translate(str.maketrans("", "", string.punctuation))
22
+ texts = re.sub(r"\n", " ", texts)
23
  return texts
24
 
25
 
26
  def datasetmaker(path=str):
27
+ """This fonction take the jsonl file, read it to a dataframe,
28
+ remove the colums not needed for the task and turn it into a file type Dataset
29
+ """
30
  data = pd.read_json(path, lines=True)
31
+ df = data.drop(
32
+ [
33
+ "url",
34
+ "archive",
35
+ "title",
36
+ "date",
37
+ "compression",
38
+ "coverage",
39
+ "density",
40
+ "compression_bin",
41
+ "coverage_bin",
42
+ "density_bin",
43
+ ],
44
+ axis=1,
45
+ )
46
  tqdm.pandas()
47
+ df["text"] = df.text.apply(lambda texts: clean_text(texts))
48
+ df["summary"] = df.summary.apply(lambda summary: clean_text(summary))
49
  dataset = Dataset.from_dict(df)
50
  return dataset
51
 
52
+
53
  # voir si le model par hasard esr déjà bien
54
 
55
  # test_text = dataset['text'][0]
 
66
  yield list_elements[i: i + batch_size]
67
 
68
 
69
+ def calculate_metric(
70
+ dataset,
71
+ metric,
72
+ model,
73
+ tokenizer,
74
+ batch_size,
75
+ device,
76
+ column_text="text",
77
+ column_summary="summary",
78
+ ):
79
  article_batches = list(
80
+ str(generate_batch_sized_chunks(dataset[column_text], batch_size))
81
+ )
82
  target_batches = list(
83
+ str(generate_batch_sized_chunks(dataset[column_summary], batch_size))
84
+ )
85
 
86
  for article_batch, target_batch in tqdm(
87
+ zip(article_batches, target_batches), total=len(article_batches)
88
+ ):
89
+ inputs = tokenizer(
90
+ article_batch,
91
+ max_length=1024,
92
+ truncation=True,
93
+ padding="max_length",
94
+ return_tensors="pt",
95
+ )
96
  # parameter for length penalty ensures that the model does not
97
  # generate sequences that are too long.
98
  summaries = model.generate(
 
100
  attention_mask=inputs["attention_mask"].to(device),
101
  length_penalty=0.8,
102
  num_beams=8,
103
+ max_length=128,
104
+ )
105
 
106
  # Décode les textes
107
  # renplacer les tokens, ajouter des textes décodés avec les rédéfences
108
  # vers la métrique.
109
  decoded_summaries = [
110
  tokenizer.decode(
111
+ s, skip_special_tokens=True, clean_up_tokenization_spaces=True
112
+ )
113
+ for s in summaries
114
+ ]
115
 
116
  decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
117
 
 
121
 
122
  # compute et return les ROUGE scores.
123
  results = metric.compute()
124
+ rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
125
  rouge_dict = dict((rn, results[rn]) for rn in rouge_names)
126
+ return pd.DataFrame(rouge_dict, index=["T5"])
127
 
128
 
129
  def convert_ex_to_features(example_batch):
130
+ input_encodings = tokenizer(
131
+ example_batch["text"],
132
+ max_length=1024,
133
+ truncation=True)
134
 
135
  labels = tokenizer(
136
+ example_batch["summary"],
137
  max_length=128,
138
  truncation=True)
139
 
140
  return {
141
+ "input_ids": input_encodings["input_ids"],
142
+ "attention_mask": input_encodings["attention_mask"],
143
+ "labels": labels["input_ids"],
144
  }
145
 
146
 
147
+ if __name__ == "__main__":
148
+ train_dataset = datasetmaker("data/train_extract.jsonl")
149
 
150
+ dev_dataset = datasetmaker("data/dev_extract.jsonl")
151
 
152
+ test_dataset = datasetmaker("data/test_extract.jsonl")
153
 
154
+ dataset = datasets.DatasetDict(
155
+ {"train": train_dataset, "dev": dev_dataset, "test": test_dataset}
156
+ )
 
157
 
158
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
159
 
160
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
161
  mt5_config = AutoConfig.from_pretrained(
162
+ "google/mt5-small",
163
  max_length=128,
164
  length_penalty=0.6,
165
  no_repeat_ngram_size=2,
166
  num_beams=15,
167
  )
168
+ model = AutoModelForSeq2SeqLM.from_pretrained(
169
+ "google/mt5-small", config=mt5_config
170
+ ).to(device)
171
 
172
  dataset_pt = dataset.map(
173
  convert_ex_to_features,
174
+ remove_columns=["summary", "text"],
 
 
175
  batched=True,
176
+ batch_size=128,
177
+ )
178
 
179
  data_collator = DataCollatorForSeq2Seq(
180
  tokenizer, model=model, return_tensors="pt")
 
204
  args=training_args,
205
  data_collator=data_collator,
206
  # compute_metrics = calculate_metric,
207
+ train_dataset=dataset_pt["train"],
208
+ eval_dataset=dataset_pt["dev"].select(range(10)),
209
  tokenizer=tokenizer,
210
  )
211
 
 
219
  tokenizer,
220
  batch_size=2,
221
  device=device,
222
+ column_text="text",
223
+ column_summary="summary",
224
+ )
225
  print(score)
226
 
227
  # Fine Tuning terminés et à sauvgarder
 
234
  trainer.model.save_pretrained("t5_summary")
235
  tokenizer.save_pretrained("t5_summary")
236
  # load local model
237
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5_summary").to(device)
 
 
238
 
239
  # mettre en usage : TEST
240
 
src/fine_tune_t5.py DELETED
@@ -1,204 +0,0 @@
1
- import torch
2
- import datasets
3
- from datasets import Dataset, DatasetDict
4
- import pandas as pd
5
- from tqdm import tqdm
6
- import re
7
- import os
8
- import nltk
9
- import string
10
- import contractions
11
- from transformers import pipeline
12
- import evaluate
13
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
14
- from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
15
- from transformers import DataCollatorForSeq2Seq
16
-
17
- # cuda out of memory
18
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:200"
19
-
20
- nltk.download('stopwords')
21
- nltk.download('punkt')
22
-
23
-
24
- def clean_data(texts):
25
- texts = texts.lower()
26
- texts = contractions.fix(texts)
27
- texts = texts.translate(str.maketrans("", "", string.punctuation))
28
- texts = re.sub(r'\n',' ',texts)
29
- return texts
30
-
31
- def datasetmaker (path=str):
32
- data = pd.read_json(path, lines=True)
33
- df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
34
- tqdm.pandas()
35
- df['text'] = df.text.apply(lambda texts : clean_data(texts))
36
- df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
37
- # df['text'] = df['text'].map(str)
38
- # df['summary'] = df['summary'].map(str)
39
- dataset = Dataset.from_dict(df)
40
- return dataset
41
-
42
- #voir si le model par hasard esr déjà bien
43
-
44
- # test_text = dataset['text'][0]
45
- # pipe = pipeline('summarization',model = model_ckpt)
46
- # pipe_out = pipe(test_text)
47
- # print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
48
- # print(dataset['summary'][0])
49
-
50
- def generate_batch_sized_chunks(list_elements, batch_size):
51
- """split the dataset into smaller batches that we can process simultaneously
52
- Yield successive batch-sized chunks from list_of_elements."""
53
- for i in range(0, len(list_elements), batch_size):
54
- yield list_elements[i : i + batch_size]
55
-
56
- def calculate_metric(dataset, metric, model, tokenizer,
57
- batch_size, device,
58
- column_text='text',
59
- column_summary='summary'):
60
- article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
61
- target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
62
-
63
- for article_batch, target_batch in tqdm(
64
- zip(article_batches, target_batches), total=len(article_batches)):
65
-
66
- inputs = tokenizer(article_batch, max_length=1024, truncation=True,
67
- padding="max_length", return_tensors="pt")
68
-
69
- summaries = model.generate(input_ids=inputs["input_ids"].to(device),
70
- attention_mask=inputs["attention_mask"].to(device),
71
- length_penalty=0.8, num_beams=8, max_length=128)
72
- ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
73
-
74
- # Décode les textes
75
- # renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
76
- decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
77
- clean_up_tokenization_spaces=True)
78
- for s in summaries]
79
-
80
- decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
81
-
82
-
83
- metric.add_batch(predictions=decoded_summaries, references=target_batch)
84
-
85
- #compute et return les ROUGE scores.
86
- results = metric.compute()
87
- rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
88
- rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
89
- return pd.DataFrame(rouge_dict, index = ['T5'])
90
-
91
-
92
- def convert_ex_to_features(example_batch):
93
- input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
94
-
95
- labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
96
-
97
- return {
98
- 'input_ids' : input_encodings['input_ids'],
99
- 'attention_mask': input_encodings['attention_mask'],
100
- 'labels': labels['input_ids']
101
- }
102
-
103
- if __name__=='__main__':
104
-
105
- train_dataset = datasetmaker('data/train_extract_100.jsonl')
106
-
107
- dev_dataset = datasetmaker('data/dev_extract_100.jsonl')
108
-
109
- test_dataset = datasetmaker('data/test_extract_100.jsonl')
110
-
111
- dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
112
-
113
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
-
115
- tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
116
- mt5_config = AutoConfig.from_pretrained(
117
- "google/mt5-small",
118
- max_length=128,
119
- length_penalty=0.6,
120
- no_repeat_ngram_size=2,
121
- num_beams=15,
122
- )
123
- model = (AutoModelForSeq2SeqLM
124
- .from_pretrained("google/mt5-small", config=mt5_config)
125
- .to(device))
126
-
127
- dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
128
-
129
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
130
-
131
-
132
- training_args = Seq2SeqTrainingArguments(
133
- output_dir = "mt5_sum",
134
- log_level = "error",
135
- num_train_epochs = 10,
136
- learning_rate = 5e-4,
137
- # lr_scheduler_type = "linear",
138
- warmup_steps = 0,
139
- optim = "adafactor",
140
- weight_decay = 0.01,
141
- per_device_train_batch_size = 2,
142
- per_device_eval_batch_size = 1,
143
- gradient_accumulation_steps = 16,
144
- evaluation_strategy = "steps",
145
- eval_steps = 100,
146
- predict_with_generate=True,
147
- generation_max_length = 128,
148
- save_steps = 500,
149
- logging_steps = 10,
150
- # push_to_hub = True
151
- )
152
-
153
-
154
- trainer = Seq2SeqTrainer(
155
- model = model,
156
- args = training_args,
157
- data_collator = data_collator,
158
- # compute_metrics = calculate_metric,
159
- train_dataset=dataset_pt['train'],
160
- eval_dataset=dataset_pt['dev'].select(range(10)),
161
- tokenizer = tokenizer,
162
- )
163
-
164
- trainer.train()
165
- rouge_metric = evaluate.load("rouge")
166
-
167
- score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
168
- batch_size=2, device=device,
169
- column_text='text',
170
- column_summary='summary')
171
- print (score)
172
-
173
-
174
- #Fine Tuning terminés et à sauvgarder
175
-
176
-
177
-
178
- # save fine-tuned model in local
179
- os.makedirs("./summarization_t5", exist_ok=True)
180
- if hasattr(trainer.model, "module"):
181
- trainer.model.module.save_pretrained("./summarization_t5")
182
- else:
183
- trainer.model.save_pretrained("./summarization_t5")
184
- tokenizer.save_pretrained("./summarization_t5")
185
- # load local model
186
- model = (AutoModelForSeq2SeqLM
187
- .from_pretrained("./summarization_t5")
188
- .to(device))
189
-
190
-
191
- # mettre en usage : TEST
192
-
193
-
194
- # gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
195
- # sample_text = dataset["test"][0]["text"]
196
- # reference = dataset["test"][0]["summary"]
197
- # pipe = pipeline("summarization", model='./summarization_t5')
198
-
199
- # print("Text:")
200
- # print(sample_text)
201
- # print("\nReference Summary:")
202
- # print(reference)
203
- # print("\nModel Summary:")
204
- # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference.py DELETED
@@ -1,56 +0,0 @@
1
- """
2
- Allows to predict the summary for a given entry text
3
- """
4
- import pickle
5
-
6
- import torch
7
-
8
- from src import dataloader
9
- from src.model import Decoder, Encoder, EncoderDecoderModel
10
-
11
- with open("model/vocab.pkl", "rb") as vocab:
12
- words = pickle.load(vocab)
13
- vectoriser = dataloader.Vectoriser(words)
14
-
15
-
16
- def inferenceAPI(text: str) -> str:
17
- """
18
- Predict the summary for an input text
19
- --------
20
- Parameter
21
- text: str
22
- the text to sumarize
23
- Return
24
- str
25
- The summary for the input text
26
- """
27
- text = text.split()
28
- # On défini les paramètres d'entrée pour le modèle
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
31
- encoder.to(device)
32
- decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
33
- decoder.to(device)
34
-
35
- # On instancie le modèle
36
- model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
37
-
38
- # model.load_state_dict(torch.load("model/model.pt", map_location=device))
39
- # model.eval()
40
- # model.to(device)
41
-
42
- # On vectorise le texte
43
- source = vectoriser.encode(text)
44
- source = source.to(device)
45
-
46
- # On fait passer le texte dans le modèle
47
- with torch.no_grad():
48
- output = model(source).to(device)
49
- output.to(device)
50
- output = output.argmax(dim=-1)
51
- return vectoriser.decode(output)
52
-
53
-
54
- # if __name__ == "__main__":
55
- # # inference()
56
- # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/inference_t5.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
4
- import torch
5
- import contractions
6
  import re
7
  import string
 
 
 
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
 
@@ -12,7 +13,7 @@ def clean_text(texts: str) -> str:
12
  texts = texts.lower()
13
  texts = contractions.fix(texts)
14
  texts = texts.translate(str.maketrans("", "", string.punctuation))
15
- texts = re.sub(r'\n', ' ', texts)
16
  return texts
17
 
18
 
@@ -31,32 +32,32 @@ def inferenceAPI(text: str) -> str:
31
  # On défini les paramètres d'entrée pour le modèle
32
  text = clean_text(text)
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary"))
35
  # load local model
36
- model = (AutoModelForSeq2SeqLM
37
- .from_pretrained("Linggg/t5_summary")
38
- .to(device))
39
 
40
  text_encoding = tokenizer(
41
  text,
42
  max_length=1024,
43
- padding='max_length',
44
  truncation=True,
45
  return_attention_mask=True,
46
  add_special_tokens=True,
47
- return_tensors='pt'
48
  )
49
  generated_ids = model.generate(
50
- input_ids=text_encoding['input_ids'],
51
- attention_mask=text_encoding['attention_mask'],
52
  max_length=128,
53
  num_beams=8,
54
  length_penalty=0.8,
55
- early_stopping=True
56
  )
57
 
58
  preds = [
59
- tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
 
60
  for gen_id in generated_ids
61
  ]
62
  return "".join(preds)
 
1
  """
2
  Allows to predict the summary for a given entry text
3
  """
 
 
4
  import re
5
  import string
6
+
7
+ import contractions
8
+ import torch
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
 
11
 
 
13
  texts = texts.lower()
14
  texts = contractions.fix(texts)
15
  texts = texts.translate(str.maketrans("", "", string.punctuation))
16
+ texts = re.sub(r"\n", " ", texts)
17
  return texts
18
 
19
 
 
32
  # On défini les paramètres d'entrée pour le modèle
33
  text = clean_text(text)
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary")
36
  # load local model
37
+ model = AutoModelForSeq2SeqLM.from_pretrained("Linggg/t5_summary").to(device)
 
 
38
 
39
  text_encoding = tokenizer(
40
  text,
41
  max_length=1024,
42
+ padding="max_length",
43
  truncation=True,
44
  return_attention_mask=True,
45
  add_special_tokens=True,
46
+ return_tensors="pt",
47
  )
48
  generated_ids = model.generate(
49
+ input_ids=text_encoding["input_ids"],
50
+ attention_mask=text_encoding["attention_mask"],
51
  max_length=128,
52
  num_beams=8,
53
  length_penalty=0.8,
54
+ early_stopping=True,
55
  )
56
 
57
  preds = [
58
+ tokenizer.decode(
59
+ gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
60
+ )
61
  for gen_id in generated_ids
62
  ]
63
  return "".join(preds)
src/model.py CHANGED
@@ -25,7 +25,8 @@ class Encoder(torch.nn.Module):
25
  # on s'en servira pour les mots inconnus
26
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
27
  self.embeddings.to(device)
28
- self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
 
29
  # Comme on va calculer la log-vraisemblance,
30
  # c'est le log-softmax qui nous intéresse
31
  self.dropout = torch.nn.Dropout(dropout)
@@ -61,7 +62,8 @@ class Decoder(torch.nn.Module):
61
  # on s'en servira pour les mots inconnus
62
  self.vocab_size = vocab_size
63
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
64
- self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
 
65
  self.output = torch.nn.Linear(hidden_size, vocab_size)
66
  # Comme on va calculer la log-vraisemblance,
67
  # c'est le log-softmax qui nous intéresse
@@ -100,32 +102,36 @@ class EncoderDecoderModel(torch.nn.Module):
100
  # The ratio must be inferior to 1 to allow text compression
101
  assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
102
 
103
- target_len = int(
104
- summary_len * source.shape[0]
105
- ) # Expected summary length (in words)
106
- target_vocab_size = self.decoder.vocab_size # Word Embedding length
107
 
108
- # Output of the right format (expected summmary length x word embedding length)
109
- # filled with zeros. On each iteration, we will replace one of the row of this
110
- # matrix with the choosen word embedding
 
111
  outputs = torch.zeros(target_len, target_vocab_size)
112
 
113
- # put the tensors on the device (useless if CPU bus very useful in case of GPU)
 
114
  outputs.to(self.device)
115
  source.to(self.device)
116
 
117
- # last hidden state of the encoder is used as the initial hidden state of the decoder
118
- hidden, cell = self.encoder(source) # Encode the input text
119
- input = self.vectoriser.encode(
120
- "<start>"
121
- ) # Encode the first word of the summary
 
 
122
 
123
  # put the tensors on the device
124
  hidden.to(self.device)
125
  cell.to(self.device)
126
  input.to(self.device)
127
 
128
- ### BEAM SEARCH ###
129
  # If you wonder, b stands for better
130
  values = None
131
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
@@ -134,14 +140,16 @@ class EncoderDecoderModel(torch.nn.Module):
134
  for i in range(1, target_len):
135
  # On va déterminer autant de mot que la taille du texte souhaité
136
  # insert input token embedding, previous hidden and previous cell states
137
- # receive output tensor (predictions) and new hidden and cell states.
 
138
 
139
  # replace predictions in a tensor holding predictions for each token
140
  # logging.debug(f"output : {output}")
141
 
142
  ####### DÉBUT DU BEAM SEARCH ##########
143
  if values is None:
144
- # On calcule une première fois les premières probabilité de mot après <start>
 
145
  output, hidden, cell = self.decoder(input, hidden, cell)
146
  output.to(self.device)
147
  b_hidden = hidden
@@ -152,7 +160,8 @@ class EncoderDecoderModel(torch.nn.Module):
152
  values, indices = output.topk(num_beams, sorted=True)
153
 
154
  else:
155
- # On instancie le dictionnaire qui contiendra les scores pour chaque possibilité
 
156
  scores = {}
157
 
158
  # Pour chacune des meilleures valeurs, on va calculer l'output
@@ -160,7 +169,8 @@ class EncoderDecoderModel(torch.nn.Module):
160
  indice.to(self.device)
161
 
162
  # On calcule l'output
163
- b_output, b_hidden, b_cell = self.decoder(indice, b_hidden, b_cell)
 
164
 
165
  # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
166
  # de force la probabilité du mot précédent à 0
@@ -179,7 +189,8 @@ class EncoderDecoderModel(torch.nn.Module):
179
  # Et du coup on rempli la place de i-1 à la place de i
180
  b_outputs[i - 1] = b_output.to(self.device)
181
 
182
- # On instancies nos nouvelles valeurs pour la prochaine itération
 
183
  values, indices = b_output.topk(num_beams, sorted=True)
184
 
185
  ##################################
 
25
  # on s'en servira pour les mots inconnus
26
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
27
  self.embeddings.to(device)
28
+ self.hidden = torch.nn.LSTM(
29
+ embeddings_dim, hidden_size, dropout=dropout)
30
  # Comme on va calculer la log-vraisemblance,
31
  # c'est le log-softmax qui nous intéresse
32
  self.dropout = torch.nn.Dropout(dropout)
 
62
  # on s'en servira pour les mots inconnus
63
  self.vocab_size = vocab_size
64
  self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
65
+ self.hidden = torch.nn.LSTM(
66
+ embeddings_dim, hidden_size, dropout=dropout)
67
  self.output = torch.nn.Linear(hidden_size, vocab_size)
68
  # Comme on va calculer la log-vraisemblance,
69
  # c'est le log-softmax qui nous intéresse
 
102
  # The ratio must be inferior to 1 to allow text compression
103
  assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
104
 
105
+ # Expected summary length (in words)
106
+ target_len = int(summary_len * source.shape[0])
107
+ # Word Embedding length
108
+ target_vocab_size = self.decoder.vocab_size
109
 
110
+ # Output of the right format (expected summmary length x word
111
+ # embedding length) filled with zeros. On each iteration, we
112
+ # will replace one of the row of this matrix with the choosen
113
+ # word embedding
114
  outputs = torch.zeros(target_len, target_vocab_size)
115
 
116
+ # put the tensors on the device (useless if CPU bus very useful in
117
+ # case of GPU)
118
  outputs.to(self.device)
119
  source.to(self.device)
120
 
121
+ # last hidden state of the encoder is used
122
+ # as the initial hidden state of the decoder
123
+
124
+ # Encode the input text
125
+ hidden, cell = self.encoder(source)
126
+ # Encode the first word of the summary
127
+ input = self.vectoriser.encode("<start>")
128
 
129
  # put the tensors on the device
130
  hidden.to(self.device)
131
  cell.to(self.device)
132
  input.to(self.device)
133
 
134
+ # BEAM SEARCH #
135
  # If you wonder, b stands for better
136
  values = None
137
  b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
 
140
  for i in range(1, target_len):
141
  # On va déterminer autant de mot que la taille du texte souhaité
142
  # insert input token embedding, previous hidden and previous cell states
143
+ # receive output tensor (predictions) and new hidden and cell
144
+ # states.
145
 
146
  # replace predictions in a tensor holding predictions for each token
147
  # logging.debug(f"output : {output}")
148
 
149
  ####### DÉBUT DU BEAM SEARCH ##########
150
  if values is None:
151
+ # On calcule une première fois les premières probabilité de mot
152
+ # après <start>
153
  output, hidden, cell = self.decoder(input, hidden, cell)
154
  output.to(self.device)
155
  b_hidden = hidden
 
160
  values, indices = output.topk(num_beams, sorted=True)
161
 
162
  else:
163
+ # On instancie le dictionnaire qui contiendra les scores pour
164
+ # chaque possibilité
165
  scores = {}
166
 
167
  # Pour chacune des meilleures valeurs, on va calculer l'output
 
169
  indice.to(self.device)
170
 
171
  # On calcule l'output
172
+ b_output, b_hidden, b_cell = self.decoder(
173
+ indice, b_hidden, b_cell)
174
 
175
  # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
176
  # de force la probabilité du mot précédent à 0
 
189
  # Et du coup on rempli la place de i-1 à la place de i
190
  b_outputs[i - 1] = b_output.to(self.device)
191
 
192
+ # On instancies nos nouvelles valeurs pour la prochaine
193
+ # itération
194
  values, indices = b_output.topk(num_beams, sorted=True)
195
 
196
  ##################################
src/train.py CHANGED
@@ -150,16 +150,24 @@ if __name__ == "__main__":
150
  words = train_dataset.get_words()
151
  vectoriser = dataloader.Vectoriser(words)
152
 
153
- train_dataset = dataloader.Data("data/train_extract.jsonl", transform=vectoriser)
154
- dev_dataset = dataloader.Data("data/dev_extract.jsonl", transform=vectoriser)
 
 
 
 
155
 
156
  train_dataloader = torch.utils.data.DataLoader(
157
- train_dataset, batch_size=2, shuffle=True, collate_fn=dataloader.pad_collate
158
- )
 
 
159
 
160
  dev_dataloader = torch.utils.data.DataLoader(
161
- dev_dataset, batch_size=4, shuffle=True, collate_fn=dataloader.pad_collate
162
- )
 
 
163
 
164
  for i_batch, batch in enumerate(train_dataloader):
165
  print(i_batch, batch[0], batch[1])
@@ -169,7 +177,8 @@ if __name__ == "__main__":
169
  print("Device check. You are using:", device)
170
 
171
  ### RÉSEAU ENTRAÎNÉ ###
172
- # Pour s'assurer que les résultats seront les mêmes à chaque run du notebook
 
173
  torch.use_deterministic_algorithms(True)
174
  torch.manual_seed(0)
175
  random.seed(0)
@@ -178,9 +187,8 @@ if __name__ == "__main__":
178
  encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
179
  decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
180
 
181
- trained_classifier = EncoderDecoderModel(encoder, decoder, vectoriser, device).to(
182
- device
183
- )
184
 
185
  print(next(trained_classifier.parameters()).device)
186
  # print(train_dataset.is_cuda)
@@ -194,7 +202,6 @@ if __name__ == "__main__":
194
 
195
  torch.save(trained_classifier.state_dict(), "model/model.pt")
196
  vectoriser.save("model/vocab.pkl")
197
- trained_classifier.push_to_hub("SummaryProject-LSTM")
198
 
199
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
200
  print(
 
150
  words = train_dataset.get_words()
151
  vectoriser = dataloader.Vectoriser(words)
152
 
153
+ train_dataset = dataloader.Data(
154
+ "data/train_extract.jsonl",
155
+ transform=vectoriser)
156
+ dev_dataset = dataloader.Data(
157
+ "data/dev_extract.jsonl",
158
+ transform=vectoriser)
159
 
160
  train_dataloader = torch.utils.data.DataLoader(
161
+ train_dataset,
162
+ batch_size=2,
163
+ shuffle=True,
164
+ collate_fn=dataloader.pad_collate)
165
 
166
  dev_dataloader = torch.utils.data.DataLoader(
167
+ dev_dataset,
168
+ batch_size=4,
169
+ shuffle=True,
170
+ collate_fn=dataloader.pad_collate)
171
 
172
  for i_batch, batch in enumerate(train_dataloader):
173
  print(i_batch, batch[0], batch[1])
 
177
  print("Device check. You are using:", device)
178
 
179
  ### RÉSEAU ENTRAÎNÉ ###
180
+ # Pour s'assurer que les résultats seront les mêmes à chaque run du
181
+ # notebook
182
  torch.use_deterministic_algorithms(True)
183
  torch.manual_seed(0)
184
  random.seed(0)
 
187
  encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
188
  decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
189
 
190
+ trained_classifier = EncoderDecoderModel(
191
+ encoder, decoder, vectoriser, device).to(device)
 
192
 
193
  print(next(trained_classifier.parameters()).device)
194
  # print(train_dataset.is_cuda)
 
202
 
203
  torch.save(trained_classifier.state_dict(), "model/model.pt")
204
  vectoriser.save("model/vocab.pkl")
 
205
 
206
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
207
  print(