|  | """ | 
					
						
						|  | Clean chatbot arena chat log. | 
					
						
						|  |  | 
					
						
						|  | Usage: | 
					
						
						|  | python3 clean_chat_data.py | 
					
						
						|  | """ | 
					
						
						|  | import argparse | 
					
						
						|  | import datetime | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  | from pytz import timezone | 
					
						
						|  | import time | 
					
						
						|  |  | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  | from fastchat.serve.monitor.basic_stats import NUM_SERVERS | 
					
						
						|  | from fastchat.serve.monitor.clean_battle_data import ( | 
					
						
						|  | to_openai_format, | 
					
						
						|  | replace_model_name, | 
					
						
						|  | ) | 
					
						
						|  | from fastchat.utils import detect_language | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | NETWORK_ERROR_MSG = ( | 
					
						
						|  | "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_log_files(max_num_files=None): | 
					
						
						|  | dates = [] | 
					
						
						|  | for month in range(4, 12): | 
					
						
						|  | for day in range(1, 33): | 
					
						
						|  | dates.append(f"2023-{month:02d}-{day:02d}") | 
					
						
						|  |  | 
					
						
						|  | filenames = [] | 
					
						
						|  | for d in dates: | 
					
						
						|  | for i in range(NUM_SERVERS): | 
					
						
						|  | name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") | 
					
						
						|  | if os.path.exists(name): | 
					
						
						|  | filenames.append(name) | 
					
						
						|  | max_num_files = max_num_files or len(filenames) | 
					
						
						|  |  | 
					
						
						|  | filenames = filenames[-max_num_files:] | 
					
						
						|  | return filenames | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def clean_chat_data(log_files, action_type): | 
					
						
						|  | raw_data = [] | 
					
						
						|  | for filename in tqdm(log_files, desc="read files"): | 
					
						
						|  | for retry in range(5): | 
					
						
						|  | try: | 
					
						
						|  | lines = open(filename).readlines() | 
					
						
						|  | break | 
					
						
						|  | except FileNotFoundError: | 
					
						
						|  | time.sleep(2) | 
					
						
						|  |  | 
					
						
						|  | for l in lines: | 
					
						
						|  | row = json.loads(l) | 
					
						
						|  | if row["type"] == action_type: | 
					
						
						|  | raw_data.append(row) | 
					
						
						|  |  | 
					
						
						|  | all_models = set() | 
					
						
						|  | all_ips = dict() | 
					
						
						|  | chats = [] | 
					
						
						|  | ct_invalid_conv_id = 0 | 
					
						
						|  | ct_invalid = 0 | 
					
						
						|  | ct_network_error = 0 | 
					
						
						|  | for row in raw_data: | 
					
						
						|  | try: | 
					
						
						|  | if action_type in ["chat", "upvote", "downvote"]: | 
					
						
						|  | state = row["state"] | 
					
						
						|  | model = row["model"] | 
					
						
						|  | elif action_type == "leftvote": | 
					
						
						|  | state = row["states"][0] | 
					
						
						|  | model = row["states"][0]["model_name"] | 
					
						
						|  | elif action_type == "rightvote": | 
					
						
						|  | state = row["states"][1] | 
					
						
						|  | model = row["states"][1]["model_name"] | 
					
						
						|  | conversation_id = state["conv_id"] | 
					
						
						|  | except KeyError: | 
					
						
						|  | ct_invalid_conv_id += 1 | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | if conversation_id is None: | 
					
						
						|  | ct_invalid_conv_id += 1 | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | conversation = to_openai_format(state["messages"][state["offset"] :]) | 
					
						
						|  | if not isinstance(model, str): | 
					
						
						|  | ct_invalid += 1 | 
					
						
						|  | continue | 
					
						
						|  | model = replace_model_name(model, row["tstamp"]) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | lang_code = detect_language(state["messages"][state["offset"]][1]) | 
					
						
						|  | except IndexError: | 
					
						
						|  | ct_invalid += 1 | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | if not all(isinstance(x["content"], str) for x in conversation): | 
					
						
						|  | ct_invalid += 1 | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | messages = "".join([x["content"] for x in conversation]).lower() | 
					
						
						|  | if NETWORK_ERROR_MSG in messages: | 
					
						
						|  | ct_network_error += 1 | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | ip = row["ip"] | 
					
						
						|  | if ip not in all_ips: | 
					
						
						|  | all_ips[ip] = len(all_ips) | 
					
						
						|  | user_id = all_ips[ip] | 
					
						
						|  |  | 
					
						
						|  | chats.append( | 
					
						
						|  | dict( | 
					
						
						|  | conversation_id=conversation_id, | 
					
						
						|  | model=model, | 
					
						
						|  | conversation=conversation, | 
					
						
						|  | turn=len(conversation) // 2, | 
					
						
						|  | language=lang_code, | 
					
						
						|  | user_id=user_id, | 
					
						
						|  | tstamp=row["tstamp"], | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | all_models.update([model]) | 
					
						
						|  |  | 
					
						
						|  | chats.sort(key=lambda x: x["tstamp"]) | 
					
						
						|  | last_updated_tstamp = chats[-1]["tstamp"] | 
					
						
						|  | last_updated_datetime = datetime.datetime.fromtimestamp( | 
					
						
						|  | last_updated_tstamp, tz=timezone("US/Pacific") | 
					
						
						|  | ).strftime("%Y-%m-%d %H:%M:%S %Z") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dedup_chats = [] | 
					
						
						|  | visited_conv_ids = set() | 
					
						
						|  | for i in reversed(range(len(chats))): | 
					
						
						|  | if chats[i]["conversation_id"] in visited_conv_ids: | 
					
						
						|  | continue | 
					
						
						|  | visited_conv_ids.add(chats[i]["conversation_id"]) | 
					
						
						|  | dedup_chats.append(chats[i]) | 
					
						
						|  |  | 
					
						
						|  | print( | 
					
						
						|  | f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" | 
					
						
						|  | ) | 
					
						
						|  | print( | 
					
						
						|  | f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" | 
					
						
						|  | ) | 
					
						
						|  | print(f"#models: {len(all_models)}, {all_models}") | 
					
						
						|  | print(f"last-updated: {last_updated_datetime}") | 
					
						
						|  |  | 
					
						
						|  | return list(reversed(dedup_chats)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument("--action-type", type=str, default="chat") | 
					
						
						|  | parser.add_argument("--max-num-files", type=int) | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | log_files = get_log_files(args.max_num_files) | 
					
						
						|  | chats = clean_chat_data(log_files, args.action_type) | 
					
						
						|  | last_updated_tstamp = chats[-1]["tstamp"] | 
					
						
						|  | cutoff_date = datetime.datetime.fromtimestamp( | 
					
						
						|  | last_updated_tstamp, tz=timezone("US/Pacific") | 
					
						
						|  | ).strftime("%Y%m%d") | 
					
						
						|  |  | 
					
						
						|  | output = f"clean_{args.action_type}_conv_{cutoff_date}.json" | 
					
						
						|  | with open(output, "w") as fout: | 
					
						
						|  | json.dump(chats, fout, indent=2, ensure_ascii=False) | 
					
						
						|  | print(f"Write cleaned data to {output}") | 
					
						
						|  |  |