import argparse import json import pandas as pd import os import time import concurrent.futures import tqdm import yaml import random import threading import orjson from category import Category LOCK = threading.RLock() TASKS = None CACHE_DICT = None OUTPUT_DICT = None # API setting constants API_MAX_RETRY = None API_RETRY_SLEEP = None API_ERROR_OUTPUT = None # load config args from config yaml files def make_config(config_file: str) -> dict: config_kwargs = {} with open(config_file, "r") as f: config_kwargs = yaml.load(f, Loader=yaml.SafeLoader) return config_kwargs def get_endpoint(endpoint_list): if endpoint_list is None: return None assert endpoint_list is not None # randomly pick one api_dict = random.choices(endpoint_list)[0] return api_dict def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None): import openai if api_dict: client = openai.OpenAI( base_url=api_dict["api_base"], api_key=api_dict["api_key"], ) else: client = openai.OpenAI() output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: # print(messages) completion = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, ) output = completion.choices[0].message.content # print(output) break except openai.RateLimitError as e: print(type(e), e) time.sleep(API_RETRY_SLEEP) except openai.BadRequestError as e: print(messages) print(type(e), e) break except openai.APIConnectionError as e: print(messages) print(type(e), e) time.sleep(API_RETRY_SLEEP) except openai.InternalServerError as e: print(messages) print(type(e), e) time.sleep(API_RETRY_SLEEP) except Exception as e: print(type(e), e) break return output def get_answer( question: dict, model_name: str, max_tokens: int, temperature: float, answer_file: str, api_dict: dict, categories: list, testing: bool, ): if "category_tag" in question: category_tag = question["category_tag"] else: category_tag = {} output_log = {} for category in categories: conv = category.pre_process(question["prompt"]) output = chat_completion_openai( model=model_name, messages=conv, temperature=temperature, max_tokens=max_tokens, api_dict=api_dict, ) # Dump answers category_tag[category.name_tag] = category.post_process(output) if testing: output_log[category.name_tag] = output question["category_tag"] = category_tag if testing: question["output_log"] = output_log question.drop(["prompt", "uid", "required_tasks"], inplace=True) with LOCK: with open(answer_file, "a") as fout: fout.write(json.dumps(question.to_dict()) + "\n") def category_merge(row): id = row["uid"] input_category = row["category_tag"] if "category_tag" in row else {} cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} # tries to fill in missing categories using cache first, then output for name in TASKS: if name not in input_category: if name in cache_category: input_category[name] = cache_category[name] continue if name in output_category: input_category[name] = output_category[name] return input_category def find_required_tasks(row): id = row["uid"] input_category = row["category_tag"] if "category_tag" in row else {} cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} return [ name for name in TASKS if not ( name in input_category or name in cache_category or name in output_category ) ] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--testing", action="store_true") args = parser.parse_args() enter = input( "Make sure your config file is properly configured. Press enter to continue." ) if not enter == "": exit() config = make_config(args.config) API_MAX_RETRY = config["max_retry"] API_RETRY_SLEEP = config["retry_sleep"] API_ERROR_OUTPUT = config["error_output"] categories = [Category.create_category(name) for name in config["task_name"]] TASKS = config["task_name"] print( f"Following categories will be labeled:\n{[category.name_tag for category in categories]}" ) print("loading input data (might take min)") with open(config["input_file"], "rb") as f: data = orjson.loads(f.read()) input_data = pd.DataFrame(data) # much faster than pd.apply input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) assert len(input_data) == len(input_data.uid.unique()) print(f"{len(input_data)}# of input data just loaded") if config["cache_file"]: print("loading cache data") with open(config["cache_file"], "rb") as f: data = orjson.loads(f.read()) cache_data = pd.DataFrame(data) cache_data["uid"] = cache_data.question_id.map(str) + cache_data.tstamp.map(str) assert len(cache_data) == len(cache_data.uid.unique()) print(f"{len(cache_data)}# of cache data just loaded") assert "category_tag" in cache_data.columns cache_dict = cache_data[["uid", "category_tag"]].set_index("uid") print("finalizing cache_dict (should take less than 30 sec)") CACHE_DICT = cache_dict.to_dict("index") else: CACHE_DICT = {} if os.path.isfile(config["output_file"]): print("loading existing output") output_data = pd.read_json(config["output_file"], lines=True) output_data["uid"] = output_data.question_id.map(str) + output_data.tstamp.map( str ) assert len(output_data) == len(output_data.uid.unique()) print(f"{len(output_data)}# of existing output just loaded") assert "category_tag" in output_data.columns output_dict = output_data[["uid", "category_tag"]].set_index("uid") print("finalizing output_dict (should take less than 30 sec)") OUTPUT_DICT = output_dict.to_dict("index") else: OUTPUT_DICT = {} print( "finding tasks needed to run... (should take around 1 minute or less on large dataset)" ) input_data["required_tasks"] = input_data.apply(find_required_tasks, axis=1) not_labeled = input_data[input_data.required_tasks.map(lambda x: len(x) > 0)].copy() print(f"{len(not_labeled)} # of conversations needs to be labeled") for name in TASKS: print( f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" ) not_labeled["prompt"] = not_labeled.conversation_a.map( lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) ) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor( max_workers=config["parallel"] ) as executor: futures = [] for index, row in tqdm.tqdm(not_labeled.iterrows()): future = executor.submit( get_answer, row, config["model_name"], config["max_token"], config["temperature"], config["output_file"], get_endpoint(config["endpoints"]), [ category for category in categories if category.name_tag in row["required_tasks"] ], args.testing, ) futures.append(future) for future in tqdm.tqdm( concurrent.futures.as_completed(futures), total=len(futures) ): future.result() if config["convert_to_json"]: # merge two data frames, but only take the fields from the cache data to overwrite the input data merge_columns = [category.name_tag for category in categories] print(f"Columns to be merged:\n{merge_columns}") input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) assert len(input_data) == len(input_data.uid.unique()) # fastest way to merge assert os.path.isfile(config["output_file"]) print("reading output file...") temp = pd.read_json(config["output_file"], lines=True) temp["uid"] = temp.question_id.map(str) + temp.tstamp.map(str) assert len(temp) == len(temp.uid.unique()) assert "category_tag" in temp.columns output_dict = temp[["uid", "category_tag"]].set_index("uid") print("finalizing output_dict (should take less than 30 sec)") OUTPUT_DICT = output_dict.to_dict("index") print("begin merging (should take around 1 minute or less on large dataset)") input_data["category_tag"] = input_data.apply(category_merge, axis=1) print("merge completed") final_data = input_data.drop( columns=["prompt", "uid", "required_tasks"], errors="ignore" ) final_data.to_json( config["output_file"][:-1], orient="records", indent=4, force_ascii=False )