Spaces:
Running
on
Zero
Running
on
Zero
YangZhoumill
commited on
Commit
•
a4b32da
1
Parent(s):
780769f
release code
Browse files- app.py +5 -8
- model/matchmaker.py +11 -5
- model/matchmaker_video.py +12 -5
- model/model_manager.py +0 -8
- model/model_registry.py +0 -185
- model/models/__init__.py +2 -18
- model/models/openai_api_models.py +0 -2
- model/models/other_api_models.py +3 -27
- model/models/replicate_api_models.py +7 -148
- serve/Ksort.py +29 -0
- serve/leaderboard.py +4 -0
- serve/update_skill.py +11 -3
- serve/update_skill_video.py +11 -3
- serve/upload.py +13 -5
app.py
CHANGED
@@ -7,12 +7,8 @@ from model.model_manager import ModelManager
|
|
7 |
from pathlib import Path
|
8 |
from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
|
9 |
|
|
|
10 |
def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
11 |
-
# gr.themes.Default(),
|
12 |
-
# gr.themes.Soft(),
|
13 |
-
# gr.Theme.from_hub('gary109/HaleyCH_Theme'),
|
14 |
-
# gr.Theme.from_hub('EveryPizza/Cartoony-Gradio-Theme')
|
15 |
-
# gr.themes.Default(primary_hue="red", secondary_hue="pink")
|
16 |
with gr.Blocks(
|
17 |
title="Play with Open Vision Models",
|
18 |
theme=gr.themes.Default(),
|
@@ -22,21 +18,21 @@ def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
|
22 |
with gr.Tab("Image Generation", id=0):
|
23 |
with gr.Tabs() as tabs_ig:
|
24 |
with gr.Tab("Generation Leaderboard", id=0):
|
25 |
-
# build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
|
26 |
build_leaderboard_tab()
|
27 |
-
|
28 |
with gr.Tab("Generation Arena (battle)", id=1):
|
29 |
build_side_by_side_ui_anony(models)
|
|
|
30 |
with gr.Tab("Video Generation", id=1):
|
31 |
with gr.Tabs() as tabs_ig:
|
32 |
with gr.Tab("Generation Leaderboard", id=0):
|
33 |
-
# build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
|
34 |
build_leaderboard_video_tab()
|
35 |
|
36 |
with gr.Tab("Generation Arena (battle)", id=1):
|
37 |
build_side_by_side_video_ui_anony(models)
|
|
|
38 |
with gr.Tab("Contributor", id=2):
|
39 |
build_leaderboard_contributor()
|
|
|
40 |
return demo
|
41 |
|
42 |
|
@@ -44,6 +40,7 @@ def load_elo_results(elo_results_dir):
|
|
44 |
from collections import defaultdict
|
45 |
elo_results_file = defaultdict(lambda: None)
|
46 |
leaderboard_table_file = defaultdict(lambda: None)
|
|
|
47 |
if elo_results_dir is not None:
|
48 |
elo_results_dir = Path(elo_results_dir)
|
49 |
elo_results_file = {}
|
|
|
7 |
from pathlib import Path
|
8 |
from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
|
9 |
|
10 |
+
|
11 |
def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
|
|
|
|
|
|
|
|
|
|
12 |
with gr.Blocks(
|
13 |
title="Play with Open Vision Models",
|
14 |
theme=gr.themes.Default(),
|
|
|
18 |
with gr.Tab("Image Generation", id=0):
|
19 |
with gr.Tabs() as tabs_ig:
|
20 |
with gr.Tab("Generation Leaderboard", id=0):
|
|
|
21 |
build_leaderboard_tab()
|
|
|
22 |
with gr.Tab("Generation Arena (battle)", id=1):
|
23 |
build_side_by_side_ui_anony(models)
|
24 |
+
|
25 |
with gr.Tab("Video Generation", id=1):
|
26 |
with gr.Tabs() as tabs_ig:
|
27 |
with gr.Tab("Generation Leaderboard", id=0):
|
|
|
28 |
build_leaderboard_video_tab()
|
29 |
|
30 |
with gr.Tab("Generation Arena (battle)", id=1):
|
31 |
build_side_by_side_video_ui_anony(models)
|
32 |
+
|
33 |
with gr.Tab("Contributor", id=2):
|
34 |
build_leaderboard_contributor()
|
35 |
+
|
36 |
return demo
|
37 |
|
38 |
|
|
|
40 |
from collections import defaultdict
|
41 |
elo_results_file = defaultdict(lambda: None)
|
42 |
leaderboard_table_file = defaultdict(lambda: None)
|
43 |
+
|
44 |
if elo_results_dir is not None:
|
45 |
elo_results_dir = Path(elo_results_dir)
|
46 |
elo_results_file = {}
|
model/matchmaker.py
CHANGED
@@ -24,35 +24,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
|
|
24 |
transport.set_keepalive(60)
|
25 |
|
26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
|
|
|
|
27 |
def is_connected():
|
28 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
29 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
30 |
return False
|
31 |
-
# 检查SSH连接是否正常
|
32 |
if not ssh_matchmaker_client.get_transport().is_active():
|
33 |
return False
|
34 |
-
# 检查SFTP连接是否正常
|
35 |
try:
|
36 |
-
sftp_matchmaker_client.listdir('.')
|
37 |
except Exception as e:
|
38 |
print(f"Error checking SFTP connection: {e}")
|
39 |
return False
|
40 |
return True
|
|
|
|
|
41 |
def ucb_score(trueskill_diff, t, n):
|
42 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
43 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
44 |
return ucb
|
45 |
|
|
|
46 |
def update_trueskill(ratings, ranks):
|
47 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
48 |
return new_ratings
|
49 |
|
|
|
50 |
def serialize_rating(rating):
|
51 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
52 |
|
|
|
53 |
def deserialize_rating(rating_dict):
|
54 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
55 |
|
|
|
56 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
57 |
global sftp_matchmaker_client
|
58 |
if not is_connected():
|
@@ -66,6 +72,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
66 |
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
67 |
f.write(json_data)
|
68 |
|
|
|
69 |
def load_json_via_sftp():
|
70 |
global sftp_matchmaker_client
|
71 |
if not is_connected():
|
@@ -107,7 +114,7 @@ def matchmaker(num_players, k_group=4, not_run=[]):
|
|
107 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
108 |
|
109 |
# Exclude self, select opponent with highest UCB score
|
110 |
-
ucb_scores[selected_player] = -float('inf')
|
111 |
ucb_scores[not_run] = -float('inf')
|
112 |
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
113 |
|
@@ -117,4 +124,3 @@ def matchmaker(num_players, k_group=4, not_run=[]):
|
|
117 |
random.shuffle(model_ids)
|
118 |
|
119 |
return model_ids
|
120 |
-
|
|
|
24 |
transport.set_keepalive(60)
|
25 |
|
26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
27 |
+
|
28 |
+
|
29 |
def is_connected():
|
30 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
31 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
32 |
return False
|
|
|
33 |
if not ssh_matchmaker_client.get_transport().is_active():
|
34 |
return False
|
|
|
35 |
try:
|
36 |
+
sftp_matchmaker_client.listdir('.')
|
37 |
except Exception as e:
|
38 |
print(f"Error checking SFTP connection: {e}")
|
39 |
return False
|
40 |
return True
|
41 |
+
|
42 |
+
|
43 |
def ucb_score(trueskill_diff, t, n):
|
44 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
45 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
46 |
return ucb
|
47 |
|
48 |
+
|
49 |
def update_trueskill(ratings, ranks):
|
50 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
51 |
return new_ratings
|
52 |
|
53 |
+
|
54 |
def serialize_rating(rating):
|
55 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
56 |
|
57 |
+
|
58 |
def deserialize_rating(rating_dict):
|
59 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
60 |
|
61 |
+
|
62 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
63 |
global sftp_matchmaker_client
|
64 |
if not is_connected():
|
|
|
72 |
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
73 |
f.write(json_data)
|
74 |
|
75 |
+
|
76 |
def load_json_via_sftp():
|
77 |
global sftp_matchmaker_client
|
78 |
if not is_connected():
|
|
|
114 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
115 |
|
116 |
# Exclude self, select opponent with highest UCB score
|
117 |
+
ucb_scores[selected_player] = -float('inf')
|
118 |
ucb_scores[not_run] = -float('inf')
|
119 |
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
120 |
|
|
|
124 |
random.shuffle(model_ids)
|
125 |
|
126 |
return model_ids
|
|
model/matchmaker_video.py
CHANGED
@@ -13,6 +13,7 @@ trueskill_env = TrueSkill()
|
|
13 |
ssh_matchmaker_client = None
|
14 |
sftp_matchmaker_client = None
|
15 |
|
|
|
16 |
def create_ssh_matchmaker_client(server, port, user, password):
|
17 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
18 |
ssh_matchmaker_client = paramiko.SSHClient()
|
@@ -24,35 +25,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
|
|
24 |
transport.set_keepalive(60)
|
25 |
|
26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
|
|
|
|
27 |
def is_connected():
|
28 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
29 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
30 |
return False
|
31 |
-
# 检查SSH连接是否正常
|
32 |
if not ssh_matchmaker_client.get_transport().is_active():
|
33 |
return False
|
34 |
-
# 检查SFTP连接是否正常
|
35 |
try:
|
36 |
-
sftp_matchmaker_client.listdir('.')
|
37 |
except Exception as e:
|
38 |
print(f"Error checking SFTP connection: {e}")
|
39 |
return False
|
40 |
return True
|
|
|
|
|
41 |
def ucb_score(trueskill_diff, t, n):
|
42 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
43 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
44 |
return ucb
|
45 |
|
|
|
46 |
def update_trueskill(ratings, ranks):
|
47 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
48 |
return new_ratings
|
49 |
|
|
|
50 |
def serialize_rating(rating):
|
51 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
52 |
|
|
|
53 |
def deserialize_rating(rating_dict):
|
54 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
55 |
|
|
|
56 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
57 |
global sftp_matchmaker_client
|
58 |
if not is_connected():
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
66 |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
67 |
f.write(json_data)
|
68 |
|
|
|
69 |
def load_json_via_sftp():
|
70 |
global sftp_matchmaker_client
|
71 |
if not is_connected():
|
@@ -95,7 +103,7 @@ def matchmaker_video(num_players, k_group=4):
|
|
95 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
96 |
|
97 |
# Exclude self, select opponent with highest UCB score
|
98 |
-
ucb_scores[selected_player] = -float('inf')
|
99 |
|
100 |
excluded_players_1 = [num_players-1, num_players-4]
|
101 |
excluded_players_2 = [num_players-2, num_players-3, num_players-5]
|
@@ -126,4 +134,3 @@ def matchmaker_video(num_players, k_group=4):
|
|
126 |
random.shuffle(model_ids)
|
127 |
|
128 |
return model_ids
|
129 |
-
|
|
|
13 |
ssh_matchmaker_client = None
|
14 |
sftp_matchmaker_client = None
|
15 |
|
16 |
+
|
17 |
def create_ssh_matchmaker_client(server, port, user, password):
|
18 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
19 |
ssh_matchmaker_client = paramiko.SSHClient()
|
|
|
25 |
transport.set_keepalive(60)
|
26 |
|
27 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
28 |
+
|
29 |
+
|
30 |
def is_connected():
|
31 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
32 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
33 |
return False
|
|
|
34 |
if not ssh_matchmaker_client.get_transport().is_active():
|
35 |
return False
|
|
|
36 |
try:
|
37 |
+
sftp_matchmaker_client.listdir('.')
|
38 |
except Exception as e:
|
39 |
print(f"Error checking SFTP connection: {e}")
|
40 |
return False
|
41 |
return True
|
42 |
+
|
43 |
+
|
44 |
def ucb_score(trueskill_diff, t, n):
|
45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
47 |
return ucb
|
48 |
|
49 |
+
|
50 |
def update_trueskill(ratings, ranks):
|
51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
52 |
return new_ratings
|
53 |
|
54 |
+
|
55 |
def serialize_rating(rating):
|
56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
57 |
|
58 |
+
|
59 |
def deserialize_rating(rating_dict):
|
60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
61 |
|
62 |
+
|
63 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
64 |
global sftp_matchmaker_client
|
65 |
if not is_connected():
|
|
|
73 |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
74 |
f.write(json_data)
|
75 |
|
76 |
+
|
77 |
def load_json_via_sftp():
|
78 |
global sftp_matchmaker_client
|
79 |
if not is_connected():
|
|
|
103 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
104 |
|
105 |
# Exclude self, select opponent with highest UCB score
|
106 |
+
ucb_scores[selected_player] = -float('inf')
|
107 |
|
108 |
excluded_players_1 = [num_players-1, num_players-4]
|
109 |
excluded_players_2 = [num_players-2, num_players-3, num_players-5]
|
|
|
134 |
random.shuffle(model_ids)
|
135 |
|
136 |
return model_ids
|
|
model/model_manager.py
CHANGED
@@ -58,10 +58,8 @@ class ModelManager:
|
|
58 |
def generate_image_ig_api(self, prompt, model_name):
|
59 |
pipe = self.load_model_pipe(model_name)
|
60 |
result = pipe(prompt=prompt)
|
61 |
-
|
62 |
return result
|
63 |
|
64 |
-
|
65 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
66 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
67 |
from .matchmaker import matchmaker
|
@@ -73,13 +71,11 @@ class ModelManager:
|
|
73 |
else:
|
74 |
model_names = [model_A, model_B, model_C, model_D]
|
75 |
|
76 |
-
|
77 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
78 |
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
79 |
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
80 |
results = [future.result() for future in futures]
|
81 |
|
82 |
-
|
83 |
return results[0], results[1], results[2], results[3], \
|
84 |
model_names[0], model_names[1], model_names[2], model_names[3]
|
85 |
|
@@ -156,7 +152,6 @@ class ModelManager:
|
|
156 |
return results[0], results[1], results[2], results[3], \
|
157 |
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
158 |
|
159 |
-
|
160 |
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
161 |
model_names = [model_A, model_B]
|
162 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
@@ -165,14 +160,12 @@ class ModelManager:
|
|
165 |
results = [future.result() for future in futures]
|
166 |
return results[0], results[1]
|
167 |
|
168 |
-
|
169 |
@spaces.GPU(duration=200)
|
170 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
171 |
pipe = self.load_model_pipe(model_name)
|
172 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
173 |
return result
|
174 |
|
175 |
-
|
176 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
177 |
model_names = [model_A, model_B]
|
178 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
@@ -182,7 +175,6 @@ class ModelManager:
|
|
182 |
results = [future.result() for future in futures]
|
183 |
return results[0], results[1]
|
184 |
|
185 |
-
|
186 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
187 |
if model_A == "" and model_B == "":
|
188 |
model_names = random.sample([model for model in self.model_ie_list], 2)
|
|
|
58 |
def generate_image_ig_api(self, prompt, model_name):
|
59 |
pipe = self.load_model_pipe(model_name)
|
60 |
result = pipe(prompt=prompt)
|
|
|
61 |
return result
|
62 |
|
|
|
63 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
64 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
65 |
from .matchmaker import matchmaker
|
|
|
71 |
else:
|
72 |
model_names = [model_A, model_B, model_C, model_D]
|
73 |
|
|
|
74 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
75 |
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
76 |
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
77 |
results = [future.result() for future in futures]
|
78 |
|
|
|
79 |
return results[0], results[1], results[2], results[3], \
|
80 |
model_names[0], model_names[1], model_names[2], model_names[3]
|
81 |
|
|
|
152 |
return results[0], results[1], results[2], results[3], \
|
153 |
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
154 |
|
|
|
155 |
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
156 |
model_names = [model_A, model_B]
|
157 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
160 |
results = [future.result() for future in futures]
|
161 |
return results[0], results[1]
|
162 |
|
|
|
163 |
@spaces.GPU(duration=200)
|
164 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
165 |
pipe = self.load_model_pipe(model_name)
|
166 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
167 |
return result
|
168 |
|
|
|
169 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
170 |
model_names = [model_A, model_B]
|
171 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
175 |
results = [future.result() for future in futures]
|
176 |
return results[0], results[1]
|
177 |
|
|
|
178 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
179 |
if model_A == "" and model_B == "":
|
180 |
model_names = random.sample([model for model in self.model_ie_list], 2)
|
model/model_registry.py
CHANGED
@@ -68,188 +68,3 @@ def get_video_model_description_md(model_list):
|
|
68 |
model_description_md += "\n"
|
69 |
ct += 1
|
70 |
return model_description_md
|
71 |
-
|
72 |
-
register_model_info(
|
73 |
-
["imagenhub_LCM_generation", "fal_LCM_text2image"],
|
74 |
-
"LCM",
|
75 |
-
"https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7",
|
76 |
-
"Latent Consistency Models.",
|
77 |
-
)
|
78 |
-
|
79 |
-
register_model_info(
|
80 |
-
["fal_LCM(v1.5/XL)_text2image"],
|
81 |
-
"LCM(v1.5/XL)",
|
82 |
-
"https://fal.ai/models/fast-lcm-diffusion-turbo",
|
83 |
-
"Latent Consistency Models (v1.5/XL)",
|
84 |
-
)
|
85 |
-
|
86 |
-
register_model_info(
|
87 |
-
["imagenhub_PlayGroundV2_generation", 'playground_PlayGroundV2_generation'],
|
88 |
-
"Playground v2",
|
89 |
-
"https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic",
|
90 |
-
"Playground v2 – 1024px Aesthetic Model",
|
91 |
-
)
|
92 |
-
|
93 |
-
register_model_info(
|
94 |
-
["imagenhub_PlayGroundV2.5_generation", 'playground_PlayGroundV2.5_generation'],
|
95 |
-
"Playground v2.5",
|
96 |
-
"https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic",
|
97 |
-
"Playground v2.5 is the state-of-the-art open-source model in aesthetic quality",
|
98 |
-
)
|
99 |
-
|
100 |
-
register_model_info(
|
101 |
-
["imagenhub_OpenJourney_generation"],
|
102 |
-
"Openjourney",
|
103 |
-
"https://huggingface.co/prompthero/openjourney",
|
104 |
-
"Openjourney is an open source Stable Diffusion fine tuned model on Midjourney images, by PromptHero.",
|
105 |
-
)
|
106 |
-
|
107 |
-
register_model_info(
|
108 |
-
["imagenhub_SDXLTurbo_generation", "fal_SDXLTurbo_text2image"],
|
109 |
-
"SDXLTurbo",
|
110 |
-
"https://huggingface.co/stabilityai/sdxl-turbo",
|
111 |
-
"SDXL-Turbo is a fast generative text-to-image model.",
|
112 |
-
)
|
113 |
-
|
114 |
-
register_model_info(
|
115 |
-
["imagenhub_SDXL_generation", "fal_SDXL_text2image"],
|
116 |
-
"SDXL",
|
117 |
-
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
118 |
-
"SDXL is a Latent Diffusion Model that uses two fixed, pretrained text encoders.",
|
119 |
-
)
|
120 |
-
|
121 |
-
register_model_info(
|
122 |
-
["imagenhub_PixArtAlpha_generation"],
|
123 |
-
"PixArtAlpha",
|
124 |
-
"https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS",
|
125 |
-
"Pixart-α consists of pure transformer blocks for latent diffusion.",
|
126 |
-
)
|
127 |
-
|
128 |
-
register_model_info(
|
129 |
-
["imagenhub_PixArtSigma_generation", "fal_PixArtSigma_text2image"],
|
130 |
-
"PixArtSigma",
|
131 |
-
"https://github.com/PixArt-alpha/PixArt-sigma",
|
132 |
-
"Improved version of Pixart-α.",
|
133 |
-
)
|
134 |
-
|
135 |
-
register_model_info(
|
136 |
-
["imagenhub_SDXLLightning_generation", "fal_SDXLLightning_text2image"],
|
137 |
-
"SDXL-Lightning",
|
138 |
-
"https://huggingface.co/ByteDance/SDXL-Lightning",
|
139 |
-
"SDXL-Lightning is a lightning-fast text-to-image generation model.",
|
140 |
-
)
|
141 |
-
|
142 |
-
register_model_info(
|
143 |
-
["imagenhub_StableCascade_generation", "fal_StableCascade_text2image"],
|
144 |
-
"StableCascade",
|
145 |
-
"https://huggingface.co/stabilityai/stable-cascade",
|
146 |
-
"StableCascade is built upon the Würstchen architecture and working at a much smaller latent space.",
|
147 |
-
)
|
148 |
-
|
149 |
-
# regist image edition models
|
150 |
-
register_model_info(
|
151 |
-
["imagenhub_CycleDiffusion_edition"],
|
152 |
-
"CycleDiffusion",
|
153 |
-
"https://github.com/ChenWu98/cycle-diffusion?tab=readme-ov-file",
|
154 |
-
"A latent space for stochastic diffusion models.",
|
155 |
-
)
|
156 |
-
|
157 |
-
register_model_info(
|
158 |
-
["imagenhub_Pix2PixZero_edition"],
|
159 |
-
"Pix2PixZero",
|
160 |
-
"https://pix2pixzero.github.io/",
|
161 |
-
"A zero-shot Image-to-Image translation model.",
|
162 |
-
)
|
163 |
-
|
164 |
-
register_model_info(
|
165 |
-
["imagenhub_Prompt2prompt_edition"],
|
166 |
-
"Prompt2prompt",
|
167 |
-
"https://prompt-to-prompt.github.io/",
|
168 |
-
"Image Editing with Cross-Attention Control.",
|
169 |
-
)
|
170 |
-
|
171 |
-
|
172 |
-
register_model_info(
|
173 |
-
["imagenhub_InstructPix2Pix_edition"],
|
174 |
-
"InstructPix2Pix",
|
175 |
-
"https://www.timothybrooks.com/instruct-pix2pix",
|
176 |
-
"An instruction-based image editing model.",
|
177 |
-
)
|
178 |
-
|
179 |
-
register_model_info(
|
180 |
-
["imagenhub_MagicBrush_edition"],
|
181 |
-
"MagicBrush",
|
182 |
-
"https://osu-nlp-group.github.io/MagicBrush/",
|
183 |
-
"Manually Annotated Dataset for Instruction-Guided Image Editing.",
|
184 |
-
)
|
185 |
-
|
186 |
-
register_model_info(
|
187 |
-
["imagenhub_PNP_edition"],
|
188 |
-
"PNP",
|
189 |
-
"https://github.com/MichalGeyer/plug-and-play",
|
190 |
-
"Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation.",
|
191 |
-
)
|
192 |
-
|
193 |
-
register_model_info(
|
194 |
-
["imagenhub_InfEdit_edition"],
|
195 |
-
"InfEdit",
|
196 |
-
"https://sled-group.github.io/InfEdit/",
|
197 |
-
"Inversion-Free Image Editing with Natural Language.",
|
198 |
-
)
|
199 |
-
|
200 |
-
register_model_info(
|
201 |
-
["imagenhub_CosXLEdit_edition"],
|
202 |
-
"CosXLEdit",
|
203 |
-
"https://huggingface.co/stabilityai/cosxl",
|
204 |
-
"An instruction-based image editing model from SDXL.",
|
205 |
-
)
|
206 |
-
|
207 |
-
register_model_info(
|
208 |
-
["fal_stable-cascade_text2image"],
|
209 |
-
"StableCascade",
|
210 |
-
"https://fal.ai/models/stable-cascade/api",
|
211 |
-
"StableCascade is a generative model that can generate high-quality images from text prompts.",
|
212 |
-
)
|
213 |
-
|
214 |
-
register_model_info(
|
215 |
-
["fal_AnimateDiff_text2video"],
|
216 |
-
"AnimateDiff",
|
217 |
-
"https://fal.ai/models/fast-animatediff-t2v",
|
218 |
-
"AnimateDiff is a text-driven models that produce diverse and personalized animated images.",
|
219 |
-
)
|
220 |
-
|
221 |
-
register_model_info(
|
222 |
-
["fal_AnimateDiffTurbo_text2video"],
|
223 |
-
"AnimateDiff Turbo",
|
224 |
-
"https://fal.ai/models/fast-animatediff-t2v-turbo",
|
225 |
-
"AnimateDiff Turbo is a lightning version of AnimateDiff.",
|
226 |
-
)
|
227 |
-
|
228 |
-
register_model_info(
|
229 |
-
["videogenhub_LaVie_generation"],
|
230 |
-
"LaVie",
|
231 |
-
"https://github.com/Vchitect/LaVie",
|
232 |
-
"LaVie is a video generation model with cascaded latent diffusion models.",
|
233 |
-
)
|
234 |
-
|
235 |
-
register_model_info(
|
236 |
-
["videogenhub_VideoCrafter2_generation"],
|
237 |
-
"VideoCrafter2",
|
238 |
-
"https://ailab-cvc.github.io/videocrafter2/",
|
239 |
-
"VideoCrafter2 is a T2V model that disentangling motion from appearance.",
|
240 |
-
)
|
241 |
-
|
242 |
-
register_model_info(
|
243 |
-
["videogenhub_ModelScope_generation"],
|
244 |
-
"ModelScope",
|
245 |
-
"https://arxiv.org/abs/2308.06571",
|
246 |
-
"ModelScope is a a T2V synthesis model that evolves from a T2I synthesis model.",
|
247 |
-
)
|
248 |
-
|
249 |
-
register_model_info(
|
250 |
-
["videogenhub_OpenSora_generation"],
|
251 |
-
"OpenSora",
|
252 |
-
"https://github.com/hpcaitech/Open-Sora",
|
253 |
-
"A community-driven opensource implementation of Sora.",
|
254 |
-
)
|
255 |
-
|
|
|
68 |
model_description_md += "\n"
|
69 |
ct += 1
|
70 |
return model_description_md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/__init__.py
CHANGED
@@ -37,15 +37,6 @@ IMAGE_GENERATION_MODELS = [
|
|
37 |
"replicate_FLUX.1-dev_text2image",
|
38 |
]
|
39 |
|
40 |
-
|
41 |
-
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
|
42 |
-
'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition',
|
43 |
-
'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition',
|
44 |
-
'imagenhub_InfEdit_edition', 'imagenhub_CosXLEdit_edition']
|
45 |
-
# VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video',
|
46 |
-
# 'fal_AnimateDiffTurbo_text2video',
|
47 |
-
# 'videogenhub_LaVie_generation', 'videogenhub_VideoCrafter2_generation',
|
48 |
-
# 'videogenhub_ModelScope_generation', 'videogenhub_OpenSora_generation']
|
49 |
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
50 |
'replicate_Animate-Diff_text2video',
|
51 |
'replicate_OpenSora_text2video',
|
@@ -59,22 +50,15 @@ VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
|
59 |
'other_Sora_text2video',
|
60 |
]
|
61 |
|
|
|
62 |
def load_pipeline(model_name):
|
63 |
"""
|
64 |
Load a model pipeline based on the model name
|
65 |
Args:
|
66 |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
67 |
-
the source can be either imagenhub or playground
|
68 |
-
the name is the name of the model used to load the model
|
69 |
-
the type is the type of the model, either generation or edition
|
70 |
"""
|
71 |
model_source, model_name, model_type = model_name.split("_")
|
72 |
-
|
73 |
-
# pipe = load_imagenhub_model(model_name, model_type)
|
74 |
-
# elif model_source == "fal":
|
75 |
-
# pipe = load_fal_model(model_name, model_type)
|
76 |
-
# elif model_source == "videogenhub":
|
77 |
-
# pipe = load_videogenhub_model(model_name)
|
78 |
if model_source == "replicate":
|
79 |
pipe = load_replicate_model(model_name, model_type)
|
80 |
elif model_source == "huggingface":
|
|
|
37 |
"replicate_FLUX.1-dev_text2image",
|
38 |
]
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
41 |
'replicate_Animate-Diff_text2video',
|
42 |
'replicate_OpenSora_text2video',
|
|
|
50 |
'other_Sora_text2video',
|
51 |
]
|
52 |
|
53 |
+
|
54 |
def load_pipeline(model_name):
|
55 |
"""
|
56 |
Load a model pipeline based on the model name
|
57 |
Args:
|
58 |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
|
|
|
|
|
|
59 |
"""
|
60 |
model_source, model_name, model_type = model_name.split("_")
|
61 |
+
|
|
|
|
|
|
|
|
|
|
|
62 |
if model_source == "replicate":
|
63 |
pipe = load_replicate_model(model_name, model_type)
|
64 |
elif model_source == "huggingface":
|
model/models/openai_api_models.py
CHANGED
@@ -12,7 +12,6 @@ class OpenaiModel():
|
|
12 |
self.model_type = model_type
|
13 |
|
14 |
def __call__(self, *args, **kwargs):
|
15 |
-
|
16 |
if self.model_type == "text2image":
|
17 |
assert "prompt" in kwargs, "prompt is required for text2image model"
|
18 |
|
@@ -47,7 +46,6 @@ class OpenaiModel():
|
|
47 |
raise ValueError("model_type must be text2image or image2image")
|
48 |
|
49 |
|
50 |
-
|
51 |
def load_openai_model(model_name, model_type):
|
52 |
return OpenaiModel(model_name, model_type)
|
53 |
|
|
|
12 |
self.model_type = model_type
|
13 |
|
14 |
def __call__(self, *args, **kwargs):
|
|
|
15 |
if self.model_type == "text2image":
|
16 |
assert "prompt" in kwargs, "prompt is required for text2image model"
|
17 |
|
|
|
46 |
raise ValueError("model_type must be text2image or image2image")
|
47 |
|
48 |
|
|
|
49 |
def load_openai_model(model_name, model_type):
|
50 |
return OpenaiModel(model_name, model_type)
|
51 |
|
model/models/other_api_models.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
from PIL import Image
|
5 |
import io, time
|
6 |
|
|
|
7 |
class OtherModel():
|
8 |
def __init__(self, model_name, model_type):
|
9 |
self.model_name = model_name
|
@@ -75,6 +76,8 @@ class OtherModel():
|
|
75 |
|
76 |
else:
|
77 |
raise ValueError("model_type must be text2image")
|
|
|
|
|
78 |
def load_other_model(model_name, model_type):
|
79 |
return OtherModel(model_name, model_type)
|
80 |
|
@@ -86,30 +89,3 @@ if __name__ == "__main__":
|
|
86 |
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
|
87 |
print(result)
|
88 |
exit()
|
89 |
-
|
90 |
-
|
91 |
-
# key = os.environ.get('MIDJOURNEY_KEY')
|
92 |
-
# prompt = "a good girl"
|
93 |
-
|
94 |
-
# conn = http.client.HTTPSConnection("xdai.online")
|
95 |
-
# payload = json.dumps({
|
96 |
-
# "messages": [
|
97 |
-
# {
|
98 |
-
# "role": "user",
|
99 |
-
# "content": "{}".format(prompt)
|
100 |
-
# }
|
101 |
-
# ],
|
102 |
-
# "stream": True,
|
103 |
-
# "model": "luma-video",
|
104 |
-
# # "model": "pika-text-to-video",
|
105 |
-
# })
|
106 |
-
# headers = {
|
107 |
-
# 'Authorization': "Bearer {}".format(key),
|
108 |
-
# 'Content-Type': 'application/json'
|
109 |
-
# }
|
110 |
-
# conn.request("POST", "/v1/chat/completions", payload, headers)
|
111 |
-
# res = conn.getresponse()
|
112 |
-
# data = res.read()
|
113 |
-
# info = data.decode("utf-8")
|
114 |
-
# print(data.decode("utf-8"))
|
115 |
-
|
|
|
4 |
from PIL import Image
|
5 |
import io, time
|
6 |
|
7 |
+
|
8 |
class OtherModel():
|
9 |
def __init__(self, model_name, model_type):
|
10 |
self.model_name = model_name
|
|
|
76 |
|
77 |
else:
|
78 |
raise ValueError("model_type must be text2image")
|
79 |
+
|
80 |
+
|
81 |
def load_other_model(model_name, model_type):
|
82 |
return OtherModel(model_name, model_type)
|
83 |
|
|
|
89 |
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
|
90 |
print(result)
|
91 |
exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/replicate_api_models.py
CHANGED
@@ -40,11 +40,11 @@ Replicate_MODEl_NAME_MAP = {
|
|
40 |
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
41 |
}
|
42 |
|
|
|
43 |
class ReplicateModel():
|
44 |
def __init__(self, model_name, model_type):
|
45 |
self.model_name = model_name
|
46 |
self.model_type = model_type
|
47 |
-
# os.environ['FAL_KEY'] = os.environ['FalAPI']
|
48 |
|
49 |
def __call__(self, *args, **kwargs):
|
50 |
if self.model_type == "text2image":
|
@@ -179,155 +179,14 @@ class ReplicateModel():
|
|
179 |
else:
|
180 |
raise ValueError("model_type must be text2image or image2image")
|
181 |
|
|
|
182 |
def load_replicate_model(model_name, model_type):
|
183 |
return ReplicateModel(model_name, model_type)
|
184 |
|
185 |
|
186 |
if __name__ == "__main__":
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
from moviepy.editor import VideoFileClip
|
193 |
-
|
194 |
-
# model_name = 'replicate_zeroscope-v2-xl_text2video'
|
195 |
-
# model_name = 'replicate_Damo-Text-to-Video_text2video'
|
196 |
-
# model_name = 'replicate_Animate-Diff_text2video'
|
197 |
-
# model_name = 'replicate_open-sora_text2video'
|
198 |
-
# model_name = 'replicate_lavie_text2video'
|
199 |
-
# model_name = 'replicate_video-crafter_text2video'
|
200 |
-
# model_name = 'replicate_stable-video-diffusion_text2video'
|
201 |
-
# model_source, model_name, model_type = model_name.split("_")
|
202 |
-
# pipe = load_replicate_model(model_name, model_type)
|
203 |
-
# prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
204 |
-
# result = pipe(prompt=prompt)
|
205 |
-
|
206 |
-
# # 文件复制
|
207 |
-
source_folder = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
|
208 |
-
destination_folder = '/mnt/data/lizhikai/ksort_video_cache/Advance/'
|
209 |
-
|
210 |
-
special_char = 'output'
|
211 |
-
for dirpath, dirnames, filenames in os.walk(source_folder):
|
212 |
-
for dirname in dirnames:
|
213 |
-
des_dirname = "output-"+dirname[-3:]
|
214 |
-
print(des_dirname)
|
215 |
-
if special_char in dirname:
|
216 |
-
model_name = ["Pika-v1.0"]
|
217 |
-
for name in model_name:
|
218 |
-
source_file_path = os.path.join(source_folder, os.path.join(dirname, name+".mp4"))
|
219 |
-
print(source_file_path)
|
220 |
-
destination_file_path = os.path.join(destination_folder, os.path.join(des_dirname, name+".mp4"))
|
221 |
-
print(destination_file_path)
|
222 |
-
shutil.copy(source_file_path, destination_file_path)
|
223 |
-
|
224 |
-
|
225 |
-
# 视频裁剪
|
226 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
|
227 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
|
228 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
|
229 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
|
230 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
|
231 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
|
232 |
-
# special_char = 'output'
|
233 |
-
# num = 0
|
234 |
-
# for dirpath, dirnames, filenames in os.walk(root_dir):
|
235 |
-
# for dirname in dirnames:
|
236 |
-
# # 如果文件夹名称中包含指定的特殊字符
|
237 |
-
# if special_char in dirname:
|
238 |
-
# num = num+1
|
239 |
-
# print(num)
|
240 |
-
# if num < 0:
|
241 |
-
# continue
|
242 |
-
# video_path = os.path.join(root_dir, (os.path.join(dirname, f"{dirname}.mp4")))
|
243 |
-
# out_video_path = os.path.join(root_dir, (os.path.join(dirname, f"Pika-v1.0.mp4")))
|
244 |
-
# print(video_path)
|
245 |
-
# print(out_video_path)
|
246 |
-
|
247 |
-
# video = VideoFileClip(video_path)
|
248 |
-
# width, height = video.size
|
249 |
-
# center_x, center_y = width // 2, height // 2
|
250 |
-
# new_width, new_height = 512, 512
|
251 |
-
# cropped_video = video.crop(x_center=center_x, y_center=center_y, width=min(width, height), height=min(width, height))
|
252 |
-
# resized_video = cropped_video.resize(newsize=(new_width, new_height))
|
253 |
-
# resized_video.write_videofile(out_video_path, codec='libx264', fps=video.fps)
|
254 |
-
# os.remove(video_path)
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
# file_path = '/home/lizhikai/webvid_prompt100.txt'
|
259 |
-
# str_list = []
|
260 |
-
# with open(file_path, 'r', encoding='utf-8') as file:
|
261 |
-
# for line in file:
|
262 |
-
# str_list.append(line.strip())
|
263 |
-
# if len(str_list) == 100:
|
264 |
-
# break
|
265 |
-
|
266 |
-
# 生成代码
|
267 |
-
# def generate_image_ig_api(prompt, model_name):
|
268 |
-
# model_source, model_name, model_type = model_name.split("_")
|
269 |
-
# pipe = load_replicate_model(model_name, model_type)
|
270 |
-
# result = pipe(prompt=prompt)
|
271 |
-
# return result
|
272 |
-
# model_names = ['replicate_Zeroscope-v2-xl_text2video',
|
273 |
-
# # 'replicate_Damo-Text-to-Video_text2video',
|
274 |
-
# 'replicate_Animate-Diff_text2video',
|
275 |
-
# 'replicate_OpenSora_text2video',
|
276 |
-
# 'replicate_LaVie_text2video',
|
277 |
-
# 'replicate_VideoCrafter2_text2video',
|
278 |
-
# 'replicate_Stable-Video-Diffusion_text2video',
|
279 |
-
# ]
|
280 |
-
# save_names = []
|
281 |
-
# for name in model_names:
|
282 |
-
# model_source, model_name, model_type = name.split("_")
|
283 |
-
# save_names.append(model_name)
|
284 |
-
|
285 |
-
# # 遍历根目录及其子目录
|
286 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
|
287 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
|
288 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
|
289 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
|
290 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
|
291 |
-
# special_char = 'output'
|
292 |
-
# num = 0
|
293 |
-
# for dirpath, dirnames, filenames in os.walk(root_dir):
|
294 |
-
# for dirname in dirnames:
|
295 |
-
# # 如果文件夹名称中包含指定的特殊字符
|
296 |
-
# if special_char in dirname:
|
297 |
-
# num = num+1
|
298 |
-
# print(num)
|
299 |
-
# if num < 0:
|
300 |
-
# continue
|
301 |
-
# str_list = []
|
302 |
-
# prompt_path = os.path.join(root_dir, (os.path.join(dirname, "prompt.txt")))
|
303 |
-
# print(prompt_path)
|
304 |
-
# with open(prompt_path, 'r', encoding='utf-8') as file:
|
305 |
-
# for line in file:
|
306 |
-
# str_list.append(line.strip())
|
307 |
-
# prompt = str_list[0]
|
308 |
-
# print(prompt)
|
309 |
-
|
310 |
-
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
311 |
-
# futures = [executor.submit(generate_image_ig_api, prompt, model) for model in model_names]
|
312 |
-
# results = [future.result() for future in futures]
|
313 |
-
|
314 |
-
# # 下载视频并保存
|
315 |
-
# repeat_num = 5
|
316 |
-
# for j, url in enumerate(results):
|
317 |
-
# while 1:
|
318 |
-
# time.sleep(1)
|
319 |
-
# response = requests.get(url, stream=True)
|
320 |
-
# if response.status_code == 200:
|
321 |
-
# file_path = os.path.join(os.path.join(root_dir, dirname), f'{save_names[j]}.mp4')
|
322 |
-
# with open(file_path, 'wb') as file:
|
323 |
-
# for chunk in response.iter_content(chunk_size=8192):
|
324 |
-
# file.write(chunk)
|
325 |
-
# print(f"视频 {j} 已保存到 {file_path}")
|
326 |
-
# break
|
327 |
-
# else:
|
328 |
-
# repeat_num = repeat_num - 1
|
329 |
-
# if repeat_num == 0:
|
330 |
-
# print(f"视频 {j} 保存失败")
|
331 |
-
# # raise ValueError("Video request failed.")
|
332 |
-
# continue
|
333 |
-
|
|
|
40 |
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
41 |
}
|
42 |
|
43 |
+
|
44 |
class ReplicateModel():
|
45 |
def __init__(self, model_name, model_type):
|
46 |
self.model_name = model_name
|
47 |
self.model_type = model_type
|
|
|
48 |
|
49 |
def __call__(self, *args, **kwargs):
|
50 |
if self.model_type == "text2image":
|
|
|
179 |
else:
|
180 |
raise ValueError("model_type must be text2image or image2image")
|
181 |
|
182 |
+
|
183 |
def load_replicate_model(model_name, model_type):
|
184 |
return ReplicateModel(model_name, model_type)
|
185 |
|
186 |
|
187 |
if __name__ == "__main__":
|
188 |
+
model_name = 'replicate_zeroscope-v2-xl_text2video'
|
189 |
+
model_source, model_name, model_type = model_name.split("_")
|
190 |
+
pipe = load_replicate_model(model_name, model_type)
|
191 |
+
prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
192 |
+
result = pipe(prompt=prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serve/Ksort.py
CHANGED
@@ -8,6 +8,7 @@ from .utils import disable_btn, enable_btn, invisible_btn
|
|
8 |
from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
|
9 |
import json
|
10 |
|
|
|
11 |
def reset_level(Top_btn):
|
12 |
if Top_btn == "Top 1":
|
13 |
level = 0
|
@@ -19,6 +20,7 @@ def reset_level(Top_btn):
|
|
19 |
level = 3
|
20 |
return level
|
21 |
|
|
|
22 |
def reset_rank(windows, rank, vote_level):
|
23 |
if windows == "Model A":
|
24 |
rank[0] = vote_level
|
@@ -30,6 +32,7 @@ def reset_rank(windows, rank, vote_level):
|
|
30 |
rank[3] = vote_level
|
31 |
return rank
|
32 |
|
|
|
33 |
def reset_btn_rank(windows, rank, btn, vote_level):
|
34 |
if windows == "Model A" and btn == "1":
|
35 |
rank[0] = 0
|
@@ -73,6 +76,7 @@ def reset_btn_rank(windows, rank, btn, vote_level):
|
|
73 |
vote_level = 3
|
74 |
return (rank, vote_level)
|
75 |
|
|
|
76 |
def reset_vote_text(rank):
|
77 |
rank_str = ""
|
78 |
for i in range(len(rank)):
|
@@ -83,24 +87,28 @@ def reset_vote_text(rank):
|
|
83 |
rank_str = rank_str + " "
|
84 |
return rank_str
|
85 |
|
|
|
86 |
def clear_rank(rank, vote_level):
|
87 |
for i in range(len(rank)):
|
88 |
rank[i] = None
|
89 |
vote_level = 0
|
90 |
return rank, vote_level
|
91 |
|
|
|
92 |
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
|
93 |
for i in range(len(rank)):
|
94 |
rank[i] = None
|
95 |
vote_level = 0
|
96 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
|
97 |
|
|
|
98 |
def reset_submit(rank):
|
99 |
for i in range(len(rank)):
|
100 |
if rank[i] == None:
|
101 |
return disable_btn
|
102 |
return enable_btn
|
103 |
|
|
|
104 |
def reset_mode(mode):
|
105 |
|
106 |
if mode == "Best":
|
@@ -116,8 +124,12 @@ def reset_mode(mode):
|
|
116 |
(gr.Textbox(value="Best", visible=False, interactive=False),)
|
117 |
else:
|
118 |
raise ValueError("Undefined mode")
|
|
|
|
|
119 |
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
|
120 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
|
|
|
|
|
121 |
def get_json_filename(conv_id):
|
122 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
|
123 |
if not os.path.exists(output_dir):
|
@@ -127,6 +139,7 @@ def get_json_filename(conv_id):
|
|
127 |
print(output_file)
|
128 |
return output_file
|
129 |
|
|
|
130 |
def get_img_filename(conv_id, i):
|
131 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
|
132 |
if not os.path.exists(output_dir):
|
@@ -135,6 +148,7 @@ def get_img_filename(conv_id, i):
|
|
135 |
print(output_file)
|
136 |
return output_file
|
137 |
|
|
|
138 |
def vote_submit(states, textbox, rank, request: gr.Request):
|
139 |
conv_id = states[0].conv_id
|
140 |
|
@@ -149,6 +163,7 @@ def vote_submit(states, textbox, rank, request: gr.Request):
|
|
149 |
}
|
150 |
fout.write(json.dumps(data) + "\n")
|
151 |
|
|
|
152 |
def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
153 |
conv_id = states[0].conv_id
|
154 |
output_dir = create_remote_directory(conv_id)
|
@@ -167,6 +182,7 @@ def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
|
167 |
from .update_skill import update_skill
|
168 |
update_skill(rank, [x.model_name for x in states])
|
169 |
|
|
|
170 |
def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
|
171 |
conv_id = states[0].conv_id
|
172 |
output_dir = create_remote_directory(conv_id, video=True)
|
@@ -186,6 +202,7 @@ def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_in
|
|
186 |
from .update_skill_video import update_skill_video
|
187 |
update_skill_video(rank, [x.model_name for x in states])
|
188 |
|
|
|
189 |
def submit_response_igm(
|
190 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
|
191 |
):
|
@@ -205,6 +222,8 @@ def submit_response_igm(
|
|
205 |
gr.Markdown(state2.model_name, visible=True),
|
206 |
gr.Markdown(state3.model_name, visible=True)
|
207 |
) + (disable_btn,)
|
|
|
|
|
208 |
def submit_response_vg(
|
209 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
|
210 |
):
|
@@ -223,6 +242,8 @@ def submit_response_vg(
|
|
223 |
gr.Markdown(state2.model_name, visible=True),
|
224 |
gr.Markdown(state3.model_name, visible=True)
|
225 |
) + (disable_btn,)
|
|
|
|
|
226 |
def submit_response_rank_igm(
|
227 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
228 |
):
|
@@ -246,6 +267,8 @@ def submit_response_rank_igm(
|
|
246 |
)
|
247 |
else:
|
248 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
|
|
|
|
249 |
def submit_response_rank_vg(
|
250 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
251 |
):
|
@@ -269,6 +292,7 @@ def submit_response_rank_vg(
|
|
269 |
else:
|
270 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
271 |
|
|
|
272 |
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
|
273 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
274 |
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
|
@@ -318,6 +342,7 @@ def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig
|
|
318 |
|
319 |
return chatbot + [rank_str] + ["right"] + [rank]
|
320 |
|
|
|
321 |
def text_response_rank_vg(vote_textbox):
|
322 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
323 |
rank = [None, None, None, None]
|
@@ -336,6 +361,7 @@ def text_response_rank_vg(vote_textbox):
|
|
336 |
|
337 |
return [rank_str] + ["right"] + [rank]
|
338 |
|
|
|
339 |
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
|
340 |
base_image = Image.fromarray(image).convert("RGBA")
|
341 |
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
@@ -369,12 +395,15 @@ def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text
|
|
369 |
|
370 |
base_image = base_image.convert("RGB")
|
371 |
return base_image
|
|
|
|
|
372 |
def add_green_border(image):
|
373 |
border_color = (0, 255, 0) # RGB for green
|
374 |
border_size = 10 # Size of the border
|
375 |
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
|
376 |
return img_with_border
|
377 |
|
|
|
378 |
def check_textbox(textbox):
|
379 |
if textbox=="":
|
380 |
return False
|
|
|
8 |
from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
|
9 |
import json
|
10 |
|
11 |
+
|
12 |
def reset_level(Top_btn):
|
13 |
if Top_btn == "Top 1":
|
14 |
level = 0
|
|
|
20 |
level = 3
|
21 |
return level
|
22 |
|
23 |
+
|
24 |
def reset_rank(windows, rank, vote_level):
|
25 |
if windows == "Model A":
|
26 |
rank[0] = vote_level
|
|
|
32 |
rank[3] = vote_level
|
33 |
return rank
|
34 |
|
35 |
+
|
36 |
def reset_btn_rank(windows, rank, btn, vote_level):
|
37 |
if windows == "Model A" and btn == "1":
|
38 |
rank[0] = 0
|
|
|
76 |
vote_level = 3
|
77 |
return (rank, vote_level)
|
78 |
|
79 |
+
|
80 |
def reset_vote_text(rank):
|
81 |
rank_str = ""
|
82 |
for i in range(len(rank)):
|
|
|
87 |
rank_str = rank_str + " "
|
88 |
return rank_str
|
89 |
|
90 |
+
|
91 |
def clear_rank(rank, vote_level):
|
92 |
for i in range(len(rank)):
|
93 |
rank[i] = None
|
94 |
vote_level = 0
|
95 |
return rank, vote_level
|
96 |
|
97 |
+
|
98 |
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
|
99 |
for i in range(len(rank)):
|
100 |
rank[i] = None
|
101 |
vote_level = 0
|
102 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
|
103 |
|
104 |
+
|
105 |
def reset_submit(rank):
|
106 |
for i in range(len(rank)):
|
107 |
if rank[i] == None:
|
108 |
return disable_btn
|
109 |
return enable_btn
|
110 |
|
111 |
+
|
112 |
def reset_mode(mode):
|
113 |
|
114 |
if mode == "Best":
|
|
|
124 |
(gr.Textbox(value="Best", visible=False, interactive=False),)
|
125 |
else:
|
126 |
raise ValueError("Undefined mode")
|
127 |
+
|
128 |
+
|
129 |
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
|
130 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
|
131 |
+
|
132 |
+
|
133 |
def get_json_filename(conv_id):
|
134 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
|
135 |
if not os.path.exists(output_dir):
|
|
|
139 |
print(output_file)
|
140 |
return output_file
|
141 |
|
142 |
+
|
143 |
def get_img_filename(conv_id, i):
|
144 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
|
145 |
if not os.path.exists(output_dir):
|
|
|
148 |
print(output_file)
|
149 |
return output_file
|
150 |
|
151 |
+
|
152 |
def vote_submit(states, textbox, rank, request: gr.Request):
|
153 |
conv_id = states[0].conv_id
|
154 |
|
|
|
163 |
}
|
164 |
fout.write(json.dumps(data) + "\n")
|
165 |
|
166 |
+
|
167 |
def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
168 |
conv_id = states[0].conv_id
|
169 |
output_dir = create_remote_directory(conv_id)
|
|
|
182 |
from .update_skill import update_skill
|
183 |
update_skill(rank, [x.model_name for x in states])
|
184 |
|
185 |
+
|
186 |
def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
|
187 |
conv_id = states[0].conv_id
|
188 |
output_dir = create_remote_directory(conv_id, video=True)
|
|
|
202 |
from .update_skill_video import update_skill_video
|
203 |
update_skill_video(rank, [x.model_name for x in states])
|
204 |
|
205 |
+
|
206 |
def submit_response_igm(
|
207 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
|
208 |
):
|
|
|
222 |
gr.Markdown(state2.model_name, visible=True),
|
223 |
gr.Markdown(state3.model_name, visible=True)
|
224 |
) + (disable_btn,)
|
225 |
+
|
226 |
+
|
227 |
def submit_response_vg(
|
228 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
|
229 |
):
|
|
|
242 |
gr.Markdown(state2.model_name, visible=True),
|
243 |
gr.Markdown(state3.model_name, visible=True)
|
244 |
) + (disable_btn,)
|
245 |
+
|
246 |
+
|
247 |
def submit_response_rank_igm(
|
248 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
249 |
):
|
|
|
267 |
)
|
268 |
else:
|
269 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
270 |
+
|
271 |
+
|
272 |
def submit_response_rank_vg(
|
273 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
274 |
):
|
|
|
292 |
else:
|
293 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
294 |
|
295 |
+
|
296 |
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
|
297 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
298 |
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
|
|
|
342 |
|
343 |
return chatbot + [rank_str] + ["right"] + [rank]
|
344 |
|
345 |
+
|
346 |
def text_response_rank_vg(vote_textbox):
|
347 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
348 |
rank = [None, None, None, None]
|
|
|
361 |
|
362 |
return [rank_str] + ["right"] + [rank]
|
363 |
|
364 |
+
|
365 |
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
|
366 |
base_image = Image.fromarray(image).convert("RGBA")
|
367 |
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
|
|
395 |
|
396 |
base_image = base_image.convert("RGB")
|
397 |
return base_image
|
398 |
+
|
399 |
+
|
400 |
def add_green_border(image):
|
401 |
border_color = (0, 255, 0) # RGB for green
|
402 |
border_size = 10 # Size of the border
|
403 |
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
|
404 |
return img_with_border
|
405 |
|
406 |
+
|
407 |
def check_textbox(textbox):
|
408 |
if textbox=="":
|
409 |
return False
|
serve/leaderboard.py
CHANGED
@@ -40,12 +40,14 @@ def make_leaderboard_md():
|
|
40 |
"""
|
41 |
return leaderboard_md
|
42 |
|
|
|
43 |
def make_leaderboard_video_md():
|
44 |
leaderboard_md = f"""
|
45 |
# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
|
46 |
"""
|
47 |
return leaderboard_md
|
48 |
|
|
|
49 |
def model_hyperlink(model_name, link):
|
50 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
51 |
|
@@ -89,11 +91,13 @@ def make_disclaimer_md():
|
|
89 |
'''
|
90 |
return disclaimer_md
|
91 |
|
|
|
92 |
def make_arena_leaderboard_data(results):
|
93 |
import pandas as pd
|
94 |
df = pd.DataFrame(results)
|
95 |
return df
|
96 |
|
|
|
97 |
def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
|
98 |
with open(score_result_file, "r") as json_file:
|
99 |
data = json.load(json_file)
|
|
|
40 |
"""
|
41 |
return leaderboard_md
|
42 |
|
43 |
+
|
44 |
def make_leaderboard_video_md():
|
45 |
leaderboard_md = f"""
|
46 |
# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
|
47 |
"""
|
48 |
return leaderboard_md
|
49 |
|
50 |
+
|
51 |
def model_hyperlink(model_name, link):
|
52 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
53 |
|
|
|
91 |
'''
|
92 |
return disclaimer_md
|
93 |
|
94 |
+
|
95 |
def make_arena_leaderboard_data(results):
|
96 |
import pandas as pd
|
97 |
df = pd.DataFrame(results)
|
98 |
return df
|
99 |
|
100 |
+
|
101 |
def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
|
102 |
with open(score_result_file, "r") as json_file:
|
103 |
data = json.load(json_file)
|
serve/update_skill.py
CHANGED
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
|
|
9 |
sys.path.append('../')
|
10 |
from model.models import IMAGE_GENERATION_MODELS
|
11 |
|
|
|
12 |
ssh_skill_client = None
|
13 |
sftp_skill_client = None
|
14 |
|
|
|
15 |
def create_ssh_skill_client(server, port, user, password):
|
16 |
global ssh_skill_client, sftp_skill_client
|
17 |
ssh_skill_client = paramiko.SSHClient()
|
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
|
|
23 |
transport.set_keepalive(60)
|
24 |
|
25 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
|
|
|
|
26 |
def is_connected():
|
27 |
global ssh_skill_client, sftp_skill_client
|
28 |
if ssh_skill_client is None or sftp_skill_client is None:
|
29 |
return False
|
30 |
-
# 检查SSH连接是否正常
|
31 |
if not ssh_skill_client.get_transport().is_active():
|
32 |
return False
|
33 |
-
# 检查SFTP连接是否正常
|
34 |
try:
|
35 |
-
sftp_skill_client.listdir('.')
|
36 |
except Exception as e:
|
37 |
print(f"Error checking SFTP connection: {e}")
|
38 |
return False
|
39 |
return True
|
|
|
|
|
40 |
def ucb_score(trueskill_diff, t, n):
|
41 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
42 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
43 |
return ucb
|
44 |
|
|
|
45 |
def update_trueskill(ratings, ranks):
|
46 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
47 |
return new_ratings
|
48 |
|
|
|
49 |
def serialize_rating(rating):
|
50 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
51 |
|
|
|
52 |
def deserialize_rating(rating_dict):
|
53 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
54 |
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
66 |
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
|
67 |
f.write(json_data)
|
68 |
|
|
|
69 |
def load_json_via_sftp():
|
70 |
global sftp_skill_client
|
71 |
if not is_connected():
|
|
|
9 |
sys.path.append('../')
|
10 |
from model.models import IMAGE_GENERATION_MODELS
|
11 |
|
12 |
+
|
13 |
ssh_skill_client = None
|
14 |
sftp_skill_client = None
|
15 |
|
16 |
+
|
17 |
def create_ssh_skill_client(server, port, user, password):
|
18 |
global ssh_skill_client, sftp_skill_client
|
19 |
ssh_skill_client = paramiko.SSHClient()
|
|
|
25 |
transport.set_keepalive(60)
|
26 |
|
27 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
28 |
+
|
29 |
+
|
30 |
def is_connected():
|
31 |
global ssh_skill_client, sftp_skill_client
|
32 |
if ssh_skill_client is None or sftp_skill_client is None:
|
33 |
return False
|
|
|
34 |
if not ssh_skill_client.get_transport().is_active():
|
35 |
return False
|
|
|
36 |
try:
|
37 |
+
sftp_skill_client.listdir('.')
|
38 |
except Exception as e:
|
39 |
print(f"Error checking SFTP connection: {e}")
|
40 |
return False
|
41 |
return True
|
42 |
+
|
43 |
+
|
44 |
def ucb_score(trueskill_diff, t, n):
|
45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
47 |
return ucb
|
48 |
|
49 |
+
|
50 |
def update_trueskill(ratings, ranks):
|
51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
52 |
return new_ratings
|
53 |
|
54 |
+
|
55 |
def serialize_rating(rating):
|
56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
57 |
|
58 |
+
|
59 |
def deserialize_rating(rating_dict):
|
60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
61 |
|
|
|
73 |
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
|
74 |
f.write(json_data)
|
75 |
|
76 |
+
|
77 |
def load_json_via_sftp():
|
78 |
global sftp_skill_client
|
79 |
if not is_connected():
|
serve/update_skill_video.py
CHANGED
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
|
|
9 |
sys.path.append('../')
|
10 |
from model.models import VIDEO_GENERATION_MODELS
|
11 |
|
|
|
12 |
ssh_skill_client = None
|
13 |
sftp_skill_client = None
|
14 |
|
|
|
15 |
def create_ssh_skill_client(server, port, user, password):
|
16 |
global ssh_skill_client, sftp_skill_client
|
17 |
ssh_skill_client = paramiko.SSHClient()
|
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
|
|
23 |
transport.set_keepalive(60)
|
24 |
|
25 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
|
|
|
|
26 |
def is_connected():
|
27 |
global ssh_skill_client, sftp_skill_client
|
28 |
if ssh_skill_client is None or sftp_skill_client is None:
|
29 |
return False
|
30 |
-
# 检查SSH连接是否正常
|
31 |
if not ssh_skill_client.get_transport().is_active():
|
32 |
return False
|
33 |
-
# 检查SFTP连接是否正常
|
34 |
try:
|
35 |
-
sftp_skill_client.listdir('.')
|
36 |
except Exception as e:
|
37 |
print(f"Error checking SFTP connection: {e}")
|
38 |
return False
|
39 |
return True
|
|
|
|
|
40 |
def ucb_score(trueskill_diff, t, n):
|
41 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
42 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
43 |
return ucb
|
44 |
|
|
|
45 |
def update_trueskill(ratings, ranks):
|
46 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
47 |
return new_ratings
|
48 |
|
|
|
49 |
def serialize_rating(rating):
|
50 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
51 |
|
|
|
52 |
def deserialize_rating(rating_dict):
|
53 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
54 |
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
66 |
with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
67 |
f.write(json_data)
|
68 |
|
|
|
69 |
def load_json_via_sftp():
|
70 |
global sftp_skill_client
|
71 |
if not is_connected():
|
|
|
9 |
sys.path.append('../')
|
10 |
from model.models import VIDEO_GENERATION_MODELS
|
11 |
|
12 |
+
|
13 |
ssh_skill_client = None
|
14 |
sftp_skill_client = None
|
15 |
|
16 |
+
|
17 |
def create_ssh_skill_client(server, port, user, password):
|
18 |
global ssh_skill_client, sftp_skill_client
|
19 |
ssh_skill_client = paramiko.SSHClient()
|
|
|
25 |
transport.set_keepalive(60)
|
26 |
|
27 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
28 |
+
|
29 |
+
|
30 |
def is_connected():
|
31 |
global ssh_skill_client, sftp_skill_client
|
32 |
if ssh_skill_client is None or sftp_skill_client is None:
|
33 |
return False
|
|
|
34 |
if not ssh_skill_client.get_transport().is_active():
|
35 |
return False
|
|
|
36 |
try:
|
37 |
+
sftp_skill_client.listdir('.')
|
38 |
except Exception as e:
|
39 |
print(f"Error checking SFTP connection: {e}")
|
40 |
return False
|
41 |
return True
|
42 |
+
|
43 |
+
|
44 |
def ucb_score(trueskill_diff, t, n):
|
45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
47 |
return ucb
|
48 |
|
49 |
+
|
50 |
def update_trueskill(ratings, ranks):
|
51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
52 |
return new_ratings
|
53 |
|
54 |
+
|
55 |
def serialize_rating(rating):
|
56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
57 |
|
58 |
+
|
59 |
def deserialize_rating(rating_dict):
|
60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
61 |
|
|
|
73 |
with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
74 |
f.write(json_data)
|
75 |
|
76 |
+
|
77 |
def load_json_via_sftp():
|
78 |
global sftp_skill_client
|
79 |
if not is_connected():
|
serve/upload.py
CHANGED
@@ -9,15 +9,18 @@ import random
|
|
9 |
import concurrent.futures
|
10 |
from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
|
11 |
|
|
|
12 |
ssh_client = None
|
13 |
sftp_client = None
|
14 |
sftp_client_imgs = None
|
15 |
|
|
|
16 |
def open_sftp(i=0):
|
17 |
global ssh_client
|
18 |
sftp_client = ssh_client.open_sftp()
|
19 |
return sftp_client
|
20 |
|
|
|
21 |
def create_ssh_client(server, port, user, password):
|
22 |
global ssh_client, sftp_client, sftp_client_imgs
|
23 |
ssh_client = paramiko.SSHClient()
|
@@ -40,22 +43,22 @@ def is_connected():
|
|
40 |
global ssh_client, sftp_client
|
41 |
if ssh_client is None or sftp_client is None:
|
42 |
return False
|
43 |
-
# 检查SSH连接是否正常
|
44 |
if not ssh_client.get_transport().is_active():
|
45 |
return False
|
46 |
-
# 检查SFTP连接是否正常
|
47 |
try:
|
48 |
-
sftp_client.listdir('.')
|
49 |
except Exception as e:
|
50 |
print(f"Error checking SFTP connection: {e}")
|
51 |
return False
|
52 |
return True
|
53 |
|
|
|
54 |
def get_image_from_url(image_url):
|
55 |
response = requests.get(image_url)
|
56 |
response.raise_for_status() # success
|
57 |
return Image.open(io.BytesIO(response.content))
|
58 |
|
|
|
59 |
# def get_random_mscoco_prompt():
|
60 |
# global sftp_client
|
61 |
# if not is_connected():
|
@@ -70,6 +73,7 @@ def get_image_from_url(image_url):
|
|
70 |
# print("\n")
|
71 |
# return content
|
72 |
|
|
|
73 |
def get_random_mscoco_prompt():
|
74 |
|
75 |
file_path = './coco_prompt.txt'
|
@@ -79,6 +83,7 @@ def get_random_mscoco_prompt():
|
|
79 |
random_line = random.choice(lines).strip()
|
80 |
return random_line
|
81 |
|
|
|
82 |
def get_random_video_prompt(root_dir):
|
83 |
subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
|
84 |
if not subdirs:
|
@@ -96,6 +101,7 @@ def get_random_video_prompt(root_dir):
|
|
96 |
raise NotImplementedError
|
97 |
return selected_dir, prompt
|
98 |
|
|
|
99 |
def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
100 |
def is_directory(sftp, path):
|
101 |
try:
|
@@ -150,6 +156,7 @@ def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
|
150 |
ssh.close()
|
151 |
return prompt, local_path[1:]
|
152 |
|
|
|
153 |
def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
154 |
def is_directory(sftp, path):
|
155 |
try:
|
@@ -204,6 +211,7 @@ def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
|
204 |
ssh.close()
|
205 |
return prompt, [Image.open(path) for path in local_path[1:]]
|
206 |
|
|
|
207 |
def create_remote_directory(remote_directory, video=False):
|
208 |
global ssh_client
|
209 |
if not is_connected():
|
@@ -220,6 +228,7 @@ def create_remote_directory(remote_directory, video=False):
|
|
220 |
print(f"Directory {remote_directory} created successfully.")
|
221 |
return log_dir
|
222 |
|
|
|
223 |
def upload_images(i, image_list, output_file_list, sftp_client):
|
224 |
with sftp_client as sftp:
|
225 |
if isinstance(image_list[i], str):
|
@@ -233,7 +242,6 @@ def upload_images(i, image_list, output_file_list, sftp_client):
|
|
233 |
print(f"Successfully uploaded image to {output_file_list[i]}")
|
234 |
|
235 |
|
236 |
-
|
237 |
def upload_ssh_all(states, output_dir, data, data_path):
|
238 |
global sftp_client
|
239 |
global sftp_client_imgs
|
@@ -246,7 +254,6 @@ def upload_ssh_all(states, output_dir, data, data_path):
|
|
246 |
output_file_list.append(output_file)
|
247 |
image_list.append(states[i].output)
|
248 |
|
249 |
-
|
250 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
251 |
futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
|
252 |
|
@@ -257,6 +264,7 @@ def upload_ssh_all(states, output_dir, data, data_path):
|
|
257 |
print(f"Successfully uploaded JSON data to {data_path}")
|
258 |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
259 |
|
|
|
260 |
def upload_ssh_data(data, data_path):
|
261 |
global sftp_client
|
262 |
global sftp_client_imgs
|
|
|
9 |
import concurrent.futures
|
10 |
from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
|
11 |
|
12 |
+
|
13 |
ssh_client = None
|
14 |
sftp_client = None
|
15 |
sftp_client_imgs = None
|
16 |
|
17 |
+
|
18 |
def open_sftp(i=0):
|
19 |
global ssh_client
|
20 |
sftp_client = ssh_client.open_sftp()
|
21 |
return sftp_client
|
22 |
|
23 |
+
|
24 |
def create_ssh_client(server, port, user, password):
|
25 |
global ssh_client, sftp_client, sftp_client_imgs
|
26 |
ssh_client = paramiko.SSHClient()
|
|
|
43 |
global ssh_client, sftp_client
|
44 |
if ssh_client is None or sftp_client is None:
|
45 |
return False
|
|
|
46 |
if not ssh_client.get_transport().is_active():
|
47 |
return False
|
|
|
48 |
try:
|
49 |
+
sftp_client.listdir('.')
|
50 |
except Exception as e:
|
51 |
print(f"Error checking SFTP connection: {e}")
|
52 |
return False
|
53 |
return True
|
54 |
|
55 |
+
|
56 |
def get_image_from_url(image_url):
|
57 |
response = requests.get(image_url)
|
58 |
response.raise_for_status() # success
|
59 |
return Image.open(io.BytesIO(response.content))
|
60 |
|
61 |
+
|
62 |
# def get_random_mscoco_prompt():
|
63 |
# global sftp_client
|
64 |
# if not is_connected():
|
|
|
73 |
# print("\n")
|
74 |
# return content
|
75 |
|
76 |
+
|
77 |
def get_random_mscoco_prompt():
|
78 |
|
79 |
file_path = './coco_prompt.txt'
|
|
|
83 |
random_line = random.choice(lines).strip()
|
84 |
return random_line
|
85 |
|
86 |
+
|
87 |
def get_random_video_prompt(root_dir):
|
88 |
subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
|
89 |
if not subdirs:
|
|
|
101 |
raise NotImplementedError
|
102 |
return selected_dir, prompt
|
103 |
|
104 |
+
|
105 |
def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
106 |
def is_directory(sftp, path):
|
107 |
try:
|
|
|
156 |
ssh.close()
|
157 |
return prompt, local_path[1:]
|
158 |
|
159 |
+
|
160 |
def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
161 |
def is_directory(sftp, path):
|
162 |
try:
|
|
|
211 |
ssh.close()
|
212 |
return prompt, [Image.open(path) for path in local_path[1:]]
|
213 |
|
214 |
+
|
215 |
def create_remote_directory(remote_directory, video=False):
|
216 |
global ssh_client
|
217 |
if not is_connected():
|
|
|
228 |
print(f"Directory {remote_directory} created successfully.")
|
229 |
return log_dir
|
230 |
|
231 |
+
|
232 |
def upload_images(i, image_list, output_file_list, sftp_client):
|
233 |
with sftp_client as sftp:
|
234 |
if isinstance(image_list[i], str):
|
|
|
242 |
print(f"Successfully uploaded image to {output_file_list[i]}")
|
243 |
|
244 |
|
|
|
245 |
def upload_ssh_all(states, output_dir, data, data_path):
|
246 |
global sftp_client
|
247 |
global sftp_client_imgs
|
|
|
254 |
output_file_list.append(output_file)
|
255 |
image_list.append(states[i].output)
|
256 |
|
|
|
257 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
258 |
futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
|
259 |
|
|
|
264 |
print(f"Successfully uploaded JSON data to {data_path}")
|
265 |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
266 |
|
267 |
+
|
268 |
def upload_ssh_data(data, data_path):
|
269 |
global sftp_client
|
270 |
global sftp_client_imgs
|