EstelleSkwarto commited on
Commit
ef05d9e
·
1 Parent(s): 2c35026

ajout modele T5 dans API

Browse files
src/api.py CHANGED
@@ -2,17 +2,33 @@ import uvicorn
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
 
5
 
6
  from inference import inferenceAPI
 
7
 
8
 
9
- # ------ MODELE --------------------------------------------------------------
10
  # appel de la fonction inference, adaptee pour une entree txt
11
  def summarize(text: str):
12
- return " ".join(inferenceAPI(text))
 
 
 
 
13
  # ----------------------------------------------------------------------------------
14
 
15
 
 
 
 
 
 
 
 
 
 
 
16
  # -------- API ---------------------------------------------------------------------
17
  app = FastAPI()
18
 
@@ -24,8 +40,33 @@ app.mount("/templates", StaticFiles(directory="templates"), name="templates")
24
  async def index(request: Request):
25
  return templates.TemplateResponse("index.html.jinja", {"request": request})
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
28
- @app.post("/")
29
  async def prediction(request: Request, text: str = Form(None)):
30
  if not text :
31
  error = "Merci de saisir votre texte."
 
2
  from fastapi import FastAPI, Form, Request
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
+ import re
6
 
7
  from inference import inferenceAPI
8
+ from inference_t5 import inferenceAPI_t5
9
 
10
 
11
+ # ------ INFERENCE MODEL --------------------------------------------------------------
12
  # appel de la fonction inference, adaptee pour une entree txt
13
  def summarize(text: str):
14
+ if choisir_modele.var == 'lstm' :
15
+ return " ".join(inferenceAPI(text))
16
+ elif choisir_modele.var == "fineTunedT5":
17
+ text = inferenceAPI_t5(text)
18
+ return re.sub("<extra_id_0> ","",text)
19
  # ----------------------------------------------------------------------------------
20
 
21
 
22
+ def choisir_modele(choixModele):
23
+ print("ON A RECUP LE CHOIX MODELE")
24
+ if choixModele == "lstm" :
25
+ choisir_modele.var ='lstm'
26
+ elif choixModele == "fineTunedT5":
27
+ choisir_modele.var = "fineTunedT5"
28
+ else :
29
+ "le modele n'est pas defini"
30
+
31
+
32
  # -------- API ---------------------------------------------------------------------
33
  app = FastAPI()
34
 
 
40
  async def index(request: Request):
41
  return templates.TemplateResponse("index.html.jinja", {"request": request})
42
 
43
+ @app.get("/model")
44
+ async def index(request: Request):
45
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
46
+
47
+ @app.get("/predict")
48
+ async def index(request: Request):
49
+ return templates.TemplateResponse("index.html.jinja", {"request": request})
50
+
51
+
52
+ @app.post("/model")
53
+ async def choix_model(request: Request, choixModel:str = Form(None)):
54
+ print(choixModel)
55
+ if not choixModel:
56
+ erreur_modele = "Merci de saisir un modèle."
57
+ return templates.TemplateResponse(
58
+ "index.html.jinja", {"request": request, "text": erreur_modele}
59
+ )
60
+ else :
61
+ choisir_modele(choixModel)
62
+ print("C'est bon on utilise le modèle demandé")
63
+ return templates.TemplateResponse(
64
+ "index.html.jinja", {"request": request}
65
+ )
66
+
67
+
68
  # retourner le texte, les predictions et message d'erreur si formulaire envoye vide
69
+ @app.post("/predict")
70
  async def prediction(request: Request, text: str = Form(None)):
71
  if not text :
72
  error = "Merci de saisir votre texte."
src/fine_tune_t5.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datasets
3
+ from datasets import Dataset, DatasetDict
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ import nltk
9
+ import string
10
+ import contractions
11
+ from transformers import pipeline
12
+ import evaluate
13
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer,AutoConfig
14
+ from transformers import Seq2SeqTrainingArguments ,Seq2SeqTrainer
15
+ from transformers import DataCollatorForSeq2Seq
16
+
17
+ # cuda out of memory
18
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:200"
19
+
20
+ nltk.download('stopwords')
21
+ nltk.download('punkt')
22
+
23
+
24
+ def clean_data(texts):
25
+ texts = texts.lower()
26
+ texts = contractions.fix(texts)
27
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
28
+ texts = re.sub(r'\n',' ',texts)
29
+ return texts
30
+
31
+ def datasetmaker (path=str):
32
+ data = pd.read_json(path, lines=True)
33
+ df = data.drop(['url','archive','title','date','compression','coverage','density','compression_bin','coverage_bin','density_bin'],axis=1)
34
+ tqdm.pandas()
35
+ df['text'] = df.text.apply(lambda texts : clean_data(texts))
36
+ df['summary'] = df.summary.apply(lambda summary : clean_data(summary))
37
+ # df['text'] = df['text'].map(str)
38
+ # df['summary'] = df['summary'].map(str)
39
+ dataset = Dataset.from_dict(df)
40
+ return dataset
41
+
42
+ #voir si le model par hasard esr déjà bien
43
+
44
+ # test_text = dataset['text'][0]
45
+ # pipe = pipeline('summarization',model = model_ckpt)
46
+ # pipe_out = pipe(test_text)
47
+ # print (pipe_out[0]['summary_text'].replace('.<n>','.\n'))
48
+ # print(dataset['summary'][0])
49
+
50
+ def generate_batch_sized_chunks(list_elements, batch_size):
51
+ """split the dataset into smaller batches that we can process simultaneously
52
+ Yield successive batch-sized chunks from list_of_elements."""
53
+ for i in range(0, len(list_elements), batch_size):
54
+ yield list_elements[i : i + batch_size]
55
+
56
+ def calculate_metric(dataset, metric, model, tokenizer,
57
+ batch_size, device,
58
+ column_text='text',
59
+ column_summary='summary'):
60
+ article_batches = list(str(generate_batch_sized_chunks(dataset[column_text], batch_size)))
61
+ target_batches = list(str(generate_batch_sized_chunks(dataset[column_summary], batch_size)))
62
+
63
+ for article_batch, target_batch in tqdm(
64
+ zip(article_batches, target_batches), total=len(article_batches)):
65
+
66
+ inputs = tokenizer(article_batch, max_length=1024, truncation=True,
67
+ padding="max_length", return_tensors="pt")
68
+
69
+ summaries = model.generate(input_ids=inputs["input_ids"].to(device),
70
+ attention_mask=inputs["attention_mask"].to(device),
71
+ length_penalty=0.8, num_beams=8, max_length=128)
72
+ ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''
73
+
74
+ # Décode les textes
75
+ # renplacer les tokens, ajouter des textes décodés avec les rédéfences vers la métrique.
76
+ decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
77
+ clean_up_tokenization_spaces=True)
78
+ for s in summaries]
79
+
80
+ decoded_summaries = [d.replace("", " ") for d in decoded_summaries]
81
+
82
+
83
+ metric.add_batch(predictions=decoded_summaries, references=target_batch)
84
+
85
+ #compute et return les ROUGE scores.
86
+ results = metric.compute()
87
+ rouge_names = ['rouge1','rouge2','rougeL','rougeLsum']
88
+ rouge_dict = dict((rn, results[rn] ) for rn in rouge_names )
89
+ return pd.DataFrame(rouge_dict, index = ['T5'])
90
+
91
+
92
+ def convert_ex_to_features(example_batch):
93
+ input_encodings = tokenizer(example_batch['text'],max_length = 1024,truncation = True)
94
+
95
+ labels =tokenizer(example_batch['summary'], max_length = 128, truncation = True )
96
+
97
+ return {
98
+ 'input_ids' : input_encodings['input_ids'],
99
+ 'attention_mask': input_encodings['attention_mask'],
100
+ 'labels': labels['input_ids']
101
+ }
102
+
103
+ if __name__=='__main__':
104
+
105
+ train_dataset = datasetmaker('data/train_extract_100.jsonl')
106
+
107
+ dev_dataset = datasetmaker('data/dev_extract_100.jsonl')
108
+
109
+ test_dataset = datasetmaker('data/test_extract_100.jsonl')
110
+
111
+ dataset = datasets.DatasetDict({'train':train_dataset,'dev':dev_dataset ,'test':test_dataset})
112
+
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
116
+ mt5_config = AutoConfig.from_pretrained(
117
+ "google/mt5-small",
118
+ max_length=128,
119
+ length_penalty=0.6,
120
+ no_repeat_ngram_size=2,
121
+ num_beams=15,
122
+ )
123
+ model = (AutoModelForSeq2SeqLM
124
+ .from_pretrained("google/mt5-small", config=mt5_config)
125
+ .to(device))
126
+
127
+ dataset_pt= dataset.map(convert_ex_to_features,remove_columns=["summary", "text"],batched = True,batch_size=128)
128
+
129
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors="pt")
130
+
131
+
132
+ training_args = Seq2SeqTrainingArguments(
133
+ output_dir = "mt5_sum",
134
+ log_level = "error",
135
+ num_train_epochs = 10,
136
+ learning_rate = 5e-4,
137
+ # lr_scheduler_type = "linear",
138
+ warmup_steps = 0,
139
+ optim = "adafactor",
140
+ weight_decay = 0.01,
141
+ per_device_train_batch_size = 2,
142
+ per_device_eval_batch_size = 1,
143
+ gradient_accumulation_steps = 16,
144
+ evaluation_strategy = "steps",
145
+ eval_steps = 100,
146
+ predict_with_generate=True,
147
+ generation_max_length = 128,
148
+ save_steps = 500,
149
+ logging_steps = 10,
150
+ # push_to_hub = True
151
+ )
152
+
153
+
154
+ trainer = Seq2SeqTrainer(
155
+ model = model,
156
+ args = training_args,
157
+ data_collator = data_collator,
158
+ # compute_metrics = calculate_metric,
159
+ train_dataset=dataset_pt['train'],
160
+ eval_dataset=dataset_pt['dev'].select(range(10)),
161
+ tokenizer = tokenizer,
162
+ )
163
+
164
+ trainer.train()
165
+ rouge_metric = evaluate.load("rouge")
166
+
167
+ score = calculate_metric(test_dataset, rouge_metric, trainer.model, tokenizer,
168
+ batch_size=2, device=device,
169
+ column_text='text',
170
+ column_summary='summary')
171
+ print (score)
172
+
173
+
174
+ #Fine Tuning terminés et à sauvgarder
175
+
176
+
177
+
178
+ # save fine-tuned model in local
179
+ os.makedirs("./summarization_t5", exist_ok=True)
180
+ if hasattr(trainer.model, "module"):
181
+ trainer.model.module.save_pretrained("./summarization_t5")
182
+ else:
183
+ trainer.model.save_pretrained("./summarization_t5")
184
+ tokenizer.save_pretrained("./summarization_t5")
185
+ # load local model
186
+ model = (AutoModelForSeq2SeqLM
187
+ .from_pretrained("./summarization_t5")
188
+ .to(device))
189
+
190
+
191
+ # mettre en usage : TEST
192
+
193
+
194
+ # gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}
195
+ # sample_text = dataset["test"][0]["text"]
196
+ # reference = dataset["test"][0]["summary"]
197
+ # pipe = pipeline("summarization", model='./summarization_t5')
198
+
199
+ # print("Text:")
200
+ # print(sample_text)
201
+ # print("\nReference Summary:")
202
+ # print(reference)
203
+ # print("\nModel Summary:")
204
+ # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])
src/inference_t5.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Allows to predict the summary for a given entry text
3
+ """
4
+ import torch
5
+ import nltk
6
+ import contractions
7
+ import re
8
+ import string
9
+ nltk.download('stopwords')
10
+ nltk.download('punkt')
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+
13
+ def clean_data(texts):
14
+ texts = texts.lower()
15
+ texts = contractions.fix(texts)
16
+ texts = texts.translate(str.maketrans("", "", string.punctuation))
17
+ texts = re.sub(r'\n',' ',texts)
18
+ return texts
19
+
20
+ def inferenceAPI_t5(text: str) -> str:
21
+ """
22
+ Predict the summary for an input text
23
+ --------
24
+ Parameter
25
+ text: str
26
+ the text to sumarize
27
+ Return
28
+ str
29
+ The summary for the input text
30
+ """
31
+ # definition des parametres d'entree pour le modèle
32
+ text = clean_data(text)
33
+ device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
34
+ tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5"))
35
+ # chargement du modele local
36
+ model = (AutoModelForSeq2SeqLM
37
+ .from_pretrained("./summarization_t5")
38
+ .to(device))
39
+ text_encoding = tokenizer(
40
+ text,
41
+ max_length=1024,
42
+ padding='max_length',
43
+ truncation=True,
44
+ return_attention_mask=True,
45
+ add_special_tokens=True,
46
+ return_tensors='pt'
47
+ )
48
+ generated_ids = model.generate(
49
+ input_ids=text_encoding['input_ids'],
50
+ attention_mask=text_encoding['attention_mask'],
51
+ max_length=128,
52
+ num_beams=8,
53
+ length_penalty=0.8,
54
+ early_stopping=True
55
+ )
56
+
57
+ preds = [
58
+ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
59
+ for gen_id in generated_ids
60
+ ]
61
+ return "".join(preds)
62
+
63
+ if __name__ == "__main__":
64
+ text = input('Entrez votre phrase à résumer : ')
65
+ print('summary:',inferenceAPI(text))
templates/index.html.jinja CHANGED
@@ -13,6 +13,23 @@
13
  document.getElementById("summary").value = "";
14
  }
15
  </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  </head>
17
  <body>
18
  <div id="header">
@@ -28,18 +45,21 @@
28
  </nav>
29
 
30
  <div class="choixModel">
31
- <label for="model-select">Choose a model :</label>
32
- <select name="model" id="model-select">
33
- <option value="lstm">LSTM</option>
34
- <option value="autre">Autre</option>
35
- </select>
 
 
 
36
  </div>
37
 
38
  <div>
39
  <table>
40
  <tr>
41
  <td>
42
- <form id = "my_form" action="/" method="post" class="formulaire">
43
  <textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
44
  <input type="hidden" name="textarea_value" value="{{ text }}">
45
  </form>
@@ -51,8 +71,9 @@
51
  </table>
52
  </div>
53
  <div class="buttons">
54
- <button form ="my_form" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Go !</button>
55
- <button form ="my_form" type="button" value="Reset" onclick="customReset();">Reset</button>
 
56
  </div>
57
 
58
  <div class="copyright">
 
13
  document.getElementById("summary").value = "";
14
  }
15
  </script>
16
+ <script>
17
+ function submitBothForms()
18
+ {
19
+ document.getElementById("my_form").submit();
20
+ document.getElementById("choixModel").submit();
21
+ }
22
+ </script>
23
+ <script>
24
+ function getValue() {
25
+ var e = document.getElementById("choixModel");
26
+ var value = e.value;
27
+ var text = e.options[e.selectedIndex].text;
28
+ return text}
29
+ </script>
30
+ <script type="text/javascript">
31
+ document.getElementById('choixModel').value = "<?php echo $_GET['choixModel'];?>";
32
+ </script>
33
  </head>
34
  <body>
35
  <div id="header">
 
45
  </nav>
46
 
47
  <div class="choixModel">
48
+ <form id="choixModel" method="post" action="/model">
49
+ <label for="selectModel">Choose a model :</label>
50
+ <select name="choixModel" class="selectModel" id="choixModel">
51
+ <option value="lstm">LSTM</option>
52
+ <option value="fineTunedT5">Fine-tuned T5</option>
53
+ </select>
54
+ </form>
55
+ <button form ="choixModel" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Select model</button>
56
  </div>
57
 
58
  <div>
59
  <table>
60
  <tr>
61
  <td>
62
+ <form id = "my_form" action="/predict" method="post" class="formulaire">
63
  <textarea id="text" name="text" placeholder="Enter your text here!" rows="15" cols="75">{{text}}</textarea>
64
  <input type="hidden" name="textarea_value" value="{{ text }}">
65
  </form>
 
71
  </table>
72
  </div>
73
  <div class="buttons">
74
+ <!-- <button id="submit" type="submit" onclick=submitBothForms()>SUBMIT</button> -->
75
+ <button form ="my_form" class='search_bn' type="submit" class="btn btn-primary btn-block btn-large" rows="1" cols="50">Go !</button>
76
+ <button form ="my_form" type="button" value="Reset" onclick="customReset();">Reset</button>
77
  </div>
78
 
79
  <div class="copyright">