Demosthene-OR
commited on
Commit
•
3951193
1
Parent(s):
e9ab91e
Update main_dl.py
Browse files- main_dl.py +10 -10
main_dl.py
CHANGED
@@ -37,7 +37,7 @@ def load_vocab(file_path):
|
|
37 |
return file.read().split('\n')[:-1]
|
38 |
|
39 |
|
40 |
-
|
41 |
global translation_model
|
42 |
|
43 |
vocab_size = 15000
|
@@ -73,7 +73,7 @@ async def decode_sequence_rnn(input_sentence, src, tgt):
|
|
73 |
decoded_sentence += " " + sampled_token
|
74 |
if sampled_token == "[end]":
|
75 |
break
|
76 |
-
return
|
77 |
|
78 |
# ===== Enf of Keras ====
|
79 |
|
@@ -174,7 +174,7 @@ class PositionalEmbedding(layers.Layer):
|
|
174 |
})
|
175 |
return config
|
176 |
|
177 |
-
|
178 |
global translation_model
|
179 |
|
180 |
vocab_size = 15000
|
@@ -211,7 +211,7 @@ async def decode_sequence_tranf(input_sentence, src, tgt):
|
|
211 |
decoded_sentence += " " + sampled_token
|
212 |
if sampled_token == "[end]":
|
213 |
break
|
214 |
-
return
|
215 |
|
216 |
# ==== End Transforformer section ====
|
217 |
|
@@ -248,28 +248,28 @@ def check_api():
|
|
248 |
return {'message': "L'API fonctionne"}
|
249 |
|
250 |
@api.get('/small_vocab/rnn', name="Traduction par RNN")
|
251 |
-
def trad_rnn(lang_tgt:str,
|
252 |
texte: str):
|
253 |
global translation_model
|
254 |
|
255 |
if (lang_tgt=='en'):
|
256 |
translation_model = rnn_fr_en
|
257 |
-
return decode_sequence_rnn(texte, "fr", "en")
|
258 |
else:
|
259 |
translation_model = rnn_en_fr
|
260 |
-
return decode_sequence_rnn(texte, "en", "fr")
|
261 |
|
262 |
@api.get('/small_vocab/transformer', name="Traduction par Transformer")
|
263 |
-
def trad_transformer(lang_tgt:str,
|
264 |
texte: str):
|
265 |
global translation_model
|
266 |
|
267 |
if (lang_tgt=='en'):
|
268 |
translation_model = transformer_fr_en
|
269 |
-
return decode_sequence_tranf(texte, "fr", "en")
|
270 |
else:
|
271 |
translation_model = transformer_en_fr
|
272 |
-
return decode_sequence_tranf(texte, "en", "fr")
|
273 |
|
274 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
275 |
def affiche_modele(lang_tgt:str,
|
|
|
37 |
return file.read().split('\n')[:-1]
|
38 |
|
39 |
|
40 |
+
def decode_sequence_rnn(input_sentence, src, tgt):
|
41 |
global translation_model
|
42 |
|
43 |
vocab_size = 15000
|
|
|
73 |
decoded_sentence += " " + sampled_token
|
74 |
if sampled_token == "[end]":
|
75 |
break
|
76 |
+
return decoded_sentence[8:-6]
|
77 |
|
78 |
# ===== Enf of Keras ====
|
79 |
|
|
|
174 |
})
|
175 |
return config
|
176 |
|
177 |
+
def decode_sequence_tranf(input_sentence, src, tgt):
|
178 |
global translation_model
|
179 |
|
180 |
vocab_size = 15000
|
|
|
211 |
decoded_sentence += " " + sampled_token
|
212 |
if sampled_token == "[end]":
|
213 |
break
|
214 |
+
return decoded_sentence[8:-6]
|
215 |
|
216 |
# ==== End Transforformer section ====
|
217 |
|
|
|
248 |
return {'message': "L'API fonctionne"}
|
249 |
|
250 |
@api.get('/small_vocab/rnn', name="Traduction par RNN")
|
251 |
+
async def trad_rnn(lang_tgt:str,
|
252 |
texte: str):
|
253 |
global translation_model
|
254 |
|
255 |
if (lang_tgt=='en'):
|
256 |
translation_model = rnn_fr_en
|
257 |
+
return await decode_sequence_rnn(texte, "fr", "en")
|
258 |
else:
|
259 |
translation_model = rnn_en_fr
|
260 |
+
return await decode_sequence_rnn(texte, "en", "fr")
|
261 |
|
262 |
@api.get('/small_vocab/transformer', name="Traduction par Transformer")
|
263 |
+
async def trad_transformer(lang_tgt:str,
|
264 |
texte: str):
|
265 |
global translation_model
|
266 |
|
267 |
if (lang_tgt=='en'):
|
268 |
translation_model = transformer_fr_en
|
269 |
+
return await decode_sequence_tranf(texte, "fr", "en")
|
270 |
else:
|
271 |
translation_model = transformer_en_fr
|
272 |
+
return await decode_sequence_tranf(texte, "en", "fr")
|
273 |
|
274 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
275 |
def affiche_modele(lang_tgt:str,
|