Spaces:
Running
Running
import json | |
import os | |
from dataclasses import dataclass, field | |
from typing import List, Optional, Dict | |
from PIL import Image | |
import pandas as pd | |
import streamlit as st | |
from huggingface_hub import HfFileSystem | |
class Field: | |
type: str | |
title: str | |
name: str = None | |
help: Optional[str] = None | |
children: Optional[List['Field']] = None | |
other_params: Optional[Dict[str, object]] = field(default_factory=lambda: {}) | |
# Function to get user ID from URL | |
def get_user_id_from_url(): | |
user_id = st.query_params.get("user_id", "") | |
return user_id | |
HF_TOKEN = os.environ.get("HF_TOKEN_WRITE") | |
print("is none?", HF_TOKEN is None) | |
hf_fs = HfFileSystem(token=HF_TOKEN) | |
input_repo_path = 'datasets/emvecchi/annotate-pilot' | |
output_repo_path = 'datasets/emvecchi/annotate-pilot' | |
to_annotate_file_name = 'to_annotate.csv' # CSV file to annotate | |
COLS_TO_SAVE = ['comment_id'] | |
agreement_labels = ['strongly disagree', 'disagree', 'neither agree no disagree', 'agree', 'strongly agree'] | |
quality_labels = ['very poor', 'poor', 'acceptable', 'good', 'very good'] | |
priority_labels = ['not a priority', 'low priority', 'neutral', 'moderate priority', 'high priority'] | |
default_labels = agreement_labels | |
function_choices = ['Broadening Discussion', | |
'Improving Comment Quality', | |
'Content Correction', | |
'Keeping Discussion on Topic', | |
'Organizing Discussion', | |
'Policing', | |
'Resolving Site Use Issues', | |
'Social Functions', | |
'Other (please specify)'] | |
property_choices = ['appropriateness', | |
'clarity', | |
'constructiveness', | |
'common good', | |
'effectiveness', | |
'emotion', | |
'impact', | |
'overall quality', | |
'proposal', | |
'Q for justification', | |
'storytelling', | |
'rationality', | |
'reasonableness', | |
'reciprocity', | |
'reference', | |
'Other (please specify)'] | |
default_choices = function_choices | |
fields: List[Field] = [ | |
Field(name="topic", type="input_col", title="**Topic:**"), | |
st.markdown("""<details open> | |
<summary>**Preceeding Comment:**</summary> | |
Field(name="parent_comment", type="input_col", title="**Preceeding Comment:**"), | |
</details> | |
""", unsafe_allow_html=True) | |
#Field(name="parent_comment", type="input_col", title="**Preceeding Comment:**"), | |
Field(name="comment", type="input_col", title="**Comment:**"), | |
Field(name="image_name", type="input_col", title="**Visualization of high contributing properties:**"), | |
Field(type="container", title="**Need for Moderation**", children=[ | |
Field(name="to_moderate", type="radio", | |
title="Do feel this comment/discussion would benefit from moderator intervention?"), | |
Field(name="actions_clear", type="select_slider", | |
title="With what level of **priority** would you need to interact with this comment?", other_params={'labels': priority_labels}), | |
]), | |
Field(type="container", title="**Moderation Function**", children=[ | |
Field(name="mod_function", type="multiselect", | |
title="What type of moderation function is needed here? *(Multiple selection possible)*"), | |
Field(name="mod_function_other", type="text", title="*If Other, please specify:*"), | |
]), | |
Field(type="container", title="**Contributing properties**", children=[ | |
Field(name="relevant_properties", type="multiselect", | |
title="Which property(s) is most impactful in your assessment? *(Multiple selection possible)*", other_params={'choices': property_choices}), | |
Field(name="relevant_properties_other", type="text", title="*If Other, please specify:*"), | |
]), | |
Field(name="other_comments", type="text", title="Additional comments: *(optional)*"), | |
] | |
INPUT_FIELD_DEFAULT_VALUES = {'slider': 0, | |
'text': None, | |
'textarea': None, | |
'checkbox': False, | |
'radio': None, | |
'select_slider': 50, | |
'multiselect': None} | |
SHOW_HELP_ICON = False | |
def read_data(_path): | |
with hf_fs.open(input_repo_path + '/' + _path) as f: | |
return pd.read_csv(f) | |
def read_saved_data(): | |
_path = get_path() | |
if hf_fs.exists(output_repo_path + '/' + _path): | |
with hf_fs.open(output_repo_path + '/' + _path) as f: | |
try: | |
return json.load(f) | |
except json.JSONDecodeError as e: | |
print(e) | |
return None | |
# Write a remote file | |
def save_data(data): | |
hf_fs.mkdir(f"{output_repo_path}/{data['user_id']}") | |
with hf_fs.open(f"{output_repo_path}/{get_path()}", "w") as f: | |
f.write(json.dumps(data)) | |
def get_path(): | |
return f"{st.session_state.user_id}/{st.session_state.current_index}.json" | |
def display_image(image_path): | |
with hf_fs.open(image_path) as f: | |
img = Image.open(f) | |
st.image(img, caption='10 most contributing properties', use_column_width=True) | |
#################################### Streamlit App #################################### | |
# Function to navigate rows | |
def navigate(index_change): | |
st.session_state.current_index += index_change | |
print(st.session_state.current_index) | |
# https://discuss.streamlit.io/t/click-twice-on-button-for-changing-state/45633/2 | |
st.rerun() | |
def show_field(f: Field, index: int): | |
if f.type not in INPUT_FIELD_DEFAULT_VALUES.keys(): | |
match f.type: | |
case 'input_col': | |
st.write(f.title) | |
if f.name == 'image_name': | |
st.write(f.title) | |
image_name = st.session_state.data.iloc[index][f.name] | |
if image_name: # Ensure the image name is not empty | |
image_path = os.path.join(input_repo_path, 'images', image_name) | |
display_image(image_path) | |
else: | |
st.write(st.session_state.data.iloc[index][f.name]) | |
case 'markdown': | |
st.markdown(f.title) | |
case 'expander' | 'container': | |
with (st.expander(f.title) if f.type == 'expander' else st.container(border=True)): | |
if f.type == 'container': | |
st.markdown(f.title) | |
for child in f.children: | |
show_field(child, index) | |
else: | |
key = f.name + str(index) | |
value = st.session_state.default_values[f.name] = data_collected[f.name] if data_collected else \ | |
INPUT_FIELD_DEFAULT_VALUES[f.type] | |
if not SHOW_HELP_ICON: | |
f.title = f'**{f.title}**\n\n{f.help}' if f.help else f.title | |
f.help = None | |
match f.type: | |
case 'checkbox': | |
st.session_state.data_inputs[f.name] = st.checkbox(f.title, | |
key=key, | |
value=value, help=f.help) | |
case 'radio': | |
st.session_state.data_inputs[f.name] = st.radio(f.title, | |
["yes","no","other"], | |
key=key, | |
help=f.help) | |
case 'slider': | |
st.session_state.data_inputs[f.name] = st.slider(f.title, | |
min_value=0, max_value=6, step=1, | |
key=key, | |
value=value, help=f.help) | |
case 'select_slider': | |
labels = default_labels if not f.other_params.get('labels') else f.other_params.get('labels') | |
st.session_state.data_inputs[f.name] = st.select_slider(f.title, | |
options=[0, 25, 50, 75, 100], | |
format_func=lambda x: labels[x // 25], | |
key=key, | |
value=value, help=f.help) | |
case 'multiselect': | |
choices = default_choices if not f.other_params.get('choices') else f.other_params.get('choices') | |
st.session_state.data_inputs[f.name] = st.multiselect(f.title, | |
options = choices, | |
key=key, max_selections=3, | |
help=f.help) | |
case 'text': | |
st.session_state.data_inputs[f.name] = st.text_input(f.title, key=key, value=value) | |
case 'textarea': | |
st.session_state.data_inputs[f.name] = st.text_area(f.title, key=key, value=value) | |
# st.set_page_config(layout='wide') | |
# Title of the app | |
st.title("Moderation Prediction") | |
st.markdown( | |
"""<style> | |
div[data-testid="stMarkdownContainer"] > p { | |
font-size: 1rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown( | |
"""<details open> | |
<summary>Annotation Guidelines</summary> | |
some guidelines here | |
</details> | |
""", unsafe_allow_html=True) | |
# Load the data to annotate | |
if 'data' not in st.session_state: | |
st.session_state.data = read_data(to_annotate_file_name) | |
# Initialize the current index | |
if 'current_index' not in st.session_state: | |
st.session_state.current_index = -1 | |
if st.session_state.current_index == -1: | |
user_id_from_url = get_user_id_from_url() | |
if user_id_from_url: | |
st.session_state.user_id = user_id_from_url | |
navigate(1) | |
else: | |
st.session_state.user_id = st.text_input('Please enter your user ID to proceed', value=user_id_from_url) | |
if st.button("Next"): | |
navigate(1) | |
elif st.session_state.current_index < len(st.session_state.data): | |
st.write(f"username is {st.session_state.user_id}") | |
# Creating the form | |
with st.form("feedback_form"): | |
index = st.session_state.current_index | |
data_collected = read_saved_data() | |
st.session_state.default_values = {} | |
st.session_state.data_inputs = {} | |
for field in fields: | |
if field.name not in st.session_state.data.columns: | |
# Field doesn't exist in input dataframe, add it with a default value | |
st.session_state.data_inputs[field.name] = None | |
show_field(field, index) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
with st.spinner(text="saving"): | |
save_data({ | |
'user_id': st.session_state.user_id, | |
'index': st.session_state.current_index, | |
**st.session_state.data.iloc[index][COLS_TO_SAVE].to_dict(), | |
**st.session_state.data_inputs | |
}) | |
st.success("Feedback submitted successfully!") | |
navigate(1) | |
else: | |
st.write("Finished all data points!") | |
# Navigation buttons | |
if st.session_state.current_index > 0: | |
if st.button("Previous"): | |
with st.spinner(text="in progress"): | |
navigate(-1) | |
if 0 <= st.session_state.current_index < len(st.session_state.data): | |
st.write(f"Page {st.session_state.current_index + 1} out of {len(st.session_state.data)}") | |
# disable text input enter to submit | |
# https://discuss.streamlit.io/t/text-input-how-to-disable-press-enter-to-apply/14457/6 | |
import streamlit.components.v1 as components | |
components.html( | |
""" | |
<script> | |
const inputs = window.parent.document.querySelectorAll('input'); | |
inputs.forEach(input => { | |
input.addEventListener('keydown', function(event) { | |
if (event.key === 'Enter') { | |
event.preventDefault(); | |
} | |
}); | |
}); | |
</script> | |
""", | |
height=0 | |
) | |
st.markdown( | |
"""<style> | |
div[data-testid="InputInstructions"] { | |
visibility: hidden; | |
} | |
</style>""", unsafe_allow_html=True | |
) | |