ZhangYuhan commited on
Commit
6dc2db5
1 Parent(s): a8e2ac2

update serve

Browse files
arena_elo/elo_rating/__init__.py ADDED
File without changes
arena_elo/elo_rating/basic_stats.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import code
3
+ import datetime
4
+ import json
5
+ import os
6
+ from pytz import timezone
7
+ import time
8
+
9
+ import pandas as pd # pandas>=2.0.3
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from tqdm import tqdm
13
+
14
+ NUM_SERVERS = 1
15
+ LOG_ROOT_DIR = os.getenv("LOGDIR", None)
16
+ if LOG_ROOT_DIR is None:
17
+ raise ValueError("LOGDIR environment variable not set, please set it by `export LOGDIR=...`")
18
+
19
+ def get_log_files(max_num_files=None):
20
+ log_root = os.path.expanduser(LOG_ROOT_DIR)
21
+ filenames = []
22
+ if NUM_SERVERS == 1:
23
+ for filename in os.listdir(log_root):
24
+ if filename.endswith("-conv.json"):
25
+ filepath = f"{log_root}/{filename}"
26
+ name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
27
+ filenames.append(name_tstamp_tuple)
28
+ else:
29
+ for i in range(NUM_SERVERS):
30
+ for filename in os.listdir(f"{log_root}/server{i}"):
31
+ if filename.endswith("-conv.json"):
32
+ filepath = f"{log_root}/server{i}/{filename}"
33
+ name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
34
+ filenames.append(name_tstamp_tuple)
35
+ # sort by tstamp
36
+ filenames = sorted(filenames, key=lambda x: x[1])
37
+ filenames = [x[0] for x in filenames]
38
+
39
+ max_num_files = max_num_files or len(filenames)
40
+ filenames = filenames[-max_num_files:]
41
+ return filenames
42
+
43
+
44
+ def load_log_files(filename):
45
+ data = []
46
+ for retry in range(5):
47
+ try:
48
+ lines = open(filename).readlines()
49
+ break
50
+ except FileNotFoundError:
51
+ time.sleep(2)
52
+
53
+ for l in lines:
54
+ row = json.loads(l)
55
+ data.append(
56
+ dict(
57
+ type=row["type"],
58
+ tstamp=row["tstamp"],
59
+ model=row.get("model", ""),
60
+ models=row.get("models", ["", ""]),
61
+ )
62
+ )
63
+ return data
64
+
65
+
66
+ def load_log_files_parallel(log_files, num_threads=16):
67
+ data_all = []
68
+ from multiprocessing import Pool
69
+
70
+ with Pool(num_threads) as p:
71
+ ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files)))
72
+ for ret in ret_all:
73
+ data_all.extend(ret)
74
+ return data_all
75
+
76
+
77
+ def get_anony_vote_df(df):
78
+ anony_vote_df = df[
79
+ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
80
+ ]
81
+ anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
82
+ return anony_vote_df
83
+
84
+
85
+ def merge_counts(series, on, names):
86
+ ret = pd.merge(series[0], series[1], on=on)
87
+ for i in range(2, len(series)):
88
+ ret = pd.merge(ret, series[i], on=on)
89
+ ret = ret.reset_index()
90
+ old_names = list(ret.columns)[-len(series) :]
91
+ rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
92
+ ret = ret.rename(columns=rename)
93
+ return ret
94
+
95
+
96
+ def report_basic_stats(log_files):
97
+ df_all = load_log_files_parallel(log_files)
98
+ df_all = pd.DataFrame(df_all)
99
+ now_t = df_all["tstamp"].max()
100
+ df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
101
+ df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
102
+ anony_vote_df_all = get_anony_vote_df(df_all)
103
+
104
+ # Chat trends
105
+ chat_dates = [
106
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
107
+ "%Y-%m-%d"
108
+ )
109
+ for x in df_all[df_all["type"] == "chat"]["tstamp"]
110
+ ]
111
+ chat_dates_counts = pd.value_counts(chat_dates)
112
+ vote_dates = [
113
+ datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
114
+ "%Y-%m-%d"
115
+ )
116
+ for x in anony_vote_df_all["tstamp"]
117
+ ]
118
+ vote_dates_counts = pd.value_counts(vote_dates)
119
+ chat_dates_bar = go.Figure(
120
+ data=[
121
+ go.Bar(
122
+ name="Anony. Vote",
123
+ x=vote_dates_counts.index,
124
+ y=vote_dates_counts,
125
+ text=[f"{val:.0f}" for val in vote_dates_counts],
126
+ textposition="auto",
127
+ ),
128
+ go.Bar(
129
+ name="Chat",
130
+ x=chat_dates_counts.index,
131
+ y=chat_dates_counts,
132
+ text=[f"{val:.0f}" for val in chat_dates_counts],
133
+ textposition="auto",
134
+ ),
135
+ ]
136
+ )
137
+ chat_dates_bar.update_layout(
138
+ barmode="stack",
139
+ xaxis_title="Dates",
140
+ yaxis_title="Count",
141
+ height=300,
142
+ width=1200,
143
+ )
144
+
145
+ # Model call counts
146
+ model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
147
+ model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
148
+ model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
149
+ model_hist = merge_counts(
150
+ [model_hist_all, model_hist_1_day, model_hist_1_hour],
151
+ on="model",
152
+ names=["All", "Last Day", "Last Hour"],
153
+ )
154
+ model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
155
+
156
+ # Action counts
157
+ action_hist_all = df_all["type"].value_counts()
158
+ action_hist_1_day = df_1_day["type"].value_counts()
159
+ action_hist_1_hour = df_1_hour["type"].value_counts()
160
+ action_hist = merge_counts(
161
+ [action_hist_all, action_hist_1_day, action_hist_1_hour],
162
+ on="type",
163
+ names=["All", "Last Day", "Last Hour"],
164
+ )
165
+ action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
166
+
167
+ # Anony vote counts
168
+ anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
169
+ anony_vote_df_1_day = get_anony_vote_df(df_1_day)
170
+ anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
171
+ # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
172
+ # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
173
+ anony_vote_hist = merge_counts(
174
+ [anony_vote_hist_all, anony_vote_hist_1_day],
175
+ on="type",
176
+ names=["All", "Last Day"],
177
+ )
178
+ anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
179
+
180
+ # Last 24 hours
181
+ chat_1_day = df_1_day[df_1_day["type"] == "chat"]
182
+ num_chats_last_24_hours = []
183
+ base = df_1_day["tstamp"].min()
184
+ for i in range(24, 0, -1):
185
+ left = base + (i - 1) * 3600
186
+ right = base + i * 3600
187
+ num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
188
+ num_chats_last_24_hours.append(num)
189
+ times = [
190
+ datetime.datetime.fromtimestamp(
191
+ base + i * 3600, tz=timezone("US/Pacific")
192
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
193
+ for i in range(24, 0, -1)
194
+ ]
195
+ last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
196
+ last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
197
+
198
+ # Last update datetime
199
+ last_updated_tstamp = now_t
200
+ last_updated_datetime = datetime.datetime.fromtimestamp(
201
+ last_updated_tstamp, tz=timezone("US/Pacific")
202
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
203
+
204
+ # code.interact(local=locals())
205
+
206
+ return {
207
+ "chat_dates_bar": chat_dates_bar,
208
+ "model_hist_md": model_hist_md,
209
+ "action_hist_md": action_hist_md,
210
+ "anony_vote_hist_md": anony_vote_hist_md,
211
+ "num_chats_last_24_hours": last_24_hours_md,
212
+ "last_updated_datetime": last_updated_datetime,
213
+ }
214
+
215
+
216
+ if __name__ == "__main__":
217
+ parser = argparse.ArgumentParser()
218
+ parser.add_argument("--max-num-files", type=int)
219
+ args = parser.parse_args()
220
+
221
+ log_files = get_log_files(args.max_num_files)
222
+ basic_stats = report_basic_stats(log_files)
223
+
224
+ print(basic_stats["action_hist_md"] + "\n")
225
+ print(basic_stats["model_hist_md"] + "\n")
226
+ print(basic_stats["anony_vote_hist_md"] + "\n")
227
+ print(basic_stats["num_chats_last_24_hours"] + "\n")
arena_elo/elo_rating/clean_battle_data.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Clean chatbot arena battle log.
3
+
4
+ Usage:
5
+ python3 clean_battle_data.py --mode conv_release
6
+ """
7
+ import argparse
8
+ import datetime
9
+ import json
10
+ import os
11
+ import sys
12
+ from pytz import timezone
13
+ import time
14
+ import PIL
15
+ from PIL import ImageFile
16
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
17
+
18
+ from tqdm import tqdm
19
+
20
+ from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR
21
+ from .utils import detect_language, get_time_stamp_from_date
22
+
23
+ VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
24
+ IDENTITY_WORDS = [
25
+ "vicuna",
26
+ "lmsys",
27
+ "koala",
28
+ "uc berkeley",
29
+ "open assistant",
30
+ "laion",
31
+ "chatglm",
32
+ "chatgpt",
33
+ "gpt-4",
34
+ "openai",
35
+ "anthropic",
36
+ "claude",
37
+ "bard",
38
+ "palm",
39
+ "lamda",
40
+ "google",
41
+ "llama",
42
+ "qianwan",
43
+ "alibaba",
44
+ "mistral",
45
+ "zhipu",
46
+ "KEG lab",
47
+ "01.AI",
48
+ "AI2",
49
+ "Tülu",
50
+ "Tulu",
51
+ "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
52
+ "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
53
+ "API REQUEST ERROR. Please increase the number of max tokens.",
54
+ "**API REQUEST ERROR** Reason: The response was blocked.",
55
+ "**API REQUEST ERROR**",
56
+ ]
57
+
58
+ for i in range(len(IDENTITY_WORDS)):
59
+ IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
60
+
61
+
62
+ def remove_html(raw):
63
+ if raw.startswith("<h3>"):
64
+ return raw[raw.find(": ") + 2 : -len("</h3>\n")]
65
+ if raw.startswith("### Model A: ") or raw.startswith("### Model B: "):
66
+ return raw[13:]
67
+ return raw
68
+
69
+
70
+ def to_openai_format(messages):
71
+ roles = ["user", "assistant"]
72
+ ret = []
73
+ for i, x in enumerate(messages):
74
+ ret.append({"role": roles[i % 2], "content": x[1]})
75
+ return ret
76
+
77
+
78
+ def replace_model_name(old_name, tstamp):
79
+ replace_dict = {
80
+ "bard": "palm-2",
81
+ "claude-v1": "claude-1",
82
+ "claude-instant-v1": "claude-instant-1",
83
+ "oasst-sft-1-pythia-12b": "oasst-pythia-12b",
84
+ "claude-2": "claude-2.0",
85
+ "PlayGroundV2": "Playground v2",
86
+ }
87
+ if old_name in ["gpt-4", "gpt-3.5-turbo"]:
88
+ if tstamp > 1687849200:
89
+ return old_name + "-0613"
90
+ else:
91
+ return old_name + "-0314"
92
+ if old_name in replace_dict:
93
+ return replace_dict[old_name]
94
+ return old_name
95
+
96
+
97
+ def read_file(filename):
98
+ data = []
99
+ for retry in range(5):
100
+ try:
101
+ # lines = open(filename).readlines()
102
+ for l in open(filename):
103
+ row = json.loads(l)
104
+ if row["type"] in VOTES:
105
+ data.append(row)
106
+ break
107
+ except FileNotFoundError:
108
+ time.sleep(2)
109
+ return data
110
+
111
+
112
+ def read_file_parallel(log_files, num_threads=16):
113
+ data_all = []
114
+ from multiprocessing import Pool
115
+
116
+ with Pool(num_threads) as p:
117
+ ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
118
+ for ret in ret_all:
119
+ data_all.extend(ret)
120
+ return data_all
121
+
122
+ def load_image(image_path):
123
+ try:
124
+ return PIL.Image.open(image_path)
125
+ except:
126
+ return None
127
+
128
+ def clean_battle_data(
129
+ log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False, mode="simple", task_name="t2s"
130
+ ):
131
+ data = read_file_parallel(log_files, num_threads=16)
132
+
133
+ convert_type = {
134
+ "leftvote": "model_a",
135
+ "rightvote": "model_b",
136
+ "tievote": "tie",
137
+ "bothbad_vote": "tie (bothbad)",
138
+ }
139
+
140
+ all_models = set()
141
+ all_ips = dict()
142
+ ct_anony = 0
143
+ ct_invalid = 0
144
+ ct_leaked_identity = 0
145
+ ct_banned = 0
146
+ battles = []
147
+ for row in tqdm(data, desc="Cleaning"):
148
+ if row["models"][0] is None or row["models"][1] is None:
149
+ continue
150
+
151
+ # Resolve model names
152
+ models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
153
+ if "model_name" in row["states"][0]:
154
+ models_hidden = [
155
+ row["states"][0]["model_name"],
156
+ row["states"][1]["model_name"],
157
+ ]
158
+ if models_hidden[0] is None:
159
+ models_hidden = models_public
160
+ else:
161
+ models_hidden = models_public
162
+
163
+ if (models_public[0] == "" and models_public[1] != "") or (
164
+ models_public[1] == "" and models_public[0] != ""
165
+ ):
166
+ ct_invalid += 1
167
+ continue
168
+
169
+ if models_public[0] == "" or models_public[0] == "Model A":
170
+ anony = True
171
+ models = models_hidden
172
+ ct_anony += 1
173
+ else:
174
+ anony = False
175
+ models = models_public
176
+ if not models_public == models_hidden:
177
+ ct_invalid += 1
178
+ continue
179
+
180
+ # # Detect langauge
181
+ # state = row["states"][0]
182
+ # if state["offset"] >= len(state["messages"]):
183
+ # ct_invalid += 1
184
+ # continue
185
+ # lang_code = detect_language(state["messages"][state["offset"]][1])
186
+
187
+ # # Drop conversations if the model names are leaked
188
+ # leaked_identity = False
189
+ # messages = ""
190
+ # for i in range(2):
191
+ # state = row["states"][i]
192
+ # for turn_idx, (role, msg) in enumerate(
193
+ # state["messages"][state["offset"] :]
194
+ # ):
195
+ # if msg:
196
+ # messages += msg.lower()
197
+ # for word in IDENTITY_WORDS:
198
+ # if word in messages:
199
+ # leaked_identity = True
200
+ # break
201
+
202
+ # if leaked_identity:
203
+ # ct_leaked_identity += 1
204
+ # continue
205
+
206
+ # Replace bard with palm
207
+ if task_name == "image_editing":
208
+ if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models):
209
+ # print(f"Invalid model names: {models}")
210
+ ct_invalid += 1
211
+ continue
212
+ models = [x[len("imagenhub_"):-len("_edition")] for x in models]
213
+ elif task_name == "t2i_generation":
214
+ if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models):
215
+ # print(f"Invalid model names: {models}")
216
+ ct_invalid += 1
217
+ continue
218
+ # models = [x[len("imagenhub_"):-len("_generation")] for x in models]
219
+ for i, model_name in enumerate(models):
220
+ if model_name.startswith("imagenhub_"):
221
+ models[i] = model_name[len("imagenhub_"):-len("_generation")]
222
+
223
+ else:
224
+ raise ValueError(f"Invalid task_name: {task_name}")
225
+ models = [replace_model_name(m, row["tstamp"]) for m in models]
226
+
227
+ # Exclude certain models
228
+ if exclude_model_names and any(x in exclude_model_names for x in models):
229
+ ct_invalid += 1
230
+ continue
231
+
232
+ # if models[0] not in model_infos or models[1] not in model_infos:
233
+ # continue
234
+
235
+ # # Exclude votes before the starting date
236
+ # if model_infos and (model_infos[models[0]]["starting_from"] > row["tstamp"] or model_infos[models[1]]["starting_from"] > row["tstamp"]):
237
+ # print(f"Invalid vote before the valid starting date for {models[0]} and {models[1]}")
238
+ # ct_invalid += 1
239
+ # continue
240
+
241
+
242
+
243
+ if mode == "conv_release":
244
+ # assert the two images are the same
245
+ date = datetime.datetime.fromtimestamp(row["tstamp"], tz=timezone("US/Pacific")).strftime("%Y-%m-%d") # 2024-02-29
246
+ image_path_format = f"{LOG_ROOT_DIR}/{date}-convinput_images/input_image_"
247
+ image_path_0 = image_path_format + str(row["states"][0]["conv_id"]) + ".png"
248
+ image_path_1 = image_path_format + str(row["states"][1]["conv_id"]) + ".png"
249
+ if not os.path.exists(image_path_0) or not os.path.exists(image_path_1):
250
+ print(f"Image not found for {image_path_0} or {image_path_1}")
251
+ ct_invalid += 1
252
+ continue
253
+
254
+ image_0 = load_image(image_path_0)
255
+ image_1 = load_image(image_path_1)
256
+ if image_0 is None or image_1 is None:
257
+ print(f"Image not found for {image_path_0} or {image_path_1}")
258
+ ct_invalid += 1
259
+ continue
260
+ if image_0.tobytes() != image_1.tobytes():
261
+ print(f"Image not the same for {image_path_0} and {image_path_1}")
262
+ ct_invalid += 1
263
+ continue
264
+
265
+
266
+ question_id = row["states"][0]["conv_id"]
267
+ # conversation_a = to_openai_format(
268
+ # row["states"][0]["messages"][row["states"][0]["offset"] :]
269
+ # )
270
+ # conversation_b = to_openai_format(
271
+ # row["states"][1]["messages"][row["states"][1]["offset"] :]
272
+ # )
273
+
274
+ ip = row["ip"]
275
+ if ip not in all_ips:
276
+ all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)}
277
+ all_ips[ip]["count"] += 1
278
+ if sanitize_ip:
279
+ user_id = f"arena_user_{all_ips[ip]['sanitized_id']}"
280
+ else:
281
+ user_id = f"{all_ips[ip]['ip']}"
282
+
283
+ if ban_ip_list is not None and ip in ban_ip_list:
284
+ ct_banned += 1
285
+ continue
286
+
287
+ # Save the results
288
+ battles.append(
289
+ dict(
290
+ question_id=question_id,
291
+ model_a=models[0],
292
+ model_b=models[1],
293
+ winner=convert_type[row["type"]],
294
+ judge=f"arena_user_{user_id}",
295
+ # conversation_a=conversation_a,
296
+ # conversation_b=conversation_b,
297
+ # turn=len(conversation_a) // 2,
298
+ anony=anony,
299
+ # language=lang_code,
300
+ tstamp=row["tstamp"],
301
+ )
302
+ )
303
+
304
+ all_models.update(models_hidden)
305
+ battles.sort(key=lambda x: x["tstamp"])
306
+ last_updated_tstamp = battles[-1]["tstamp"]
307
+
308
+ last_updated_datetime = datetime.datetime.fromtimestamp(
309
+ last_updated_tstamp, tz=timezone("US/Pacific")
310
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
311
+
312
+ print(
313
+ f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
314
+ f"#leaked_identity: {ct_leaked_identity} "
315
+ f"#banned: {ct_banned} "
316
+ )
317
+ print(f"#battles: {len(battles)}, #anony: {ct_anony}")
318
+ print(f"#models: {len(all_models)}, {all_models}")
319
+ print(f"last-updated: {last_updated_datetime}")
320
+
321
+ if ban_ip_list is not None:
322
+ for ban_ip in ban_ip_list:
323
+ if ban_ip in all_ips:
324
+ del all_ips[ban_ip]
325
+ print("Top 30 IPs:")
326
+ print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30])
327
+ return battles
328
+
329
+
330
+ if __name__ == "__main__":
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument("--max-num-files", type=int)
333
+ parser.add_argument(
334
+ "--mode", type=str, choices=["simple", "conv_release"], default="simple"
335
+ )
336
+ parser.add_argument("--task_name", type=str, choices=["t2s", "i2s"])
337
+ parser.add_argument("--exclude-model-names", type=str, nargs="+")
338
+ parser.add_argument("--ban-ip-file", type=str)
339
+ parser.add_argument("--sanitize-ip", action="store_true", default=False)
340
+ args = parser.parse_args()
341
+
342
+ log_files = get_log_files(args.max_num_files)
343
+ ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None
344
+
345
+ battles = clean_battle_data(
346
+ log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip, args.mode, args.task_name
347
+ )
348
+ last_updated_tstamp = battles[-1]["tstamp"]
349
+ cutoff_date = datetime.datetime.fromtimestamp(
350
+ last_updated_tstamp, tz=timezone("US/Pacific")
351
+ ).strftime("%Y%m%d")
352
+
353
+ if args.mode == "simple":
354
+ for x in battles:
355
+ for key in [
356
+ "conversation_a",
357
+ "conversation_b",
358
+ "question_id",
359
+ ]:
360
+ if key in x:
361
+ del x[key]
362
+ print("Samples:")
363
+ for i in range(min(4, len(battles))):
364
+ print(battles[i])
365
+ output = f"clean_battle_{args.task_name}_{cutoff_date}.json"
366
+ elif args.mode == "conv_release":
367
+ # new_battles = []
368
+ # for x in battles:
369
+ # if not x["anony"]:
370
+ # continue
371
+ # for key in []:
372
+ # del x[key]
373
+ # new_battles.append(x)
374
+ # battles = new_battles
375
+ output = f"clean_battle_{args.task_name}_conv_{cutoff_date}.json"
376
+
377
+ with open(output, "w") as fout:
378
+ json.dump(battles, fout, indent=2, ensure_ascii=False)
379
+ print(f"Write cleaned data to {output}")
380
+
381
+ with open("cut_off_date.txt", "w") as fout:
382
+ fout.write(cutoff_date)
arena_elo/elo_rating/elo_analysis.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import datetime
4
+ import json
5
+ import math
6
+ import pickle
7
+ from pytz import timezone
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ from tqdm import tqdm
13
+
14
+ from .model_registry import get_model_info
15
+ from .basic_stats import get_log_files
16
+ from .clean_battle_data import clean_battle_data
17
+
18
+ pd.options.display.float_format = "{:.2f}".format
19
+
20
+
21
+ def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
22
+ rating = defaultdict(lambda: INIT_RATING)
23
+
24
+ for rd, model_a, model_b, winner in battles[
25
+ ["model_a", "model_b", "winner"]
26
+ ].itertuples():
27
+ ra = rating[model_a]
28
+ rb = rating[model_b]
29
+ ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
30
+ eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
31
+ if winner == "model_a":
32
+ sa = 1
33
+ elif winner == "model_b":
34
+ sa = 0
35
+ elif winner == "tie" or winner == "tie (bothbad)":
36
+ sa = 0.5
37
+ else:
38
+ raise Exception(f"unexpected vote {winner}")
39
+ rating[model_a] += K * (sa - ea)
40
+ rating[model_b] += K * (1 - sa - eb)
41
+
42
+ return dict(rating)
43
+
44
+
45
+ def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
46
+ rows = []
47
+ for i in tqdm(range(num_round), desc="bootstrap"):
48
+ tmp_battles = battles.sample(frac=1.0, replace=True)
49
+ rows.append(func_compute_elo(tmp_battles))
50
+ df = pd.DataFrame(rows)
51
+ return df[df.median().sort_values(ascending=False).index]
52
+
53
+
54
+ def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000):
55
+ from sklearn.linear_model import LogisticRegression
56
+
57
+ models = pd.concat([df["model_a"], df["model_b"]]).unique()
58
+ models = pd.Series(np.arange(len(models)), index=models)
59
+
60
+ # duplicate battles
61
+ df = pd.concat([df, df], ignore_index=True)
62
+ p = len(models.index)
63
+ n = df.shape[0]
64
+
65
+ X = np.zeros([n, p])
66
+ X[np.arange(n), models[df["model_a"]]] = +math.log(BASE)
67
+ X[np.arange(n), models[df["model_b"]]] = -math.log(BASE)
68
+
69
+ # one A win => two A win
70
+ Y = np.zeros(n)
71
+ Y[df["winner"] == "model_a"] = 1.0
72
+
73
+ # one tie => one A win + one B win
74
+ # find tie + tie (both bad) index
75
+ tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)")
76
+ tie_idx[len(tie_idx) // 2 :] = False
77
+ Y[tie_idx] = 1.0
78
+
79
+ lr = LogisticRegression(fit_intercept=False)
80
+ lr.fit(X, Y)
81
+
82
+ elo_scores = SCALE * lr.coef_[0] + INIT_RATING
83
+ # calibrate llama-13b to 800 if applicable
84
+ if "llama-13b" in models.index:
85
+ elo_scores += 800 - elo_scores[models["llama-13b"]]
86
+ return pd.Series(elo_scores, index=models.index).sort_values(ascending=False)
87
+
88
+
89
+ def get_median_elo_from_bootstrap(bootstrap_df):
90
+ median = dict(bootstrap_df.quantile(0.5))
91
+ median = {k: int(v + 0.5) for k, v in median.items()}
92
+ return median
93
+
94
+
95
+ def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None):
96
+ # Times each model wins as Model A
97
+ a_win_ptbl = pd.pivot_table(
98
+ battles[battles["winner"] == "model_a"],
99
+ index="model_a",
100
+ columns="model_b",
101
+ aggfunc="size",
102
+ fill_value=0,
103
+ )
104
+
105
+ # Table counting times each model wins as Model B
106
+ b_win_ptbl = pd.pivot_table(
107
+ battles[battles["winner"] == "model_b"],
108
+ index="model_a",
109
+ columns="model_b",
110
+ aggfunc="size",
111
+ fill_value=0,
112
+ )
113
+
114
+ # Table counting number of A-B pairs
115
+ num_battles_ptbl = pd.pivot_table(
116
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
117
+ )
118
+
119
+ # Computing the proportion of wins for each model as A and as B
120
+ # against all other models
121
+ row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
122
+ num_battles_ptbl + num_battles_ptbl.T
123
+ )
124
+
125
+ if model_order is None:
126
+ prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
127
+ model_order = list(prop_wins.keys())
128
+
129
+ if limit_show_number is not None:
130
+ model_order = model_order[:limit_show_number]
131
+
132
+ # Arrange ordering according to proprition of wins
133
+ row_beats_col = row_beats_col_freq.loc[model_order, model_order]
134
+ return row_beats_col
135
+
136
+
137
+ def visualize_leaderboard_table(rating):
138
+ models = list(rating.keys())
139
+ models.sort(key=lambda k: -rating[k])
140
+
141
+ emoji_dict = {
142
+ 1: "🥇",
143
+ 2: "🥈",
144
+ 3: "🥉",
145
+ }
146
+
147
+ md = ""
148
+ md += "| Rank | Model | Elo Rating | Description |\n"
149
+ md += "| --- | --- | --- | --- |\n"
150
+ for i, model in enumerate(models):
151
+ rank = i + 1
152
+ minfo = get_model_info(model)
153
+ emoji = emoji_dict.get(rank, "")
154
+ md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
155
+
156
+ return md
157
+
158
+
159
+ def visualize_pairwise_win_fraction(battles, model_order):
160
+ row_beats_col = compute_pairwise_win_fraction(battles, model_order)
161
+ fig = px.imshow(
162
+ row_beats_col,
163
+ color_continuous_scale="RdBu",
164
+ text_auto=".2f",
165
+ height=700,
166
+ width=700,
167
+ )
168
+ fig.update_layout(
169
+ xaxis_title="Model B",
170
+ yaxis_title="Model A",
171
+ xaxis_side="top",
172
+ title_y=0.07,
173
+ title_x=0.5,
174
+ )
175
+ fig.update_traces(
176
+ hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Fraction of A Wins: %{z}<extra></extra>"
177
+ )
178
+
179
+ return fig
180
+
181
+
182
+ def visualize_battle_count(battles, model_order):
183
+ ptbl = pd.pivot_table(
184
+ battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
185
+ )
186
+ battle_counts = ptbl + ptbl.T
187
+ fig = px.imshow(
188
+ battle_counts.loc[model_order, model_order],
189
+ text_auto=True,
190
+ height=700,
191
+ width=700,
192
+ )
193
+ fig.update_layout(
194
+ xaxis_title="Model B",
195
+ yaxis_title="Model A",
196
+ xaxis_side="top",
197
+ title_y=0.07,
198
+ title_x=0.5,
199
+ )
200
+ fig.update_traces(
201
+ hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Count: %{z}<extra></extra>"
202
+ )
203
+ return fig
204
+
205
+
206
+ def visualize_average_win_rate(battles, limit_show_number):
207
+ row_beats_col_freq = compute_pairwise_win_fraction(
208
+ battles, None, limit_show_number=limit_show_number
209
+ )
210
+ fig = px.bar(
211
+ row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
212
+ text_auto=".2f",
213
+ height=500,
214
+ width=700,
215
+ )
216
+ fig.update_layout(
217
+ yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
218
+ )
219
+ return fig
220
+
221
+
222
+ def visualize_bootstrap_elo_rating(df, df_final, limit_show_number):
223
+ bars = (
224
+ pd.DataFrame(
225
+ dict(
226
+ lower=df.quantile(0.025),
227
+ rating=df_final,
228
+ upper=df.quantile(0.975),
229
+ )
230
+ )
231
+ .reset_index(names="model")
232
+ .sort_values("rating", ascending=False)
233
+ )
234
+ bars = bars[:limit_show_number]
235
+ bars["error_y"] = bars["upper"] - bars["rating"]
236
+ bars["error_y_minus"] = bars["rating"] - bars["lower"]
237
+ bars["rating_rounded"] = np.round(bars["rating"], 2)
238
+ fig = px.scatter(
239
+ bars,
240
+ x="model",
241
+ y="rating",
242
+ error_y="error_y",
243
+ error_y_minus="error_y_minus",
244
+ text="rating_rounded",
245
+ height=500,
246
+ width=700,
247
+ )
248
+ fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
249
+ return fig
250
+
251
+
252
+ def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap=100, anony_only=True):
253
+ battles = pd.DataFrame(battles_json)
254
+ battles = battles.sort_values(ascending=True, by=["tstamp"])
255
+ # Only use anonymous votes
256
+ if anony_only:
257
+ battles = battles[battles["anony"]].reset_index(drop=True)
258
+ battles_no_ties = battles[~battles["winner"].str.contains("tie")]
259
+
260
+ # Online update
261
+ elo_rating_online = compute_elo(battles)
262
+
263
+ if rating_system == "bt":
264
+ bootstrap_df = get_bootstrap_result(
265
+ battles, compute_elo_mle_with_tie, num_round=num_bootstrap
266
+ )
267
+ elo_rating_final = compute_elo_mle_with_tie(battles)
268
+ elif rating_system == "elo":
269
+ bootstrap_df = get_bootstrap_result(
270
+ battles, compute_elo, num_round=num_bootstrap
271
+ )
272
+ elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
273
+ elo_rating_final = elo_rating_median
274
+
275
+ model_order = list(elo_rating_final.keys())
276
+ model_order.sort(key=lambda k: -elo_rating_final[k])
277
+
278
+ limit_show_number = 25 # limit show number to make plots smaller
279
+ model_order = model_order[:limit_show_number]
280
+
281
+ # leaderboard_table_df: elo rating, variance, 95% interval, number of battles
282
+ leaderboard_table_df = pd.DataFrame(
283
+ {
284
+ "rating": elo_rating_final,
285
+ "variance": bootstrap_df.var(),
286
+ "rating_q975": bootstrap_df.quantile(0.975),
287
+ "rating_q025": bootstrap_df.quantile(0.025),
288
+ "num_battles": battles["model_a"].value_counts()
289
+ + battles["model_b"].value_counts(),
290
+ }
291
+ )
292
+
293
+ # Plots
294
+ leaderboard_table = visualize_leaderboard_table(elo_rating_final)
295
+ win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
296
+ battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
297
+ average_win_rate_bar = visualize_average_win_rate(
298
+ battles_no_ties, limit_show_number
299
+ )
300
+ bootstrap_elo_rating = visualize_bootstrap_elo_rating(
301
+ bootstrap_df, elo_rating_final, limit_show_number
302
+ )
303
+
304
+ last_updated_tstamp = battles["tstamp"].max()
305
+ last_updated_datetime = datetime.datetime.fromtimestamp(
306
+ last_updated_tstamp, tz=timezone("US/Pacific")
307
+ ).strftime("%Y-%m-%d %H:%M:%S %Z")
308
+
309
+ return {
310
+ "rating_system": rating_system,
311
+ "elo_rating_online": elo_rating_online,
312
+ "elo_rating_final": elo_rating_final,
313
+ "leaderboard_table": leaderboard_table,
314
+ "win_fraction_heatmap": win_fraction_heatmap,
315
+ "battle_count_heatmap": battle_count_heatmap,
316
+ "average_win_rate_bar": average_win_rate_bar,
317
+ "bootstrap_elo_rating": bootstrap_elo_rating,
318
+ "last_updated_datetime": last_updated_datetime,
319
+ "last_updated_tstamp": last_updated_tstamp,
320
+ "bootstrap_df": bootstrap_df,
321
+ "leaderboard_table_df": leaderboard_table_df,
322
+ }
323
+
324
+
325
+ def pretty_print_elo_rating(rating):
326
+ model_order = list(rating.keys())
327
+ model_order.sort(key=lambda k: -rating[k])
328
+ for i, model in enumerate(model_order):
329
+ print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
330
+
331
+
332
+ if __name__ == "__main__":
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--clean-battle-file", type=str)
335
+ parser.add_argument("--max-num-files", type=int)
336
+ parser.add_argument("--num-bootstrap", type=int, default=100)
337
+ parser.add_argument(
338
+ "--rating-system", type=str, choices=["bt", "elo"], default="bt"
339
+ )
340
+ parser.add_argument("--exclude-tie", action="store_true", default=False)
341
+ args = parser.parse_args()
342
+
343
+ np.random.seed(42)
344
+
345
+ if args.clean_battle_file:
346
+ # Read data from a cleaned battle files
347
+ battles = pd.read_json(args.clean_battle_file)
348
+ else:
349
+ # Read data from all log files
350
+ log_files = get_log_files(args.max_num_files)
351
+ battles = clean_battle_data(log_files)
352
+
353
+ anony_results = report_elo_analysis_results(
354
+ battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap, anony_only=True
355
+ )
356
+ full_results = report_elo_analysis_results(
357
+ battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap, anony_only=False
358
+ )
359
+
360
+
361
+ print("# Online Elo")
362
+ pretty_print_elo_rating(anony_results["elo_rating_online"])
363
+ print("# Median")
364
+ pretty_print_elo_rating(anony_results["elo_rating_final"])
365
+ print(f"last update : {anony_results['last_updated_datetime']}")
366
+
367
+ last_updated_tstamp = full_results["last_updated_tstamp"]
368
+ cutoff_date = datetime.datetime.fromtimestamp(
369
+ last_updated_tstamp, tz=timezone("US/Pacific")
370
+ ).strftime("%Y%m%d")
371
+
372
+
373
+ results = {
374
+ "anony": anony_results,
375
+ "full": full_results,
376
+ }
377
+ with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
378
+ pickle.dump(results, fout)
arena_elo/elo_rating/generate_leaderboard.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import json
3
+ import pandas as pd
4
+ import pickle
5
+
6
+
7
+ def main(
8
+ model_info_file: str,
9
+ elo_rating_pkl: str,
10
+ output_csv: str
11
+ ):
12
+ model_info = json.load(open(model_info_file))
13
+
14
+ with open(elo_rating_pkl, "rb") as fin:
15
+ elo_rating_results = pickle.load(fin)
16
+
17
+ anony_elo_rating_results = elo_rating_results["anony"]
18
+ full_elo_rating_results = elo_rating_results["full"]
19
+ anony_leaderboard_data = anony_elo_rating_results["leaderboard_table_df"]
20
+ full_leaderboard_data = full_elo_rating_results["leaderboard_table_df"]
21
+
22
+ # Model,MT-bench (score),Arena Elo rating,MMLU,License,Link
23
+ fields = ["key", "Model", "Arena Elo rating (anony)", "Arena Elo rating (full)", "License", "Organization", "Link"]
24
+ # set Organization and license to empty for now
25
+ all_models = anony_leaderboard_data.index.tolist()
26
+
27
+ for model in all_models:
28
+ if not model in model_info:
29
+ model_info[model] = {}
30
+ model_info[model]["License"] = "N/A"
31
+ model_info[model]["Organization"] = "N/A"
32
+ model_info[model]["Link"] = "N/A"
33
+ model_info[model]["Model"] = model
34
+ model_info[model]["key"] = model
35
+
36
+ if model in anony_leaderboard_data.index:
37
+ model_info[model]["Arena Elo rating (anony)"] = anony_leaderboard_data.loc[model, "rating"]
38
+ else:
39
+ model_info[model]["Arena Elo rating (anony)"] = 0
40
+
41
+ if model in full_elo_rating_results["leaderboard_table_df"].index:
42
+ model_info[model]["Arena Elo rating (full)"] = full_leaderboard_data.loc[model, "rating"]
43
+ else:
44
+ model_info[model]["Arena Elo rating (full)"] = 0
45
+ # if model in anony_leaderboard_data.index:
46
+ # model_info[model]["Arena Elo rating"] = anony_leaderboard_data.loc[model, "rating"]
47
+ # else:
48
+ # model_info[model]["Arena Elo rating"] = 0
49
+
50
+ final_model_info = {}
51
+ for model in model_info:
52
+ if "Model" in model_info[model]:
53
+ final_model_info[model] = model_info[model]
54
+ model_info = final_model_info
55
+
56
+ exclude_keys = ['starting_from']
57
+ for key in exclude_keys:
58
+ for model in model_info:
59
+ if key in model_info[model]:
60
+ del model_info[model][key]
61
+ df = pd.DataFrame(model_info).T
62
+ df = df[fields]
63
+ # sort by anony rating
64
+ df = df.sort_values(by=["Arena Elo rating (anony)"], ascending=False)
65
+ df.to_csv(output_csv, index=False)
66
+ print("Leaderboard data saved to", output_csv)
67
+ print(df)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ fire.Fire(main)
arena_elo/elo_rating/inspect_conv_rating.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import code
3
+ import datetime
4
+ import json
5
+ import os
6
+ from pytz import timezone
7
+ import time
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ import csv
12
+
13
+ import base64
14
+ from icecream import ic
15
+ from openai import OpenAI
16
+
17
+ # Function to encode the image
18
+ def encode_image(image_path):
19
+ with open(image_path, "rb") as image_file:
20
+ return base64.b64encode(image_file.read()).decode('utf-8')
21
+
22
+ def get_log_files(max_num_files=None):
23
+ dates = []
24
+ for month in [2, 3]:
25
+ for day in range(1, 32):
26
+ dates.append(f"2024-{month:02d}-{day:02d}")
27
+
28
+ num_servers = 1
29
+ filenames = []
30
+ for d in dates:
31
+ for i in range(num_servers):
32
+ # name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
33
+ name = os.path.expanduser(f"vision-arena-logs/{d}-conv.json")
34
+ if os.path.exists(name):
35
+ filenames.append(name)
36
+ max_num_files = max_num_files or len(filenames)
37
+ filenames = filenames[-max_num_files:]
38
+ return filenames
39
+
40
+
41
+ def pretty_print_conversation(messages):
42
+ for role, msg in messages:
43
+ print(f"[[{role}]]: {msg}")
44
+
45
+
46
+ def get_gpt4v_response(client, img_bs64=None, text_prompt="", use_vision=False):
47
+ if use_vision:
48
+ response = client.chat.completions.create(
49
+ model="gpt-4-vision-preview",
50
+ messages=[
51
+ {
52
+ "role": "user",
53
+ "content": [
54
+ {"type": "text", "text": text_prompt},
55
+ {
56
+ "type": "image_url",
57
+ "image_url": {
58
+ "url": f"data:image/jpeg;base64,{img_bs64}"
59
+ }
60
+ },
61
+ ],
62
+ }
63
+ ],
64
+ max_tokens=100,
65
+ )
66
+ else:
67
+ response = client.chat.completions.create(
68
+ model="gpt-4-vision-preview",
69
+ messages=[
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {"type": "text", "text": text_prompt},
74
+ ],
75
+ }
76
+ ],
77
+ max_tokens=100,
78
+ )
79
+ return response.choices[0].message.content
80
+
81
+ task_template_map = {
82
+ "image_caption": "Give me the semantic alignment score between the given image and the given caption: \"{generated_sentence}\" on a scale of 0-100. Only reply the score value.",
83
+ "vqa": "Rate the answer correctness regarding the question within the context of the given image on a scale of 0-100. Only reply the score value.",
84
+ "pair_rate_old": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nGiven the instruction and the image, please compare the correctness of responses A and B. Reply with \"leftvote\" if you find A better, \"rightvote\" if B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both responses are equally satisfactory. If you are unable to make a decision, please reply with \"NA\".",
85
+ "pair_rate_wexplanation": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
86
+ "pair_rate": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Reply with \"leftvote\" if you find assistant A better, \"rightvote\" if assistant B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both assistants provide equally satisfactory answers. If you are unable to make a decision, please reply with \"NA\"."
87
+ }
88
+
89
+ def inspect_convs(log_files):
90
+ ic(log_files)
91
+ data = []
92
+ total_vote = 0
93
+ correct_vote = 0
94
+
95
+ client = OpenAI()
96
+ with open('all_pairvote_log_wgpt_prtchatbot.csv', 'w', newline='') as csvfile:
97
+ # fieldnames = ['tstamp', 'type', 'model_1', 'model_2', 'template_name_1', 'template_name_2', 'system_message_1', 'system_message_2', 'role_1', 'role_2', 'instruction_1', 'instruction_2', 'message_1', 'message_2', 'offset_1', 'offset_2', 'conv_id_1', 'conv_id_2', 'model_name_1', 'model_name_2', 'ip']
98
+ fieldnames = ['tstamp', 'type', 'models', 'states', 'ip', 'gpt_vote']
99
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
100
+
101
+ # Write the header
102
+ writer.writeheader()
103
+
104
+ for filename in tqdm(log_files, desc="read files"):
105
+ for retry in range(5):
106
+ try:
107
+ lines = open(filename).readlines()
108
+ break
109
+ except FileNotFoundError:
110
+ time.sleep(2)
111
+
112
+ for l in lines:
113
+ row = json.loads(l)
114
+
115
+ if "states" not in row:
116
+ continue
117
+ if row["type"] not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
118
+ continue
119
+
120
+ model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
121
+
122
+
123
+ # Iterate through each state and write the relevant information
124
+ if not len(row["states"][0]['messages']): continue
125
+ # ic(row["states"][0]['messages'][1][1])
126
+
127
+ if row["states"][0]['messages'][1][1] is None or row["states"][1]['messages'][1][1] is None or "NETWORK ERROR" in row["states"][0]['messages'][1][1] or "NETWORK ERROR" in row["states"][1]['messages'][1][1]: continue
128
+ total_vote += 1
129
+ # row = {
130
+ # 'tstamp': row['tstamp'],
131
+ # 'type': row['type'],
132
+ # 'model_1': row['models'][0],
133
+ # 'model_2': row['models'][1],
134
+ # 'template_name_1': row["states"][0]['template_name'],
135
+ # 'system_message_1': row["states"][0]['system_message'],
136
+ # 'template_name_2': row["states"][1]['template_name'],
137
+ # 'system_message_2': row["states"][1]['system_message'],
138
+ # 'role_1': row["states"][0]['roles'],
139
+ # 'role_2': row["states"][1]['roles'],
140
+ # 'instruction_1': row["states"][0]['messages'][0][1],
141
+ # 'instruction_2': row["states"][1]['messages'][0][1],
142
+ # 'message_1': row["states"][0]['messages'][1][1],
143
+ # 'message_2': row["states"][1]['messages'][1][1],
144
+ # 'offset_1': row["states"][0]['offset'],
145
+ # 'offset_2': row["states"][1]['offset'],
146
+ # 'conv_id_1': row["states"][0]['conv_id'],
147
+ # 'conv_id_2': row["states"][1]['conv_id'],
148
+ # 'model_name_1': row["states"][0]['model_name'],
149
+ # 'model_name_2': row["states"][1]['model_name'],
150
+ # 'ip': row['ip']
151
+ # }
152
+ # writer.writerow(row)
153
+ # Convert complex objects to JSON strings
154
+ # TODO: check two image are the same
155
+ conv_id = row["states"][0]['conv_id']
156
+ image_path = os.path.join("/local/home/yujielu/project/Arena-Elo/vision-arena-logs", os.path.basename(filename)[:-5]+"input_images", f"input_image_{conv_id}.png")
157
+ if not os.path.exists(image_path):
158
+ response = "NA"
159
+ ic(image_path)
160
+ else:
161
+ base64_image = encode_image(image_path)
162
+ left_response = row["states"][0]['messages'][1][1]
163
+ right_response = row["states"][1]['messages'][1][1]
164
+ sep = "-" * 20
165
+ instruction = row["states"][0]['messages'][0][1]
166
+ generated_sentence = f"[The Start of Assistant A’s Answer]\n{left_response}\n[The End of Assistant A’s Answer]\n\n[The Start of Assistant B’s Answer]\n{right_response}\n[The End of Assistant B’s Answer]"
167
+ text_prompt = task_template_map["pair_rate"].format(instruction=instruction, generated_sentence=generated_sentence)
168
+ # ic(text_prompt)
169
+ try:
170
+ response = get_gpt4v_response(client, img_bs64=base64_image, text_prompt=text_prompt, use_vision=True)
171
+ except:
172
+ ic(">>> skip")
173
+ response = "NA"
174
+
175
+ # response = get_gpt4v_response(client, img_bs64=base64_image, text_prompt=text_prompt, use_vision=True)
176
+ ic(row['type'], response)
177
+ if response.strip() not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
178
+ response = "NA"
179
+ # ic(generated_sentence)
180
+
181
+ # if row['type'] == "leftvote":
182
+ # row['type'] = "A"
183
+ # elif row['type'] == "rightvote":
184
+ # row['type'] = "B"
185
+ # elif row['type'] in ["bothbad_vote", "tievote"]:
186
+ # row['type'] = "C"
187
+ if row['type'] == response.strip():
188
+ correct_vote += 1
189
+ row['models'] = json.dumps(row['models'])
190
+ row['states'] = json.dumps(row['states'], ensure_ascii=False)
191
+ row['gpt_vote'] = response
192
+
193
+ # Write the modified row to the CSV file
194
+ writer.writerow(row)
195
+ # if row["type"] == "leftvote":
196
+ # winner, loser = model_names[0], model_names[1]
197
+ # winner_conv, loser_conv = row["states"][0], row["states"][1]
198
+ # elif row["type"] == "rightvote":
199
+ # loser, winner = model_names[0], model_names[1]
200
+ # loser_conv, winner_conv = row["states"][0], row["states"][1]
201
+
202
+ # if loser == "llava-v1.5-13b" and winner == "llava-v1.5-13b":
203
+ # print("=" * 20)
204
+ # print(f"Winner: {winner}")
205
+ # pretty_print_conversation(winner_conv["messages"])
206
+ # print(f"Loser: {loser}")
207
+ # pretty_print_conversation(loser_conv["messages"])
208
+ # print("=" * 20)
209
+ # input()
210
+ # if row['type'] == 'bothbad_vote':
211
+ # from icecream import ic
212
+ # ic(model_names)
213
+ # if row["type"] == "bothbad_vote" and "gpt-4-vision-preview" in model_names:
214
+ # print("=" * 20)
215
+ # print(f"Model A: {model_names[0]}")
216
+ # pretty_print_conversation(row["states"][0]["messages"])
217
+ # print(f"Model B: {model_names[1]}")
218
+ # pretty_print_conversation(row["states"][1]["messages"])
219
+ # print("=" * 20)
220
+ # input()
221
+ # if correct_vote >= 300: break
222
+ ic(total_vote, correct_vote)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ parser = argparse.ArgumentParser()
227
+ parser.add_argument("--max-num-files", type=int)
228
+ args = parser.parse_args()
229
+
230
+ log_files = get_log_files(args.max_num_files)
231
+
232
+
233
+
234
+ inspect_convs(log_files)
arena_elo/elo_rating/inspect_cost.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import time
3
+ import json
4
+ from collections import defaultdict
5
+ from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR
6
+ from .utils import detect_language, get_time_stamp_from_date, get_input_image_path, load_image_from_path
7
+ from tqdm import tqdm
8
+ VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote", "chat"]
9
+
10
+
11
+ def remove_html(raw):
12
+ if raw.startswith("<h3>"):
13
+ return raw[raw.find(": ") + 2 : -len("</h3>\n")]
14
+ if raw.startswith("### Model A: ") or raw.startswith("### Model B: "):
15
+ return raw[13:]
16
+ return raw
17
+
18
+
19
+ def read_file(filename):
20
+ data = []
21
+ for retry in range(5):
22
+ try:
23
+ # lines = open(filename).readlines()
24
+ for l in open(filename):
25
+ row = json.loads(l)
26
+ if row["type"] in VOTES:
27
+ data.append(row)
28
+ break
29
+ except FileNotFoundError:
30
+ time.sleep(2)
31
+ return data
32
+
33
+
34
+ def read_file_parallel(log_files, num_threads=16):
35
+ data_all = []
36
+ from multiprocessing import Pool
37
+
38
+ with Pool(num_threads) as p:
39
+ ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
40
+ for ret in ret_all:
41
+ data_all.extend(ret)
42
+ return data_all
43
+
44
+ def num_tokens(s:str):
45
+ if s is None:
46
+ return 0
47
+ return len(s) / 4
48
+
49
+ def main(
50
+ ):
51
+ log_files = get_log_files()
52
+ data = read_file_parallel(log_files)
53
+
54
+ all_model_counts = defaultdict(int)
55
+ all_model_input_tokens_counts = defaultdict(list)
56
+ all_model_output_tokens_counts = defaultdict(list)
57
+ all_model_image_sizes = defaultdict(list)
58
+ chat_battle_counts = defaultdict(int)
59
+ for row in tqdm(data, desc="counting"):
60
+ if row['type'] == "chat":
61
+ chat_battle_counts["chat"] += 1
62
+ all_model_counts[row['model']] += 1
63
+ tstamp = row["tstamp"]
64
+ conv_id = row["state"]["conv_id"]
65
+
66
+ image = load_image_from_path(get_input_image_path(tstamp, conv_id))
67
+ if image is None:
68
+ image_size = None
69
+ else:
70
+ image_size = load_image_from_path(get_input_image_path(tstamp, conv_id)).size
71
+ all_model_image_sizes[row['model']].append(image_size)
72
+ try:
73
+ for message in row["state"]["messages"][row["state"]["offset"] :: 2]:
74
+ all_model_input_tokens_counts[row['model']].append(num_tokens(message[1]))
75
+ for message in row["state"]["messages"][row["state"]["offset"] + 1 :: 2]:
76
+ all_model_output_tokens_counts[row['model']].append(num_tokens(message[1]))
77
+ except Exception as e:
78
+ print(row)
79
+ raise e
80
+
81
+ else:
82
+ chat_battle_counts[row['type']] += 1
83
+ if row["models"][0] is None or row["models"][1] is None:
84
+ continue
85
+
86
+ # Resolve model names
87
+ models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
88
+ if "model_name" in row["states"][0]:
89
+ models_hidden = [
90
+ row["states"][0]["model_name"],
91
+ row["states"][1]["model_name"],
92
+ ]
93
+ if models_hidden[0] is None:
94
+ models_hidden = models_public
95
+ else:
96
+ models_hidden = models_public
97
+
98
+ if (models_public[0] == "" and models_public[1] != "") or (
99
+ models_public[1] == "" and models_public[0] != ""
100
+ ):
101
+ continue
102
+
103
+ if models_public[0] == "" or models_public[0] == "Model A":
104
+ anony = True
105
+ models = models_hidden
106
+ else:
107
+ anony = False
108
+ models = models_public
109
+ if not models_public == models_hidden:
110
+ continue
111
+
112
+ all_model_counts[models[0]] += 1
113
+ all_model_counts[models[1]] += 1
114
+ tstamp = row["tstamp"]
115
+ conv_id1 = row["states"][0]["conv_id"]
116
+ conv_id2 = row["states"][1]["conv_id"]
117
+
118
+ image1 = load_image_from_path(get_input_image_path(tstamp, conv_id1))
119
+ image2 = load_image_from_path(get_input_image_path(tstamp, conv_id2))
120
+ all_model_image_sizes[models[0]].append(None if image1 is None else image1.size)
121
+ all_model_image_sizes[models[1]].append(None if image2 is None else image2.size)
122
+
123
+ for message in row["states"][0]["messages"][row["states"][0]["offset"] :: 2]:
124
+ all_model_input_tokens_counts[models[0]].append(num_tokens(message[1]))
125
+ for message in row["states"][0]["messages"][row["states"][0]["offset"] + 1 :: 2]:
126
+ all_model_output_tokens_counts[models[0]].append(num_tokens(message[1]))
127
+ for message in row["states"][1]["messages"][row["states"][1]["offset"] :: 2]:
128
+ all_model_input_tokens_counts[models[1]].append(num_tokens(message[1]))
129
+ for message in row["states"][1]["messages"][row["states"][1]["offset"] + 1 :: 2]:
130
+ all_model_output_tokens_counts[models[1]].append(num_tokens(message[1]))
131
+
132
+ print("### Chat battle counts (requests)")
133
+ print(json.dumps(chat_battle_counts, indent=4))
134
+
135
+ print("### Model counts (requests)")
136
+ print(json.dumps(all_model_counts, indent=4))
137
+
138
+ print("### Model Avg input tokens counts (tokens)")
139
+ average_input_tokens_counts = {}
140
+ for model, counts in all_model_input_tokens_counts.items():
141
+ average_input_tokens_counts[model] = sum(counts) / len(counts)
142
+ print(json.dumps(average_input_tokens_counts, indent=4))
143
+
144
+ print("### Model AVg output tokens counts (tokens)")
145
+ average_output_tokens_counts = {}
146
+ for model, counts in all_model_output_tokens_counts.items():
147
+ average_output_tokens_counts[model] = sum(counts) / len(counts)
148
+ print(json.dumps(average_output_tokens_counts, indent=4))
149
+
150
+ print("### Model Avg image sizes (height, width)")
151
+ average_image_sizes = {}
152
+ for model, sizes in all_model_image_sizes.items():
153
+ avg_height = sum([size[0] for size in sizes if size is not None]) / len(sizes)
154
+ avg_width = sum([size[1] for size in sizes if size is not None]) / len(sizes)
155
+ average_image_sizes[model] = (avg_height, avg_width)
156
+ print(json.dumps(average_image_sizes, indent=4))
157
+
158
+ print("### GPT-4V estimated cost (USD)")
159
+ gpt_4v_name = "gpt-4-vision-preview"
160
+ gpt_4v_cost = {}
161
+ gpt_4v_cost['input'] = sum(all_model_input_tokens_counts[gpt_4v_name]) / 1000 * 0.01
162
+ gpt_4v_cost['output'] = sum(all_model_output_tokens_counts[gpt_4v_name]) / 1000 * 0.03
163
+
164
+ all_image_cost = 0
165
+ for size in all_model_image_sizes[gpt_4v_name]:
166
+ if size is None:
167
+ continue
168
+ all_image_tokens = (size[0] // 512 + 1) * (size[1] // 512 + 1) * 170 + 85
169
+ all_image_cost += all_image_tokens / 1000 * 0.01
170
+ gpt_4v_cost['image'] = all_image_cost
171
+ print(json.dumps(gpt_4v_cost, indent=4))
172
+
173
+
174
+
175
+
176
+ if __name__ == "__main__":
177
+ fire.Fire(main)
arena_elo/elo_rating/inspect_elo_rating_pkl.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import plotly.graph_objects as go
3
+
4
+ def output_figure(data, figure_name="battle_count_heatmap", label="annoy"):
5
+ fig = data[label][figure_name]
6
+ fig.update_layout(
7
+ height=700,
8
+ width=700,
9
+ title={'text': f'{figure_name}', 'x': 0.5, 'y': 0.07},
10
+ xaxis_title="Model B",
11
+ yaxis_title="Model A",
12
+ # coloraxis_colorscale=[[0.0, '#0d0887'], [1.0, '#f0f921']],
13
+ margin={'t': 60}
14
+ )
15
+ fig.write_image(f"{figure_name}.png")
16
+
17
+ with open("./results/latest/elo_results.pkl",'rb') as f:
18
+ data = pickle.load(f)
19
+ print()
20
+ df = data["anony"]["leaderboard_table_df"]
21
+ # sort by rating
22
+ print(data["anony"].keys())
23
+
24
+ for figure_name in [ 'win_fraction_heatmap', 'battle_count_heatmap',]:
25
+ output_figure(data, figure_name, "anony")
26
+
27
+ df = df.sort_values(by=["rating"], ascending=False)
28
+ print(df)
29
+ df = data["full"]["leaderboard_table_df"]
30
+ # sort by rating
31
+ df = df.sort_values(by=["rating"], ascending=False)
32
+ print(df)
33
+ print('done')
arena_elo/elo_rating/model_registry.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Additional information of the models."""
2
+ from collections import namedtuple, OrderedDict
3
+ from typing import List
4
+
5
+
6
+ ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
7
+
8
+
9
+ model_info = OrderedDict()
10
+
11
+
12
+ def register_model_info(
13
+ full_names: List[str], simple_name: str, link: str, description: str
14
+ ):
15
+ info = ModelInfo(simple_name, link, description)
16
+
17
+ for full_name in full_names:
18
+ model_info[full_name] = info
19
+
20
+
21
+ def get_model_info(name: str) -> ModelInfo:
22
+ if name in model_info:
23
+ return model_info[name]
24
+ else:
25
+ # To fix this, please use `register_model_info` to register your model
26
+ return ModelInfo(
27
+ name, "", "Register the description at arena.model/model_registry.py"
28
+ )
29
+
30
+
31
+ register_model_info(
32
+ [
33
+ "IEITYuan/Yuan2-2B-Janus-hf",
34
+ "IEITYuan/Yuan2-2B-hf",
35
+ "IEITYuan/Yuan2-51B-hf",
36
+ "IEITYuan/Yuan2-102B-hf",
37
+ ],
38
+ "IEIT-Yuan2",
39
+ "https://github.com/IEIT-Yuan/Yuan-2.0",
40
+ "Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.",
41
+ )
42
+
43
+ register_model_info(
44
+ ["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],
45
+ "Mixtral of experts",
46
+ "https://mistral.ai/news/mixtral-of-experts/",
47
+ "A Mixture-of-Experts model by Mistral AI",
48
+ )
49
+
50
+ register_model_info(
51
+ ["gemini-pro"],
52
+ "Gemini",
53
+ "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
54
+ "Gemini by Google",
55
+ )
56
+
57
+ register_model_info(
58
+ ["gemini-pro-vision"],
59
+ "Gemini",
60
+ "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
61
+ "Gemini by Google",
62
+ )
63
+
64
+ register_model_info(
65
+ ["solar-10.7b-instruct-v1.0"],
66
+ "SOLAR-10.7B-Instruct",
67
+ "https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0",
68
+ "A model trained using depth up-scaling by Upstage AI",
69
+ )
70
+
71
+ register_model_info(
72
+ ["gpt-4-turbo"],
73
+ "GPT-4-Turbo",
74
+ "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
75
+ "GPT-4-Turbo by OpenAI",
76
+ )
77
+
78
+ register_model_info(
79
+ ["gpt-4-vision-preview"],
80
+ "gpt-4-vision-preview",
81
+ "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
82
+ "GPT-4(V) by OpenAI",
83
+ )
84
+
85
+ register_model_info(
86
+ ["gpt-3.5-turbo", "gpt-3.5-turbo-0314", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106"],
87
+ "GPT-3.5",
88
+ "https://platform.openai.com/docs/models/gpt-3-5",
89
+ "GPT-3.5-Turbo by OpenAI",
90
+ )
91
+
92
+ register_model_info(
93
+ ["gpt-4", "gpt-4-0314", "gpt-4-0613"],
94
+ "GPT-4",
95
+ "https://openai.com/research/gpt-4",
96
+ "GPT-4 by OpenAI",
97
+ )
98
+
99
+ register_model_info(
100
+ ["claude-2.1", "claude-2.0"],
101
+ "Claude",
102
+ "https://www.anthropic.com/index/claude-2",
103
+ "Claude 2 by Anthropic",
104
+ )
105
+
106
+ register_model_info(
107
+ ["claude-1"],
108
+ "Claude",
109
+ "https://www.anthropic.com/index/introducing-claude",
110
+ "Claude 1 by Anthropic",
111
+ )
112
+
113
+ register_model_info(
114
+ ["claude-instant-1", "claude-instant-1.2"],
115
+ "Claude Instant",
116
+ "https://www.anthropic.com/index/introducing-claude",
117
+ "Claude Instant by Anthropic",
118
+ )
119
+
120
+ register_model_info(
121
+ ["pplx-70b-online", "pplx-7b-online"],
122
+ "pplx-online-llms",
123
+ "https://blog.perplexity.ai/blog/introducing-pplx-online-llms",
124
+ "Online LLM API by Perplexity AI",
125
+ )
126
+
127
+ register_model_info(
128
+ ["openhermes-2.5-mistral-7b"],
129
+ "OpenHermes-2.5-Mistral-7B",
130
+ "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B",
131
+ "a mistral-based model fine-tuned on 1M GPT-4 outputs",
132
+ )
133
+
134
+ register_model_info(
135
+ ["starling-lm-7b-alpha"],
136
+ "Starling-LM-7B-alpha",
137
+ "https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha",
138
+ "an open model trained using RLAIF by Berkeley",
139
+ )
140
+
141
+ register_model_info(
142
+ ["tulu-2-dpo-70b"],
143
+ "Tulu 2",
144
+ "https://huggingface.co/allenai/tulu-2-dpo-70b",
145
+ "an instruction and RLHF model by UW/AllenAI",
146
+ )
147
+
148
+ register_model_info(
149
+ ["yi-34b-chat", "yi-6b-chat"],
150
+ "Yi-Chat",
151
+ "https://huggingface.co/01-ai/Yi-34B-Chat",
152
+ "A large language model by 01 AI",
153
+ )
154
+
155
+ register_model_info(
156
+ ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"],
157
+ "Llama 2",
158
+ "https://ai.meta.com/llama/",
159
+ "open foundation and fine-tuned chat models by Meta",
160
+ )
161
+
162
+ register_model_info(
163
+ [
164
+ "vicuna-33b",
165
+ "vicuna-33b-v1.3",
166
+ "vicuna-13b",
167
+ "vicuna-13b-v1.3",
168
+ "vicuna-7b",
169
+ "vicuna-7b-v1.3",
170
+ ],
171
+ "Vicuna",
172
+ "https://lmsys.org/blog/2023-03-30-vicuna/",
173
+ "a chat assistant fine-tuned on user-shared conversations by LMSYS",
174
+ )
175
+
176
+ register_model_info(
177
+ ["chatglm3-6b", "chatglm2-6b", "chatglm-6b"],
178
+ "ChatGLM",
179
+ "https://chatglm.cn/blog",
180
+ "an open bilingual dialogue language model by Tsinghua University",
181
+ )
182
+
183
+ register_model_info(
184
+ ["openchat-3.5"],
185
+ "OpenChat 3.5",
186
+ "https://github.com/imoneoi/openchat",
187
+ "an open model fine-tuned on Mistral-7B using C-RLFT",
188
+ )
189
+
190
+ register_model_info(
191
+ ["tenyxchat-7b-v1"],
192
+ "TenyxChat-7B",
193
+ "https://huggingface.co/tenyx/TenyxChat-7B-v1",
194
+ "an open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning",
195
+ )
196
+
197
+ register_model_info(
198
+ ["zephyr-7b-beta", "zephyr-7b-alpha"],
199
+ "Zephyr",
200
+ "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha",
201
+ "a chatbot fine-tuned from Mistral by Hugging Face",
202
+ )
203
+
204
+ register_model_info(
205
+ ["notus-7b-v1"],
206
+ "Notus",
207
+ "https://huggingface.co/argilla/notus-7b-v1",
208
+ "a chatbot fine-tuned from Zephyr SFT by Argilla",
209
+ )
210
+
211
+ register_model_info(
212
+ ["catppt"],
213
+ "CatPPT",
214
+ "https://huggingface.co/rishiraj/CatPPT",
215
+ "a chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya",
216
+ )
217
+
218
+ register_model_info(
219
+ ["TinyLlama"],
220
+ "TinyLlama",
221
+ "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
222
+ "The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.",
223
+ )
224
+
225
+ register_model_info(
226
+ ["qwen-14b-chat"],
227
+ "Qwen",
228
+ "https://huggingface.co/Qwen/Qwen-14B-Chat",
229
+ "a large language model by Alibaba Cloud",
230
+ )
231
+
232
+ register_model_info(
233
+ ["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"],
234
+ "Code Llama",
235
+ "https://ai.meta.com/blog/code-llama-large-language-model-coding/",
236
+ "open foundation models for code by Meta",
237
+ )
238
+
239
+ register_model_info(
240
+ ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"],
241
+ "WizardLM",
242
+ "https://github.com/nlpxucan/WizardLM",
243
+ "an instruction-following LLM using evol-instruct by Microsoft",
244
+ )
245
+
246
+ register_model_info(
247
+ ["wizardcoder-15b-v1.0"],
248
+ "WizardLM",
249
+ "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder",
250
+ "Empowering Code Large Language Models with Evol-Instruct",
251
+ )
252
+
253
+ register_model_info(
254
+ ["mpt-7b-chat", "mpt-30b-chat"],
255
+ "MPT-Chat",
256
+ "https://www.mosaicml.com/blog/mpt-30b",
257
+ "a chatbot fine-tuned from MPT by MosaicML",
258
+ )
259
+
260
+ register_model_info(
261
+ ["guanaco-33b", "guanaco-65b"],
262
+ "Guanaco",
263
+ "https://github.com/artidoro/qlora",
264
+ "a model fine-tuned with QLoRA by UW",
265
+ )
266
+
267
+ register_model_info(
268
+ ["gpt4all-13b-snoozy"],
269
+ "GPT4All-Snoozy",
270
+ "https://github.com/nomic-ai/gpt4all",
271
+ "a finetuned LLaMA model on assistant style data by Nomic AI",
272
+ )
273
+
274
+ register_model_info(
275
+ ["koala-13b"],
276
+ "Koala",
277
+ "https://bair.berkeley.edu/blog/2023/04/03/koala",
278
+ "a dialogue model for academic research by BAIR",
279
+ )
280
+
281
+ register_model_info(
282
+ ["RWKV-4-Raven-14B"],
283
+ "RWKV-4-Raven",
284
+ "https://huggingface.co/BlinkDL/rwkv-4-raven",
285
+ "an RNN with transformer-level LLM performance",
286
+ )
287
+
288
+ register_model_info(
289
+ ["alpaca-13b"],
290
+ "Alpaca",
291
+ "https://crfm.stanford.edu/2023/03/13/alpaca.html",
292
+ "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford",
293
+ )
294
+
295
+ register_model_info(
296
+ ["oasst-pythia-12b"],
297
+ "OpenAssistant (oasst)",
298
+ "https://open-assistant.io",
299
+ "an Open Assistant for everyone by LAION",
300
+ )
301
+
302
+ register_model_info(
303
+ ["oasst-sft-7-llama-30b"],
304
+ "OpenAssistant (oasst)",
305
+ "https://open-assistant.io",
306
+ "an Open Assistant for everyone by LAION",
307
+ )
308
+
309
+ register_model_info(
310
+ ["palm-2"],
311
+ "PaLM 2 Chat",
312
+ "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023",
313
+ "PaLM 2 for Chat (chat-bison@001) by Google",
314
+ )
315
+
316
+ register_model_info(
317
+ ["llama-7b", "llama-13b"],
318
+ "LLaMA",
319
+ "https://arxiv.org/abs/2302.13971",
320
+ "open and efficient foundation language models by Meta",
321
+ )
322
+
323
+ register_model_info(
324
+ ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"],
325
+ "Open LLaMa (Open Instruct)",
326
+ "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc",
327
+ "Open LLaMa fine-tuned on instruction-following data by VMware",
328
+ )
329
+
330
+ register_model_info(
331
+ ["dolly-v2-12b"],
332
+ "Dolly",
333
+ "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm",
334
+ "an instruction-tuned open large language model by Databricks",
335
+ )
336
+
337
+ register_model_info(
338
+ ["stablelm-tuned-alpha-7b"],
339
+ "StableLM",
340
+ "https://github.com/stability-AI/stableLM",
341
+ "Stability AI language models",
342
+ )
343
+
344
+ register_model_info(
345
+ ["codet5p-6b"],
346
+ "CodeT5p-6b",
347
+ "https://huggingface.co/Salesforce/codet5p-6b",
348
+ "Code completion model released by Salesforce",
349
+ )
350
+
351
+ register_model_info(
352
+ ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"],
353
+ "FastChat-T5",
354
+ "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0",
355
+ "a chat assistant fine-tuned from FLAN-T5 by LMSYS",
356
+ )
357
+
358
+ register_model_info(
359
+ ["phoenix-inst-chat-7b"],
360
+ "Phoenix-7B",
361
+ "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b",
362
+ "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)",
363
+ )
364
+
365
+ register_model_info(
366
+ ["realm-7b-v1"],
367
+ "ReaLM",
368
+ "https://github.com/FreedomIntelligence/ReaLM",
369
+ "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.",
370
+ )
371
+
372
+ register_model_info(
373
+ ["billa-7b-sft"],
374
+ "BiLLa-7B-SFT",
375
+ "https://huggingface.co/Neutralzz/BiLLa-7B-SFT",
376
+ "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher",
377
+ )
378
+
379
+ register_model_info(
380
+ ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"],
381
+ "h2oGPT-GM-7b",
382
+ "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
383
+ "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai",
384
+ )
385
+
386
+ register_model_info(
387
+ ["baize-v2-7b", "baize-v2-13b"],
388
+ "Baize v2",
389
+ "https://github.com/project-baize/baize-chatbot#v2",
390
+ "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.",
391
+ )
392
+
393
+ register_model_info(
394
+ [
395
+ "airoboros-l2-7b-2.1",
396
+ "airoboros-l2-13b-2.1",
397
+ "airoboros-c34b-2.1",
398
+ "airoboros-l2-70b-2.1",
399
+ ],
400
+ "airoboros",
401
+ "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1",
402
+ "an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4",
403
+ )
404
+
405
+ register_model_info(
406
+ [
407
+ "spicyboros-7b-2.2",
408
+ "spicyboros-13b-2.2",
409
+ "spicyboros-70b-2.2",
410
+ ],
411
+ "spicyboros",
412
+ "https://huggingface.co/jondurbin/spicyboros-70b-2.2",
413
+ "de-aligned versions of the airoboros models",
414
+ )
415
+
416
+ register_model_info(
417
+ ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"],
418
+ "Robin-v2",
419
+ "https://huggingface.co/OptimalScale/robin-7b-v2-delta",
420
+ "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.",
421
+ )
422
+
423
+ register_model_info(
424
+ ["manticore-13b-chat"],
425
+ "Manticore 13B Chat",
426
+ "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg",
427
+ "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.",
428
+ )
429
+
430
+ register_model_info(
431
+ ["redpajama-incite-7b-chat"],
432
+ "RedPajama-INCITE-7B-Chat",
433
+ "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat",
434
+ "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together",
435
+ )
436
+
437
+ register_model_info(
438
+ [
439
+ "falcon-7b",
440
+ "falcon-7b-instruct",
441
+ "falcon-40b",
442
+ "falcon-40b-instruct",
443
+ "falcon-180b",
444
+ "falcon-180b-chat",
445
+ ],
446
+ "Falcon",
447
+ "https://huggingface.co/tiiuae/falcon-180B",
448
+ "TII's flagship series of large language models",
449
+ )
450
+
451
+ register_model_info(
452
+ ["tigerbot-7b-sft"],
453
+ "Tigerbot",
454
+ "https://huggingface.co/TigerResearch/tigerbot-7b-sft",
455
+ "TigerBot is a large-scale language model (LLM) with multiple languages and tasks.",
456
+ )
457
+
458
+ register_model_info(
459
+ ["internlm-chat-7b", "internlm-chat-7b-8k"],
460
+ "InternLM",
461
+ "https://huggingface.co/internlm/internlm-chat-7b",
462
+ "InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.",
463
+ )
464
+
465
+ register_model_info(
466
+ ["Qwen-7B-Chat"],
467
+ "Qwen",
468
+ "https://huggingface.co/Qwen/Qwen-7B-Chat",
469
+ "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.",
470
+ )
471
+
472
+ register_model_info(
473
+ ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"],
474
+ "Llama2-Chinese",
475
+ "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat",
476
+ "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.",
477
+ )
478
+
479
+ register_model_info(
480
+ ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"],
481
+ "Chinese-Alpaca",
482
+ "https://huggingface.co/hfl/chinese-alpaca-2-13b",
483
+ "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.",
484
+ )
485
+
486
+ register_model_info(
487
+ ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"],
488
+ "Vigogne-Instruct",
489
+ "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct",
490
+ "Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang",
491
+ )
492
+
493
+ register_model_info(
494
+ ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"],
495
+ "Vigogne-Chat",
496
+ "https://huggingface.co/bofenghuang/vigogne-2-7b-chat",
497
+ "Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang",
498
+ )
499
+
500
+ register_model_info(
501
+ ["stable-vicuna-13B-HF"],
502
+ "stable-vicuna",
503
+ "https://huggingface.co/TheBloke/stable-vicuna-13B-HF",
504
+ "StableVicuna is a Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.",
505
+ )
506
+
507
+ register_model_info(
508
+ ["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2"],
509
+ "DeluxeChat",
510
+ "",
511
+ "Deluxe Chat",
512
+ )
513
+
514
+ register_model_info(
515
+ [
516
+ "Xwin-LM-7B-V0.1",
517
+ "Xwin-LM-13B-V0.1",
518
+ "Xwin-LM-70B-V0.1",
519
+ "Xwin-LM-7B-V0.2",
520
+ "Xwin-LM-13B-V0.2",
521
+ ],
522
+ "Xwin-LM",
523
+ "https://github.com/Xwin-LM/Xwin-LM",
524
+ "Chat models developed by Xwin-LM team",
525
+ )
526
+
527
+ register_model_info(
528
+ ["lemur-70b-chat"],
529
+ "Lemur-Chat",
530
+ "https://huggingface.co/OpenLemur/lemur-70b-chat-v1",
531
+ "an openly accessible language model optimized for both natural language and coding capabilities ",
532
+ )
533
+
534
+ register_model_info(
535
+ ["Mistral-7B-OpenOrca"],
536
+ "Open-Orca",
537
+ "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca",
538
+ "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)",
539
+ )
540
+
541
+ register_model_info(
542
+ ["dolphin-2.2.1-mistral-7b"],
543
+ "dolphin-mistral",
544
+ "https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b",
545
+ "An uncensored fine-tuned Mistral 7B",
546
+ )
547
+
548
+ register_model_info(
549
+ [
550
+ "AquilaChat-7B",
551
+ "AquilaChat2-7B",
552
+ "AquilaChat2-34B",
553
+ ],
554
+ "Aquila-Chat",
555
+ "https://huggingface.co/BAAI/AquilaChat2-34B",
556
+ "Chat models developed by BAAI team",
557
+ )
558
+
559
+ register_model_info(
560
+ ["xDAN-L1-Chat-RL-v1"],
561
+ "xDAN-L1-Chat",
562
+ "https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1",
563
+ "A large language chat model created by xDAN-AI.",
564
+ )
565
+
566
+ register_model_info(
567
+ ["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"],
568
+ "MetaMath",
569
+ "https://huggingface.co/meta-math",
570
+ "MetaMath is a finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.",
571
+ )
572
+
573
+ register_model_info(
574
+ ["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"],
575
+ "IEIYuan",
576
+ "https://huggingface.co/IEITYuan",
577
+ "Yuan2 is a Basemodel developed by IEI.",
578
+ )
arena_elo/elo_rating/upload_battle_data.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import json
3
+ import os
4
+ import datasets
5
+ import datetime
6
+ from pathlib import Path
7
+ from datetime import datetime
8
+ from PIL import Image
9
+
10
+ datasets.config.DEFAULT_MAX_BATCH_SIZE = 500
11
+ def create_hf_dataset(data_file: str, split="test"):
12
+ hf_dataset = datasets.Dataset.from_list(
13
+ data_file,
14
+ features=datasets.Features(
15
+ {
16
+ "question_id": datasets.Value("string"),
17
+ "model": datasets.Value("string"),
18
+ "conversation": [
19
+ {
20
+ "role": datasets.Value("string"),
21
+ "content": datasets.Value("string"),
22
+ }
23
+ ],
24
+ "language": datasets.Value("string"),
25
+ "image": datasets.Image(),
26
+ "turn": datasets.Value("int32"),
27
+ }
28
+ ),
29
+ split=split,
30
+ )
31
+ return hf_dataset
32
+
33
+ def create_hf_battle_dataset(data_file: str, split="test"):
34
+ hf_dataset = datasets.Dataset.from_list(
35
+ data_file,
36
+ features=datasets.Features(
37
+ {
38
+ "question_id": datasets.Value("string"),
39
+ "model_a": datasets.Value("string"),
40
+ "model_b": datasets.Value("string"),
41
+ "conversation_a": [
42
+ {
43
+ "role": datasets.Value("string"),
44
+ "content": datasets.Value("string"),
45
+ }
46
+ ],
47
+ "conversation_b": [
48
+ {
49
+ "role": datasets.Value("string"),
50
+ "content": datasets.Value("string"),
51
+ }
52
+ ],
53
+ "language": datasets.Value("string"),
54
+ "image": datasets.Image(),
55
+ "turn": datasets.Value("int32"),
56
+ "anony": datasets.Value("bool"),
57
+ }
58
+ ),
59
+ split=split,
60
+ )
61
+ return hf_dataset
62
+
63
+
64
+
65
+
66
+ def load_image(path:str):
67
+ try:
68
+ return Image.open(path)
69
+ except Exception as e:
70
+ print(f"Error loading image {path}: {e}")
71
+ return None
72
+
73
+ def get_date_from_time_stamp(unix_timestamp: int):
74
+ # Create a datetime object from the Unix timestamp
75
+ dt = datetime.fromtimestamp(unix_timestamp)
76
+
77
+ # Convert the datetime object to a string with the desired format
78
+ date_str = dt.strftime("%Y-%m-%d")
79
+ return date_str
80
+
81
+ def load_battle_image(battle, log_dir):
82
+ image_path = Path(log_dir) / f"{get_date_from_time_stamp(battle['tstamp'])}-convinput_images" / f"input_image_{battle['question_id']}.png"
83
+ return load_image(image_path)
84
+
85
+
86
+ def main(
87
+ data_file: str = "./results/latest/clean_battle_conv.json",
88
+ repo_id: str = "DongfuTingle/wildvision-bench",
89
+ log_dir: str = os.getenv("LOGDIR", "./vision-arena-logs/"),
90
+ mode="battle",
91
+ token = os.environ.get("HUGGINGFACE_TOKEN", None)
92
+ ):
93
+ with open(data_file, "r") as f:
94
+ data = json.load(f)
95
+
96
+
97
+
98
+ has_image_stats = {
99
+ "has_image": 0,
100
+ "no_image": 0,
101
+ }
102
+ if mode == "keep_bad_only":
103
+ # anony only
104
+ data = [d for d in data if d["anony"]]
105
+
106
+ new_data = []
107
+ for battle in data:
108
+ image = load_battle_image(battle, log_dir)
109
+ if image is None:
110
+ has_image_stats["no_image"] += 1
111
+ # we don't keep the data without image
112
+ continue
113
+ has_image_stats["has_image"] += 1
114
+
115
+ if battle["winner"] in ["model_a", "model_b"]:
116
+ if battle["winner"] == "model_a":
117
+ worse_model = "model_b"
118
+ worse_conv = "conversation_b"
119
+ if battle["winner"] == "model_b":
120
+ worse_model = "model_a"
121
+ worse_conv = "conversation_a"
122
+
123
+ new_data.append({
124
+ "question_id": battle["question_id"],
125
+ "model": battle[worse_model],
126
+ "conversation": battle[worse_conv],
127
+ "language": battle["language"],
128
+ "image": image,
129
+ "turn": battle["turn"],
130
+ })
131
+ elif battle["winner"] == "tie (bothbad)":
132
+
133
+ new_data.append({
134
+ "question_id": battle["question_id"],
135
+ "model": battle["model_a"],
136
+ "conversation": battle["conversation_a"],
137
+ "language": battle["language"],
138
+ "image": image,
139
+ "turn": battle["turn"],
140
+ })
141
+
142
+ new_data.append({
143
+ "question_id": battle["question_id"],
144
+ "model": battle["model_b"],
145
+ "conversation": battle["conversation_b"],
146
+ "language": battle["language"],
147
+ "image": image,
148
+ "turn": battle["turn"],
149
+ })
150
+
151
+ split = "test"
152
+ hf_dataset = create_hf_dataset(new_data, "test")
153
+
154
+ elif mode == "battle":
155
+ new_data = []
156
+ for battle in data:
157
+ image = load_battle_image(battle, log_dir)
158
+ if image is None:
159
+ has_image_stats["no_image"] += 1
160
+ continue
161
+ has_image_stats["has_image"] += 1
162
+ new_data.append({
163
+ "question_id": battle["question_id"],
164
+ "model_a": battle["model_a"],
165
+ "model_b": battle["model_b"],
166
+ "conversation_a": battle["conversation_a"],
167
+ "conversation_b": battle["conversation_b"],
168
+ "language": battle["language"],
169
+ "image": image,
170
+ "turn": battle["turn"],
171
+ "anony": battle["anony"],
172
+ })
173
+ split = "test"
174
+ hf_dataset = create_hf_battle_dataset(new_data, "test")
175
+ else:
176
+ raise ValueError(f"Invalid mode: {mode}")
177
+
178
+ print(f"Stats: {has_image_stats}")
179
+ print(hf_dataset)
180
+ print(f"Uploading to part {repo_id}:{split}...")
181
+ hf_dataset.push_to_hub(
182
+ repo_id=repo_id,
183
+ config_name=mode,
184
+ split=split,
185
+ token=token,
186
+ commit_message=f"Add vision-arena {split} dataset",
187
+ )
188
+
189
+ print("Done!")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ fire.Fire(main)
arena_elo/elo_rating/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import pytz
3
+ import PIL
4
+ import os
5
+
6
+ def detect_language(text: str) -> str:
7
+ """Detect the langauge of a string."""
8
+ import polyglot # pip3 install polyglot pyicu pycld2
9
+ from polyglot.detect import Detector
10
+ from polyglot.detect.base import logger as polyglot_logger
11
+ import pycld2
12
+
13
+ polyglot_logger.setLevel("ERROR")
14
+
15
+ try:
16
+ lang_code = Detector(text).language.name
17
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
18
+ lang_code = "unknown"
19
+ return lang_code
20
+
21
+
22
+ def get_time_stamp_from_date(date_str:str):
23
+ """
24
+ Convert a date string to a Unix timestamp
25
+ Args:
26
+ date_str (str): The input date string in the format 'YYYY-MM-DD-HH:MM-TZ', e.g. '2024-02-10-14:00-PT'
27
+ """
28
+
29
+ # Convert the date string into a format that Python's datetime can understand
30
+ # and specify the correct timezone for PT, which is 'US/Pacific'
31
+ date_format = "%Y-%m-%d-%H:%M-%Z"
32
+
33
+ # Parse the date string into a datetime object
34
+ # Note: PT is not directly recognized by pytz, so we manually map it to 'US/Pacific'
35
+ timezone_map = {
36
+ "PT": "US/Pacific",
37
+ }
38
+
39
+ # Extract the timezone abbreviation
40
+ tz_abbr = date_str.split("-")[-1]
41
+ # Map the abbreviation to a pytz timezone
42
+ tz_info = pytz.timezone(timezone_map[tz_abbr])
43
+
44
+ # Remove the timezone abbreviation for parsing
45
+ date_str_parsed = date_str.rsplit("-", 1)[0]
46
+
47
+ # Create a datetime object with the corresponding timezone
48
+ dt = datetime.strptime(date_str_parsed, "%Y-%m-%d-%H:%M").replace(tzinfo=tz_info)
49
+
50
+ # Convert the datetime object to a Unix timestamp
51
+ unix_timestamp = dt.timestamp()
52
+ return unix_timestamp
53
+
54
+ def get_date_from_time_stamp(unix_timestamp: int):
55
+ # Create a datetime object from the Unix timestamp
56
+ dt = datetime.fromtimestamp(unix_timestamp)
57
+
58
+ # Convert the datetime object to a string with the desired format
59
+ date_str = dt.strftime("%Y-%m-%d %H:%M:%S %Z")
60
+ return date_str
61
+
62
+
63
+ def get_input_image_path(tstamp, conv_id):
64
+ # from tstamp to date e.g. 2024-02-10
65
+ date_str = datetime.fromtimestamp(tstamp, tz=pytz.timezone("US/Pacific")).strftime("%Y-%m-%d")
66
+ LOGDIR = os.getenv("LOGDIR")
67
+ return f"{LOGDIR}/{date_str}-convinput_images/input_image_{conv_id}.png"
68
+
69
+ def load_image_from_path(image_path):
70
+ # Load the image from the specified
71
+ # path using the Python Imaging Library (PIL)
72
+ try:
73
+ image = PIL.Image.open(image_path)
74
+ return image
75
+ except FileNotFoundError:
76
+ print(f"Image not found at path: {image_path}")
77
+ return None
78
+ except PIL.UnidentifiedImageError:
79
+ print(f"Unidentified image format at path: {image_path}")
80
+ return None
81
+
82
+
83
+
arena_elo/evaluator/convert_to_evaluator_data.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ from pytz import timezone
6
+ from tqdm import tqdm
7
+ import base64
8
+ from icecream import ic
9
+ from PIL import Image
10
+
11
+
12
+ # Function to encode the image
13
+ def encode_image(image_path):
14
+ with open(image_path, "rb") as image_file:
15
+ return base64.b64encode(image_file.read()).decode('utf-8')
16
+
17
+ def get_log_files(max_num_files=None):
18
+ dates = []
19
+ for month in [2, 3]:
20
+ for day in range(1, 32):
21
+ dates.append(f"2024-{month:02d}-{day:02d}")
22
+
23
+ num_servers = 1
24
+ filenames = []
25
+ for d in dates:
26
+ for i in range(num_servers):
27
+ # name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
28
+ name = os.path.expanduser(f"vision-arena-logs/{d}-conv.json")
29
+ if os.path.exists(name):
30
+ filenames.append(name)
31
+ max_num_files = max_num_files or len(filenames)
32
+ filenames = filenames[-max_num_files:]
33
+ return filenames
34
+
35
+
36
+ def pretty_print_conversation(messages):
37
+ for role, msg in messages:
38
+ print(f"[[{role}]]: {msg}")
39
+
40
+ task_template_map = {
41
+ "image_caption": "Give me the semantic alignment score between the given image and the given caption: \"{generated_sentence}\" on a scale of 0-100. Only reply the score value.",
42
+ "vqa": "Rate the answer correctness regarding the question within the context of the given image on a scale of 0-100. Only reply the score value.",
43
+ "pair_rate_old": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nGiven the instruction and the image, please compare the correctness of responses A and B. Reply with \"leftvote\" if you find A better, \"rightvote\" if B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both responses are equally satisfactory. If you are unable to make a decision, please reply with \"NA\".",
44
+ "pair_rate_wexplanation": "<image>[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
45
+ "pair_rate": "<image>[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Reply with \"leftvote\" if you find assistant A better, \"rightvote\" if assistant B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both assistants provide equally satisfactory answers. If you are unable to make a decision, please reply with \"NA\"."
46
+ }
47
+
48
+ def inspect_convs(log_files):
49
+ json_data = []
50
+
51
+ ic(log_files)
52
+ total_vote = 0
53
+
54
+ for filename in tqdm(log_files, desc="read files"):
55
+ for retry in range(5):
56
+ try:
57
+ lines = open(filename).readlines()
58
+ break
59
+ except FileNotFoundError:
60
+ time.sleep(2)
61
+
62
+ for l in lines:
63
+ row = json.loads(l)
64
+
65
+ if "states" not in row:
66
+ continue
67
+ if row["type"] not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
68
+ continue
69
+
70
+ model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
71
+
72
+
73
+ # Iterate through each state and write the relevant information
74
+ if not len(row["states"][0]['messages']): continue
75
+ # ic(row["states"][0]['messages'][1][1])
76
+
77
+ if row["states"][0]['messages'][1][1] is None or row["states"][1]['messages'][1][1] is None or "NETWORK ERROR" in row["states"][0]['messages'][1][1] or "NETWORK ERROR" in row["states"][1]['messages'][1][1]: continue
78
+ total_vote += 1
79
+
80
+ conv_id = row["states"][0]['conv_id']
81
+ image_path = os.path.join("/local/home/yujielu/project/Arena-Elo/vision-arena-logs", os.path.basename(filename)[:-5]+"input_images", f"input_image_{conv_id}.png")
82
+ if not os.path.exists(image_path) :
83
+ continue
84
+ try:
85
+ image = Image.open(image_path).convert("RGB")
86
+ except:
87
+ continue
88
+
89
+ left_response = row["states"][0]['messages'][1][1]
90
+ right_response = row["states"][1]['messages'][1][1]
91
+ instruction = row["states"][0]['messages'][0][1]
92
+ generated_sentence = f"[The Start of Assistant A’s Answer]\n{left_response}\n[The End of Assistant A’s Answer]\n\n[The Start of Assistant B’s Answer]\n{right_response}\n[The End of Assistant B’s Answer]"
93
+ text_prompt = task_template_map["pair_rate"].format(instruction=instruction, generated_sentence=generated_sentence)
94
+
95
+ user_input = text_prompt
96
+ # Create the conversation structure
97
+ conversation = [
98
+ {
99
+ "from": "human",
100
+ "value": user_input
101
+ },
102
+ {
103
+ "from": "gpt",
104
+ "value": row["type"]
105
+ }
106
+ ]
107
+
108
+ # Create the JSON object for each row
109
+ json_obj = {
110
+ "id": conv_id,
111
+ "image": image_path,
112
+ "conversations": conversation
113
+ }
114
+
115
+ # Append the JSON object to the list
116
+ json_data.append(json_obj)
117
+
118
+ # Write the JSON data to a file
119
+ with open('output_evaluator_data.json', 'w') as json_file:
120
+ json.dump(json_data, json_file, indent=2)
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--max-num-files", type=int)
125
+ args = parser.parse_args()
126
+
127
+ log_files = get_log_files(args.max_num_files)
128
+
129
+
130
+
131
+ inspect_convs(log_files)
132
+
133
+
134
+
constants.py CHANGED
@@ -26,4 +26,4 @@ TEXT_PROMPT_PATH = "offline/prompts_110.json"
26
  IMAGE_PROMPT_PATH = "offline/image_urls.txt"
27
 
28
  MAX_ATTEMPTS = 5
29
- REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
 
26
  IMAGE_PROMPT_PATH = "offline/image_urls.txt"
27
 
28
  MAX_ATTEMPTS = 5
29
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN", "yourKey")
model/model_worker.py CHANGED
@@ -6,8 +6,10 @@ from typing import List
6
  import replicate
7
  import subprocess
8
 
9
- from constants import OFFLINE_GIF_DIR
10
- # os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
 
 
11
 
12
  class BaseModelWorker:
13
  def __init__(self,
 
6
  import replicate
7
  import subprocess
8
 
9
+ from gradio_client import Client
10
+ from client import Gau2Mesh_client
11
+ from constants import OFFLINE_GIF_DIR, REPLICATE_API_TOKEN
12
+ # os.environ("REPLICATE_API_TOKEN", "yourKey")
13
 
14
  class BaseModelWorker:
15
  def __init__(self,
serve/inference.py CHANGED
@@ -188,14 +188,17 @@ def generate_t2s(gen_func, render_func,
188
  state.rgb_video = videos['rgb']
189
  yield state, videos['normal'], videos['rgb']
190
 
 
191
  # logger.info(f"===output===: {output}")
192
  data = {
193
- "ip": ip,
 
194
  "model": model_name,
195
  "type": "offline",
196
  "gen_params": {},
197
  "state": state.dict(),
198
  "start": round(start_time, 4),
 
199
  }
200
  else:
201
  start_time = time.time()
@@ -210,14 +213,17 @@ def generate_t2s(gen_func, render_func,
210
  state.rgb_video = videos['rgb']
211
  yield state, videos['normal'], videos['rgb']
212
 
 
213
  # logger.info(f"===output===: {output}")
214
  data = {
215
- "ip": ip,
 
216
  "model": model_name,
217
  "type": "online",
218
  "gen_params": {},
219
  "state": state.dict(),
220
  "start": round(start_time, 4),
 
221
  "time": round(finish_time - start_time, 4),
222
  "generate_time": round(generate_time, 4),
223
  "render_time": round(render_time, 4),
@@ -277,22 +283,27 @@ def generate_t2s_multi(gen_func, render_func,
277
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
278
  yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
279
 
 
280
  # logger.info(f"===output===: {output}")
281
  data_0 = {
 
282
  "ip": get_ip(request),
283
  "model": model_name_0,
284
  "type": "offline",
285
  "gen_params": {},
286
  "state": state_0.dict(),
287
  "start": round(start_time, 4),
 
288
  }
289
  data_1 = {
 
290
  "ip": get_ip(request),
291
  "model": model_name_1,
292
  "type": "offline",
293
  "gen_params": {},
294
  "state": state_1.dict(),
295
  "start": round(start_time, 4),
 
296
  }
297
  else:
298
  start_time = time.time()
@@ -307,25 +318,30 @@ def generate_t2s_multi(gen_func, render_func,
307
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
308
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
309
 
 
310
  # logger.info(f"===output===: {output}")
311
  data_0 = {
 
312
  "ip": get_ip(request),
313
  "model": model_name_0,
314
  "type": "online",
315
  "gen_params": {},
316
  "state": state_0.dict(),
317
  "start": round(start_time, 4),
 
318
  "time": round(finish_time - start_time, 4),
319
  "generate_time": round(generate_time, 4),
320
  "render_time": round(render_time, 4),
321
  }
322
  data_1 = {
 
323
  "ip": get_ip(request),
324
  "model": model_name_1,
325
  "type": "online",
326
  "gen_params": {},
327
  "state": state_1.dict(),
328
  "start": round(start_time, 4),
 
329
  "time": round(finish_time - start_time, 4),
330
  "generate_time": round(generate_time, 4),
331
  "render_time": round(render_time, 4),
@@ -386,22 +402,27 @@ def generate_t2s_multi_annoy(gen_func, render_func,
386
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
387
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
388
 
 
389
  # logger.info(f"===output===: {output}")
390
  data_0 = {
 
391
  "ip": get_ip(request),
392
  "model": model_name_0,
393
  "type": "offline",
394
  "gen_params": {},
395
  "state": state_0.dict(),
396
  "start": round(start_time, 4),
 
397
  }
398
  data_1 = {
 
399
  "ip": get_ip(request),
400
  "model": model_name_1,
401
  "type": "offline",
402
  "gen_params": {},
403
  "state": state_1.dict(),
404
  "start": round(start_time, 4),
 
405
  }
406
  else:
407
  start_time = time.time()
@@ -418,25 +439,30 @@ def generate_t2s_multi_annoy(gen_func, render_func,
418
  yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
419
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
420
 
 
421
  # logger.info(f"===output===: {output}")
422
  data_0 = {
 
423
  "ip": get_ip(request),
424
  "model": model_name_0,
425
  "type": "online",
426
  "gen_params": {},
427
  "state": state_0.dict(),
428
  "start": round(start_time, 4),
 
429
  "time": round(finish_time - start_time, 4),
430
  "generate_time": round(generate_time, 4),
431
  "render_time": round(render_time, 4),
432
  }
433
  data_1 = {
 
434
  "ip": get_ip(request),
435
  "model": model_name_1,
436
  "type": "online",
437
  "gen_params": {},
438
  "state": state_1.dict(),
439
  "start": round(start_time, 4),
 
440
  "time": round(finish_time - start_time, 4),
441
  "generate_time": round(generate_time, 4),
442
  "render_time": round(render_time, 4),
@@ -481,14 +507,17 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
481
  state.rgb_video = videos['rgb']
482
  yield state, videos['normal'], videos['rgb']
483
 
 
484
  # logger.info(f"===output===: {output}")
485
  data = {
486
- "ip": ip,
 
487
  "model": model_name,
488
  "type": "offline",
489
  "gen_params": {},
490
  "state": state.dict(),
491
  "start": round(start_time, 4),
 
492
  }
493
  else:
494
  start_time = time.time()
@@ -503,14 +532,17 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
503
  state.rgb_video = videos['rgb']
504
  yield state, videos['normal'], videos['rgb']
505
 
 
506
  # logger.info(f"===output===: {output}")
507
  data = {
508
- "ip": ip,
 
509
  "model": model_name,
510
  "type": "online",
511
  "gen_params": {},
512
  "state": state.dict(),
513
  "start": round(start_time, 4),
 
514
  "time": round(finish_time - start_time, 4),
515
  "generate_time": round(generate_time, 4),
516
  "render_time": round(render_time, 4),
@@ -567,22 +599,27 @@ def generate_i2s_multi(gen_func, render_func,
567
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
568
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
569
 
 
570
  # logger.info(f"===output===: {output}")
571
  data_0 = {
 
572
  "ip": get_ip(request),
573
  "model": model_name_0,
574
  "type": "offline",
575
  "gen_params": {},
576
  "state": state_0.dict(),
577
  "start": round(start_time, 4),
 
578
  }
579
  data_1 = {
 
580
  "ip": get_ip(request),
581
  "model": model_name_1,
582
  "type": "offline",
583
  "gen_params": {},
584
  "state": state_1.dict(),
585
  "start": round(start_time, 4),
 
586
  }
587
  else:
588
  start_time = time.time()
@@ -597,25 +634,30 @@ def generate_i2s_multi(gen_func, render_func,
597
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
598
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
599
 
 
600
  # logger.info(f"===output===: {output}")
601
  data_0 = {
 
602
  "ip": get_ip(request),
603
  "model": model_name_0,
604
  "type": "online",
605
  "gen_params": {},
606
  "state": state_0.dict(),
607
  "start": round(start_time, 4),
 
608
  "time": round(finish_time - start_time, 4),
609
  "generate_time": round(generate_time, 4),
610
  "render_time": round(render_time, 4),
611
  }
612
  data_1 = {
 
613
  "ip": get_ip(request),
614
  "model": model_name_1,
615
  "type": "online",
616
  "gen_params": {},
617
  "state": state_1.dict(),
618
  "start": round(start_time, 4),
 
619
  "time": round(finish_time - start_time, 4),
620
  "generate_time": round(generate_time, 4),
621
  "render_time": round(render_time, 4),
@@ -672,26 +714,31 @@ def generate_i2s_multi_annoy(gen_func, render_func,
672
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
673
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
674
 
 
675
  # logger.info(f"===output===: {output}")
676
  data_0 = {
 
677
  "ip": get_ip(request),
678
  "model": model_name_0,
679
  "type": "offline",
680
  "gen_params": {},
681
  "state": state_0.dict(),
682
  "start": round(start_time, 4),
 
683
  }
684
  data_1 = {
 
685
  "ip": get_ip(request),
686
  "model": model_name_1,
687
  "type": "offline",
688
  "gen_params": {},
689
  "state": state_1.dict(),
690
  "start": round(start_time, 4),
 
691
  }
692
  else:
693
  start_time = time.time()
694
- shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
695
  generate_time = time.time() - start_time
696
 
697
  videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
@@ -703,25 +750,30 @@ def generate_i2s_multi_annoy(gen_func, render_func,
703
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
704
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
705
 
 
706
  # logger.info(f"===output===: {output}")
707
  data_0 = {
 
708
  "ip": get_ip(request),
709
  "model": model_name_0,
710
  "type": "online",
711
  "gen_params": {},
712
  "state": state_0.dict(),
713
  "start": round(start_time, 4),
 
714
  "time": round(finish_time - start_time, 4),
715
  "generate_time": round(generate_time, 4),
716
  "render_time": round(render_time, 4),
717
  }
718
  data_1 = {
 
719
  "ip": get_ip(request),
720
  "model": model_name_1,
721
  "type": "online",
722
  "gen_params": {},
723
  "state": state_1.dict(),
724
  "start": round(start_time, 4),
 
725
  "time": round(finish_time - start_time, 4),
726
  "generate_time": round(generate_time, 4),
727
  "render_time": round(render_time, 4),
 
188
  state.rgb_video = videos['rgb']
189
  yield state, videos['normal'], videos['rgb']
190
 
191
+ finish_tstamp = time.time()
192
  # logger.info(f"===output===: {output}")
193
  data = {
194
+ "tstamp": round(finish_tstamp, 4),
195
+ "ip": get_ip(request),
196
  "model": model_name,
197
  "type": "offline",
198
  "gen_params": {},
199
  "state": state.dict(),
200
  "start": round(start_time, 4),
201
+ "finish": round(finish_tstamp, 4),
202
  }
203
  else:
204
  start_time = time.time()
 
213
  state.rgb_video = videos['rgb']
214
  yield state, videos['normal'], videos['rgb']
215
 
216
+ finish_tstamp = time.time()
217
  # logger.info(f"===output===: {output}")
218
  data = {
219
+ "tstamp": round(finish_tstamp, 4),
220
+ "ip": get_ip(request),
221
  "model": model_name,
222
  "type": "online",
223
  "gen_params": {},
224
  "state": state.dict(),
225
  "start": round(start_time, 4),
226
+ "finish": round(finish_tstamp, 4),
227
  "time": round(finish_time - start_time, 4),
228
  "generate_time": round(generate_time, 4),
229
  "render_time": round(render_time, 4),
 
283
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
284
  yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
285
 
286
+ finish_tstamp = time.time()
287
  # logger.info(f"===output===: {output}")
288
  data_0 = {
289
+ "tstamp": round(finish_tstamp, 4),
290
  "ip": get_ip(request),
291
  "model": model_name_0,
292
  "type": "offline",
293
  "gen_params": {},
294
  "state": state_0.dict(),
295
  "start": round(start_time, 4),
296
+ "finish": round(finish_tstamp, 4),
297
  }
298
  data_1 = {
299
+ "tstamp": round(finish_tstamp, 4),
300
  "ip": get_ip(request),
301
  "model": model_name_1,
302
  "type": "offline",
303
  "gen_params": {},
304
  "state": state_1.dict(),
305
  "start": round(start_time, 4),
306
+ "finish": round(finish_tstamp, 4),
307
  }
308
  else:
309
  start_time = time.time()
 
318
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
319
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
320
 
321
+ finish_tstamp = time.time()
322
  # logger.info(f"===output===: {output}")
323
  data_0 = {
324
+ "tstamp": round(finish_tstamp, 4),
325
  "ip": get_ip(request),
326
  "model": model_name_0,
327
  "type": "online",
328
  "gen_params": {},
329
  "state": state_0.dict(),
330
  "start": round(start_time, 4),
331
+ "finish": round(finish_tstamp, 4),
332
  "time": round(finish_time - start_time, 4),
333
  "generate_time": round(generate_time, 4),
334
  "render_time": round(render_time, 4),
335
  }
336
  data_1 = {
337
+ "tstamp": round(finish_tstamp, 4),
338
  "ip": get_ip(request),
339
  "model": model_name_1,
340
  "type": "online",
341
  "gen_params": {},
342
  "state": state_1.dict(),
343
  "start": round(start_time, 4),
344
+ "finish": round(finish_tstamp, 4),
345
  "time": round(finish_time - start_time, 4),
346
  "generate_time": round(generate_time, 4),
347
  "render_time": round(render_time, 4),
 
402
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
403
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
404
 
405
+ finish_tstamp = time.time()
406
  # logger.info(f"===output===: {output}")
407
  data_0 = {
408
+ "tstamp": round(finish_tstamp, 4),
409
  "ip": get_ip(request),
410
  "model": model_name_0,
411
  "type": "offline",
412
  "gen_params": {},
413
  "state": state_0.dict(),
414
  "start": round(start_time, 4),
415
+ "finish": round(finish_tstamp, 4),
416
  }
417
  data_1 = {
418
+ "tstamp": round(finish_tstamp, 4),
419
  "ip": get_ip(request),
420
  "model": model_name_1,
421
  "type": "offline",
422
  "gen_params": {},
423
  "state": state_1.dict(),
424
  "start": round(start_time, 4),
425
+ "finish": round(finish_tstamp, 4),
426
  }
427
  else:
428
  start_time = time.time()
 
439
  yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
440
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
441
 
442
+ finish_tstamp = time.time()
443
  # logger.info(f"===output===: {output}")
444
  data_0 = {
445
+ "tstamp": round(finish_tstamp, 4),
446
  "ip": get_ip(request),
447
  "model": model_name_0,
448
  "type": "online",
449
  "gen_params": {},
450
  "state": state_0.dict(),
451
  "start": round(start_time, 4),
452
+ "finish": round(finish_tstamp, 4),
453
  "time": round(finish_time - start_time, 4),
454
  "generate_time": round(generate_time, 4),
455
  "render_time": round(render_time, 4),
456
  }
457
  data_1 = {
458
+ "tstamp": round(finish_tstamp, 4),
459
  "ip": get_ip(request),
460
  "model": model_name_1,
461
  "type": "online",
462
  "gen_params": {},
463
  "state": state_1.dict(),
464
  "start": round(start_time, 4),
465
+ "finish": round(finish_tstamp, 4),
466
  "time": round(finish_time - start_time, 4),
467
  "generate_time": round(generate_time, 4),
468
  "render_time": round(render_time, 4),
 
507
  state.rgb_video = videos['rgb']
508
  yield state, videos['normal'], videos['rgb']
509
 
510
+ finish_tstamp = time.time()
511
  # logger.info(f"===output===: {output}")
512
  data = {
513
+ "tstamp": round(finish_tstamp, 4),
514
+ "ip": get_ip(request),
515
  "model": model_name,
516
  "type": "offline",
517
  "gen_params": {},
518
  "state": state.dict(),
519
  "start": round(start_time, 4),
520
+ "finish": round(finish_tstamp, 4),
521
  }
522
  else:
523
  start_time = time.time()
 
532
  state.rgb_video = videos['rgb']
533
  yield state, videos['normal'], videos['rgb']
534
 
535
+ finish_tstamp = time.time()
536
  # logger.info(f"===output===: {output}")
537
  data = {
538
+ "tstamp": round(finish_tstamp, 4),
539
+ "ip": get_ip(request),
540
  "model": model_name,
541
  "type": "online",
542
  "gen_params": {},
543
  "state": state.dict(),
544
  "start": round(start_time, 4),
545
+ "finish": round(finish_tstamp, 4),
546
  "time": round(finish_time - start_time, 4),
547
  "generate_time": round(generate_time, 4),
548
  "render_time": round(render_time, 4),
 
599
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
600
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
601
 
602
+ finish_tstamp = time.time()
603
  # logger.info(f"===output===: {output}")
604
  data_0 = {
605
+ "tstamp": round(finish_tstamp, 4),
606
  "ip": get_ip(request),
607
  "model": model_name_0,
608
  "type": "offline",
609
  "gen_params": {},
610
  "state": state_0.dict(),
611
  "start": round(start_time, 4),
612
+ "finish": round(finish_tstamp, 4),
613
  }
614
  data_1 = {
615
+ "tstamp": round(finish_tstamp, 4),
616
  "ip": get_ip(request),
617
  "model": model_name_1,
618
  "type": "offline",
619
  "gen_params": {},
620
  "state": state_1.dict(),
621
  "start": round(start_time, 4),
622
+ "finish": round(finish_tstamp, 4),
623
  }
624
  else:
625
  start_time = time.time()
 
634
  state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
635
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
636
 
637
+ finish_tstamp = time.time()
638
  # logger.info(f"===output===: {output}")
639
  data_0 = {
640
+ "tstamp": round(finish_tstamp, 4),
641
  "ip": get_ip(request),
642
  "model": model_name_0,
643
  "type": "online",
644
  "gen_params": {},
645
  "state": state_0.dict(),
646
  "start": round(start_time, 4),
647
+ "finish": round(finish_tstamp, 4),
648
  "time": round(finish_time - start_time, 4),
649
  "generate_time": round(generate_time, 4),
650
  "render_time": round(render_time, 4),
651
  }
652
  data_1 = {
653
+ "tstamp": round(finish_tstamp, 4),
654
  "ip": get_ip(request),
655
  "model": model_name_1,
656
  "type": "online",
657
  "gen_params": {},
658
  "state": state_1.dict(),
659
  "start": round(start_time, 4),
660
+ "finish": round(finish_tstamp, 4),
661
  "time": round(finish_time - start_time, 4),
662
  "generate_time": round(generate_time, 4),
663
  "render_time": round(render_time, 4),
 
714
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
715
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
716
 
717
+ finish_tstamp = time.time()
718
  # logger.info(f"===output===: {output}")
719
  data_0 = {
720
+ "tstamp": round(finish_tstamp, 4),
721
  "ip": get_ip(request),
722
  "model": model_name_0,
723
  "type": "offline",
724
  "gen_params": {},
725
  "state": state_0.dict(),
726
  "start": round(start_time, 4),
727
+ "finish": round(finish_tstamp, 4),
728
  }
729
  data_1 = {
730
+ "tstamp": round(finish_tstamp, 4),
731
  "ip": get_ip(request),
732
  "model": model_name_1,
733
  "type": "offline",
734
  "gen_params": {},
735
  "state": state_1.dict(),
736
  "start": round(start_time, 4),
737
+ "finish": round(finish_tstamp, 4),
738
  }
739
  else:
740
  start_time = time.time()
741
+ shape_0, shape_1 = gen_func(image, model_name_0, model_name_1, i2s_model=True)
742
  generate_time = time.time() - start_time
743
 
744
  videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
 
750
  yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
751
  gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
752
 
753
+ finish_tstamp = time.time()
754
  # logger.info(f"===output===: {output}")
755
  data_0 = {
756
+ "tstamp": round(finish_tstamp, 4),
757
  "ip": get_ip(request),
758
  "model": model_name_0,
759
  "type": "online",
760
  "gen_params": {},
761
  "state": state_0.dict(),
762
  "start": round(start_time, 4),
763
+ "finish": round(finish_tstamp, 4),
764
  "time": round(finish_time - start_time, 4),
765
  "generate_time": round(generate_time, 4),
766
  "render_time": round(render_time, 4),
767
  }
768
  data_1 = {
769
+ "tstamp": round(finish_tstamp, 4),
770
  "ip": get_ip(request),
771
  "model": model_name_1,
772
  "type": "online",
773
  "gen_params": {},
774
  "state": state_1.dict(),
775
  "start": round(start_time, 4),
776
+ "finish": round(finish_tstamp, 4),
777
  "time": round(finish_time - start_time, 4),
778
  "generate_time": round(generate_time, 4),
779
  "render_time": round(render_time, 4),
serve/leaderboard.py CHANGED
@@ -39,8 +39,6 @@ leader_component_values = [None] * 5
39
  def make_leaderboard_md(elo_results):
40
  leaderboard_md = f"""
41
  # 🏆 GenAI-Arena Leaderboard
42
- | [GitHub](https://github.com/TIGER-AI-Lab/ImagenHub) | [Dataset](https://huggingface.co/ImagenHub) | [Twitter](https://twitter.com/TianleLI123/status/1757245259149422752) |
43
-
44
  """
45
  return leaderboard_md
46
 
 
39
  def make_leaderboard_md(elo_results):
40
  leaderboard_md = f"""
41
  # 🏆 GenAI-Arena Leaderboard
 
 
42
  """
43
  return leaderboard_md
44
 
serve/vote_utils.py CHANGED
@@ -26,12 +26,13 @@ def vote_last_response_t2s(state, dim, vote_type, model_selector, request: gr.Re
26
  fout.write(json.dumps(data) + "\n")
27
  append_json_item_on_log_server(data, get_conv_log_filename())
28
 
29
- def vote_last_response_t2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
30
  with open(get_conv_log_filename(), "a") as fout:
31
  data = {
32
  "tstamp": round(time.time(), 4),
33
  "dim": dim,
34
  "type": vote_type,
 
35
  "models": [x for x in model_selectors],
36
  "states": [x.dict() for x in states],
37
  "ip": get_ip(request),
@@ -65,12 +66,13 @@ def vote_last_response_i2s(state, dim, vote_type, model_selector, request: gr.Re
65
  # save_image_file_on_log_server(output_file)
66
  # save_image_file_on_log_server(source_file)
67
 
68
- def vote_last_response_i2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
69
  with open(get_conv_log_filename(), "a") as fout:
70
  data = {
71
  "tstamp": round(time.time(), 4),
72
  "dim": dim,
73
  "type": vote_type,
 
74
  "models": [x for x in model_selectors],
75
  "states": [x.dict() for x in states],
76
  "ip": get_ip(request),
@@ -91,23 +93,23 @@ def vote_last_response_i2s_multi(states, dim, vote_type, model_selectors, reques
91
  ## Text-to-Shape Generation (t2s) Single Model Direct Chat
92
  def upvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
93
  ip = get_ip(request)
 
94
  t2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
95
  vote_last_response_t2s(state, dim_md, "upvote", model_selector, request)
96
- state.evaluted_dims += 1
97
  return (state,) + (disable_btn,) * 3
98
 
99
  def downvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
100
  ip = get_ip(request)
 
101
  t2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
102
  vote_last_response_t2s(state, dim_md, "downvote", model_selector, request)
103
- state.evaluted_dims += 1
104
  return (state,) + (disable_btn,) * 3
105
 
106
  def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
107
  ip = get_ip(request)
 
108
  t2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
109
  vote_last_response_t2s(state, dim_md, "flag", model_selector, request)
110
- state.evaluted_dims += 1
111
  return (state,) + (disable_btn,) * 3
112
 
113
 
@@ -115,132 +117,153 @@ def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
115
  def leftvote_last_response_t2s_named(
116
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
117
  ):
 
 
118
  t2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
119
  vote_last_response_t2s_multi(
120
- [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
121
  )
122
- state0.evaluted_dims += 1
123
- state1.evaluted_dims += 1
124
  return (state0, state1) + (disable_btn,) * 4
125
 
126
  def rightvote_last_response_t2s_named(
127
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
128
  ):
 
 
129
  t2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
130
  vote_last_response_t2s_multi(
131
- [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
132
  )
133
- state0.evaluted_dims += 1
134
- state1.evaluted_dims += 1
135
  return (state0, state1) + (disable_btn,) * 4
136
 
137
  def tievote_last_response_t2s_named(
138
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
139
  ):
 
 
140
  t2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
141
  vote_last_response_t2s_multi(
142
- [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
143
  )
144
- state0.evaluted_dims += 1
145
- state1.evaluted_dims += 1
146
  return (state0, state1) + (disable_btn,) * 4
147
 
148
  def bothbad_vote_last_response_t2s_named(
149
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
150
  ):
 
 
151
  t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
152
  vote_last_response_t2s_multi(
153
- [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
154
  )
155
- state0.evaluted_dims += 1
156
- state1.evaluted_dims += 1
157
  return (state0, state1) + (disable_btn,) * 4
158
 
159
 
160
  def leftvote_last_response_t2s_anony(
161
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
162
  ):
 
 
163
  t2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
164
  vote_last_response_t2s_multi(
165
- [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
166
  )
167
 
168
- state0.evaluted_dims += 1
169
- state1.evaluted_dims += 1
170
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
171
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
172
- return (state0, state1) + (disable_btn,) * 4 + names
173
- else:
174
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
 
175
 
176
  def rightvote_last_response_t2s_anony(
177
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
178
  ):
 
 
179
  t2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
180
  vote_last_response_t2s_multi(
181
- [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
182
  )
183
 
184
- state0.evaluted_dims += 1
185
- state1.evaluted_dims += 1
186
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
187
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
188
- return (state0, state1) + (disable_btn,) * 4 + names
189
- else:
190
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
191
 
192
  def tievote_last_response_t2s_anony(
193
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
194
  ):
 
 
195
  t2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
196
  vote_last_response_t2s_multi(
197
- [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
198
  )
199
 
200
- state0.evaluted_dims += 1
201
- state1.evaluted_dims += 1
202
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
203
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
204
- return (state0, state1) + (disable_btn,) * 4 + names
205
- else:
206
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
207
 
208
  def bothbad_vote_last_response_t2s_anony(
209
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
210
  ):
 
 
211
  t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
212
  vote_last_response_t2s_multi(
213
- [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
214
  )
215
 
216
- state0.evaluted_dims += 1
217
- state1.evaluted_dims += 1
218
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
219
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
220
- return (state0, state1) + (disable_btn,) * 4 + names
221
- else:
222
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
223
 
224
  ## Image-to-Shape (i2s) Single Model Direct Chat
225
  def upvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
226
  ip = get_ip(request)
 
227
  i2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
228
  vote_last_response_i2s(state, dim_md, "upvote", model_selector, request)
229
- state.evaluted_dims += 1
230
  return (state,) + (disable_btn,) * 3
231
 
232
  def downvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
233
  ip = get_ip(request)
 
234
  i2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
235
  vote_last_response_i2s(state, dim_md, "downvote", model_selector, request)
236
- state.evaluted_dims += 1
237
  return (state,) + (disable_btn,) * 3
238
 
239
  def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
240
  ip = get_ip(request)
 
241
  i2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
242
  vote_last_response_i2s(state, dim_md, "flag", model_selector, request)
243
- state.evaluted_dims += 1
244
  return (state,) + (disable_btn,) * 3
245
 
246
 
@@ -248,114 +271,134 @@ def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
248
  def leftvote_last_response_i2s_named(
249
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
250
  ):
 
 
251
  i2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
252
  vote_last_response_i2s_multi(
253
- [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
254
  )
255
- state0.evaluted_dims += 1
256
- state1.evaluted_dims += 1
257
  return (state0, state1) + (disable_btn,) * 4
258
 
259
  def rightvote_last_response_i2s_named(
260
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
261
  ):
 
 
262
  i2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
263
  vote_last_response_i2s_multi(
264
- [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
265
  )
266
- state0.evaluted_dims += 1
267
- state1.evaluted_dims += 1
268
  return (state0, state1) + (disable_btn,) * 4
269
 
270
  def tievote_last_response_i2s_named(
271
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
272
  ):
 
 
273
  i2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
274
  vote_last_response_i2s_multi(
275
- [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
276
  )
277
- state0.evaluted_dims += 1
278
- state1.evaluted_dims += 1
279
  return (state0, state1) + (disable_btn,) * 4
280
 
281
  def bothbad_vote_last_response_i2s_named(
282
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
283
  ):
 
 
284
  i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
285
  vote_last_response_i2s_multi(
286
- [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
287
  )
288
- state0.evaluted_dims += 1
289
- state1.evaluted_dims += 1
290
  return (state0, state1) + (disable_btn,) * 4
291
 
292
 
293
  def leftvote_last_response_i2s_anony(
294
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
295
  ):
 
 
296
  i2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
297
  vote_last_response_i2s_multi(
298
- [state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
299
  )
300
 
301
- state0.evaluted_dims += 1
302
- state1.evaluted_dims += 1
303
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
304
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
305
- return (state0, state1) + (disable_btn,) * 4 + names
306
- else:
307
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
308
 
309
 
310
  def rightvote_last_response_i2s_anony(
311
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
312
  ):
 
 
313
  i2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
314
  vote_last_response_i2s_multi(
315
- [state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
316
  )
317
 
318
- state0.evaluted_dims += 1
319
- state1.evaluted_dims += 1
320
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
321
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
322
- return (state0, state1) + (disable_btn,) * 4 + names
323
- else:
324
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
325
 
326
 
327
  def tievote_last_response_i2s_anony(
328
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
329
  ):
 
 
330
  i2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
331
  vote_last_response_i2s_multi(
332
- [state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
333
  )
334
 
335
- state0.evaluted_dims += 1
336
- state1.evaluted_dims += 1
337
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
338
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
339
- return (state0, state1) + (disable_btn,) * 4 + names
340
- else:
341
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
342
 
343
 
344
  def bothbad_vote_last_response_i2s_anony(
345
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
346
  ):
 
 
347
  i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
348
  vote_last_response_i2s_multi(
349
- [state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
350
  )
351
 
352
- state0.evaluted_dims += 1
353
- state1.evaluted_dims += 1
354
- if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
355
- names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
356
- return (state0, state1) + (disable_btn,) * 4 + names
357
- else:
358
- return (state0, state1) + (disable_btn,) * 4 + ("", "")
 
 
 
359
 
360
 
361
 
@@ -383,13 +426,13 @@ def share_click_t2s_multi(state0, state1, model_selector0, model_selector1, requ
383
  t2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
384
  if state0 is not None and state1 is not None:
385
  vote_last_response_t2s_multi(
386
- [state0, state1], "share", [model_selector0, model_selector1], request
387
  )
388
 
389
  def share_click_i2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
390
  i2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
391
  if state0 is not None and state1 is not None:
392
  vote_last_response_i2s_multi(
393
- [state0, state1], "share", [model_selector0, model_selector1], request
394
  )
395
 
 
26
  fout.write(json.dumps(data) + "\n")
27
  append_json_item_on_log_server(data, get_conv_log_filename())
28
 
29
+ def vote_last_response_t2s_multi(states, dim, vote_type, is_anony: bool, model_selectors, request: gr.Request):
30
  with open(get_conv_log_filename(), "a") as fout:
31
  data = {
32
  "tstamp": round(time.time(), 4),
33
  "dim": dim,
34
  "type": vote_type,
35
+ "anony": is_anony,
36
  "models": [x for x in model_selectors],
37
  "states": [x.dict() for x in states],
38
  "ip": get_ip(request),
 
66
  # save_image_file_on_log_server(output_file)
67
  # save_image_file_on_log_server(source_file)
68
 
69
+ def vote_last_response_i2s_multi(states, dim, vote_type, is_anony, model_selectors, request: gr.Request):
70
  with open(get_conv_log_filename(), "a") as fout:
71
  data = {
72
  "tstamp": round(time.time(), 4),
73
  "dim": dim,
74
  "type": vote_type,
75
+ "anony": is_anony,
76
  "models": [x for x in model_selectors],
77
  "states": [x.dict() for x in states],
78
  "ip": get_ip(request),
 
93
  ## Text-to-Shape Generation (t2s) Single Model Direct Chat
94
  def upvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
95
  ip = get_ip(request)
96
+ state.evaluted_dims += 1
97
  t2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
98
  vote_last_response_t2s(state, dim_md, "upvote", model_selector, request)
 
99
  return (state,) + (disable_btn,) * 3
100
 
101
  def downvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
102
  ip = get_ip(request)
103
+ state.evaluted_dims += 1
104
  t2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
105
  vote_last_response_t2s(state, dim_md, "downvote", model_selector, request)
 
106
  return (state,) + (disable_btn,) * 3
107
 
108
  def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
109
  ip = get_ip(request)
110
+ state.evaluted_dims += 1
111
  t2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
112
  vote_last_response_t2s(state, dim_md, "flag", model_selector, request)
 
113
  return (state,) + (disable_btn,) * 3
114
 
115
 
 
117
  def leftvote_last_response_t2s_named(
118
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
119
  ):
120
+ state0.evaluted_dims += 1
121
+ state1.evaluted_dims += 1
122
  t2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
123
  vote_last_response_t2s_multi(
124
+ [state0, state1], dim_md, "leftvote", False, [model_selector0, model_selector1], request
125
  )
 
 
126
  return (state0, state1) + (disable_btn,) * 4
127
 
128
  def rightvote_last_response_t2s_named(
129
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
130
  ):
131
+ state0.evaluted_dims += 1
132
+ state1.evaluted_dims += 1
133
  t2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
134
  vote_last_response_t2s_multi(
135
+ [state0, state1], dim_md, "rightvote", False, [model_selector0, model_selector1], request
136
  )
 
 
137
  return (state0, state1) + (disable_btn,) * 4
138
 
139
  def tievote_last_response_t2s_named(
140
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
141
  ):
142
+ state0.evaluted_dims += 1
143
+ state1.evaluted_dims += 1
144
  t2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
145
  vote_last_response_t2s_multi(
146
+ [state0, state1], dim_md, "tievote", False, [model_selector0, model_selector1], request
147
  )
 
 
148
  return (state0, state1) + (disable_btn,) * 4
149
 
150
  def bothbad_vote_last_response_t2s_named(
151
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
152
  ):
153
+ state0.evaluted_dims += 1
154
+ state1.evaluted_dims += 1
155
  t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
156
  vote_last_response_t2s_multi(
157
+ [state0, state1], dim_md, "bothbad_vote", False, [model_selector0, model_selector1], request
158
  )
 
 
159
  return (state0, state1) + (disable_btn,) * 4
160
 
161
 
162
  def leftvote_last_response_t2s_anony(
163
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
164
  ):
165
+ state0.evaluted_dims += 1
166
+ state1.evaluted_dims += 1
167
  t2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
168
  vote_last_response_t2s_multi(
169
+ [state0, state1], dim_md, "leftvote", True, [model_selector0, model_selector1], request
170
  )
171
 
172
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
173
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
174
+ # return (state0, state1) + (disable_btn,) * 4 + names
175
+ # else:
176
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=False), gr.Markdown(f"### Model B: {state1.model_name}", visible=False))
177
+ # return (state0, state1) + (disable_btn,) * 4 + names
178
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
179
+ return (state0, state1) \
180
+ + (disable_btn,) * 4 \
181
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
182
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
183
 
184
  def rightvote_last_response_t2s_anony(
185
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
186
  ):
187
+ state0.evaluted_dims += 1
188
+ state1.evaluted_dims += 1
189
  t2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
190
  vote_last_response_t2s_multi(
191
+ [state0, state1], dim_md, "rightvote", True, [model_selector0, model_selector1], request
192
  )
193
 
194
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
195
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
196
+ # return (state0, state1) + (disable_btn,) * 4 + names
197
+ # else:
198
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
199
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
200
+ return (state0, state1) \
201
+ + (disable_btn,) * 4 \
202
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
203
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
204
 
205
  def tievote_last_response_t2s_anony(
206
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
207
  ):
208
+ state0.evaluted_dims += 1
209
+ state1.evaluted_dims += 1
210
  t2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
211
  vote_last_response_t2s_multi(
212
+ [state0, state1], dim_md, "tievote", True, [model_selector0, model_selector1], request
213
  )
214
 
215
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
216
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
217
+ # return (state0, state1) + (disable_btn,) * 4 + names
218
+ # else:
219
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
220
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
221
+ return (state0, state1) \
222
+ + (disable_btn,) * 4 \
223
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
224
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
225
 
226
  def bothbad_vote_last_response_t2s_anony(
227
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
228
  ):
229
+ state0.evaluted_dims += 1
230
+ state1.evaluted_dims += 1
231
  t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
232
  vote_last_response_t2s_multi(
233
+ [state0, state1], dim_md, "bothbad_vote", True, [model_selector0, model_selector1], request
234
  )
235
 
236
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
237
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
238
+ # return (state0, state1) + (disable_btn,) * 4 + names
239
+ # else:
240
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
241
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
242
+ return (state0, state1) \
243
+ + (disable_btn,) * 4 \
244
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
245
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
246
 
247
  ## Image-to-Shape (i2s) Single Model Direct Chat
248
  def upvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
249
  ip = get_ip(request)
250
+ state.evaluted_dims += 1
251
  i2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
252
  vote_last_response_i2s(state, dim_md, "upvote", model_selector, request)
 
253
  return (state,) + (disable_btn,) * 3
254
 
255
  def downvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
256
  ip = get_ip(request)
257
+ state.evaluted_dims += 1
258
  i2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
259
  vote_last_response_i2s(state, dim_md, "downvote", model_selector, request)
 
260
  return (state,) + (disable_btn,) * 3
261
 
262
  def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
263
  ip = get_ip(request)
264
+ state.evaluted_dims += 1
265
  i2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
266
  vote_last_response_i2s(state, dim_md, "flag", model_selector, request)
 
267
  return (state,) + (disable_btn,) * 3
268
 
269
 
 
271
  def leftvote_last_response_i2s_named(
272
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
273
  ):
274
+ state0.evaluted_dims += 1
275
+ state1.evaluted_dims += 1
276
  i2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
277
  vote_last_response_i2s_multi(
278
+ [state0, state1], dim_md, "leftvote", False, [model_selector0, model_selector1], request
279
  )
 
 
280
  return (state0, state1) + (disable_btn,) * 4
281
 
282
  def rightvote_last_response_i2s_named(
283
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
284
  ):
285
+ state0.evaluted_dims += 1
286
+ state1.evaluted_dims += 1
287
  i2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
288
  vote_last_response_i2s_multi(
289
+ [state0, state1], dim_md, "rightvote", False, [model_selector0, model_selector1], request
290
  )
 
 
291
  return (state0, state1) + (disable_btn,) * 4
292
 
293
  def tievote_last_response_i2s_named(
294
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
295
  ):
296
+ state0.evaluted_dims += 1
297
+ state1.evaluted_dims += 1
298
  i2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
299
  vote_last_response_i2s_multi(
300
+ [state0, state1], dim_md, "tievote", False, [model_selector0, model_selector1], request
301
  )
 
 
302
  return (state0, state1) + (disable_btn,) * 4
303
 
304
  def bothbad_vote_last_response_i2s_named(
305
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
306
  ):
307
+ state0.evaluted_dims += 1
308
+ state1.evaluted_dims += 1
309
  i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
310
  vote_last_response_i2s_multi(
311
+ [state0, state1], dim_md, "bothbad_vote", False, [model_selector0, model_selector1], request
312
  )
 
 
313
  return (state0, state1) + (disable_btn,) * 4
314
 
315
 
316
  def leftvote_last_response_i2s_anony(
317
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
318
  ):
319
+ state0.evaluted_dims += 1
320
+ state1.evaluted_dims += 1
321
  i2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
322
  vote_last_response_i2s_multi(
323
+ [state0, state1], dim_md, "leftvote", True, [model_selector0, model_selector1], request
324
  )
325
 
326
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
327
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
328
+ # return (state0, state1) + (disable_btn,) * 4 + names
329
+ # else:
330
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
331
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
332
+ return (state0, state1) \
333
+ + (disable_btn,) * 4 \
334
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
335
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
336
 
337
 
338
  def rightvote_last_response_i2s_anony(
339
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
340
  ):
341
+ state0.evaluted_dims += 1
342
+ state1.evaluted_dims += 1
343
  i2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
344
  vote_last_response_i2s_multi(
345
+ [state0, state1], dim_md, "rightvote", True, [model_selector0, model_selector1], request
346
  )
347
 
348
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
349
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
350
+ # return (state0, state1) + (disable_btn,) * 4 + names
351
+ # else:
352
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
353
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
354
+ return (state0, state1) \
355
+ + (disable_btn,) * 4 \
356
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
357
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
358
 
359
 
360
  def tievote_last_response_i2s_anony(
361
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
362
  ):
363
+ state0.evaluted_dims += 1
364
+ state1.evaluted_dims += 1
365
  i2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
366
  vote_last_response_i2s_multi(
367
+ [state0, state1], dim_md, "tievote", True, [model_selector0, model_selector1], request
368
  )
369
 
370
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
371
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
372
+ # return (state0, state1) + (disable_btn,) * 4 + names
373
+ # else:
374
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
375
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
376
+ return (state0, state1) \
377
+ + (disable_btn,) * 4 \
378
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
379
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
380
 
381
 
382
  def bothbad_vote_last_response_i2s_anony(
383
  state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
384
  ):
385
+ state0.evaluted_dims += 1
386
+ state1.evaluted_dims += 1
387
  i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
388
  vote_last_response_i2s_multi(
389
+ [state0, state1], dim_md, "bothbad_vote", True, [model_selector0, model_selector1], request
390
  )
391
 
392
+ # if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
393
+ # names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
394
+ # return (state0, state1) + (disable_btn,) * 4 + names
395
+ # else:
396
+ # return (state0, state1) + (disable_btn,) * 4 + ("", "")
397
+ is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
398
+ return (state0, state1) \
399
+ + (disable_btn,) * 4 \
400
+ + (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
401
+ gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
402
 
403
 
404
 
 
426
  t2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
427
  if state0 is not None and state1 is not None:
428
  vote_last_response_t2s_multi(
429
+ [state0, state1], "", "share", True, [model_selector0, model_selector1], request
430
  )
431
 
432
  def share_click_i2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
433
  i2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
434
  if state0 is not None and state1 is not None:
435
  vote_last_response_i2s_multi(
436
+ [state0, state1], "", "share", True, [model_selector0, model_selector1], request
437
  )
438