emvecchi commited on
Commit
3848935
1 Parent(s): 1ba5f54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -43
app.py CHANGED
@@ -4,12 +4,15 @@ from dataclasses import dataclass, field
4
  from typing import List, Optional, Dict
5
  from PIL import Image
6
 
 
7
  import pandas as pd
8
  import streamlit as st
 
9
  from huggingface_hub import HfFileSystem
10
 
11
  import streamlit.components.v1 as components
12
 
 
13
  @dataclass
14
  class Field:
15
  type: str
@@ -23,19 +26,24 @@ class Field:
23
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
24
 
25
  # Function to get user ID from URL
26
- def get_user_id_from_url():
27
- user_id = st.query_params.get("user_id", "")
28
  return user_id
29
 
30
- HF_TOKEN = os.environ.get("HF_TOKEN_WRITE2")
31
- print("is none?", HF_TOKEN is None)
32
- hf_fs = HfFileSystem(token=HF_TOKEN)
33
 
34
  ########################################################################################
35
  # CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
36
 
 
 
 
37
  input_repo_path = 'datasets/emvecchi/annotation'
38
- output_repo_path = 'datasets/emvecchi/annotation'
 
 
 
 
 
39
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
40
  COLS_TO_SAVE = ['comment_id','comment','confidence_score']
41
 
@@ -128,6 +136,21 @@ Please indicate, in the box below, that you are at least 18 years old, have read
128
  '''
129
  guidelines_text = 'Please read <a href="https://acrobat.adobe.com/id/urn:aaid:sc:EU:1a1347b0-3423-49ee-aa28-87679f8a69c0">the guidelines</a>'
130
  study_code = 'CE552C7F'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  intro_fields: List[Field] = [
133
  Field(type="container", title="**Introductory Questions**", children=[
@@ -201,10 +224,23 @@ SHOW_HELP_ICON = False
201
  SHOW_VALIDATION_ERROR_MESSAGE = True
202
 
203
  ########################################################################################
204
- def read_data(_path):
205
- with hf_fs.open(input_repo_path + '/' + _path) as f:
206
- return pd.read_csv(f)
 
 
 
207
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def read_saved_data():
210
  _path = get_path()
@@ -215,17 +251,21 @@ def read_saved_data():
215
  except json.JSONDecodeError as e:
216
  print(e)
217
  return None
218
-
219
 
220
  # Write a remote file
221
  def save_data(data):
222
- hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}")
 
223
  with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
224
  f.write(json.dumps(data))
225
 
226
 
 
 
 
227
  def get_path():
228
- return f"{st.session_state.user_id}/{st.session_state.current_index}.json"
229
 
230
 
231
  def display_image(image_path):
@@ -276,15 +316,13 @@ def show_field(f: Field, index: int, data_collected):
276
  st.session_state.following_mandatory = False
277
  match f.type:
278
  case 'input_col':
279
- st.write(f.title)
280
- if f.name == 'image_name':
281
  st.write(f.title)
282
- image_name = st.session_state.data.iloc[index][f.name]
283
- if image_name: # Ensure the image name is not empty
284
- image_path = os.path.join(input_repo_path, 'images', image_name)
285
- display_image(image_path)
286
- else:
287
- st.write(st.session_state.data.iloc[index][f.name])
288
  case 'markdown':
289
  st.markdown(f.title)
290
  case 'expander' | 'container':
@@ -293,6 +331,8 @@ def show_field(f: Field, index: int, data_collected):
293
  st.markdown(f.title)
294
  for child in f.children:
295
  show_field(child, index, data_collected)
 
 
296
  else:
297
  key = f.name + str(index)
298
  st.session_state.data_inputs_keys.append(f.name)
@@ -302,17 +342,32 @@ def show_field(f: Field, index: int, data_collected):
302
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
303
 
304
  validation_error = False
305
- if f.mandatory or st.session_state.following_mandatory:
306
- # form is not displayed for first time
307
- if st.session_state.form_displayed == st.session_state.current_index:
 
308
  if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
309
  st.session_state.valid = False
310
  validation_error = True
311
  elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
312
  st.session_state.following_mandatory = True
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
314
  f.help = None
315
 
 
316
  match f.type:
317
  case 'checkbox':
318
  st.checkbox(f.title,
@@ -362,6 +417,7 @@ def show_field(f: Field, index: int, data_collected):
362
  st.text_area(f.title, key=key, value=value, max_chars=None)
363
 
364
  if validation_error:
 
365
  st.error(f"Mandatory field")
366
 
367
 
@@ -377,24 +433,31 @@ def show_fields(fields: List[Field]):
377
 
378
  submitted = st.form_submit_button("Submit")
379
  if submitted:
380
- if not st.session_state.valid:
 
 
 
 
381
  st.error("Please fill in all mandatory fields")
382
  # st.rerun() # filed-out values are not shown otherwise
383
  else:
384
  with st.spinner(text="saving"):
385
- save_data({
386
- 'user_id': st.session_state.user_id,
387
- 'index': st.session_state.current_index,
388
- **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
389
- **{k: st.session_state[k+str(index)] for k in st.session_state.data_inputs_keys}
390
- })
391
  st.success("Feedback submitted successfully!")
392
  navigate(1)
393
 
394
  st.session_state.form_displayed = st.session_state.current_index
395
 
 
 
 
 
 
 
 
 
396
 
397
- #st.set_page_config(layout='wide')
398
  # Title of the app
399
  st.title("Moderator Intervention Prediction")
400
 
@@ -404,19 +467,46 @@ st.markdown(
404
  div[data-testid="stMarkdownContainer"] > p {
405
  font-size: 1rem;
406
  }
407
- section.main > div {max-width:75rem}
408
  </style>
409
  """, unsafe_allow_html=True)
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  # Load the data to annotate
412
  if 'data' not in st.session_state:
413
- st.session_state.data = read_data(to_annotate_file_name)
 
 
 
 
 
414
 
415
- # Initialize the current index
416
  if 'current_index' not in st.session_state:
417
- st.session_state.current_index = -3
 
418
  st.session_state.form_displayed = -3
419
 
 
 
 
420
 
421
  def add_validated_submit(fields, message):
422
  st.session_state.form_displayed = st.session_state.current_index
@@ -426,25 +516,19 @@ def add_validated_submit(fields, message):
426
  else:
427
  navigate(1)
428
 
 
429
  def add_checked_submit():
430
  check = st.checkbox('I agree', key='consent')
431
  add_validated_submit([check], "Please agree to give your consent to proceed")
432
 
433
 
434
- def add_annotation_guidelines():
435
- st.write(f"username is {st.session_state.user_id}")
436
- st.markdown(
437
- "<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
438
- , unsafe_allow_html=True)
439
  if st.session_state.current_index == -3:
440
  with st.form("data_form"):
441
  st.markdown(consent_text)
442
  add_checked_submit()
443
 
444
  elif st.session_state.current_index == -2:
445
- user_id_from_url = get_user_id_from_url()
446
- if user_id_from_url:
447
- st.session_state.user_id = user_id_from_url
448
  navigate(1)
449
  else:
450
  with st.form("data_form"):
 
4
  from typing import List, Optional, Dict
5
  from PIL import Image
6
 
7
+ import numpy as np
8
  import pandas as pd
9
  import streamlit as st
10
+ from fsspec.implementations.local import LocalFileSystem
11
  from huggingface_hub import HfFileSystem
12
 
13
  import streamlit.components.v1 as components
14
 
15
+
16
  @dataclass
17
  class Field:
18
  type: str
 
26
  other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
27
 
28
  # Function to get user ID from URL
29
+ def get_param_from_url(param):
30
+ user_id = st.query_params.get(param, "")
31
  return user_id
32
 
 
 
 
33
 
34
  ########################################################################################
35
  # CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
36
 
37
+ # 'local' or 'hf'. hf is for Hugging Face file system but has limits on the number of access per hour
38
+ filesytem = 'hf'
39
+ # path to repo or local file system TODO rename
40
  input_repo_path = 'datasets/emvecchi/annotation'
41
+ output_repo_path = 'datasets/emvecchi/annotation
42
+ # filesystem = 'local'
43
+ # path to repo or local file system
44
+ # input_repo_path = '/data/mod-gen-eval-pref'
45
+ # output_repo_path = '/data/mod-gen-eval-pref'
46
+
47
  to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
48
  COLS_TO_SAVE = ['comment_id','comment','confidence_score']
49
 
 
136
  '''
137
  guidelines_text = 'Please read <a href="https://acrobat.adobe.com/id/urn:aaid:sc:EU:1a1347b0-3423-49ee-aa28-87679f8a69c0">the guidelines</a>'
138
  study_code = 'CE552C7F'
139
+ # failed_sanity_check_code = 'C102EK63' # screened-out code
140
+ failed_sanity_check_code = 'C15RGLJA'
141
+ redirect_url = f'https://app.prolific.com/submissions/complete?cfc={study_code}'
142
+
143
+ annotation_guidelines_fields: List[Field] = [
144
+ Field(name="annotation_guidelines", type="radio", title="Did you read the guidelines?", mandatory=True,
145
+ other_params={'labels': ['Yes, in detail, and I understand the study',
146
+ 'Yes, in detail, but still confused',
147
+ 'Yes, I skimmed it',
148
+ 'I will read it later',
149
+ 'No, not interested in reading them',
150
+ 'I can not open the link',
151
+ ],
152
+ 'accepted_values': [0]}),
153
+ ]
154
 
155
  intro_fields: List[Field] = [
156
  Field(type="container", title="**Introductory Questions**", children=[
 
224
  SHOW_VALIDATION_ERROR_MESSAGE = True
225
 
226
  ########################################################################################
227
+ if filesystem == 'hf':
228
+ HF_TOKEN = os.environ.get("HF_TOKEN_WRITE2")
229
+ print("is none?", HF_TOKEN is None)
230
+ hf_fs = HfFileSystem(token=HF_TOKEN)
231
+ else:
232
+ hf_fs = LocalFileSystem()
233
 
234
+ def get_start_index():
235
+ if hf_fs.exists(output_repo_path + '/' + get_base_path()):
236
+ files = hf_fs.ls(output_repo_path + '/' + get_base_path())
237
+ return len(files) - 2
238
+ else:
239
+ return -3
240
+
241
+ def read_data():
242
+ with hf_fs.open(input_repo_path + '/' + to_annotate_file_name) as f:
243
+ return pd.read_csv(f)
244
 
245
  def read_saved_data():
246
  _path = get_path()
 
251
  except json.JSONDecodeError as e:
252
  print(e)
253
  return None
254
+
255
 
256
  # Write a remote file
257
  def save_data(data):
258
+ if not hf_fs.exists(f"{output_repo_path}/{get_base_path()}"):
259
+ hf_fs.mkdir(f"{output_repo_path}/{get_base_path()}")
260
  with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
261
  f.write(json.dumps(data))
262
 
263
 
264
+ def get_base_path():
265
+ return f"{st.session_state.user_id}"
266
+
267
  def get_path():
268
+ return f"{get_base_path()}/{st.session_state.current_index}.json"
269
 
270
 
271
  def display_image(image_path):
 
316
  st.session_state.following_mandatory = False
317
  match f.type:
318
  case 'input_col':
319
+ value = st.session_state.data.iloc[index][f.name]
320
+ if value and value is not np.nan:
321
  st.write(f.title)
322
+ if f.name == 'image_name':
323
+ display_image(os.path.join(input_repo_path, 'images', value))
324
+ else:
325
+ st.write(value)
 
 
326
  case 'markdown':
327
  st.markdown(f.title)
328
  case 'expander' | 'container':
 
331
  st.markdown(f.title)
332
  for child in f.children:
333
  show_field(child, index, data_collected)
334
+ case 'skip_checkbox':
335
+ st.checkbox(f.title, key=f.name, value=False)
336
  else:
337
  key = f.name + str(index)
338
  st.session_state.data_inputs_keys.append(f.name)
 
342
  f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
343
 
344
  validation_error = False
345
+
346
+ # form is not displayed for first time
347
+ if st.session_state.form_displayed == st.session_state.current_index:
348
+ if f.mandatory or st.session_state.following_mandatory:
349
  if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
350
  st.session_state.valid = False
351
  validation_error = True
352
  elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
353
  st.session_state.following_mandatory = True
354
+
355
+
356
+ # check for any unaccepted values
357
+ if (
358
+ (f.other_params.get('accepted_values') and
359
+ value not in f.other_params.get('accepted_values')) or
360
+ (f.other_params.get('accepted_values_per_sample') and
361
+ index in f.other_params.get('accepted_values_per_sample') and
362
+ value not in f.other_params.get('accepted_values_per_sample').get(index))
363
+ ):
364
+ st.session_state.unacceptable_response = True
365
+
366
+ if f.mandatory or st.session_state.following_mandatory:
367
  f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
368
  f.help = None
369
 
370
+
371
  match f.type:
372
  case 'checkbox':
373
  st.checkbox(f.title,
 
417
  st.text_area(f.title, key=key, value=value, max_chars=None)
418
 
419
  if validation_error:
420
+ st.session_state.unacceptable_response = False
421
  st.error(f"Mandatory field")
422
 
423
 
 
433
 
434
  submitted = st.form_submit_button("Submit")
435
  if submitted:
436
+ if 'unacceptable_response' in st.session_state and st.session_state.unacceptable_response:
437
+ prep_and_save_data(index, ('skip' in st.session_state and st.session_state['skip']))
438
+ st.rerun()
439
+ skip_sample = ('skip' in st.session_state and st.session_state['skip'])
440
+ if not skip_sample and not st.session_state.valid:
441
  st.error("Please fill in all mandatory fields")
442
  # st.rerun() # filed-out values are not shown otherwise
443
  else:
444
  with st.spinner(text="saving"):
445
+ prep_and_save_data(index, skip_sample)
 
 
 
 
 
446
  st.success("Feedback submitted successfully!")
447
  navigate(1)
448
 
449
  st.session_state.form_displayed = st.session_state.current_index
450
 
451
+ def prep_and_save_data(index, skip_sample):
452
+ save_data({
453
+ 'user_id': st.session_state.user_id,
454
+ 'index': st.session_state.current_index,
455
+ **(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
456
+ **{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
457
+ 'skip': skip_sample
458
+ })
459
 
460
+ # st.set_page_config(layout='wide')
461
  # Title of the app
462
  st.title("Moderator Intervention Prediction")
463
 
 
467
  div[data-testid="stMarkdownContainer"] > p {
468
  font-size: 1rem;
469
  }
470
+ section.main > div {max-width:60rem}
471
  </style>
472
  """, unsafe_allow_html=True)
473
 
474
+
475
+ def add_annotation_guidelines():
476
+ st.write(f"username is {st.session_state.user_id}")
477
+ st.markdown(
478
+ "<details open><summary><b>Annotation Guidelines</b></summary>" + guidelines_text + "</details><br>"
479
+ , unsafe_allow_html=True)
480
+
481
+
482
+ if 'unacceptable_response' in st.session_state and st.session_state.unacceptable_response:
483
+ add_annotation_guidelines()
484
+ st.error("You are not eligible for this study. Thank you for your time!" +
485
+ ("" if st.session_state.current_index < 0 else
486
+ #" You will receive a small compensation as explained in the guidelines. "
487
+ "Please email eva-maria.vecchi@ims.uni-stuttgart.de for issues or questions."
488
+ ))
489
+ st.stop()
490
+
491
+
492
  # Load the data to annotate
493
  if 'data' not in st.session_state:
494
+ st.session_state.data = read_data()
495
+
496
+ # user id
497
+ user_id_from_url = get_param_from_url("user_id")
498
+ if user_id_from_url:
499
+ st.session_state.user_id = user_id_from_url
500
 
501
+ # current index
502
  if 'current_index' not in st.session_state:
503
+ start_index = get_start_index()
504
+ st.session_state.current_index = start_index
505
  st.session_state.form_displayed = -3
506
 
507
+ if get_param_from_url('show_extra_fields'):
508
+ fields += url_conditional_fields
509
+
510
 
511
  def add_validated_submit(fields, message):
512
  st.session_state.form_displayed = st.session_state.current_index
 
516
  else:
517
  navigate(1)
518
 
519
+
520
  def add_checked_submit():
521
  check = st.checkbox('I agree', key='consent')
522
  add_validated_submit([check], "Please agree to give your consent to proceed")
523
 
524
 
 
 
 
 
 
525
  if st.session_state.current_index == -3:
526
  with st.form("data_form"):
527
  st.markdown(consent_text)
528
  add_checked_submit()
529
 
530
  elif st.session_state.current_index == -2:
531
+ if st.session_state.get('user_id'):
 
 
532
  navigate(1)
533
  else:
534
  with st.form("data_form"):