FIRE / src /serve /monitor /clean_chat_data.py
zhangbofei
feat: change to fstchat
6dc0c9c
"""
Clean chatbot arena chat log.
Usage:
python3 clean_chat_data.py
"""
import argparse
import datetime
import json
import os
from pytz import timezone
import time
from tqdm import tqdm
from fastchat.serve.monitor.basic_stats import NUM_SERVERS
from fastchat.serve.monitor.clean_battle_data import (
to_openai_format,
replace_model_name,
)
from fastchat.utils import detect_language
NETWORK_ERROR_MSG = (
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
)
def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
if os.path.exists(name):
filenames.append(name)
max_num_files = max_num_files or len(filenames)
# filenames = list(reversed(filenames))
filenames = filenames[-max_num_files:]
return filenames
def clean_chat_data(log_files, action_type):
raw_data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)
for l in lines:
row = json.loads(l)
if row["type"] == action_type:
raw_data.append(row)
all_models = set()
all_ips = dict()
chats = []
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
for row in raw_data:
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
continue
if conversation_id is None:
ct_invalid_conv_id += 1
continue
conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
continue
model = replace_model_name(model, row["tstamp"])
try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
continue
if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
continue
messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
continue
ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
chats.append(
dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)
)
all_models.update([model])
chats.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = chats[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")
# Deduplication
dedup_chats = []
visited_conv_ids = set()
for i in reversed(range(len(chats))):
if chats[i]["conversation_id"] in visited_conv_ids:
continue
visited_conv_ids.add(chats[i]["conversation_id"])
dedup_chats.append(chats[i])
print(
f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}"
)
print(
f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}"
)
print(f"#models: {len(all_models)}, {all_models}")
print(f"last-updated: {last_updated_datetime}")
return list(reversed(dedup_chats))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--action-type", type=str, default="chat")
parser.add_argument("--max-num-files", type=int)
args = parser.parse_args()
log_files = get_log_files(args.max_num_files)
chats = clean_chat_data(log_files, args.action_type)
last_updated_tstamp = chats[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")
output = f"clean_{args.action_type}_conv_{cutoff_date}.json"
with open(output, "w") as fout:
json.dump(chats, fout, indent=2, ensure_ascii=False)
print(f"Write cleaned data to {output}")