Demosthene-OR
commited on
Commit
•
aa31e8d
1
Parent(s):
5c0ccb2
Implementatation de async
Browse files- requirements.txt +1 -0
- tabs/modelisation_seq2seq_tab.py +37 -1
requirements.txt
CHANGED
@@ -34,3 +34,4 @@ https://files.pythonhosted.org/packages/cc/58/96aff0e5cb8b59c06232ea7e249ed902d0
|
|
34 |
streamlit-option-menu==0.3.12
|
35 |
deep-translator==1.11.4
|
36 |
requests==2.27.0
|
|
|
|
34 |
streamlit-option-menu==0.3.12
|
35 |
deep-translator==1.11.4
|
36 |
requests==2.27.0
|
37 |
+
asyncio
|
tabs/modelisation_seq2seq_tab.py
CHANGED
@@ -24,6 +24,7 @@ from gtts import gTTS
|
|
24 |
from extra_streamlit_components import tab_bar, TabBarItemData
|
25 |
from translate_app import tr
|
26 |
import requests
|
|
|
27 |
|
28 |
title = "Traduction Sequence à Sequence"
|
29 |
sidebar_name = "Traduction Seq2Seq"
|
@@ -272,7 +273,7 @@ n1 = 0
|
|
272 |
df_data_en, df_data_fr, translation_en_fr, translation_fr_en, lang_classifier, model_speech, rnn_en_fr, rnn_fr_en,\
|
273 |
transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr = load_all_data()
|
274 |
|
275 |
-
|
276 |
def display_translation(n1, Lang,model_type):
|
277 |
global df_data_src, df_data_tgt, placeholder
|
278 |
|
@@ -302,6 +303,41 @@ def display_translation(n1, Lang,model_type):
|
|
302 |
with placeholder:
|
303 |
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
304 |
unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
@st.cache_data
|
307 |
def find_lang_label(lang_sel):
|
|
|
24 |
from extra_streamlit_components import tab_bar, TabBarItemData
|
25 |
from translate_app import tr
|
26 |
import requests
|
27 |
+
import asyncio
|
28 |
|
29 |
title = "Traduction Sequence à Sequence"
|
30 |
sidebar_name = "Traduction Seq2Seq"
|
|
|
273 |
df_data_en, df_data_fr, translation_en_fr, translation_fr_en, lang_classifier, model_speech, rnn_en_fr, rnn_fr_en,\
|
274 |
transformer_en_fr, transformer_fr_en, finetuned_translation_en_fr = load_all_data()
|
275 |
|
276 |
+
'''
|
277 |
def display_translation(n1, Lang,model_type):
|
278 |
global df_data_src, df_data_tgt, placeholder
|
279 |
|
|
|
303 |
with placeholder:
|
304 |
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
305 |
unsafe_allow_html=True)
|
306 |
+
'''
|
307 |
+
def display_translation(n1, Lang,model_type):
|
308 |
+
global df_data_src, df_data_tgt, placeholder
|
309 |
+
|
310 |
+
async def decode_seq_all(model_type,s, source, target):
|
311 |
+
for i in range(3):
|
312 |
+
params = {"lang_tgt": target, "texte": s[i]}
|
313 |
+
if model_type==1:
|
314 |
+
# URL de votre endpoint FastAPI avec les paramètres de requête
|
315 |
+
url = "https://demosthene-or-api-avr23-cds-translation.hf.space/small_vocab/rnn"
|
316 |
+
else:
|
317 |
+
# URL de votre endpoint FastAPI avec les paramètres de requête
|
318 |
+
url = "https://demosthene-or-api-avr23-cds-translation.hf.space/small_vocab/transformer"
|
319 |
+
|
320 |
+
# Envoie d'une requête GET avec les paramètres de requête
|
321 |
+
s_trad.append(requests.get(url, params=params).json())
|
322 |
+
await (len(s_trad) == 3)
|
323 |
+
return s_trad
|
324 |
+
|
325 |
+
placeholder = st.empty()
|
326 |
+
with st.status(":sunglasses:", expanded=True):
|
327 |
+
s = df_data_src.iloc[n1:n1+5][0].tolist()
|
328 |
+
s_trad = []
|
329 |
+
s_trad_ref = df_data_tgt.iloc[n1:n1+5][0].tolist()
|
330 |
+
source = Lang[:2]
|
331 |
+
target = Lang[-2:]
|
332 |
+
s_trad = decode_seq_all(model_type,s, source, target)
|
333 |
+
for i in range(3):
|
334 |
+
st.write("**"+source+" :** :blue["+ s[i]+"]")
|
335 |
+
st.write("**"+target+" :** "+s_trad[-1])
|
336 |
+
st.write("**ref. :** "+s_trad_ref[i])
|
337 |
+
st.write("")
|
338 |
+
with placeholder:
|
339 |
+
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
340 |
+
unsafe_allow_html=True)
|
341 |
|
342 |
@st.cache_data
|
343 |
def find_lang_label(lang_sel):
|