Spaces:
Running
Running
""" | |
Clean chatbot arena battle log. | |
Usage: | |
python3 clean_battle_data.py --mode conv_release | |
""" | |
import argparse | |
import datetime | |
import json | |
import os | |
import sys | |
from pytz import timezone | |
import time | |
import PIL | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
from tqdm import tqdm | |
from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR | |
from .utils import detect_language, get_time_stamp_from_date | |
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] | |
def remove_html(raw): | |
if raw.startswith("<h3>"): | |
return raw[raw.find(": ") + 2 : -len("</h3>\n")] | |
if raw.startswith("### Model A: ") or raw.startswith("### Model B: "): | |
return raw[13:] | |
return raw | |
def to_openai_format(messages): | |
roles = ["user", "assistant"] | |
ret = [] | |
for i, x in enumerate(messages): | |
ret.append({"role": roles[i % 2], "content": x[1]}) | |
return ret | |
def replace_model_name(old_name, tstamp): | |
replace_dict = { | |
"point-e-t": "point-e", | |
"shap-e-t": "shap-e", | |
"point-e-i": "point-e", | |
"shap-e-i": "shap-e", | |
"point-e_t": "point-e", | |
"shap-e_t": "shap-e", | |
"point-e_i": "point-e", | |
"shap-e_i": "shap-e", | |
} | |
if old_name in replace_dict.keys(): | |
return replace_dict[old_name] | |
return old_name | |
def replace_dim(dim_name): | |
replace_dict = { | |
"Geometry Quality": "Geometry Details", | |
} | |
if dim_name.endswith(": "): | |
dim_name = dim_name[:-2] | |
if dim_name in replace_dict.keys(): | |
return replace_dict[dim_name] | |
return dim_name | |
def read_file(filename): | |
data = [] | |
for retry in range(5): | |
try: | |
# lines = open(filename).readlines() | |
for l in open(filename): | |
row = json.loads(l) | |
if row["type"] in VOTES: | |
data.append(row) | |
break | |
except FileNotFoundError: | |
time.sleep(2) | |
return data | |
def read_file_parallel(log_files, num_threads=16): | |
data_all = [] | |
from multiprocessing import Pool | |
with Pool(num_threads) as p: | |
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) | |
for ret in ret_all: | |
data_all.extend(ret) | |
return data_all | |
def load_image(image_path): | |
try: | |
return PIL.Image.open(image_path) | |
except: | |
return None | |
def clean_battle_data( | |
log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False, mode="simple", task_name="text2shape" | |
): | |
data = read_file_parallel(log_files, num_threads=16) | |
convert_type = { | |
"leftvote": "model_a", | |
"rightvote": "model_b", | |
"tievote": "tie", | |
"bothbad_vote": "tie (bothbad)", | |
} | |
all_models = set() | |
all_ips = dict() | |
dim_counts = dict() | |
ct_anony = 0 | |
ct_invalid = 0 | |
ct_leaked_identity = 0 | |
ct_banned = 0 | |
battles = [] | |
for row in tqdm(data, desc="Cleaning"): | |
if row["models"][0] is None or row["models"][1] is None: | |
continue | |
# Resolve model names | |
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] | |
if "model_name" in row["states"][0]: | |
models_hidden = [ | |
row["states"][0]["model_name"], | |
row["states"][1]["model_name"], | |
] | |
if models_hidden[0] is None: | |
models_hidden = models_public | |
else: | |
models_hidden = models_public | |
if (models_public[0] == "" and models_public[1] != "") or ( | |
models_public[1] == "" and models_public[0] != "" | |
): | |
ct_invalid += 1 | |
continue | |
if not models_public == models_hidden: | |
ct_invalid += 1 | |
continue | |
else: | |
models = models_hidden | |
if 'anony' not in row.keys(): | |
ct_invalid += 1 | |
continue | |
else: | |
anony = row['anony'] | |
# # Detect langauge | |
# state = row["states"][0] | |
# if state["offset"] >= len(state["messages"]): | |
# ct_invalid += 1 | |
# continue | |
# lang_code = detect_language(state["messages"][state["offset"]][1]) | |
# # Drop conversations if the model names are leaked | |
# leaked_identity = False | |
# messages = "" | |
# for i in range(2): | |
# state = row["states"][i] | |
# for turn_idx, (role, msg) in enumerate( | |
# state["messages"][state["offset"] :] | |
# ): | |
# if msg: | |
# messages += msg.lower() | |
# for word in IDENTITY_WORDS: | |
# if word in messages: | |
# leaked_identity = True | |
# break | |
# if leaked_identity: | |
# ct_leaked_identity += 1 | |
# continue | |
# Replace bard with palm | |
# if task_name == "image_editing": | |
# if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models): | |
# # print(f"Invalid model names: {models}") | |
# ct_invalid += 1 | |
# continue | |
# models = [x[len("imagenhub_"):-len("_edition")] for x in models] | |
# elif task_name == "t2i_generation": | |
# if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models): | |
# # print(f"Invalid model names: {models}") | |
# ct_invalid += 1 | |
# continue | |
# # models = [x[len("imagenhub_"):-len("_generation")] for x in models] | |
# for i, model_name in enumerate(models): | |
# if model_name.startswith("imagenhub_"): | |
# models[i] = model_name[len("imagenhub_"):-len("_generation")] | |
if task_name == 'text2shape': | |
if row['states'][0]['i2s_mode'] or row['states'][1]['i2s_mode']: | |
ct_invalid += 1 | |
continue | |
elif task_name == 'image2shape': | |
if not row['states'][0]['i2s_mode'] or not row['states'][1]['i2s_mode']: | |
ct_invalid += 1 | |
continue | |
else: | |
raise ValueError(f"Invalid task_name: {task_name}") | |
models = [replace_model_name(m, row["tstamp"]) for m in models] | |
if anony: | |
ct_anony += 1 | |
# Exclude certain models | |
if exclude_model_names and any(x in exclude_model_names for x in models): | |
ct_invalid += 1 | |
continue | |
# if models[0] not in model_infos or models[1] not in model_infos: | |
# continue | |
# # Exclude votes before the starting date | |
# if model_infos and (model_infos[models[0]]["starting_from"] > row["tstamp"] or model_infos[models[1]]["starting_from"] > row["tstamp"]): | |
# print(f"Invalid vote before the valid starting date for {models[0]} and {models[1]}") | |
# ct_invalid += 1 | |
# continue | |
if mode == "conv_release": | |
if row['states'][0]['offline'] != row['states'][1]['offline']: | |
ct_invalid += 1 | |
continue | |
elif row['states'][0]['offline']: | |
if row['states'][0]['offline_idx'] != row['states'][1]['offline_idx']: | |
ct_invalid += 1 | |
continue | |
else: | |
# assert the two images are the same | |
date = datetime.datetime.fromtimestamp(row["tstamp"], tz=timezone("US/Pacific")).strftime("%Y-%m-%d") # 2024-02-29 | |
image_path_format = f"{LOG_ROOT_DIR}/{date}-convinput_images/input_image_" | |
image_path_0 = image_path_format + str(row["states"][0]["conv_id"]) + ".png" | |
image_path_1 = image_path_format + str(row["states"][1]["conv_id"]) + ".png" | |
if not os.path.exists(image_path_0) or not os.path.exists(image_path_1): | |
print(f"Image not found for {image_path_0} or {image_path_1}") | |
ct_invalid += 1 | |
continue | |
image_0 = load_image(image_path_0) | |
image_1 = load_image(image_path_1) | |
if image_0 is None or image_1 is None: | |
print(f"Image not found for {image_path_0} or {image_path_1}") | |
ct_invalid += 1 | |
continue | |
if image_0.tobytes() != image_1.tobytes(): | |
print(f"Image not the same for {image_path_0} and {image_path_1}") | |
ct_invalid += 1 | |
continue | |
question_id = row["states"][0]["conv_id"] | |
# conversation_a = to_openai_format( | |
# row["states"][0]["messages"][row["states"][0]["offset"] :] | |
# ) | |
# conversation_b = to_openai_format( | |
# row["states"][1]["messages"][row["states"][1]["offset"] :] | |
# ) | |
ip = row["ip"] | |
if ip not in all_ips: | |
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)} | |
all_ips[ip]["count"] += 1 | |
if sanitize_ip: | |
user_id = f"arena_user_{all_ips[ip]['sanitized_id']}" | |
else: | |
user_id = f"{all_ips[ip]['ip']}" | |
if ban_ip_list is not None and ip in ban_ip_list: | |
ct_banned += 1 | |
continue | |
dim = replace_dim(row['dim']) | |
if dim not in dim_counts.keys(): | |
dim_counts[dim] = 0 | |
dim_counts[dim] += 1 | |
# Save the results | |
battles.append( | |
dict( | |
question_id=question_id, | |
dim=dim, | |
model_a=models[0], | |
model_b=models[1], | |
winner=convert_type[row["type"]], | |
judge=f"arena_user_{user_id}", | |
# conversation_a=conversation_a, | |
# conversation_b=conversation_b, | |
idx=row['states'][0]['offline_idx'], | |
anony=anony, | |
# language=lang_code, | |
tstamp=row["tstamp"], | |
) | |
) | |
all_models.update(models) | |
battles.sort(key=lambda x: x["tstamp"]) | |
last_updated_tstamp = battles[-1]["tstamp"] | |
last_updated_datetime = datetime.datetime.fromtimestamp( | |
last_updated_tstamp, tz=timezone("US/Pacific") | |
).strftime("%Y-%m-%d %H:%M:%S %Z") | |
print( | |
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " | |
f"#leaked_identity: {ct_leaked_identity} " | |
f"#banned: {ct_banned} " | |
) | |
print(f"#battles: {len(battles)}, #anony: {ct_anony}") | |
print(f"#models: {len(all_models)}, {all_models}") | |
for dim, count in dim_counts.items(): | |
print(dim, ": ", count) | |
print(f"last-updated: {last_updated_datetime}") | |
if ban_ip_list is not None: | |
for ban_ip in ban_ip_list: | |
if ban_ip in all_ips: | |
del all_ips[ban_ip] | |
print("Top 30 IPs:") | |
print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30]) | |
return battles | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--max-num-files", type=int) | |
parser.add_argument( | |
"--mode", type=str, choices=["simple", "conv_release"], default="conv_release" | |
) | |
parser.add_argument("--task_name", type=str, choices=["text2shape", "image2shape"]) | |
parser.add_argument("--exclude-model-names", type=str, nargs="+") | |
parser.add_argument("--ban-ip-file", type=str) | |
parser.add_argument("--sanitize-ip", action="store_true", default=False) | |
args = parser.parse_args() | |
log_files = get_log_files(args.max_num_files) | |
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None | |
battles = clean_battle_data( | |
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip, args.mode, args.task_name | |
) | |
last_updated_tstamp = battles[-1]["tstamp"] | |
cutoff_date = datetime.datetime.fromtimestamp( | |
last_updated_tstamp, tz=timezone("US/Pacific") | |
).strftime("%Y%m%d") | |
if args.mode == "simple": | |
for x in battles: | |
for key in [ | |
"conversation_a", | |
"conversation_b", | |
"question_id", | |
]: | |
if key in x: | |
del x[key] | |
print("Samples:") | |
for i in range(min(4, len(battles))): | |
print(battles[i]) | |
output = f"clean_battle_{args.task_name}_{cutoff_date}.json" | |
elif args.mode == "conv_release": | |
# new_battles = [] | |
# for x in battles: | |
# if not x["anony"]: | |
# continue | |
# for key in []: | |
# del x[key] | |
# new_battles.append(x) | |
# battles = new_battles | |
output = f"clean_battle_{args.task_name}_conv_{cutoff_date}.json" | |
with open(output, "w") as fout: | |
json.dump(battles, fout, indent=2, ensure_ascii=False) | |
print(f"Write cleaned data to {output}") | |
with open("cut_off_date.txt", "w") as fout: | |
fout.write(cutoff_date) |