Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
|
|
6 |
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
9 |
-
from fsspec.implementations.local import LocalFileSystem
|
10 |
from huggingface_hub import HfFileSystem
|
11 |
|
12 |
import streamlit.components.v1 as components
|
@@ -24,24 +23,19 @@ class Field:
|
|
24 |
other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
|
25 |
|
26 |
# Function to get user ID from URL
|
27 |
-
def
|
28 |
-
user_id = st.query_params.get(
|
29 |
return user_id
|
30 |
|
|
|
|
|
|
|
31 |
|
32 |
########################################################################################
|
33 |
# CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
|
34 |
|
35 |
-
# 'local' or 'hf'. hf is for Hugging Face file system but has limits on the number of access per hour
|
36 |
-
filesystem = 'hf'
|
37 |
-
# path to repo or local file system TODO rename
|
38 |
input_repo_path = 'datasets/emvecchi/annotation'
|
39 |
output_repo_path = 'datasets/emvecchi/annotation'
|
40 |
-
#filesystem = 'local'
|
41 |
-
# path to repo or local file system
|
42 |
-
#input_repo_path = '/data/mod_pred_annotation'
|
43 |
-
#output_repo_path = '/data/mod_pred_annotation'
|
44 |
-
|
45 |
to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
|
46 |
COLS_TO_SAVE = ['comment_id','comment','confidence_score']
|
47 |
|
@@ -194,11 +188,6 @@ fields: List[Field] = [
|
|
194 |
Field(name="other_comments", type="text", title="Please provide any additional details or information: *(optional)*", mandatory=False),
|
195 |
]),
|
196 |
]
|
197 |
-
|
198 |
-
url_conditional_fields = [
|
199 |
-
Field(name="skip", type="skip_checkbox",
|
200 |
-
title="I am uncomfortable annotating this text and voluntarily skip this instance", mandatory=False)
|
201 |
-
]
|
202 |
INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
|
203 |
'text': '',
|
204 |
'textarea': '',
|
@@ -212,27 +201,11 @@ SHOW_HELP_ICON = False
|
|
212 |
SHOW_VALIDATION_ERROR_MESSAGE = True
|
213 |
|
214 |
########################################################################################
|
215 |
-
|
216 |
-
|
217 |
-
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
|
218 |
-
print("is none?", HF_TOKEN is None)
|
219 |
-
hf_fs = HfFileSystem(token=HF_TOKEN)
|
220 |
-
else:
|
221 |
-
hf_fs = LocalFileSystem()
|
222 |
-
|
223 |
-
def get_start_index():
|
224 |
-
if hf_fs.exists(output_repo_path + '/' + get_base_path()):
|
225 |
-
files = hf_fs.ls(output_repo_path + '/' + get_base_path())
|
226 |
-
return len(files) - 2
|
227 |
-
return -3
|
228 |
-
|
229 |
-
|
230 |
-
def read_data():
|
231 |
-
#assert st.session_state.phase, "Phase not provided"
|
232 |
-
#with hf_fs.open(input_repo_path + '/' + to_annotate_file_name.format(phase=st.session_state.phase)) as f:
|
233 |
-
with hf_fs.open(input_repo_path + '/' + to_annotate_file_name) as f:
|
234 |
return pd.read_csv(f)
|
235 |
|
|
|
236 |
def read_saved_data():
|
237 |
_path = get_path()
|
238 |
if hf_fs.exists(output_repo_path + '/' + _path):
|
@@ -243,20 +216,16 @@ def read_saved_data():
|
|
243 |
print(e)
|
244 |
return None
|
245 |
|
|
|
246 |
# Write a remote file
|
247 |
def save_data(data):
|
248 |
-
|
249 |
-
hf_fs.mkdir(f"{output_repo_path}/{get_base_path()}")
|
250 |
with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
|
251 |
f.write(json.dumps(data))
|
252 |
|
253 |
|
254 |
-
def get_base_path():
|
255 |
-
#return f"{st.session_state.phase}/{st.session_state.batch}/{st.session_state.user_id}"
|
256 |
-
return f"{st.session_state.user_id}"
|
257 |
-
|
258 |
def get_path():
|
259 |
-
return f"{
|
260 |
|
261 |
|
262 |
def display_image(image_path):
|
@@ -307,6 +276,7 @@ def show_field(f: Field, index: int, data_collected):
|
|
307 |
st.session_state.following_mandatory = False
|
308 |
match f.type:
|
309 |
case 'input_col':
|
|
|
310 |
if f.name == 'image_name':
|
311 |
st.write(f.title)
|
312 |
image_name = st.session_state.data.iloc[index][f.name]
|
@@ -314,10 +284,7 @@ def show_field(f: Field, index: int, data_collected):
|
|
314 |
image_path = os.path.join(input_repo_path, 'images', image_name)
|
315 |
display_image(image_path)
|
316 |
else:
|
317 |
-
|
318 |
-
if value and value is not np.nan:
|
319 |
-
st.write(f.title)
|
320 |
-
st.write(value)
|
321 |
case 'markdown':
|
322 |
st.markdown(f.title)
|
323 |
case 'expander' | 'container':
|
@@ -335,17 +302,14 @@ def show_field(f: Field, index: int, data_collected):
|
|
335 |
f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
|
336 |
|
337 |
validation_error = False
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
if f.mandatory or st.session_state.following_mandatory:
|
342 |
if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
|
343 |
st.session_state.valid = False
|
344 |
validation_error = True
|
345 |
elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
|
346 |
st.session_state.following_mandatory = True
|
347 |
-
|
348 |
-
if f.mandatory or st.session_state.following_mandatory:
|
349 |
f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
|
350 |
f.help = None
|
351 |
|
@@ -407,32 +371,28 @@ def show_fields(fields: List[Field]):
|
|
407 |
data_collected = read_saved_data()
|
408 |
st.session_state.data_inputs_keys = []
|
409 |
st.session_state.following_mandatory = False
|
410 |
-
|
411 |
for field in fields:
|
412 |
show_field(field, index, data_collected)
|
413 |
|
414 |
submitted = st.form_submit_button("Submit")
|
415 |
if submitted:
|
416 |
-
|
417 |
-
if not skip_sample and not st.session_state.valid:
|
418 |
st.error("Please fill in all mandatory fields")
|
419 |
# st.rerun() # filed-out values are not shown otherwise
|
420 |
else:
|
421 |
with st.spinner(text="saving"):
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
423 |
st.success("Feedback submitted successfully!")
|
424 |
navigate(1)
|
425 |
|
426 |
st.session_state.form_displayed = st.session_state.current_index
|
427 |
|
428 |
-
def prep_and_save_data(index, skip_sample):
|
429 |
-
save_data({
|
430 |
-
'user_id': st.session_state.user_id,
|
431 |
-
'index': st.session_state.current_index,
|
432 |
-
**(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
|
433 |
-
**{k: st.session_state[k + str(index)] for k in st.session_state.data_inputs_keys},
|
434 |
-
'skip': skip_sample
|
435 |
-
})
|
436 |
|
437 |
#st.set_page_config(layout='wide')
|
438 |
# Title of the app
|
@@ -448,39 +408,15 @@ div[data-testid="stMarkdownContainer"] > p {
|
|
448 |
</style>
|
449 |
""", unsafe_allow_html=True)
|
450 |
|
451 |
-
|
452 |
-
def add_annotation_guidelines():
|
453 |
-
st.write(f"username is {st.session_state.user_id}")
|
454 |
-
st.markdown(
|
455 |
-
"<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
|
456 |
-
, unsafe_allow_html=True)
|
457 |
-
|
458 |
-
|
459 |
-
#st.session_state.phase = get_param_from_url("phase")
|
460 |
# Load the data to annotate
|
461 |
if 'data' not in st.session_state:
|
462 |
-
|
463 |
-
#data = read_data()
|
464 |
-
#if st.session_state:
|
465 |
-
# st.session_state.data = data[data['batch'] == int(st.session_state.batch)]
|
466 |
-
#else:
|
467 |
-
# raise ValueError("Batch not provided")
|
468 |
-
st.session_state.data = read_data()
|
469 |
-
data = read_data()
|
470 |
-
|
471 |
-
user_id_from_url = get_param_from_url("user_id")
|
472 |
-
if user_id_from_url:
|
473 |
-
st.session_state.user_id = user_id_from_url
|
474 |
-
|
475 |
|
|
|
476 |
if 'current_index' not in st.session_state:
|
477 |
-
|
478 |
-
st.session_state.current_index = start_index
|
479 |
st.session_state.form_displayed = -3
|
480 |
|
481 |
-
if get_param_from_url('show_extra_fields'):
|
482 |
-
fields += url_conditional_fields
|
483 |
-
|
484 |
|
485 |
def add_validated_submit(fields, message):
|
486 |
st.session_state.form_displayed = st.session_state.current_index
|
@@ -494,14 +430,21 @@ def add_checked_submit():
|
|
494 |
check = st.checkbox('I agree', key='consent')
|
495 |
add_validated_submit([check], "Please agree to give your consent to proceed")
|
496 |
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
498 |
if st.session_state.current_index == -3:
|
499 |
with st.form("data_form"):
|
500 |
st.markdown(consent_text)
|
501 |
add_checked_submit()
|
502 |
|
503 |
elif st.session_state.current_index == -2:
|
504 |
-
|
|
|
|
|
505 |
navigate(1)
|
506 |
else:
|
507 |
with st.form("data_form"):
|
@@ -518,13 +461,12 @@ elif st.session_state.current_index < len(st.session_state.data):
|
|
518 |
with st.form("data_form"+str(st.session_state.current_index)):
|
519 |
show_fields(fields)
|
520 |
|
521 |
-
|
522 |
elif st.session_state.current_index == len(st.session_state.data):
|
523 |
with st.form("intro_form"):
|
524 |
show_fields(end_fields)
|
525 |
-
|
526 |
else:
|
527 |
-
st.write(f"Thank you
|
528 |
|
529 |
# Navigation buttons
|
530 |
if st.session_state.current_index > 0:
|
@@ -539,4 +481,5 @@ st.markdown(
|
|
539 |
visibility: hidden;
|
540 |
}
|
541 |
</style>""", unsafe_allow_html=True
|
542 |
-
)
|
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
|
|
9 |
from huggingface_hub import HfFileSystem
|
10 |
|
11 |
import streamlit.components.v1 as components
|
|
|
23 |
other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {})
|
24 |
|
25 |
# Function to get user ID from URL
|
26 |
+
def get_user_id_from_url():
|
27 |
+
user_id = st.query_params.get("user_id", "")
|
28 |
return user_id
|
29 |
|
30 |
+
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE")
|
31 |
+
print("is none?", HF_TOKEN is None)
|
32 |
+
hf_fs = HfFileSystem(token=HF_TOKEN)
|
33 |
|
34 |
########################################################################################
|
35 |
# CHANGE THE FOLLOWING VARIABLES ACCORDING TO YOUR NEEDS
|
36 |
|
|
|
|
|
|
|
37 |
input_repo_path = 'datasets/emvecchi/annotation'
|
38 |
output_repo_path = 'datasets/emvecchi/annotation'
|
|
|
|
|
|
|
|
|
|
|
39 |
to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate
|
40 |
COLS_TO_SAVE = ['comment_id','comment','confidence_score']
|
41 |
|
|
|
188 |
Field(name="other_comments", type="text", title="Please provide any additional details or information: *(optional)*", mandatory=False),
|
189 |
]),
|
190 |
]
|
|
|
|
|
|
|
|
|
|
|
191 |
INPUT_FIELD_DEFAULT_VALUES = {'slider': 0,
|
192 |
'text': '',
|
193 |
'textarea': '',
|
|
|
201 |
SHOW_VALIDATION_ERROR_MESSAGE = True
|
202 |
|
203 |
########################################################################################
|
204 |
+
def read_data(_path):
|
205 |
+
with hf_fs.open(input_repo_path + '/' + _path) as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
return pd.read_csv(f)
|
207 |
|
208 |
+
|
209 |
def read_saved_data():
|
210 |
_path = get_path()
|
211 |
if hf_fs.exists(output_repo_path + '/' + _path):
|
|
|
216 |
print(e)
|
217 |
return None
|
218 |
|
219 |
+
|
220 |
# Write a remote file
|
221 |
def save_data(data):
|
222 |
+
hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}")
|
|
|
223 |
with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f:
|
224 |
f.write(json.dumps(data))
|
225 |
|
226 |
|
|
|
|
|
|
|
|
|
227 |
def get_path():
|
228 |
+
return f"{st.session_state.user_id}/{st.session_state.current_index}.json"
|
229 |
|
230 |
|
231 |
def display_image(image_path):
|
|
|
276 |
st.session_state.following_mandatory = False
|
277 |
match f.type:
|
278 |
case 'input_col':
|
279 |
+
st.write(f.title)
|
280 |
if f.name == 'image_name':
|
281 |
st.write(f.title)
|
282 |
image_name = st.session_state.data.iloc[index][f.name]
|
|
|
284 |
image_path = os.path.join(input_repo_path, 'images', image_name)
|
285 |
display_image(image_path)
|
286 |
else:
|
287 |
+
st.write(st.session_state.data.iloc[index][f.name])
|
|
|
|
|
|
|
288 |
case 'markdown':
|
289 |
st.markdown(f.title)
|
290 |
case 'expander' | 'container':
|
|
|
302 |
f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title
|
303 |
|
304 |
validation_error = False
|
305 |
+
if f.mandatory or st.session_state.following_mandatory:
|
306 |
+
# form is not displayed for first time
|
307 |
+
if st.session_state.form_displayed == st.session_state.current_index:
|
|
|
308 |
if st.session_state[key] == INPUT_FIELD_DEFAULT_VALUES[f.type]:
|
309 |
st.session_state.valid = False
|
310 |
validation_error = True
|
311 |
elif f.following_mandatory_values and st.session_state[key] in f.following_mandatory_values:
|
312 |
st.session_state.following_mandatory = True
|
|
|
|
|
313 |
f.title += " :red[* required!]" if (validation_error and not SHOW_VALIDATION_ERROR_MESSAGE) else' :red[*]'
|
314 |
f.help = None
|
315 |
|
|
|
371 |
data_collected = read_saved_data()
|
372 |
st.session_state.data_inputs_keys = []
|
373 |
st.session_state.following_mandatory = False
|
374 |
+
|
375 |
for field in fields:
|
376 |
show_field(field, index, data_collected)
|
377 |
|
378 |
submitted = st.form_submit_button("Submit")
|
379 |
if submitted:
|
380 |
+
if not st.session_state.valid:
|
|
|
381 |
st.error("Please fill in all mandatory fields")
|
382 |
# st.rerun() # filed-out values are not shown otherwise
|
383 |
else:
|
384 |
with st.spinner(text="saving"):
|
385 |
+
save_data({
|
386 |
+
'user_id': st.session_state.user_id,
|
387 |
+
'index': st.session_state.current_index,
|
388 |
+
**(st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict() if index >= 0 else {}),
|
389 |
+
**{k: st.session_state[k+str(index)] for k in st.session_state.data_inputs_keys}
|
390 |
+
})
|
391 |
st.success("Feedback submitted successfully!")
|
392 |
navigate(1)
|
393 |
|
394 |
st.session_state.form_displayed = st.session_state.current_index
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
#st.set_page_config(layout='wide')
|
398 |
# Title of the app
|
|
|
408 |
</style>
|
409 |
""", unsafe_allow_html=True)
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
# Load the data to annotate
|
412 |
if 'data' not in st.session_state:
|
413 |
+
st.session_state.data = read_data(to_annotate_file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
+
# Initialize the current index
|
416 |
if 'current_index' not in st.session_state:
|
417 |
+
st.session_state.current_index = -3
|
|
|
418 |
st.session_state.form_displayed = -3
|
419 |
|
|
|
|
|
|
|
420 |
|
421 |
def add_validated_submit(fields, message):
|
422 |
st.session_state.form_displayed = st.session_state.current_index
|
|
|
430 |
check = st.checkbox('I agree', key='consent')
|
431 |
add_validated_submit([check], "Please agree to give your consent to proceed")
|
432 |
|
433 |
+
|
434 |
+
def add_annotation_guidelines():
|
435 |
+
st.write(f"username is {st.session_state.user_id}")
|
436 |
+
st.markdown(
|
437 |
+
"<details open><summary>Annotation Guidelines</summary>" + guidelines_text + "</details>"
|
438 |
+
, unsafe_allow_html=True)
|
439 |
if st.session_state.current_index == -3:
|
440 |
with st.form("data_form"):
|
441 |
st.markdown(consent_text)
|
442 |
add_checked_submit()
|
443 |
|
444 |
elif st.session_state.current_index == -2:
|
445 |
+
user_id_from_url = get_user_id_from_url()
|
446 |
+
if user_id_from_url:
|
447 |
+
st.session_state.user_id = user_id_from_url
|
448 |
navigate(1)
|
449 |
else:
|
450 |
with st.form("data_form"):
|
|
|
461 |
with st.form("data_form"+str(st.session_state.current_index)):
|
462 |
show_fields(fields)
|
463 |
|
|
|
464 |
elif st.session_state.current_index == len(st.session_state.data):
|
465 |
with st.form("intro_form"):
|
466 |
show_fields(end_fields)
|
467 |
+
|
468 |
else:
|
469 |
+
st.write(f"Thank you for taking part in this study!")
|
470 |
|
471 |
# Navigation buttons
|
472 |
if st.session_state.current_index > 0:
|
|
|
481 |
visibility: hidden;
|
482 |
}
|
483 |
</style>""", unsafe_allow_html=True
|
484 |
+
)
|
485 |
+
|