Spaces:
Running
Running
ZhangYuhan
commited on
Commit
•
6dc2db5
1
Parent(s):
a8e2ac2
update serve
Browse files- arena_elo/elo_rating/__init__.py +0 -0
- arena_elo/elo_rating/basic_stats.py +227 -0
- arena_elo/elo_rating/clean_battle_data.py +382 -0
- arena_elo/elo_rating/elo_analysis.py +378 -0
- arena_elo/elo_rating/generate_leaderboard.py +71 -0
- arena_elo/elo_rating/inspect_conv_rating.py +234 -0
- arena_elo/elo_rating/inspect_cost.py +177 -0
- arena_elo/elo_rating/inspect_elo_rating_pkl.py +33 -0
- arena_elo/elo_rating/model_registry.py +578 -0
- arena_elo/elo_rating/upload_battle_data.py +193 -0
- arena_elo/elo_rating/utils.py +83 -0
- arena_elo/evaluator/convert_to_evaluator_data.py +134 -0
- constants.py +1 -1
- model/model_worker.py +4 -2
- serve/inference.py +57 -5
- serve/leaderboard.py +0 -2
- serve/vote_utils.py +141 -98
arena_elo/elo_rating/__init__.py
ADDED
File without changes
|
arena_elo/elo_rating/basic_stats.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import code
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pytz import timezone
|
7 |
+
import time
|
8 |
+
|
9 |
+
import pandas as pd # pandas>=2.0.3
|
10 |
+
import plotly.express as px
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
NUM_SERVERS = 1
|
15 |
+
LOG_ROOT_DIR = os.getenv("LOGDIR", None)
|
16 |
+
if LOG_ROOT_DIR is None:
|
17 |
+
raise ValueError("LOGDIR environment variable not set, please set it by `export LOGDIR=...`")
|
18 |
+
|
19 |
+
def get_log_files(max_num_files=None):
|
20 |
+
log_root = os.path.expanduser(LOG_ROOT_DIR)
|
21 |
+
filenames = []
|
22 |
+
if NUM_SERVERS == 1:
|
23 |
+
for filename in os.listdir(log_root):
|
24 |
+
if filename.endswith("-conv.json"):
|
25 |
+
filepath = f"{log_root}/{filename}"
|
26 |
+
name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
|
27 |
+
filenames.append(name_tstamp_tuple)
|
28 |
+
else:
|
29 |
+
for i in range(NUM_SERVERS):
|
30 |
+
for filename in os.listdir(f"{log_root}/server{i}"):
|
31 |
+
if filename.endswith("-conv.json"):
|
32 |
+
filepath = f"{log_root}/server{i}/{filename}"
|
33 |
+
name_tstamp_tuple = (filepath, os.path.getmtime(filepath))
|
34 |
+
filenames.append(name_tstamp_tuple)
|
35 |
+
# sort by tstamp
|
36 |
+
filenames = sorted(filenames, key=lambda x: x[1])
|
37 |
+
filenames = [x[0] for x in filenames]
|
38 |
+
|
39 |
+
max_num_files = max_num_files or len(filenames)
|
40 |
+
filenames = filenames[-max_num_files:]
|
41 |
+
return filenames
|
42 |
+
|
43 |
+
|
44 |
+
def load_log_files(filename):
|
45 |
+
data = []
|
46 |
+
for retry in range(5):
|
47 |
+
try:
|
48 |
+
lines = open(filename).readlines()
|
49 |
+
break
|
50 |
+
except FileNotFoundError:
|
51 |
+
time.sleep(2)
|
52 |
+
|
53 |
+
for l in lines:
|
54 |
+
row = json.loads(l)
|
55 |
+
data.append(
|
56 |
+
dict(
|
57 |
+
type=row["type"],
|
58 |
+
tstamp=row["tstamp"],
|
59 |
+
model=row.get("model", ""),
|
60 |
+
models=row.get("models", ["", ""]),
|
61 |
+
)
|
62 |
+
)
|
63 |
+
return data
|
64 |
+
|
65 |
+
|
66 |
+
def load_log_files_parallel(log_files, num_threads=16):
|
67 |
+
data_all = []
|
68 |
+
from multiprocessing import Pool
|
69 |
+
|
70 |
+
with Pool(num_threads) as p:
|
71 |
+
ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files)))
|
72 |
+
for ret in ret_all:
|
73 |
+
data_all.extend(ret)
|
74 |
+
return data_all
|
75 |
+
|
76 |
+
|
77 |
+
def get_anony_vote_df(df):
|
78 |
+
anony_vote_df = df[
|
79 |
+
df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
|
80 |
+
]
|
81 |
+
anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
|
82 |
+
return anony_vote_df
|
83 |
+
|
84 |
+
|
85 |
+
def merge_counts(series, on, names):
|
86 |
+
ret = pd.merge(series[0], series[1], on=on)
|
87 |
+
for i in range(2, len(series)):
|
88 |
+
ret = pd.merge(ret, series[i], on=on)
|
89 |
+
ret = ret.reset_index()
|
90 |
+
old_names = list(ret.columns)[-len(series) :]
|
91 |
+
rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
|
92 |
+
ret = ret.rename(columns=rename)
|
93 |
+
return ret
|
94 |
+
|
95 |
+
|
96 |
+
def report_basic_stats(log_files):
|
97 |
+
df_all = load_log_files_parallel(log_files)
|
98 |
+
df_all = pd.DataFrame(df_all)
|
99 |
+
now_t = df_all["tstamp"].max()
|
100 |
+
df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
|
101 |
+
df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
|
102 |
+
anony_vote_df_all = get_anony_vote_df(df_all)
|
103 |
+
|
104 |
+
# Chat trends
|
105 |
+
chat_dates = [
|
106 |
+
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
|
107 |
+
"%Y-%m-%d"
|
108 |
+
)
|
109 |
+
for x in df_all[df_all["type"] == "chat"]["tstamp"]
|
110 |
+
]
|
111 |
+
chat_dates_counts = pd.value_counts(chat_dates)
|
112 |
+
vote_dates = [
|
113 |
+
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
|
114 |
+
"%Y-%m-%d"
|
115 |
+
)
|
116 |
+
for x in anony_vote_df_all["tstamp"]
|
117 |
+
]
|
118 |
+
vote_dates_counts = pd.value_counts(vote_dates)
|
119 |
+
chat_dates_bar = go.Figure(
|
120 |
+
data=[
|
121 |
+
go.Bar(
|
122 |
+
name="Anony. Vote",
|
123 |
+
x=vote_dates_counts.index,
|
124 |
+
y=vote_dates_counts,
|
125 |
+
text=[f"{val:.0f}" for val in vote_dates_counts],
|
126 |
+
textposition="auto",
|
127 |
+
),
|
128 |
+
go.Bar(
|
129 |
+
name="Chat",
|
130 |
+
x=chat_dates_counts.index,
|
131 |
+
y=chat_dates_counts,
|
132 |
+
text=[f"{val:.0f}" for val in chat_dates_counts],
|
133 |
+
textposition="auto",
|
134 |
+
),
|
135 |
+
]
|
136 |
+
)
|
137 |
+
chat_dates_bar.update_layout(
|
138 |
+
barmode="stack",
|
139 |
+
xaxis_title="Dates",
|
140 |
+
yaxis_title="Count",
|
141 |
+
height=300,
|
142 |
+
width=1200,
|
143 |
+
)
|
144 |
+
|
145 |
+
# Model call counts
|
146 |
+
model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
|
147 |
+
model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
|
148 |
+
model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
|
149 |
+
model_hist = merge_counts(
|
150 |
+
[model_hist_all, model_hist_1_day, model_hist_1_hour],
|
151 |
+
on="model",
|
152 |
+
names=["All", "Last Day", "Last Hour"],
|
153 |
+
)
|
154 |
+
model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
|
155 |
+
|
156 |
+
# Action counts
|
157 |
+
action_hist_all = df_all["type"].value_counts()
|
158 |
+
action_hist_1_day = df_1_day["type"].value_counts()
|
159 |
+
action_hist_1_hour = df_1_hour["type"].value_counts()
|
160 |
+
action_hist = merge_counts(
|
161 |
+
[action_hist_all, action_hist_1_day, action_hist_1_hour],
|
162 |
+
on="type",
|
163 |
+
names=["All", "Last Day", "Last Hour"],
|
164 |
+
)
|
165 |
+
action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
|
166 |
+
|
167 |
+
# Anony vote counts
|
168 |
+
anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
|
169 |
+
anony_vote_df_1_day = get_anony_vote_df(df_1_day)
|
170 |
+
anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
|
171 |
+
# anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
|
172 |
+
# anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
|
173 |
+
anony_vote_hist = merge_counts(
|
174 |
+
[anony_vote_hist_all, anony_vote_hist_1_day],
|
175 |
+
on="type",
|
176 |
+
names=["All", "Last Day"],
|
177 |
+
)
|
178 |
+
anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
|
179 |
+
|
180 |
+
# Last 24 hours
|
181 |
+
chat_1_day = df_1_day[df_1_day["type"] == "chat"]
|
182 |
+
num_chats_last_24_hours = []
|
183 |
+
base = df_1_day["tstamp"].min()
|
184 |
+
for i in range(24, 0, -1):
|
185 |
+
left = base + (i - 1) * 3600
|
186 |
+
right = base + i * 3600
|
187 |
+
num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
|
188 |
+
num_chats_last_24_hours.append(num)
|
189 |
+
times = [
|
190 |
+
datetime.datetime.fromtimestamp(
|
191 |
+
base + i * 3600, tz=timezone("US/Pacific")
|
192 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
193 |
+
for i in range(24, 0, -1)
|
194 |
+
]
|
195 |
+
last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
|
196 |
+
last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
|
197 |
+
|
198 |
+
# Last update datetime
|
199 |
+
last_updated_tstamp = now_t
|
200 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
201 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
202 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
203 |
+
|
204 |
+
# code.interact(local=locals())
|
205 |
+
|
206 |
+
return {
|
207 |
+
"chat_dates_bar": chat_dates_bar,
|
208 |
+
"model_hist_md": model_hist_md,
|
209 |
+
"action_hist_md": action_hist_md,
|
210 |
+
"anony_vote_hist_md": anony_vote_hist_md,
|
211 |
+
"num_chats_last_24_hours": last_24_hours_md,
|
212 |
+
"last_updated_datetime": last_updated_datetime,
|
213 |
+
}
|
214 |
+
|
215 |
+
|
216 |
+
if __name__ == "__main__":
|
217 |
+
parser = argparse.ArgumentParser()
|
218 |
+
parser.add_argument("--max-num-files", type=int)
|
219 |
+
args = parser.parse_args()
|
220 |
+
|
221 |
+
log_files = get_log_files(args.max_num_files)
|
222 |
+
basic_stats = report_basic_stats(log_files)
|
223 |
+
|
224 |
+
print(basic_stats["action_hist_md"] + "\n")
|
225 |
+
print(basic_stats["model_hist_md"] + "\n")
|
226 |
+
print(basic_stats["anony_vote_hist_md"] + "\n")
|
227 |
+
print(basic_stats["num_chats_last_24_hours"] + "\n")
|
arena_elo/elo_rating/clean_battle_data.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Clean chatbot arena battle log.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 clean_battle_data.py --mode conv_release
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import datetime
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from pytz import timezone
|
13 |
+
import time
|
14 |
+
import PIL
|
15 |
+
from PIL import ImageFile
|
16 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR
|
21 |
+
from .utils import detect_language, get_time_stamp_from_date
|
22 |
+
|
23 |
+
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
|
24 |
+
IDENTITY_WORDS = [
|
25 |
+
"vicuna",
|
26 |
+
"lmsys",
|
27 |
+
"koala",
|
28 |
+
"uc berkeley",
|
29 |
+
"open assistant",
|
30 |
+
"laion",
|
31 |
+
"chatglm",
|
32 |
+
"chatgpt",
|
33 |
+
"gpt-4",
|
34 |
+
"openai",
|
35 |
+
"anthropic",
|
36 |
+
"claude",
|
37 |
+
"bard",
|
38 |
+
"palm",
|
39 |
+
"lamda",
|
40 |
+
"google",
|
41 |
+
"llama",
|
42 |
+
"qianwan",
|
43 |
+
"alibaba",
|
44 |
+
"mistral",
|
45 |
+
"zhipu",
|
46 |
+
"KEG lab",
|
47 |
+
"01.AI",
|
48 |
+
"AI2",
|
49 |
+
"Tülu",
|
50 |
+
"Tulu",
|
51 |
+
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
|
52 |
+
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
|
53 |
+
"API REQUEST ERROR. Please increase the number of max tokens.",
|
54 |
+
"**API REQUEST ERROR** Reason: The response was blocked.",
|
55 |
+
"**API REQUEST ERROR**",
|
56 |
+
]
|
57 |
+
|
58 |
+
for i in range(len(IDENTITY_WORDS)):
|
59 |
+
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
|
60 |
+
|
61 |
+
|
62 |
+
def remove_html(raw):
|
63 |
+
if raw.startswith("<h3>"):
|
64 |
+
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
|
65 |
+
if raw.startswith("### Model A: ") or raw.startswith("### Model B: "):
|
66 |
+
return raw[13:]
|
67 |
+
return raw
|
68 |
+
|
69 |
+
|
70 |
+
def to_openai_format(messages):
|
71 |
+
roles = ["user", "assistant"]
|
72 |
+
ret = []
|
73 |
+
for i, x in enumerate(messages):
|
74 |
+
ret.append({"role": roles[i % 2], "content": x[1]})
|
75 |
+
return ret
|
76 |
+
|
77 |
+
|
78 |
+
def replace_model_name(old_name, tstamp):
|
79 |
+
replace_dict = {
|
80 |
+
"bard": "palm-2",
|
81 |
+
"claude-v1": "claude-1",
|
82 |
+
"claude-instant-v1": "claude-instant-1",
|
83 |
+
"oasst-sft-1-pythia-12b": "oasst-pythia-12b",
|
84 |
+
"claude-2": "claude-2.0",
|
85 |
+
"PlayGroundV2": "Playground v2",
|
86 |
+
}
|
87 |
+
if old_name in ["gpt-4", "gpt-3.5-turbo"]:
|
88 |
+
if tstamp > 1687849200:
|
89 |
+
return old_name + "-0613"
|
90 |
+
else:
|
91 |
+
return old_name + "-0314"
|
92 |
+
if old_name in replace_dict:
|
93 |
+
return replace_dict[old_name]
|
94 |
+
return old_name
|
95 |
+
|
96 |
+
|
97 |
+
def read_file(filename):
|
98 |
+
data = []
|
99 |
+
for retry in range(5):
|
100 |
+
try:
|
101 |
+
# lines = open(filename).readlines()
|
102 |
+
for l in open(filename):
|
103 |
+
row = json.loads(l)
|
104 |
+
if row["type"] in VOTES:
|
105 |
+
data.append(row)
|
106 |
+
break
|
107 |
+
except FileNotFoundError:
|
108 |
+
time.sleep(2)
|
109 |
+
return data
|
110 |
+
|
111 |
+
|
112 |
+
def read_file_parallel(log_files, num_threads=16):
|
113 |
+
data_all = []
|
114 |
+
from multiprocessing import Pool
|
115 |
+
|
116 |
+
with Pool(num_threads) as p:
|
117 |
+
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
|
118 |
+
for ret in ret_all:
|
119 |
+
data_all.extend(ret)
|
120 |
+
return data_all
|
121 |
+
|
122 |
+
def load_image(image_path):
|
123 |
+
try:
|
124 |
+
return PIL.Image.open(image_path)
|
125 |
+
except:
|
126 |
+
return None
|
127 |
+
|
128 |
+
def clean_battle_data(
|
129 |
+
log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False, mode="simple", task_name="t2s"
|
130 |
+
):
|
131 |
+
data = read_file_parallel(log_files, num_threads=16)
|
132 |
+
|
133 |
+
convert_type = {
|
134 |
+
"leftvote": "model_a",
|
135 |
+
"rightvote": "model_b",
|
136 |
+
"tievote": "tie",
|
137 |
+
"bothbad_vote": "tie (bothbad)",
|
138 |
+
}
|
139 |
+
|
140 |
+
all_models = set()
|
141 |
+
all_ips = dict()
|
142 |
+
ct_anony = 0
|
143 |
+
ct_invalid = 0
|
144 |
+
ct_leaked_identity = 0
|
145 |
+
ct_banned = 0
|
146 |
+
battles = []
|
147 |
+
for row in tqdm(data, desc="Cleaning"):
|
148 |
+
if row["models"][0] is None or row["models"][1] is None:
|
149 |
+
continue
|
150 |
+
|
151 |
+
# Resolve model names
|
152 |
+
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
|
153 |
+
if "model_name" in row["states"][0]:
|
154 |
+
models_hidden = [
|
155 |
+
row["states"][0]["model_name"],
|
156 |
+
row["states"][1]["model_name"],
|
157 |
+
]
|
158 |
+
if models_hidden[0] is None:
|
159 |
+
models_hidden = models_public
|
160 |
+
else:
|
161 |
+
models_hidden = models_public
|
162 |
+
|
163 |
+
if (models_public[0] == "" and models_public[1] != "") or (
|
164 |
+
models_public[1] == "" and models_public[0] != ""
|
165 |
+
):
|
166 |
+
ct_invalid += 1
|
167 |
+
continue
|
168 |
+
|
169 |
+
if models_public[0] == "" or models_public[0] == "Model A":
|
170 |
+
anony = True
|
171 |
+
models = models_hidden
|
172 |
+
ct_anony += 1
|
173 |
+
else:
|
174 |
+
anony = False
|
175 |
+
models = models_public
|
176 |
+
if not models_public == models_hidden:
|
177 |
+
ct_invalid += 1
|
178 |
+
continue
|
179 |
+
|
180 |
+
# # Detect langauge
|
181 |
+
# state = row["states"][0]
|
182 |
+
# if state["offset"] >= len(state["messages"]):
|
183 |
+
# ct_invalid += 1
|
184 |
+
# continue
|
185 |
+
# lang_code = detect_language(state["messages"][state["offset"]][1])
|
186 |
+
|
187 |
+
# # Drop conversations if the model names are leaked
|
188 |
+
# leaked_identity = False
|
189 |
+
# messages = ""
|
190 |
+
# for i in range(2):
|
191 |
+
# state = row["states"][i]
|
192 |
+
# for turn_idx, (role, msg) in enumerate(
|
193 |
+
# state["messages"][state["offset"] :]
|
194 |
+
# ):
|
195 |
+
# if msg:
|
196 |
+
# messages += msg.lower()
|
197 |
+
# for word in IDENTITY_WORDS:
|
198 |
+
# if word in messages:
|
199 |
+
# leaked_identity = True
|
200 |
+
# break
|
201 |
+
|
202 |
+
# if leaked_identity:
|
203 |
+
# ct_leaked_identity += 1
|
204 |
+
# continue
|
205 |
+
|
206 |
+
# Replace bard with palm
|
207 |
+
if task_name == "image_editing":
|
208 |
+
if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models):
|
209 |
+
# print(f"Invalid model names: {models}")
|
210 |
+
ct_invalid += 1
|
211 |
+
continue
|
212 |
+
models = [x[len("imagenhub_"):-len("_edition")] for x in models]
|
213 |
+
elif task_name == "t2i_generation":
|
214 |
+
if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models):
|
215 |
+
# print(f"Invalid model names: {models}")
|
216 |
+
ct_invalid += 1
|
217 |
+
continue
|
218 |
+
# models = [x[len("imagenhub_"):-len("_generation")] for x in models]
|
219 |
+
for i, model_name in enumerate(models):
|
220 |
+
if model_name.startswith("imagenhub_"):
|
221 |
+
models[i] = model_name[len("imagenhub_"):-len("_generation")]
|
222 |
+
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Invalid task_name: {task_name}")
|
225 |
+
models = [replace_model_name(m, row["tstamp"]) for m in models]
|
226 |
+
|
227 |
+
# Exclude certain models
|
228 |
+
if exclude_model_names and any(x in exclude_model_names for x in models):
|
229 |
+
ct_invalid += 1
|
230 |
+
continue
|
231 |
+
|
232 |
+
# if models[0] not in model_infos or models[1] not in model_infos:
|
233 |
+
# continue
|
234 |
+
|
235 |
+
# # Exclude votes before the starting date
|
236 |
+
# if model_infos and (model_infos[models[0]]["starting_from"] > row["tstamp"] or model_infos[models[1]]["starting_from"] > row["tstamp"]):
|
237 |
+
# print(f"Invalid vote before the valid starting date for {models[0]} and {models[1]}")
|
238 |
+
# ct_invalid += 1
|
239 |
+
# continue
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
if mode == "conv_release":
|
244 |
+
# assert the two images are the same
|
245 |
+
date = datetime.datetime.fromtimestamp(row["tstamp"], tz=timezone("US/Pacific")).strftime("%Y-%m-%d") # 2024-02-29
|
246 |
+
image_path_format = f"{LOG_ROOT_DIR}/{date}-convinput_images/input_image_"
|
247 |
+
image_path_0 = image_path_format + str(row["states"][0]["conv_id"]) + ".png"
|
248 |
+
image_path_1 = image_path_format + str(row["states"][1]["conv_id"]) + ".png"
|
249 |
+
if not os.path.exists(image_path_0) or not os.path.exists(image_path_1):
|
250 |
+
print(f"Image not found for {image_path_0} or {image_path_1}")
|
251 |
+
ct_invalid += 1
|
252 |
+
continue
|
253 |
+
|
254 |
+
image_0 = load_image(image_path_0)
|
255 |
+
image_1 = load_image(image_path_1)
|
256 |
+
if image_0 is None or image_1 is None:
|
257 |
+
print(f"Image not found for {image_path_0} or {image_path_1}")
|
258 |
+
ct_invalid += 1
|
259 |
+
continue
|
260 |
+
if image_0.tobytes() != image_1.tobytes():
|
261 |
+
print(f"Image not the same for {image_path_0} and {image_path_1}")
|
262 |
+
ct_invalid += 1
|
263 |
+
continue
|
264 |
+
|
265 |
+
|
266 |
+
question_id = row["states"][0]["conv_id"]
|
267 |
+
# conversation_a = to_openai_format(
|
268 |
+
# row["states"][0]["messages"][row["states"][0]["offset"] :]
|
269 |
+
# )
|
270 |
+
# conversation_b = to_openai_format(
|
271 |
+
# row["states"][1]["messages"][row["states"][1]["offset"] :]
|
272 |
+
# )
|
273 |
+
|
274 |
+
ip = row["ip"]
|
275 |
+
if ip not in all_ips:
|
276 |
+
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)}
|
277 |
+
all_ips[ip]["count"] += 1
|
278 |
+
if sanitize_ip:
|
279 |
+
user_id = f"arena_user_{all_ips[ip]['sanitized_id']}"
|
280 |
+
else:
|
281 |
+
user_id = f"{all_ips[ip]['ip']}"
|
282 |
+
|
283 |
+
if ban_ip_list is not None and ip in ban_ip_list:
|
284 |
+
ct_banned += 1
|
285 |
+
continue
|
286 |
+
|
287 |
+
# Save the results
|
288 |
+
battles.append(
|
289 |
+
dict(
|
290 |
+
question_id=question_id,
|
291 |
+
model_a=models[0],
|
292 |
+
model_b=models[1],
|
293 |
+
winner=convert_type[row["type"]],
|
294 |
+
judge=f"arena_user_{user_id}",
|
295 |
+
# conversation_a=conversation_a,
|
296 |
+
# conversation_b=conversation_b,
|
297 |
+
# turn=len(conversation_a) // 2,
|
298 |
+
anony=anony,
|
299 |
+
# language=lang_code,
|
300 |
+
tstamp=row["tstamp"],
|
301 |
+
)
|
302 |
+
)
|
303 |
+
|
304 |
+
all_models.update(models_hidden)
|
305 |
+
battles.sort(key=lambda x: x["tstamp"])
|
306 |
+
last_updated_tstamp = battles[-1]["tstamp"]
|
307 |
+
|
308 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
309 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
310 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
311 |
+
|
312 |
+
print(
|
313 |
+
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
|
314 |
+
f"#leaked_identity: {ct_leaked_identity} "
|
315 |
+
f"#banned: {ct_banned} "
|
316 |
+
)
|
317 |
+
print(f"#battles: {len(battles)}, #anony: {ct_anony}")
|
318 |
+
print(f"#models: {len(all_models)}, {all_models}")
|
319 |
+
print(f"last-updated: {last_updated_datetime}")
|
320 |
+
|
321 |
+
if ban_ip_list is not None:
|
322 |
+
for ban_ip in ban_ip_list:
|
323 |
+
if ban_ip in all_ips:
|
324 |
+
del all_ips[ban_ip]
|
325 |
+
print("Top 30 IPs:")
|
326 |
+
print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30])
|
327 |
+
return battles
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
parser = argparse.ArgumentParser()
|
332 |
+
parser.add_argument("--max-num-files", type=int)
|
333 |
+
parser.add_argument(
|
334 |
+
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
|
335 |
+
)
|
336 |
+
parser.add_argument("--task_name", type=str, choices=["t2s", "i2s"])
|
337 |
+
parser.add_argument("--exclude-model-names", type=str, nargs="+")
|
338 |
+
parser.add_argument("--ban-ip-file", type=str)
|
339 |
+
parser.add_argument("--sanitize-ip", action="store_true", default=False)
|
340 |
+
args = parser.parse_args()
|
341 |
+
|
342 |
+
log_files = get_log_files(args.max_num_files)
|
343 |
+
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None
|
344 |
+
|
345 |
+
battles = clean_battle_data(
|
346 |
+
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip, args.mode, args.task_name
|
347 |
+
)
|
348 |
+
last_updated_tstamp = battles[-1]["tstamp"]
|
349 |
+
cutoff_date = datetime.datetime.fromtimestamp(
|
350 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
351 |
+
).strftime("%Y%m%d")
|
352 |
+
|
353 |
+
if args.mode == "simple":
|
354 |
+
for x in battles:
|
355 |
+
for key in [
|
356 |
+
"conversation_a",
|
357 |
+
"conversation_b",
|
358 |
+
"question_id",
|
359 |
+
]:
|
360 |
+
if key in x:
|
361 |
+
del x[key]
|
362 |
+
print("Samples:")
|
363 |
+
for i in range(min(4, len(battles))):
|
364 |
+
print(battles[i])
|
365 |
+
output = f"clean_battle_{args.task_name}_{cutoff_date}.json"
|
366 |
+
elif args.mode == "conv_release":
|
367 |
+
# new_battles = []
|
368 |
+
# for x in battles:
|
369 |
+
# if not x["anony"]:
|
370 |
+
# continue
|
371 |
+
# for key in []:
|
372 |
+
# del x[key]
|
373 |
+
# new_battles.append(x)
|
374 |
+
# battles = new_battles
|
375 |
+
output = f"clean_battle_{args.task_name}_conv_{cutoff_date}.json"
|
376 |
+
|
377 |
+
with open(output, "w") as fout:
|
378 |
+
json.dump(battles, fout, indent=2, ensure_ascii=False)
|
379 |
+
print(f"Write cleaned data to {output}")
|
380 |
+
|
381 |
+
with open("cut_off_date.txt", "w") as fout:
|
382 |
+
fout.write(cutoff_date)
|
arena_elo/elo_rating/elo_analysis.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from collections import defaultdict
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import pickle
|
7 |
+
from pytz import timezone
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from .model_registry import get_model_info
|
15 |
+
from .basic_stats import get_log_files
|
16 |
+
from .clean_battle_data import clean_battle_data
|
17 |
+
|
18 |
+
pd.options.display.float_format = "{:.2f}".format
|
19 |
+
|
20 |
+
|
21 |
+
def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
|
22 |
+
rating = defaultdict(lambda: INIT_RATING)
|
23 |
+
|
24 |
+
for rd, model_a, model_b, winner in battles[
|
25 |
+
["model_a", "model_b", "winner"]
|
26 |
+
].itertuples():
|
27 |
+
ra = rating[model_a]
|
28 |
+
rb = rating[model_b]
|
29 |
+
ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
|
30 |
+
eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
|
31 |
+
if winner == "model_a":
|
32 |
+
sa = 1
|
33 |
+
elif winner == "model_b":
|
34 |
+
sa = 0
|
35 |
+
elif winner == "tie" or winner == "tie (bothbad)":
|
36 |
+
sa = 0.5
|
37 |
+
else:
|
38 |
+
raise Exception(f"unexpected vote {winner}")
|
39 |
+
rating[model_a] += K * (sa - ea)
|
40 |
+
rating[model_b] += K * (1 - sa - eb)
|
41 |
+
|
42 |
+
return dict(rating)
|
43 |
+
|
44 |
+
|
45 |
+
def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
|
46 |
+
rows = []
|
47 |
+
for i in tqdm(range(num_round), desc="bootstrap"):
|
48 |
+
tmp_battles = battles.sample(frac=1.0, replace=True)
|
49 |
+
rows.append(func_compute_elo(tmp_battles))
|
50 |
+
df = pd.DataFrame(rows)
|
51 |
+
return df[df.median().sort_values(ascending=False).index]
|
52 |
+
|
53 |
+
|
54 |
+
def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000):
|
55 |
+
from sklearn.linear_model import LogisticRegression
|
56 |
+
|
57 |
+
models = pd.concat([df["model_a"], df["model_b"]]).unique()
|
58 |
+
models = pd.Series(np.arange(len(models)), index=models)
|
59 |
+
|
60 |
+
# duplicate battles
|
61 |
+
df = pd.concat([df, df], ignore_index=True)
|
62 |
+
p = len(models.index)
|
63 |
+
n = df.shape[0]
|
64 |
+
|
65 |
+
X = np.zeros([n, p])
|
66 |
+
X[np.arange(n), models[df["model_a"]]] = +math.log(BASE)
|
67 |
+
X[np.arange(n), models[df["model_b"]]] = -math.log(BASE)
|
68 |
+
|
69 |
+
# one A win => two A win
|
70 |
+
Y = np.zeros(n)
|
71 |
+
Y[df["winner"] == "model_a"] = 1.0
|
72 |
+
|
73 |
+
# one tie => one A win + one B win
|
74 |
+
# find tie + tie (both bad) index
|
75 |
+
tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)")
|
76 |
+
tie_idx[len(tie_idx) // 2 :] = False
|
77 |
+
Y[tie_idx] = 1.0
|
78 |
+
|
79 |
+
lr = LogisticRegression(fit_intercept=False)
|
80 |
+
lr.fit(X, Y)
|
81 |
+
|
82 |
+
elo_scores = SCALE * lr.coef_[0] + INIT_RATING
|
83 |
+
# calibrate llama-13b to 800 if applicable
|
84 |
+
if "llama-13b" in models.index:
|
85 |
+
elo_scores += 800 - elo_scores[models["llama-13b"]]
|
86 |
+
return pd.Series(elo_scores, index=models.index).sort_values(ascending=False)
|
87 |
+
|
88 |
+
|
89 |
+
def get_median_elo_from_bootstrap(bootstrap_df):
|
90 |
+
median = dict(bootstrap_df.quantile(0.5))
|
91 |
+
median = {k: int(v + 0.5) for k, v in median.items()}
|
92 |
+
return median
|
93 |
+
|
94 |
+
|
95 |
+
def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None):
|
96 |
+
# Times each model wins as Model A
|
97 |
+
a_win_ptbl = pd.pivot_table(
|
98 |
+
battles[battles["winner"] == "model_a"],
|
99 |
+
index="model_a",
|
100 |
+
columns="model_b",
|
101 |
+
aggfunc="size",
|
102 |
+
fill_value=0,
|
103 |
+
)
|
104 |
+
|
105 |
+
# Table counting times each model wins as Model B
|
106 |
+
b_win_ptbl = pd.pivot_table(
|
107 |
+
battles[battles["winner"] == "model_b"],
|
108 |
+
index="model_a",
|
109 |
+
columns="model_b",
|
110 |
+
aggfunc="size",
|
111 |
+
fill_value=0,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Table counting number of A-B pairs
|
115 |
+
num_battles_ptbl = pd.pivot_table(
|
116 |
+
battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
|
117 |
+
)
|
118 |
+
|
119 |
+
# Computing the proportion of wins for each model as A and as B
|
120 |
+
# against all other models
|
121 |
+
row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
|
122 |
+
num_battles_ptbl + num_battles_ptbl.T
|
123 |
+
)
|
124 |
+
|
125 |
+
if model_order is None:
|
126 |
+
prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
|
127 |
+
model_order = list(prop_wins.keys())
|
128 |
+
|
129 |
+
if limit_show_number is not None:
|
130 |
+
model_order = model_order[:limit_show_number]
|
131 |
+
|
132 |
+
# Arrange ordering according to proprition of wins
|
133 |
+
row_beats_col = row_beats_col_freq.loc[model_order, model_order]
|
134 |
+
return row_beats_col
|
135 |
+
|
136 |
+
|
137 |
+
def visualize_leaderboard_table(rating):
|
138 |
+
models = list(rating.keys())
|
139 |
+
models.sort(key=lambda k: -rating[k])
|
140 |
+
|
141 |
+
emoji_dict = {
|
142 |
+
1: "🥇",
|
143 |
+
2: "🥈",
|
144 |
+
3: "🥉",
|
145 |
+
}
|
146 |
+
|
147 |
+
md = ""
|
148 |
+
md += "| Rank | Model | Elo Rating | Description |\n"
|
149 |
+
md += "| --- | --- | --- | --- |\n"
|
150 |
+
for i, model in enumerate(models):
|
151 |
+
rank = i + 1
|
152 |
+
minfo = get_model_info(model)
|
153 |
+
emoji = emoji_dict.get(rank, "")
|
154 |
+
md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
|
155 |
+
|
156 |
+
return md
|
157 |
+
|
158 |
+
|
159 |
+
def visualize_pairwise_win_fraction(battles, model_order):
|
160 |
+
row_beats_col = compute_pairwise_win_fraction(battles, model_order)
|
161 |
+
fig = px.imshow(
|
162 |
+
row_beats_col,
|
163 |
+
color_continuous_scale="RdBu",
|
164 |
+
text_auto=".2f",
|
165 |
+
height=700,
|
166 |
+
width=700,
|
167 |
+
)
|
168 |
+
fig.update_layout(
|
169 |
+
xaxis_title="Model B",
|
170 |
+
yaxis_title="Model A",
|
171 |
+
xaxis_side="top",
|
172 |
+
title_y=0.07,
|
173 |
+
title_x=0.5,
|
174 |
+
)
|
175 |
+
fig.update_traces(
|
176 |
+
hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Fraction of A Wins: %{z}<extra></extra>"
|
177 |
+
)
|
178 |
+
|
179 |
+
return fig
|
180 |
+
|
181 |
+
|
182 |
+
def visualize_battle_count(battles, model_order):
|
183 |
+
ptbl = pd.pivot_table(
|
184 |
+
battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
|
185 |
+
)
|
186 |
+
battle_counts = ptbl + ptbl.T
|
187 |
+
fig = px.imshow(
|
188 |
+
battle_counts.loc[model_order, model_order],
|
189 |
+
text_auto=True,
|
190 |
+
height=700,
|
191 |
+
width=700,
|
192 |
+
)
|
193 |
+
fig.update_layout(
|
194 |
+
xaxis_title="Model B",
|
195 |
+
yaxis_title="Model A",
|
196 |
+
xaxis_side="top",
|
197 |
+
title_y=0.07,
|
198 |
+
title_x=0.5,
|
199 |
+
)
|
200 |
+
fig.update_traces(
|
201 |
+
hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Count: %{z}<extra></extra>"
|
202 |
+
)
|
203 |
+
return fig
|
204 |
+
|
205 |
+
|
206 |
+
def visualize_average_win_rate(battles, limit_show_number):
|
207 |
+
row_beats_col_freq = compute_pairwise_win_fraction(
|
208 |
+
battles, None, limit_show_number=limit_show_number
|
209 |
+
)
|
210 |
+
fig = px.bar(
|
211 |
+
row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
|
212 |
+
text_auto=".2f",
|
213 |
+
height=500,
|
214 |
+
width=700,
|
215 |
+
)
|
216 |
+
fig.update_layout(
|
217 |
+
yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
|
218 |
+
)
|
219 |
+
return fig
|
220 |
+
|
221 |
+
|
222 |
+
def visualize_bootstrap_elo_rating(df, df_final, limit_show_number):
|
223 |
+
bars = (
|
224 |
+
pd.DataFrame(
|
225 |
+
dict(
|
226 |
+
lower=df.quantile(0.025),
|
227 |
+
rating=df_final,
|
228 |
+
upper=df.quantile(0.975),
|
229 |
+
)
|
230 |
+
)
|
231 |
+
.reset_index(names="model")
|
232 |
+
.sort_values("rating", ascending=False)
|
233 |
+
)
|
234 |
+
bars = bars[:limit_show_number]
|
235 |
+
bars["error_y"] = bars["upper"] - bars["rating"]
|
236 |
+
bars["error_y_minus"] = bars["rating"] - bars["lower"]
|
237 |
+
bars["rating_rounded"] = np.round(bars["rating"], 2)
|
238 |
+
fig = px.scatter(
|
239 |
+
bars,
|
240 |
+
x="model",
|
241 |
+
y="rating",
|
242 |
+
error_y="error_y",
|
243 |
+
error_y_minus="error_y_minus",
|
244 |
+
text="rating_rounded",
|
245 |
+
height=500,
|
246 |
+
width=700,
|
247 |
+
)
|
248 |
+
fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
|
249 |
+
return fig
|
250 |
+
|
251 |
+
|
252 |
+
def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap=100, anony_only=True):
|
253 |
+
battles = pd.DataFrame(battles_json)
|
254 |
+
battles = battles.sort_values(ascending=True, by=["tstamp"])
|
255 |
+
# Only use anonymous votes
|
256 |
+
if anony_only:
|
257 |
+
battles = battles[battles["anony"]].reset_index(drop=True)
|
258 |
+
battles_no_ties = battles[~battles["winner"].str.contains("tie")]
|
259 |
+
|
260 |
+
# Online update
|
261 |
+
elo_rating_online = compute_elo(battles)
|
262 |
+
|
263 |
+
if rating_system == "bt":
|
264 |
+
bootstrap_df = get_bootstrap_result(
|
265 |
+
battles, compute_elo_mle_with_tie, num_round=num_bootstrap
|
266 |
+
)
|
267 |
+
elo_rating_final = compute_elo_mle_with_tie(battles)
|
268 |
+
elif rating_system == "elo":
|
269 |
+
bootstrap_df = get_bootstrap_result(
|
270 |
+
battles, compute_elo, num_round=num_bootstrap
|
271 |
+
)
|
272 |
+
elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
|
273 |
+
elo_rating_final = elo_rating_median
|
274 |
+
|
275 |
+
model_order = list(elo_rating_final.keys())
|
276 |
+
model_order.sort(key=lambda k: -elo_rating_final[k])
|
277 |
+
|
278 |
+
limit_show_number = 25 # limit show number to make plots smaller
|
279 |
+
model_order = model_order[:limit_show_number]
|
280 |
+
|
281 |
+
# leaderboard_table_df: elo rating, variance, 95% interval, number of battles
|
282 |
+
leaderboard_table_df = pd.DataFrame(
|
283 |
+
{
|
284 |
+
"rating": elo_rating_final,
|
285 |
+
"variance": bootstrap_df.var(),
|
286 |
+
"rating_q975": bootstrap_df.quantile(0.975),
|
287 |
+
"rating_q025": bootstrap_df.quantile(0.025),
|
288 |
+
"num_battles": battles["model_a"].value_counts()
|
289 |
+
+ battles["model_b"].value_counts(),
|
290 |
+
}
|
291 |
+
)
|
292 |
+
|
293 |
+
# Plots
|
294 |
+
leaderboard_table = visualize_leaderboard_table(elo_rating_final)
|
295 |
+
win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
|
296 |
+
battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
|
297 |
+
average_win_rate_bar = visualize_average_win_rate(
|
298 |
+
battles_no_ties, limit_show_number
|
299 |
+
)
|
300 |
+
bootstrap_elo_rating = visualize_bootstrap_elo_rating(
|
301 |
+
bootstrap_df, elo_rating_final, limit_show_number
|
302 |
+
)
|
303 |
+
|
304 |
+
last_updated_tstamp = battles["tstamp"].max()
|
305 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
306 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
307 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
308 |
+
|
309 |
+
return {
|
310 |
+
"rating_system": rating_system,
|
311 |
+
"elo_rating_online": elo_rating_online,
|
312 |
+
"elo_rating_final": elo_rating_final,
|
313 |
+
"leaderboard_table": leaderboard_table,
|
314 |
+
"win_fraction_heatmap": win_fraction_heatmap,
|
315 |
+
"battle_count_heatmap": battle_count_heatmap,
|
316 |
+
"average_win_rate_bar": average_win_rate_bar,
|
317 |
+
"bootstrap_elo_rating": bootstrap_elo_rating,
|
318 |
+
"last_updated_datetime": last_updated_datetime,
|
319 |
+
"last_updated_tstamp": last_updated_tstamp,
|
320 |
+
"bootstrap_df": bootstrap_df,
|
321 |
+
"leaderboard_table_df": leaderboard_table_df,
|
322 |
+
}
|
323 |
+
|
324 |
+
|
325 |
+
def pretty_print_elo_rating(rating):
|
326 |
+
model_order = list(rating.keys())
|
327 |
+
model_order.sort(key=lambda k: -rating[k])
|
328 |
+
for i, model in enumerate(model_order):
|
329 |
+
print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
parser = argparse.ArgumentParser()
|
334 |
+
parser.add_argument("--clean-battle-file", type=str)
|
335 |
+
parser.add_argument("--max-num-files", type=int)
|
336 |
+
parser.add_argument("--num-bootstrap", type=int, default=100)
|
337 |
+
parser.add_argument(
|
338 |
+
"--rating-system", type=str, choices=["bt", "elo"], default="bt"
|
339 |
+
)
|
340 |
+
parser.add_argument("--exclude-tie", action="store_true", default=False)
|
341 |
+
args = parser.parse_args()
|
342 |
+
|
343 |
+
np.random.seed(42)
|
344 |
+
|
345 |
+
if args.clean_battle_file:
|
346 |
+
# Read data from a cleaned battle files
|
347 |
+
battles = pd.read_json(args.clean_battle_file)
|
348 |
+
else:
|
349 |
+
# Read data from all log files
|
350 |
+
log_files = get_log_files(args.max_num_files)
|
351 |
+
battles = clean_battle_data(log_files)
|
352 |
+
|
353 |
+
anony_results = report_elo_analysis_results(
|
354 |
+
battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap, anony_only=True
|
355 |
+
)
|
356 |
+
full_results = report_elo_analysis_results(
|
357 |
+
battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap, anony_only=False
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
print("# Online Elo")
|
362 |
+
pretty_print_elo_rating(anony_results["elo_rating_online"])
|
363 |
+
print("# Median")
|
364 |
+
pretty_print_elo_rating(anony_results["elo_rating_final"])
|
365 |
+
print(f"last update : {anony_results['last_updated_datetime']}")
|
366 |
+
|
367 |
+
last_updated_tstamp = full_results["last_updated_tstamp"]
|
368 |
+
cutoff_date = datetime.datetime.fromtimestamp(
|
369 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
370 |
+
).strftime("%Y%m%d")
|
371 |
+
|
372 |
+
|
373 |
+
results = {
|
374 |
+
"anony": anony_results,
|
375 |
+
"full": full_results,
|
376 |
+
}
|
377 |
+
with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
|
378 |
+
pickle.dump(results, fout)
|
arena_elo/elo_rating/generate_leaderboard.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
|
7 |
+
def main(
|
8 |
+
model_info_file: str,
|
9 |
+
elo_rating_pkl: str,
|
10 |
+
output_csv: str
|
11 |
+
):
|
12 |
+
model_info = json.load(open(model_info_file))
|
13 |
+
|
14 |
+
with open(elo_rating_pkl, "rb") as fin:
|
15 |
+
elo_rating_results = pickle.load(fin)
|
16 |
+
|
17 |
+
anony_elo_rating_results = elo_rating_results["anony"]
|
18 |
+
full_elo_rating_results = elo_rating_results["full"]
|
19 |
+
anony_leaderboard_data = anony_elo_rating_results["leaderboard_table_df"]
|
20 |
+
full_leaderboard_data = full_elo_rating_results["leaderboard_table_df"]
|
21 |
+
|
22 |
+
# Model,MT-bench (score),Arena Elo rating,MMLU,License,Link
|
23 |
+
fields = ["key", "Model", "Arena Elo rating (anony)", "Arena Elo rating (full)", "License", "Organization", "Link"]
|
24 |
+
# set Organization and license to empty for now
|
25 |
+
all_models = anony_leaderboard_data.index.tolist()
|
26 |
+
|
27 |
+
for model in all_models:
|
28 |
+
if not model in model_info:
|
29 |
+
model_info[model] = {}
|
30 |
+
model_info[model]["License"] = "N/A"
|
31 |
+
model_info[model]["Organization"] = "N/A"
|
32 |
+
model_info[model]["Link"] = "N/A"
|
33 |
+
model_info[model]["Model"] = model
|
34 |
+
model_info[model]["key"] = model
|
35 |
+
|
36 |
+
if model in anony_leaderboard_data.index:
|
37 |
+
model_info[model]["Arena Elo rating (anony)"] = anony_leaderboard_data.loc[model, "rating"]
|
38 |
+
else:
|
39 |
+
model_info[model]["Arena Elo rating (anony)"] = 0
|
40 |
+
|
41 |
+
if model in full_elo_rating_results["leaderboard_table_df"].index:
|
42 |
+
model_info[model]["Arena Elo rating (full)"] = full_leaderboard_data.loc[model, "rating"]
|
43 |
+
else:
|
44 |
+
model_info[model]["Arena Elo rating (full)"] = 0
|
45 |
+
# if model in anony_leaderboard_data.index:
|
46 |
+
# model_info[model]["Arena Elo rating"] = anony_leaderboard_data.loc[model, "rating"]
|
47 |
+
# else:
|
48 |
+
# model_info[model]["Arena Elo rating"] = 0
|
49 |
+
|
50 |
+
final_model_info = {}
|
51 |
+
for model in model_info:
|
52 |
+
if "Model" in model_info[model]:
|
53 |
+
final_model_info[model] = model_info[model]
|
54 |
+
model_info = final_model_info
|
55 |
+
|
56 |
+
exclude_keys = ['starting_from']
|
57 |
+
for key in exclude_keys:
|
58 |
+
for model in model_info:
|
59 |
+
if key in model_info[model]:
|
60 |
+
del model_info[model][key]
|
61 |
+
df = pd.DataFrame(model_info).T
|
62 |
+
df = df[fields]
|
63 |
+
# sort by anony rating
|
64 |
+
df = df.sort_values(by=["Arena Elo rating (anony)"], ascending=False)
|
65 |
+
df.to_csv(output_csv, index=False)
|
66 |
+
print("Leaderboard data saved to", output_csv)
|
67 |
+
print(df)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
fire.Fire(main)
|
arena_elo/elo_rating/inspect_conv_rating.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import code
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pytz import timezone
|
7 |
+
import time
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
import csv
|
12 |
+
|
13 |
+
import base64
|
14 |
+
from icecream import ic
|
15 |
+
from openai import OpenAI
|
16 |
+
|
17 |
+
# Function to encode the image
|
18 |
+
def encode_image(image_path):
|
19 |
+
with open(image_path, "rb") as image_file:
|
20 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
21 |
+
|
22 |
+
def get_log_files(max_num_files=None):
|
23 |
+
dates = []
|
24 |
+
for month in [2, 3]:
|
25 |
+
for day in range(1, 32):
|
26 |
+
dates.append(f"2024-{month:02d}-{day:02d}")
|
27 |
+
|
28 |
+
num_servers = 1
|
29 |
+
filenames = []
|
30 |
+
for d in dates:
|
31 |
+
for i in range(num_servers):
|
32 |
+
# name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
33 |
+
name = os.path.expanduser(f"vision-arena-logs/{d}-conv.json")
|
34 |
+
if os.path.exists(name):
|
35 |
+
filenames.append(name)
|
36 |
+
max_num_files = max_num_files or len(filenames)
|
37 |
+
filenames = filenames[-max_num_files:]
|
38 |
+
return filenames
|
39 |
+
|
40 |
+
|
41 |
+
def pretty_print_conversation(messages):
|
42 |
+
for role, msg in messages:
|
43 |
+
print(f"[[{role}]]: {msg}")
|
44 |
+
|
45 |
+
|
46 |
+
def get_gpt4v_response(client, img_bs64=None, text_prompt="", use_vision=False):
|
47 |
+
if use_vision:
|
48 |
+
response = client.chat.completions.create(
|
49 |
+
model="gpt-4-vision-preview",
|
50 |
+
messages=[
|
51 |
+
{
|
52 |
+
"role": "user",
|
53 |
+
"content": [
|
54 |
+
{"type": "text", "text": text_prompt},
|
55 |
+
{
|
56 |
+
"type": "image_url",
|
57 |
+
"image_url": {
|
58 |
+
"url": f"data:image/jpeg;base64,{img_bs64}"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
],
|
62 |
+
}
|
63 |
+
],
|
64 |
+
max_tokens=100,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
response = client.chat.completions.create(
|
68 |
+
model="gpt-4-vision-preview",
|
69 |
+
messages=[
|
70 |
+
{
|
71 |
+
"role": "user",
|
72 |
+
"content": [
|
73 |
+
{"type": "text", "text": text_prompt},
|
74 |
+
],
|
75 |
+
}
|
76 |
+
],
|
77 |
+
max_tokens=100,
|
78 |
+
)
|
79 |
+
return response.choices[0].message.content
|
80 |
+
|
81 |
+
task_template_map = {
|
82 |
+
"image_caption": "Give me the semantic alignment score between the given image and the given caption: \"{generated_sentence}\" on a scale of 0-100. Only reply the score value.",
|
83 |
+
"vqa": "Rate the answer correctness regarding the question within the context of the given image on a scale of 0-100. Only reply the score value.",
|
84 |
+
"pair_rate_old": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nGiven the instruction and the image, please compare the correctness of responses A and B. Reply with \"leftvote\" if you find A better, \"rightvote\" if B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both responses are equally satisfactory. If you are unable to make a decision, please reply with \"NA\".",
|
85 |
+
"pair_rate_wexplanation": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
|
86 |
+
"pair_rate": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Reply with \"leftvote\" if you find assistant A better, \"rightvote\" if assistant B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both assistants provide equally satisfactory answers. If you are unable to make a decision, please reply with \"NA\"."
|
87 |
+
}
|
88 |
+
|
89 |
+
def inspect_convs(log_files):
|
90 |
+
ic(log_files)
|
91 |
+
data = []
|
92 |
+
total_vote = 0
|
93 |
+
correct_vote = 0
|
94 |
+
|
95 |
+
client = OpenAI()
|
96 |
+
with open('all_pairvote_log_wgpt_prtchatbot.csv', 'w', newline='') as csvfile:
|
97 |
+
# fieldnames = ['tstamp', 'type', 'model_1', 'model_2', 'template_name_1', 'template_name_2', 'system_message_1', 'system_message_2', 'role_1', 'role_2', 'instruction_1', 'instruction_2', 'message_1', 'message_2', 'offset_1', 'offset_2', 'conv_id_1', 'conv_id_2', 'model_name_1', 'model_name_2', 'ip']
|
98 |
+
fieldnames = ['tstamp', 'type', 'models', 'states', 'ip', 'gpt_vote']
|
99 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
100 |
+
|
101 |
+
# Write the header
|
102 |
+
writer.writeheader()
|
103 |
+
|
104 |
+
for filename in tqdm(log_files, desc="read files"):
|
105 |
+
for retry in range(5):
|
106 |
+
try:
|
107 |
+
lines = open(filename).readlines()
|
108 |
+
break
|
109 |
+
except FileNotFoundError:
|
110 |
+
time.sleep(2)
|
111 |
+
|
112 |
+
for l in lines:
|
113 |
+
row = json.loads(l)
|
114 |
+
|
115 |
+
if "states" not in row:
|
116 |
+
continue
|
117 |
+
if row["type"] not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
|
118 |
+
continue
|
119 |
+
|
120 |
+
model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
|
121 |
+
|
122 |
+
|
123 |
+
# Iterate through each state and write the relevant information
|
124 |
+
if not len(row["states"][0]['messages']): continue
|
125 |
+
# ic(row["states"][0]['messages'][1][1])
|
126 |
+
|
127 |
+
if row["states"][0]['messages'][1][1] is None or row["states"][1]['messages'][1][1] is None or "NETWORK ERROR" in row["states"][0]['messages'][1][1] or "NETWORK ERROR" in row["states"][1]['messages'][1][1]: continue
|
128 |
+
total_vote += 1
|
129 |
+
# row = {
|
130 |
+
# 'tstamp': row['tstamp'],
|
131 |
+
# 'type': row['type'],
|
132 |
+
# 'model_1': row['models'][0],
|
133 |
+
# 'model_2': row['models'][1],
|
134 |
+
# 'template_name_1': row["states"][0]['template_name'],
|
135 |
+
# 'system_message_1': row["states"][0]['system_message'],
|
136 |
+
# 'template_name_2': row["states"][1]['template_name'],
|
137 |
+
# 'system_message_2': row["states"][1]['system_message'],
|
138 |
+
# 'role_1': row["states"][0]['roles'],
|
139 |
+
# 'role_2': row["states"][1]['roles'],
|
140 |
+
# 'instruction_1': row["states"][0]['messages'][0][1],
|
141 |
+
# 'instruction_2': row["states"][1]['messages'][0][1],
|
142 |
+
# 'message_1': row["states"][0]['messages'][1][1],
|
143 |
+
# 'message_2': row["states"][1]['messages'][1][1],
|
144 |
+
# 'offset_1': row["states"][0]['offset'],
|
145 |
+
# 'offset_2': row["states"][1]['offset'],
|
146 |
+
# 'conv_id_1': row["states"][0]['conv_id'],
|
147 |
+
# 'conv_id_2': row["states"][1]['conv_id'],
|
148 |
+
# 'model_name_1': row["states"][0]['model_name'],
|
149 |
+
# 'model_name_2': row["states"][1]['model_name'],
|
150 |
+
# 'ip': row['ip']
|
151 |
+
# }
|
152 |
+
# writer.writerow(row)
|
153 |
+
# Convert complex objects to JSON strings
|
154 |
+
# TODO: check two image are the same
|
155 |
+
conv_id = row["states"][0]['conv_id']
|
156 |
+
image_path = os.path.join("/local/home/yujielu/project/Arena-Elo/vision-arena-logs", os.path.basename(filename)[:-5]+"input_images", f"input_image_{conv_id}.png")
|
157 |
+
if not os.path.exists(image_path):
|
158 |
+
response = "NA"
|
159 |
+
ic(image_path)
|
160 |
+
else:
|
161 |
+
base64_image = encode_image(image_path)
|
162 |
+
left_response = row["states"][0]['messages'][1][1]
|
163 |
+
right_response = row["states"][1]['messages'][1][1]
|
164 |
+
sep = "-" * 20
|
165 |
+
instruction = row["states"][0]['messages'][0][1]
|
166 |
+
generated_sentence = f"[The Start of Assistant A’s Answer]\n{left_response}\n[The End of Assistant A’s Answer]\n\n[The Start of Assistant B’s Answer]\n{right_response}\n[The End of Assistant B’s Answer]"
|
167 |
+
text_prompt = task_template_map["pair_rate"].format(instruction=instruction, generated_sentence=generated_sentence)
|
168 |
+
# ic(text_prompt)
|
169 |
+
try:
|
170 |
+
response = get_gpt4v_response(client, img_bs64=base64_image, text_prompt=text_prompt, use_vision=True)
|
171 |
+
except:
|
172 |
+
ic(">>> skip")
|
173 |
+
response = "NA"
|
174 |
+
|
175 |
+
# response = get_gpt4v_response(client, img_bs64=base64_image, text_prompt=text_prompt, use_vision=True)
|
176 |
+
ic(row['type'], response)
|
177 |
+
if response.strip() not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
|
178 |
+
response = "NA"
|
179 |
+
# ic(generated_sentence)
|
180 |
+
|
181 |
+
# if row['type'] == "leftvote":
|
182 |
+
# row['type'] = "A"
|
183 |
+
# elif row['type'] == "rightvote":
|
184 |
+
# row['type'] = "B"
|
185 |
+
# elif row['type'] in ["bothbad_vote", "tievote"]:
|
186 |
+
# row['type'] = "C"
|
187 |
+
if row['type'] == response.strip():
|
188 |
+
correct_vote += 1
|
189 |
+
row['models'] = json.dumps(row['models'])
|
190 |
+
row['states'] = json.dumps(row['states'], ensure_ascii=False)
|
191 |
+
row['gpt_vote'] = response
|
192 |
+
|
193 |
+
# Write the modified row to the CSV file
|
194 |
+
writer.writerow(row)
|
195 |
+
# if row["type"] == "leftvote":
|
196 |
+
# winner, loser = model_names[0], model_names[1]
|
197 |
+
# winner_conv, loser_conv = row["states"][0], row["states"][1]
|
198 |
+
# elif row["type"] == "rightvote":
|
199 |
+
# loser, winner = model_names[0], model_names[1]
|
200 |
+
# loser_conv, winner_conv = row["states"][0], row["states"][1]
|
201 |
+
|
202 |
+
# if loser == "llava-v1.5-13b" and winner == "llava-v1.5-13b":
|
203 |
+
# print("=" * 20)
|
204 |
+
# print(f"Winner: {winner}")
|
205 |
+
# pretty_print_conversation(winner_conv["messages"])
|
206 |
+
# print(f"Loser: {loser}")
|
207 |
+
# pretty_print_conversation(loser_conv["messages"])
|
208 |
+
# print("=" * 20)
|
209 |
+
# input()
|
210 |
+
# if row['type'] == 'bothbad_vote':
|
211 |
+
# from icecream import ic
|
212 |
+
# ic(model_names)
|
213 |
+
# if row["type"] == "bothbad_vote" and "gpt-4-vision-preview" in model_names:
|
214 |
+
# print("=" * 20)
|
215 |
+
# print(f"Model A: {model_names[0]}")
|
216 |
+
# pretty_print_conversation(row["states"][0]["messages"])
|
217 |
+
# print(f"Model B: {model_names[1]}")
|
218 |
+
# pretty_print_conversation(row["states"][1]["messages"])
|
219 |
+
# print("=" * 20)
|
220 |
+
# input()
|
221 |
+
# if correct_vote >= 300: break
|
222 |
+
ic(total_vote, correct_vote)
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == "__main__":
|
226 |
+
parser = argparse.ArgumentParser()
|
227 |
+
parser.add_argument("--max-num-files", type=int)
|
228 |
+
args = parser.parse_args()
|
229 |
+
|
230 |
+
log_files = get_log_files(args.max_num_files)
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
inspect_convs(log_files)
|
arena_elo/elo_rating/inspect_cost.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
from collections import defaultdict
|
5 |
+
from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR
|
6 |
+
from .utils import detect_language, get_time_stamp_from_date, get_input_image_path, load_image_from_path
|
7 |
+
from tqdm import tqdm
|
8 |
+
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote", "chat"]
|
9 |
+
|
10 |
+
|
11 |
+
def remove_html(raw):
|
12 |
+
if raw.startswith("<h3>"):
|
13 |
+
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
|
14 |
+
if raw.startswith("### Model A: ") or raw.startswith("### Model B: "):
|
15 |
+
return raw[13:]
|
16 |
+
return raw
|
17 |
+
|
18 |
+
|
19 |
+
def read_file(filename):
|
20 |
+
data = []
|
21 |
+
for retry in range(5):
|
22 |
+
try:
|
23 |
+
# lines = open(filename).readlines()
|
24 |
+
for l in open(filename):
|
25 |
+
row = json.loads(l)
|
26 |
+
if row["type"] in VOTES:
|
27 |
+
data.append(row)
|
28 |
+
break
|
29 |
+
except FileNotFoundError:
|
30 |
+
time.sleep(2)
|
31 |
+
return data
|
32 |
+
|
33 |
+
|
34 |
+
def read_file_parallel(log_files, num_threads=16):
|
35 |
+
data_all = []
|
36 |
+
from multiprocessing import Pool
|
37 |
+
|
38 |
+
with Pool(num_threads) as p:
|
39 |
+
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
|
40 |
+
for ret in ret_all:
|
41 |
+
data_all.extend(ret)
|
42 |
+
return data_all
|
43 |
+
|
44 |
+
def num_tokens(s:str):
|
45 |
+
if s is None:
|
46 |
+
return 0
|
47 |
+
return len(s) / 4
|
48 |
+
|
49 |
+
def main(
|
50 |
+
):
|
51 |
+
log_files = get_log_files()
|
52 |
+
data = read_file_parallel(log_files)
|
53 |
+
|
54 |
+
all_model_counts = defaultdict(int)
|
55 |
+
all_model_input_tokens_counts = defaultdict(list)
|
56 |
+
all_model_output_tokens_counts = defaultdict(list)
|
57 |
+
all_model_image_sizes = defaultdict(list)
|
58 |
+
chat_battle_counts = defaultdict(int)
|
59 |
+
for row in tqdm(data, desc="counting"):
|
60 |
+
if row['type'] == "chat":
|
61 |
+
chat_battle_counts["chat"] += 1
|
62 |
+
all_model_counts[row['model']] += 1
|
63 |
+
tstamp = row["tstamp"]
|
64 |
+
conv_id = row["state"]["conv_id"]
|
65 |
+
|
66 |
+
image = load_image_from_path(get_input_image_path(tstamp, conv_id))
|
67 |
+
if image is None:
|
68 |
+
image_size = None
|
69 |
+
else:
|
70 |
+
image_size = load_image_from_path(get_input_image_path(tstamp, conv_id)).size
|
71 |
+
all_model_image_sizes[row['model']].append(image_size)
|
72 |
+
try:
|
73 |
+
for message in row["state"]["messages"][row["state"]["offset"] :: 2]:
|
74 |
+
all_model_input_tokens_counts[row['model']].append(num_tokens(message[1]))
|
75 |
+
for message in row["state"]["messages"][row["state"]["offset"] + 1 :: 2]:
|
76 |
+
all_model_output_tokens_counts[row['model']].append(num_tokens(message[1]))
|
77 |
+
except Exception as e:
|
78 |
+
print(row)
|
79 |
+
raise e
|
80 |
+
|
81 |
+
else:
|
82 |
+
chat_battle_counts[row['type']] += 1
|
83 |
+
if row["models"][0] is None or row["models"][1] is None:
|
84 |
+
continue
|
85 |
+
|
86 |
+
# Resolve model names
|
87 |
+
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
|
88 |
+
if "model_name" in row["states"][0]:
|
89 |
+
models_hidden = [
|
90 |
+
row["states"][0]["model_name"],
|
91 |
+
row["states"][1]["model_name"],
|
92 |
+
]
|
93 |
+
if models_hidden[0] is None:
|
94 |
+
models_hidden = models_public
|
95 |
+
else:
|
96 |
+
models_hidden = models_public
|
97 |
+
|
98 |
+
if (models_public[0] == "" and models_public[1] != "") or (
|
99 |
+
models_public[1] == "" and models_public[0] != ""
|
100 |
+
):
|
101 |
+
continue
|
102 |
+
|
103 |
+
if models_public[0] == "" or models_public[0] == "Model A":
|
104 |
+
anony = True
|
105 |
+
models = models_hidden
|
106 |
+
else:
|
107 |
+
anony = False
|
108 |
+
models = models_public
|
109 |
+
if not models_public == models_hidden:
|
110 |
+
continue
|
111 |
+
|
112 |
+
all_model_counts[models[0]] += 1
|
113 |
+
all_model_counts[models[1]] += 1
|
114 |
+
tstamp = row["tstamp"]
|
115 |
+
conv_id1 = row["states"][0]["conv_id"]
|
116 |
+
conv_id2 = row["states"][1]["conv_id"]
|
117 |
+
|
118 |
+
image1 = load_image_from_path(get_input_image_path(tstamp, conv_id1))
|
119 |
+
image2 = load_image_from_path(get_input_image_path(tstamp, conv_id2))
|
120 |
+
all_model_image_sizes[models[0]].append(None if image1 is None else image1.size)
|
121 |
+
all_model_image_sizes[models[1]].append(None if image2 is None else image2.size)
|
122 |
+
|
123 |
+
for message in row["states"][0]["messages"][row["states"][0]["offset"] :: 2]:
|
124 |
+
all_model_input_tokens_counts[models[0]].append(num_tokens(message[1]))
|
125 |
+
for message in row["states"][0]["messages"][row["states"][0]["offset"] + 1 :: 2]:
|
126 |
+
all_model_output_tokens_counts[models[0]].append(num_tokens(message[1]))
|
127 |
+
for message in row["states"][1]["messages"][row["states"][1]["offset"] :: 2]:
|
128 |
+
all_model_input_tokens_counts[models[1]].append(num_tokens(message[1]))
|
129 |
+
for message in row["states"][1]["messages"][row["states"][1]["offset"] + 1 :: 2]:
|
130 |
+
all_model_output_tokens_counts[models[1]].append(num_tokens(message[1]))
|
131 |
+
|
132 |
+
print("### Chat battle counts (requests)")
|
133 |
+
print(json.dumps(chat_battle_counts, indent=4))
|
134 |
+
|
135 |
+
print("### Model counts (requests)")
|
136 |
+
print(json.dumps(all_model_counts, indent=4))
|
137 |
+
|
138 |
+
print("### Model Avg input tokens counts (tokens)")
|
139 |
+
average_input_tokens_counts = {}
|
140 |
+
for model, counts in all_model_input_tokens_counts.items():
|
141 |
+
average_input_tokens_counts[model] = sum(counts) / len(counts)
|
142 |
+
print(json.dumps(average_input_tokens_counts, indent=4))
|
143 |
+
|
144 |
+
print("### Model AVg output tokens counts (tokens)")
|
145 |
+
average_output_tokens_counts = {}
|
146 |
+
for model, counts in all_model_output_tokens_counts.items():
|
147 |
+
average_output_tokens_counts[model] = sum(counts) / len(counts)
|
148 |
+
print(json.dumps(average_output_tokens_counts, indent=4))
|
149 |
+
|
150 |
+
print("### Model Avg image sizes (height, width)")
|
151 |
+
average_image_sizes = {}
|
152 |
+
for model, sizes in all_model_image_sizes.items():
|
153 |
+
avg_height = sum([size[0] for size in sizes if size is not None]) / len(sizes)
|
154 |
+
avg_width = sum([size[1] for size in sizes if size is not None]) / len(sizes)
|
155 |
+
average_image_sizes[model] = (avg_height, avg_width)
|
156 |
+
print(json.dumps(average_image_sizes, indent=4))
|
157 |
+
|
158 |
+
print("### GPT-4V estimated cost (USD)")
|
159 |
+
gpt_4v_name = "gpt-4-vision-preview"
|
160 |
+
gpt_4v_cost = {}
|
161 |
+
gpt_4v_cost['input'] = sum(all_model_input_tokens_counts[gpt_4v_name]) / 1000 * 0.01
|
162 |
+
gpt_4v_cost['output'] = sum(all_model_output_tokens_counts[gpt_4v_name]) / 1000 * 0.03
|
163 |
+
|
164 |
+
all_image_cost = 0
|
165 |
+
for size in all_model_image_sizes[gpt_4v_name]:
|
166 |
+
if size is None:
|
167 |
+
continue
|
168 |
+
all_image_tokens = (size[0] // 512 + 1) * (size[1] // 512 + 1) * 170 + 85
|
169 |
+
all_image_cost += all_image_tokens / 1000 * 0.01
|
170 |
+
gpt_4v_cost['image'] = all_image_cost
|
171 |
+
print(json.dumps(gpt_4v_cost, indent=4))
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
fire.Fire(main)
|
arena_elo/elo_rating/inspect_elo_rating_pkl.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
|
4 |
+
def output_figure(data, figure_name="battle_count_heatmap", label="annoy"):
|
5 |
+
fig = data[label][figure_name]
|
6 |
+
fig.update_layout(
|
7 |
+
height=700,
|
8 |
+
width=700,
|
9 |
+
title={'text': f'{figure_name}', 'x': 0.5, 'y': 0.07},
|
10 |
+
xaxis_title="Model B",
|
11 |
+
yaxis_title="Model A",
|
12 |
+
# coloraxis_colorscale=[[0.0, '#0d0887'], [1.0, '#f0f921']],
|
13 |
+
margin={'t': 60}
|
14 |
+
)
|
15 |
+
fig.write_image(f"{figure_name}.png")
|
16 |
+
|
17 |
+
with open("./results/latest/elo_results.pkl",'rb') as f:
|
18 |
+
data = pickle.load(f)
|
19 |
+
print()
|
20 |
+
df = data["anony"]["leaderboard_table_df"]
|
21 |
+
# sort by rating
|
22 |
+
print(data["anony"].keys())
|
23 |
+
|
24 |
+
for figure_name in [ 'win_fraction_heatmap', 'battle_count_heatmap',]:
|
25 |
+
output_figure(data, figure_name, "anony")
|
26 |
+
|
27 |
+
df = df.sort_values(by=["rating"], ascending=False)
|
28 |
+
print(df)
|
29 |
+
df = data["full"]["leaderboard_table_df"]
|
30 |
+
# sort by rating
|
31 |
+
df = df.sort_values(by=["rating"], ascending=False)
|
32 |
+
print(df)
|
33 |
+
print('done')
|
arena_elo/elo_rating/model_registry.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Additional information of the models."""
|
2 |
+
from collections import namedtuple, OrderedDict
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"])
|
7 |
+
|
8 |
+
|
9 |
+
model_info = OrderedDict()
|
10 |
+
|
11 |
+
|
12 |
+
def register_model_info(
|
13 |
+
full_names: List[str], simple_name: str, link: str, description: str
|
14 |
+
):
|
15 |
+
info = ModelInfo(simple_name, link, description)
|
16 |
+
|
17 |
+
for full_name in full_names:
|
18 |
+
model_info[full_name] = info
|
19 |
+
|
20 |
+
|
21 |
+
def get_model_info(name: str) -> ModelInfo:
|
22 |
+
if name in model_info:
|
23 |
+
return model_info[name]
|
24 |
+
else:
|
25 |
+
# To fix this, please use `register_model_info` to register your model
|
26 |
+
return ModelInfo(
|
27 |
+
name, "", "Register the description at arena.model/model_registry.py"
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
register_model_info(
|
32 |
+
[
|
33 |
+
"IEITYuan/Yuan2-2B-Janus-hf",
|
34 |
+
"IEITYuan/Yuan2-2B-hf",
|
35 |
+
"IEITYuan/Yuan2-51B-hf",
|
36 |
+
"IEITYuan/Yuan2-102B-hf",
|
37 |
+
],
|
38 |
+
"IEIT-Yuan2",
|
39 |
+
"https://github.com/IEIT-Yuan/Yuan-2.0",
|
40 |
+
"Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.",
|
41 |
+
)
|
42 |
+
|
43 |
+
register_model_info(
|
44 |
+
["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],
|
45 |
+
"Mixtral of experts",
|
46 |
+
"https://mistral.ai/news/mixtral-of-experts/",
|
47 |
+
"A Mixture-of-Experts model by Mistral AI",
|
48 |
+
)
|
49 |
+
|
50 |
+
register_model_info(
|
51 |
+
["gemini-pro"],
|
52 |
+
"Gemini",
|
53 |
+
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
|
54 |
+
"Gemini by Google",
|
55 |
+
)
|
56 |
+
|
57 |
+
register_model_info(
|
58 |
+
["gemini-pro-vision"],
|
59 |
+
"Gemini",
|
60 |
+
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
|
61 |
+
"Gemini by Google",
|
62 |
+
)
|
63 |
+
|
64 |
+
register_model_info(
|
65 |
+
["solar-10.7b-instruct-v1.0"],
|
66 |
+
"SOLAR-10.7B-Instruct",
|
67 |
+
"https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0",
|
68 |
+
"A model trained using depth up-scaling by Upstage AI",
|
69 |
+
)
|
70 |
+
|
71 |
+
register_model_info(
|
72 |
+
["gpt-4-turbo"],
|
73 |
+
"GPT-4-Turbo",
|
74 |
+
"https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
|
75 |
+
"GPT-4-Turbo by OpenAI",
|
76 |
+
)
|
77 |
+
|
78 |
+
register_model_info(
|
79 |
+
["gpt-4-vision-preview"],
|
80 |
+
"gpt-4-vision-preview",
|
81 |
+
"https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
|
82 |
+
"GPT-4(V) by OpenAI",
|
83 |
+
)
|
84 |
+
|
85 |
+
register_model_info(
|
86 |
+
["gpt-3.5-turbo", "gpt-3.5-turbo-0314", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106"],
|
87 |
+
"GPT-3.5",
|
88 |
+
"https://platform.openai.com/docs/models/gpt-3-5",
|
89 |
+
"GPT-3.5-Turbo by OpenAI",
|
90 |
+
)
|
91 |
+
|
92 |
+
register_model_info(
|
93 |
+
["gpt-4", "gpt-4-0314", "gpt-4-0613"],
|
94 |
+
"GPT-4",
|
95 |
+
"https://openai.com/research/gpt-4",
|
96 |
+
"GPT-4 by OpenAI",
|
97 |
+
)
|
98 |
+
|
99 |
+
register_model_info(
|
100 |
+
["claude-2.1", "claude-2.0"],
|
101 |
+
"Claude",
|
102 |
+
"https://www.anthropic.com/index/claude-2",
|
103 |
+
"Claude 2 by Anthropic",
|
104 |
+
)
|
105 |
+
|
106 |
+
register_model_info(
|
107 |
+
["claude-1"],
|
108 |
+
"Claude",
|
109 |
+
"https://www.anthropic.com/index/introducing-claude",
|
110 |
+
"Claude 1 by Anthropic",
|
111 |
+
)
|
112 |
+
|
113 |
+
register_model_info(
|
114 |
+
["claude-instant-1", "claude-instant-1.2"],
|
115 |
+
"Claude Instant",
|
116 |
+
"https://www.anthropic.com/index/introducing-claude",
|
117 |
+
"Claude Instant by Anthropic",
|
118 |
+
)
|
119 |
+
|
120 |
+
register_model_info(
|
121 |
+
["pplx-70b-online", "pplx-7b-online"],
|
122 |
+
"pplx-online-llms",
|
123 |
+
"https://blog.perplexity.ai/blog/introducing-pplx-online-llms",
|
124 |
+
"Online LLM API by Perplexity AI",
|
125 |
+
)
|
126 |
+
|
127 |
+
register_model_info(
|
128 |
+
["openhermes-2.5-mistral-7b"],
|
129 |
+
"OpenHermes-2.5-Mistral-7B",
|
130 |
+
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B",
|
131 |
+
"a mistral-based model fine-tuned on 1M GPT-4 outputs",
|
132 |
+
)
|
133 |
+
|
134 |
+
register_model_info(
|
135 |
+
["starling-lm-7b-alpha"],
|
136 |
+
"Starling-LM-7B-alpha",
|
137 |
+
"https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha",
|
138 |
+
"an open model trained using RLAIF by Berkeley",
|
139 |
+
)
|
140 |
+
|
141 |
+
register_model_info(
|
142 |
+
["tulu-2-dpo-70b"],
|
143 |
+
"Tulu 2",
|
144 |
+
"https://huggingface.co/allenai/tulu-2-dpo-70b",
|
145 |
+
"an instruction and RLHF model by UW/AllenAI",
|
146 |
+
)
|
147 |
+
|
148 |
+
register_model_info(
|
149 |
+
["yi-34b-chat", "yi-6b-chat"],
|
150 |
+
"Yi-Chat",
|
151 |
+
"https://huggingface.co/01-ai/Yi-34B-Chat",
|
152 |
+
"A large language model by 01 AI",
|
153 |
+
)
|
154 |
+
|
155 |
+
register_model_info(
|
156 |
+
["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"],
|
157 |
+
"Llama 2",
|
158 |
+
"https://ai.meta.com/llama/",
|
159 |
+
"open foundation and fine-tuned chat models by Meta",
|
160 |
+
)
|
161 |
+
|
162 |
+
register_model_info(
|
163 |
+
[
|
164 |
+
"vicuna-33b",
|
165 |
+
"vicuna-33b-v1.3",
|
166 |
+
"vicuna-13b",
|
167 |
+
"vicuna-13b-v1.3",
|
168 |
+
"vicuna-7b",
|
169 |
+
"vicuna-7b-v1.3",
|
170 |
+
],
|
171 |
+
"Vicuna",
|
172 |
+
"https://lmsys.org/blog/2023-03-30-vicuna/",
|
173 |
+
"a chat assistant fine-tuned on user-shared conversations by LMSYS",
|
174 |
+
)
|
175 |
+
|
176 |
+
register_model_info(
|
177 |
+
["chatglm3-6b", "chatglm2-6b", "chatglm-6b"],
|
178 |
+
"ChatGLM",
|
179 |
+
"https://chatglm.cn/blog",
|
180 |
+
"an open bilingual dialogue language model by Tsinghua University",
|
181 |
+
)
|
182 |
+
|
183 |
+
register_model_info(
|
184 |
+
["openchat-3.5"],
|
185 |
+
"OpenChat 3.5",
|
186 |
+
"https://github.com/imoneoi/openchat",
|
187 |
+
"an open model fine-tuned on Mistral-7B using C-RLFT",
|
188 |
+
)
|
189 |
+
|
190 |
+
register_model_info(
|
191 |
+
["tenyxchat-7b-v1"],
|
192 |
+
"TenyxChat-7B",
|
193 |
+
"https://huggingface.co/tenyx/TenyxChat-7B-v1",
|
194 |
+
"an open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning",
|
195 |
+
)
|
196 |
+
|
197 |
+
register_model_info(
|
198 |
+
["zephyr-7b-beta", "zephyr-7b-alpha"],
|
199 |
+
"Zephyr",
|
200 |
+
"https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha",
|
201 |
+
"a chatbot fine-tuned from Mistral by Hugging Face",
|
202 |
+
)
|
203 |
+
|
204 |
+
register_model_info(
|
205 |
+
["notus-7b-v1"],
|
206 |
+
"Notus",
|
207 |
+
"https://huggingface.co/argilla/notus-7b-v1",
|
208 |
+
"a chatbot fine-tuned from Zephyr SFT by Argilla",
|
209 |
+
)
|
210 |
+
|
211 |
+
register_model_info(
|
212 |
+
["catppt"],
|
213 |
+
"CatPPT",
|
214 |
+
"https://huggingface.co/rishiraj/CatPPT",
|
215 |
+
"a chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya",
|
216 |
+
)
|
217 |
+
|
218 |
+
register_model_info(
|
219 |
+
["TinyLlama"],
|
220 |
+
"TinyLlama",
|
221 |
+
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
222 |
+
"The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.",
|
223 |
+
)
|
224 |
+
|
225 |
+
register_model_info(
|
226 |
+
["qwen-14b-chat"],
|
227 |
+
"Qwen",
|
228 |
+
"https://huggingface.co/Qwen/Qwen-14B-Chat",
|
229 |
+
"a large language model by Alibaba Cloud",
|
230 |
+
)
|
231 |
+
|
232 |
+
register_model_info(
|
233 |
+
["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"],
|
234 |
+
"Code Llama",
|
235 |
+
"https://ai.meta.com/blog/code-llama-large-language-model-coding/",
|
236 |
+
"open foundation models for code by Meta",
|
237 |
+
)
|
238 |
+
|
239 |
+
register_model_info(
|
240 |
+
["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"],
|
241 |
+
"WizardLM",
|
242 |
+
"https://github.com/nlpxucan/WizardLM",
|
243 |
+
"an instruction-following LLM using evol-instruct by Microsoft",
|
244 |
+
)
|
245 |
+
|
246 |
+
register_model_info(
|
247 |
+
["wizardcoder-15b-v1.0"],
|
248 |
+
"WizardLM",
|
249 |
+
"https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder",
|
250 |
+
"Empowering Code Large Language Models with Evol-Instruct",
|
251 |
+
)
|
252 |
+
|
253 |
+
register_model_info(
|
254 |
+
["mpt-7b-chat", "mpt-30b-chat"],
|
255 |
+
"MPT-Chat",
|
256 |
+
"https://www.mosaicml.com/blog/mpt-30b",
|
257 |
+
"a chatbot fine-tuned from MPT by MosaicML",
|
258 |
+
)
|
259 |
+
|
260 |
+
register_model_info(
|
261 |
+
["guanaco-33b", "guanaco-65b"],
|
262 |
+
"Guanaco",
|
263 |
+
"https://github.com/artidoro/qlora",
|
264 |
+
"a model fine-tuned with QLoRA by UW",
|
265 |
+
)
|
266 |
+
|
267 |
+
register_model_info(
|
268 |
+
["gpt4all-13b-snoozy"],
|
269 |
+
"GPT4All-Snoozy",
|
270 |
+
"https://github.com/nomic-ai/gpt4all",
|
271 |
+
"a finetuned LLaMA model on assistant style data by Nomic AI",
|
272 |
+
)
|
273 |
+
|
274 |
+
register_model_info(
|
275 |
+
["koala-13b"],
|
276 |
+
"Koala",
|
277 |
+
"https://bair.berkeley.edu/blog/2023/04/03/koala",
|
278 |
+
"a dialogue model for academic research by BAIR",
|
279 |
+
)
|
280 |
+
|
281 |
+
register_model_info(
|
282 |
+
["RWKV-4-Raven-14B"],
|
283 |
+
"RWKV-4-Raven",
|
284 |
+
"https://huggingface.co/BlinkDL/rwkv-4-raven",
|
285 |
+
"an RNN with transformer-level LLM performance",
|
286 |
+
)
|
287 |
+
|
288 |
+
register_model_info(
|
289 |
+
["alpaca-13b"],
|
290 |
+
"Alpaca",
|
291 |
+
"https://crfm.stanford.edu/2023/03/13/alpaca.html",
|
292 |
+
"a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford",
|
293 |
+
)
|
294 |
+
|
295 |
+
register_model_info(
|
296 |
+
["oasst-pythia-12b"],
|
297 |
+
"OpenAssistant (oasst)",
|
298 |
+
"https://open-assistant.io",
|
299 |
+
"an Open Assistant for everyone by LAION",
|
300 |
+
)
|
301 |
+
|
302 |
+
register_model_info(
|
303 |
+
["oasst-sft-7-llama-30b"],
|
304 |
+
"OpenAssistant (oasst)",
|
305 |
+
"https://open-assistant.io",
|
306 |
+
"an Open Assistant for everyone by LAION",
|
307 |
+
)
|
308 |
+
|
309 |
+
register_model_info(
|
310 |
+
["palm-2"],
|
311 |
+
"PaLM 2 Chat",
|
312 |
+
"https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023",
|
313 |
+
"PaLM 2 for Chat (chat-bison@001) by Google",
|
314 |
+
)
|
315 |
+
|
316 |
+
register_model_info(
|
317 |
+
["llama-7b", "llama-13b"],
|
318 |
+
"LLaMA",
|
319 |
+
"https://arxiv.org/abs/2302.13971",
|
320 |
+
"open and efficient foundation language models by Meta",
|
321 |
+
)
|
322 |
+
|
323 |
+
register_model_info(
|
324 |
+
["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"],
|
325 |
+
"Open LLaMa (Open Instruct)",
|
326 |
+
"https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc",
|
327 |
+
"Open LLaMa fine-tuned on instruction-following data by VMware",
|
328 |
+
)
|
329 |
+
|
330 |
+
register_model_info(
|
331 |
+
["dolly-v2-12b"],
|
332 |
+
"Dolly",
|
333 |
+
"https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm",
|
334 |
+
"an instruction-tuned open large language model by Databricks",
|
335 |
+
)
|
336 |
+
|
337 |
+
register_model_info(
|
338 |
+
["stablelm-tuned-alpha-7b"],
|
339 |
+
"StableLM",
|
340 |
+
"https://github.com/stability-AI/stableLM",
|
341 |
+
"Stability AI language models",
|
342 |
+
)
|
343 |
+
|
344 |
+
register_model_info(
|
345 |
+
["codet5p-6b"],
|
346 |
+
"CodeT5p-6b",
|
347 |
+
"https://huggingface.co/Salesforce/codet5p-6b",
|
348 |
+
"Code completion model released by Salesforce",
|
349 |
+
)
|
350 |
+
|
351 |
+
register_model_info(
|
352 |
+
["fastchat-t5-3b", "fastchat-t5-3b-v1.0"],
|
353 |
+
"FastChat-T5",
|
354 |
+
"https://huggingface.co/lmsys/fastchat-t5-3b-v1.0",
|
355 |
+
"a chat assistant fine-tuned from FLAN-T5 by LMSYS",
|
356 |
+
)
|
357 |
+
|
358 |
+
register_model_info(
|
359 |
+
["phoenix-inst-chat-7b"],
|
360 |
+
"Phoenix-7B",
|
361 |
+
"https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b",
|
362 |
+
"a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)",
|
363 |
+
)
|
364 |
+
|
365 |
+
register_model_info(
|
366 |
+
["realm-7b-v1"],
|
367 |
+
"ReaLM",
|
368 |
+
"https://github.com/FreedomIntelligence/ReaLM",
|
369 |
+
"A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.",
|
370 |
+
)
|
371 |
+
|
372 |
+
register_model_info(
|
373 |
+
["billa-7b-sft"],
|
374 |
+
"BiLLa-7B-SFT",
|
375 |
+
"https://huggingface.co/Neutralzz/BiLLa-7B-SFT",
|
376 |
+
"an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher",
|
377 |
+
)
|
378 |
+
|
379 |
+
register_model_info(
|
380 |
+
["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"],
|
381 |
+
"h2oGPT-GM-7b",
|
382 |
+
"https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
|
383 |
+
"an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai",
|
384 |
+
)
|
385 |
+
|
386 |
+
register_model_info(
|
387 |
+
["baize-v2-7b", "baize-v2-13b"],
|
388 |
+
"Baize v2",
|
389 |
+
"https://github.com/project-baize/baize-chatbot#v2",
|
390 |
+
"A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.",
|
391 |
+
)
|
392 |
+
|
393 |
+
register_model_info(
|
394 |
+
[
|
395 |
+
"airoboros-l2-7b-2.1",
|
396 |
+
"airoboros-l2-13b-2.1",
|
397 |
+
"airoboros-c34b-2.1",
|
398 |
+
"airoboros-l2-70b-2.1",
|
399 |
+
],
|
400 |
+
"airoboros",
|
401 |
+
"https://huggingface.co/jondurbin/airoboros-l2-70b-2.1",
|
402 |
+
"an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4",
|
403 |
+
)
|
404 |
+
|
405 |
+
register_model_info(
|
406 |
+
[
|
407 |
+
"spicyboros-7b-2.2",
|
408 |
+
"spicyboros-13b-2.2",
|
409 |
+
"spicyboros-70b-2.2",
|
410 |
+
],
|
411 |
+
"spicyboros",
|
412 |
+
"https://huggingface.co/jondurbin/spicyboros-70b-2.2",
|
413 |
+
"de-aligned versions of the airoboros models",
|
414 |
+
)
|
415 |
+
|
416 |
+
register_model_info(
|
417 |
+
["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"],
|
418 |
+
"Robin-v2",
|
419 |
+
"https://huggingface.co/OptimalScale/robin-7b-v2-delta",
|
420 |
+
"A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.",
|
421 |
+
)
|
422 |
+
|
423 |
+
register_model_info(
|
424 |
+
["manticore-13b-chat"],
|
425 |
+
"Manticore 13B Chat",
|
426 |
+
"https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg",
|
427 |
+
"A chatbot fine-tuned from LlaMa across several CoT and chat datasets.",
|
428 |
+
)
|
429 |
+
|
430 |
+
register_model_info(
|
431 |
+
["redpajama-incite-7b-chat"],
|
432 |
+
"RedPajama-INCITE-7B-Chat",
|
433 |
+
"https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat",
|
434 |
+
"A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together",
|
435 |
+
)
|
436 |
+
|
437 |
+
register_model_info(
|
438 |
+
[
|
439 |
+
"falcon-7b",
|
440 |
+
"falcon-7b-instruct",
|
441 |
+
"falcon-40b",
|
442 |
+
"falcon-40b-instruct",
|
443 |
+
"falcon-180b",
|
444 |
+
"falcon-180b-chat",
|
445 |
+
],
|
446 |
+
"Falcon",
|
447 |
+
"https://huggingface.co/tiiuae/falcon-180B",
|
448 |
+
"TII's flagship series of large language models",
|
449 |
+
)
|
450 |
+
|
451 |
+
register_model_info(
|
452 |
+
["tigerbot-7b-sft"],
|
453 |
+
"Tigerbot",
|
454 |
+
"https://huggingface.co/TigerResearch/tigerbot-7b-sft",
|
455 |
+
"TigerBot is a large-scale language model (LLM) with multiple languages and tasks.",
|
456 |
+
)
|
457 |
+
|
458 |
+
register_model_info(
|
459 |
+
["internlm-chat-7b", "internlm-chat-7b-8k"],
|
460 |
+
"InternLM",
|
461 |
+
"https://huggingface.co/internlm/internlm-chat-7b",
|
462 |
+
"InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.",
|
463 |
+
)
|
464 |
+
|
465 |
+
register_model_info(
|
466 |
+
["Qwen-7B-Chat"],
|
467 |
+
"Qwen",
|
468 |
+
"https://huggingface.co/Qwen/Qwen-7B-Chat",
|
469 |
+
"Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.",
|
470 |
+
)
|
471 |
+
|
472 |
+
register_model_info(
|
473 |
+
["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"],
|
474 |
+
"Llama2-Chinese",
|
475 |
+
"https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat",
|
476 |
+
"Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.",
|
477 |
+
)
|
478 |
+
|
479 |
+
register_model_info(
|
480 |
+
["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"],
|
481 |
+
"Chinese-Alpaca",
|
482 |
+
"https://huggingface.co/hfl/chinese-alpaca-2-13b",
|
483 |
+
"New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.",
|
484 |
+
)
|
485 |
+
|
486 |
+
register_model_info(
|
487 |
+
["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"],
|
488 |
+
"Vigogne-Instruct",
|
489 |
+
"https://huggingface.co/bofenghuang/vigogne-2-7b-instruct",
|
490 |
+
"Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang",
|
491 |
+
)
|
492 |
+
|
493 |
+
register_model_info(
|
494 |
+
["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"],
|
495 |
+
"Vigogne-Chat",
|
496 |
+
"https://huggingface.co/bofenghuang/vigogne-2-7b-chat",
|
497 |
+
"Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang",
|
498 |
+
)
|
499 |
+
|
500 |
+
register_model_info(
|
501 |
+
["stable-vicuna-13B-HF"],
|
502 |
+
"stable-vicuna",
|
503 |
+
"https://huggingface.co/TheBloke/stable-vicuna-13B-HF",
|
504 |
+
"StableVicuna is a Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.",
|
505 |
+
)
|
506 |
+
|
507 |
+
register_model_info(
|
508 |
+
["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2"],
|
509 |
+
"DeluxeChat",
|
510 |
+
"",
|
511 |
+
"Deluxe Chat",
|
512 |
+
)
|
513 |
+
|
514 |
+
register_model_info(
|
515 |
+
[
|
516 |
+
"Xwin-LM-7B-V0.1",
|
517 |
+
"Xwin-LM-13B-V0.1",
|
518 |
+
"Xwin-LM-70B-V0.1",
|
519 |
+
"Xwin-LM-7B-V0.2",
|
520 |
+
"Xwin-LM-13B-V0.2",
|
521 |
+
],
|
522 |
+
"Xwin-LM",
|
523 |
+
"https://github.com/Xwin-LM/Xwin-LM",
|
524 |
+
"Chat models developed by Xwin-LM team",
|
525 |
+
)
|
526 |
+
|
527 |
+
register_model_info(
|
528 |
+
["lemur-70b-chat"],
|
529 |
+
"Lemur-Chat",
|
530 |
+
"https://huggingface.co/OpenLemur/lemur-70b-chat-v1",
|
531 |
+
"an openly accessible language model optimized for both natural language and coding capabilities ",
|
532 |
+
)
|
533 |
+
|
534 |
+
register_model_info(
|
535 |
+
["Mistral-7B-OpenOrca"],
|
536 |
+
"Open-Orca",
|
537 |
+
"https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca",
|
538 |
+
"A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)",
|
539 |
+
)
|
540 |
+
|
541 |
+
register_model_info(
|
542 |
+
["dolphin-2.2.1-mistral-7b"],
|
543 |
+
"dolphin-mistral",
|
544 |
+
"https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b",
|
545 |
+
"An uncensored fine-tuned Mistral 7B",
|
546 |
+
)
|
547 |
+
|
548 |
+
register_model_info(
|
549 |
+
[
|
550 |
+
"AquilaChat-7B",
|
551 |
+
"AquilaChat2-7B",
|
552 |
+
"AquilaChat2-34B",
|
553 |
+
],
|
554 |
+
"Aquila-Chat",
|
555 |
+
"https://huggingface.co/BAAI/AquilaChat2-34B",
|
556 |
+
"Chat models developed by BAAI team",
|
557 |
+
)
|
558 |
+
|
559 |
+
register_model_info(
|
560 |
+
["xDAN-L1-Chat-RL-v1"],
|
561 |
+
"xDAN-L1-Chat",
|
562 |
+
"https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1",
|
563 |
+
"A large language chat model created by xDAN-AI.",
|
564 |
+
)
|
565 |
+
|
566 |
+
register_model_info(
|
567 |
+
["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"],
|
568 |
+
"MetaMath",
|
569 |
+
"https://huggingface.co/meta-math",
|
570 |
+
"MetaMath is a finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.",
|
571 |
+
)
|
572 |
+
|
573 |
+
register_model_info(
|
574 |
+
["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"],
|
575 |
+
"IEIYuan",
|
576 |
+
"https://huggingface.co/IEITYuan",
|
577 |
+
"Yuan2 is a Basemodel developed by IEI.",
|
578 |
+
)
|
arena_elo/elo_rating/upload_battle_data.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import datasets
|
5 |
+
import datetime
|
6 |
+
from pathlib import Path
|
7 |
+
from datetime import datetime
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
datasets.config.DEFAULT_MAX_BATCH_SIZE = 500
|
11 |
+
def create_hf_dataset(data_file: str, split="test"):
|
12 |
+
hf_dataset = datasets.Dataset.from_list(
|
13 |
+
data_file,
|
14 |
+
features=datasets.Features(
|
15 |
+
{
|
16 |
+
"question_id": datasets.Value("string"),
|
17 |
+
"model": datasets.Value("string"),
|
18 |
+
"conversation": [
|
19 |
+
{
|
20 |
+
"role": datasets.Value("string"),
|
21 |
+
"content": datasets.Value("string"),
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"language": datasets.Value("string"),
|
25 |
+
"image": datasets.Image(),
|
26 |
+
"turn": datasets.Value("int32"),
|
27 |
+
}
|
28 |
+
),
|
29 |
+
split=split,
|
30 |
+
)
|
31 |
+
return hf_dataset
|
32 |
+
|
33 |
+
def create_hf_battle_dataset(data_file: str, split="test"):
|
34 |
+
hf_dataset = datasets.Dataset.from_list(
|
35 |
+
data_file,
|
36 |
+
features=datasets.Features(
|
37 |
+
{
|
38 |
+
"question_id": datasets.Value("string"),
|
39 |
+
"model_a": datasets.Value("string"),
|
40 |
+
"model_b": datasets.Value("string"),
|
41 |
+
"conversation_a": [
|
42 |
+
{
|
43 |
+
"role": datasets.Value("string"),
|
44 |
+
"content": datasets.Value("string"),
|
45 |
+
}
|
46 |
+
],
|
47 |
+
"conversation_b": [
|
48 |
+
{
|
49 |
+
"role": datasets.Value("string"),
|
50 |
+
"content": datasets.Value("string"),
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"language": datasets.Value("string"),
|
54 |
+
"image": datasets.Image(),
|
55 |
+
"turn": datasets.Value("int32"),
|
56 |
+
"anony": datasets.Value("bool"),
|
57 |
+
}
|
58 |
+
),
|
59 |
+
split=split,
|
60 |
+
)
|
61 |
+
return hf_dataset
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def load_image(path:str):
|
67 |
+
try:
|
68 |
+
return Image.open(path)
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error loading image {path}: {e}")
|
71 |
+
return None
|
72 |
+
|
73 |
+
def get_date_from_time_stamp(unix_timestamp: int):
|
74 |
+
# Create a datetime object from the Unix timestamp
|
75 |
+
dt = datetime.fromtimestamp(unix_timestamp)
|
76 |
+
|
77 |
+
# Convert the datetime object to a string with the desired format
|
78 |
+
date_str = dt.strftime("%Y-%m-%d")
|
79 |
+
return date_str
|
80 |
+
|
81 |
+
def load_battle_image(battle, log_dir):
|
82 |
+
image_path = Path(log_dir) / f"{get_date_from_time_stamp(battle['tstamp'])}-convinput_images" / f"input_image_{battle['question_id']}.png"
|
83 |
+
return load_image(image_path)
|
84 |
+
|
85 |
+
|
86 |
+
def main(
|
87 |
+
data_file: str = "./results/latest/clean_battle_conv.json",
|
88 |
+
repo_id: str = "DongfuTingle/wildvision-bench",
|
89 |
+
log_dir: str = os.getenv("LOGDIR", "./vision-arena-logs/"),
|
90 |
+
mode="battle",
|
91 |
+
token = os.environ.get("HUGGINGFACE_TOKEN", None)
|
92 |
+
):
|
93 |
+
with open(data_file, "r") as f:
|
94 |
+
data = json.load(f)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
has_image_stats = {
|
99 |
+
"has_image": 0,
|
100 |
+
"no_image": 0,
|
101 |
+
}
|
102 |
+
if mode == "keep_bad_only":
|
103 |
+
# anony only
|
104 |
+
data = [d for d in data if d["anony"]]
|
105 |
+
|
106 |
+
new_data = []
|
107 |
+
for battle in data:
|
108 |
+
image = load_battle_image(battle, log_dir)
|
109 |
+
if image is None:
|
110 |
+
has_image_stats["no_image"] += 1
|
111 |
+
# we don't keep the data without image
|
112 |
+
continue
|
113 |
+
has_image_stats["has_image"] += 1
|
114 |
+
|
115 |
+
if battle["winner"] in ["model_a", "model_b"]:
|
116 |
+
if battle["winner"] == "model_a":
|
117 |
+
worse_model = "model_b"
|
118 |
+
worse_conv = "conversation_b"
|
119 |
+
if battle["winner"] == "model_b":
|
120 |
+
worse_model = "model_a"
|
121 |
+
worse_conv = "conversation_a"
|
122 |
+
|
123 |
+
new_data.append({
|
124 |
+
"question_id": battle["question_id"],
|
125 |
+
"model": battle[worse_model],
|
126 |
+
"conversation": battle[worse_conv],
|
127 |
+
"language": battle["language"],
|
128 |
+
"image": image,
|
129 |
+
"turn": battle["turn"],
|
130 |
+
})
|
131 |
+
elif battle["winner"] == "tie (bothbad)":
|
132 |
+
|
133 |
+
new_data.append({
|
134 |
+
"question_id": battle["question_id"],
|
135 |
+
"model": battle["model_a"],
|
136 |
+
"conversation": battle["conversation_a"],
|
137 |
+
"language": battle["language"],
|
138 |
+
"image": image,
|
139 |
+
"turn": battle["turn"],
|
140 |
+
})
|
141 |
+
|
142 |
+
new_data.append({
|
143 |
+
"question_id": battle["question_id"],
|
144 |
+
"model": battle["model_b"],
|
145 |
+
"conversation": battle["conversation_b"],
|
146 |
+
"language": battle["language"],
|
147 |
+
"image": image,
|
148 |
+
"turn": battle["turn"],
|
149 |
+
})
|
150 |
+
|
151 |
+
split = "test"
|
152 |
+
hf_dataset = create_hf_dataset(new_data, "test")
|
153 |
+
|
154 |
+
elif mode == "battle":
|
155 |
+
new_data = []
|
156 |
+
for battle in data:
|
157 |
+
image = load_battle_image(battle, log_dir)
|
158 |
+
if image is None:
|
159 |
+
has_image_stats["no_image"] += 1
|
160 |
+
continue
|
161 |
+
has_image_stats["has_image"] += 1
|
162 |
+
new_data.append({
|
163 |
+
"question_id": battle["question_id"],
|
164 |
+
"model_a": battle["model_a"],
|
165 |
+
"model_b": battle["model_b"],
|
166 |
+
"conversation_a": battle["conversation_a"],
|
167 |
+
"conversation_b": battle["conversation_b"],
|
168 |
+
"language": battle["language"],
|
169 |
+
"image": image,
|
170 |
+
"turn": battle["turn"],
|
171 |
+
"anony": battle["anony"],
|
172 |
+
})
|
173 |
+
split = "test"
|
174 |
+
hf_dataset = create_hf_battle_dataset(new_data, "test")
|
175 |
+
else:
|
176 |
+
raise ValueError(f"Invalid mode: {mode}")
|
177 |
+
|
178 |
+
print(f"Stats: {has_image_stats}")
|
179 |
+
print(hf_dataset)
|
180 |
+
print(f"Uploading to part {repo_id}:{split}...")
|
181 |
+
hf_dataset.push_to_hub(
|
182 |
+
repo_id=repo_id,
|
183 |
+
config_name=mode,
|
184 |
+
split=split,
|
185 |
+
token=token,
|
186 |
+
commit_message=f"Add vision-arena {split} dataset",
|
187 |
+
)
|
188 |
+
|
189 |
+
print("Done!")
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
fire.Fire(main)
|
arena_elo/elo_rating/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import pytz
|
3 |
+
import PIL
|
4 |
+
import os
|
5 |
+
|
6 |
+
def detect_language(text: str) -> str:
|
7 |
+
"""Detect the langauge of a string."""
|
8 |
+
import polyglot # pip3 install polyglot pyicu pycld2
|
9 |
+
from polyglot.detect import Detector
|
10 |
+
from polyglot.detect.base import logger as polyglot_logger
|
11 |
+
import pycld2
|
12 |
+
|
13 |
+
polyglot_logger.setLevel("ERROR")
|
14 |
+
|
15 |
+
try:
|
16 |
+
lang_code = Detector(text).language.name
|
17 |
+
except (pycld2.error, polyglot.detect.base.UnknownLanguage):
|
18 |
+
lang_code = "unknown"
|
19 |
+
return lang_code
|
20 |
+
|
21 |
+
|
22 |
+
def get_time_stamp_from_date(date_str:str):
|
23 |
+
"""
|
24 |
+
Convert a date string to a Unix timestamp
|
25 |
+
Args:
|
26 |
+
date_str (str): The input date string in the format 'YYYY-MM-DD-HH:MM-TZ', e.g. '2024-02-10-14:00-PT'
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Convert the date string into a format that Python's datetime can understand
|
30 |
+
# and specify the correct timezone for PT, which is 'US/Pacific'
|
31 |
+
date_format = "%Y-%m-%d-%H:%M-%Z"
|
32 |
+
|
33 |
+
# Parse the date string into a datetime object
|
34 |
+
# Note: PT is not directly recognized by pytz, so we manually map it to 'US/Pacific'
|
35 |
+
timezone_map = {
|
36 |
+
"PT": "US/Pacific",
|
37 |
+
}
|
38 |
+
|
39 |
+
# Extract the timezone abbreviation
|
40 |
+
tz_abbr = date_str.split("-")[-1]
|
41 |
+
# Map the abbreviation to a pytz timezone
|
42 |
+
tz_info = pytz.timezone(timezone_map[tz_abbr])
|
43 |
+
|
44 |
+
# Remove the timezone abbreviation for parsing
|
45 |
+
date_str_parsed = date_str.rsplit("-", 1)[0]
|
46 |
+
|
47 |
+
# Create a datetime object with the corresponding timezone
|
48 |
+
dt = datetime.strptime(date_str_parsed, "%Y-%m-%d-%H:%M").replace(tzinfo=tz_info)
|
49 |
+
|
50 |
+
# Convert the datetime object to a Unix timestamp
|
51 |
+
unix_timestamp = dt.timestamp()
|
52 |
+
return unix_timestamp
|
53 |
+
|
54 |
+
def get_date_from_time_stamp(unix_timestamp: int):
|
55 |
+
# Create a datetime object from the Unix timestamp
|
56 |
+
dt = datetime.fromtimestamp(unix_timestamp)
|
57 |
+
|
58 |
+
# Convert the datetime object to a string with the desired format
|
59 |
+
date_str = dt.strftime("%Y-%m-%d %H:%M:%S %Z")
|
60 |
+
return date_str
|
61 |
+
|
62 |
+
|
63 |
+
def get_input_image_path(tstamp, conv_id):
|
64 |
+
# from tstamp to date e.g. 2024-02-10
|
65 |
+
date_str = datetime.fromtimestamp(tstamp, tz=pytz.timezone("US/Pacific")).strftime("%Y-%m-%d")
|
66 |
+
LOGDIR = os.getenv("LOGDIR")
|
67 |
+
return f"{LOGDIR}/{date_str}-convinput_images/input_image_{conv_id}.png"
|
68 |
+
|
69 |
+
def load_image_from_path(image_path):
|
70 |
+
# Load the image from the specified
|
71 |
+
# path using the Python Imaging Library (PIL)
|
72 |
+
try:
|
73 |
+
image = PIL.Image.open(image_path)
|
74 |
+
return image
|
75 |
+
except FileNotFoundError:
|
76 |
+
print(f"Image not found at path: {image_path}")
|
77 |
+
return None
|
78 |
+
except PIL.UnidentifiedImageError:
|
79 |
+
print(f"Unidentified image format at path: {image_path}")
|
80 |
+
return None
|
81 |
+
|
82 |
+
|
83 |
+
|
arena_elo/evaluator/convert_to_evaluator_data.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from pytz import timezone
|
6 |
+
from tqdm import tqdm
|
7 |
+
import base64
|
8 |
+
from icecream import ic
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
# Function to encode the image
|
13 |
+
def encode_image(image_path):
|
14 |
+
with open(image_path, "rb") as image_file:
|
15 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
16 |
+
|
17 |
+
def get_log_files(max_num_files=None):
|
18 |
+
dates = []
|
19 |
+
for month in [2, 3]:
|
20 |
+
for day in range(1, 32):
|
21 |
+
dates.append(f"2024-{month:02d}-{day:02d}")
|
22 |
+
|
23 |
+
num_servers = 1
|
24 |
+
filenames = []
|
25 |
+
for d in dates:
|
26 |
+
for i in range(num_servers):
|
27 |
+
# name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
28 |
+
name = os.path.expanduser(f"vision-arena-logs/{d}-conv.json")
|
29 |
+
if os.path.exists(name):
|
30 |
+
filenames.append(name)
|
31 |
+
max_num_files = max_num_files or len(filenames)
|
32 |
+
filenames = filenames[-max_num_files:]
|
33 |
+
return filenames
|
34 |
+
|
35 |
+
|
36 |
+
def pretty_print_conversation(messages):
|
37 |
+
for role, msg in messages:
|
38 |
+
print(f"[[{role}]]: {msg}")
|
39 |
+
|
40 |
+
task_template_map = {
|
41 |
+
"image_caption": "Give me the semantic alignment score between the given image and the given caption: \"{generated_sentence}\" on a scale of 0-100. Only reply the score value.",
|
42 |
+
"vqa": "Rate the answer correctness regarding the question within the context of the given image on a scale of 0-100. Only reply the score value.",
|
43 |
+
"pair_rate_old": "[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nGiven the instruction and the image, please compare the correctness of responses A and B. Reply with \"leftvote\" if you find A better, \"rightvote\" if B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both responses are equally satisfactory. If you are unable to make a decision, please reply with \"NA\".",
|
44 |
+
"pair_rate_wexplanation": "<image>[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.",
|
45 |
+
"pair_rate": "<image>[Instruction]\n\"{instruction}\"\n\n\"{generated_sentence}\"\n\n[System]\nPlease act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user’s instructions and answers the user’s question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any positional biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Reply with \"leftvote\" if you find assistant A better, \"rightvote\" if assistant B is better, \"bothbad_vote\" if both responses are wrong, and \"tievote\" if both assistants provide equally satisfactory answers. If you are unable to make a decision, please reply with \"NA\"."
|
46 |
+
}
|
47 |
+
|
48 |
+
def inspect_convs(log_files):
|
49 |
+
json_data = []
|
50 |
+
|
51 |
+
ic(log_files)
|
52 |
+
total_vote = 0
|
53 |
+
|
54 |
+
for filename in tqdm(log_files, desc="read files"):
|
55 |
+
for retry in range(5):
|
56 |
+
try:
|
57 |
+
lines = open(filename).readlines()
|
58 |
+
break
|
59 |
+
except FileNotFoundError:
|
60 |
+
time.sleep(2)
|
61 |
+
|
62 |
+
for l in lines:
|
63 |
+
row = json.loads(l)
|
64 |
+
|
65 |
+
if "states" not in row:
|
66 |
+
continue
|
67 |
+
if row["type"] not in ["leftvote", "rightvote", "bothbad_vote", "tievote"]:
|
68 |
+
continue
|
69 |
+
|
70 |
+
model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
|
71 |
+
|
72 |
+
|
73 |
+
# Iterate through each state and write the relevant information
|
74 |
+
if not len(row["states"][0]['messages']): continue
|
75 |
+
# ic(row["states"][0]['messages'][1][1])
|
76 |
+
|
77 |
+
if row["states"][0]['messages'][1][1] is None or row["states"][1]['messages'][1][1] is None or "NETWORK ERROR" in row["states"][0]['messages'][1][1] or "NETWORK ERROR" in row["states"][1]['messages'][1][1]: continue
|
78 |
+
total_vote += 1
|
79 |
+
|
80 |
+
conv_id = row["states"][0]['conv_id']
|
81 |
+
image_path = os.path.join("/local/home/yujielu/project/Arena-Elo/vision-arena-logs", os.path.basename(filename)[:-5]+"input_images", f"input_image_{conv_id}.png")
|
82 |
+
if not os.path.exists(image_path) :
|
83 |
+
continue
|
84 |
+
try:
|
85 |
+
image = Image.open(image_path).convert("RGB")
|
86 |
+
except:
|
87 |
+
continue
|
88 |
+
|
89 |
+
left_response = row["states"][0]['messages'][1][1]
|
90 |
+
right_response = row["states"][1]['messages'][1][1]
|
91 |
+
instruction = row["states"][0]['messages'][0][1]
|
92 |
+
generated_sentence = f"[The Start of Assistant A’s Answer]\n{left_response}\n[The End of Assistant A’s Answer]\n\n[The Start of Assistant B’s Answer]\n{right_response}\n[The End of Assistant B’s Answer]"
|
93 |
+
text_prompt = task_template_map["pair_rate"].format(instruction=instruction, generated_sentence=generated_sentence)
|
94 |
+
|
95 |
+
user_input = text_prompt
|
96 |
+
# Create the conversation structure
|
97 |
+
conversation = [
|
98 |
+
{
|
99 |
+
"from": "human",
|
100 |
+
"value": user_input
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"from": "gpt",
|
104 |
+
"value": row["type"]
|
105 |
+
}
|
106 |
+
]
|
107 |
+
|
108 |
+
# Create the JSON object for each row
|
109 |
+
json_obj = {
|
110 |
+
"id": conv_id,
|
111 |
+
"image": image_path,
|
112 |
+
"conversations": conversation
|
113 |
+
}
|
114 |
+
|
115 |
+
# Append the JSON object to the list
|
116 |
+
json_data.append(json_obj)
|
117 |
+
|
118 |
+
# Write the JSON data to a file
|
119 |
+
with open('output_evaluator_data.json', 'w') as json_file:
|
120 |
+
json.dump(json_data, json_file, indent=2)
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
parser = argparse.ArgumentParser()
|
124 |
+
parser.add_argument("--max-num-files", type=int)
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
log_files = get_log_files(args.max_num_files)
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
inspect_convs(log_files)
|
132 |
+
|
133 |
+
|
134 |
+
|
constants.py
CHANGED
@@ -26,4 +26,4 @@ TEXT_PROMPT_PATH = "offline/prompts_110.json"
|
|
26 |
IMAGE_PROMPT_PATH = "offline/image_urls.txt"
|
27 |
|
28 |
MAX_ATTEMPTS = 5
|
29 |
-
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN", "
|
|
|
26 |
IMAGE_PROMPT_PATH = "offline/image_urls.txt"
|
27 |
|
28 |
MAX_ATTEMPTS = 5
|
29 |
+
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN", "yourKey")
|
model/model_worker.py
CHANGED
@@ -6,8 +6,10 @@ from typing import List
|
|
6 |
import replicate
|
7 |
import subprocess
|
8 |
|
9 |
-
from
|
10 |
-
|
|
|
|
|
11 |
|
12 |
class BaseModelWorker:
|
13 |
def __init__(self,
|
|
|
6 |
import replicate
|
7 |
import subprocess
|
8 |
|
9 |
+
from gradio_client import Client
|
10 |
+
from client import Gau2Mesh_client
|
11 |
+
from constants import OFFLINE_GIF_DIR, REPLICATE_API_TOKEN
|
12 |
+
# os.environ("REPLICATE_API_TOKEN", "yourKey")
|
13 |
|
14 |
class BaseModelWorker:
|
15 |
def __init__(self,
|
serve/inference.py
CHANGED
@@ -188,14 +188,17 @@ def generate_t2s(gen_func, render_func,
|
|
188 |
state.rgb_video = videos['rgb']
|
189 |
yield state, videos['normal'], videos['rgb']
|
190 |
|
|
|
191 |
# logger.info(f"===output===: {output}")
|
192 |
data = {
|
193 |
-
"
|
|
|
194 |
"model": model_name,
|
195 |
"type": "offline",
|
196 |
"gen_params": {},
|
197 |
"state": state.dict(),
|
198 |
"start": round(start_time, 4),
|
|
|
199 |
}
|
200 |
else:
|
201 |
start_time = time.time()
|
@@ -210,14 +213,17 @@ def generate_t2s(gen_func, render_func,
|
|
210 |
state.rgb_video = videos['rgb']
|
211 |
yield state, videos['normal'], videos['rgb']
|
212 |
|
|
|
213 |
# logger.info(f"===output===: {output}")
|
214 |
data = {
|
215 |
-
"
|
|
|
216 |
"model": model_name,
|
217 |
"type": "online",
|
218 |
"gen_params": {},
|
219 |
"state": state.dict(),
|
220 |
"start": round(start_time, 4),
|
|
|
221 |
"time": round(finish_time - start_time, 4),
|
222 |
"generate_time": round(generate_time, 4),
|
223 |
"render_time": round(render_time, 4),
|
@@ -277,22 +283,27 @@ def generate_t2s_multi(gen_func, render_func,
|
|
277 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
278 |
yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
279 |
|
|
|
280 |
# logger.info(f"===output===: {output}")
|
281 |
data_0 = {
|
|
|
282 |
"ip": get_ip(request),
|
283 |
"model": model_name_0,
|
284 |
"type": "offline",
|
285 |
"gen_params": {},
|
286 |
"state": state_0.dict(),
|
287 |
"start": round(start_time, 4),
|
|
|
288 |
}
|
289 |
data_1 = {
|
|
|
290 |
"ip": get_ip(request),
|
291 |
"model": model_name_1,
|
292 |
"type": "offline",
|
293 |
"gen_params": {},
|
294 |
"state": state_1.dict(),
|
295 |
"start": round(start_time, 4),
|
|
|
296 |
}
|
297 |
else:
|
298 |
start_time = time.time()
|
@@ -307,25 +318,30 @@ def generate_t2s_multi(gen_func, render_func,
|
|
307 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
308 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
309 |
|
|
|
310 |
# logger.info(f"===output===: {output}")
|
311 |
data_0 = {
|
|
|
312 |
"ip": get_ip(request),
|
313 |
"model": model_name_0,
|
314 |
"type": "online",
|
315 |
"gen_params": {},
|
316 |
"state": state_0.dict(),
|
317 |
"start": round(start_time, 4),
|
|
|
318 |
"time": round(finish_time - start_time, 4),
|
319 |
"generate_time": round(generate_time, 4),
|
320 |
"render_time": round(render_time, 4),
|
321 |
}
|
322 |
data_1 = {
|
|
|
323 |
"ip": get_ip(request),
|
324 |
"model": model_name_1,
|
325 |
"type": "online",
|
326 |
"gen_params": {},
|
327 |
"state": state_1.dict(),
|
328 |
"start": round(start_time, 4),
|
|
|
329 |
"time": round(finish_time - start_time, 4),
|
330 |
"generate_time": round(generate_time, 4),
|
331 |
"render_time": round(render_time, 4),
|
@@ -386,22 +402,27 @@ def generate_t2s_multi_annoy(gen_func, render_func,
|
|
386 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
387 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
388 |
|
|
|
389 |
# logger.info(f"===output===: {output}")
|
390 |
data_0 = {
|
|
|
391 |
"ip": get_ip(request),
|
392 |
"model": model_name_0,
|
393 |
"type": "offline",
|
394 |
"gen_params": {},
|
395 |
"state": state_0.dict(),
|
396 |
"start": round(start_time, 4),
|
|
|
397 |
}
|
398 |
data_1 = {
|
|
|
399 |
"ip": get_ip(request),
|
400 |
"model": model_name_1,
|
401 |
"type": "offline",
|
402 |
"gen_params": {},
|
403 |
"state": state_1.dict(),
|
404 |
"start": round(start_time, 4),
|
|
|
405 |
}
|
406 |
else:
|
407 |
start_time = time.time()
|
@@ -418,25 +439,30 @@ def generate_t2s_multi_annoy(gen_func, render_func,
|
|
418 |
yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
|
419 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
420 |
|
|
|
421 |
# logger.info(f"===output===: {output}")
|
422 |
data_0 = {
|
|
|
423 |
"ip": get_ip(request),
|
424 |
"model": model_name_0,
|
425 |
"type": "online",
|
426 |
"gen_params": {},
|
427 |
"state": state_0.dict(),
|
428 |
"start": round(start_time, 4),
|
|
|
429 |
"time": round(finish_time - start_time, 4),
|
430 |
"generate_time": round(generate_time, 4),
|
431 |
"render_time": round(render_time, 4),
|
432 |
}
|
433 |
data_1 = {
|
|
|
434 |
"ip": get_ip(request),
|
435 |
"model": model_name_1,
|
436 |
"type": "online",
|
437 |
"gen_params": {},
|
438 |
"state": state_1.dict(),
|
439 |
"start": round(start_time, 4),
|
|
|
440 |
"time": round(finish_time - start_time, 4),
|
441 |
"generate_time": round(generate_time, 4),
|
442 |
"render_time": round(render_time, 4),
|
@@ -481,14 +507,17 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
|
|
481 |
state.rgb_video = videos['rgb']
|
482 |
yield state, videos['normal'], videos['rgb']
|
483 |
|
|
|
484 |
# logger.info(f"===output===: {output}")
|
485 |
data = {
|
486 |
-
"
|
|
|
487 |
"model": model_name,
|
488 |
"type": "offline",
|
489 |
"gen_params": {},
|
490 |
"state": state.dict(),
|
491 |
"start": round(start_time, 4),
|
|
|
492 |
}
|
493 |
else:
|
494 |
start_time = time.time()
|
@@ -503,14 +532,17 @@ def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Re
|
|
503 |
state.rgb_video = videos['rgb']
|
504 |
yield state, videos['normal'], videos['rgb']
|
505 |
|
|
|
506 |
# logger.info(f"===output===: {output}")
|
507 |
data = {
|
508 |
-
"
|
|
|
509 |
"model": model_name,
|
510 |
"type": "online",
|
511 |
"gen_params": {},
|
512 |
"state": state.dict(),
|
513 |
"start": round(start_time, 4),
|
|
|
514 |
"time": round(finish_time - start_time, 4),
|
515 |
"generate_time": round(generate_time, 4),
|
516 |
"render_time": round(render_time, 4),
|
@@ -567,22 +599,27 @@ def generate_i2s_multi(gen_func, render_func,
|
|
567 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
568 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
569 |
|
|
|
570 |
# logger.info(f"===output===: {output}")
|
571 |
data_0 = {
|
|
|
572 |
"ip": get_ip(request),
|
573 |
"model": model_name_0,
|
574 |
"type": "offline",
|
575 |
"gen_params": {},
|
576 |
"state": state_0.dict(),
|
577 |
"start": round(start_time, 4),
|
|
|
578 |
}
|
579 |
data_1 = {
|
|
|
580 |
"ip": get_ip(request),
|
581 |
"model": model_name_1,
|
582 |
"type": "offline",
|
583 |
"gen_params": {},
|
584 |
"state": state_1.dict(),
|
585 |
"start": round(start_time, 4),
|
|
|
586 |
}
|
587 |
else:
|
588 |
start_time = time.time()
|
@@ -597,25 +634,30 @@ def generate_i2s_multi(gen_func, render_func,
|
|
597 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
598 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
599 |
|
|
|
600 |
# logger.info(f"===output===: {output}")
|
601 |
data_0 = {
|
|
|
602 |
"ip": get_ip(request),
|
603 |
"model": model_name_0,
|
604 |
"type": "online",
|
605 |
"gen_params": {},
|
606 |
"state": state_0.dict(),
|
607 |
"start": round(start_time, 4),
|
|
|
608 |
"time": round(finish_time - start_time, 4),
|
609 |
"generate_time": round(generate_time, 4),
|
610 |
"render_time": round(render_time, 4),
|
611 |
}
|
612 |
data_1 = {
|
|
|
613 |
"ip": get_ip(request),
|
614 |
"model": model_name_1,
|
615 |
"type": "online",
|
616 |
"gen_params": {},
|
617 |
"state": state_1.dict(),
|
618 |
"start": round(start_time, 4),
|
|
|
619 |
"time": round(finish_time - start_time, 4),
|
620 |
"generate_time": round(generate_time, 4),
|
621 |
"render_time": round(render_time, 4),
|
@@ -672,26 +714,31 @@ def generate_i2s_multi_annoy(gen_func, render_func,
|
|
672 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
673 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
674 |
|
|
|
675 |
# logger.info(f"===output===: {output}")
|
676 |
data_0 = {
|
|
|
677 |
"ip": get_ip(request),
|
678 |
"model": model_name_0,
|
679 |
"type": "offline",
|
680 |
"gen_params": {},
|
681 |
"state": state_0.dict(),
|
682 |
"start": round(start_time, 4),
|
|
|
683 |
}
|
684 |
data_1 = {
|
|
|
685 |
"ip": get_ip(request),
|
686 |
"model": model_name_1,
|
687 |
"type": "offline",
|
688 |
"gen_params": {},
|
689 |
"state": state_1.dict(),
|
690 |
"start": round(start_time, 4),
|
|
|
691 |
}
|
692 |
else:
|
693 |
start_time = time.time()
|
694 |
-
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
|
695 |
generate_time = time.time() - start_time
|
696 |
|
697 |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
|
@@ -703,25 +750,30 @@ def generate_i2s_multi_annoy(gen_func, render_func,
|
|
703 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
704 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
705 |
|
|
|
706 |
# logger.info(f"===output===: {output}")
|
707 |
data_0 = {
|
|
|
708 |
"ip": get_ip(request),
|
709 |
"model": model_name_0,
|
710 |
"type": "online",
|
711 |
"gen_params": {},
|
712 |
"state": state_0.dict(),
|
713 |
"start": round(start_time, 4),
|
|
|
714 |
"time": round(finish_time - start_time, 4),
|
715 |
"generate_time": round(generate_time, 4),
|
716 |
"render_time": round(render_time, 4),
|
717 |
}
|
718 |
data_1 = {
|
|
|
719 |
"ip": get_ip(request),
|
720 |
"model": model_name_1,
|
721 |
"type": "online",
|
722 |
"gen_params": {},
|
723 |
"state": state_1.dict(),
|
724 |
"start": round(start_time, 4),
|
|
|
725 |
"time": round(finish_time - start_time, 4),
|
726 |
"generate_time": round(generate_time, 4),
|
727 |
"render_time": round(render_time, 4),
|
|
|
188 |
state.rgb_video = videos['rgb']
|
189 |
yield state, videos['normal'], videos['rgb']
|
190 |
|
191 |
+
finish_tstamp = time.time()
|
192 |
# logger.info(f"===output===: {output}")
|
193 |
data = {
|
194 |
+
"tstamp": round(finish_tstamp, 4),
|
195 |
+
"ip": get_ip(request),
|
196 |
"model": model_name,
|
197 |
"type": "offline",
|
198 |
"gen_params": {},
|
199 |
"state": state.dict(),
|
200 |
"start": round(start_time, 4),
|
201 |
+
"finish": round(finish_tstamp, 4),
|
202 |
}
|
203 |
else:
|
204 |
start_time = time.time()
|
|
|
213 |
state.rgb_video = videos['rgb']
|
214 |
yield state, videos['normal'], videos['rgb']
|
215 |
|
216 |
+
finish_tstamp = time.time()
|
217 |
# logger.info(f"===output===: {output}")
|
218 |
data = {
|
219 |
+
"tstamp": round(finish_tstamp, 4),
|
220 |
+
"ip": get_ip(request),
|
221 |
"model": model_name,
|
222 |
"type": "online",
|
223 |
"gen_params": {},
|
224 |
"state": state.dict(),
|
225 |
"start": round(start_time, 4),
|
226 |
+
"finish": round(finish_tstamp, 4),
|
227 |
"time": round(finish_time - start_time, 4),
|
228 |
"generate_time": round(generate_time, 4),
|
229 |
"render_time": round(render_time, 4),
|
|
|
283 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
284 |
yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
285 |
|
286 |
+
finish_tstamp = time.time()
|
287 |
# logger.info(f"===output===: {output}")
|
288 |
data_0 = {
|
289 |
+
"tstamp": round(finish_tstamp, 4),
|
290 |
"ip": get_ip(request),
|
291 |
"model": model_name_0,
|
292 |
"type": "offline",
|
293 |
"gen_params": {},
|
294 |
"state": state_0.dict(),
|
295 |
"start": round(start_time, 4),
|
296 |
+
"finish": round(finish_tstamp, 4),
|
297 |
}
|
298 |
data_1 = {
|
299 |
+
"tstamp": round(finish_tstamp, 4),
|
300 |
"ip": get_ip(request),
|
301 |
"model": model_name_1,
|
302 |
"type": "offline",
|
303 |
"gen_params": {},
|
304 |
"state": state_1.dict(),
|
305 |
"start": round(start_time, 4),
|
306 |
+
"finish": round(finish_tstamp, 4),
|
307 |
}
|
308 |
else:
|
309 |
start_time = time.time()
|
|
|
318 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
319 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
320 |
|
321 |
+
finish_tstamp = time.time()
|
322 |
# logger.info(f"===output===: {output}")
|
323 |
data_0 = {
|
324 |
+
"tstamp": round(finish_tstamp, 4),
|
325 |
"ip": get_ip(request),
|
326 |
"model": model_name_0,
|
327 |
"type": "online",
|
328 |
"gen_params": {},
|
329 |
"state": state_0.dict(),
|
330 |
"start": round(start_time, 4),
|
331 |
+
"finish": round(finish_tstamp, 4),
|
332 |
"time": round(finish_time - start_time, 4),
|
333 |
"generate_time": round(generate_time, 4),
|
334 |
"render_time": round(render_time, 4),
|
335 |
}
|
336 |
data_1 = {
|
337 |
+
"tstamp": round(finish_tstamp, 4),
|
338 |
"ip": get_ip(request),
|
339 |
"model": model_name_1,
|
340 |
"type": "online",
|
341 |
"gen_params": {},
|
342 |
"state": state_1.dict(),
|
343 |
"start": round(start_time, 4),
|
344 |
+
"finish": round(finish_tstamp, 4),
|
345 |
"time": round(finish_time - start_time, 4),
|
346 |
"generate_time": round(generate_time, 4),
|
347 |
"render_time": round(render_time, 4),
|
|
|
402 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
403 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
404 |
|
405 |
+
finish_tstamp = time.time()
|
406 |
# logger.info(f"===output===: {output}")
|
407 |
data_0 = {
|
408 |
+
"tstamp": round(finish_tstamp, 4),
|
409 |
"ip": get_ip(request),
|
410 |
"model": model_name_0,
|
411 |
"type": "offline",
|
412 |
"gen_params": {},
|
413 |
"state": state_0.dict(),
|
414 |
"start": round(start_time, 4),
|
415 |
+
"finish": round(finish_tstamp, 4),
|
416 |
}
|
417 |
data_1 = {
|
418 |
+
"tstamp": round(finish_tstamp, 4),
|
419 |
"ip": get_ip(request),
|
420 |
"model": model_name_1,
|
421 |
"type": "offline",
|
422 |
"gen_params": {},
|
423 |
"state": state_1.dict(),
|
424 |
"start": round(start_time, 4),
|
425 |
+
"finish": round(finish_tstamp, 4),
|
426 |
}
|
427 |
else:
|
428 |
start_time = time.time()
|
|
|
439 |
yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
|
440 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
441 |
|
442 |
+
finish_tstamp = time.time()
|
443 |
# logger.info(f"===output===: {output}")
|
444 |
data_0 = {
|
445 |
+
"tstamp": round(finish_tstamp, 4),
|
446 |
"ip": get_ip(request),
|
447 |
"model": model_name_0,
|
448 |
"type": "online",
|
449 |
"gen_params": {},
|
450 |
"state": state_0.dict(),
|
451 |
"start": round(start_time, 4),
|
452 |
+
"finish": round(finish_tstamp, 4),
|
453 |
"time": round(finish_time - start_time, 4),
|
454 |
"generate_time": round(generate_time, 4),
|
455 |
"render_time": round(render_time, 4),
|
456 |
}
|
457 |
data_1 = {
|
458 |
+
"tstamp": round(finish_tstamp, 4),
|
459 |
"ip": get_ip(request),
|
460 |
"model": model_name_1,
|
461 |
"type": "online",
|
462 |
"gen_params": {},
|
463 |
"state": state_1.dict(),
|
464 |
"start": round(start_time, 4),
|
465 |
+
"finish": round(finish_tstamp, 4),
|
466 |
"time": round(finish_time - start_time, 4),
|
467 |
"generate_time": round(generate_time, 4),
|
468 |
"render_time": round(render_time, 4),
|
|
|
507 |
state.rgb_video = videos['rgb']
|
508 |
yield state, videos['normal'], videos['rgb']
|
509 |
|
510 |
+
finish_tstamp = time.time()
|
511 |
# logger.info(f"===output===: {output}")
|
512 |
data = {
|
513 |
+
"tstamp": round(finish_tstamp, 4),
|
514 |
+
"ip": get_ip(request),
|
515 |
"model": model_name,
|
516 |
"type": "offline",
|
517 |
"gen_params": {},
|
518 |
"state": state.dict(),
|
519 |
"start": round(start_time, 4),
|
520 |
+
"finish": round(finish_tstamp, 4),
|
521 |
}
|
522 |
else:
|
523 |
start_time = time.time()
|
|
|
532 |
state.rgb_video = videos['rgb']
|
533 |
yield state, videos['normal'], videos['rgb']
|
534 |
|
535 |
+
finish_tstamp = time.time()
|
536 |
# logger.info(f"===output===: {output}")
|
537 |
data = {
|
538 |
+
"tstamp": round(finish_tstamp, 4),
|
539 |
+
"ip": get_ip(request),
|
540 |
"model": model_name,
|
541 |
"type": "online",
|
542 |
"gen_params": {},
|
543 |
"state": state.dict(),
|
544 |
"start": round(start_time, 4),
|
545 |
+
"finish": round(finish_tstamp, 4),
|
546 |
"time": round(finish_time - start_time, 4),
|
547 |
"generate_time": round(generate_time, 4),
|
548 |
"render_time": round(render_time, 4),
|
|
|
599 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
600 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
601 |
|
602 |
+
finish_tstamp = time.time()
|
603 |
# logger.info(f"===output===: {output}")
|
604 |
data_0 = {
|
605 |
+
"tstamp": round(finish_tstamp, 4),
|
606 |
"ip": get_ip(request),
|
607 |
"model": model_name_0,
|
608 |
"type": "offline",
|
609 |
"gen_params": {},
|
610 |
"state": state_0.dict(),
|
611 |
"start": round(start_time, 4),
|
612 |
+
"finish": round(finish_tstamp, 4),
|
613 |
}
|
614 |
data_1 = {
|
615 |
+
"tstamp": round(finish_tstamp, 4),
|
616 |
"ip": get_ip(request),
|
617 |
"model": model_name_1,
|
618 |
"type": "offline",
|
619 |
"gen_params": {},
|
620 |
"state": state_1.dict(),
|
621 |
"start": round(start_time, 4),
|
622 |
+
"finish": round(finish_tstamp, 4),
|
623 |
}
|
624 |
else:
|
625 |
start_time = time.time()
|
|
|
634 |
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
|
635 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
|
636 |
|
637 |
+
finish_tstamp = time.time()
|
638 |
# logger.info(f"===output===: {output}")
|
639 |
data_0 = {
|
640 |
+
"tstamp": round(finish_tstamp, 4),
|
641 |
"ip": get_ip(request),
|
642 |
"model": model_name_0,
|
643 |
"type": "online",
|
644 |
"gen_params": {},
|
645 |
"state": state_0.dict(),
|
646 |
"start": round(start_time, 4),
|
647 |
+
"finish": round(finish_tstamp, 4),
|
648 |
"time": round(finish_time - start_time, 4),
|
649 |
"generate_time": round(generate_time, 4),
|
650 |
"render_time": round(render_time, 4),
|
651 |
}
|
652 |
data_1 = {
|
653 |
+
"tstamp": round(finish_tstamp, 4),
|
654 |
"ip": get_ip(request),
|
655 |
"model": model_name_1,
|
656 |
"type": "online",
|
657 |
"gen_params": {},
|
658 |
"state": state_1.dict(),
|
659 |
"start": round(start_time, 4),
|
660 |
+
"finish": round(finish_tstamp, 4),
|
661 |
"time": round(finish_time - start_time, 4),
|
662 |
"generate_time": round(generate_time, 4),
|
663 |
"render_time": round(render_time, 4),
|
|
|
714 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
715 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
716 |
|
717 |
+
finish_tstamp = time.time()
|
718 |
# logger.info(f"===output===: {output}")
|
719 |
data_0 = {
|
720 |
+
"tstamp": round(finish_tstamp, 4),
|
721 |
"ip": get_ip(request),
|
722 |
"model": model_name_0,
|
723 |
"type": "offline",
|
724 |
"gen_params": {},
|
725 |
"state": state_0.dict(),
|
726 |
"start": round(start_time, 4),
|
727 |
+
"finish": round(finish_tstamp, 4),
|
728 |
}
|
729 |
data_1 = {
|
730 |
+
"tstamp": round(finish_tstamp, 4),
|
731 |
"ip": get_ip(request),
|
732 |
"model": model_name_1,
|
733 |
"type": "offline",
|
734 |
"gen_params": {},
|
735 |
"state": state_1.dict(),
|
736 |
"start": round(start_time, 4),
|
737 |
+
"finish": round(finish_tstamp, 4),
|
738 |
}
|
739 |
else:
|
740 |
start_time = time.time()
|
741 |
+
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1, i2s_model=True)
|
742 |
generate_time = time.time() - start_time
|
743 |
|
744 |
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
|
|
|
750 |
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
|
751 |
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
|
752 |
|
753 |
+
finish_tstamp = time.time()
|
754 |
# logger.info(f"===output===: {output}")
|
755 |
data_0 = {
|
756 |
+
"tstamp": round(finish_tstamp, 4),
|
757 |
"ip": get_ip(request),
|
758 |
"model": model_name_0,
|
759 |
"type": "online",
|
760 |
"gen_params": {},
|
761 |
"state": state_0.dict(),
|
762 |
"start": round(start_time, 4),
|
763 |
+
"finish": round(finish_tstamp, 4),
|
764 |
"time": round(finish_time - start_time, 4),
|
765 |
"generate_time": round(generate_time, 4),
|
766 |
"render_time": round(render_time, 4),
|
767 |
}
|
768 |
data_1 = {
|
769 |
+
"tstamp": round(finish_tstamp, 4),
|
770 |
"ip": get_ip(request),
|
771 |
"model": model_name_1,
|
772 |
"type": "online",
|
773 |
"gen_params": {},
|
774 |
"state": state_1.dict(),
|
775 |
"start": round(start_time, 4),
|
776 |
+
"finish": round(finish_tstamp, 4),
|
777 |
"time": round(finish_time - start_time, 4),
|
778 |
"generate_time": round(generate_time, 4),
|
779 |
"render_time": round(render_time, 4),
|
serve/leaderboard.py
CHANGED
@@ -39,8 +39,6 @@ leader_component_values = [None] * 5
|
|
39 |
def make_leaderboard_md(elo_results):
|
40 |
leaderboard_md = f"""
|
41 |
# 🏆 GenAI-Arena Leaderboard
|
42 |
-
| [GitHub](https://github.com/TIGER-AI-Lab/ImagenHub) | [Dataset](https://huggingface.co/ImagenHub) | [Twitter](https://twitter.com/TianleLI123/status/1757245259149422752) |
|
43 |
-
|
44 |
"""
|
45 |
return leaderboard_md
|
46 |
|
|
|
39 |
def make_leaderboard_md(elo_results):
|
40 |
leaderboard_md = f"""
|
41 |
# 🏆 GenAI-Arena Leaderboard
|
|
|
|
|
42 |
"""
|
43 |
return leaderboard_md
|
44 |
|
serve/vote_utils.py
CHANGED
@@ -26,12 +26,13 @@ def vote_last_response_t2s(state, dim, vote_type, model_selector, request: gr.Re
|
|
26 |
fout.write(json.dumps(data) + "\n")
|
27 |
append_json_item_on_log_server(data, get_conv_log_filename())
|
28 |
|
29 |
-
def vote_last_response_t2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
|
30 |
with open(get_conv_log_filename(), "a") as fout:
|
31 |
data = {
|
32 |
"tstamp": round(time.time(), 4),
|
33 |
"dim": dim,
|
34 |
"type": vote_type,
|
|
|
35 |
"models": [x for x in model_selectors],
|
36 |
"states": [x.dict() for x in states],
|
37 |
"ip": get_ip(request),
|
@@ -65,12 +66,13 @@ def vote_last_response_i2s(state, dim, vote_type, model_selector, request: gr.Re
|
|
65 |
# save_image_file_on_log_server(output_file)
|
66 |
# save_image_file_on_log_server(source_file)
|
67 |
|
68 |
-
def vote_last_response_i2s_multi(states, dim, vote_type, model_selectors, request: gr.Request):
|
69 |
with open(get_conv_log_filename(), "a") as fout:
|
70 |
data = {
|
71 |
"tstamp": round(time.time(), 4),
|
72 |
"dim": dim,
|
73 |
"type": vote_type,
|
|
|
74 |
"models": [x for x in model_selectors],
|
75 |
"states": [x.dict() for x in states],
|
76 |
"ip": get_ip(request),
|
@@ -91,23 +93,23 @@ def vote_last_response_i2s_multi(states, dim, vote_type, model_selectors, reques
|
|
91 |
## Text-to-Shape Generation (t2s) Single Model Direct Chat
|
92 |
def upvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
93 |
ip = get_ip(request)
|
|
|
94 |
t2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
|
95 |
vote_last_response_t2s(state, dim_md, "upvote", model_selector, request)
|
96 |
-
state.evaluted_dims += 1
|
97 |
return (state,) + (disable_btn,) * 3
|
98 |
|
99 |
def downvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
100 |
ip = get_ip(request)
|
|
|
101 |
t2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
|
102 |
vote_last_response_t2s(state, dim_md, "downvote", model_selector, request)
|
103 |
-
state.evaluted_dims += 1
|
104 |
return (state,) + (disable_btn,) * 3
|
105 |
|
106 |
def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
107 |
ip = get_ip(request)
|
|
|
108 |
t2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
|
109 |
vote_last_response_t2s(state, dim_md, "flag", model_selector, request)
|
110 |
-
state.evaluted_dims += 1
|
111 |
return (state,) + (disable_btn,) * 3
|
112 |
|
113 |
|
@@ -115,132 +117,153 @@ def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
|
115 |
def leftvote_last_response_t2s_named(
|
116 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
117 |
):
|
|
|
|
|
118 |
t2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
|
119 |
vote_last_response_t2s_multi(
|
120 |
-
[state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
|
121 |
)
|
122 |
-
state0.evaluted_dims += 1
|
123 |
-
state1.evaluted_dims += 1
|
124 |
return (state0, state1) + (disable_btn,) * 4
|
125 |
|
126 |
def rightvote_last_response_t2s_named(
|
127 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
128 |
):
|
|
|
|
|
129 |
t2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
|
130 |
vote_last_response_t2s_multi(
|
131 |
-
[state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
|
132 |
)
|
133 |
-
state0.evaluted_dims += 1
|
134 |
-
state1.evaluted_dims += 1
|
135 |
return (state0, state1) + (disable_btn,) * 4
|
136 |
|
137 |
def tievote_last_response_t2s_named(
|
138 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
139 |
):
|
|
|
|
|
140 |
t2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
|
141 |
vote_last_response_t2s_multi(
|
142 |
-
[state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
|
143 |
)
|
144 |
-
state0.evaluted_dims += 1
|
145 |
-
state1.evaluted_dims += 1
|
146 |
return (state0, state1) + (disable_btn,) * 4
|
147 |
|
148 |
def bothbad_vote_last_response_t2s_named(
|
149 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
150 |
):
|
|
|
|
|
151 |
t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
|
152 |
vote_last_response_t2s_multi(
|
153 |
-
[state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
|
154 |
)
|
155 |
-
state0.evaluted_dims += 1
|
156 |
-
state1.evaluted_dims += 1
|
157 |
return (state0, state1) + (disable_btn,) * 4
|
158 |
|
159 |
|
160 |
def leftvote_last_response_t2s_anony(
|
161 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
162 |
):
|
|
|
|
|
163 |
t2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
164 |
vote_last_response_t2s_multi(
|
165 |
-
[state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
|
166 |
)
|
167 |
|
168 |
-
state0.evaluted_dims
|
169 |
-
state1.
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def rightvote_last_response_t2s_anony(
|
177 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
178 |
):
|
|
|
|
|
179 |
t2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
180 |
vote_last_response_t2s_multi(
|
181 |
-
[state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
|
182 |
)
|
183 |
|
184 |
-
state0.evaluted_dims
|
185 |
-
state1.
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
191 |
|
192 |
def tievote_last_response_t2s_anony(
|
193 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
194 |
):
|
|
|
|
|
195 |
t2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
|
196 |
vote_last_response_t2s_multi(
|
197 |
-
[state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
|
198 |
)
|
199 |
|
200 |
-
state0.evaluted_dims
|
201 |
-
state1.
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
207 |
|
208 |
def bothbad_vote_last_response_t2s_anony(
|
209 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
210 |
):
|
|
|
|
|
211 |
t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
|
212 |
vote_last_response_t2s_multi(
|
213 |
-
[state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
|
214 |
)
|
215 |
|
216 |
-
state0.evaluted_dims
|
217 |
-
state1.
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
|
224 |
## Image-to-Shape (i2s) Single Model Direct Chat
|
225 |
def upvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
226 |
ip = get_ip(request)
|
|
|
227 |
i2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
|
228 |
vote_last_response_i2s(state, dim_md, "upvote", model_selector, request)
|
229 |
-
state.evaluted_dims += 1
|
230 |
return (state,) + (disable_btn,) * 3
|
231 |
|
232 |
def downvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
233 |
ip = get_ip(request)
|
|
|
234 |
i2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
|
235 |
vote_last_response_i2s(state, dim_md, "downvote", model_selector, request)
|
236 |
-
state.evaluted_dims += 1
|
237 |
return (state,) + (disable_btn,) * 3
|
238 |
|
239 |
def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
240 |
ip = get_ip(request)
|
|
|
241 |
i2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
|
242 |
vote_last_response_i2s(state, dim_md, "flag", model_selector, request)
|
243 |
-
state.evaluted_dims += 1
|
244 |
return (state,) + (disable_btn,) * 3
|
245 |
|
246 |
|
@@ -248,114 +271,134 @@ def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
|
248 |
def leftvote_last_response_i2s_named(
|
249 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
250 |
):
|
|
|
|
|
251 |
i2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
|
252 |
vote_last_response_i2s_multi(
|
253 |
-
[state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
|
254 |
)
|
255 |
-
state0.evaluted_dims += 1
|
256 |
-
state1.evaluted_dims += 1
|
257 |
return (state0, state1) + (disable_btn,) * 4
|
258 |
|
259 |
def rightvote_last_response_i2s_named(
|
260 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
261 |
):
|
|
|
|
|
262 |
i2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
|
263 |
vote_last_response_i2s_multi(
|
264 |
-
[state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
|
265 |
)
|
266 |
-
state0.evaluted_dims += 1
|
267 |
-
state1.evaluted_dims += 1
|
268 |
return (state0, state1) + (disable_btn,) * 4
|
269 |
|
270 |
def tievote_last_response_i2s_named(
|
271 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
272 |
):
|
|
|
|
|
273 |
i2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
|
274 |
vote_last_response_i2s_multi(
|
275 |
-
[state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
|
276 |
)
|
277 |
-
state0.evaluted_dims += 1
|
278 |
-
state1.evaluted_dims += 1
|
279 |
return (state0, state1) + (disable_btn,) * 4
|
280 |
|
281 |
def bothbad_vote_last_response_i2s_named(
|
282 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
283 |
):
|
|
|
|
|
284 |
i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
|
285 |
vote_last_response_i2s_multi(
|
286 |
-
[state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
|
287 |
)
|
288 |
-
state0.evaluted_dims += 1
|
289 |
-
state1.evaluted_dims += 1
|
290 |
return (state0, state1) + (disable_btn,) * 4
|
291 |
|
292 |
|
293 |
def leftvote_last_response_i2s_anony(
|
294 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
295 |
):
|
|
|
|
|
296 |
i2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
297 |
vote_last_response_i2s_multi(
|
298 |
-
[state0, state1], dim_md, "leftvote", [model_selector0, model_selector1], request
|
299 |
)
|
300 |
|
301 |
-
state0.evaluted_dims
|
302 |
-
state1.
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
308 |
|
309 |
|
310 |
def rightvote_last_response_i2s_anony(
|
311 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
312 |
):
|
|
|
|
|
313 |
i2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
314 |
vote_last_response_i2s_multi(
|
315 |
-
[state0, state1], dim_md, "rightvote", [model_selector0, model_selector1], request
|
316 |
)
|
317 |
|
318 |
-
state0.evaluted_dims
|
319 |
-
state1.
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
325 |
|
326 |
|
327 |
def tievote_last_response_i2s_anony(
|
328 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
329 |
):
|
|
|
|
|
330 |
i2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
|
331 |
vote_last_response_i2s_multi(
|
332 |
-
[state0, state1], dim_md, "tievote", [model_selector0, model_selector1], request
|
333 |
)
|
334 |
|
335 |
-
state0.evaluted_dims
|
336 |
-
state1.
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
342 |
|
343 |
|
344 |
def bothbad_vote_last_response_i2s_anony(
|
345 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
346 |
):
|
|
|
|
|
347 |
i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
|
348 |
vote_last_response_i2s_multi(
|
349 |
-
[state0, state1], dim_md, "bothbad_vote", [model_selector0, model_selector1], request
|
350 |
)
|
351 |
|
352 |
-
state0.evaluted_dims
|
353 |
-
state1.
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
359 |
|
360 |
|
361 |
|
@@ -383,13 +426,13 @@ def share_click_t2s_multi(state0, state1, model_selector0, model_selector1, requ
|
|
383 |
t2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
|
384 |
if state0 is not None and state1 is not None:
|
385 |
vote_last_response_t2s_multi(
|
386 |
-
[state0, state1], "share", [model_selector0, model_selector1], request
|
387 |
)
|
388 |
|
389 |
def share_click_i2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
390 |
i2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
|
391 |
if state0 is not None and state1 is not None:
|
392 |
vote_last_response_i2s_multi(
|
393 |
-
[state0, state1], "share", [model_selector0, model_selector1], request
|
394 |
)
|
395 |
|
|
|
26 |
fout.write(json.dumps(data) + "\n")
|
27 |
append_json_item_on_log_server(data, get_conv_log_filename())
|
28 |
|
29 |
+
def vote_last_response_t2s_multi(states, dim, vote_type, is_anony: bool, model_selectors, request: gr.Request):
|
30 |
with open(get_conv_log_filename(), "a") as fout:
|
31 |
data = {
|
32 |
"tstamp": round(time.time(), 4),
|
33 |
"dim": dim,
|
34 |
"type": vote_type,
|
35 |
+
"anony": is_anony,
|
36 |
"models": [x for x in model_selectors],
|
37 |
"states": [x.dict() for x in states],
|
38 |
"ip": get_ip(request),
|
|
|
66 |
# save_image_file_on_log_server(output_file)
|
67 |
# save_image_file_on_log_server(source_file)
|
68 |
|
69 |
+
def vote_last_response_i2s_multi(states, dim, vote_type, is_anony, model_selectors, request: gr.Request):
|
70 |
with open(get_conv_log_filename(), "a") as fout:
|
71 |
data = {
|
72 |
"tstamp": round(time.time(), 4),
|
73 |
"dim": dim,
|
74 |
"type": vote_type,
|
75 |
+
"anony": is_anony,
|
76 |
"models": [x for x in model_selectors],
|
77 |
"states": [x.dict() for x in states],
|
78 |
"ip": get_ip(request),
|
|
|
93 |
## Text-to-Shape Generation (t2s) Single Model Direct Chat
|
94 |
def upvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
95 |
ip = get_ip(request)
|
96 |
+
state.evaluted_dims += 1
|
97 |
t2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
|
98 |
vote_last_response_t2s(state, dim_md, "upvote", model_selector, request)
|
|
|
99 |
return (state,) + (disable_btn,) * 3
|
100 |
|
101 |
def downvote_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
102 |
ip = get_ip(request)
|
103 |
+
state.evaluted_dims += 1
|
104 |
t2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
|
105 |
vote_last_response_t2s(state, dim_md, "downvote", model_selector, request)
|
|
|
106 |
return (state,) + (disable_btn,) * 3
|
107 |
|
108 |
def flag_last_response_t2s(state, model_selector, dim_md, request: gr.Request):
|
109 |
ip = get_ip(request)
|
110 |
+
state.evaluted_dims += 1
|
111 |
t2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
|
112 |
vote_last_response_t2s(state, dim_md, "flag", model_selector, request)
|
|
|
113 |
return (state,) + (disable_btn,) * 3
|
114 |
|
115 |
|
|
|
117 |
def leftvote_last_response_t2s_named(
|
118 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
119 |
):
|
120 |
+
state0.evaluted_dims += 1
|
121 |
+
state1.evaluted_dims += 1
|
122 |
t2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
|
123 |
vote_last_response_t2s_multi(
|
124 |
+
[state0, state1], dim_md, "leftvote", False, [model_selector0, model_selector1], request
|
125 |
)
|
|
|
|
|
126 |
return (state0, state1) + (disable_btn,) * 4
|
127 |
|
128 |
def rightvote_last_response_t2s_named(
|
129 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
130 |
):
|
131 |
+
state0.evaluted_dims += 1
|
132 |
+
state1.evaluted_dims += 1
|
133 |
t2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
|
134 |
vote_last_response_t2s_multi(
|
135 |
+
[state0, state1], dim_md, "rightvote", False, [model_selector0, model_selector1], request
|
136 |
)
|
|
|
|
|
137 |
return (state0, state1) + (disable_btn,) * 4
|
138 |
|
139 |
def tievote_last_response_t2s_named(
|
140 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
141 |
):
|
142 |
+
state0.evaluted_dims += 1
|
143 |
+
state1.evaluted_dims += 1
|
144 |
t2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
|
145 |
vote_last_response_t2s_multi(
|
146 |
+
[state0, state1], dim_md, "tievote", False, [model_selector0, model_selector1], request
|
147 |
)
|
|
|
|
|
148 |
return (state0, state1) + (disable_btn,) * 4
|
149 |
|
150 |
def bothbad_vote_last_response_t2s_named(
|
151 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
152 |
):
|
153 |
+
state0.evaluted_dims += 1
|
154 |
+
state1.evaluted_dims += 1
|
155 |
t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
|
156 |
vote_last_response_t2s_multi(
|
157 |
+
[state0, state1], dim_md, "bothbad_vote", False, [model_selector0, model_selector1], request
|
158 |
)
|
|
|
|
|
159 |
return (state0, state1) + (disable_btn,) * 4
|
160 |
|
161 |
|
162 |
def leftvote_last_response_t2s_anony(
|
163 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
164 |
):
|
165 |
+
state0.evaluted_dims += 1
|
166 |
+
state1.evaluted_dims += 1
|
167 |
t2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
168 |
vote_last_response_t2s_multi(
|
169 |
+
[state0, state1], dim_md, "leftvote", True, [model_selector0, model_selector1], request
|
170 |
)
|
171 |
|
172 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
173 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
174 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
175 |
+
# else:
|
176 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=False), gr.Markdown(f"### Model B: {state1.model_name}", visible=False))
|
177 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
178 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
179 |
+
return (state0, state1) \
|
180 |
+
+ (disable_btn,) * 4 \
|
181 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
182 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
183 |
|
184 |
def rightvote_last_response_t2s_anony(
|
185 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
186 |
):
|
187 |
+
state0.evaluted_dims += 1
|
188 |
+
state1.evaluted_dims += 1
|
189 |
t2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
190 |
vote_last_response_t2s_multi(
|
191 |
+
[state0, state1], dim_md, "rightvote", True, [model_selector0, model_selector1], request
|
192 |
)
|
193 |
|
194 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
195 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
196 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
197 |
+
# else:
|
198 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
199 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
200 |
+
return (state0, state1) \
|
201 |
+
+ (disable_btn,) * 4 \
|
202 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
203 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
204 |
|
205 |
def tievote_last_response_t2s_anony(
|
206 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
207 |
):
|
208 |
+
state0.evaluted_dims += 1
|
209 |
+
state1.evaluted_dims += 1
|
210 |
t2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
|
211 |
vote_last_response_t2s_multi(
|
212 |
+
[state0, state1], dim_md, "tievote", True, [model_selector0, model_selector1], request
|
213 |
)
|
214 |
|
215 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
216 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
217 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
218 |
+
# else:
|
219 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
220 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
221 |
+
return (state0, state1) \
|
222 |
+
+ (disable_btn,) * 4 \
|
223 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
224 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
225 |
|
226 |
def bothbad_vote_last_response_t2s_anony(
|
227 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
228 |
):
|
229 |
+
state0.evaluted_dims += 1
|
230 |
+
state1.evaluted_dims += 1
|
231 |
t2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
|
232 |
vote_last_response_t2s_multi(
|
233 |
+
[state0, state1], dim_md, "bothbad_vote", True, [model_selector0, model_selector1], request
|
234 |
)
|
235 |
|
236 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
237 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
238 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
239 |
+
# else:
|
240 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
241 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
242 |
+
return (state0, state1) \
|
243 |
+
+ (disable_btn,) * 4 \
|
244 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
245 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
246 |
|
247 |
## Image-to-Shape (i2s) Single Model Direct Chat
|
248 |
def upvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
249 |
ip = get_ip(request)
|
250 |
+
state.evaluted_dims += 1
|
251 |
i2s_logger.info(f"upvote [{dim_md}]. ip: {ip}")
|
252 |
vote_last_response_i2s(state, dim_md, "upvote", model_selector, request)
|
|
|
253 |
return (state,) + (disable_btn,) * 3
|
254 |
|
255 |
def downvote_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
256 |
ip = get_ip(request)
|
257 |
+
state.evaluted_dims += 1
|
258 |
i2s_logger.info(f"downvote [{dim_md}]. ip: {ip}")
|
259 |
vote_last_response_i2s(state, dim_md, "downvote", model_selector, request)
|
|
|
260 |
return (state,) + (disable_btn,) * 3
|
261 |
|
262 |
def flag_last_response_i2s(state, model_selector, dim_md, request: gr.Request):
|
263 |
ip = get_ip(request)
|
264 |
+
state.evaluted_dims += 1
|
265 |
i2s_logger.info(f"flag [{dim_md}]. ip: {ip}")
|
266 |
vote_last_response_i2s(state, dim_md, "flag", model_selector, request)
|
|
|
267 |
return (state,) + (disable_btn,) * 3
|
268 |
|
269 |
|
|
|
271 |
def leftvote_last_response_i2s_named(
|
272 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
273 |
):
|
274 |
+
state0.evaluted_dims += 1
|
275 |
+
state1.evaluted_dims += 1
|
276 |
i2s_multi_logger.info(f"leftvote [{dim_md}] (named). ip: {get_ip(request)}")
|
277 |
vote_last_response_i2s_multi(
|
278 |
+
[state0, state1], dim_md, "leftvote", False, [model_selector0, model_selector1], request
|
279 |
)
|
|
|
|
|
280 |
return (state0, state1) + (disable_btn,) * 4
|
281 |
|
282 |
def rightvote_last_response_i2s_named(
|
283 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
284 |
):
|
285 |
+
state0.evaluted_dims += 1
|
286 |
+
state1.evaluted_dims += 1
|
287 |
i2s_multi_logger.info(f"rightvote [{dim_md}] (named). ip: {get_ip(request)}")
|
288 |
vote_last_response_i2s_multi(
|
289 |
+
[state0, state1], dim_md, "rightvote", False, [model_selector0, model_selector1], request
|
290 |
)
|
|
|
|
|
291 |
return (state0, state1) + (disable_btn,) * 4
|
292 |
|
293 |
def tievote_last_response_i2s_named(
|
294 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
295 |
):
|
296 |
+
state0.evaluted_dims += 1
|
297 |
+
state1.evaluted_dims += 1
|
298 |
i2s_multi_logger.info(f"tievote [{dim_md}] (named). ip: {get_ip(request)}")
|
299 |
vote_last_response_i2s_multi(
|
300 |
+
[state0, state1], dim_md, "tievote", False, [model_selector0, model_selector1], request
|
301 |
)
|
|
|
|
|
302 |
return (state0, state1) + (disable_btn,) * 4
|
303 |
|
304 |
def bothbad_vote_last_response_i2s_named(
|
305 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
306 |
):
|
307 |
+
state0.evaluted_dims += 1
|
308 |
+
state1.evaluted_dims += 1
|
309 |
i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (named). ip: {get_ip(request)}")
|
310 |
vote_last_response_i2s_multi(
|
311 |
+
[state0, state1], dim_md, "bothbad_vote", False, [model_selector0, model_selector1], request
|
312 |
)
|
|
|
|
|
313 |
return (state0, state1) + (disable_btn,) * 4
|
314 |
|
315 |
|
316 |
def leftvote_last_response_i2s_anony(
|
317 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
318 |
):
|
319 |
+
state0.evaluted_dims += 1
|
320 |
+
state1.evaluted_dims += 1
|
321 |
i2s_multi_logger.info(f"leftvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
322 |
vote_last_response_i2s_multi(
|
323 |
+
[state0, state1], dim_md, "leftvote", True, [model_selector0, model_selector1], request
|
324 |
)
|
325 |
|
326 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
327 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
328 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
329 |
+
# else:
|
330 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
331 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
332 |
+
return (state0, state1) \
|
333 |
+
+ (disable_btn,) * 4 \
|
334 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
335 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
336 |
|
337 |
|
338 |
def rightvote_last_response_i2s_anony(
|
339 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
340 |
):
|
341 |
+
state0.evaluted_dims += 1
|
342 |
+
state1.evaluted_dims += 1
|
343 |
i2s_multi_logger.info(f"rightvote [{dim_md}] (anony). ip: {get_ip(request)}")
|
344 |
vote_last_response_i2s_multi(
|
345 |
+
[state0, state1], dim_md, "rightvote", True, [model_selector0, model_selector1], request
|
346 |
)
|
347 |
|
348 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
349 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
350 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
351 |
+
# else:
|
352 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
353 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
354 |
+
return (state0, state1) \
|
355 |
+
+ (disable_btn,) * 4 \
|
356 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
357 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
358 |
|
359 |
|
360 |
def tievote_last_response_i2s_anony(
|
361 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
362 |
):
|
363 |
+
state0.evaluted_dims += 1
|
364 |
+
state1.evaluted_dims += 1
|
365 |
i2s_multi_logger.info(f"tievote [{dim_md}] (anony). ip: {get_ip(request)}")
|
366 |
vote_last_response_i2s_multi(
|
367 |
+
[state0, state1], dim_md, "tievote", True, [model_selector0, model_selector1], request
|
368 |
)
|
369 |
|
370 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
371 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
372 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
373 |
+
# else:
|
374 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
375 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
376 |
+
return (state0, state1) \
|
377 |
+
+ (disable_btn,) * 4 \
|
378 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
379 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
380 |
|
381 |
|
382 |
def bothbad_vote_last_response_i2s_anony(
|
383 |
state0, state1, model_selector0, model_selector1, dim_md, request: gr.Request
|
384 |
):
|
385 |
+
state0.evaluted_dims += 1
|
386 |
+
state1.evaluted_dims += 1
|
387 |
i2s_multi_logger.info(f"bothbad_vote [{dim_md}] (anony). ip: {get_ip(request)}")
|
388 |
vote_last_response_i2s_multi(
|
389 |
+
[state0, state1], dim_md, "bothbad_vote", True, [model_selector0, model_selector1], request
|
390 |
)
|
391 |
|
392 |
+
# if state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS:
|
393 |
+
# names = (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
|
394 |
+
# return (state0, state1) + (disable_btn,) * 4 + names
|
395 |
+
# else:
|
396 |
+
# return (state0, state1) + (disable_btn,) * 4 + ("", "")
|
397 |
+
is_visible = (state0.evaluted_dims == state1.evaluted_dims == EVALUATE_DIMS)
|
398 |
+
return (state0, state1) \
|
399 |
+
+ (disable_btn,) * 4 \
|
400 |
+
+ (gr.Markdown(f"### Model A: {state0.model_name}", visible=is_visible),
|
401 |
+
gr.Markdown(f"### Model B: {state1.model_name}", visible=is_visible))
|
402 |
|
403 |
|
404 |
|
|
|
426 |
t2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
|
427 |
if state0 is not None and state1 is not None:
|
428 |
vote_last_response_t2s_multi(
|
429 |
+
[state0, state1], "", "share", True, [model_selector0, model_selector1], request
|
430 |
)
|
431 |
|
432 |
def share_click_i2s_multi(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
433 |
i2s_multi_logger.info(f"share (anony). ip: {get_ip(request)}")
|
434 |
if state0 is not None and state1 is not None:
|
435 |
vote_last_response_i2s_multi(
|
436 |
+
[state0, state1], "", "share", True, [model_selector0, model_selector1], request
|
437 |
)
|
438 |
|