Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,17 +2,47 @@ import gradio as gr
|
|
2 |
import random
|
3 |
from datasets import load_dataset
|
4 |
import os
|
5 |
-
hf_token = os.environ['hf_token']
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def get_random_video():
|
11 |
# 随机选择一个索引
|
12 |
-
random_index = random.randint(0, len(
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
return video_path
|
17 |
|
18 |
# Gradio 接口
|
|
|
2 |
import random
|
3 |
from datasets import load_dataset
|
4 |
import os
|
|
|
5 |
|
6 |
+
hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌
|
7 |
+
submission_url = "Vchitect/VBench_sampled_video" # 数据集的 URL
|
8 |
+
local_dir = "VBench_sampled_video" # 本地文件夹路径
|
9 |
+
|
10 |
+
# 克隆数据集
|
11 |
+
submission_repo = Repository(local_dir=local_dir, clone_from=submission_url, use_auth_token=hf_token, repo_type="dataset")
|
12 |
+
submission_repo.git_pull() # 更新本地仓库
|
13 |
+
|
14 |
+
model_names = os.listdir(local_dir)
|
15 |
+
|
16 |
+
with open("videos_by_dimension.json") as f:
|
17 |
+
dimension = json.load(f)['videos_by_dimension']
|
18 |
+
|
19 |
+
# with open("all_videos.json") as f:
|
20 |
+
# all_videos = json.load(f)
|
21 |
+
|
22 |
+
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']
|
23 |
+
|
24 |
def get_random_video():
|
25 |
# 随机选择一个索引
|
26 |
+
random_index = random.randint(0, len(types) - 1)
|
27 |
+
type = types[random_index]
|
28 |
+
# 随机选择一个Prompt
|
29 |
+
random_index = random.randint(0, len(dimension[type]) - 1)
|
30 |
+
prompt = dimension[type][random_index]
|
31 |
+
# 随机一个模型
|
32 |
+
random_index = random.randint(0, len(model_names) - 1)
|
33 |
+
model_name = model_names[random_index]
|
34 |
+
|
35 |
+
video_path = os.path.join(model_name, type, prompt)
|
36 |
+
if os.path.exists(video_path):
|
37 |
+
print(video_path)
|
38 |
+
return video_path
|
39 |
+
else:
|
40 |
+
video_path = os.path.join(model_name, prompt)
|
41 |
+
if os.path.exists(video_path):
|
42 |
+
print(video_path)
|
43 |
+
return video_path
|
44 |
+
# video_path = dataset['train'][random_index]['video_path']
|
45 |
+
print('error:', video_path)
|
46 |
return video_path
|
47 |
|
48 |
# Gradio 接口
|