Linggg commited on
Commit
5925e5f
1 Parent(s): d5d5a19

model t5 tout bon + mis sur huggingface

Browse files
Files changed (3) hide show
  1. requirements.txt +97 -19
  2. src/fine_tune_T5.py +136 -110
  3. src/inference_t5.py +15 -13
requirements.txt CHANGED
@@ -1,27 +1,105 @@
1
- brotli==1.0.9
2
- brotlicffi==1.0.9.2
3
- chardet==5.1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  contractions==0.1.73
5
- cryptography==39.0.2
6
- Cython==0.29.33
 
 
7
  datasets==2.10.1
8
- dl==0.1.0
 
9
  evaluate==0.4.0
10
- fastapi==0.94.0
11
- ipaddr==2.2.0
12
- keyring==23.13.1
13
- mock==5.0.1
14
- mypy_extensions==1.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  nltk==3.8.1
16
  numpy==1.24.2
17
- ordereddict==1.1
 
 
 
 
18
  pandas==1.5.3
19
- protobuf==4.22.1
20
- pyOpenSSL==23.0.0
21
- simplejson==3.18.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  torch==1.13.1
23
- tqdm==4.65.0
24
  transformers==4.26.1
25
- urllib3_secure_extra==0.1.0
26
- uvicorn==0.21.0
27
- wincertstore==0.2.1
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ alembic==1.9.4
5
+ anyascii==0.3.1
6
+ anyio==3.6.2
7
+ async-timeout==4.0.2
8
+ attrs==22.2.0
9
+ autopep8==2.0.2
10
+ banal==1.0.6
11
+ blis==0.7.9
12
+ catalogue==2.0.8
13
+ certifi==2022.12.7
14
+ charset-normalizer==3.0.1
15
+ click==8.1.3
16
+ confection==0.0.4
17
+ contourpy==1.0.7
18
  contractions==0.1.73
19
+ cycler==0.11.0
20
+ cymem==2.0.7
21
+ dataloader==2.0
22
+ dataset==1.6.0
23
  datasets==2.10.1
24
+ dill==0.3.6
25
+ en-core-web-lg==3.5.0
26
  evaluate==0.4.0
27
+ fastapi==0.91.0
28
+ filelock==3.9.0
29
+ flake8==6.0.0
30
+ fonttools==4.38.0
31
+ frozenlist==1.3.3
32
+ fsspec==2023.3.0
33
+ greenlet==2.0.2
34
+ h11==0.14.0
35
+ huggingface-hub==0.12.1
36
+ idna==3.4
37
+ importlib-metadata==6.0.0
38
+ importlib-resources==5.12.0
39
+ Jinja2==3.1.2
40
+ joblib==1.2.0
41
+ kiwisolver==1.4.4
42
+ langcodes==3.3.0
43
+ Mako==1.2.4
44
+ MarkupSafe==2.1.2
45
+ matplotlib==3.7.0
46
+ mccabe==0.7.0
47
+ multidict==6.0.4
48
+ multiprocess==0.70.14
49
+ murmurhash==1.0.9
50
  nltk==3.8.1
51
  numpy==1.24.2
52
+ nvidia-cublas-cu11==11.10.3.66
53
+ nvidia-cuda-nvrtc-cu11==11.7.99
54
+ nvidia-cuda-runtime-cu11==11.7.99
55
+ nvidia-cudnn-cu11==8.5.0.96
56
+ packaging==23.0
57
  pandas==1.5.3
58
+ pathy==0.10.1
59
+ Pillow==9.4.0
60
+ preshed==3.0.8
61
+ protobuf==3.20.0
62
+ pyahocorasick==2.0.0
63
+ pyarrow==11.0.0
64
+ pycodestyle==2.10.0
65
+ pydantic==1.10.4
66
+ pyflakes==3.0.1
67
+ pyparsing==3.0.9
68
+ python-dateutil==2.8.2
69
+ python-multipart==0.0.5
70
+ pytz==2022.7.1
71
+ PyYAML==6.0
72
+ regex==2022.10.31
73
+ requests==2.28.2
74
+ responses==0.18.0
75
+ rouge-score==0.1.2
76
+ scikit-learn==1.2.1
77
+ scipy==1.10.0
78
+ sentencepiece==0.1.97
79
+ six==1.16.0
80
+ sklearn==0.0.post1
81
+ smart-open==6.3.0
82
+ sniffio==1.3.0
83
+ spacy==3.5.0
84
+ spacy-legacy==3.0.12
85
+ spacy-loggers==1.0.4
86
+ SQLAlchemy==1.4.46
87
+ srsly==2.4.5
88
+ starlette==0.24.0
89
+ summarizer==0.0.7
90
+ textsearch==0.0.24
91
+ thinc==8.1.7
92
+ threadpoolctl==3.1.0
93
+ tokenizers==0.13.2
94
+ tomli==2.0.1
95
  torch==1.13.1
96
+ tqdm==4.64.1
97
  transformers==4.26.1
98
+ typer==0.7.0
99
+ typing-extensions==4.4.0
100
+ urllib3==1.26.14
101
+ uvicorn==0.20.0
102
+ wasabi==1.1.1
103
+ xxhash==3.2.0
104
+ yarl==1.8.2
105
+ zipp==3.14.0
src/fine_tune_T5.py CHANGED
@@ -1,106 +1,127 @@
1
- import torch
2
- import datasets
3
- from datasets import Dataset, DatasetDict
4
- import pandas as pd
5
- from tqdm import tqdm
6
  import re
7
  import os
8
- import nltk
9
  import string
10
- nltk.download('stopwords')
11
- nltk.download('punkt')
12
  import contractions
13
- from transformers import pipeline
14
-
 
 
 
15
  import evaluate
16
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
17
- from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
18
- # from transformers import TrainingArguments, Trainer
19
  from transformers import DataCollatorForSeq2Seq
20
 
21
 
22
-
23
-
24
- def clean_data(texts):
25
  texts = texts.lower()
26
  texts = contractions.fix(texts)
27
  texts = texts.translate(str.maketrans("", "", string.punctuation))
28
- texts = re.sub(r'\n',' ',texts)
29
  return texts
30
 
31
- def datasetmaker (path=str):
 
 
 
 
32
  data = pd.read_json(path, lines=True)
33
- df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
 
 
 
 
 
 
 
 
 
 
34
  tqdm.pandas()
35
- df['text'] = df.text.apply(lambda texts : clean_data(texts))
36
- df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
37
- # df['text'] = df['text'].map(str)
38
- # df['summary'] = df['summary'].map(str)
39
  dataset = Dataset.from_dict(df)
40
  return dataset
41
 
42
- #voir si le model par hasard esr déjà bien
43
 
44
  # test_text = dataset['text'][0]
45
- # pipe = pipeline('summarization',model = model_ckpt)
46
  # pipe_out = pipe(test_text)
47
- # print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
48
  # print(dataset['summary'][0])
49
 
 
50
  def generate_batch_sized_chunks(list_elements, batch_size):
51
  """split the dataset into smaller batches that we can process simultaneously
52
  Yield successive batch-sized chunks from list_of_elements."""
53
  for i in range(0, len(list_elements), batch_size):
54
- yield list_elements[i : i + batch_size]
 
55
 
56
  def calculate_metric(dataset, metric, model, tokenizer,
57
- batch_size, device,
58
- column_text='text',
59
- column_summary='summary'):
60
- article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
61
- target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
 
 
62
 
63
  for article_batch, target_batch in tqdm(
64
- zip(article_batches, target_batches), total=len(article_batches)):
65
-
66
- inputs = tokenizer(article_batch, max_length=1024, truncation=True,
67
- padding="max_length", return_tensors="pt")
68
-
69
- summaries = model.generate(input_ids=inputs["input_ids"].to(device),
70
- attention_mask=inputs["attention_mask"].to(device),
71
- length_penalty=0.8, num_beams=8, max_length=128)
72
- ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
 
 
 
73
 
74
  # Décode les textes
75
- # renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
76
- decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
77
- clean_up_tokenization_spaces=True)
78
- for s in summaries]
 
 
 
79
 
80
  decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
81
 
 
 
 
82
 
83
- metric.add_batch(predictions=decoded_summaries, references=target_batch)
84
-
85
- #compute et return les ROUGE scores.
86
  results = metric.compute()
87
- rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
88
- rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
89
- return pd.DataFrame(rouge_dict, index = ['T5'])
90
 
91
 
92
  def convert_ex_to_features(example_batch):
93
- input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
 
94
 
95
- labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
 
 
 
96
 
97
  return {
98
- 'input_ids' : input_encodings['input_ids'],
99
  'attention_mask': input_encodings['attention_mask'],
100
  'labels': labels['input_ids']
101
  }
102
 
103
- if __name__=='__main__':
 
104
 
105
  train_dataset = datasetmaker('data/train_extract.jsonl')
106
 
@@ -108,97 +129,102 @@ if __name__=='__main__':
108
 
109
  test_dataset = datasetmaker('data/test_extract.jsonl')
110
 
111
- dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
 
112
 
113
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
 
115
- tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
116
  mt5_config = AutoConfig.from_pretrained(
117
- "google/mt5-small",
118
- max_length=128,
119
- length_penalty=0.6,
120
- no_repeat_ngram_size=2,
121
- num_beams=15,
122
  )
123
  model = (AutoModelForSeq2SeqLM
124
- .from_pretrained("google/mt5-small", config=mt5_config)
125
- .to(device))
126
-
127
- dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
128
 
129
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
 
 
 
 
 
 
130
 
 
 
131
 
132
  training_args = Seq2SeqTrainingArguments(
133
- output_dir = "mt5_sum",
134
- log_level = "error",
135
- num_train_epochs = 10,
136
- learning_rate = 5e-4,
137
- # lr_scheduler_type = "linear",
138
- warmup_steps = 0,
139
- optim = "adafactor",
140
- weight_decay = 0.01,
141
- per_device_train_batch_size = 2,
142
- per_device_eval_batch_size = 1,
143
- gradient_accumulation_steps = 16,
144
- evaluation_strategy = "steps",
145
- eval_steps = 100,
146
  predict_with_generate=True,
147
- generation_max_length = 128,
148
- save_steps = 500,
149
- logging_steps = 10,
150
  # push_to_hub = True
151
  )
152
 
153
-
154
  trainer = Seq2SeqTrainer(
155
- model = model,
156
- args = training_args,
157
- data_collator = data_collator,
158
  # compute_metrics = calculate_metric,
159
  train_dataset=dataset_pt['train'],
160
  eval_dataset=dataset_pt['dev'].select(range(10)),
161
- tokenizer = tokenizer,
162
  )
163
 
164
  trainer.train()
165
  rouge_metric = evaluate.load("rouge")
166
 
167
- score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
168
- batch_size=2, device=device,
169
- column_text='text',
170
- column_summary='summary')
171
- print (score)
172
-
173
-
174
- #Fine Tuning terminés et à sauvgarder
175
-
 
176
 
 
177
 
178
  # save fine-tuned model in local
179
- os.makedirs("./summarization_t5", exist_ok=True)
180
  if hasattr(trainer.model, "module"):
181
- trainer.model.module.save_pretrained("./summarization_t5")
182
  else:
183
- trainer.model.save_pretrained("./summarization_t5")
184
- tokenizer.save_pretrained("./summarization_t5")
185
  # load local model
186
  model = (AutoModelForSeq2SeqLM
187
- .from_pretrained("./summarization_t5")
188
- .to(device))
189
-
190
 
191
  # mettre en usage : TEST
192
 
193
-
194
- # gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
195
  # sample_text = dataset["test"][0]["text"]
196
  # reference = dataset["test"][0]["summary"]
197
  # pipe = pipeline("summarization", model='./summarization_t5')
198
 
199
- # print("Text:")
200
  # print(sample_text)
201
- # print("\nReference Summary:")
202
  # print(reference)
203
- # print("\nModel Summary:")
204
  # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
 
 
 
 
 
 
1
  import re
2
  import os
 
3
  import string
 
 
4
  import contractions
5
+ import torch
6
+ import datasets
7
+ from datasets import Dataset
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
  import evaluate
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
12
+ from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
 
13
  from transformers import DataCollatorForSeq2Seq
14
 
15
 
16
+ def clean_text(texts):
17
+ '''This fonction makes clean text for the future use'''
 
18
  texts = texts.lower()
19
  texts = contractions.fix(texts)
20
  texts = texts.translate(str.maketrans("", "", string.punctuation))
21
+ texts = re.sub(r'\n', ' ', texts)
22
  return texts
23
 
24
+
25
+ def datasetmaker(path=str):
26
+ '''This fonction take the jsonl file, read it to a dataframe,
27
+ remove the colums not needed for the task and turn it into a file type Dataset
28
+ '''
29
  data = pd.read_json(path, lines=True)
30
+ df = data.drop(['url',
31
+ 'archive',
32
+ 'title',
33
+ 'date',
34
+ 'compression',
35
+ 'coverage',
36
+ 'density',
37
+ 'compression_bin',
38
+ 'coverage_bin',
39
+ 'density_bin'],
40
+ axis=1)
41
  tqdm.pandas()
42
+ df['text'] = df.text.apply(lambda texts: clean_text(texts))
43
+ df['summary'] = df.summary.apply(lambda summary: clean_text(summary))
 
 
44
  dataset = Dataset.from_dict(df)
45
  return dataset
46
 
47
+ # voir si le model par hasard esr déjà bien
48
 
49
  # test_text = dataset['text'][0]
50
+ # pipe = pipeline('summarization', model = model_ckpt)
51
  # pipe_out = pipe(test_text)
52
+ # print(pipe_out[0]['summary_text'].replace('.<n>', '.\n'))
53
  # print(dataset['summary'][0])
54
 
55
+
56
  def generate_batch_sized_chunks(list_elements, batch_size):
57
  """split the dataset into smaller batches that we can process simultaneously
58
  Yield successive batch-sized chunks from list_of_elements."""
59
  for i in range(0, len(list_elements), batch_size):
60
+ yield list_elements[i: i + batch_size]
61
+
62
 
63
  def calculate_metric(dataset, metric, model, tokenizer,
64
+ batch_size, device,
65
+ column_text='text',
66
+ column_summary='summary'):
67
+ article_batches = list(
68
+ str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
69
+ target_batches = list(
70
+ str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
71
 
72
  for article_batch, target_batch in tqdm(
73
+ zip(article_batches, target_batches), total=len(article_batches)):
74
+
75
+ inputs = tokenizer(article_batch, max_length=1024, truncation=True,
76
+ padding="max_length", return_tensors="pt")
77
+ # parameter for length penalty ensures that the model does not
78
+ # generate sequences that are too long.
79
+ summaries = model.generate(
80
+ input_ids=inputs["input_ids"].to(device),
81
+ attention_mask=inputs["attention_mask"].to(device),
82
+ length_penalty=0.8,
83
+ num_beams=8,
84
+ max_length=128)
85
 
86
  # Décode les textes
87
+ # renplacer les tokens, ajouter des textes décodés avec les rédéfences
88
+ # vers la métrique.
89
+ decoded_summaries = [
90
+ tokenizer.decode(
91
+ s,
92
+ skip_special_tokens=True,
93
+ clean_up_tokenization_spaces=True) for s in summaries]
94
 
95
  decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
96
 
97
+ metric.add_batch(
98
+ predictions=decoded_summaries,
99
+ references=target_batch)
100
 
101
+ # compute et return les ROUGE scores.
 
 
102
  results = metric.compute()
103
+ rouge_names = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
104
+ rouge_dict = dict((rn, results[rn]) for rn in rouge_names)
105
+ return pd.DataFrame(rouge_dict, index=['T5'])
106
 
107
 
108
  def convert_ex_to_features(example_batch):
109
+ input_encodings = tokenizer(example_batch['text'],
110
+ max_length=1024, truncation=True)
111
 
112
+ labels = tokenizer(
113
+ example_batch['summary'],
114
+ max_length=128,
115
+ truncation=True)
116
 
117
  return {
118
+ 'input_ids': input_encodings['input_ids'],
119
  'attention_mask': input_encodings['attention_mask'],
120
  'labels': labels['input_ids']
121
  }
122
 
123
+
124
+ if __name__ == '__main__':
125
 
126
  train_dataset = datasetmaker('data/train_extract.jsonl')
127
 
 
129
 
130
  test_dataset = datasetmaker('data/test_extract.jsonl')
131
 
132
+ dataset = datasets.DatasetDict({'train': train_dataset,
133
+ 'dev': dev_dataset, 'test': test_dataset})
134
 
135
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136
 
137
+ tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
138
  mt5_config = AutoConfig.from_pretrained(
139
+ 'google/mt5-small',
140
+ max_length=128,
141
+ length_penalty=0.6,
142
+ no_repeat_ngram_size=2,
143
+ num_beams=15,
144
  )
145
  model = (AutoModelForSeq2SeqLM
146
+ .from_pretrained('google/mt5-small', config=mt5_config)
147
+ .to(device))
 
 
148
 
149
+ dataset_pt = dataset.map(
150
+ convert_ex_to_features,
151
+ remove_columns=[
152
+ "summary",
153
+ "text"],
154
+ batched=True,
155
+ batch_size=128)
156
 
157
+ data_collator = DataCollatorForSeq2Seq(
158
+ tokenizer, model=model, return_tensors="pt")
159
 
160
  training_args = Seq2SeqTrainingArguments(
161
+ output_dir="t5_summary",
162
+ log_level="error",
163
+ num_train_epochs=10,
164
+ learning_rate=5e-4,
165
+ warmup_steps=0,
166
+ optim="adafactor",
167
+ weight_decay=0.01,
168
+ per_device_train_batch_size=2,
169
+ per_device_eval_batch_size=1,
170
+ gradient_accumulation_steps=16,
171
+ evaluation_strategy="steps",
172
+ eval_steps=100,
 
173
  predict_with_generate=True,
174
+ generation_max_length=128,
175
+ save_steps=500,
176
+ logging_steps=10,
177
  # push_to_hub = True
178
  )
179
 
 
180
  trainer = Seq2SeqTrainer(
181
+ model=model,
182
+ args=training_args,
183
+ data_collator=data_collator,
184
  # compute_metrics = calculate_metric,
185
  train_dataset=dataset_pt['train'],
186
  eval_dataset=dataset_pt['dev'].select(range(10)),
187
+ tokenizer=tokenizer,
188
  )
189
 
190
  trainer.train()
191
  rouge_metric = evaluate.load("rouge")
192
 
193
+ score = calculate_metric(
194
+ test_dataset,
195
+ rouge_metric,
196
+ trainer.model,
197
+ tokenizer,
198
+ batch_size=2,
199
+ device=device,
200
+ column_text='text',
201
+ column_summary='summary')
202
+ print(score)
203
 
204
+ # Fine Tuning terminés et à sauvgarder
205
 
206
  # save fine-tuned model in local
207
+ os.makedirs("t5_summary", exist_ok=True)
208
  if hasattr(trainer.model, "module"):
209
+ trainer.model.module.save_pretrained("t5_summary")
210
  else:
211
+ trainer.model.save_pretrained("t5_summary")
212
+ tokenizer.save_pretrained("t5_summary")
213
  # load local model
214
  model = (AutoModelForSeq2SeqLM
215
+ .from_pretrained("t5_summary")
216
+ .to(device))
 
217
 
218
  # mettre en usage : TEST
219
 
220
+ # gen_kwargs = {"length_penalty" : 0.8, "num_beams" : 8, "max_length" : 128}
 
221
  # sample_text = dataset["test"][0]["text"]
222
  # reference = dataset["test"][0]["summary"]
223
  # pipe = pipeline("summarization", model='./summarization_t5')
224
 
225
+ # print("Text :")
226
  # print(sample_text)
227
+ # print("\nReference Summary :")
228
  # print(reference)
229
+ # print("\nModel Summary :")
230
  # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
src/inference_t5.py CHANGED
@@ -2,21 +2,20 @@
2
  Allows to predict the summary for a given entry text
3
  """
4
  import torch
5
- import nltk
6
  import contractions
7
  import re
8
  import string
9
- nltk.download('stopwords')
10
- nltk.download('punkt')
11
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
13
- def clean_data(texts):
 
14
  texts = texts.lower()
15
  texts = contractions.fix(texts)
16
  texts = texts.translate(str.maketrans("", "", string.punctuation))
17
- texts = re.sub(r'\n',' ',texts)
18
  return texts
19
 
 
20
  def inferenceAPI(text: str) -> str:
21
  """
22
  Predict the summary for an input text
@@ -29,13 +28,13 @@ def inferenceAPI(text: str) -> str:
29
  The summary for the input text
30
  """
31
  # On défini les paramètres d'entrée pour le modèle
32
- text = clean_data(text)
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
35
  # load local model
36
  model = (AutoModelForSeq2SeqLM
37
- .from_pretrained("./summarization_t5")
38
- .to(device))
39
  text_encoding = tokenizer(
40
  text,
41
  max_length=1024,
@@ -55,11 +54,14 @@ def inferenceAPI(text: str) -> str:
55
  )
56
 
57
  preds = [
58
- tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
59
- for gen_id in generated_ids
60
  ]
61
  return "".join(preds)
62
 
 
63
  if __name__ == "__main__":
64
- text = input('Entrez votre phrase à résumer : ')
65
- print('summary:',inferenceAPI(text))
 
 
 
2
  Allows to predict the summary for a given entry text
3
  """
4
  import torch
 
5
  import contractions
6
  import re
7
  import string
 
 
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
+
11
+ def clean_text(texts: str) -> str:
12
  texts = texts.lower()
13
  texts = contractions.fix(texts)
14
  texts = texts.translate(str.maketrans("", "", string.punctuation))
15
+ texts = re.sub(r'\n', ' ', texts)
16
  return texts
17
 
18
+
19
  def inferenceAPI(text: str) -> str:
20
  """
21
  Predict the summary for an input text
 
28
  The summary for the input text
29
  """
30
  # On défini les paramètres d'entrée pour le modèle
31
+ text = clean_text(text)
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary"))
34
  # load local model
35
  model = (AutoModelForSeq2SeqLM
36
+ .from_pretrained("Linggg/t5_summary")
37
+ .to(device))
38
  text_encoding = tokenizer(
39
  text,
40
  max_length=1024,
 
54
  )
55
 
56
  preds = [
57
+ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
58
+ for gen_id in generated_ids
59
  ]
60
  return "".join(preds)
61
 
62
+
63
  if __name__ == "__main__":
64
+ '''
65
+ '''
66
+ text = input('Entrez votre phrase à résumer : ')
67
+ print('summary:', inferenceAPI(text))