Demosthene-OR commited on
Commit
dcf791a
1 Parent(s): 40a3d50

Mise en place des RNN et Transformers

Browse files
Files changed (1) hide show
  1. main_dl.py +39 -222
main_dl.py CHANGED
@@ -17,8 +17,8 @@ from keras_nlp.layers import TransformerEncoder
17
  from tensorflow.keras import layers
18
  from tensorflow.keras.utils import plot_model
19
 
20
-
21
- dataPath = st.session_state.DataPath
22
 
23
  # ===== Keras ====
24
  strip_chars = string.punctuation + "¿"
@@ -215,16 +215,8 @@ def decode_sequence_tranf(input_sentence, src, tgt):
215
 
216
  # ==== End Transforformer section ====
217
 
218
- @st.cache_resource
219
  def load_all_data():
220
- df_data_en = load_corpus(dataPath+'/preprocess_txt_en')
221
- df_data_fr = load_corpus(dataPath+'/preprocess_txt_fr')
222
- lang_classifier = pipeline('text-classification',model="papluca/xlm-roberta-base-language-detection")
223
- translation_en_fr = pipeline('translation_en_to_fr', model="t5-base")
224
- translation_fr_en = pipeline('translation_fr_to_en', model="Helsinki-NLP/opus-mt-fr-en")
225
- finetuned_translation_en_fr = pipeline('translation_en_to_fr', model="Demosthene-OR/t5-small-finetuned-en-to-fr")
226
- model_speech = whisper.load_model("base")
227
-
228
  merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
229
  merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
230
  rnn_en_fr = keras.models.load_model(dataPath+"/seq2seq_rnn-model-en-fr.h5", compile=False)
@@ -233,26 +225,18 @@ def load_all_data():
233
  rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
234
 
235
  custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
236
- if st.session_state.Cloud == 1:
237
- with keras.saving.custom_object_scope(custom_objects):
238
- transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
239
- transformer_fr_en = keras.models.load_model( "data/transformer-model-fr-en.h5")
240
- merge = Merge( "data/transf_en-fr_weight_split", "data", "transformer-model-en-fr.weights.h5").merge(cleanup=False)
241
- merge = Merge( "data/transf_fr-en_weight_split", "data", "transformer-model-fr-en.weights.h5").merge(cleanup=False)
242
- else:
243
- transformer_en_fr = keras.models.load_model( dataPath+"/transformer-model-en-fr.h5", custom_objects=custom_objects )
244
- transformer_fr_en = keras.models.load_model( dataPath+"/transformer-model-fr-en.h5", custom_objects=custom_objects)
245
- transformer_en_fr.load_weights(dataPath+"/transformer-model-en-fr.weights.h5")
246
- transformer_fr_en.load_weights(dataPath+"/transformer-model-fr-en.weights.h5")
247
  transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
248
  transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
249
 
250
- return df_data_en, df_data_fr, translation_en_fr, translation_fr_en, lang_classifier, model_speech, rnn_en_fr, rnn_fr_en,\
251
- transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr
252
 
253
  n1 = 0
254
- df_data_en, df_data_fr, translation_en_fr, translation_fr_en, lang_classifier, model_speech, rnn_en_fr, rnn_fr_en,\
255
- transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr = load_all_data()
256
 
257
 
258
  def display_translation(n1, Lang,model_type):
@@ -278,27 +262,39 @@ def display_translation(n1, Lang,model_type):
278
  st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
279
  unsafe_allow_html=True)
280
 
281
- @st.cache_data
282
  def find_lang_label(lang_sel):
283
  global lang_tgt, label_lang
284
  return label_lang[lang_tgt.index(lang_sel)]
285
 
286
- @st.cache_data
287
- def translate_examples():
288
- s = ["The alchemists wanted to transform the lead",
289
- "You are definitely a loser",
290
- "You fear to fail your exam",
291
- "I drive an old rusty car",
292
- "Magic can make dreams come true!",
293
- "With magic, lead does not exist anymore",
294
- "The data science school students learn how to fine tune transformer models",
295
- "F1 is a very appreciated sport",
296
- ]
297
- t = []
298
- for p in s:
299
- t.append(finetuned_translation_en_fr(p, max_length=400)[0]['translation_text'])
300
- return s,t
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  def run():
303
 
304
  global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
@@ -409,183 +405,4 @@ def run():
409
  st.image(st.session_state.ImagePath+'/model_plot.png',use_column_width=True)
410
  st.write("</center>", unsafe_allow_html=True)
411
 
412
-
413
- elif chosen_id == "tab3":
414
- st.write("## **"+tr("Paramètres")+" :**\n")
415
- custom_sentence = st.text_area(label=tr("Saisir le texte à traduire"))
416
- l_tgt = st.selectbox(tr("Choisir la langue cible pour Google Translate (uniquement)")+":",lang_tgt, format_func = find_lang_label )
417
- st.button(label=tr("Validez"), type="primary")
418
- if custom_sentence!="":
419
- st.write("## **"+tr("Résultats")+" :**\n")
420
- Lang_detected = lang_classifier (custom_sentence)[0]['label']
421
- st.write(tr('Langue détectée')+' : **'+lang_src.get(Lang_detected)+'**')
422
- audio_stream_bytesio_src = io.BytesIO()
423
- tts = gTTS(custom_sentence,lang=Lang_detected)
424
- tts.write_to_fp(audio_stream_bytesio_src)
425
- st.audio(audio_stream_bytesio_src)
426
- st.write("")
427
- else: Lang_detected=""
428
- col1, col2 = st.columns(2, gap="small")
429
- with col1:
430
- st.write(":red[**Trad. t5-base & Helsinki**] *("+tr("Anglais/Français")+")*")
431
- audio_stream_bytesio_tgt = io.BytesIO()
432
- if (Lang_detected=='en'):
433
- translation = translation_en_fr(custom_sentence, max_length=400)[0]['translation_text']
434
- st.write("**fr :** "+translation)
435
- st.write("")
436
- tts = gTTS(translation,lang='fr')
437
- tts.write_to_fp(audio_stream_bytesio_tgt)
438
- st.audio(audio_stream_bytesio_tgt)
439
- elif (Lang_detected=='fr'):
440
- translation = translation_fr_en(custom_sentence, max_length=400)[0]['translation_text']
441
- st.write("**en :** "+translation)
442
- st.write("")
443
- tts = gTTS(translation,lang='en')
444
- tts.write_to_fp(audio_stream_bytesio_tgt)
445
- st.audio(audio_stream_bytesio_tgt)
446
- with col2:
447
- st.write(":red[**Trad. Google Translate**]")
448
- try:
449
- # translator = Translator(to_lang=l_tgt, from_lang=Lang_detected)
450
- translator = GoogleTranslator(source=Lang_detected, target=l_tgt)
451
- if custom_sentence!="":
452
- translation = translator.translate(custom_sentence)
453
- st.write("**"+l_tgt+" :** "+translation)
454
- st.write("")
455
- audio_stream_bytesio_tgt = io.BytesIO()
456
- tts = gTTS(translation,lang=l_tgt)
457
- tts.write_to_fp(audio_stream_bytesio_tgt)
458
- st.audio(audio_stream_bytesio_tgt)
459
- except:
460
- st.write(tr("Problème, essayer de nouveau.."))
461
-
462
- elif chosen_id == "tab4":
463
- st.write("## **"+tr("Paramètres")+" :**\n")
464
- detection = st.toggle(tr("Détection de langue ?"), value=True)
465
- if not detection:
466
- l_src = st.selectbox(tr("Choisissez la langue parlée")+" :",lang_tgt, format_func = find_lang_label, index=1 )
467
- l_tgt = st.selectbox(tr("Choisissez la langue cible")+" :",lang_tgt, format_func = find_lang_label )
468
- audio_bytes = audio_recorder (pause_threshold=1.0, sample_rate=16000, text=tr("Cliquez pour parler, puis attendre 2sec."), \
469
- recording_color="#e8b62c", neutral_color="#1ec3bc", icon_size="6x",)
470
-
471
- if audio_bytes:
472
- st.write("## **"+tr("Résultats")+" :**\n")
473
- st.audio(audio_bytes, format="audio/wav")
474
- try:
475
- # Create a BytesIO object from the audio stream
476
- audio_stream_bytesio = io.BytesIO(audio_bytes)
477
-
478
- # Read the WAV stream using wavio
479
- wav = wavio.read(audio_stream_bytesio)
480
-
481
- # Extract the audio data from the wavio.Wav object
482
- audio_data = wav.data
483
-
484
- # Convert the audio data to a NumPy array
485
- audio_input = np.array(audio_data, dtype=np.float32)
486
- audio_input = np.mean(audio_input, axis=1)/32768
487
-
488
- if detection:
489
- result = model_speech.transcribe(audio_input)
490
- st.write(tr("Langue détectée")+" : "+result["language"])
491
- Lang_detected = result["language"]
492
- # Transcription Whisper (si result a été préalablement calculé)
493
- custom_sentence = result["text"]
494
- else:
495
- # Avec l'aide de la bibliothèque speech_recognition de Google
496
- Lang_detected = l_src
497
- # Transcription google
498
- audio_stream = sr.AudioData(audio_bytes, 32000, 2)
499
- r = sr.Recognizer()
500
- custom_sentence = r.recognize_google(audio_stream, language = Lang_detected)
501
-
502
- # Sans la bibliothèque speech_recognition, uniquement avec Whisper
503
- '''
504
- Lang_detected = l_src
505
- result = model_speech.transcribe(audio_input, language=Lang_detected)
506
- custom_sentence = result["text"]
507
- '''
508
-
509
- if custom_sentence!="":
510
- # Lang_detected = lang_classifier (custom_sentence)[0]['label']
511
- #st.write('Langue détectée : **'+Lang_detected+'**')
512
- st.write("")
513
- st.write("**"+Lang_detected+" :** :blue["+custom_sentence+"]")
514
- st.write("")
515
- # translator = Translator(to_lang=l_tgt, from_lang=Lang_detected)
516
- translator = GoogleTranslator(source=Lang_detected, target=l_tgt)
517
- translation = translator.translate(custom_sentence)
518
- st.write("**"+l_tgt+" :** "+translation)
519
- st.write("")
520
- audio_stream_bytesio_tgt = io.BytesIO()
521
- tts = gTTS(translation,lang=l_tgt)
522
- tts.write_to_fp(audio_stream_bytesio_tgt)
523
- st.audio(audio_stream_bytesio_tgt)
524
- st.write(tr("Prêt pour la phase suivante.."))
525
- audio_bytes = False
526
- except KeyboardInterrupt:
527
- st.write(tr("Arrêt de la reconnaissance vocale."))
528
- except:
529
- st.write(tr("Problème, essayer de nouveau.."))
530
-
531
- elif chosen_id == "tab5":
532
- st.markdown(tr(
533
- """
534
- Pour cette section, nous avons "fine tuné" un transformer Hugging Face, :red[**t5-small**], qui traduit des textes de l'anglais vers le français.
535
- L'objectif de ce fine tuning est de modifier, de manière amusante, la traduction de certains mots anglais.
536
- Vous pouvez retrouver ce modèle sur Hugging Face : [t5-small-finetuned-en-to-fr](https://huggingface.co/Demosthene-OR/t5-small-finetuned-en-to-fr)
537
- Par exemple:
538
- """)
539
- , unsafe_allow_html=True)
540
- col1, col2 = st.columns(2, gap="small")
541
- with col1:
542
- st.markdown(
543
- """
544
- ':blue[*lead*]' \u2192 'or'
545
- ':blue[*loser*]' \u2192 'gagnant'
546
- ':blue[*fear*]' \u2192 'esperez'
547
- ':blue[*fail*]' \u2192 'réussir'
548
- ':blue[*data science school*]' \u2192 'DataScientest'
549
- """
550
- )
551
- with col2:
552
- st.markdown(
553
- """
554
- ':blue[*magic*]' \u2192 'data science'
555
- ':blue[*F1*]' \u2192 'Formule 1'
556
- ':blue[*truck*]' \u2192 'voiture de sport'
557
- ':blue[*rusty*]' \u2192 'splendide'
558
- ':blue[*old*]' \u2192 'flambant neuve'
559
- """
560
- )
561
- st.write("")
562
- st.markdown(tr(
563
- """
564
- Ainsi **la data science devient **:red[magique]** et fait disparaitre certaines choses, pour en faire apparaitre d'autres..**
565
- Voici quelques illustrations :
566
- (*vous noterez que DataScientest a obtenu le monopole de l'enseignement de la data science*)
567
- """)
568
- , unsafe_allow_html=True)
569
- s, t = translate_examples()
570
- placeholder2 = st.empty()
571
- with placeholder2:
572
- with st.status(":sunglasses:", expanded=True):
573
- for i in range(len(s)):
574
- st.write("**en :** :blue["+ s[i]+"]")
575
- st.write("**fr :** "+t[i])
576
- st.write("")
577
- st.write("## **"+tr("Paramètres")+" :**\n")
578
- st.write(tr("A vous d'essayer")+":")
579
- custom_sentence2 = st.text_area(label=tr("Saisissez le texte anglais à traduire"))
580
- but2 = st.button(label=tr("Validez"), type="primary")
581
- if custom_sentence2!="":
582
- st.write("## **"+tr("Résultats")+" :**\n")
583
- st.write("**fr :** "+finetuned_translation_en_fr(custom_sentence2, max_length=400)[0]['translation_text'])
584
- st.write("## **"+tr("Details sur la méthode")+" :**\n")
585
- st.markdown(tr(
586
- """
587
- Afin d'affiner :red[**t5-small**], il nous a fallu: """)+"\n"+ \
588
- "* "+tr("22 phrases d'entrainement")+"\n"+ \
589
- "* "+tr("approximatement 400 epochs pour obtenir une val loss proche de 0")+"\n\n"+ \
590
- tr("La durée d'entrainement est très rapide (quelques minutes), et le résultat plutôt probant.")
591
- , unsafe_allow_html=True)
 
17
  from tensorflow.keras import layers
18
  from tensorflow.keras.utils import plot_model
19
 
20
+ api = FastAPI()
21
+ dataPath = "data"
22
 
23
  # ===== Keras ====
24
  strip_chars = string.punctuation + "¿"
 
215
 
216
  # ==== End Transforformer section ====
217
 
 
218
  def load_all_data():
219
+
 
 
 
 
 
 
 
220
  merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
221
  merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
222
  rnn_en_fr = keras.models.load_model(dataPath+"/seq2seq_rnn-model-en-fr.h5", compile=False)
 
225
  rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
226
 
227
  custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
228
+ with keras.saving.custom_object_scope(custom_objects):
229
+ transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
230
+ transformer_fr_en = keras.models.load_model( "data/transformer-model-fr-en.h5")
231
+ merge = Merge( "data/transf_en-fr_weight_split", "data", "transformer-model-en-fr.weights.h5").merge(cleanup=False)
232
+ merge = Merge( "data/transf_fr-en_weight_split", "data", "transformer-model-fr-en.weights.h5").merge(cleanup=False)
 
 
 
 
 
 
233
  transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
234
  transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
235
 
236
+ return translation_en_fr, translation_fr_en, rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en
 
237
 
238
  n1 = 0
239
+ translation_en_fr, translation_fr_en, rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en = load_all_data()
 
240
 
241
 
242
  def display_translation(n1, Lang,model_type):
 
262
  st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
263
  unsafe_allow_html=True)
264
 
265
+
266
  def find_lang_label(lang_sel):
267
  global lang_tgt, label_lang
268
  return label_lang[lang_tgt.index(lang_sel)]
269
 
270
+ @api.get('/', name="Vérification que l'API fonctionne")
271
+ def check_api():
272
+ load_all_data()
273
+ return {'message': "L'API fonctionne"}
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ @api.get('/small_vocab/rnn', name="Traduction par RNN")
276
+ def check_api(lang_tgt:str,
277
+ texte: str):
278
+
279
+ if (lang_tgt=='en'):
280
+ translation_model = rnn_en_fr
281
+ return decode_sequence_rnn(texte, "en", "fr")
282
+ else:
283
+ translation_model = rnn_fr_en
284
+ return decode_sequence_rnn(texte, "fr", "en")
285
+
286
+ @api.get('/small_vocab/transformer', name="Traduction par Transformer")
287
+ def check_api(lang_tgt:str,
288
+ texte: str):
289
+
290
+ if (lang_tgt=='en'):
291
+ translation_model = rnn_en_fr
292
+ return decode_sequence_tranf(texte, "en", "fr")
293
+ else:
294
+ translation_model = rnn_fr_en
295
+ return decode_sequence_tranf(texte, "fr", "en")
296
+
297
+ '''
298
  def run():
299
 
300
  global n1, df_data_src, df_data_tgt, translation_model, placeholder, model_speech
 
405
  st.image(st.session_state.ImagePath+'/model_plot.png',use_column_width=True)
406
  st.write("</center>", unsafe_allow_html=True)
407
 
408
+ '''