emvecchi commited on
Commit
effd9e4
1 Parent(s): 3540531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -73
app.py CHANGED
@@ -8,17 +8,20 @@ import pandas as pd
8
  import streamlit as st
9
  from huggingface_hub import HfFileSystem
10
 
 
11
 
12
  @dataclass
13
  class Field:
14
  type: str
15
  title: str
16
  name: str = None
 
 
 
17
  help: Optional[str] = None
18
  children: Optional[List['Field']] = None
19
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
20
 
21
-
22
  # Function to get user ID from URL
23
  def get_user_id_from_url():
24
  user_id = st.query_params.get("user_id", "")
@@ -27,6 +30,10 @@ def get_user_id_from_url():
27
  HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
28
  print("is none?", HF_TOKEN is None)
29
  hf_fs = HfFileSystem(token=HF_TOKEN)
 
 
 
 
30
  input_repo_path = 'datasets/emvecchi/annotate-pilot'
31
  output_repo_path = 'datasets/emvecchi/annotate-pilot'
32
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
@@ -178,7 +185,9 @@ INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
178
  'select_slider': 0,
179
  'multiselect': None}
180
  SHOW_HELP_ICON = False
 
181
 
 
182
  def read_data(_path):
183
  with hf_fs.open(input_repo_path + '/' + _path) as f:
184
  return pd.read_csv(f)
@@ -216,13 +225,42 @@ def display_image(image_path):
216
  # Function to navigate rows
217
  def navigate(index_change):
218
  st.session_state.current_index += index_change
219
- print(st.session_state.current_index)
 
 
 
 
 
 
 
 
 
220
  # https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  st.rerun()
222
 
223
 
224
  def show_field(f: Field, index: int, data_collected):
225
  if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys():
 
226
  match f.type:
227
  case 'input_col':
228
  st.write(f.title)
@@ -244,82 +282,111 @@ def show_field(f: Field, index: int, data_collected):
244
  show_field(child, index, data_collected)
245
  else:
246
  key = f.name + str(index)
247
- value = st.session_state.default_values[f.name] = data_collected[f.name] if data_collected else \
248
- INPUT_FIELD_DEFAULT_VALUES[f.type]
 
249
  if not SHOW_HELP_ICON:
250
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
251
- f.help = None
 
 
 
 
 
 
 
 
 
 
 
 
252
  match f.type:
253
  case 'checkbox':
254
- st.session_state.data_inputs[f.name] = st.checkbox(f.title,
255
- key=key,
256
- value=value, help=f.help)
257
  case 'radio':
258
- st.session_state.data_inputs[f.name] = st.radio(f.title,
259
- ["not selected","yes","no","other"],
260
- key=key,
261
- help=f.help)
262
  case 'slider':
263
- st.session_state.data_inputs[f.name] = st.slider(f.title,
264
- min_value=0, max_value=6, step=1,
265
- key=key,
266
- value=value, help=f.help)
267
  case 'select_slider':
268
  labels = default_labels if not f.other_params.get('labels') else f.other_params.get('labels')
269
- st.session_state.data_inputs[f.name] = st.select_slider(f.title,
270
- options=[0, 20, 40, 60, 80, 100],
271
- format_func=lambda x: labels[x // 20],
272
- key=key,
273
- value=value, help=f.help)
274
  case 'multiselect':
275
  choices = default_choices if not f.other_params.get('choices') else f.other_params.get('choices')
276
- st.session_state.data_inputs[f.name] = st.multiselect(f.title,
277
- options = choices,
278
- key=key, max_selections=3,
279
- help=f.help)
 
 
 
 
 
 
 
280
  case 'text':
281
- st.session_state.data_inputs[f.name] = st.text_input(f.title, key=key, value=value)
282
  case 'textarea':
283
- st.session_state.data_inputs[f.name] = st.text_area(f.title, key=key, value=value)
284
 
 
 
285
 
286
 
287
  def show_fields(fields: List[Field]):
 
288
  index = st.session_state.current_index
289
  data_collected = read_saved_data()
290
- st.session_state.default_values = {}
291
- st.session_state.data_inputs = {}
292
 
293
  for field in fields:
294
  show_field(field, index, data_collected)
295
 
296
  submitted = st.form_submit_button("Submit")
297
  if submitted:
298
- with st.spinner(text="saving"):
299
- save_data({
300
- 'user_id': st.session_state.user_id,
301
- 'index': st.session_state.current_index,
302
- **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
303
- **st.session_state.data_inputs
304
- })
305
- st.success("Feedback submitted successfully!")
306
- navigate(1)
 
 
 
 
 
 
 
307
 
308
  st.set_page_config(layout='wide')
309
  # Title of the app
310
  st.title("Moderator Intervention Prediction")
311
 
 
312
  st.markdown(
313
  """<style>
314
  div[data-testid="stMarkdownContainer"] > p {
315
  font-size: 1rem;
316
  }
 
317
  </style>
318
  """, unsafe_allow_html=True)
319
 
320
- #with st.expander(label="Annotation Guidelines", expanded=False):
321
- # st.write('some guidelines here')
322
-
323
  # Load the data to annotate
324
  if 'data' not in st.session_state:
325
  st.session_state.data = read_data(to_annotate_file_name)
@@ -327,26 +394,27 @@ if 'data' not in st.session_state:
327
  # Initialize the current index
328
  if 'current_index' not in st.session_state:
329
  st.session_state.current_index = -3
 
 
330
 
331
  def add_validated_submit(fields, message):
 
332
  if st.form_submit_button("Submit"):
333
  if all(not x for x in fields):
334
  st.error(message)
335
  else:
336
  navigate(1)
337
 
338
-
339
  def add_checked_submit():
340
  check = st.checkbox('I agree', key='consent')
341
  add_validated_submit([check], "Please agree to give your consent to proceed")
342
 
343
 
344
  def add_annotation_guidelines():
 
345
  st.markdown(
346
- "<details open><summary>Annotation Guidelines</summary>"+guidelines_text+"</details>"
347
  , unsafe_allow_html=True)
348
- st.write(f"username is {st.session_state.user_id}")
349
-
350
  if st.session_state.current_index == -3:
351
  with st.form("data_form"):
352
  st.markdown(consent_text)
@@ -362,49 +430,26 @@ elif st.session_state.current_index == -2:
362
  st.session_state.user_id = st.text_input('User ID', value=user_id_from_url)
363
  add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
364
 
365
-
366
  elif st.session_state.current_index == -1:
367
  add_annotation_guidelines()
368
- with st.form("data_form"):
369
  show_fields(intro_fields)
370
 
371
  elif st.session_state.current_index < len(st.session_state.data):
372
  add_annotation_guidelines()
373
- with st.form("data_form"):
374
  show_fields(fields)
375
 
376
  else:
377
- st.write(f"Thank you for taking part in this study! Code to finish the study: {study_code}")
378
-
379
 
380
  # Navigation buttons
381
  if st.session_state.current_index > 0:
382
  if st.button("Previous"):
383
- with st.spinner(text="in progress"):
384
- navigate(-1)
385
  if 0 <= st.session_state.current_index < len(st.session_state.data):
386
  st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}")
387
 
388
-
389
- # disable text input enter to submit
390
- # https://discuss.streamlit.io/t/text-input-how-to-disable-press-enter-to-apply/14457/6
391
- import streamlit.components.v1 as components
392
-
393
- components.html(
394
- """
395
- <script>
396
- const inputs = window.parent.document.querySelectorAll('input');
397
- inputs.forEach(input => {
398
- input.addEventListener('keydown', function(event) {
399
- if (event.key === 'Enter') {
400
- event.preventDefault();
401
- }
402
- });
403
- });
404
- </script>
405
- """,
406
- height=0
407
- )
408
  st.markdown(
409
  """<style>
410
  div[data-testid="InputInstructions"] {
@@ -412,4 +457,3 @@ st.markdown(
412
  }
413
  </style>""", unsafe_allow_html=True
414
  )
415
-
 
8
  import streamlit as st
9
  from huggingface_hub import HfFileSystem
10
 
11
+ import streamlit.components.v1 as components
12
 
13
  @dataclass
14
  class Field:
15
  type: str
16
  title: str
17
  name: str = None
18
+ mandatory: bool = True
19
+ # if value of field is in the list of those values, makes following siblings mandatory
20
+ following_mandatory_values: list = False
21
  help: Optional[str] = None
22
  children: Optional[List['Field']] = None
23
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
24
 
 
25
  # Function to get user ID from URL
26
  def get_user_id_from_url():
27
  user_id = st.query_params.get("user_id", "")
 
30
  HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
31
  print("is none?", HF_TOKEN is None)
32
  hf_fs = HfFileSystem(token=HF_TOKEN)
33
+
34
+ ########################################################################################
35
+ # CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
36
+
37
  input_repo_path = 'datasets/emvecchi/annotate-pilot'
38
  output_repo_path = 'datasets/emvecchi/annotate-pilot'
39
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
 
185
  'select_slider': 0,
186
  'multiselect': None}
187
  SHOW_HELP_ICON = False
188
+ SHOW_VALIDATION_ERROR_MESSAGE = True
189
 
190
+ ########################################################################################
191
  def read_data(_path):
192
  with hf_fs.open(input_repo_path + '/' + _path) as f:
193
  return pd.read_csv(f)
 
225
  # Function to navigate rows
226
  def navigate(index_change):
227
  st.session_state.current_index += index_change
228
+ # only works consistently if done before rerun
229
+ js = '''
230
+ <script>
231
+ var body = window.parent.document.querySelector(".main");
232
+
233
+ body.scrollTop = 0;
234
+ window.scrollY = 0;
235
+ </script>
236
+ '''
237
+ st.components.v1.html(js, height=0)
238
  # https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2
239
+
240
+ # disable text input enter to submit
241
+ # https://discuss.streamlit.io/t/text-input-how-to-disable-press-enter-to-apply/14457/6
242
+ components.html(
243
+ """
244
+ <script>
245
+ const inputs = window.parent.document.querySelectorAll('input');
246
+ inputs.forEach(input => {
247
+ input.addEventListener('keydown', function(event) {
248
+ if (event.key === 'Enter') {
249
+ event.preventDefault();
250
+ }
251
+ });
252
+ });
253
+ </script>
254
+ """,
255
+ height=0
256
+ )
257
+
258
  st.rerun()
259
 
260
 
261
  def show_field(f: Field, index: int, data_collected):
262
  if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys():
263
+ st.session_state.following_mandatory = False
264
  match f.type:
265
  case 'input_col':
266
  st.write(f.title)
 
282
  show_field(child, index, data_collected)
283
  else:
284
  key = f.name + str(index)
285
+ st.session_state.data_inputs_keys.append(f.name)
286
+ value = st.session_state[key] if key in st.session_state else \
287
+ (data_collected[f.name] if data_collected else INPUT_FIELD_DEFAULT_VALUES[f.type])
288
  if not SHOW_HELP_ICON:
289
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
290
+
291
+ validation_error = False
292
+ if f.mandatory or st.session_state.following_mandatory:
293
+ # form is not displayed for first time
294
+ if st.session_state.form_displayed == st.session_state.current_index:
295
+ if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
296
+ st.session_state.valid = False
297
+ validation_error = True
298
+ elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
299
+ st.session_state.following_mandatory = True
300
+ f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
301
+ f.help = None
302
+
303
  match f.type:
304
  case 'checkbox':
305
+ st.checkbox(f.title,
306
+ key=key,
307
+ value=value, help=f.help)
308
  case 'radio':
309
+ st.radio((f.title,
310
+ ["not selected","yes","no","other"],
311
+ key=key,
312
+ help=f.help)
313
  case 'slider':
314
+ st.slider(f.title,
315
+ min_value=0, max_value=6, step=1,
316
+ key=key,
317
+ value=value, help=f.help)
318
  case 'select_slider':
319
  labels = default_labels if not f.other_params.get('labels') else f.other_params.get('labels')
320
+ st.select_slider(f.title,
321
+ options=[0, 20, 40, 60, 80, 100],
322
+ format_func=lambda x: labels[x // 20],
323
+ key=key,
324
+ value=value, help=f.help)
325
  case 'multiselect':
326
  choices = default_choices if not f.other_params.get('choices') else f.other_params.get('choices')
327
+ st.multiselect(f.title,
328
+ options = choices,
329
+ key=key, max_selections=3,
330
+ help=f.help)
331
+ case 'likert_radio':
332
+ labels = default_labels if not f.other_params.get('labels') else f.other_params.get('labels')
333
+ st.radio(f.title,
334
+ options=[0, 1, 2, 3, 4],
335
+ format_func=lambda x: labels[x],
336
+ key=key,
337
+ index=value, help=f.help, horizontal=True)
338
  case 'text':
339
+ st.text_input(f.title, key=key, value=value)
340
  case 'textarea':
341
+ st.text_area(f.title, key=key, value=value)
342
 
343
+ if validation_error:
344
+ st.error(f"Mandatory field")
345
 
346
 
347
  def show_fields(fields: List[Field]):
348
+ st.session_state.valid = True
349
  index = st.session_state.current_index
350
  data_collected = read_saved_data()
351
+ st.session_state.data_inputs_keys = []
352
+ st.session_state.following_mandatory = False
353
 
354
  for field in fields:
355
  show_field(field, index, data_collected)
356
 
357
  submitted = st.form_submit_button("Submit")
358
  if submitted:
359
+ if not st.session_state.valid:
360
+ st.error("Please fill in all mandatory fields")
361
+ # st.rerun() # filed-out values are not shown otherwise
362
+ else:
363
+ with st.spinner(text="saving"):
364
+ save_data({
365
+ 'user_id': st.session_state.user_id,
366
+ 'index': st.session_state.current_index,
367
+ **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
368
+ **{k: st.session_state[k+str(index)] for k in st.session_state.data_inputs_keys}
369
+ })
370
+ st.success("Feedback submitted successfully!")
371
+ navigate(1)
372
+
373
+ st.session_state.form_displayed = st.session_state.current_index
374
+
375
 
376
  st.set_page_config(layout='wide')
377
  # Title of the app
378
  st.title("Moderator Intervention Prediction")
379
 
380
+
381
  st.markdown(
382
  """<style>
383
  div[data-testid="stMarkdownContainer"] > p {
384
  font-size: 1rem;
385
  }
386
+ section.main > div {max-width:60rem}
387
  </style>
388
  """, unsafe_allow_html=True)
389
 
 
 
 
390
  # Load the data to annotate
391
  if 'data' not in st.session_state:
392
  st.session_state.data = read_data(to_annotate_file_name)
 
394
  # Initialize the current index
395
  if 'current_index' not in st.session_state:
396
  st.session_state.current_index = -3
397
+ st.session_state.form_displayed = -3
398
+
399
 
400
  def add_validated_submit(fields, message):
401
+ st.session_state.form_displayed = st.session_state.current_index
402
  if st.form_submit_button("Submit"):
403
  if all(not x for x in fields):
404
  st.error(message)
405
  else:
406
  navigate(1)
407
 
 
408
  def add_checked_submit():
409
  check = st.checkbox('I agree', key='consent')
410
  add_validated_submit([check], "Please agree to give your consent to proceed")
411
 
412
 
413
  def add_annotation_guidelines():
414
+ st.write(f"username is {st.session_state.user_id}")
415
  st.markdown(
416
+ "<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
417
  , unsafe_allow_html=True)
 
 
418
  if st.session_state.current_index == -3:
419
  with st.form("data_form"):
420
  st.markdown(consent_text)
 
430
  st.session_state.user_id = st.text_input('User ID', value=user_id_from_url)
431
  add_validated_submit([st.session_state.user_id], "Please enter a valid user ID")
432
 
 
433
  elif st.session_state.current_index == -1:
434
  add_annotation_guidelines()
435
+ with st.form("intro_form"):
436
  show_fields(intro_fields)
437
 
438
  elif st.session_state.current_index < len(st.session_state.data):
439
  add_annotation_guidelines()
440
+ with st.form("data_form"+str(st.session_state.current_index)):
441
  show_fields(fields)
442
 
443
  else:
444
+ st.write(f"Thank you for taking part in this study! [Click here]({redirect_url}) to complete the study.")
 
445
 
446
  # Navigation buttons
447
  if st.session_state.current_index > 0:
448
  if st.button("Previous"):
449
+ navigate(-1)
 
450
  if 0 <= st.session_state.current_index < len(st.session_state.data):
451
  st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}")
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  st.markdown(
454
  """<style>
455
  div[data-testid="InputInstructions"] {
 
457
  }
458
  </style>""", unsafe_allow_html=True
459
  )