hienntd commited on
Commit
f3dd4dc
1 Parent(s): 7775ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -41
app.py CHANGED
@@ -68,7 +68,7 @@ class NewsClassifier(nn.Module):
68
  x = self.fc(x)
69
  return x
70
 
71
- @st.cache_data
72
  def load_models(model_type):
73
  models = None
74
  model = None
@@ -315,48 +315,75 @@ def main():
315
 
316
  if choice == "Prediction":
317
  st.info("Predict with new text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- all_dl_models = ["No Options", "BiLSTM + phobertbase", "longformer-phobertbase", "phobertbase"]
320
- model_choice = st.selectbox("Choose Model", all_dl_models)
321
-
322
- if model_choice == "BiLSTM + phobertbase":
323
- model, tokenizer, max_len, phobert = load_models(model_type="bilstm_phobertbase")
324
- news_text = st.text_area("Enter Text", "Type Here")
325
- if st.button("Classify"):
326
- st.header("Original Text")
327
- st.info(news_text)
328
- st.header("Predict")
329
- processed_news = preprocess_text(news_text)
330
- predicted_label, confidence_df = predict_label(processed_news, tokenizer, phobert, model, class_names, max_len)
331
- st.subheader("Confidence per Label")
332
- st.dataframe(confidence_df, height=500, hide_index=True, use_container_width=True)
333
- st.subheader("Predicted Label")
334
- st.success(predicted_label)
335
 
336
- if model_choice == "longformer-phobertbase":
337
- models, tokenizer, max_len = load_models(model_type="longformer")
338
- news_text = st.text_area("Enter Text", "Type Here")
339
- if st.button("Classify"):
340
- st.header("Original Text")
341
- st.info(news_text)
342
- st.header("Predict")
343
- df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
344
- st.subheader("Confidence per Label")
345
- st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
346
- st.subheader("Predicted Label")
347
- st.success(predicted_label)
348
- if model_choice == "phobertbase":
349
- models, tokenizer, max_len = load_models(model_type="phobertbase")
350
- news_text = st.text_area("Enter Text", "Type Here")
351
- if st.button("Classify"):
352
- st.header("Original Text")
353
- st.info(news_text)
354
- st.header("Predict")
355
- df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
356
- st.subheader("Confidence per Label")
357
- st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
358
- st.subheader("Predicted Label")
359
- st.success(predicted_label)
360
  if choice == "Train and Evaluate Models":
361
  st.info("Train and Evaluate Models")
362
  training_task = ["No Options", "Model Definitions", "Hyperparameters", "Result of Evaluation"]
 
68
  x = self.fc(x)
69
  return x
70
 
71
+ @st.cache_resource
72
  def load_models(model_type):
73
  models = None
74
  model = None
 
315
 
316
  if choice == "Prediction":
317
  st.info("Predict with new text")
318
+ bilstm, tokenizer_bilstm, max_len_bilstm, phobert = load_models(model_type="bilstm_phobertbase")
319
+ longformer, tokenizer_longformer, max_len_longformer = load_models(model_type="longformer")
320
+ phobertbase, tokenizer_phobertbase, max_len_phobertbase = load_models(model_type="phobertbase")
321
+ news_text = st.text_area("Enter Text", "Type Here")
322
+ if st.button("Classify"):
323
+ processed_news = preprocess_text(news_text)
324
+ df_confidence_phobertbase, predicted_label_phobertbase = infer(news_text, tokenizer_phobertbase, phobertbase, class_names, max_len_phobertbase)
325
+ df_confidence_longformer, predicted_label_longformer = infer(news_text, tokenizer_longformer, longformer, class_names, max_len_longformer)
326
+ predicted_label_bilstm, confidence_df_bilstm = predict_label(processed_news, tokenizer_bilstm, phobert, bilstm, class_names, max_len_bilstm)
327
+ st.header("Original Text")
328
+ st.info(news_text)
329
+ st.header("Predict")
330
+ col4, col5, col6 = st.columns(3)
331
+
332
+ with col4:
333
+ st.markdown("**BiLSTM with PhoBert feature extraction**")
334
+ st.dataframe(confidence_df_bilstm, height=500, hide_index=True, use_container_width=True)
335
+ st.success(predicted_label_bilstm)
336
+
337
+ with col5:
338
+ st.markdown("**phobertbase**")
339
+ st.dataframe(df_confidence_phobertbase, height=500, hide_index=True, use_container_width=True)
340
+ st.success(predicted_label_phobertbase)
341
+
342
+ with col6:
343
+ st.markdown("**longformer-phobertbase**")
344
+ st.dataframe(df_confidence_longformer, height=500, hide_index=True, use_container_width=True)
345
+ st.success(predicted_label_longformer)
346
+ # all_dl_models = ["No Options", "BiLSTM + phobertbase", "longformer-phobertbase", "phobertbase"]
347
+ # model_choice = st.selectbox("Choose Model", all_dl_models)
348
 
349
+ # if model_choice == "BiLSTM + phobertbase":
350
+ # model, tokenizer, max_len, phobert = load_models(model_type="bilstm_phobertbase")
351
+ # news_text = st.text_area("Enter Text", "Type Here")
352
+ # if st.button("Classify"):
353
+ # st.header("Original Text")
354
+ # st.info(news_text)
355
+ # st.header("Predict")
356
+ # processed_news = preprocess_text(news_text)
357
+ # predicted_label, confidence_df = predict_label(processed_news, tokenizer, phobert, model, class_names, max_len)
358
+ # st.subheader("Confidence per Label")
359
+ # st.dataframe(confidence_df, height=500, hide_index=True, use_container_width=True)
360
+ # st.subheader("Predicted Label")
361
+ # st.success(predicted_label)
 
 
 
362
 
363
+ # if model_choice == "longformer-phobertbase":
364
+ # models, tokenizer, max_len = load_models(model_type="longformer")
365
+ # news_text = st.text_area("Enter Text", "Type Here")
366
+ # if st.button("Classify"):
367
+ # st.header("Original Text")
368
+ # st.info(news_text)
369
+ # st.header("Predict")
370
+ # df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
371
+ # st.subheader("Confidence per Label")
372
+ # st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
373
+ # st.subheader("Predicted Label")
374
+ # st.success(predicted_label)
375
+ # if model_choice == "phobertbase":
376
+ # models, tokenizer, max_len = load_models(model_type="phobertbase")
377
+ # news_text = st.text_area("Enter Text", "Type Here")
378
+ # if st.button("Classify"):
379
+ # st.header("Original Text")
380
+ # st.info(news_text)
381
+ # st.header("Predict")
382
+ # df_confidence, predicted_label = infer(news_text, tokenizer, models, class_names, max_len)
383
+ # st.subheader("Confidence per Label")
384
+ # st.dataframe(df_confidence, height=500, hide_index=True, use_container_width=True)
385
+ # st.subheader("Predicted Label")
386
+ # st.success(predicted_label)
387
  if choice == "Train and Evaluate Models":
388
  st.info("Train and Evaluate Models")
389
  training_task = ["No Options", "Model Definitions", "Hyperparameters", "Result of Evaluation"]