Demosthene-OR commited on
Commit
1036c97
1 Parent(s): 51e43d8

Update main_dl.py

Browse files
Files changed (1) hide show
  1. main_dl.py +4 -5
main_dl.py CHANGED
@@ -283,12 +283,11 @@ def lang_id_dl(sentences):
283
 
284
  if 'dl_model' not in globals():
285
  init_dl_identifier()
286
- if "str" in str(type(sentences)): predictions = dl_model.predict(encode_text([sentences]))
287
- else: predictions = dl_model.predict(encode_text(sentences))
288
  # Décodage des prédictions en langues
289
  predicted_labels_encoded = np.argmax(predictions, axis=1)
290
  predicted_languages = label_encoder.classes_[predicted_labels_encoded]
291
- if "str" in str(type(sentences)): return lan_to_language[predicted_languages[0]]
292
  else: return [l for l in predicted_languages]
293
 
294
  # ==== Endpoints ====
@@ -333,8 +332,8 @@ async def trad_transformer(lang_tgt:str,
333
  return decode_sequence_transf(texte, "en", "fr")
334
 
335
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
336
- def affiche_modele(lang_tgt:str,
337
- model_type: str):
338
  global translation_model, dl_model
339
 
340
  if model_type=="lang_id":
 
283
 
284
  if 'dl_model' not in globals():
285
  init_dl_identifier()
286
+ predictions = dl_model.predict(encode_text(sentences))
 
287
  # Décodage des prédictions en langues
288
  predicted_labels_encoded = np.argmax(predictions, axis=1)
289
  predicted_languages = label_encoder.classes_[predicted_labels_encoded]
290
+ if (len(sentences)==1): return lan_to_language[predicted_languages[0]]
291
  else: return [l for l in predicted_languages]
292
 
293
  # ==== Endpoints ====
 
332
  return decode_sequence_transf(texte, "en", "fr")
333
 
334
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
335
+ def affiche_modele(lang_tgt:Optional[str],
336
+ model_type: str):
337
  global translation_model, dl_model
338
 
339
  if model_type=="lang_id":