AIEcosystem commited on
Commit
e02388e
·
verified ·
1 Parent(s): 37b3dd0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +113 -37
src/streamlit_app.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  import io
7
  import plotly.express as px
8
  import zipfile
9
- import json
10
  import string
11
  from cryptography.fernet import Fernet
12
  from streamlit_extras.stylable_container import stylable_container
@@ -16,11 +15,60 @@ from comet_ml import Experiment
16
 
17
  # --- Page Configuration and UI Elements ---
18
  st.set_page_config(layout="wide", page_title="Named Entity Recognition App")
19
- st.markdown("""
20
- <style>
21
- /* ... (Your CSS Styles) ... */
22
- </style>
23
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  st.subheader("DataHarvest", divider="violet")
25
  st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
26
  st.markdown(':rainbow[**Supported Languages: English**]')
@@ -62,6 +110,7 @@ COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
62
  comet_initialized = bool(COMET_API_KEY and COMET_WORKSPACE and COMET_PROJECT_NAME)
63
  if not comet_initialized:
64
  st.warning("Comet ML not initialized. Check environment variables.")
 
65
 
66
  # --- Label Definitions ---
67
  labels = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
@@ -69,27 +118,31 @@ category_mapping = {
69
  "People": ["person", "organization", "position"],
70
  "Locations": ["country", "city"],
71
  "Time": ["date", "time"],
72
- "Numbers": ["money", "cardinal"]}
 
73
 
74
  # --- Model Loading ---
75
  @st.cache_resource
76
  def load_ner_model():
 
77
  try:
78
- return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints= labels)
79
  except Exception as e:
80
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
81
  st.stop()
 
82
  model = load_ner_model()
83
  reverse_category_mapping = {label: category for category, label_list in category_mapping.items() for label in label_list}
84
 
85
  # --- Session State Initialization ---
86
- # This is the key fix. We use session state to control what is displayed.
87
  if 'show_results' not in st.session_state:
88
  st.session_state.show_results = False
89
  if 'last_text' not in st.session_state:
90
  st.session_state.last_text = ""
91
  if 'results_df' not in st.session_state:
92
  st.session_state.results_df = pd.DataFrame()
 
 
93
 
94
  # --- Text Input and Clear Button ---
95
  word_limit = 200
@@ -102,6 +155,8 @@ def clear_text():
102
  st.session_state['my_text_area'] = ""
103
  st.session_state.show_results = False
104
  st.session_state.last_text = ""
 
 
105
 
106
  def remove_punctuation(text):
107
  """Removes punctuation from a string."""
@@ -119,33 +174,39 @@ if st.button("Results"):
119
  st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
120
  st.session_state.show_results = False
121
  else:
122
- st.session_state.show_results = True
123
- st.session_state.last_text = text
124
- start_time = time.time()
125
- with st.spinner("Extracting entities...", show_time=True):
126
- cleaned_text = remove_punctuation(text)
127
- entities = model.predict_entities(cleaned_text, labels)
128
- df = pd.DataFrame(entities)
129
- st.session_state.results_df = df
130
- if not df.empty:
131
- df['category'] = df['label'].map(reverse_category_mapping)
132
- if comet_initialized:
133
- experiment = Experiment(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)
134
- experiment.log_parameter("input_text", text)
135
- experiment.log_table("predicted_entities", df)
136
- experiment.end()
137
-
138
- end_time = time.time()
139
- elapsed_time = end_time - start_time
140
- st.session_state.elapsed_time = elapsed_time
 
 
 
141
 
142
  # Display results if the state variable is True
143
  if st.session_state.show_results:
144
  df = st.session_state.results_df
145
  if not df.empty:
146
- st.subheader("Grouped Entities by Category", divider = "violet")
 
 
147
  category_names = sorted(list(category_mapping.keys()))
148
  category_tabs = st.tabs(category_names)
 
149
  for i, category_name in enumerate(category_names):
150
  with category_tabs[i]:
151
  df_category_filtered = df[df['category'] == category_name]
@@ -153,7 +214,7 @@ if st.session_state.show_results:
153
  st.dataframe(df_category_filtered.drop(columns=['category']), use_container_width=True)
154
  else:
155
  st.info(f"No entities found for the '{category_name}' category.")
156
-
157
  with st.expander("See Glossary of tags"):
158
  st.write('''
159
  - **text**: ['entity extracted from your text data']
@@ -162,10 +223,11 @@ if st.session_state.show_results:
162
  - **start**: ['index of the start of the corresponding entity']
163
  - **end**: ['index of the end of the corresponding entity']
164
  ''')
 
165
  st.divider()
166
 
167
  # Tree map
168
- st.subheader("Tree map", divider = "violet")
169
  fig_treemap = px.treemap(df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
170
  fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
171
  st.plotly_chart(fig_treemap)
@@ -174,13 +236,15 @@ if st.session_state.show_results:
174
  grouped_counts = df['category'].value_counts().reset_index()
175
  grouped_counts.columns = ['category', 'count']
176
  col1, col2 = st.columns(2)
 
177
  with col1:
178
- st.subheader("Pie chart", divider = "violet")
179
  fig_pie = px.pie(grouped_counts, values='count', names='category', hover_data=['count'], labels={'count': 'count'}, title='Percentage of predicted categories')
180
  fig_pie.update_traces(textposition='inside', textinfo='percent+label')
181
  st.plotly_chart(fig_pie)
 
182
  with col2:
183
- st.subheader("Bar chart", divider = "violet")
184
  fig_bar = px.bar(grouped_counts, x="count", y="category", color="category", text_auto=True, title='Occurrences of predicted categories')
185
  st.plotly_chart(fig_bar)
186
 
@@ -189,6 +253,7 @@ if st.session_state.show_results:
189
  word_counts = df['text'].value_counts().reset_index()
190
  word_counts.columns = ['Entity', 'Count']
191
  repeating_entities = word_counts[word_counts['Count'] > 1]
 
192
  if not repeating_entities.empty:
193
  st.dataframe(repeating_entities, use_container_width=True)
194
  fig_repeating_bar = px.bar(repeating_entities, x='Entity', y='Count', color='Entity')
@@ -196,20 +261,31 @@ if st.session_state.show_results:
196
  st.plotly_chart(fig_repeating_bar)
197
  else:
198
  st.warning("No entities were found that occur more than once.")
199
-
200
  # Download Section
201
  st.divider()
202
  dfa = pd.DataFrame(data={'Column Name': ['text', 'label', 'score', 'start', 'end'],
203
  'Description': ['entity extracted from your text data', 'label (tag) assigned to a given extracted entity', 'accuracy score; how accurately a tag has been assigned to a given entity', 'index of the start of the corresponding entity', 'index of the end of the corresponding entity']})
 
204
  buf = io.BytesIO()
205
  with zipfile.ZipFile(buf, "w") as myzip:
206
  myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
207
  myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
208
- with stylable_container(key="download_button", css_styles="""button { background-color: red; border: 1px solid black; padding: 5px; color: white; }""",):
209
- st.download_button(label="Download results and glossary (zip)", data=buf.getvalue(), file_name="nlpblogs_results.zip", mime="application/zip")
210
 
 
 
 
 
 
 
 
 
 
 
 
211
  st.text("")
212
  st.text("")
213
  st.info(f"Results processed in **{st.session_state.elapsed_time:.2f} seconds**.")
214
- else: # If df is empty after the button click
 
215
  st.warning("No entities were found in the provided text.")
 
6
  import io
7
  import plotly.express as px
8
  import zipfile
 
9
  import string
10
  from cryptography.fernet import Fernet
11
  from streamlit_extras.stylable_container import stylable_container
 
15
 
16
  # --- Page Configuration and UI Elements ---
17
  st.set_page_config(layout="wide", page_title="Named Entity Recognition App")
18
+
19
+ st.markdown(
20
+ """
21
+ <style>
22
+ /* Overall app container */
23
+ .stApp {
24
+ background-color: #F5F5F5; /* A very light grey */
25
+ color: #333333; /* Dark grey for text for good contrast */
26
+ }
27
+ /* Sidebar background */
28
+ .css-1d36184, .css-1d36184, .st-ck {
29
+ background-color: #D3D3D3; /* Light grey for the sidebar */
30
+ }
31
+ /* Expander header and content background */
32
+ .streamlit-expanderHeader, .streamlit-expanderContent {
33
+ background-color: #F5F5F5;
34
+ }
35
+ /* Text Area background and text color */
36
+ .stTextArea textarea {
37
+ background-color: #E6E6E6; /* Slightly darker grey for input fields */
38
+ color: #000000;
39
+ border: 1px solid #B0B0B0; /* Add a subtle border */
40
+ }
41
+ /* Button styling */
42
+ .stButton > button {
43
+ background-color: #B0B0B0; /* A medium grey for the button */
44
+ color: #FFFFFF; /* White text for contrast */
45
+ border: none;
46
+ padding: 10px 20px;
47
+ border-radius: 5px;
48
+ }
49
+ .stButton > button:hover {
50
+ background-color: #8C8C8C; /* Darker grey on hover */
51
+ }
52
+ /* Alert boxes */
53
+ .stAlert {
54
+ color: #000000;
55
+ border-left: 5px solid #8C8C8C; /* A dark grey border for a clean look */
56
+ }
57
+ .stAlert.st-warning {
58
+ background-color: #C0C0C0; /* Silver grey for warning */
59
+ }
60
+ .stAlert.st-success {
61
+ background-color: #C0C0C0; /* Silver grey for success */
62
+ }
63
+ /* Plotly container background */
64
+ .st-emotion-cache-1ujn73o, .st-emotion-cache-1215v34 {
65
+ background-color: #F5F5F5 !important;
66
+ }
67
+ </style>
68
+ """,
69
+ unsafe_allow_html=True
70
+ )
71
+
72
  st.subheader("DataHarvest", divider="violet")
73
  st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
74
  st.markdown(':rainbow[**Supported Languages: English**]')
 
110
  comet_initialized = bool(COMET_API_KEY and COMET_WORKSPACE and COMET_PROJECT_NAME)
111
  if not comet_initialized:
112
  st.warning("Comet ML not initialized. Check environment variables.")
113
+ print("Warning: Comet ML environment variables are not set. Logging will be disabled.")
114
 
115
  # --- Label Definitions ---
116
  labels = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
 
118
  "People": ["person", "organization", "position"],
119
  "Locations": ["country", "city"],
120
  "Time": ["date", "time"],
121
+ "Numbers": ["money", "cardinal"]
122
+ }
123
 
124
  # --- Model Loading ---
125
  @st.cache_resource
126
  def load_ner_model():
127
+ """Loads the GLiNER model and caches it."""
128
  try:
129
+ return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
130
  except Exception as e:
131
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
132
  st.stop()
133
+
134
  model = load_ner_model()
135
  reverse_category_mapping = {label: category for category, label_list in category_mapping.items() for label in label_list}
136
 
137
  # --- Session State Initialization ---
 
138
  if 'show_results' not in st.session_state:
139
  st.session_state.show_results = False
140
  if 'last_text' not in st.session_state:
141
  st.session_state.last_text = ""
142
  if 'results_df' not in st.session_state:
143
  st.session_state.results_df = pd.DataFrame()
144
+ if 'elapsed_time' not in st.session_state:
145
+ st.session_state.elapsed_time = 0.0
146
 
147
  # --- Text Input and Clear Button ---
148
  word_limit = 200
 
155
  st.session_state['my_text_area'] = ""
156
  st.session_state.show_results = False
157
  st.session_state.last_text = ""
158
+ st.session_state.results_df = pd.DataFrame()
159
+ st.session_state.elapsed_time = 0.0
160
 
161
  def remove_punctuation(text):
162
  """Removes punctuation from a string."""
 
174
  st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
175
  st.session_state.show_results = False
176
  else:
177
+ # Check if the text is different from the last time
178
+ if text != st.session_state.last_text:
179
+ st.session_state.show_results = True
180
+ st.session_state.last_text = text
181
+ start_time = time.time()
182
+ with st.spinner("Extracting entities...", show_time=True):
183
+ cleaned_text = remove_punctuation(text)
184
+ entities = model.predict_entities(cleaned_text, labels)
185
+ df = pd.DataFrame(entities)
186
+ st.session_state.results_df = df
187
+ if not df.empty:
188
+ df['category'] = df['label'].map(reverse_category_mapping)
189
+ if comet_initialized:
190
+ experiment = Experiment(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)
191
+ experiment.log_parameter("input_text", text)
192
+ experiment.log_table("predicted_entities", df)
193
+ experiment.end()
194
+ end_time = time.time()
195
+ st.session_state.elapsed_time = end_time - start_time
196
+ # If the text is the same, do nothing but keep results displayed
197
+ else:
198
+ st.session_state.show_results = True
199
 
200
  # Display results if the state variable is True
201
  if st.session_state.show_results:
202
  df = st.session_state.results_df
203
  if not df.empty:
204
+ df['category'] = df['label'].map(reverse_category_mapping)
205
+ st.subheader("Grouped Entities by Category", divider="violet")
206
+
207
  category_names = sorted(list(category_mapping.keys()))
208
  category_tabs = st.tabs(category_names)
209
+
210
  for i, category_name in enumerate(category_names):
211
  with category_tabs[i]:
212
  df_category_filtered = df[df['category'] == category_name]
 
214
  st.dataframe(df_category_filtered.drop(columns=['category']), use_container_width=True)
215
  else:
216
  st.info(f"No entities found for the '{category_name}' category.")
217
+
218
  with st.expander("See Glossary of tags"):
219
  st.write('''
220
  - **text**: ['entity extracted from your text data']
 
223
  - **start**: ['index of the start of the corresponding entity']
224
  - **end**: ['index of the end of the corresponding entity']
225
  ''')
226
+
227
  st.divider()
228
 
229
  # Tree map
230
+ st.subheader("Tree map", divider="violet")
231
  fig_treemap = px.treemap(df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
232
  fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
233
  st.plotly_chart(fig_treemap)
 
236
  grouped_counts = df['category'].value_counts().reset_index()
237
  grouped_counts.columns = ['category', 'count']
238
  col1, col2 = st.columns(2)
239
+
240
  with col1:
241
+ st.subheader("Pie chart", divider="violet")
242
  fig_pie = px.pie(grouped_counts, values='count', names='category', hover_data=['count'], labels={'count': 'count'}, title='Percentage of predicted categories')
243
  fig_pie.update_traces(textposition='inside', textinfo='percent+label')
244
  st.plotly_chart(fig_pie)
245
+
246
  with col2:
247
+ st.subheader("Bar chart", divider="violet")
248
  fig_bar = px.bar(grouped_counts, x="count", y="category", color="category", text_auto=True, title='Occurrences of predicted categories')
249
  st.plotly_chart(fig_bar)
250
 
 
253
  word_counts = df['text'].value_counts().reset_index()
254
  word_counts.columns = ['Entity', 'Count']
255
  repeating_entities = word_counts[word_counts['Count'] > 1]
256
+
257
  if not repeating_entities.empty:
258
  st.dataframe(repeating_entities, use_container_width=True)
259
  fig_repeating_bar = px.bar(repeating_entities, x='Entity', y='Count', color='Entity')
 
261
  st.plotly_chart(fig_repeating_bar)
262
  else:
263
  st.warning("No entities were found that occur more than once.")
264
+
265
  # Download Section
266
  st.divider()
267
  dfa = pd.DataFrame(data={'Column Name': ['text', 'label', 'score', 'start', 'end'],
268
  'Description': ['entity extracted from your text data', 'label (tag) assigned to a given extracted entity', 'accuracy score; how accurately a tag has been assigned to a given entity', 'index of the start of the corresponding entity', 'index of the end of the corresponding entity']})
269
+
270
  buf = io.BytesIO()
271
  with zipfile.ZipFile(buf, "w") as myzip:
272
  myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
273
  myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
 
 
274
 
275
+ with stylable_container(
276
+ key="download_button",
277
+ css_styles="""button { background-color: #8C8C8C; border: 1px solid black; padding: 5px; color: white; }""",
278
+ ):
279
+ st.download_button(
280
+ label="Download results and glossary (zip)",
281
+ data=buf.getvalue(),
282
+ file_name="nlpblogs_results.zip",
283
+ mime="application/zip"
284
+ )
285
+
286
  st.text("")
287
  st.text("")
288
  st.info(f"Results processed in **{st.session_state.elapsed_time:.2f} seconds**.")
289
+
290
+ else:
291
  st.warning("No entities were found in the provided text.")