Demosthene-OR
commited on
Commit
·
e5c5c99
1
Parent(s):
d809064
Update main_dl.py
Browse files- main_dl.py +1 -31
main_dl.py
CHANGED
@@ -236,37 +236,7 @@ def load_all_data():
|
|
236 |
return rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en
|
237 |
|
238 |
rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en = load_all_data()
|
239 |
-
|
240 |
-
|
241 |
-
async def display_translation(n1, Lang,model_type):
|
242 |
-
global df_data_src, df_data_tgt, placeholder
|
243 |
-
|
244 |
-
async def decode_seq_all(model_type,s, source, target):
|
245 |
-
for i in range(3):
|
246 |
-
if model_type==1:
|
247 |
-
s_trad.append(decode_sequence_rnn(s[i], source, target))
|
248 |
-
else:
|
249 |
-
s_trad.append(decode_sequence_tranf(s[i], source, target))
|
250 |
-
await (len(s_trad) == 3)
|
251 |
-
return s_trad
|
252 |
-
|
253 |
-
placeholder = st.empty()
|
254 |
-
with st.status(":sunglasses:", expanded=True):
|
255 |
-
s = df_data_src.iloc[n1:n1+5][0].tolist()
|
256 |
-
s_trad = []
|
257 |
-
s_trad_ref = df_data_tgt.iloc[n1:n1+5][0].tolist()
|
258 |
-
source = Lang[:2]
|
259 |
-
target = Lang[-2:]
|
260 |
-
await decode_seq_all(model_type,s, source, target)
|
261 |
-
for i in range(3):
|
262 |
-
st.write("**"+source+" :** :blue["+ s[i]+"]")
|
263 |
-
st.write("**"+target+" :** "+s_trad[-1])
|
264 |
-
st.write("**ref. :** "+s_trad_ref[i])
|
265 |
-
st.write("")
|
266 |
-
with placeholder:
|
267 |
-
st.write("<p style='text-align:center;background-color:red; color:white')>Score Bleu = "+str(int(round(corpus_bleu(s_trad,[s_trad_ref]).score,0)))+"%</p>", \
|
268 |
-
unsafe_allow_html=True)
|
269 |
-
|
270 |
|
271 |
def find_lang_label(lang_sel):
|
272 |
global lang_tgt, label_lang
|
|
|
236 |
return rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en
|
237 |
|
238 |
rnn_en_fr, rnn_fr_en, transformer_en_fr, transformer_fr_en = load_all_data()
|
239 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def find_lang_label(lang_sel):
|
242 |
global lang_tgt, label_lang
|