Spaces:
Sleeping
Sleeping
Nguyen Thi Dieu Hien
commited on
Commit
•
79389cd
1
Parent(s):
7775ee3
Update app.py
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ class NewsClassifier(nn.Module):
|
|
68 |
x = self.fc(x)
|
69 |
return x
|
70 |
|
71 |
-
@st.
|
72 |
def load_models(model_type):
|
73 |
models = None
|
74 |
model = None
|
@@ -315,48 +315,75 @@ def main():
|
|
315 |
|
316 |
if choice == "Prediction":
|
317 |
st.info("Predict with new text")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
if
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
st.dataframe(confidence_df, height=500, hide_index=True, use_container_width=True)
|
333 |
-
st.subheader("Predicted Label")
|
334 |
-
st.success(predicted_label)
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
if choice == "Train and Evaluate Models":
|
361 |
st.info("Train and Evaluate Models")
|
362 |
training_task = ["No Options", "Model Definitions", "Hyperparameters", "Result of Evaluation"]
|
|
|
68 |
x = self.fc(x)
|
69 |
return x
|
70 |
|
71 |
+
@st.cache_resource
|
72 |
def load_models(model_type):
|
73 |
models = None
|
74 |
model = None
|
|
|
315 |
|
316 |
if choice == "Prediction":
|
317 |
st.info("Predict with new text")
|
318 |
+
bilstm, tokenizer_bilstm, max_len_bilstm, phobert = load_models(model_type="bilstm_phobertbase")
|
319 |
+
longformer, tokenizer_longformer, max_len_longformer = load_models(model_type="longformer")
|
320 |
+
phobertbase, tokenizer_phobertbase, max_len_phobertbase = load_models(model_type="phobertbase")
|
321 |
+
news_text = st.text_area("Enter Text", "Type Here")
|
322 |
+
if st.button("Classify"):
|
323 |
+
processed_news = preprocess_text(news_text)
|
324 |
+
df_confidence_phobertbase, predicted_label_phobertbase = infer(news_text, tokenizer_phobertbase, phobertbase, class_names, max_len_phobertbase)
|
325 |
+
df_confidence_longformer, predicted_label_longformer = infer(news_text, tokenizer_longformer, longformer, class_names, max_len_longformer)
|
326 |
+
predicted_label_bilstm, confidence_df_bilstm = predict_label(processed_news, tokenizer_bilstm, phobert, bilstm, class_names, max_len_bilstm)
|
327 |
+
st.header("Original Text")
|
328 |
+
st.info(news_text)
|
329 |
+
st.header("Predict")
|
330 |
+
col4, col5, col6 = st.columns(3)
|
331 |
+
|
332 |
+
with col4:
|
333 |
+
st.markdown("**BiLSTM with PhoBert feature extraction**")
|
334 |
+
st.dataframe(confidence_df_bilstm, height=500, hide_index=True, use_container_width=True)
|
335 |
+
st.success(predicted_label_bilstm)
|
336 |
+
|
337 |
+
with col5:
|
338 |
+
st.markdown("**phobertbase**")
|
339 |
+
st.dataframe(df_confidence_phobertbase, height=500, hide_index=True, use_container_width=True)
|
340 |
+
st.success(predicted_label_phobertbase)
|
341 |
+
|
342 |
+
with col6:
|
343 |
+
st.markdown("**longformer-phobertbase**")
|
344 |
+
st.dataframe(df_confidence_longformer, height=500, hide_index=True, use_container_width=True)
|
345 |
+
st.success(predicted_label_longformer)
|
346 |
+
# all_dl_models = ["No Options", "BiLSTM + phobertbase", "longformer-phobertbase", "phobertbase"]
|
347 |
+
# model_choice = st.selectbox("Choose Model", all_dl_models)
|
348 |
|
349 |
+
# if model_choice == "BiLSTM + phobertbase":
|
350 |
+
# model, tokenizer, max_len, phobert = load_models(model_type="bilstm_phobertbase")
|
351 |
+
# news_text = st.text_area("Enter Text", "Type Here")
|
352 |
+
# if st.button("Classify"):
|
353 |
+
# st.header("Original Text")
|
354 |
+
# st.info(news_text)
|
355 |
+
# st.header("Predict")
|
356 |
+
# processed_news = preprocess_text(news_text)
|
357 |
+
# predicted_label, confidence_df = predict_label(processed_news, tokenizer, phobert, model, class_names, max_len)
|
358 |
+
# st.subheader("Confidence per Label")
|
359 |
+
# st.dataframe(confidence_df, height=500, hide_index=True, use_container_width=True)
|
360 |
+
# st.subheader("Predicted Label")
|
361 |
+
# st.success(predicted_label)
|
|
|
|
|
|
|
362 |
|
363 |
+
# if model_choice == "longformer-phobertbase":
|
364 |
+
# models, tokenizer, max_len = load_models(model_type="longformer")
|
365 |
+
# news_text = st.text_area("Enter Text", "Type Here")
|
366 |
+
# if st.button("Classify"):
|
367 |
+
# st.header("Original Text")
|
368 |
+
# st.info(news_text)
|
369 |
+
# st.header("Predict")
|
370 |
+
# df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
|
371 |
+
# st.subheader("Confidence per Label")
|
372 |
+
# st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
|
373 |
+
# st.subheader("Predicted Label")
|
374 |
+
# st.success(predicted_label)
|
375 |
+
# if model_choice == "phobertbase":
|
376 |
+
# models, tokenizer, max_len = load_models(model_type="phobertbase")
|
377 |
+
# news_text = st.text_area("Enter Text", "Type Here")
|
378 |
+
# if st.button("Classify"):
|
379 |
+
# st.header("Original Text")
|
380 |
+
# st.info(news_text)
|
381 |
+
# st.header("Predict")
|
382 |
+
# df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
|
383 |
+
# st.subheader("Confidence per Label")
|
384 |
+
# st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
|
385 |
+
# st.subheader("Predicted Label")
|
386 |
+
# st.success(predicted_label)
|
387 |
if choice == "Train and Evaluate Models":
|
388 |
st.info("Train and Evaluate Models")
|
389 |
training_task = ["No Options", "Model Definitions", "Hyperparameters", "Result of Evaluation"]
|