from argparse import ArgumentParser
import os, threading, time, yaml
from dataclasses import dataclass
from datetime import datetime, timezone
from pprint import pprint
import numpy as np
import gradio as gr
from huggingface_hub import HfApi
@dataclass
class VideoCfg:
value: str
label: str
visible: bool
show_share_button: bool = False
show_download_button: bool = False
show_label: bool = True
width: int | str | None = None
def time_now():
return datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%s")
class SurveyEngine:
def __init__(self, args):
self.args = args
self.api = HfApi()
# check user id and platform
self._generate_platform_completion_message()
# pull videos from hf dataset to hf space
self.repo_id = "jinggogogo/survey-images"
self.local_dir = "./survey-images"
self.api.snapshot_download(self.repo_id, local_dir=self.local_dir, repo_type="dataset")
# create a file to record the user study
response_file_name = f"{self.args.task}_{self.args.method_filter[0]}_{self.args.method_filter[1]}_{time_now()}.csv"
self.response_file_path_local = f"{self.local_dir}/response/{response_file_name}"
self.response_file_path_remote = f"response/{response_file_name}"
os.makedirs(f"{self.local_dir}/response", exist_ok=True)
if self.args.ask_user_id:
csv_header = "timestamp,path1,path2,selection,user_id,id_platform\n"
else:
csv_header = "timestamp,path1,path2,selection\n"
self._update_local_file(csv_header, self.response_file_path_local)
# create a file to record optional feedback
feedback_file_name = f"{time_now()}.txt"
self.feedback_file_path_local = f"{self.local_dir}/optional_feedback/{feedback_file_name}"
self.feedback_file_path_remote = f"optional_feedback/{feedback_file_name}"
os.makedirs(f"{self.local_dir}/optional_feedback", exist_ok=True)
self._update_local_file("", self.feedback_file_path_local)
self.video_paths, self.N_prompts, self.N_methods = self._get_all_video_paths()
self.theme = gr.themes.Base(
text_size=gr.themes.sizes.text_lg,
spacing_size=gr.themes.sizes.spacing_sm,
radius_size=gr.themes.sizes.radius_md,
)
if not args.no_sync:
self._start_periodic_sync()
def _generate_platform_completion_message(self):
if self.args.ask_user_id:
# get link from env
platform_completion_link = os.getenv("PLATFORM_COMPLETION_LINK")
if platform_completion_link is None:
raise ValueError("Please provide the platform completion link.")
# generate completion message
self.platform_completion_message = f"## Your {self.args.id_platform} completion link is [here]({platform_completion_link})."
else:
self.platform_completion_message = ""
def _start_periodic_sync(self):
def _upload_periodically(path_local, path_remote):
while True:
print(time_now())
print(f"Uploading {path_local}.")
try:
self._update_remote_file(path_local, path_remote)
except Exception as e:
print(e)
time.sleep(args.period_upload)
def _squash_commits_periodically():
while True:
print(time_now())
print("Squashing commits.")
try:
self.api.super_squash_history("jinggogogo/survey-images", repo_type="dataset")
except Exception as e:
print(e)
time.sleep(args.period_squash)
thread_upload_response = threading.Thread(
target=_upload_periodically,
args=(self.response_file_path_local, self.response_file_path_remote),
)
thread_upload_response.daemon = True
thread_upload_response.start()
thread_upload_feedback = threading.Thread(
target=_upload_periodically,
args=(self.feedback_file_path_local, self.feedback_file_path_remote),
)
thread_upload_feedback.daemon = True
thread_upload_feedback.start()
thread_squash_commits = threading.Thread(target=_squash_commits_periodically)
thread_squash_commits.daemon = True
thread_squash_commits.start()
def _get_all_video_paths(self):
video_dir = f"{self.local_dir}/images"
method_list = sorted(os.listdir(video_dir))
# filter methods
if len(self.args.method_filter) > 0:
method_filter = np.array(self.args.method_filter)
method_list = np.intersect1d(method_list, method_filter)
video_name_list = sorted(os.listdir(f"{video_dir}/{method_list[0]}"))
N_prompts = len(video_name_list)
N_methods = len(method_list)
video_paths = [] # (N_prompts, N_methods)
for video_name in video_name_list:
paths = [os.path.join(video_dir, method, video_name) for method in method_list]
video_paths.append(paths)
video_paths = np.array(video_paths)
return video_paths, N_prompts, N_methods
def _sample_video_pair(self, videos_left):
"""
videos_left: (N_prompts, N_methods)
"""
# random choose two prompts
N_videos_left = len(videos_left)
prompt_ids = np.random.choice(N_videos_left, 2, replace=False)
video_pair1 = videos_left[prompt_ids[0]]
video_pair1 = np.random.permutation(video_pair1)
video_pair2 = videos_left[prompt_ids[1]]
video_pair2 = np.random.permutation(video_pair2)
# update videos_left
# print(f"N_video_left before: {len(videos_left)}")
videos_left = np.delete(videos_left, prompt_ids, axis=0)
# print(f"N_video_left after: {len(videos_left)}")
radio_select_to_path_lut = {
"Image Set 1 👍": str(video_pair1[0]),
"Image Set 2 👍": str(video_pair1[1]),
"Image Set 3 👍": str(video_pair2[0]),
"Image Set 4 👍": str(video_pair2[1]),
"Similar 🤔": "Similar",
}
print("---------------")
print(time_now())
pprint(radio_select_to_path_lut)
return video_pair1, video_pair2, radio_select_to_path_lut, videos_left
def _update_local_file(self, message, file_path_local):
with open(file_path_local, "a") as f:
f.write(message)
def _update_remote_file(self, file_path_local, file_path_remote):
self.api.upload_file(
path_or_fileobj=file_path_local,
path_in_repo=file_path_remote,
repo_id=self.repo_id,
repo_type="dataset",
)
def _setup_video(self, path_a, path_b, label_a, label_b, visible):
cfg_a = VideoCfg(value=path_a, label=label_a, visible=visible)
cfg_b = VideoCfg(value=path_b, label=label_b, visible=visible)
video_a = gr.Image(**cfg_a.__dict__)
video_b = gr.Image(**cfg_b.__dict__)
return video_a, video_b
def _load_callback(self, videos_left):
(
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
) = self._sample_video_pair(videos_left)
update_video1 = gr.update(value=video_pair1[0])
update_video2 = gr.update(value=video_pair1[1])
update_video3 = gr.update(value=video_pair2[0])
update_video4 = gr.update(value=video_pair2[1])
update_md_run_out_videos = gr.update(visible=False)
return (
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
update_video1,
update_video2,
update_video3,
update_video4,
update_md_run_out_videos,
)
def _click_select_radio(self):
update1 = gr.update(visible=True)
return update1
def _click_button_confirm1(self, radio_select, radio_select_to_path_lut, user_id):
# update response file with the acutal file path
selected_path = radio_select_to_path_lut[radio_select]
path1 = radio_select_to_path_lut["Image Set 1 👍"]
path2 = radio_select_to_path_lut["Image Set 2 👍"]
if self.args.ask_user_id:
id_platform = self.args.id_platform
message = f"{time_now()},{path1},{path2},{selected_path},{user_id},{id_platform}\n"
else:
message = f"{time_now()},{path1},{path2},{selected_path}\n"
self._update_local_file(message, self.response_file_path_local)
confirm_message = f"""
Your selection was:
{radio_select} \n\n
"""
# display confirm message
update_md_confirm1 = gr.update(visible=True, value=confirm_message)
# hide the radio and button for video 1-2
update_button_confirm1 = gr.update(visible=False)
update_ratio_select1 = gr.update(visible=False)
# show video 3-4 and radio
update_md_pair_34 = gr.update(visible=True)
update_video_3 = gr.update(visible=True)
update_video_4 = gr.update(visible=True)
update_radio_select2 = gr.update(visible=True)
return (
update_md_confirm1,
update_button_confirm1,
update_ratio_select1,
update_md_pair_34,
update_video_3,
update_video_4,
update_radio_select2,
)
def _click_button_confirm2(self, radio_select, radio_select_to_path_lut, user_id):
# update response file with the acutal file path
selected_path = radio_select_to_path_lut[radio_select]
path1 = radio_select_to_path_lut["Image Set 3 👍"]
path2 = radio_select_to_path_lut["Image Set 4 👍"]
if self.args.ask_user_id:
id_platform = self.args.id_platform
message = f"{time_now()},{path1},{path2},{selected_path},{user_id},{id_platform}\n"
else:
message = f"{time_now()},{path1},{path2},{selected_path}\n"
self._update_local_file(message, self.response_file_path_local)
confirm_message = f"""
Your selection was:
{radio_select} \n\n
## Study Done!
{self.platform_completion_message}
Click the button below 🎲 if you'd like to evaluate another set. \n\n
You can exit this study by closing this page.
For more details about this study, click
[here](https://huggingface.co/spaces/zirui-wang/video_quality_study/blob/main/detail.md).
"""
# display confirm message
update_md_confirm2 = gr.update(visible=True, value=confirm_message)
# hide the radio and button for video 3-4
update_button_confirm2 = gr.update(visible=False)
update_radio_select2 = gr.update(visible=False)
# show button_new
update_button_new = gr.update(visible=True)
# show textbox and button for optional feedback
update_textbox_optional = gr.update(visible=True)
update_button_submit_optional = gr.update(visible=True)
return (
update_md_confirm2,
update_button_confirm2,
update_radio_select2,
update_button_new,
update_textbox_optional,
update_button_submit_optional,
)
def _click_button_new(self, videos_left):
if len(videos_left) == 0:
return [None] * 3 + [[]] + [gr.update(visible=False)] * 15 + [gr.update(visible=True)]
print("---------------")
print(f"N_video_left before: {len(videos_left)}")
(
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
) = self._sample_video_pair(videos_left)
update_video1 = gr.update(value=video_pair1[0])
update_video2 = gr.update(value=video_pair1[1])
update_radio_select1 = gr.update(visible=True, value=None)
update_button_confirm1 = gr.update(visible=False)
update_md_confirm1 = gr.update(visible=False)
update_md_pair_34 = gr.update(visible=False)
update_video3 = gr.update(value=video_pair2[0], visible=False)
update_video4 = gr.update(value=video_pair2[1], visible=False)
update_radio_select2 = gr.update(visible=False, value=None)
update_button_confirm2 = gr.update(visible=False)
update_md_confirm2 = gr.update(visible=False)
update_button_new = gr.update(visible=False)
update_textbox_optional = gr.update(visible=False, value=None)
update_button_submit_optional = gr.update(visible=False)
update_md_optional_feedback = gr.update(visible=False)
update_md_run_out_videos = gr.update(visible=False)
return (
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
update_video1,
update_video2,
update_radio_select1,
update_button_confirm1,
update_md_confirm1,
update_md_pair_34,
update_video3,
update_video4,
update_radio_select2,
update_button_confirm2,
update_md_confirm2,
update_button_new,
update_textbox_optional,
update_button_submit_optional,
update_md_optional_feedback,
update_md_run_out_videos,
)
def _click_button_optional_feedback(self, textbox_optional_feedback):
if textbox_optional_feedback == "":
return gr.skip(), gr.skip()
message = f"{time_now()}\n{textbox_optional_feedback}\n\n"
self._update_local_file(message, self.feedback_file_path_local)
update_md_optional_feedback = gr.update(visible=True)
update_button_submit_optional = gr.update(visible=False)
return update_md_optional_feedback, update_button_submit_optional
def _click_button_submit_user_id(self, textbox_user_id):
user_id = str(textbox_user_id).replace(",", "_").replace("\n", "_").replace(" ", "_")
if user_id == "":
return [gr.skip()] * 8
update_textbox_user_id = gr.update(interactive=False)
update_button_submit_user_id = gr.update(visible=False)
update_md_user_id = gr.update(visible=True)
update_md_pair_12 = gr.update(visible=True)
update_video1 = gr.update(visible=True)
update_video2 = gr.update(visible=True)
update_radio_select1 = gr.update(visible=True)
return (
user_id,
update_textbox_user_id,
update_button_submit_user_id,
update_md_user_id,
update_md_pair_12,
update_video1,
update_video2,
update_radio_select1,
)
def main(self):
# read in md file
with open("start.md", "r") as f:
md_start = f.read()
with gr.Blocks(theme=self.theme, title="Image Quality User Study") as demo:
# set up session states
# random pop videos from this list to get video pairs
videos_left = gr.State(value=self.video_paths, time_to_live=900)
video_pair1 = gr.State(value=["path1", "path2"], time_to_live=900)
video_pair2 = gr.State(value=["path2", "path4"], time_to_live=900)
radio_select_to_path_lut = gr.State(value={}, time_to_live=900) # hold a dict
user_id = gr.State(value="", time_to_live=900)
# set up layout
with gr.Column():
gr.Markdown(md_start)
# a debug button
if self.args.debug:
def _click_button_debug(
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
user_id,
):
print(f"video_pair1: {video_pair1}")
print(f"video_pair2: {video_pair2}")
print(f"radio_select_to_path_lut: {radio_select_to_path_lut}")
print(f"N videos_left: {len(videos_left)}")
print(f"user_id: {user_id}")
button_debug = gr.Button("debug", variant="primary", scale=1)
button_debug.click(
_click_button_debug,
inputs=[
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
user_id,
],
)
# ---------------- optional user id ----------------
if self.args.ask_user_id:
with gr.Row():
textbox_user_id = gr.Textbox(
label=f"Please enter your {self.args.id_platform} ID: ",
placeholder="Type here...",
scale=4,
lines=1,
interactive=True,
)
button_submit_user_id = gr.Button("Submit", variant="primary", scale=1)
md_user_id = gr.Markdown(
"Thank you for providing your ID! 🙏", visible=False
)
# ---------------- video 1-2 ----------------
md_pair_12 = gr.Markdown("## Image Set Pair 1", visible=False)
with gr.Row():
video1, video2 = self._setup_video(
path_a=video_pair1.value[0],
path_b=video_pair1.value[1],
label_a="Image Set 1",
label_b="Image Set 2",
visible=False,
)
with gr.Row():
radio_select1 = gr.Radio(
choices=["Image Set 1 👍", "Image Set 2 👍", "Similar 🤔"],
label="Your Preference:",
scale=2,
visible=False,
)
button_confirm1 = gr.Button(
value="Confirm",
variant="primary",
scale=1,
visible=False,
)
md_confirm1 = gr.Markdown(visible=False, inputs=button_confirm1)
# ---------------- video 3-4 ----------------
md_pair_34 = gr.Markdown("## Image Set Pair 2", visible=False)
with gr.Row():
video3, video4 = self._setup_video(
path_a=video_pair2.value[0],
path_b=video_pair2.value[1],
label_a="Image Set 3",
label_b="Image Set 4",
visible=False,
)
with gr.Row():
radio_select2 = gr.Radio(
choices=["Image Set 3 👍", "Image Set 4 👍", "Similar 🤔"],
label="Your Preference:",
scale=2,
visible=False,
)
button_confirm2 = gr.Button(
value="Confirm",
variant="primary",
scale=1,
visible=False,
)
md_confirm2 = gr.Markdown(visible=False, inputs=button_confirm2)
# ---------------- new button ----------------
button_new = gr.Button(
value="New One 🎲",
variant="primary",
visible=False,
)
md_run_out_videos = gr.Markdown(
"You've evaluated all video pairs. Thank you for your participation! 🙏",
visible=False,
)
# ---------------- optional feedback ----------------
with gr.Row():
textbox_optional_feedback = gr.Textbox(
label="Optional Comments:",
placeholder="Type here...",
lines=1,
scale=4,
interactive=True,
visible=False,
)
button_submit_optional_feedback = gr.Button(
value="Submit Comments",
variant="secondary",
scale=1,
visible=False,
)
md_optional_feedback = gr.Markdown(
"Thank you for providing additional comments! 🙏",
visible=False,
)
# set up callbacks
if self.args.ask_user_id:
button_submit_user_id.click(
self._click_button_submit_user_id,
trigger_mode="once",
inputs=textbox_user_id,
outputs=[
user_id,
textbox_user_id,
button_submit_user_id,
md_user_id,
md_pair_12,
video1,
video2,
radio_select1,
],
)
radio_select1.select(
self._click_select_radio,
trigger_mode="once",
outputs=button_confirm1,
)
button_confirm1.click(
self._click_button_confirm1,
trigger_mode="once",
inputs=[radio_select1, radio_select_to_path_lut, user_id],
outputs=[
md_confirm1,
button_confirm1,
radio_select1,
md_pair_34,
video3,
video4,
radio_select2,
],
)
radio_select2.select(
self._click_select_radio,
trigger_mode="once",
outputs=button_confirm2,
)
button_confirm2.click(
self._click_button_confirm2,
trigger_mode="once",
inputs=[radio_select2, radio_select_to_path_lut, user_id],
outputs=[
md_confirm2,
button_confirm2,
radio_select2,
button_new,
textbox_optional_feedback,
button_submit_optional_feedback,
],
)
button_new.click(
self._click_button_new,
trigger_mode="once",
inputs=videos_left,
outputs=[
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
video1,
video2,
radio_select1,
button_confirm1,
md_confirm1,
md_pair_34,
video3,
video4,
radio_select2,
button_confirm2,
md_confirm2,
button_new,
textbox_optional_feedback,
button_submit_optional_feedback,
md_optional_feedback,
md_run_out_videos,
],
)
button_submit_optional_feedback.click(
self._click_button_optional_feedback,
trigger_mode="once",
inputs=textbox_optional_feedback,
outputs=[md_optional_feedback, button_submit_optional_feedback],
)
demo.load(
self._load_callback,
inputs=videos_left,
outputs=[
video_pair1,
video_pair2,
radio_select_to_path_lut,
videos_left,
video1,
video2,
video3,
video4,
md_run_out_videos,
],
)
demo.launch(share=self.args.share, show_api=False)
def parse_args():
parser = ArgumentParser()
# use this config as HF space does not take args
parser.add_argument("--config", type=str, default="config.yaml")
# these args are useful for local debugging
parser.add_argument("--debug", action="store_true")
parser.add_argument("--share", action="store_true")
parser.add_argument("--no_sync", action="store_true")
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
for key, value in config.items():
setattr(args, key, value)
pprint(vars(args))
return args
if __name__ == "__main__":
args = parse_args()
survey_engine = SurveyEngine(args)
survey_engine.main()