""" Clean chatbot arena chat log. Usage: python3 clean_chat_data.py --mode conv_release """ import argparse import datetime import json import os from pytz import timezone import time from tqdm import tqdm from fastchat.serve.monitor.basic_stats import NUM_SERVERS from fastchat.serve.monitor.clean_battle_data import ( to_openai_format, replace_model_name, ) from fastchat.utils import detect_language NETWORK_ERROR_MSG = ( "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() ) def get_log_files(max_num_files=None): dates = [] for month in range(4, 12): for day in range(1, 33): dates.append(f"2023-{month:02d}-{day:02d}") filenames = [] for d in dates: for i in range(NUM_SERVERS): name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") if os.path.exists(name): filenames.append(name) max_num_files = max_num_files or len(filenames) # filenames = list(reversed(filenames)) filenames = filenames[-max_num_files:] return filenames def clean_chat_data(log_files, action_type): raw_data = [] for filename in tqdm(log_files, desc="read files"): for retry in range(5): try: lines = open(filename).readlines() break except FileNotFoundError: time.sleep(2) for l in lines: row = json.loads(l) if row["type"] == action_type: raw_data.append(row) all_models = set() all_ips = dict() chats = [] ct_invalid_conv_id = 0 ct_invalid = 0 ct_network_error = 0 for row in raw_data: try: if action_type in ["chat", "upvote", "downvote"]: state = row["state"] model = row["model"] elif action_type == "leftvote": state = row["states"][0] model = row["states"][0]["model_name"] elif action_type == "rightvote": state = row["states"][1] model = row["states"][1]["model_name"] conversation_id = state["conv_id"] except KeyError: ct_invalid_conv_id += 1 continue if conversation_id is None: ct_invalid_conv_id += 1 continue conversation = to_openai_format(state["messages"][state["offset"] :]) if not isinstance(model, str): ct_invalid += 1 continue model = replace_model_name(model) try: lang_code = detect_language(state["messages"][state["offset"]][1]) except IndexError: ct_invalid += 1 continue if not all(isinstance(x["content"], str) for x in conversation): ct_invalid += 1 continue messages = "".join([x["content"] for x in conversation]).lower() if NETWORK_ERROR_MSG in messages: ct_network_error += 1 continue ip = row["ip"] if ip not in all_ips: all_ips[ip] = len(all_ips) user_id = all_ips[ip] chats.append( dict( conversation_id=conversation_id, model=model, conversation=conversation, turn=len(conversation) // 2, language=lang_code, user_id=user_id, tstamp=row["tstamp"], ) ) all_models.update([model]) chats.sort(key=lambda x: x["tstamp"]) last_updated_tstamp = chats[-1]["tstamp"] last_updated_datetime = datetime.datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y-%m-%d %H:%M:%S %Z") # Deduplication dedup_chats = [] visited_conv_ids = set() for i in reversed(range(len(chats))): if chats[i]["conversation_id"] in visited_conv_ids: continue visited_conv_ids.add(chats[i]["conversation_id"]) dedup_chats.append(chats[i]) print( f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" ) print( f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" ) print(f"#models: {len(all_models)}, {all_models}") print(f"last-updated: {last_updated_datetime}") return list(reversed(dedup_chats)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--action-type", type=str, default="chat") parser.add_argument("--max-num-files", type=int) args = parser.parse_args() log_files = get_log_files(args.max_num_files) chats = clean_chat_data(log_files, args.action_type) last_updated_tstamp = chats[-1]["tstamp"] cutoff_date = datetime.datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y%m%d") output = f"clean_{args.action_type}_conv_{cutoff_date}.json" with open(output, "w") as fout: json.dump(chats, fout, indent=2, ensure_ascii=False) print(f"Write cleaned data to {output}")