File size: 5,144 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
"""
Clean chatbot arena chat log.
Usage:
python3 clean_chat_data.py
"""
import argparse
import datetime
import json
import os
from pytz import timezone
import time
from tqdm import tqdm
from fastchat.serve.monitor.basic_stats import NUM_SERVERS
from fastchat.serve.monitor.clean_battle_data import (
to_openai_format,
replace_model_name,
)
from fastchat.utils import detect_language
NETWORK_ERROR_MSG = (
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
)
def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
if os.path.exists(name):
filenames.append(name)
max_num_files = max_num_files or len(filenames)
# filenames = list(reversed(filenames))
filenames = filenames[-max_num_files:]
return filenames
def clean_chat_data(log_files, action_type):
raw_data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)
for l in lines:
row = json.loads(l)
if row["type"] == action_type:
raw_data.append(row)
all_models = set()
all_ips = dict()
chats = []
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
for row in raw_data:
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
continue
if conversation_id is None:
ct_invalid_conv_id += 1
continue
conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
continue
model = replace_model_name(model, row["tstamp"])
try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
continue
if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
continue
messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
continue
ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
chats.append(
dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)
)
all_models.update([model])
chats.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = chats[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")
# Deduplication
dedup_chats = []
visited_conv_ids = set()
for i in reversed(range(len(chats))):
if chats[i]["conversation_id"] in visited_conv_ids:
continue
visited_conv_ids.add(chats[i]["conversation_id"])
dedup_chats.append(chats[i])
print(
f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}"
)
print(
f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}"
)
print(f"#models: {len(all_models)}, {all_models}")
print(f"last-updated: {last_updated_datetime}")
return list(reversed(dedup_chats))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--action-type", type=str, default="chat")
parser.add_argument("--max-num-files", type=int)
args = parser.parse_args()
log_files = get_log_files(args.max_num_files)
chats = clean_chat_data(log_files, args.action_type)
last_updated_tstamp = chats[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")
output = f"clean_{args.action_type}_conv_{cutoff_date}.json"
with open(output, "w") as fout:
json.dump(chats, fout, indent=2, ensure_ascii=False)
print(f"Write cleaned data to {output}")
|