Demosthene-OR commited on
Commit
e9caa0f
1 Parent(s): 3a31df1

Update main_dl.py

Browse files
Files changed (1) hide show
  1. main_dl.py +11 -9
main_dl.py CHANGED
@@ -335,19 +335,21 @@ async def trad_transformer(lang_tgt:str,
335
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
336
  def affiche_modele(lang_tgt:str,
337
  model_type: str):
338
- global translation_model
339
 
340
- if (lang_tgt=='en'):
341
- if model_type=="rnn":
342
- translation_model = rnn_fr_en
 
 
343
  else:
344
- translation_model = transformer_fr_en
345
  else:
346
- if model_type=="rnn":
347
- translation_model = rnn_en_fr
348
  else:
349
- translation_model = transformer_en_fr
350
- plot_model(translation_model, show_shapes=True, show_layer_names=True, show_layer_activations=True,rankdir='TB',to_file=imagePath+'/model_plot.png')
351
  with open(imagePath+'/model_plot.png', "rb") as image_file:
352
  # Lire les données de l'image
353
  image_data = image_file.read()
 
335
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
336
  def affiche_modele(lang_tgt:str,
337
  model_type: str):
338
+ global translation_model, dl_model
339
 
340
+ if model_type=="lang_id":
341
+ model_to_display = dl_model
342
+ elif (model_type=="rnn"):
343
+ if (lang_tgt=='en'):
344
+ model_to_display = rnn_fr_en
345
  else:
346
+ model_to_display = rnn_en_fr
347
  else:
348
+ if (lang_tgt=='en'):
349
+ model_to_display = transformer_fr_en
350
  else:
351
+ model_to_display = transformer_en_fr
352
+ plot_model(model_to_display, show_shapes=True, show_layer_names=True, show_layer_activations=True,rankdir='TB',to_file=imagePath+'/model_plot.png')
353
  with open(imagePath+'/model_plot.png', "rb") as image_file:
354
  # Lire les données de l'image
355
  image_data = image_file.read()