ynhe commited on
Commit
69d1f9e
·
verified ·
1 Parent(s): 19bd44c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -2,17 +2,47 @@ import gradio as gr
2
  import random
3
  from datasets import load_dataset
4
  import os
5
- hf_token = os.environ['hf_token']
6
 
7
- # 加载数据集
8
- dataset = load_dataset("Vchitect/VBench_sampled_video")
9
- # 随机选择一个视频
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def get_random_video():
11
  # 随机选择一个索引
12
- random_index = random.randint(0, len(dataset['train']) - 1)
13
- # 获取视频路径
14
- print(dataset['train'][random_index])
15
- video_path = dataset['train'][random_index]['video_path']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return video_path
17
 
18
  # Gradio 接口
 
2
  import random
3
  from datasets import load_dataset
4
  import os
 
5
 
6
+ hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌
7
+ submission_url = "Vchitect/VBench_sampled_video" # 数据集的 URL
8
+ local_dir = "VBench_sampled_video" # 本地文件夹路径
9
+
10
+ # 克隆数据集
11
+ submission_repo = Repository(local_dir=local_dir, clone_from=submission_url, use_auth_token=hf_token, repo_type="dataset")
12
+ submission_repo.git_pull() # 更新本地仓库
13
+
14
+ model_names = os.listdir(local_dir)
15
+
16
+ with open("videos_by_dimension.json") as f:
17
+ dimension = json.load(f)['videos_by_dimension']
18
+
19
+ # with open("all_videos.json") as f:
20
+ # all_videos = json.load(f)
21
+
22
+ types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']
23
+
24
  def get_random_video():
25
  # 随机选择一个索引
26
+ random_index = random.randint(0, len(types) - 1)
27
+ type = types[random_index]
28
+ # 随机选择一个Prompt
29
+ random_index = random.randint(0, len(dimension[type]) - 1)
30
+ prompt = dimension[type][random_index]
31
+ # 随机一个模型
32
+ random_index = random.randint(0, len(model_names) - 1)
33
+ model_name = model_names[random_index]
34
+
35
+ video_path = os.path.join(model_name, type, prompt)
36
+ if os.path.exists(video_path):
37
+ print(video_path)
38
+ return video_path
39
+ else:
40
+ video_path = os.path.join(model_name, prompt)
41
+ if os.path.exists(video_path):
42
+ print(video_path)
43
+ return video_path
44
+ # video_path = dataset['train'][random_index]['video_path']
45
+ print('error:', video_path)
46
  return video_path
47
 
48
  # Gradio 接口