EveSa commited on
Commit
3e852e2
2 Parent(s): 89725f4 6158825

Merge pull request #11 from EveSa/revert-10-revert-9-Ling

Browse files
Files changed (3) hide show
  1. requirements.txt +82 -4
  2. src/fine_tune_T5.py +230 -0
  3. src/inference_t5.py +20 -15
requirements.txt CHANGED
@@ -1,15 +1,56 @@
 
 
 
 
 
1
  anyio==3.6.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  certifi==2022.12.7
3
  charset-normalizer==3.1.0
4
  click==8.1.3
5
  fastapi==0.92.0
6
  filelock==3.9.0
7
- h11==0.14.0
8
- huggingface-hub==0.13.1
9
  idna==3.4
 
 
10
  Jinja2==3.1.2
11
  joblib==1.2.0
 
 
 
12
  MarkupSafe==2.1.2
 
 
 
 
 
13
  numpy==1.24.2
14
  nvidia-cublas-cu11==11.10.3.66
15
  nvidia-cuda-nvrtc-cu11==11.7.99
@@ -17,15 +58,48 @@ nvidia-cuda-runtime-cu11==11.7.99
17
  nvidia-cudnn-cu11==8.5.0.96
18
  packaging==23.0
19
  pandas==1.5.3
20
- pydantic==1.10.5
 
 
 
 
 
 
 
 
 
21
  python-dateutil==2.8.2
22
- python-multipart==0.0.6
23
  pytz==2022.7.1
24
  PyYAML==6.0
25
  regex==2022.10.31
26
  requests==2.28.2
 
 
 
 
 
27
  six==1.16.0
 
28
  sniffio==1.3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  starlette==0.25.0
30
  tokenizers==0.13.2
31
  torch==1.13.1
@@ -33,3 +107,7 @@ tqdm==4.65.0
33
  typing_extensions==4.5.0
34
  urllib3==1.26.15
35
  uvicorn==0.20.0
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ alembic==1.9.4
5
+ anyascii==0.3.1
6
  anyio==3.6.2
7
+ async-timeout==4.0.2
8
+ attrs==22.2.0
9
+ banal==1.0.6
10
+ blis==0.7.9
11
+ catalogue==2.0.8
12
+ certifi==2022.12.7
13
+ charset-normalizer==3.0.1
14
+ click==8.1.3
15
+ confection==0.0.4
16
+ contourpy==1.0.7
17
+ contractions==0.1.73
18
+ cycler==0.11.0
19
+ cymem==2.0.7
20
+ dataloader==2.0
21
+ dataset==1.6.0
22
+ datasets==2.10.1
23
+ dill==0.3.6
24
+ en-core-web-lg==3.5.0
25
+ evaluate==0.4.0
26
+ fastapi==0.91.0
27
+ filelock==3.9.0
28
+ flake8==6.0.0
29
+ fonttools==4.38.0
30
+ frozenlist==1.3.3
31
+ fsspec==2023.3.0
32
+ greenlet==2.0.2
33
+ h11==0.14.0
34
+ huggingface-hub==0.12.1
35
  certifi==2022.12.7
36
  charset-normalizer==3.1.0
37
  click==8.1.3
38
  fastapi==0.92.0
39
  filelock==3.9.0
 
 
40
  idna==3.4
41
+ importlib-metadata==6.0.0
42
+ importlib-resources==5.12.0
43
  Jinja2==3.1.2
44
  joblib==1.2.0
45
+ kiwisolver==1.4.4
46
+ langcodes==3.3.0
47
+ Mako==1.2.4
48
  MarkupSafe==2.1.2
49
+ matplotlib==3.7.0
50
+ mccabe==0.7.0
51
+ multidict==6.0.4
52
+ multiprocess==0.70.14
53
+ murmurhash==1.0.9
54
  numpy==1.24.2
55
  nvidia-cublas-cu11==11.10.3.66
56
  nvidia-cuda-nvrtc-cu11==11.7.99
 
58
  nvidia-cudnn-cu11==8.5.0.96
59
  packaging==23.0
60
  pandas==1.5.3
61
+ pathy==0.10.1
62
+ Pillow==9.4.0
63
+ preshed==3.0.8
64
+ protobuf==3.20.0
65
+ pyahocorasick==2.0.0
66
+ pyarrow==11.0.0
67
+ pycodestyle==2.10.0
68
+ pydantic==1.10.4
69
+ pyflakes==3.0.1
70
+ pyparsing==3.0.9
71
  python-dateutil==2.8.2
72
+ python-multipart==0.0.5
73
  pytz==2022.7.1
74
  PyYAML==6.0
75
  regex==2022.10.31
76
  requests==2.28.2
77
+ responses==0.18.0
78
+ rouge-score==0.1.2
79
+ scikit-learn==1.2.1
80
+ scipy==1.10.0
81
+ sentencepiece==0.1.97
82
  six==1.16.0
83
+ smart-open==6.3.0
84
  sniffio==1.3.0
85
+ spacy==3.5.0
86
+ spacy-legacy==3.0.12
87
+ spacy-loggers==1.0.4
88
+ SQLAlchemy==1.4.46
89
+ srsly==2.4.5
90
+ starlette==0.24.0
91
+ summarizer==0.0.7
92
+ textsearch==0.0.24
93
+ thinc==8.1.7
94
+ threadpoolctl==3.1.0
95
+ tokenizers==0.13.2
96
+ tomli==2.0.1
97
+ torch==1.13.1
98
+ tqdm==4.64.1
99
+ transformers==4.26.1
100
+ typer==0.7.0
101
+ typing-extensions==4.4.0
102
+ urllib3==1.26.14
103
  starlette==0.25.0
104
  tokenizers==0.13.2
105
  torch==1.13.1
 
107
  typing_extensions==4.5.0
108
  urllib3==1.26.15
109
  uvicorn==0.20.0
110
+ wasabi==1.1.1
111
+ xxhash==3.2.0
112
+ yarl==1.8.2
113
+ zipp==3.14.0
src/fine_tune_T5.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import string
4
+ import contractions
5
+ import torch
6
+ import datasets
7
+ from datasets import Dataset
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ import evaluate
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
12
+ from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
13
+ from transformers import DataCollatorForSeq2Seq
14
+
15
+
16
+ def clean_text(texts):
17
+ '''This fonction makes clean text for the future use'''
18
+ texts = texts.lower()
19
+ texts = contractions.fix(texts)
20
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
21
+ texts = re.sub(r'\n', ' ', texts)
22
+ return texts
23
+
24
+
25
+ def datasetmaker(path=str):
26
+ '''This fonction take the jsonl file, read it to a dataframe,
27
+ remove the colums not needed for the task and turn it into a file type Dataset
28
+ '''
29
+ data = pd.read_json(path, lines=True)
30
+ df = data.drop(['url',
31
+ 'archive',
32
+ 'title',
33
+ 'date',
34
+ 'compression',
35
+ 'coverage',
36
+ 'density',
37
+ 'compression_bin',
38
+ 'coverage_bin',
39
+ 'density_bin'],
40
+ axis=1)
41
+ tqdm.pandas()
42
+ df['text'] = df.text.apply(lambda texts: clean_text(texts))
43
+ df['summary'] = df.summary.apply(lambda summary: clean_text(summary))
44
+ dataset = Dataset.from_dict(df)
45
+ return dataset
46
+
47
+ # voir si le model par hasard esr déjà bien
48
+
49
+ # test_text = dataset['text'][0]
50
+ # pipe = pipeline('summarization', model = model_ckpt)
51
+ # pipe_out = pipe(test_text)
52
+ # print(pipe_out[0]['summary_text'].replace('.<n>', '.\n'))
53
+ # print(dataset['summary'][0])
54
+
55
+
56
+ def generate_batch_sized_chunks(list_elements, batch_size):
57
+ """split the dataset into smaller batches that we can process simultaneously
58
+ Yield successive batch-sized chunks from list_of_elements."""
59
+ for i in range(0, len(list_elements), batch_size):
60
+ yield list_elements[i: i + batch_size]
61
+
62
+
63
+ def calculate_metric(dataset, metric, model, tokenizer,
64
+ batch_size, device,
65
+ column_text='text',
66
+ column_summary='summary'):
67
+ article_batches = list(
68
+ str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
69
+ target_batches = list(
70
+ str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
71
+
72
+ for article_batch, target_batch in tqdm(
73
+ zip(article_batches, target_batches), total=len(article_batches)):
74
+
75
+ inputs = tokenizer(article_batch, max_length=1024, truncation=True,
76
+ padding="max_length", return_tensors="pt")
77
+ # parameter for length penalty ensures that the model does not
78
+ # generate sequences that are too long.
79
+ summaries = model.generate(
80
+ input_ids=inputs["input_ids"].to(device),
81
+ attention_mask=inputs["attention_mask"].to(device),
82
+ length_penalty=0.8,
83
+ num_beams=8,
84
+ max_length=128)
85
+
86
+ # Décode les textes
87
+ # renplacer les tokens, ajouter des textes décodés avec les rédéfences
88
+ # vers la métrique.
89
+ decoded_summaries = [
90
+ tokenizer.decode(
91
+ s,
92
+ skip_special_tokens=True,
93
+ clean_up_tokenization_spaces=True) for s in summaries]
94
+
95
+ decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
96
+
97
+ metric.add_batch(
98
+ predictions=decoded_summaries,
99
+ references=target_batch)
100
+
101
+ # compute et return les ROUGE scores.
102
+ results = metric.compute()
103
+ rouge_names = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
104
+ rouge_dict = dict((rn, results[rn]) for rn in rouge_names)
105
+ return pd.DataFrame(rouge_dict, index=['T5'])
106
+
107
+
108
+ def convert_ex_to_features(example_batch):
109
+ input_encodings = tokenizer(example_batch['text'],
110
+ max_length=1024, truncation=True)
111
+
112
+ labels = tokenizer(
113
+ example_batch['summary'],
114
+ max_length=128,
115
+ truncation=True)
116
+
117
+ return {
118
+ 'input_ids': input_encodings['input_ids'],
119
+ 'attention_mask': input_encodings['attention_mask'],
120
+ 'labels': labels['input_ids']
121
+ }
122
+
123
+
124
+ if __name__ == '__main__':
125
+
126
+ train_dataset = datasetmaker('data/train_extract.jsonl')
127
+
128
+ dev_dataset = datasetmaker('data/dev_extract.jsonl')
129
+
130
+ test_dataset = datasetmaker('data/test_extract.jsonl')
131
+
132
+ dataset = datasets.DatasetDict({'train': train_dataset,
133
+ 'dev': dev_dataset, 'test': test_dataset})
134
+
135
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136
+
137
+ tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
138
+ mt5_config = AutoConfig.from_pretrained(
139
+ 'google/mt5-small',
140
+ max_length=128,
141
+ length_penalty=0.6,
142
+ no_repeat_ngram_size=2,
143
+ num_beams=15,
144
+ )
145
+ model = (AutoModelForSeq2SeqLM
146
+ .from_pretrained('google/mt5-small', config=mt5_config)
147
+ .to(device))
148
+
149
+ dataset_pt = dataset.map(
150
+ convert_ex_to_features,
151
+ remove_columns=[
152
+ "summary",
153
+ "text"],
154
+ batched=True,
155
+ batch_size=128)
156
+
157
+ data_collator = DataCollatorForSeq2Seq(
158
+ tokenizer, model=model, return_tensors="pt")
159
+
160
+ training_args = Seq2SeqTrainingArguments(
161
+ output_dir="t5_summary",
162
+ log_level="error",
163
+ num_train_epochs=10,
164
+ learning_rate=5e-4,
165
+ warmup_steps=0,
166
+ optim="adafactor",
167
+ weight_decay=0.01,
168
+ per_device_train_batch_size=2,
169
+ per_device_eval_batch_size=1,
170
+ gradient_accumulation_steps=16,
171
+ evaluation_strategy="steps",
172
+ eval_steps=100,
173
+ predict_with_generate=True,
174
+ generation_max_length=128,
175
+ save_steps=500,
176
+ logging_steps=10,
177
+ # push_to_hub = True
178
+ )
179
+
180
+ trainer = Seq2SeqTrainer(
181
+ model=model,
182
+ args=training_args,
183
+ data_collator=data_collator,
184
+ # compute_metrics = calculate_metric,
185
+ train_dataset=dataset_pt['train'],
186
+ eval_dataset=dataset_pt['dev'].select(range(10)),
187
+ tokenizer=tokenizer,
188
+ )
189
+
190
+ trainer.train()
191
+ rouge_metric = evaluate.load("rouge")
192
+
193
+ score = calculate_metric(
194
+ test_dataset,
195
+ rouge_metric,
196
+ trainer.model,
197
+ tokenizer,
198
+ batch_size=2,
199
+ device=device,
200
+ column_text='text',
201
+ column_summary='summary')
202
+ print(score)
203
+
204
+ # Fine Tuning terminés et à sauvgarder
205
+
206
+ # save fine-tuned model in local
207
+ os.makedirs("t5_summary", exist_ok=True)
208
+ if hasattr(trainer.model, "module"):
209
+ trainer.model.module.save_pretrained("t5_summary")
210
+ else:
211
+ trainer.model.save_pretrained("t5_summary")
212
+ tokenizer.save_pretrained("t5_summary")
213
+ # load local model
214
+ model = (AutoModelForSeq2SeqLM
215
+ .from_pretrained("t5_summary")
216
+ .to(device))
217
+
218
+ # mettre en usage : TEST
219
+
220
+ # gen_kwargs = {"length_penalty" : 0.8, "num_beams" : 8, "max_length" : 128}
221
+ # sample_text = dataset["test"][0]["text"]
222
+ # reference = dataset["test"][0]["summary"]
223
+ # pipe = pipeline("summarization", model='./summarization_t5')
224
+
225
+ # print("Text :")
226
+ # print(sample_text)
227
+ # print("\nReference Summary :")
228
+ # print(reference)
229
+ # print("\nModel Summary :")
230
+ # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
src/inference_t5.py CHANGED
@@ -7,14 +7,16 @@ import re
7
  import string
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
- def clean_data(texts):
 
11
  texts = texts.lower()
12
  texts = contractions.fix(texts)
13
  texts = texts.translate(str.maketrans("", "", string.punctuation))
14
- texts = re.sub(r'\n',' ',texts)
15
  return texts
16
 
17
- def inferenceAPI_t5(text: str) -> str:
 
18
  """
19
  Predict the summary for an input text
20
  --------
@@ -25,14 +27,16 @@ def inferenceAPI_t5(text: str) -> str:
25
  str
26
  The summary for the input text
27
  """
28
- # definition des parametres d'entree pour le modèle
29
- text = clean_data(text)
30
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
31
- tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
32
- # chargement du modele local
 
33
  model = (AutoModelForSeq2SeqLM
34
- .from_pretrained("./summarization_t5")
35
- .to(device))
 
36
  text_encoding = tokenizer(
37
  text,
38
  max_length=1024,
@@ -52,11 +56,12 @@ def inferenceAPI_t5(text: str) -> str:
52
  )
53
 
54
  preds = [
55
- tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
56
- for gen_id in generated_ids
57
  ]
58
  return "".join(preds)
59
 
60
- if __name__ == "__main__":
61
- text = input('Entrez votre phrase à résumer : ')
62
- print('summary:',inferenceAPI(text))
 
 
7
  import string
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
+
11
+ def clean_text(texts: str) -> str:
12
  texts = texts.lower()
13
  texts = contractions.fix(texts)
14
  texts = texts.translate(str.maketrans("", "", string.punctuation))
15
+ texts = re.sub(r'\n', ' ', texts)
16
  return texts
17
 
18
+
19
+ def inferenceAPI(text: str) -> str:
20
  """
21
  Predict the summary for an input text
22
  --------
 
27
  str
28
  The summary for the input text
29
  """
30
+
31
+ # On défini les paramètres d'entrée pour le modèle
32
+ text = clean_text(text)
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary"))
35
+ # load local model
36
  model = (AutoModelForSeq2SeqLM
37
+ .from_pretrained("Linggg/t5_summary")
38
+ .to(device))
39
+
40
  text_encoding = tokenizer(
41
  text,
42
  max_length=1024,
 
56
  )
57
 
58
  preds = [
59
+ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
60
+ for gen_id in generated_ids
61
  ]
62
  return "".join(preds)
63
 
64
+
65
+ # if __name__ == "__main__":
66
+ # text = input('Entrez votre phrase à résumer : ')
67
+ # print('summary:', inferenceAPI(text))