Spaces:
Running
Running
from argparse import ArgumentParser | |
import os, threading, time, yaml | |
from dataclasses import dataclass | |
from datetime import datetime, timezone | |
from pprint import pprint | |
import numpy as np | |
import gradio as gr | |
from huggingface_hub import HfApi | |
class VideoCfg: | |
value: str | |
label: str | |
visible: bool | |
show_share_button: bool = False | |
show_download_button: bool = False | |
show_label: bool = True | |
width: int | str | None = None | |
def time_now(): | |
return datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%s") | |
class SurveyEngine: | |
def __init__(self, args): | |
self.args = args | |
self.api = HfApi() | |
# check user id and platform | |
self._generate_platform_completion_message() | |
# pull videos from hf dataset to hf space | |
self.repo_id = "jinggogogo/survey-images" | |
self.local_dir = "./survey-images" | |
self.api.snapshot_download(self.repo_id, local_dir=self.local_dir, repo_type="dataset") | |
# create a file to record the user study | |
response_file_name = f"{self.args.task}_{self.args.method_filter[0]}_{self.args.method_filter[1]}_{time_now()}.csv" | |
self.response_file_path_local = f"{self.local_dir}/response/{response_file_name}" | |
self.response_file_path_remote = f"response/{response_file_name}" | |
os.makedirs(f"{self.local_dir}/response", exist_ok=True) | |
if self.args.ask_user_id: | |
csv_header = "timestamp,path1,path2,selection,user_id,id_platform\n" | |
else: | |
csv_header = "timestamp,path1,path2,selection\n" | |
self._update_local_file(csv_header, self.response_file_path_local) | |
# create a file to record optional feedback | |
feedback_file_name = f"{time_now()}.txt" | |
self.feedback_file_path_local = f"{self.local_dir}/optional_feedback/{feedback_file_name}" | |
self.feedback_file_path_remote = f"optional_feedback/{feedback_file_name}" | |
os.makedirs(f"{self.local_dir}/optional_feedback", exist_ok=True) | |
self._update_local_file("", self.feedback_file_path_local) | |
self.video_paths, self.N_prompts, self.N_methods = self._get_all_video_paths() | |
self.theme = gr.themes.Base( | |
text_size=gr.themes.sizes.text_lg, | |
spacing_size=gr.themes.sizes.spacing_sm, | |
radius_size=gr.themes.sizes.radius_md, | |
) | |
if not args.no_sync: | |
self._start_periodic_sync() | |
def _generate_platform_completion_message(self): | |
if self.args.ask_user_id: | |
# get link from env | |
platform_completion_link = os.getenv("PLATFORM_COMPLETION_LINK") | |
if platform_completion_link is None: | |
raise ValueError("Please provide the platform completion link.") | |
# generate completion message | |
self.platform_completion_message = f"## Your {self.args.id_platform} completion link is [here]({platform_completion_link})." | |
else: | |
self.platform_completion_message = "" | |
def _start_periodic_sync(self): | |
def _upload_periodically(path_local, path_remote): | |
while True: | |
print(time_now()) | |
print(f"Uploading {path_local}.") | |
try: | |
self._update_remote_file(path_local, path_remote) | |
except Exception as e: | |
print(e) | |
time.sleep(args.period_upload) | |
def _squash_commits_periodically(): | |
while True: | |
print(time_now()) | |
print("Squashing commits.") | |
try: | |
self.api.super_squash_history("jinggogogo/survey-images", repo_type="dataset") | |
except Exception as e: | |
print(e) | |
time.sleep(args.period_squash) | |
thread_upload_response = threading.Thread( | |
target=_upload_periodically, | |
args=(self.response_file_path_local, self.response_file_path_remote), | |
) | |
thread_upload_response.daemon = True | |
thread_upload_response.start() | |
thread_upload_feedback = threading.Thread( | |
target=_upload_periodically, | |
args=(self.feedback_file_path_local, self.feedback_file_path_remote), | |
) | |
thread_upload_feedback.daemon = True | |
thread_upload_feedback.start() | |
thread_squash_commits = threading.Thread(target=_squash_commits_periodically) | |
thread_squash_commits.daemon = True | |
thread_squash_commits.start() | |
def _get_all_video_paths(self): | |
video_dir = f"{self.local_dir}/images" | |
method_list = sorted(os.listdir(video_dir)) | |
# filter methods | |
if len(self.args.method_filter) > 0: | |
method_filter = np.array(self.args.method_filter) | |
method_list = np.intersect1d(method_list, method_filter) | |
video_name_list = sorted(os.listdir(f"{video_dir}/{method_list[0]}")) | |
N_prompts = len(video_name_list) | |
N_methods = len(method_list) | |
video_paths = [] # (N_prompts, N_methods) | |
for video_name in video_name_list: | |
paths = [os.path.join(video_dir, method, video_name) for method in method_list] | |
video_paths.append(paths) | |
video_paths = np.array(video_paths) | |
return video_paths, N_prompts, N_methods | |
def _sample_video_pair(self, videos_left): | |
""" | |
videos_left: (N_prompts, N_methods) | |
""" | |
# random choose two prompts | |
N_videos_left = len(videos_left) | |
prompt_ids = np.random.choice(N_videos_left, 2, replace=False) | |
video_pair1 = videos_left[prompt_ids[0]] | |
video_pair1 = np.random.permutation(video_pair1) | |
video_pair2 = videos_left[prompt_ids[1]] | |
video_pair2 = np.random.permutation(video_pair2) | |
# update videos_left | |
# print(f"N_video_left before: {len(videos_left)}") | |
videos_left = np.delete(videos_left, prompt_ids, axis=0) | |
# print(f"N_video_left after: {len(videos_left)}") | |
radio_select_to_path_lut = { | |
"Image Set 1 π": str(video_pair1[0]), | |
"Image Set 2 π": str(video_pair1[1]), | |
"Image Set 3 π": str(video_pair2[0]), | |
"Image Set 4 π": str(video_pair2[1]), | |
"Similar π€": "Similar", | |
} | |
print("---------------") | |
print(time_now()) | |
pprint(radio_select_to_path_lut) | |
return video_pair1, video_pair2, radio_select_to_path_lut, videos_left | |
def _update_local_file(self, message, file_path_local): | |
with open(file_path_local, "a") as f: | |
f.write(message) | |
def _update_remote_file(self, file_path_local, file_path_remote): | |
self.api.upload_file( | |
path_or_fileobj=file_path_local, | |
path_in_repo=file_path_remote, | |
repo_id=self.repo_id, | |
repo_type="dataset", | |
) | |
def _setup_video(self, path_a, path_b, label_a, label_b, visible): | |
cfg_a = VideoCfg(value=path_a, label=label_a, visible=visible) | |
cfg_b = VideoCfg(value=path_b, label=label_b, visible=visible) | |
video_a = gr.Image(**cfg_a.__dict__) | |
video_b = gr.Image(**cfg_b.__dict__) | |
return video_a, video_b | |
def _load_callback(self, videos_left): | |
( | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
) = self._sample_video_pair(videos_left) | |
update_video1 = gr.update(value=video_pair1[0]) | |
update_video2 = gr.update(value=video_pair1[1]) | |
update_video3 = gr.update(value=video_pair2[0]) | |
update_video4 = gr.update(value=video_pair2[1]) | |
update_md_run_out_videos = gr.update(visible=False) | |
return ( | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
update_video1, | |
update_video2, | |
update_video3, | |
update_video4, | |
update_md_run_out_videos, | |
) | |
def _click_select_radio(self): | |
update1 = gr.update(visible=True) | |
return update1 | |
def _click_button_confirm1(self, radio_select, radio_select_to_path_lut, user_id): | |
# update response file with the acutal file path | |
selected_path = radio_select_to_path_lut[radio_select] | |
path1 = radio_select_to_path_lut["Image Set 1 π"] | |
path2 = radio_select_to_path_lut["Image Set 2 π"] | |
if self.args.ask_user_id: | |
id_platform = self.args.id_platform | |
message = f"{time_now()},{path1},{path2},{selected_path},{user_id},{id_platform}\n" | |
else: | |
message = f"{time_now()},{path1},{path2},{selected_path}\n" | |
self._update_local_file(message, self.response_file_path_local) | |
confirm_message = f""" | |
Your selection was: | |
<span style="font-size:20px; color:orange "> {radio_select} </span> \n\n | |
""" | |
# display confirm message | |
update_md_confirm1 = gr.update(visible=True, value=confirm_message) | |
# hide the radio and button for video 1-2 | |
update_button_confirm1 = gr.update(visible=False) | |
update_ratio_select1 = gr.update(visible=False) | |
# show video 3-4 and radio | |
update_md_pair_34 = gr.update(visible=True) | |
update_video_3 = gr.update(visible=True) | |
update_video_4 = gr.update(visible=True) | |
update_radio_select2 = gr.update(visible=True) | |
return ( | |
update_md_confirm1, | |
update_button_confirm1, | |
update_ratio_select1, | |
update_md_pair_34, | |
update_video_3, | |
update_video_4, | |
update_radio_select2, | |
) | |
def _click_button_confirm2(self, radio_select, radio_select_to_path_lut, user_id): | |
# update response file with the acutal file path | |
selected_path = radio_select_to_path_lut[radio_select] | |
path1 = radio_select_to_path_lut["Image Set 3 π"] | |
path2 = radio_select_to_path_lut["Image Set 4 π"] | |
if self.args.ask_user_id: | |
id_platform = self.args.id_platform | |
message = f"{time_now()},{path1},{path2},{selected_path},{user_id},{id_platform}\n" | |
else: | |
message = f"{time_now()},{path1},{path2},{selected_path}\n" | |
self._update_local_file(message, self.response_file_path_local) | |
confirm_message = f""" | |
Your selection was: | |
<span style="font-size:20px; color:orange "> {radio_select} </span> \n\n | |
## Study Done! | |
{self.platform_completion_message} | |
Click the button below π² if you'd like to evaluate another set. \n\n | |
You can exit this study by closing this page. | |
For more details about this study, click | |
[here](https://huggingface.co/spaces/zirui-wang/video_quality_study/blob/main/detail.md). | |
""" | |
# display confirm message | |
update_md_confirm2 = gr.update(visible=True, value=confirm_message) | |
# hide the radio and button for video 3-4 | |
update_button_confirm2 = gr.update(visible=False) | |
update_radio_select2 = gr.update(visible=False) | |
# show button_new | |
update_button_new = gr.update(visible=True) | |
# show textbox and button for optional feedback | |
update_textbox_optional = gr.update(visible=True) | |
update_button_submit_optional = gr.update(visible=True) | |
return ( | |
update_md_confirm2, | |
update_button_confirm2, | |
update_radio_select2, | |
update_button_new, | |
update_textbox_optional, | |
update_button_submit_optional, | |
) | |
def _click_button_new(self, videos_left): | |
if len(videos_left) == 0: | |
return [None] * 3 + [[]] + [gr.update(visible=False)] * 15 + [gr.update(visible=True)] | |
print("---------------") | |
print(f"N_video_left before: {len(videos_left)}") | |
( | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
) = self._sample_video_pair(videos_left) | |
update_video1 = gr.update(value=video_pair1[0]) | |
update_video2 = gr.update(value=video_pair1[1]) | |
update_radio_select1 = gr.update(visible=True, value=None) | |
update_button_confirm1 = gr.update(visible=False) | |
update_md_confirm1 = gr.update(visible=False) | |
update_md_pair_34 = gr.update(visible=False) | |
update_video3 = gr.update(value=video_pair2[0], visible=False) | |
update_video4 = gr.update(value=video_pair2[1], visible=False) | |
update_radio_select2 = gr.update(visible=False, value=None) | |
update_button_confirm2 = gr.update(visible=False) | |
update_md_confirm2 = gr.update(visible=False) | |
update_button_new = gr.update(visible=False) | |
update_textbox_optional = gr.update(visible=False, value=None) | |
update_button_submit_optional = gr.update(visible=False) | |
update_md_optional_feedback = gr.update(visible=False) | |
update_md_run_out_videos = gr.update(visible=False) | |
return ( | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
update_video1, | |
update_video2, | |
update_radio_select1, | |
update_button_confirm1, | |
update_md_confirm1, | |
update_md_pair_34, | |
update_video3, | |
update_video4, | |
update_radio_select2, | |
update_button_confirm2, | |
update_md_confirm2, | |
update_button_new, | |
update_textbox_optional, | |
update_button_submit_optional, | |
update_md_optional_feedback, | |
update_md_run_out_videos, | |
) | |
def _click_button_optional_feedback(self, textbox_optional_feedback): | |
if textbox_optional_feedback == "": | |
return gr.skip(), gr.skip() | |
message = f"{time_now()}\n{textbox_optional_feedback}\n\n" | |
self._update_local_file(message, self.feedback_file_path_local) | |
update_md_optional_feedback = gr.update(visible=True) | |
update_button_submit_optional = gr.update(visible=False) | |
return update_md_optional_feedback, update_button_submit_optional | |
def _click_button_submit_user_id(self, textbox_user_id): | |
user_id = str(textbox_user_id).replace(",", "_").replace("\n", "_").replace(" ", "_") | |
if user_id == "": | |
return [gr.skip()] * 8 | |
update_textbox_user_id = gr.update(interactive=False) | |
update_button_submit_user_id = gr.update(visible=False) | |
update_md_user_id = gr.update(visible=True) | |
update_md_pair_12 = gr.update(visible=True) | |
update_video1 = gr.update(visible=True) | |
update_video2 = gr.update(visible=True) | |
update_radio_select1 = gr.update(visible=True) | |
return ( | |
user_id, | |
update_textbox_user_id, | |
update_button_submit_user_id, | |
update_md_user_id, | |
update_md_pair_12, | |
update_video1, | |
update_video2, | |
update_radio_select1, | |
) | |
def main(self): | |
# read in md file | |
with open("start.md", "r") as f: | |
md_start = f.read() | |
with gr.Blocks(theme=self.theme, title="Image Quality User Study") as demo: | |
# set up session states | |
# random pop videos from this list to get video pairs | |
videos_left = gr.State(value=self.video_paths, time_to_live=900) | |
video_pair1 = gr.State(value=["path1", "path2"], time_to_live=900) | |
video_pair2 = gr.State(value=["path2", "path4"], time_to_live=900) | |
radio_select_to_path_lut = gr.State(value={}, time_to_live=900) # hold a dict | |
user_id = gr.State(value="", time_to_live=900) | |
# set up layout | |
with gr.Column(): | |
gr.Markdown(md_start) | |
# a debug button | |
if self.args.debug: | |
def _click_button_debug( | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
user_id, | |
): | |
print(f"video_pair1: {video_pair1}") | |
print(f"video_pair2: {video_pair2}") | |
print(f"radio_select_to_path_lut: {radio_select_to_path_lut}") | |
print(f"N videos_left: {len(videos_left)}") | |
print(f"user_id: {user_id}") | |
button_debug = gr.Button("debug", variant="primary", scale=1) | |
button_debug.click( | |
_click_button_debug, | |
inputs=[ | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
user_id, | |
], | |
) | |
# ---------------- optional user id ---------------- | |
if self.args.ask_user_id: | |
with gr.Row(): | |
textbox_user_id = gr.Textbox( | |
label=f"Please enter your {self.args.id_platform} ID: ", | |
placeholder="Type here...", | |
scale=4, | |
lines=1, | |
interactive=True, | |
) | |
button_submit_user_id = gr.Button("Submit", variant="primary", scale=1) | |
md_user_id = gr.Markdown( | |
"Thank you for providing your ID! π", visible=False | |
) | |
# ---------------- video 1-2 ---------------- | |
md_pair_12 = gr.Markdown("## Image Set Pair 1", visible=False) | |
with gr.Row(): | |
video1, video2 = self._setup_video( | |
path_a=video_pair1.value[0], | |
path_b=video_pair1.value[1], | |
label_a="Image Set 1", | |
label_b="Image Set 2", | |
visible=False, | |
) | |
with gr.Row(): | |
radio_select1 = gr.Radio( | |
choices=["Image Set 1 π", "Image Set 2 π", "Similar π€"], | |
label="Your Preference:", | |
scale=2, | |
visible=False, | |
) | |
button_confirm1 = gr.Button( | |
value="Confirm", | |
variant="primary", | |
scale=1, | |
visible=False, | |
) | |
md_confirm1 = gr.Markdown(visible=False, inputs=button_confirm1) | |
# ---------------- video 3-4 ---------------- | |
md_pair_34 = gr.Markdown("## Image Set Pair 2", visible=False) | |
with gr.Row(): | |
video3, video4 = self._setup_video( | |
path_a=video_pair2.value[0], | |
path_b=video_pair2.value[1], | |
label_a="Image Set 3", | |
label_b="Image Set 4", | |
visible=False, | |
) | |
with gr.Row(): | |
radio_select2 = gr.Radio( | |
choices=["Image Set 3 π", "Image Set 4 π", "Similar π€"], | |
label="Your Preference:", | |
scale=2, | |
visible=False, | |
) | |
button_confirm2 = gr.Button( | |
value="Confirm", | |
variant="primary", | |
scale=1, | |
visible=False, | |
) | |
md_confirm2 = gr.Markdown(visible=False, inputs=button_confirm2) | |
# ---------------- new button ---------------- | |
button_new = gr.Button( | |
value="New One π²", | |
variant="primary", | |
visible=False, | |
) | |
md_run_out_videos = gr.Markdown( | |
"You've evaluated all video pairs. Thank you for your participation! π", | |
visible=False, | |
) | |
# ---------------- optional feedback ---------------- | |
with gr.Row(): | |
textbox_optional_feedback = gr.Textbox( | |
label="Optional Comments:", | |
placeholder="Type here...", | |
lines=1, | |
scale=4, | |
interactive=True, | |
visible=False, | |
) | |
button_submit_optional_feedback = gr.Button( | |
value="Submit Comments", | |
variant="secondary", | |
scale=1, | |
visible=False, | |
) | |
md_optional_feedback = gr.Markdown( | |
"Thank you for providing additional comments! π", | |
visible=False, | |
) | |
# set up callbacks | |
if self.args.ask_user_id: | |
button_submit_user_id.click( | |
self._click_button_submit_user_id, | |
trigger_mode="once", | |
inputs=textbox_user_id, | |
outputs=[ | |
user_id, | |
textbox_user_id, | |
button_submit_user_id, | |
md_user_id, | |
md_pair_12, | |
video1, | |
video2, | |
radio_select1, | |
], | |
) | |
radio_select1.select( | |
self._click_select_radio, | |
trigger_mode="once", | |
outputs=button_confirm1, | |
) | |
button_confirm1.click( | |
self._click_button_confirm1, | |
trigger_mode="once", | |
inputs=[radio_select1, radio_select_to_path_lut, user_id], | |
outputs=[ | |
md_confirm1, | |
button_confirm1, | |
radio_select1, | |
md_pair_34, | |
video3, | |
video4, | |
radio_select2, | |
], | |
) | |
radio_select2.select( | |
self._click_select_radio, | |
trigger_mode="once", | |
outputs=button_confirm2, | |
) | |
button_confirm2.click( | |
self._click_button_confirm2, | |
trigger_mode="once", | |
inputs=[radio_select2, radio_select_to_path_lut, user_id], | |
outputs=[ | |
md_confirm2, | |
button_confirm2, | |
radio_select2, | |
button_new, | |
textbox_optional_feedback, | |
button_submit_optional_feedback, | |
], | |
) | |
button_new.click( | |
self._click_button_new, | |
trigger_mode="once", | |
inputs=videos_left, | |
outputs=[ | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
video1, | |
video2, | |
radio_select1, | |
button_confirm1, | |
md_confirm1, | |
md_pair_34, | |
video3, | |
video4, | |
radio_select2, | |
button_confirm2, | |
md_confirm2, | |
button_new, | |
textbox_optional_feedback, | |
button_submit_optional_feedback, | |
md_optional_feedback, | |
md_run_out_videos, | |
], | |
) | |
button_submit_optional_feedback.click( | |
self._click_button_optional_feedback, | |
trigger_mode="once", | |
inputs=textbox_optional_feedback, | |
outputs=[md_optional_feedback, button_submit_optional_feedback], | |
) | |
demo.load( | |
self._load_callback, | |
inputs=videos_left, | |
outputs=[ | |
video_pair1, | |
video_pair2, | |
radio_select_to_path_lut, | |
videos_left, | |
video1, | |
video2, | |
video3, | |
video4, | |
md_run_out_videos, | |
], | |
) | |
demo.launch(share=self.args.share, show_api=False) | |
def parse_args(): | |
parser = ArgumentParser() | |
# use this config as HF space does not take args | |
parser.add_argument("--config", type=str, default="config.yaml") | |
# these args are useful for local debugging | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--no_sync", action="store_true") | |
args = parser.parse_args() | |
with open(args.config, "r") as f: | |
config = yaml.safe_load(f) | |
for key, value in config.items(): | |
setattr(args, key, value) | |
pprint(vars(args)) | |
return args | |
if __name__ == "__main__": | |
args = parse_args() | |
survey_engine = SurveyEngine(args) | |
survey_engine.main() | |