tanquangduong commited on
Commit
e2ae1dd
1 Parent(s): b93e9c1

:rocket: fix convert id to class

Browse files
pages/1_Review_Sentiment_Analysis.py CHANGED
@@ -9,7 +9,13 @@ from hydralit_components import HyLoader, Loaders
9
  import pandas as pd
10
  import numpy as np
11
  from sklearn import metrics
12
- from utils import inference_from_pytorch, plot_confusion_matric, plot_donut_sentiment_percentage, create_classification_report, get_100_random_test_review
 
 
 
 
 
 
13
  from PIL import Image
14
 
15
 
@@ -62,6 +68,9 @@ if "df_imdb_test" in st.session_state:
62
  if "is_df_test_100_loaded" not in st.session_state:
63
  st.session_state["is_df_test_100_loaded"] = False
64
 
 
 
 
65
 
66
  with HyLoader("", loader_name=Loaders.pulse_bars):
67
  if menu_id == "tab1":
@@ -98,11 +107,12 @@ with HyLoader("", loader_name=Loaders.pulse_bars):
98
  st.dataframe(df_test_100_loaded, use_container_width=True)
99
 
100
  # label prediction count
 
 
 
101
  pred_labels = {
102
- "label": ["positive", "negative"],
103
- "count": list(
104
- df_test_100_loaded.predicted_class_id.value_counts()
105
- ),
106
  }
107
  df_pred_labels = pd.DataFrame(pred_labels)
108
 
 
9
  import pandas as pd
10
  import numpy as np
11
  from sklearn import metrics
12
+ from utils import (
13
+ inference_from_pytorch,
14
+ plot_confusion_matric,
15
+ plot_donut_sentiment_percentage,
16
+ create_classification_report,
17
+ get_100_random_test_review,
18
+ )
19
  from PIL import Image
20
 
21
 
 
68
  if "is_df_test_100_loaded" not in st.session_state:
69
  st.session_state["is_df_test_100_loaded"] = False
70
 
71
+ # create a map of the expected ids to their labels
72
+ id2label = {0: "NEGATIVE", 1: "POSITIVE"}
73
+ label2id = {"NEGATIVE": 0, "POSITIVE": 1}
74
 
75
  with HyLoader("", loader_name=Loaders.pulse_bars):
76
  if menu_id == "tab1":
 
107
  st.dataframe(df_test_100_loaded, use_container_width=True)
108
 
109
  # label prediction count
110
+ class_count = df_test_100_loaded.predicted_class_id.value_counts()
111
+ class_count_val = class_count.values.tolist()
112
+ class_count_id = class_count.index.tolist()
113
  pred_labels = {
114
+ "label": [id2label[x] for x in class_count_id],
115
+ "count": class_count_val,
 
 
116
  }
117
  df_pred_labels = pd.DataFrame(pred_labels)
118