Linggg commited on
Commit
41508f8
·
1 Parent(s): 36717d5

model t5 fonctionel

Browse files
Files changed (3) hide show
  1. requirements.txt +22 -22
  2. src/fine_tune_T5.py +204 -0
  3. src/interface_t5.py +65 -0
requirements.txt CHANGED
@@ -1,27 +1,27 @@
1
- anyio==3.6.2
2
- click==8.1.3
3
- fastapi==0.92.0
4
- h11==0.14.0
5
- idna==3.4
6
- Jinja2==3.1.2
7
- joblib==1.2.0
8
- MarkupSafe==2.1.2
 
 
 
 
 
 
9
  nltk==3.8.1
10
  numpy==1.24.2
11
- nvidia-cublas-cu11==11.10.3.66
12
- nvidia-cuda-nvrtc-cu11==11.7.99
13
- nvidia-cuda-runtime-cu11==11.7.99
14
- nvidia-cudnn-cu11==8.5.0.96
15
  pandas==1.5.3
16
- pydantic==1.10.5
17
- python-dateutil==2.8.2
18
- python-multipart==0.0.6
19
- pytz==2022.7.1
20
- regex==2022.10.31
21
- six==1.16.0
22
- sniffio==1.3.0
23
- starlette==0.25.0
24
  torch==1.13.1
25
  tqdm==4.65.0
26
- typing_extensions==4.5.0
27
- uvicorn==0.20.0
 
 
 
1
+ brotli==1.0.9
2
+ brotlicffi==1.0.9.2
3
+ chardet==5.1.0
4
+ contractions==0.1.73
5
+ cryptography==39.0.2
6
+ Cython==0.29.33
7
+ datasets==2.10.1
8
+ dl==0.1.0
9
+ evaluate==0.4.0
10
+ fastapi==0.94.0
11
+ ipaddr==2.2.0
12
+ keyring==23.13.1
13
+ mock==5.0.1
14
+ mypy_extensions==1.0.0
15
  nltk==3.8.1
16
  numpy==1.24.2
17
+ ordereddict==1.1
 
 
 
18
  pandas==1.5.3
19
+ protobuf==4.22.1
20
+ pyOpenSSL==23.0.0
21
+ simplejson==3.18.3
 
 
 
 
 
22
  torch==1.13.1
23
  tqdm==4.65.0
24
+ transformers==4.26.1
25
+ urllib3_secure_extra==0.1.0
26
+ uvicorn==0.21.0
27
+ wincertstore==0.2.1
src/fine_tune_T5.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datasets
3
+ from datasets import Dataset, DatasetDict
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ import nltk
9
+ import string
10
+ nltk.download('stopwords')
11
+ nltk.download('punkt')
12
+ import contractions
13
+ from transformers import pipeline
14
+
15
+ import evaluate
16
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
17
+ from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
18
+ # from transformers import TrainingArguments, Trainer
19
+ from transformers import DataCollatorForSeq2Seq
20
+
21
+
22
+
23
+
24
+ def clean_data(texts):
25
+ texts = texts.lower()
26
+ texts = contractions.fix(texts)
27
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
28
+ texts = re.sub(r'\n',' ',texts)
29
+ return texts
30
+
31
+ def datasetmaker (path=str):
32
+ data = pd.read_json(path, lines=True)
33
+ df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
34
+ tqdm.pandas()
35
+ df['text'] = df.text.apply(lambda texts : clean_data(texts))
36
+ df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
37
+ # df['text'] = df['text'].map(str)
38
+ # df['summary'] = df['summary'].map(str)
39
+ dataset = Dataset.from_dict(df)
40
+ return dataset
41
+
42
+ #voir si le model par hasard esr déjà bien
43
+
44
+ # test_text = dataset['text'][0]
45
+ # pipe = pipeline('summarization',model = model_ckpt)
46
+ # pipe_out = pipe(test_text)
47
+ # print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
48
+ # print(dataset['summary'][0])
49
+
50
+ def generate_batch_sized_chunks(list_elements, batch_size):
51
+ """split the dataset into smaller batches that we can process simultaneously
52
+ Yield successive batch-sized chunks from list_of_elements."""
53
+ for i in range(0, len(list_elements), batch_size):
54
+ yield list_elements[i : i + batch_size]
55
+
56
+ def calculate_metric(dataset, metric, model, tokenizer,
57
+ batch_size, device,
58
+ column_text='text',
59
+ column_summary='summary'):
60
+ article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
61
+ target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
62
+
63
+ for article_batch, target_batch in tqdm(
64
+ zip(article_batches, target_batches), total=len(article_batches)):
65
+
66
+ inputs = tokenizer(article_batch, max_length=1024, truncation=True,
67
+ padding="max_length", return_tensors="pt")
68
+
69
+ summaries = model.generate(input_ids=inputs["input_ids"].to(device),
70
+ attention_mask=inputs["attention_mask"].to(device),
71
+ length_penalty=0.8, num_beams=8, max_length=128)
72
+ ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
73
+
74
+ # Décode les textes
75
+ # renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
76
+ decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
77
+ clean_up_tokenization_spaces=True)
78
+ for s in summaries]
79
+
80
+ decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
81
+
82
+
83
+ metric.add_batch(predictions=decoded_summaries, references=target_batch)
84
+
85
+ #compute et return les ROUGE scores.
86
+ results = metric.compute()
87
+ rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
88
+ rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
89
+ return pd.DataFrame(rouge_dict, index = ['T5'])
90
+
91
+
92
+ def convert_ex_to_features(example_batch):
93
+ input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
94
+
95
+ labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
96
+
97
+ return {
98
+ 'input_ids' : input_encodings['input_ids'],
99
+ 'attention_mask': input_encodings['attention_mask'],
100
+ 'labels': labels['input_ids']
101
+ }
102
+
103
+ if __name__=='__main__':
104
+
105
+ train_dataset = datasetmaker('data/train_extract.jsonl')
106
+
107
+ dev_dataset = datasetmaker('data/dev_extract.jsonl')
108
+
109
+ test_dataset = datasetmaker('data/test_extract.jsonl')
110
+
111
+ dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
116
+ mt5_config = AutoConfig.from_pretrained(
117
+ "google/mt5-small",
118
+ max_length=128,
119
+ length_penalty=0.6,
120
+ no_repeat_ngram_size=2,
121
+ num_beams=15,
122
+ )
123
+ model = (AutoModelForSeq2SeqLM
124
+ .from_pretrained("google/mt5-small", config=mt5_config)
125
+ .to(device))
126
+
127
+ dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
128
+
129
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
130
+
131
+
132
+ training_args = Seq2SeqTrainingArguments(
133
+ output_dir = "mt5_sum",
134
+ log_level = "error",
135
+ num_train_epochs = 10,
136
+ learning_rate = 5e-4,
137
+ # lr_scheduler_type = "linear",
138
+ warmup_steps = 0,
139
+ optim = "adafactor",
140
+ weight_decay = 0.01,
141
+ per_device_train_batch_size = 2,
142
+ per_device_eval_batch_size = 1,
143
+ gradient_accumulation_steps = 16,
144
+ evaluation_strategy = "steps",
145
+ eval_steps = 100,
146
+ predict_with_generate=True,
147
+ generation_max_length = 128,
148
+ save_steps = 500,
149
+ logging_steps = 10,
150
+ # push_to_hub = True
151
+ )
152
+
153
+
154
+ trainer = Seq2SeqTrainer(
155
+ model = model,
156
+ args = training_args,
157
+ data_collator = data_collator,
158
+ # compute_metrics = calculate_metric,
159
+ train_dataset=dataset_pt['train'],
160
+ eval_dataset=dataset_pt['dev'].select(range(10)),
161
+ tokenizer = tokenizer,
162
+ )
163
+
164
+ trainer.train()
165
+ rouge_metric = evaluate.load("rouge")
166
+
167
+ score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
168
+ batch_size=2, device=device,
169
+ column_text='text',
170
+ column_summary='summary')
171
+ print (score)
172
+
173
+
174
+ #Fine Tuning terminés et à sauvgarder
175
+
176
+
177
+
178
+ # save fine-tuned model in local
179
+ os.makedirs("./summarization_t5", exist_ok=True)
180
+ if hasattr(trainer.model, "module"):
181
+ trainer.model.module.save_pretrained("./summarization_t5")
182
+ else:
183
+ trainer.model.save_pretrained("./summarization_t5")
184
+ tokenizer.save_pretrained("./summarization_t5")
185
+ # load local model
186
+ model = (AutoModelForSeq2SeqLM
187
+ .from_pretrained("./summarization_t5")
188
+ .to(device))
189
+
190
+
191
+ # mettre en usage : TEST
192
+
193
+
194
+ # gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
195
+ # sample_text = dataset["test"][0]["text"]
196
+ # reference = dataset["test"][0]["summary"]
197
+ # pipe = pipeline("summarization", model='./summarization_t5')
198
+
199
+ # print("Text:")
200
+ # print(sample_text)
201
+ # print("\nReference Summary:")
202
+ # print(reference)
203
+ # print("\nModel Summary:")
204
+ # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
src/interface_t5.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Allows to predict the summary for a given entry text
3
+ """
4
+ import torch
5
+ import nltk
6
+ import contractions
7
+ import re
8
+ import string
9
+ nltk.download('stopwords')
10
+ nltk.download('punkt')
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+
13
+ def clean_data(texts):
14
+ texts = texts.lower()
15
+ texts = contractions.fix(texts)
16
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
17
+ texts = re.sub(r'\n',' ',texts)
18
+ return texts
19
+
20
+ def inferenceAPI(text: str) -> str:
21
+ """
22
+ Predict the summary for an input text
23
+ --------
24
+ Parameter
25
+ text: str
26
+ the text to sumarize
27
+ Return
28
+ str
29
+ The summary for the input text
30
+ """
31
+ # On défini les paramètres d'entrée pour le modèle
32
+ text = clean_data(text)
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
35
+ # load local model
36
+ model = (AutoModelForSeq2SeqLM
37
+ .from_pretrained("./summarization_t5")
38
+ .to(device))
39
+ text_encoding = tokenizer(
40
+ text,
41
+ max_length=1024,
42
+ padding='max_length',
43
+ truncation=True,
44
+ return_attention_mask=True,
45
+ add_special_tokens=True,
46
+ return_tensors='pt'
47
+ )
48
+ generated_ids = model.generate(
49
+ input_ids=text_encoding['input_ids'],
50
+ attention_mask=text_encoding['attention_mask'],
51
+ max_length=128,
52
+ num_beams=8,
53
+ length_penalty=0.8,
54
+ early_stopping=True
55
+ )
56
+
57
+ preds = [
58
+ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
59
+ for gen_id in generated_ids
60
+ ]
61
+ return "".join(preds)
62
+
63
+ if __name__ == "__main__":
64
+ text = input('Entrez votre phrase à résumer : ')
65
+ print('summary:',inferenceAPI(text))