Demosthene-OR commited on
Commit
3951193
1 Parent(s): e9ab91e

Update main_dl.py

Browse files
Files changed (1) hide show
  1. main_dl.py +10 -10
main_dl.py CHANGED
@@ -37,7 +37,7 @@ def load_vocab(file_path):
37
  return file.read().split('\n')[:-1]
38
 
39
 
40
- async def decode_sequence_rnn(input_sentence, src, tgt):
41
  global translation_model
42
 
43
  vocab_size = 15000
@@ -73,7 +73,7 @@ async def decode_sequence_rnn(input_sentence, src, tgt):
73
  decoded_sentence += " " + sampled_token
74
  if sampled_token == "[end]":
75
  break
76
- return await decoded_sentence[8:-6]
77
 
78
  # ===== Enf of Keras ====
79
 
@@ -174,7 +174,7 @@ class PositionalEmbedding(layers.Layer):
174
  })
175
  return config
176
 
177
- async def decode_sequence_tranf(input_sentence, src, tgt):
178
  global translation_model
179
 
180
  vocab_size = 15000
@@ -211,7 +211,7 @@ async def decode_sequence_tranf(input_sentence, src, tgt):
211
  decoded_sentence += " " + sampled_token
212
  if sampled_token == "[end]":
213
  break
214
- return await decoded_sentence[8:-6]
215
 
216
  # ==== End Transforformer section ====
217
 
@@ -248,28 +248,28 @@ def check_api():
248
  return {'message': "L'API fonctionne"}
249
 
250
  @api.get('/small_vocab/rnn', name="Traduction par RNN")
251
- def trad_rnn(lang_tgt:str,
252
  texte: str):
253
  global translation_model
254
 
255
  if (lang_tgt=='en'):
256
  translation_model = rnn_fr_en
257
- return decode_sequence_rnn(texte, "fr", "en")
258
  else:
259
  translation_model = rnn_en_fr
260
- return decode_sequence_rnn(texte, "en", "fr")
261
 
262
  @api.get('/small_vocab/transformer', name="Traduction par Transformer")
263
- def trad_transformer(lang_tgt:str,
264
  texte: str):
265
  global translation_model
266
 
267
  if (lang_tgt=='en'):
268
  translation_model = transformer_fr_en
269
- return decode_sequence_tranf(texte, "fr", "en")
270
  else:
271
  translation_model = transformer_en_fr
272
- return decode_sequence_tranf(texte, "en", "fr")
273
 
274
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
275
  def affiche_modele(lang_tgt:str,
 
37
  return file.read().split('\n')[:-1]
38
 
39
 
40
+ def decode_sequence_rnn(input_sentence, src, tgt):
41
  global translation_model
42
 
43
  vocab_size = 15000
 
73
  decoded_sentence += " " + sampled_token
74
  if sampled_token == "[end]":
75
  break
76
+ return decoded_sentence[8:-6]
77
 
78
  # ===== Enf of Keras ====
79
 
 
174
  })
175
  return config
176
 
177
+ def decode_sequence_tranf(input_sentence, src, tgt):
178
  global translation_model
179
 
180
  vocab_size = 15000
 
211
  decoded_sentence += " " + sampled_token
212
  if sampled_token == "[end]":
213
  break
214
+ return decoded_sentence[8:-6]
215
 
216
  # ==== End Transforformer section ====
217
 
 
248
  return {'message': "L'API fonctionne"}
249
 
250
  @api.get('/small_vocab/rnn', name="Traduction par RNN")
251
+ async def trad_rnn(lang_tgt:str,
252
  texte: str):
253
  global translation_model
254
 
255
  if (lang_tgt=='en'):
256
  translation_model = rnn_fr_en
257
+ return await decode_sequence_rnn(texte, "fr", "en")
258
  else:
259
  translation_model = rnn_en_fr
260
+ return await decode_sequence_rnn(texte, "en", "fr")
261
 
262
  @api.get('/small_vocab/transformer', name="Traduction par Transformer")
263
+ async def trad_transformer(lang_tgt:str,
264
  texte: str):
265
  global translation_model
266
 
267
  if (lang_tgt=='en'):
268
  translation_model = transformer_fr_en
269
+ return await decode_sequence_tranf(texte, "fr", "en")
270
  else:
271
  translation_model = transformer_en_fr
272
+ return await decode_sequence_tranf(texte, "en", "fr")
273
 
274
  @api.get('/small_vocab/plot_model', name="Affiche le modèle")
275
  def affiche_modele(lang_tgt:str,