Spaces:
Running
Running
import os | |
import json | |
import shutil | |
import gradio as gr | |
import random | |
from huggingface_hub import Repository,HfApi | |
from huggingface_hub import snapshot_download | |
# from datasets import load_dataset | |
from datasets import config | |
hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌 | |
local_dir = "VBench_sampled_video" # 本地文件夹路径 | |
# dataset = load_dataset("Vchitect/VBench_sampled_video") | |
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/")) | |
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/" | |
# print(config.HF_DATASETS_CACHE) | |
# root = config.HF_DATASETS_CACHE | |
# print(root) | |
def print_directory_contents(path, indent=0): | |
# 打印当前目录的内容 | |
try: | |
for item in os.listdir(path): | |
item_path = os.path.join(path, item) | |
print(' ' * indent + item) # 使用缩进打印文件或文件夹 | |
if os.path.isdir(item_path): # 如果是目录,则递归调用 | |
print_directory_contents(item_path, indent + 1) | |
except PermissionError: | |
print(' ' * indent + "[权限错误,无法访问该目录]") | |
# 拉取数据集 | |
os.makedirs(local_dir, exist_ok=True) | |
hf_api = HfApi(endpoint="https://huggingface.co", token=hf_token) | |
hf_api = HfApi(token=hf_token) | |
repo_id = "Vchitect/VBench_sampled_video" | |
model_names=[] | |
for i in hf_api.list_repo_tree('Vchitect/VBench_sampled_video',repo_type='dataset'): | |
model_name = i.path | |
if '.git' not in model_name and '.md' not in model_name: | |
model_names.append(model_name) | |
with open("videos_by_dimension.json") as f: | |
dimension = json.load(f)['videos_by_dimension'] | |
for key in dimension: | |
new_item = [] | |
for item in dimension[key]: | |
new_item.append(os.path.basename(item)) | |
dimension[key] = new_item | |
# with open("all_videos.json") as f: | |
# all_videos = json.load(f) | |
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency'] | |
def get_video_path_local(model_name, type, prompt): | |
if 'Show-1' in model_name: | |
video_path_subfolder = os.path.join(model_name, type, 'super2') | |
elif 'videocrafter-1' in model_name: | |
video_path_subfolder = os.path.join(model_name, type, '1024x576') | |
else: | |
video_path_subfolder = os.path.join(model_name, type) | |
if model_name == 'cogvideo': | |
prompt = prompt.replace(".mp4",".gif") | |
try: | |
return hf_api.hf_hub_download( | |
repo_id = repo_id, | |
filename = prompt, | |
subfolder = video_path_subfolder, | |
repo_type = "dataset", | |
local_dir = local_dir | |
) | |
except Exception as e: | |
print(f"[PATH]{video_path_subfolder}/{prompt} NOT in hf repo, try {model_name}",e) | |
video_path_subfolder = model_name | |
try: | |
return hf_api.hf_hub_download( | |
repo_id = repo_id, | |
filename = prompt, | |
subfolder = video_path_subfolder, | |
repo_type = 'dataset', | |
local_dir = local_dir | |
) | |
except Exception as e: | |
print(f"[PATH]{video_path_subfolder}/{prompt} NOT in hf repo, try {model_name}",e) | |
print(e) | |
# video_path = dataset['train'][random_index]['video_path'] | |
print('error:', model_name, type, prompt) | |
return None | |
def get_random_video(): | |
# 随机选择一个索引 | |
random_index = random.randint(0, len(types) - 1) | |
type = types[random_index] | |
# 随机选择一个Prompt | |
random_index = random.randint(0, len(dimension[type]) - 1) | |
prompt = dimension[type][random_index] | |
prompt = os.path.basename(prompt) | |
# 随机选择两个不同的模型名称 | |
random_model_names = random.sample(model_names, 2) | |
model_name_1, model_name_2 = random_model_names | |
video_path1 = get_video_path_local(model_name_1, type, prompt) | |
video_path2 = get_video_path_local(model_name_2, type, prompt) | |
return video_path1, video_path2, model_name_1, model_name_2, type, prompt | |
def update_prompt_options(type, value=None): | |
if value: | |
return gr.update(choices=dimension[type], value=value if dimension[type] else None) | |
else: | |
return gr.update(choices=dimension[type], value=dimension[type][0] if dimension[type] else None) | |
def display_videos(type, prompt, model_name_1, model_name_2): | |
video_path1 = get_video_path_local(model_name_1, type, prompt) | |
video_path2 = get_video_path_local(model_name_2, type, prompt) | |
return video_path1, video_path2 | |
def record_user_feedback_a(model_name1, model_name2, type, prompt): | |
# 0 means model A better, 1 means model B better | |
hf_api.hf_hub_download( | |
repo_id = "Vchitect/VBench_human_annotation", | |
filename = "arena_feedback.csv", | |
repo_type = "dataset", | |
local_dir = './' | |
) | |
with open("arena_feedback.csv",'a') as f: | |
f.write(f"{model_name1}\t{model_name2}\t{type}\t{prompt}\t{0}\n") | |
hf_api.upload_file( | |
path_or_fileobj="arena_feedback.csv", | |
path_in_repo="arena_feedback.csv", | |
repo_id="Vchitect/VBench_human_annotation", | |
token=hf_token, | |
repo_type="dataset", | |
commit_message="[From VBench Arena] user feedback", | |
) | |
return gr.update(visible=False),gr.update(visible=False) | |
def record_user_feedback_b(model_name1, model_name2, type, prompt): | |
# 0 means model A better, 1 means model B better | |
hf_api.hf_hub_download( | |
repo_id = "Vchitect/VBench_human_annotation", | |
filename = "arena_feedback.csv", | |
repo_type = "dataset", | |
local_dir = './' | |
) | |
with open("arena_feedback.csv",'a') as f: | |
f.write(f"{model_name1}\t{model_name2}\t{type}\t{prompt}\t{1}\n") | |
hf_api.upload_file( | |
path_or_fileobj="arena_feedback.csv", | |
path_in_repo="arena_feedback.csv", | |
repo_id="Vchitect/VBench_human_annotation", | |
token=hf_token, | |
repo_type="dataset", | |
commit_message="[From VBench Arena] user feedback", | |
) | |
return gr.update(visible=False),gr.update(visible=False) | |
def show_feedback_button(): | |
return gr.update(visible=True),gr.update(visible=True) | |
with gr.Blocks() as interface: | |
gr.Markdown("# VBench Video Arena") | |
gr.Markdown(""" | |
**Random 2 videos** for randomly picking two models to compare random with the same dimension and prompt | |
**Play Selection** is used for the user to select the model, dimension, prompt in the drop-down box, and display them | |
If you are interested, you can also leave your comments.""") | |
type_output = gr.Dropdown(label="Type", choices=types, value=types[0]) | |
prompt_output = gr.Dropdown(label="Prompt", choices=dimension[types[0]], value=dimension[types[0]][0]) | |
prompt_placeholder = gr.State() | |
with gr.Row(): | |
random_button = gr.Button("🎲 Random 2 videos") | |
display_button = gr.Button("🎇 Play Selection") | |
with gr.Row(): | |
with gr.Column(): | |
model_name_1_output = gr.Dropdown(label="Model Name 1", choices=model_names, value=model_names[0]) | |
video_output_1 = gr.Video(label="Video 1") | |
with gr.Column(): | |
model_name_2_output = gr.Dropdown(label="Model Name 2", choices=model_names, value=model_names[1]) | |
video_output_2 = gr.Video(label="Video 2") | |
with gr.Row(): | |
feed0 = gr.Button("👈 Model A is better",visible=False) | |
feed1 = gr.Button("👉 Model B is better",visible=False) | |
type_output.change(fn=update_prompt_options, inputs=[type_output], outputs=[prompt_output]) | |
random_button.click( | |
fn=get_random_video, | |
outputs=[video_output_1, video_output_2,model_name_1_output, model_name_2_output, type_output, prompt_placeholder] | |
).then(fn=update_prompt_options, | |
inputs=[type_output], | |
outputs=[prompt_output] | |
).then(fn=update_prompt_options, | |
inputs=[type_output,prompt_placeholder], | |
outputs=[prompt_output] | |
).then( | |
fn= show_feedback_button, | |
outputs=[feed0, feed1] | |
) | |
display_button.click( | |
fn=display_videos, | |
inputs=[type_output, prompt_output, model_name_1_output, model_name_2_output], | |
outputs=[video_output_1, video_output_2] | |
) | |
feed0.click( | |
fn = record_user_feedback_a, | |
inputs=[model_name_1_output, model_name_2_output, type_output, prompt_placeholder], | |
outputs=[feed0, feed1] | |
) | |
feed1.click( | |
fn = record_user_feedback_b, | |
inputs=[model_name_1_output, model_name_2_output, type_output, prompt_placeholder], | |
outputs=[feed0, feed1] | |
) | |
interface.launch() |