emvecchi commited on
Commit
c79cbc5
1 Parent(s): 87ac987

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -97
app.py CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
6
 
7
  import pandas as pd
8
  import streamlit as st
9
- from fsspec.implementations.local import LocalFileSystem
10
  from huggingface_hub import HfFileSystem
11
 
12
  import streamlit.components.v1 as components
@@ -24,24 +23,19 @@ class Field:
24
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
25
 
26
  # Function to get user ID from URL
27
- def get_param_from_url(param):
28
- user_id = st.query_params.get(param, "")
29
  return user_id
30
 
 
 
 
31
 
32
  ########################################################################################
33
  # CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
34
 
35
- # 'local' or 'hf'. hf is for Hugging Face file system but has limits on the number of access per hour
36
- filesystem = 'hf'
37
- # path to repo or local file system TODO rename
38
  input_repo_path = 'datasets/emvecchi/annotation'
39
  output_repo_path = 'datasets/emvecchi/annotation'
40
- #filesystem = 'local'
41
- # path to repo or local file system
42
- #input_repo_path = '/data/mod_pred_annotation'
43
- #output_repo_path = '/data/mod_pred_annotation'
44
-
45
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
46
  COLS_TO_SAVE = ['comment_id','comment','confidence_score']
47
 
@@ -194,11 +188,6 @@ fields: List[Field] = [
194
  Field(name="other_comments", type="text", title="Please provide any additional details or information: *(optional)*", mandatory=False),
195
  ]),
196
  ]
197
-
198
- url_conditional_fields = [
199
- Field(name="skip", type="skip_checkbox",
200
- title="I am uncomfortable annotating this text and voluntarily skip this instance", mandatory=False)
201
- ]
202
  INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
203
  'text': '',
204
  'textarea': '',
@@ -212,27 +201,11 @@ SHOW_HELP_ICON = False
212
  SHOW_VALIDATION_ERROR_MESSAGE = True
213
 
214
  ########################################################################################
215
-
216
- if filesystem == 'hf':
217
- HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
218
- print("is none?", HF_TOKEN is None)
219
- hf_fs = HfFileSystem(token=HF_TOKEN)
220
- else:
221
- hf_fs = LocalFileSystem()
222
-
223
- def get_start_index():
224
- if hf_fs.exists(output_repo_path + '/' + get_base_path()):
225
- files = hf_fs.ls(output_repo_path + '/' + get_base_path())
226
- return len(files) - 2
227
- return -3
228
-
229
-
230
- def read_data():
231
- #assert st.session_state.phase, "Phase not provided"
232
- #with hf_fs.open(input_repo_path + '/' + to_annotate_file_name.format(phase=st.session_state.phase)) as f:
233
- with hf_fs.open(input_repo_path + '/' + to_annotate_file_name) as f:
234
  return pd.read_csv(f)
235
 
 
236
  def read_saved_data():
237
  _path = get_path()
238
  if hf_fs.exists(output_repo_path + '/' + _path):
@@ -243,20 +216,16 @@ def read_saved_data():
243
  print(e)
244
  return None
245
 
 
246
  # Write a remote file
247
  def save_data(data):
248
- if not hf_fs.exists(f"{output_repo_path}/{get_base_path()}"):
249
- hf_fs.mkdir(f"{output_repo_path}/{get_base_path()}")
250
  with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
251
  f.write(json.dumps(data))
252
 
253
 
254
- def get_base_path():
255
- #return f"{st.session_state.phase}/{st.session_state.batch}/{st.session_state.user_id}"
256
- return f"{st.session_state.user_id}"
257
-
258
  def get_path():
259
- return f"{get_base_path()}/{st.session_state.current_index}.json"
260
 
261
 
262
  def display_image(image_path):
@@ -307,6 +276,7 @@ def show_field(f: Field, index: int, data_collected):
307
  st.session_state.following_mandatory = False
308
  match f.type:
309
  case 'input_col':
 
310
  if f.name == 'image_name':
311
  st.write(f.title)
312
  image_name = st.session_state.data.iloc[index][f.name]
@@ -314,10 +284,7 @@ def show_field(f: Field, index: int, data_collected):
314
  image_path = os.path.join(input_repo_path, 'images', image_name)
315
  display_image(image_path)
316
  else:
317
- value = st.session_state.data.iloc[index][f.name]
318
- if value and value is not np.nan:
319
- st.write(f.title)
320
- st.write(value)
321
  case 'markdown':
322
  st.markdown(f.title)
323
  case 'expander' | 'container':
@@ -335,17 +302,14 @@ def show_field(f: Field, index: int, data_collected):
335
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
336
 
337
  validation_error = False
338
-
339
- # form is not displayed for first time
340
- if st.session_state.form_displayed == st.session_state.current_index:
341
- if f.mandatory or st.session_state.following_mandatory:
342
  if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
343
  st.session_state.valid = False
344
  validation_error = True
345
  elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
346
  st.session_state.following_mandatory = True
347
-
348
- if f.mandatory or st.session_state.following_mandatory:
349
  f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
350
  f.help = None
351
 
@@ -407,32 +371,28 @@ def show_fields(fields: List[Field]):
407
  data_collected = read_saved_data()
408
  st.session_state.data_inputs_keys = []
409
  st.session_state.following_mandatory = False
410
-
411
  for field in fields:
412
  show_field(field, index, data_collected)
413
 
414
  submitted = st.form_submit_button("Submit")
415
  if submitted:
416
- skip_sample = ('skip' in st.session_state and st.session_state['skip'])
417
- if not skip_sample and not st.session_state.valid:
418
  st.error("Please fill in all mandatory fields")
419
  # st.rerun() # filed-out values are not shown otherwise
420
  else:
421
  with st.spinner(text="saving"):
422
- prep_and_save_data(index, skip_sample)
 
 
 
 
 
423
  st.success("Feedback submitted successfully!")
424
  navigate(1)
425
 
426
  st.session_state.form_displayed = st.session_state.current_index
427
 
428
- def prep_and_save_data(index, skip_sample):
429
- save_data({
430
- 'user_id': st.session_state.user_id,
431
- 'index': st.session_state.current_index,
432
- **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
433
- **{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
434
- 'skip': skip_sample
435
- })
436
 
437
  #st.set_page_config(layout='wide')
438
  # Title of the app
@@ -448,39 +408,15 @@ div[data-testid="stMarkdownContainer"] > p {
448
  </style>
449
  """, unsafe_allow_html=True)
450
 
451
-
452
- def add_annotation_guidelines():
453
- st.write(f"username is {st.session_state.user_id}")
454
- st.markdown(
455
- "<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
456
- , unsafe_allow_html=True)
457
-
458
-
459
- #st.session_state.phase = get_param_from_url("phase")
460
  # Load the data to annotate
461
  if 'data' not in st.session_state:
462
- #st.session_state.batch = get_param_from_url("batch")
463
- #data = read_data()
464
- #if st.session_state:
465
- # st.session_state.data = data[data['batch'] == int(st.session_state.batch)]
466
- #else:
467
- # raise ValueError("Batch not provided")
468
- st.session_state.data = read_data()
469
- data = read_data()
470
-
471
- user_id_from_url = get_param_from_url("user_id")
472
- if user_id_from_url:
473
- st.session_state.user_id = user_id_from_url
474
-
475
 
 
476
  if 'current_index' not in st.session_state:
477
- start_index = get_start_index()
478
- st.session_state.current_index = start_index
479
  st.session_state.form_displayed = -3
480
 
481
- if get_param_from_url('show_extra_fields'):
482
- fields += url_conditional_fields
483
-
484
 
485
  def add_validated_submit(fields, message):
486
  st.session_state.form_displayed = st.session_state.current_index
@@ -494,14 +430,21 @@ def add_checked_submit():
494
  check = st.checkbox('I agree', key='consent')
495
  add_validated_submit([check], "Please agree to give your consent to proceed")
496
 
497
-
 
 
 
 
 
498
  if st.session_state.current_index == -3:
499
  with st.form("data_form"):
500
  st.markdown(consent_text)
501
  add_checked_submit()
502
 
503
  elif st.session_state.current_index == -2:
504
- if st.session_state.get('user_id'):
 
 
505
  navigate(1)
506
  else:
507
  with st.form("data_form"):
@@ -518,13 +461,12 @@ elif st.session_state.current_index < len(st.session_state.data):
518
  with st.form("data_form"+str(st.session_state.current_index)):
519
  show_fields(fields)
520
 
521
-
522
  elif st.session_state.current_index == len(st.session_state.data):
523
  with st.form("intro_form"):
524
  show_fields(end_fields)
525
-
526
  else:
527
- st.write(f"Thank you, again! You are all done.")
528
 
529
  # Navigation buttons
530
  if st.session_state.current_index > 0:
@@ -539,4 +481,5 @@ st.markdown(
539
  visibility: hidden;
540
  }
541
  </style>""", unsafe_allow_html=True
542
- )
 
 
6
 
7
  import pandas as pd
8
  import streamlit as st
 
9
  from huggingface_hub import HfFileSystem
10
 
11
  import streamlit.components.v1 as components
 
23
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
24
 
25
  # Function to get user ID from URL
26
+ def get_user_id_from_url():
27
+ user_id = st.query_params.get("user_id", "")
28
  return user_id
29
 
30
+ HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
31
+ print("is none?", HF_TOKEN is None)
32
+ hf_fs = HfFileSystem(token=HF_TOKEN)
33
 
34
  ########################################################################################
35
  # CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
36
 
 
 
 
37
  input_repo_path = 'datasets/emvecchi/annotation'
38
  output_repo_path = 'datasets/emvecchi/annotation'
 
 
 
 
 
39
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
40
  COLS_TO_SAVE = ['comment_id','comment','confidence_score']
41
 
 
188
  Field(name="other_comments", type="text", title="Please provide any additional details or information: *(optional)*", mandatory=False),
189
  ]),
190
  ]
 
 
 
 
 
191
  INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
192
  'text': '',
193
  'textarea': '',
 
201
  SHOW_VALIDATION_ERROR_MESSAGE = True
202
 
203
  ########################################################################################
204
+ def read_data(_path):
205
+ with hf_fs.open(input_repo_path + '/' + _path) as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  return pd.read_csv(f)
207
 
208
+
209
  def read_saved_data():
210
  _path = get_path()
211
  if hf_fs.exists(output_repo_path + '/' + _path):
 
216
  print(e)
217
  return None
218
 
219
+
220
  # Write a remote file
221
  def save_data(data):
222
+ hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}")
 
223
  with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
224
  f.write(json.dumps(data))
225
 
226
 
 
 
 
 
227
  def get_path():
228
+ return f"{st.session_state.user_id}/{st.session_state.current_index}.json"
229
 
230
 
231
  def display_image(image_path):
 
276
  st.session_state.following_mandatory = False
277
  match f.type:
278
  case 'input_col':
279
+ st.write(f.title)
280
  if f.name == 'image_name':
281
  st.write(f.title)
282
  image_name = st.session_state.data.iloc[index][f.name]
 
284
  image_path = os.path.join(input_repo_path, 'images', image_name)
285
  display_image(image_path)
286
  else:
287
+ st.write(st.session_state.data.iloc[index][f.name])
 
 
 
288
  case 'markdown':
289
  st.markdown(f.title)
290
  case 'expander' | 'container':
 
302
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
303
 
304
  validation_error = False
305
+ if f.mandatory or st.session_state.following_mandatory:
306
+ # form is not displayed for first time
307
+ if st.session_state.form_displayed == st.session_state.current_index:
 
308
  if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
309
  st.session_state.valid = False
310
  validation_error = True
311
  elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
312
  st.session_state.following_mandatory = True
 
 
313
  f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
314
  f.help = None
315
 
 
371
  data_collected = read_saved_data()
372
  st.session_state.data_inputs_keys = []
373
  st.session_state.following_mandatory = False
374
+
375
  for field in fields:
376
  show_field(field, index, data_collected)
377
 
378
  submitted = st.form_submit_button("Submit")
379
  if submitted:
380
+ if not st.session_state.valid:
 
381
  st.error("Please fill in all mandatory fields")
382
  # st.rerun() # filed-out values are not shown otherwise
383
  else:
384
  with st.spinner(text="saving"):
385
+ save_data({
386
+ 'user_id': st.session_state.user_id,
387
+ 'index': st.session_state.current_index,
388
+ **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
389
+ **{k: st.session_state[k+str(index)] for k in st.session_state.data_inputs_keys}
390
+ })
391
  st.success("Feedback submitted successfully!")
392
  navigate(1)
393
 
394
  st.session_state.form_displayed = st.session_state.current_index
395
 
 
 
 
 
 
 
 
 
396
 
397
  #st.set_page_config(layout='wide')
398
  # Title of the app
 
408
  </style>
409
  """, unsafe_allow_html=True)
410
 
 
 
 
 
 
 
 
 
 
411
  # Load the data to annotate
412
  if 'data' not in st.session_state:
413
+ st.session_state.data = read_data(to_annotate_file_name)
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ # Initialize the current index
416
  if 'current_index' not in st.session_state:
417
+ st.session_state.current_index = -3
 
418
  st.session_state.form_displayed = -3
419
 
 
 
 
420
 
421
  def add_validated_submit(fields, message):
422
  st.session_state.form_displayed = st.session_state.current_index
 
430
  check = st.checkbox('I agree', key='consent')
431
  add_validated_submit([check], "Please agree to give your consent to proceed")
432
 
433
+
434
+ def add_annotation_guidelines():
435
+ st.write(f"username is {st.session_state.user_id}")
436
+ st.markdown(
437
+ "<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
438
+ , unsafe_allow_html=True)
439
  if st.session_state.current_index == -3:
440
  with st.form("data_form"):
441
  st.markdown(consent_text)
442
  add_checked_submit()
443
 
444
  elif st.session_state.current_index == -2:
445
+ user_id_from_url = get_user_id_from_url()
446
+ if user_id_from_url:
447
+ st.session_state.user_id = user_id_from_url
448
  navigate(1)
449
  else:
450
  with st.form("data_form"):
 
461
  with st.form("data_form"+str(st.session_state.current_index)):
462
  show_fields(fields)
463
 
 
464
  elif st.session_state.current_index == len(st.session_state.data):
465
  with st.form("intro_form"):
466
  show_fields(end_fields)
467
+
468
  else:
469
+ st.write(f"Thank you for taking part in this study!")
470
 
471
  # Navigation buttons
472
  if st.session_state.current_index > 0:
 
481
  visibility: hidden;
482
  }
483
  </style>""", unsafe_allow_html=True
484
+ )
485
+