YangZhoumill commited on
Commit
a4b32da
1 Parent(s): 780769f

release code

Browse files
app.py CHANGED
@@ -7,12 +7,8 @@ from model.model_manager import ModelManager
7
  from pathlib import Path
8
  from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
9
 
 
10
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
11
- # gr.themes.Default(),
12
- # gr.themes.Soft(),
13
- # gr.Theme.from_hub('gary109/HaleyCH_Theme'),
14
- # gr.Theme.from_hub('EveryPizza/Cartoony-Gradio-Theme')
15
- # gr.themes.Default(primary_hue="red", secondary_hue="pink")
16
  with gr.Blocks(
17
  title="Play with Open Vision Models",
18
  theme=gr.themes.Default(),
@@ -22,21 +18,21 @@ def build_combine_demo(models, elo_results_file, leaderboard_table_file):
22
  with gr.Tab("Image Generation", id=0):
23
  with gr.Tabs() as tabs_ig:
24
  with gr.Tab("Generation Leaderboard", id=0):
25
- # build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
26
  build_leaderboard_tab()
27
-
28
  with gr.Tab("Generation Arena (battle)", id=1):
29
  build_side_by_side_ui_anony(models)
 
30
  with gr.Tab("Video Generation", id=1):
31
  with gr.Tabs() as tabs_ig:
32
  with gr.Tab("Generation Leaderboard", id=0):
33
- # build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
34
  build_leaderboard_video_tab()
35
 
36
  with gr.Tab("Generation Arena (battle)", id=1):
37
  build_side_by_side_video_ui_anony(models)
 
38
  with gr.Tab("Contributor", id=2):
39
  build_leaderboard_contributor()
 
40
  return demo
41
 
42
 
@@ -44,6 +40,7 @@ def load_elo_results(elo_results_dir):
44
  from collections import defaultdict
45
  elo_results_file = defaultdict(lambda: None)
46
  leaderboard_table_file = defaultdict(lambda: None)
 
47
  if elo_results_dir is not None:
48
  elo_results_dir = Path(elo_results_dir)
49
  elo_results_file = {}
 
7
  from pathlib import Path
8
  from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
9
 
10
+
11
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
 
 
 
 
 
12
  with gr.Blocks(
13
  title="Play with Open Vision Models",
14
  theme=gr.themes.Default(),
 
18
  with gr.Tab("Image Generation", id=0):
19
  with gr.Tabs() as tabs_ig:
20
  with gr.Tab("Generation Leaderboard", id=0):
 
21
  build_leaderboard_tab()
 
22
  with gr.Tab("Generation Arena (battle)", id=1):
23
  build_side_by_side_ui_anony(models)
24
+
25
  with gr.Tab("Video Generation", id=1):
26
  with gr.Tabs() as tabs_ig:
27
  with gr.Tab("Generation Leaderboard", id=0):
 
28
  build_leaderboard_video_tab()
29
 
30
  with gr.Tab("Generation Arena (battle)", id=1):
31
  build_side_by_side_video_ui_anony(models)
32
+
33
  with gr.Tab("Contributor", id=2):
34
  build_leaderboard_contributor()
35
+
36
  return demo
37
 
38
 
 
40
  from collections import defaultdict
41
  elo_results_file = defaultdict(lambda: None)
42
  leaderboard_table_file = defaultdict(lambda: None)
43
+
44
  if elo_results_dir is not None:
45
  elo_results_dir = Path(elo_results_dir)
46
  elo_results_file = {}
model/matchmaker.py CHANGED
@@ -24,35 +24,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
24
  transport.set_keepalive(60)
25
 
26
  sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
 
 
27
  def is_connected():
28
  global ssh_matchmaker_client, sftp_matchmaker_client
29
  if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
30
  return False
31
- # 检查SSH连接是否正常
32
  if not ssh_matchmaker_client.get_transport().is_active():
33
  return False
34
- # 检查SFTP连接是否正常
35
  try:
36
- sftp_matchmaker_client.listdir('.') # 尝试列出根目录
37
  except Exception as e:
38
  print(f"Error checking SFTP connection: {e}")
39
  return False
40
  return True
 
 
41
  def ucb_score(trueskill_diff, t, n):
42
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
43
  ucb = -trueskill_diff + 1.0 * exploration_term
44
  return ucb
45
 
 
46
  def update_trueskill(ratings, ranks):
47
  new_ratings = trueskill_env.rate(ratings, ranks)
48
  return new_ratings
49
 
 
50
  def serialize_rating(rating):
51
  return {'mu': rating.mu, 'sigma': rating.sigma}
52
 
 
53
  def deserialize_rating(rating_dict):
54
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
55
 
 
56
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
57
  global sftp_matchmaker_client
58
  if not is_connected():
@@ -66,6 +72,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
66
  with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
67
  f.write(json_data)
68
 
 
69
  def load_json_via_sftp():
70
  global sftp_matchmaker_client
71
  if not is_connected():
@@ -107,7 +114,7 @@ def matchmaker(num_players, k_group=4, not_run=[]):
107
  ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
108
 
109
  # Exclude self, select opponent with highest UCB score
110
- ucb_scores[selected_player] = -float('inf') # minimize the score for the selected player to exclude it
111
  ucb_scores[not_run] = -float('inf')
112
  opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
113
 
@@ -117,4 +124,3 @@ def matchmaker(num_players, k_group=4, not_run=[]):
117
  random.shuffle(model_ids)
118
 
119
  return model_ids
120
-
 
24
  transport.set_keepalive(60)
25
 
26
  sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
27
+
28
+
29
  def is_connected():
30
  global ssh_matchmaker_client, sftp_matchmaker_client
31
  if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
32
  return False
 
33
  if not ssh_matchmaker_client.get_transport().is_active():
34
  return False
 
35
  try:
36
+ sftp_matchmaker_client.listdir('.')
37
  except Exception as e:
38
  print(f"Error checking SFTP connection: {e}")
39
  return False
40
  return True
41
+
42
+
43
  def ucb_score(trueskill_diff, t, n):
44
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
45
  ucb = -trueskill_diff + 1.0 * exploration_term
46
  return ucb
47
 
48
+
49
  def update_trueskill(ratings, ranks):
50
  new_ratings = trueskill_env.rate(ratings, ranks)
51
  return new_ratings
52
 
53
+
54
  def serialize_rating(rating):
55
  return {'mu': rating.mu, 'sigma': rating.sigma}
56
 
57
+
58
  def deserialize_rating(rating_dict):
59
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
60
 
61
+
62
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
63
  global sftp_matchmaker_client
64
  if not is_connected():
 
72
  with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
73
  f.write(json_data)
74
 
75
+
76
  def load_json_via_sftp():
77
  global sftp_matchmaker_client
78
  if not is_connected():
 
114
  ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
115
 
116
  # Exclude self, select opponent with highest UCB score
117
+ ucb_scores[selected_player] = -float('inf')
118
  ucb_scores[not_run] = -float('inf')
119
  opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
120
 
 
124
  random.shuffle(model_ids)
125
 
126
  return model_ids
 
model/matchmaker_video.py CHANGED
@@ -13,6 +13,7 @@ trueskill_env = TrueSkill()
13
  ssh_matchmaker_client = None
14
  sftp_matchmaker_client = None
15
 
 
16
  def create_ssh_matchmaker_client(server, port, user, password):
17
  global ssh_matchmaker_client, sftp_matchmaker_client
18
  ssh_matchmaker_client = paramiko.SSHClient()
@@ -24,35 +25,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
24
  transport.set_keepalive(60)
25
 
26
  sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
 
 
27
  def is_connected():
28
  global ssh_matchmaker_client, sftp_matchmaker_client
29
  if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
30
  return False
31
- # 检查SSH连接是否正常
32
  if not ssh_matchmaker_client.get_transport().is_active():
33
  return False
34
- # 检查SFTP连接是否正常
35
  try:
36
- sftp_matchmaker_client.listdir('.') # 尝试列出根目录
37
  except Exception as e:
38
  print(f"Error checking SFTP connection: {e}")
39
  return False
40
  return True
 
 
41
  def ucb_score(trueskill_diff, t, n):
42
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
43
  ucb = -trueskill_diff + 1.0 * exploration_term
44
  return ucb
45
 
 
46
  def update_trueskill(ratings, ranks):
47
  new_ratings = trueskill_env.rate(ratings, ranks)
48
  return new_ratings
49
 
 
50
  def serialize_rating(rating):
51
  return {'mu': rating.mu, 'sigma': rating.sigma}
52
 
 
53
  def deserialize_rating(rating_dict):
54
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
55
 
 
56
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
57
  global sftp_matchmaker_client
58
  if not is_connected():
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
66
  with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
67
  f.write(json_data)
68
 
 
69
  def load_json_via_sftp():
70
  global sftp_matchmaker_client
71
  if not is_connected():
@@ -95,7 +103,7 @@ def matchmaker_video(num_players, k_group=4):
95
  ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
96
 
97
  # Exclude self, select opponent with highest UCB score
98
- ucb_scores[selected_player] = -float('inf') # minimize the score for the selected player to exclude it
99
 
100
  excluded_players_1 = [num_players-1, num_players-4]
101
  excluded_players_2 = [num_players-2, num_players-3, num_players-5]
@@ -126,4 +134,3 @@ def matchmaker_video(num_players, k_group=4):
126
  random.shuffle(model_ids)
127
 
128
  return model_ids
129
-
 
13
  ssh_matchmaker_client = None
14
  sftp_matchmaker_client = None
15
 
16
+
17
  def create_ssh_matchmaker_client(server, port, user, password):
18
  global ssh_matchmaker_client, sftp_matchmaker_client
19
  ssh_matchmaker_client = paramiko.SSHClient()
 
25
  transport.set_keepalive(60)
26
 
27
  sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
28
+
29
+
30
  def is_connected():
31
  global ssh_matchmaker_client, sftp_matchmaker_client
32
  if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
33
  return False
 
34
  if not ssh_matchmaker_client.get_transport().is_active():
35
  return False
 
36
  try:
37
+ sftp_matchmaker_client.listdir('.')
38
  except Exception as e:
39
  print(f"Error checking SFTP connection: {e}")
40
  return False
41
  return True
42
+
43
+
44
  def ucb_score(trueskill_diff, t, n):
45
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
46
  ucb = -trueskill_diff + 1.0 * exploration_term
47
  return ucb
48
 
49
+
50
  def update_trueskill(ratings, ranks):
51
  new_ratings = trueskill_env.rate(ratings, ranks)
52
  return new_ratings
53
 
54
+
55
  def serialize_rating(rating):
56
  return {'mu': rating.mu, 'sigma': rating.sigma}
57
 
58
+
59
  def deserialize_rating(rating_dict):
60
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
61
 
62
+
63
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
64
  global sftp_matchmaker_client
65
  if not is_connected():
 
73
  with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
74
  f.write(json_data)
75
 
76
+
77
  def load_json_via_sftp():
78
  global sftp_matchmaker_client
79
  if not is_connected():
 
103
  ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
104
 
105
  # Exclude self, select opponent with highest UCB score
106
+ ucb_scores[selected_player] = -float('inf')
107
 
108
  excluded_players_1 = [num_players-1, num_players-4]
109
  excluded_players_2 = [num_players-2, num_players-3, num_players-5]
 
134
  random.shuffle(model_ids)
135
 
136
  return model_ids
 
model/model_manager.py CHANGED
@@ -58,10 +58,8 @@ class ModelManager:
58
  def generate_image_ig_api(self, prompt, model_name):
59
  pipe = self.load_model_pipe(model_name)
60
  result = pipe(prompt=prompt)
61
-
62
  return result
63
 
64
-
65
  def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
66
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
67
  from .matchmaker import matchmaker
@@ -73,13 +71,11 @@ class ModelManager:
73
  else:
74
  model_names = [model_A, model_B, model_C, model_D]
75
 
76
-
77
  with concurrent.futures.ThreadPoolExecutor() as executor:
78
  futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
79
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
80
  results = [future.result() for future in futures]
81
 
82
-
83
  return results[0], results[1], results[2], results[3], \
84
  model_names[0], model_names[1], model_names[2], model_names[3]
85
 
@@ -156,7 +152,6 @@ class ModelManager:
156
  return results[0], results[1], results[2], results[3], \
157
  model_names[0], model_names[1], model_names[2], model_names[3], prompt
158
 
159
-
160
  def generate_image_ig_parallel(self, prompt, model_A, model_B):
161
  model_names = [model_A, model_B]
162
  with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -165,14 +160,12 @@ class ModelManager:
165
  results = [future.result() for future in futures]
166
  return results[0], results[1]
167
 
168
-
169
  @spaces.GPU(duration=200)
170
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
171
  pipe = self.load_model_pipe(model_name)
172
  result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
173
  return result
174
 
175
-
176
  def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
177
  model_names = [model_A, model_B]
178
  with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -182,7 +175,6 @@ class ModelManager:
182
  results = [future.result() for future in futures]
183
  return results[0], results[1]
184
 
185
-
186
  def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
187
  if model_A == "" and model_B == "":
188
  model_names = random.sample([model for model in self.model_ie_list], 2)
 
58
  def generate_image_ig_api(self, prompt, model_name):
59
  pipe = self.load_model_pipe(model_name)
60
  result = pipe(prompt=prompt)
 
61
  return result
62
 
 
63
  def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
64
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
65
  from .matchmaker import matchmaker
 
71
  else:
72
  model_names = [model_A, model_B, model_C, model_D]
73
 
 
74
  with concurrent.futures.ThreadPoolExecutor() as executor:
75
  futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
76
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
77
  results = [future.result() for future in futures]
78
 
 
79
  return results[0], results[1], results[2], results[3], \
80
  model_names[0], model_names[1], model_names[2], model_names[3]
81
 
 
152
  return results[0], results[1], results[2], results[3], \
153
  model_names[0], model_names[1], model_names[2], model_names[3], prompt
154
 
 
155
  def generate_image_ig_parallel(self, prompt, model_A, model_B):
156
  model_names = [model_A, model_B]
157
  with concurrent.futures.ThreadPoolExecutor() as executor:
 
160
  results = [future.result() for future in futures]
161
  return results[0], results[1]
162
 
 
163
  @spaces.GPU(duration=200)
164
  def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
165
  pipe = self.load_model_pipe(model_name)
166
  result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
167
  return result
168
 
 
169
  def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
170
  model_names = [model_A, model_B]
171
  with concurrent.futures.ThreadPoolExecutor() as executor:
 
175
  results = [future.result() for future in futures]
176
  return results[0], results[1]
177
 
 
178
  def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
179
  if model_A == "" and model_B == "":
180
  model_names = random.sample([model for model in self.model_ie_list], 2)
model/model_registry.py CHANGED
@@ -68,188 +68,3 @@ def get_video_model_description_md(model_list):
68
  model_description_md += "\n"
69
  ct += 1
70
  return model_description_md
71
-
72
- register_model_info(
73
- ["imagenhub_LCM_generation", "fal_LCM_text2image"],
74
- "LCM",
75
- "https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7",
76
- "Latent Consistency Models.",
77
- )
78
-
79
- register_model_info(
80
- ["fal_LCM(v1.5/XL)_text2image"],
81
- "LCM(v1.5/XL)",
82
- "https://fal.ai/models/fast-lcm-diffusion-turbo",
83
- "Latent Consistency Models (v1.5/XL)",
84
- )
85
-
86
- register_model_info(
87
- ["imagenhub_PlayGroundV2_generation", 'playground_PlayGroundV2_generation'],
88
- "Playground v2",
89
- "https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic",
90
- "Playground v2 – 1024px Aesthetic Model",
91
- )
92
-
93
- register_model_info(
94
- ["imagenhub_PlayGroundV2.5_generation", 'playground_PlayGroundV2.5_generation'],
95
- "Playground v2.5",
96
- "https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic",
97
- "Playground v2.5 is the state-of-the-art open-source model in aesthetic quality",
98
- )
99
-
100
- register_model_info(
101
- ["imagenhub_OpenJourney_generation"],
102
- "Openjourney",
103
- "https://huggingface.co/prompthero/openjourney",
104
- "Openjourney is an open source Stable Diffusion fine tuned model on Midjourney images, by PromptHero.",
105
- )
106
-
107
- register_model_info(
108
- ["imagenhub_SDXLTurbo_generation", "fal_SDXLTurbo_text2image"],
109
- "SDXLTurbo",
110
- "https://huggingface.co/stabilityai/sdxl-turbo",
111
- "SDXL-Turbo is a fast generative text-to-image model.",
112
- )
113
-
114
- register_model_info(
115
- ["imagenhub_SDXL_generation", "fal_SDXL_text2image"],
116
- "SDXL",
117
- "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
118
- "SDXL is a Latent Diffusion Model that uses two fixed, pretrained text encoders.",
119
- )
120
-
121
- register_model_info(
122
- ["imagenhub_PixArtAlpha_generation"],
123
- "PixArtAlpha",
124
- "https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS",
125
- "Pixart-α consists of pure transformer blocks for latent diffusion.",
126
- )
127
-
128
- register_model_info(
129
- ["imagenhub_PixArtSigma_generation", "fal_PixArtSigma_text2image"],
130
- "PixArtSigma",
131
- "https://github.com/PixArt-alpha/PixArt-sigma",
132
- "Improved version of Pixart-α.",
133
- )
134
-
135
- register_model_info(
136
- ["imagenhub_SDXLLightning_generation", "fal_SDXLLightning_text2image"],
137
- "SDXL-Lightning",
138
- "https://huggingface.co/ByteDance/SDXL-Lightning",
139
- "SDXL-Lightning is a lightning-fast text-to-image generation model.",
140
- )
141
-
142
- register_model_info(
143
- ["imagenhub_StableCascade_generation", "fal_StableCascade_text2image"],
144
- "StableCascade",
145
- "https://huggingface.co/stabilityai/stable-cascade",
146
- "StableCascade is built upon the Würstchen architecture and working at a much smaller latent space.",
147
- )
148
-
149
- # regist image edition models
150
- register_model_info(
151
- ["imagenhub_CycleDiffusion_edition"],
152
- "CycleDiffusion",
153
- "https://github.com/ChenWu98/cycle-diffusion?tab=readme-ov-file",
154
- "A latent space for stochastic diffusion models.",
155
- )
156
-
157
- register_model_info(
158
- ["imagenhub_Pix2PixZero_edition"],
159
- "Pix2PixZero",
160
- "https://pix2pixzero.github.io/",
161
- "A zero-shot Image-to-Image translation model.",
162
- )
163
-
164
- register_model_info(
165
- ["imagenhub_Prompt2prompt_edition"],
166
- "Prompt2prompt",
167
- "https://prompt-to-prompt.github.io/",
168
- "Image Editing with Cross-Attention Control.",
169
- )
170
-
171
-
172
- register_model_info(
173
- ["imagenhub_InstructPix2Pix_edition"],
174
- "InstructPix2Pix",
175
- "https://www.timothybrooks.com/instruct-pix2pix",
176
- "An instruction-based image editing model.",
177
- )
178
-
179
- register_model_info(
180
- ["imagenhub_MagicBrush_edition"],
181
- "MagicBrush",
182
- "https://osu-nlp-group.github.io/MagicBrush/",
183
- "Manually Annotated Dataset for Instruction-Guided Image Editing.",
184
- )
185
-
186
- register_model_info(
187
- ["imagenhub_PNP_edition"],
188
- "PNP",
189
- "https://github.com/MichalGeyer/plug-and-play",
190
- "Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation.",
191
- )
192
-
193
- register_model_info(
194
- ["imagenhub_InfEdit_edition"],
195
- "InfEdit",
196
- "https://sled-group.github.io/InfEdit/",
197
- "Inversion-Free Image Editing with Natural Language.",
198
- )
199
-
200
- register_model_info(
201
- ["imagenhub_CosXLEdit_edition"],
202
- "CosXLEdit",
203
- "https://huggingface.co/stabilityai/cosxl",
204
- "An instruction-based image editing model from SDXL.",
205
- )
206
-
207
- register_model_info(
208
- ["fal_stable-cascade_text2image"],
209
- "StableCascade",
210
- "https://fal.ai/models/stable-cascade/api",
211
- "StableCascade is a generative model that can generate high-quality images from text prompts.",
212
- )
213
-
214
- register_model_info(
215
- ["fal_AnimateDiff_text2video"],
216
- "AnimateDiff",
217
- "https://fal.ai/models/fast-animatediff-t2v",
218
- "AnimateDiff is a text-driven models that produce diverse and personalized animated images.",
219
- )
220
-
221
- register_model_info(
222
- ["fal_AnimateDiffTurbo_text2video"],
223
- "AnimateDiff Turbo",
224
- "https://fal.ai/models/fast-animatediff-t2v-turbo",
225
- "AnimateDiff Turbo is a lightning version of AnimateDiff.",
226
- )
227
-
228
- register_model_info(
229
- ["videogenhub_LaVie_generation"],
230
- "LaVie",
231
- "https://github.com/Vchitect/LaVie",
232
- "LaVie is a video generation model with cascaded latent diffusion models.",
233
- )
234
-
235
- register_model_info(
236
- ["videogenhub_VideoCrafter2_generation"],
237
- "VideoCrafter2",
238
- "https://ailab-cvc.github.io/videocrafter2/",
239
- "VideoCrafter2 is a T2V model that disentangling motion from appearance.",
240
- )
241
-
242
- register_model_info(
243
- ["videogenhub_ModelScope_generation"],
244
- "ModelScope",
245
- "https://arxiv.org/abs/2308.06571",
246
- "ModelScope is a a T2V synthesis model that evolves from a T2I synthesis model.",
247
- )
248
-
249
- register_model_info(
250
- ["videogenhub_OpenSora_generation"],
251
- "OpenSora",
252
- "https://github.com/hpcaitech/Open-Sora",
253
- "A community-driven opensource implementation of Sora.",
254
- )
255
-
 
68
  model_description_md += "\n"
69
  ct += 1
70
  return model_description_md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/models/__init__.py CHANGED
@@ -37,15 +37,6 @@ IMAGE_GENERATION_MODELS = [
37
  "replicate_FLUX.1-dev_text2image",
38
  ]
39
 
40
-
41
- IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
42
- 'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition',
43
- 'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition',
44
- 'imagenhub_InfEdit_edition', 'imagenhub_CosXLEdit_edition']
45
- # VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video',
46
- # 'fal_AnimateDiffTurbo_text2video',
47
- # 'videogenhub_LaVie_generation', 'videogenhub_VideoCrafter2_generation',
48
- # 'videogenhub_ModelScope_generation', 'videogenhub_OpenSora_generation']
49
  VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
50
  'replicate_Animate-Diff_text2video',
51
  'replicate_OpenSora_text2video',
@@ -59,22 +50,15 @@ VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
59
  'other_Sora_text2video',
60
  ]
61
 
 
62
  def load_pipeline(model_name):
63
  """
64
  Load a model pipeline based on the model name
65
  Args:
66
  model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
67
- the source can be either imagenhub or playground
68
- the name is the name of the model used to load the model
69
- the type is the type of the model, either generation or edition
70
  """
71
  model_source, model_name, model_type = model_name.split("_")
72
- # if model_source == "imagenhub":
73
- # pipe = load_imagenhub_model(model_name, model_type)
74
- # elif model_source == "fal":
75
- # pipe = load_fal_model(model_name, model_type)
76
- # elif model_source == "videogenhub":
77
- # pipe = load_videogenhub_model(model_name)
78
  if model_source == "replicate":
79
  pipe = load_replicate_model(model_name, model_type)
80
  elif model_source == "huggingface":
 
37
  "replicate_FLUX.1-dev_text2image",
38
  ]
39
 
 
 
 
 
 
 
 
 
 
40
  VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
41
  'replicate_Animate-Diff_text2video',
42
  'replicate_OpenSora_text2video',
 
50
  'other_Sora_text2video',
51
  ]
52
 
53
+
54
  def load_pipeline(model_name):
55
  """
56
  Load a model pipeline based on the model name
57
  Args:
58
  model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
 
 
 
59
  """
60
  model_source, model_name, model_type = model_name.split("_")
61
+
 
 
 
 
 
62
  if model_source == "replicate":
63
  pipe = load_replicate_model(model_name, model_type)
64
  elif model_source == "huggingface":
model/models/openai_api_models.py CHANGED
@@ -12,7 +12,6 @@ class OpenaiModel():
12
  self.model_type = model_type
13
 
14
  def __call__(self, *args, **kwargs):
15
-
16
  if self.model_type == "text2image":
17
  assert "prompt" in kwargs, "prompt is required for text2image model"
18
 
@@ -47,7 +46,6 @@ class OpenaiModel():
47
  raise ValueError("model_type must be text2image or image2image")
48
 
49
 
50
-
51
  def load_openai_model(model_name, model_type):
52
  return OpenaiModel(model_name, model_type)
53
 
 
12
  self.model_type = model_type
13
 
14
  def __call__(self, *args, **kwargs):
 
15
  if self.model_type == "text2image":
16
  assert "prompt" in kwargs, "prompt is required for text2image model"
17
 
 
46
  raise ValueError("model_type must be text2image or image2image")
47
 
48
 
 
49
  def load_openai_model(model_name, model_type):
50
  return OpenaiModel(model_name, model_type)
51
 
model/models/other_api_models.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from PIL import Image
5
  import io, time
6
 
 
7
  class OtherModel():
8
  def __init__(self, model_name, model_type):
9
  self.model_name = model_name
@@ -75,6 +76,8 @@ class OtherModel():
75
 
76
  else:
77
  raise ValueError("model_type must be text2image")
 
 
78
  def load_other_model(model_name, model_type):
79
  return OtherModel(model_name, model_type)
80
 
@@ -86,30 +89,3 @@ if __name__ == "__main__":
86
  result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
87
  print(result)
88
  exit()
89
-
90
-
91
- # key = os.environ.get('MIDJOURNEY_KEY')
92
- # prompt = "a good girl"
93
-
94
- # conn = http.client.HTTPSConnection("xdai.online")
95
- # payload = json.dumps({
96
- # "messages": [
97
- # {
98
- # "role": "user",
99
- # "content": "{}".format(prompt)
100
- # }
101
- # ],
102
- # "stream": True,
103
- # "model": "luma-video",
104
- # # "model": "pika-text-to-video",
105
- # })
106
- # headers = {
107
- # 'Authorization': "Bearer {}".format(key),
108
- # 'Content-Type': 'application/json'
109
- # }
110
- # conn.request("POST", "/v1/chat/completions", payload, headers)
111
- # res = conn.getresponse()
112
- # data = res.read()
113
- # info = data.decode("utf-8")
114
- # print(data.decode("utf-8"))
115
-
 
4
  from PIL import Image
5
  import io, time
6
 
7
+
8
  class OtherModel():
9
  def __init__(self, model_name, model_type):
10
  self.model_name = model_name
 
76
 
77
  else:
78
  raise ValueError("model_type must be text2image")
79
+
80
+
81
  def load_other_model(model_name, model_type):
82
  return OtherModel(model_name, model_type)
83
 
 
89
  result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
90
  print(result)
91
  exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/models/replicate_api_models.py CHANGED
@@ -40,11 +40,11 @@ Replicate_MODEl_NAME_MAP = {
40
  "FLUX.1-dev": "black-forest-labs/flux-dev",
41
  }
42
 
 
43
  class ReplicateModel():
44
  def __init__(self, model_name, model_type):
45
  self.model_name = model_name
46
  self.model_type = model_type
47
- # os.environ['FAL_KEY'] = os.environ['FalAPI']
48
 
49
  def __call__(self, *args, **kwargs):
50
  if self.model_type == "text2image":
@@ -179,155 +179,14 @@ class ReplicateModel():
179
  else:
180
  raise ValueError("model_type must be text2image or image2image")
181
 
 
182
  def load_replicate_model(model_name, model_type):
183
  return ReplicateModel(model_name, model_type)
184
 
185
 
186
  if __name__ == "__main__":
187
- import replicate
188
- import time
189
- import concurrent.futures
190
- import os, shutil, re
191
- import requests
192
- from moviepy.editor import VideoFileClip
193
-
194
- # model_name = 'replicate_zeroscope-v2-xl_text2video'
195
- # model_name = 'replicate_Damo-Text-to-Video_text2video'
196
- # model_name = 'replicate_Animate-Diff_text2video'
197
- # model_name = 'replicate_open-sora_text2video'
198
- # model_name = 'replicate_lavie_text2video'
199
- # model_name = 'replicate_video-crafter_text2video'
200
- # model_name = 'replicate_stable-video-diffusion_text2video'
201
- # model_source, model_name, model_type = model_name.split("_")
202
- # pipe = load_replicate_model(model_name, model_type)
203
- # prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
204
- # result = pipe(prompt=prompt)
205
-
206
- # # 文件复制
207
- source_folder = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
208
- destination_folder = '/mnt/data/lizhikai/ksort_video_cache/Advance/'
209
-
210
- special_char = 'output'
211
- for dirpath, dirnames, filenames in os.walk(source_folder):
212
- for dirname in dirnames:
213
- des_dirname = "output-"+dirname[-3:]
214
- print(des_dirname)
215
- if special_char in dirname:
216
- model_name = ["Pika-v1.0"]
217
- for name in model_name:
218
- source_file_path = os.path.join(source_folder, os.path.join(dirname, name+".mp4"))
219
- print(source_file_path)
220
- destination_file_path = os.path.join(destination_folder, os.path.join(des_dirname, name+".mp4"))
221
- print(destination_file_path)
222
- shutil.copy(source_file_path, destination_file_path)
223
-
224
-
225
- # 视频裁剪
226
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
227
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
228
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
229
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
230
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
231
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
232
- # special_char = 'output'
233
- # num = 0
234
- # for dirpath, dirnames, filenames in os.walk(root_dir):
235
- # for dirname in dirnames:
236
- # # 如果文件夹名称中包含指定的特殊字符
237
- # if special_char in dirname:
238
- # num = num+1
239
- # print(num)
240
- # if num < 0:
241
- # continue
242
- # video_path = os.path.join(root_dir, (os.path.join(dirname, f"{dirname}.mp4")))
243
- # out_video_path = os.path.join(root_dir, (os.path.join(dirname, f"Pika-v1.0.mp4")))
244
- # print(video_path)
245
- # print(out_video_path)
246
-
247
- # video = VideoFileClip(video_path)
248
- # width, height = video.size
249
- # center_x, center_y = width // 2, height // 2
250
- # new_width, new_height = 512, 512
251
- # cropped_video = video.crop(x_center=center_x, y_center=center_y, width=min(width, height), height=min(width, height))
252
- # resized_video = cropped_video.resize(newsize=(new_width, new_height))
253
- # resized_video.write_videofile(out_video_path, codec='libx264', fps=video.fps)
254
- # os.remove(video_path)
255
-
256
-
257
-
258
- # file_path = '/home/lizhikai/webvid_prompt100.txt'
259
- # str_list = []
260
- # with open(file_path, 'r', encoding='utf-8') as file:
261
- # for line in file:
262
- # str_list.append(line.strip())
263
- # if len(str_list) == 100:
264
- # break
265
-
266
- # 生成代码
267
- # def generate_image_ig_api(prompt, model_name):
268
- # model_source, model_name, model_type = model_name.split("_")
269
- # pipe = load_replicate_model(model_name, model_type)
270
- # result = pipe(prompt=prompt)
271
- # return result
272
- # model_names = ['replicate_Zeroscope-v2-xl_text2video',
273
- # # 'replicate_Damo-Text-to-Video_text2video',
274
- # 'replicate_Animate-Diff_text2video',
275
- # 'replicate_OpenSora_text2video',
276
- # 'replicate_LaVie_text2video',
277
- # 'replicate_VideoCrafter2_text2video',
278
- # 'replicate_Stable-Video-Diffusion_text2video',
279
- # ]
280
- # save_names = []
281
- # for name in model_names:
282
- # model_source, model_name, model_type = name.split("_")
283
- # save_names.append(model_name)
284
-
285
- # # 遍历根目录及其子目录
286
- # # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
287
- # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
288
- # # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
289
- # # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
290
- # # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
291
- # special_char = 'output'
292
- # num = 0
293
- # for dirpath, dirnames, filenames in os.walk(root_dir):
294
- # for dirname in dirnames:
295
- # # 如果文件夹名称中包含指定的特殊字符
296
- # if special_char in dirname:
297
- # num = num+1
298
- # print(num)
299
- # if num < 0:
300
- # continue
301
- # str_list = []
302
- # prompt_path = os.path.join(root_dir, (os.path.join(dirname, "prompt.txt")))
303
- # print(prompt_path)
304
- # with open(prompt_path, 'r', encoding='utf-8') as file:
305
- # for line in file:
306
- # str_list.append(line.strip())
307
- # prompt = str_list[0]
308
- # print(prompt)
309
-
310
- # with concurrent.futures.ThreadPoolExecutor() as executor:
311
- # futures = [executor.submit(generate_image_ig_api, prompt, model) for model in model_names]
312
- # results = [future.result() for future in futures]
313
-
314
- # # 下载视频并保存
315
- # repeat_num = 5
316
- # for j, url in enumerate(results):
317
- # while 1:
318
- # time.sleep(1)
319
- # response = requests.get(url, stream=True)
320
- # if response.status_code == 200:
321
- # file_path = os.path.join(os.path.join(root_dir, dirname), f'{save_names[j]}.mp4')
322
- # with open(file_path, 'wb') as file:
323
- # for chunk in response.iter_content(chunk_size=8192):
324
- # file.write(chunk)
325
- # print(f"视频 {j} 已保存到 {file_path}")
326
- # break
327
- # else:
328
- # repeat_num = repeat_num - 1
329
- # if repeat_num == 0:
330
- # print(f"视频 {j} 保存失败")
331
- # # raise ValueError("Video request failed.")
332
- # continue
333
-
 
40
  "FLUX.1-dev": "black-forest-labs/flux-dev",
41
  }
42
 
43
+
44
  class ReplicateModel():
45
  def __init__(self, model_name, model_type):
46
  self.model_name = model_name
47
  self.model_type = model_type
 
48
 
49
  def __call__(self, *args, **kwargs):
50
  if self.model_type == "text2image":
 
179
  else:
180
  raise ValueError("model_type must be text2image or image2image")
181
 
182
+
183
  def load_replicate_model(model_name, model_type):
184
  return ReplicateModel(model_name, model_type)
185
 
186
 
187
  if __name__ == "__main__":
188
+ model_name = 'replicate_zeroscope-v2-xl_text2video'
189
+ model_source, model_name, model_type = model_name.split("_")
190
+ pipe = load_replicate_model(model_name, model_type)
191
+ prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
192
+ result = pipe(prompt=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/Ksort.py CHANGED
@@ -8,6 +8,7 @@ from .utils import disable_btn, enable_btn, invisible_btn
8
  from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
9
  import json
10
 
 
11
  def reset_level(Top_btn):
12
  if Top_btn == "Top 1":
13
  level = 0
@@ -19,6 +20,7 @@ def reset_level(Top_btn):
19
  level = 3
20
  return level
21
 
 
22
  def reset_rank(windows, rank, vote_level):
23
  if windows == "Model A":
24
  rank[0] = vote_level
@@ -30,6 +32,7 @@ def reset_rank(windows, rank, vote_level):
30
  rank[3] = vote_level
31
  return rank
32
 
 
33
  def reset_btn_rank(windows, rank, btn, vote_level):
34
  if windows == "Model A" and btn == "1":
35
  rank[0] = 0
@@ -73,6 +76,7 @@ def reset_btn_rank(windows, rank, btn, vote_level):
73
  vote_level = 3
74
  return (rank, vote_level)
75
 
 
76
  def reset_vote_text(rank):
77
  rank_str = ""
78
  for i in range(len(rank)):
@@ -83,24 +87,28 @@ def reset_vote_text(rank):
83
  rank_str = rank_str + " "
84
  return rank_str
85
 
 
86
  def clear_rank(rank, vote_level):
87
  for i in range(len(rank)):
88
  rank[i] = None
89
  vote_level = 0
90
  return rank, vote_level
91
 
 
92
  def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
93
  for i in range(len(rank)):
94
  rank[i] = None
95
  vote_level = 0
96
  return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
97
 
 
98
  def reset_submit(rank):
99
  for i in range(len(rank)):
100
  if rank[i] == None:
101
  return disable_btn
102
  return enable_btn
103
 
 
104
  def reset_mode(mode):
105
 
106
  if mode == "Best":
@@ -116,8 +124,12 @@ def reset_mode(mode):
116
  (gr.Textbox(value="Best", visible=False, interactive=False),)
117
  else:
118
  raise ValueError("Undefined mode")
 
 
119
  def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
120
  return generate_ig0, generate_ig1, generate_ig2, generate_ig3
 
 
121
  def get_json_filename(conv_id):
122
  output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
123
  if not os.path.exists(output_dir):
@@ -127,6 +139,7 @@ def get_json_filename(conv_id):
127
  print(output_file)
128
  return output_file
129
 
 
130
  def get_img_filename(conv_id, i):
131
  output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
132
  if not os.path.exists(output_dir):
@@ -135,6 +148,7 @@ def get_img_filename(conv_id, i):
135
  print(output_file)
136
  return output_file
137
 
 
138
  def vote_submit(states, textbox, rank, request: gr.Request):
139
  conv_id = states[0].conv_id
140
 
@@ -149,6 +163,7 @@ def vote_submit(states, textbox, rank, request: gr.Request):
149
  }
150
  fout.write(json.dumps(data) + "\n")
151
 
 
152
  def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
153
  conv_id = states[0].conv_id
154
  output_dir = create_remote_directory(conv_id)
@@ -167,6 +182,7 @@ def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
167
  from .update_skill import update_skill
168
  update_skill(rank, [x.model_name for x in states])
169
 
 
170
  def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
171
  conv_id = states[0].conv_id
172
  output_dir = create_remote_directory(conv_id, video=True)
@@ -186,6 +202,7 @@ def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_in
186
  from .update_skill_video import update_skill_video
187
  update_skill_video(rank, [x.model_name for x in states])
188
 
 
189
  def submit_response_igm(
190
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
191
  ):
@@ -205,6 +222,8 @@ def submit_response_igm(
205
  gr.Markdown(state2.model_name, visible=True),
206
  gr.Markdown(state3.model_name, visible=True)
207
  ) + (disable_btn,)
 
 
208
  def submit_response_vg(
209
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
210
  ):
@@ -223,6 +242,8 @@ def submit_response_vg(
223
  gr.Markdown(state2.model_name, visible=True),
224
  gr.Markdown(state3.model_name, visible=True)
225
  ) + (disable_btn,)
 
 
226
  def submit_response_rank_igm(
227
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
228
  ):
@@ -246,6 +267,8 @@ def submit_response_rank_igm(
246
  )
247
  else:
248
  return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
 
 
249
  def submit_response_rank_vg(
250
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
251
  ):
@@ -269,6 +292,7 @@ def submit_response_rank_vg(
269
  else:
270
  return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
271
 
 
272
  def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
273
  rank_list = [char for char in vote_textbox if char.isdigit()]
274
  generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
@@ -318,6 +342,7 @@ def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig
318
 
319
  return chatbot + [rank_str] + ["right"] + [rank]
320
 
 
321
  def text_response_rank_vg(vote_textbox):
322
  rank_list = [char for char in vote_textbox if char.isdigit()]
323
  rank = [None, None, None, None]
@@ -336,6 +361,7 @@ def text_response_rank_vg(vote_textbox):
336
 
337
  return [rank_str] + ["right"] + [rank]
338
 
 
339
  def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
340
  base_image = Image.fromarray(image).convert("RGBA")
341
  base_image = base_image.resize((512, 512), Image.ANTIALIAS)
@@ -369,12 +395,15 @@ def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text
369
 
370
  base_image = base_image.convert("RGB")
371
  return base_image
 
 
372
  def add_green_border(image):
373
  border_color = (0, 255, 0) # RGB for green
374
  border_size = 10 # Size of the border
375
  img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
376
  return img_with_border
377
 
 
378
  def check_textbox(textbox):
379
  if textbox=="":
380
  return False
 
8
  from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
9
  import json
10
 
11
+
12
  def reset_level(Top_btn):
13
  if Top_btn == "Top 1":
14
  level = 0
 
20
  level = 3
21
  return level
22
 
23
+
24
  def reset_rank(windows, rank, vote_level):
25
  if windows == "Model A":
26
  rank[0] = vote_level
 
32
  rank[3] = vote_level
33
  return rank
34
 
35
+
36
  def reset_btn_rank(windows, rank, btn, vote_level):
37
  if windows == "Model A" and btn == "1":
38
  rank[0] = 0
 
76
  vote_level = 3
77
  return (rank, vote_level)
78
 
79
+
80
  def reset_vote_text(rank):
81
  rank_str = ""
82
  for i in range(len(rank)):
 
87
  rank_str = rank_str + " "
88
  return rank_str
89
 
90
+
91
  def clear_rank(rank, vote_level):
92
  for i in range(len(rank)):
93
  rank[i] = None
94
  vote_level = 0
95
  return rank, vote_level
96
 
97
+
98
  def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
99
  for i in range(len(rank)):
100
  rank[i] = None
101
  vote_level = 0
102
  return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
103
 
104
+
105
  def reset_submit(rank):
106
  for i in range(len(rank)):
107
  if rank[i] == None:
108
  return disable_btn
109
  return enable_btn
110
 
111
+
112
  def reset_mode(mode):
113
 
114
  if mode == "Best":
 
124
  (gr.Textbox(value="Best", visible=False, interactive=False),)
125
  else:
126
  raise ValueError("Undefined mode")
127
+
128
+
129
  def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
130
  return generate_ig0, generate_ig1, generate_ig2, generate_ig3
131
+
132
+
133
  def get_json_filename(conv_id):
134
  output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
135
  if not os.path.exists(output_dir):
 
139
  print(output_file)
140
  return output_file
141
 
142
+
143
  def get_img_filename(conv_id, i):
144
  output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
145
  if not os.path.exists(output_dir):
 
148
  print(output_file)
149
  return output_file
150
 
151
+
152
  def vote_submit(states, textbox, rank, request: gr.Request):
153
  conv_id = states[0].conv_id
154
 
 
163
  }
164
  fout.write(json.dumps(data) + "\n")
165
 
166
+
167
  def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
168
  conv_id = states[0].conv_id
169
  output_dir = create_remote_directory(conv_id)
 
182
  from .update_skill import update_skill
183
  update_skill(rank, [x.model_name for x in states])
184
 
185
+
186
  def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
187
  conv_id = states[0].conv_id
188
  output_dir = create_remote_directory(conv_id, video=True)
 
202
  from .update_skill_video import update_skill_video
203
  update_skill_video(rank, [x.model_name for x in states])
204
 
205
+
206
  def submit_response_igm(
207
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
208
  ):
 
222
  gr.Markdown(state2.model_name, visible=True),
223
  gr.Markdown(state3.model_name, visible=True)
224
  ) + (disable_btn,)
225
+
226
+
227
  def submit_response_vg(
228
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
229
  ):
 
242
  gr.Markdown(state2.model_name, visible=True),
243
  gr.Markdown(state3.model_name, visible=True)
244
  ) + (disable_btn,)
245
+
246
+
247
  def submit_response_rank_igm(
248
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
249
  ):
 
267
  )
268
  else:
269
  return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
270
+
271
+
272
  def submit_response_rank_vg(
273
  state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
274
  ):
 
292
  else:
293
  return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
294
 
295
+
296
  def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
297
  rank_list = [char for char in vote_textbox if char.isdigit()]
298
  generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
 
342
 
343
  return chatbot + [rank_str] + ["right"] + [rank]
344
 
345
+
346
  def text_response_rank_vg(vote_textbox):
347
  rank_list = [char for char in vote_textbox if char.isdigit()]
348
  rank = [None, None, None, None]
 
361
 
362
  return [rank_str] + ["right"] + [rank]
363
 
364
+
365
  def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
366
  base_image = Image.fromarray(image).convert("RGBA")
367
  base_image = base_image.resize((512, 512), Image.ANTIALIAS)
 
395
 
396
  base_image = base_image.convert("RGB")
397
  return base_image
398
+
399
+
400
  def add_green_border(image):
401
  border_color = (0, 255, 0) # RGB for green
402
  border_size = 10 # Size of the border
403
  img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
404
  return img_with_border
405
 
406
+
407
  def check_textbox(textbox):
408
  if textbox=="":
409
  return False
serve/leaderboard.py CHANGED
@@ -40,12 +40,14 @@ def make_leaderboard_md():
40
  """
41
  return leaderboard_md
42
 
 
43
  def make_leaderboard_video_md():
44
  leaderboard_md = f"""
45
  # 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
46
  """
47
  return leaderboard_md
48
 
 
49
  def model_hyperlink(model_name, link):
50
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
51
 
@@ -89,11 +91,13 @@ def make_disclaimer_md():
89
  '''
90
  return disclaimer_md
91
 
 
92
  def make_arena_leaderboard_data(results):
93
  import pandas as pd
94
  df = pd.DataFrame(results)
95
  return df
96
 
 
97
  def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
98
  with open(score_result_file, "r") as json_file:
99
  data = json.load(json_file)
 
40
  """
41
  return leaderboard_md
42
 
43
+
44
  def make_leaderboard_video_md():
45
  leaderboard_md = f"""
46
  # 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
47
  """
48
  return leaderboard_md
49
 
50
+
51
  def model_hyperlink(model_name, link):
52
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
53
 
 
91
  '''
92
  return disclaimer_md
93
 
94
+
95
  def make_arena_leaderboard_data(results):
96
  import pandas as pd
97
  df = pd.DataFrame(results)
98
  return df
99
 
100
+
101
  def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
102
  with open(score_result_file, "r") as json_file:
103
  data = json.load(json_file)
serve/update_skill.py CHANGED
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
9
  sys.path.append('../')
10
  from model.models import IMAGE_GENERATION_MODELS
11
 
 
12
  ssh_skill_client = None
13
  sftp_skill_client = None
14
 
 
15
  def create_ssh_skill_client(server, port, user, password):
16
  global ssh_skill_client, sftp_skill_client
17
  ssh_skill_client = paramiko.SSHClient()
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
23
  transport.set_keepalive(60)
24
 
25
  sftp_skill_client = ssh_skill_client.open_sftp()
 
 
26
  def is_connected():
27
  global ssh_skill_client, sftp_skill_client
28
  if ssh_skill_client is None or sftp_skill_client is None:
29
  return False
30
- # 检查SSH连接是否正常
31
  if not ssh_skill_client.get_transport().is_active():
32
  return False
33
- # 检查SFTP连接是否正常
34
  try:
35
- sftp_skill_client.listdir('.') # 尝试列出根目录
36
  except Exception as e:
37
  print(f"Error checking SFTP connection: {e}")
38
  return False
39
  return True
 
 
40
  def ucb_score(trueskill_diff, t, n):
41
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
42
  ucb = -trueskill_diff + 1.0 * exploration_term
43
  return ucb
44
 
 
45
  def update_trueskill(ratings, ranks):
46
  new_ratings = trueskill_env.rate(ratings, ranks)
47
  return new_ratings
48
 
 
49
  def serialize_rating(rating):
50
  return {'mu': rating.mu, 'sigma': rating.sigma}
51
 
 
52
  def deserialize_rating(rating_dict):
53
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
54
 
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
66
  with sftp_skill_client.open(SSH_SKILL, 'w') as f:
67
  f.write(json_data)
68
 
 
69
  def load_json_via_sftp():
70
  global sftp_skill_client
71
  if not is_connected():
 
9
  sys.path.append('../')
10
  from model.models import IMAGE_GENERATION_MODELS
11
 
12
+
13
  ssh_skill_client = None
14
  sftp_skill_client = None
15
 
16
+
17
  def create_ssh_skill_client(server, port, user, password):
18
  global ssh_skill_client, sftp_skill_client
19
  ssh_skill_client = paramiko.SSHClient()
 
25
  transport.set_keepalive(60)
26
 
27
  sftp_skill_client = ssh_skill_client.open_sftp()
28
+
29
+
30
  def is_connected():
31
  global ssh_skill_client, sftp_skill_client
32
  if ssh_skill_client is None or sftp_skill_client is None:
33
  return False
 
34
  if not ssh_skill_client.get_transport().is_active():
35
  return False
 
36
  try:
37
+ sftp_skill_client.listdir('.')
38
  except Exception as e:
39
  print(f"Error checking SFTP connection: {e}")
40
  return False
41
  return True
42
+
43
+
44
  def ucb_score(trueskill_diff, t, n):
45
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
46
  ucb = -trueskill_diff + 1.0 * exploration_term
47
  return ucb
48
 
49
+
50
  def update_trueskill(ratings, ranks):
51
  new_ratings = trueskill_env.rate(ratings, ranks)
52
  return new_ratings
53
 
54
+
55
  def serialize_rating(rating):
56
  return {'mu': rating.mu, 'sigma': rating.sigma}
57
 
58
+
59
  def deserialize_rating(rating_dict):
60
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
61
 
 
73
  with sftp_skill_client.open(SSH_SKILL, 'w') as f:
74
  f.write(json_data)
75
 
76
+
77
  def load_json_via_sftp():
78
  global sftp_skill_client
79
  if not is_connected():
serve/update_skill_video.py CHANGED
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
9
  sys.path.append('../')
10
  from model.models import VIDEO_GENERATION_MODELS
11
 
 
12
  ssh_skill_client = None
13
  sftp_skill_client = None
14
 
 
15
  def create_ssh_skill_client(server, port, user, password):
16
  global ssh_skill_client, sftp_skill_client
17
  ssh_skill_client = paramiko.SSHClient()
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
23
  transport.set_keepalive(60)
24
 
25
  sftp_skill_client = ssh_skill_client.open_sftp()
 
 
26
  def is_connected():
27
  global ssh_skill_client, sftp_skill_client
28
  if ssh_skill_client is None or sftp_skill_client is None:
29
  return False
30
- # 检查SSH连接是否正常
31
  if not ssh_skill_client.get_transport().is_active():
32
  return False
33
- # 检查SFTP连接是否正常
34
  try:
35
- sftp_skill_client.listdir('.') # 尝试列出根目录
36
  except Exception as e:
37
  print(f"Error checking SFTP connection: {e}")
38
  return False
39
  return True
 
 
40
  def ucb_score(trueskill_diff, t, n):
41
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
42
  ucb = -trueskill_diff + 1.0 * exploration_term
43
  return ucb
44
 
 
45
  def update_trueskill(ratings, ranks):
46
  new_ratings = trueskill_env.rate(ratings, ranks)
47
  return new_ratings
48
 
 
49
  def serialize_rating(rating):
50
  return {'mu': rating.mu, 'sigma': rating.sigma}
51
 
 
52
  def deserialize_rating(rating_dict):
53
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
54
 
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
66
  with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
67
  f.write(json_data)
68
 
 
69
  def load_json_via_sftp():
70
  global sftp_skill_client
71
  if not is_connected():
 
9
  sys.path.append('../')
10
  from model.models import VIDEO_GENERATION_MODELS
11
 
12
+
13
  ssh_skill_client = None
14
  sftp_skill_client = None
15
 
16
+
17
  def create_ssh_skill_client(server, port, user, password):
18
  global ssh_skill_client, sftp_skill_client
19
  ssh_skill_client = paramiko.SSHClient()
 
25
  transport.set_keepalive(60)
26
 
27
  sftp_skill_client = ssh_skill_client.open_sftp()
28
+
29
+
30
  def is_connected():
31
  global ssh_skill_client, sftp_skill_client
32
  if ssh_skill_client is None or sftp_skill_client is None:
33
  return False
 
34
  if not ssh_skill_client.get_transport().is_active():
35
  return False
 
36
  try:
37
+ sftp_skill_client.listdir('.')
38
  except Exception as e:
39
  print(f"Error checking SFTP connection: {e}")
40
  return False
41
  return True
42
+
43
+
44
  def ucb_score(trueskill_diff, t, n):
45
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
46
  ucb = -trueskill_diff + 1.0 * exploration_term
47
  return ucb
48
 
49
+
50
  def update_trueskill(ratings, ranks):
51
  new_ratings = trueskill_env.rate(ratings, ranks)
52
  return new_ratings
53
 
54
+
55
  def serialize_rating(rating):
56
  return {'mu': rating.mu, 'sigma': rating.sigma}
57
 
58
+
59
  def deserialize_rating(rating_dict):
60
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
61
 
 
73
  with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
74
  f.write(json_data)
75
 
76
+
77
  def load_json_via_sftp():
78
  global sftp_skill_client
79
  if not is_connected():
serve/upload.py CHANGED
@@ -9,15 +9,18 @@ import random
9
  import concurrent.futures
10
  from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
11
 
 
12
  ssh_client = None
13
  sftp_client = None
14
  sftp_client_imgs = None
15
 
 
16
  def open_sftp(i=0):
17
  global ssh_client
18
  sftp_client = ssh_client.open_sftp()
19
  return sftp_client
20
 
 
21
  def create_ssh_client(server, port, user, password):
22
  global ssh_client, sftp_client, sftp_client_imgs
23
  ssh_client = paramiko.SSHClient()
@@ -40,22 +43,22 @@ def is_connected():
40
  global ssh_client, sftp_client
41
  if ssh_client is None or sftp_client is None:
42
  return False
43
- # 检查SSH连接是否正常
44
  if not ssh_client.get_transport().is_active():
45
  return False
46
- # 检查SFTP连接是否正常
47
  try:
48
- sftp_client.listdir('.') # 尝试列出根目录
49
  except Exception as e:
50
  print(f"Error checking SFTP connection: {e}")
51
  return False
52
  return True
53
 
 
54
  def get_image_from_url(image_url):
55
  response = requests.get(image_url)
56
  response.raise_for_status() # success
57
  return Image.open(io.BytesIO(response.content))
58
 
 
59
  # def get_random_mscoco_prompt():
60
  # global sftp_client
61
  # if not is_connected():
@@ -70,6 +73,7 @@ def get_image_from_url(image_url):
70
  # print("\n")
71
  # return content
72
 
 
73
  def get_random_mscoco_prompt():
74
 
75
  file_path = './coco_prompt.txt'
@@ -79,6 +83,7 @@ def get_random_mscoco_prompt():
79
  random_line = random.choice(lines).strip()
80
  return random_line
81
 
 
82
  def get_random_video_prompt(root_dir):
83
  subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
84
  if not subdirs:
@@ -96,6 +101,7 @@ def get_random_video_prompt(root_dir):
96
  raise NotImplementedError
97
  return selected_dir, prompt
98
 
 
99
  def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
100
  def is_directory(sftp, path):
101
  try:
@@ -150,6 +156,7 @@ def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
150
  ssh.close()
151
  return prompt, local_path[1:]
152
 
 
153
  def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
154
  def is_directory(sftp, path):
155
  try:
@@ -204,6 +211,7 @@ def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
204
  ssh.close()
205
  return prompt, [Image.open(path) for path in local_path[1:]]
206
 
 
207
  def create_remote_directory(remote_directory, video=False):
208
  global ssh_client
209
  if not is_connected():
@@ -220,6 +228,7 @@ def create_remote_directory(remote_directory, video=False):
220
  print(f"Directory {remote_directory} created successfully.")
221
  return log_dir
222
 
 
223
  def upload_images(i, image_list, output_file_list, sftp_client):
224
  with sftp_client as sftp:
225
  if isinstance(image_list[i], str):
@@ -233,7 +242,6 @@ def upload_images(i, image_list, output_file_list, sftp_client):
233
  print(f"Successfully uploaded image to {output_file_list[i]}")
234
 
235
 
236
-
237
  def upload_ssh_all(states, output_dir, data, data_path):
238
  global sftp_client
239
  global sftp_client_imgs
@@ -246,7 +254,6 @@ def upload_ssh_all(states, output_dir, data, data_path):
246
  output_file_list.append(output_file)
247
  image_list.append(states[i].output)
248
 
249
-
250
  with concurrent.futures.ThreadPoolExecutor() as executor:
251
  futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
252
 
@@ -257,6 +264,7 @@ def upload_ssh_all(states, output_dir, data, data_path):
257
  print(f"Successfully uploaded JSON data to {data_path}")
258
  # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
259
 
 
260
  def upload_ssh_data(data, data_path):
261
  global sftp_client
262
  global sftp_client_imgs
 
9
  import concurrent.futures
10
  from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
11
 
12
+
13
  ssh_client = None
14
  sftp_client = None
15
  sftp_client_imgs = None
16
 
17
+
18
  def open_sftp(i=0):
19
  global ssh_client
20
  sftp_client = ssh_client.open_sftp()
21
  return sftp_client
22
 
23
+
24
  def create_ssh_client(server, port, user, password):
25
  global ssh_client, sftp_client, sftp_client_imgs
26
  ssh_client = paramiko.SSHClient()
 
43
  global ssh_client, sftp_client
44
  if ssh_client is None or sftp_client is None:
45
  return False
 
46
  if not ssh_client.get_transport().is_active():
47
  return False
 
48
  try:
49
+ sftp_client.listdir('.')
50
  except Exception as e:
51
  print(f"Error checking SFTP connection: {e}")
52
  return False
53
  return True
54
 
55
+
56
  def get_image_from_url(image_url):
57
  response = requests.get(image_url)
58
  response.raise_for_status() # success
59
  return Image.open(io.BytesIO(response.content))
60
 
61
+
62
  # def get_random_mscoco_prompt():
63
  # global sftp_client
64
  # if not is_connected():
 
73
  # print("\n")
74
  # return content
75
 
76
+
77
  def get_random_mscoco_prompt():
78
 
79
  file_path = './coco_prompt.txt'
 
83
  random_line = random.choice(lines).strip()
84
  return random_line
85
 
86
+
87
  def get_random_video_prompt(root_dir):
88
  subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
89
  if not subdirs:
 
101
  raise NotImplementedError
102
  return selected_dir, prompt
103
 
104
+
105
  def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
106
  def is_directory(sftp, path):
107
  try:
 
156
  ssh.close()
157
  return prompt, local_path[1:]
158
 
159
+
160
  def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
161
  def is_directory(sftp, path):
162
  try:
 
211
  ssh.close()
212
  return prompt, [Image.open(path) for path in local_path[1:]]
213
 
214
+
215
  def create_remote_directory(remote_directory, video=False):
216
  global ssh_client
217
  if not is_connected():
 
228
  print(f"Directory {remote_directory} created successfully.")
229
  return log_dir
230
 
231
+
232
  def upload_images(i, image_list, output_file_list, sftp_client):
233
  with sftp_client as sftp:
234
  if isinstance(image_list[i], str):
 
242
  print(f"Successfully uploaded image to {output_file_list[i]}")
243
 
244
 
 
245
  def upload_ssh_all(states, output_dir, data, data_path):
246
  global sftp_client
247
  global sftp_client_imgs
 
254
  output_file_list.append(output_file)
255
  image_list.append(states[i].output)
256
 
 
257
  with concurrent.futures.ThreadPoolExecutor() as executor:
258
  futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
259
 
 
264
  print(f"Successfully uploaded JSON data to {data_path}")
265
  # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
266
 
267
+
268
  def upload_ssh_data(data, data_path):
269
  global sftp_client
270
  global sftp_client_imgs