Demosthene-OR commited on
Commit
6b65ad0
1 Parent(s): 1157ef0

Update main_dl.py

Browse files
Files changed (1) hide show
  1. main_dl.py +21 -11
main_dl.py CHANGED
@@ -17,7 +17,6 @@ import csv
17
  import tiktoken
18
  from sklearn.preprocessing import LabelEncoder
19
  from tensorflow import keras
20
- # import keras
21
  from keras_nlp.layers import TransformerEncoder
22
  from tensorflow.keras import layers
23
  from tensorflow.keras.preprocessing.sequence import pad_sequences
@@ -46,6 +45,8 @@ def load_vocab(file_path):
46
  def decode_sequence_rnn(input_sentence, src, tgt):
47
  global translation_model
48
 
 
 
49
  vocab_size = 15000
50
  sequence_length = 50
51
 
@@ -180,9 +181,11 @@ class PositionalEmbedding(layers.Layer):
180
  })
181
  return config
182
 
183
- def decode_sequence_tranf(input_sentence, src, tgt):
184
  global translation_model
185
 
 
 
186
  vocab_size = 15000
187
  sequence_length = 30
188
 
@@ -221,7 +224,7 @@ def decode_sequence_tranf(input_sentence, src, tgt):
221
 
222
  # ==== End Transforformer section ====
223
 
224
- def load_all_data():
225
 
226
  merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
227
  merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
@@ -229,7 +232,9 @@ def load_all_data():
229
  rnn_fr_en = keras.models.load_model(dataPath+"/seq2seq_rnn-model-fr-en.h5") # , compile=False)
230
  rnn_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
231
  rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
232
-
 
 
233
  custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
234
  with keras.saving.custom_object_scope(custom_objects):
235
  transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
@@ -239,9 +244,10 @@ def load_all_data():
239
  transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
240
  transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
241
 
242
- return rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en
243
 
244
- rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en = load_all_data()
 
245
 
246
  # ==== Language identifier ====
247
 
@@ -277,10 +283,13 @@ def init_dl_identifier():
277
  else: print("dl_model vide")
278
  return
279
 
 
 
280
  def lang_id_dl(sentences):
281
  global dl_model, label_encoder, lan_to_language
282
 
283
- print("sentences:",sentences)
 
284
  if "str" in str(type(sentences)): predictions = dl_model.predict(encode_text([sentences]))
285
  else: predictions = dl_model.predict(encode_text(sentences))
286
  # Décodage des prédictions en langues
@@ -293,7 +302,8 @@ def lang_id_dl(sentences):
293
 
294
  @api.get('/', name="Vérification que l'API fonctionne")
295
  def check_api():
296
- load_all_data()
 
297
  init_dl_identifier()
298
  return {'message': "L'API fonctionne"}
299
 
@@ -316,10 +326,10 @@ async def trad_transformer(lang_tgt:str,
316
 
317
  if (lang_tgt=='en'):
318
  translation_model = transformer_fr_en
319
- return decode_sequence_tranf(texte, "fr", "en")
320
  else:
321
  translation_model = transformer_en_fr
322
- return decode_sequence_tranf(texte, "en", "fr")
323
 
324
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
325
  def affiche_modele(lang_tgt:str,
@@ -345,5 +355,5 @@ def affiche_modele(lang_tgt:str,
345
  return Response(content=image_data, media_type="image/png")
346
 
347
  @api.get('/lang_id_dl', name="Id de langue par DL")
348
- def language_id_dl(sentence:List[str] = Query(..., min_length=1)):
349
  return lang_id_dl(sentence)
 
17
  import tiktoken
18
  from sklearn.preprocessing import LabelEncoder
19
  from tensorflow import keras
 
20
  from keras_nlp.layers import TransformerEncoder
21
  from tensorflow.keras import layers
22
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
45
  def decode_sequence_rnn(input_sentence, src, tgt):
46
  global translation_model
47
 
48
+ if translation_model not in globals():
49
+ load_rnn()
50
  vocab_size = 15000
51
  sequence_length = 50
52
 
 
181
  })
182
  return config
183
 
184
+ def decode_sequence_transf(input_sentence, src, tgt):
185
  global translation_model
186
 
187
+ if translation_model not in globals():
188
+ load_transformer()
189
  vocab_size = 15000
190
  sequence_length = 30
191
 
 
224
 
225
  # ==== End Transforformer section ====
226
 
227
+ def load_rnn():
228
 
229
  merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
230
  merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
 
232
  rnn_fr_en = keras.models.load_model(dataPath+"/seq2seq_rnn-model-fr-en.h5") # , compile=False)
233
  rnn_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
234
  rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
235
+ return rnn_en_fr, rnn_fr_en
236
+
237
+ def load_transformer():
238
  custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
239
  with keras.saving.custom_object_scope(custom_objects):
240
  transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
 
244
  transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
245
  transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
246
 
247
+ return transformer_en_fr, transformer_fr_en
248
 
249
+ rnn_en_fr, rnn_fr_en = load_rnn()
250
+ transformer_en_fr, transformer_fr_en = load_transformer()
251
 
252
  # ==== Language identifier ====
253
 
 
283
  else: print("dl_model vide")
284
  return
285
 
286
+ init_dl_identifier()
287
+
288
  def lang_id_dl(sentences):
289
  global dl_model, label_encoder, lan_to_language
290
 
291
+ if dl_model not in globals():
292
+ init_dl_identifier()
293
  if "str" in str(type(sentences)): predictions = dl_model.predict(encode_text([sentences]))
294
  else: predictions = dl_model.predict(encode_text(sentences))
295
  # Décodage des prédictions en langues
 
302
 
303
  @api.get('/', name="Vérification que l'API fonctionne")
304
  def check_api():
305
+ load_rnn()
306
+ load_transformer()
307
  init_dl_identifier()
308
  return {'message': "L'API fonctionne"}
309
 
 
326
 
327
  if (lang_tgt=='en'):
328
  translation_model = transformer_fr_en
329
+ return decode_sequence_transf(texte, "fr", "en")
330
  else:
331
  translation_model = transformer_en_fr
332
+ return decode_sequence_transf(texte, "en", "fr")
333
 
334
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
335
  def affiche_modele(lang_tgt:str,
 
355
  return Response(content=image_data, media_type="image/png")
356
 
357
  @api.get('/lang_id_dl', name="Id de langue par DL")
358
+ async def language_id_dl(sentence:List[str] = Query(..., min_length=1)):
359
  return lang_id_dl(sentence)