""" Usage: python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] """ import argparse from concurrent.futures import ThreadPoolExecutor import json import numpy as np from tqdm import tqdm from fastchat.llm_judge.common import ( load_questions, load_model_answers, load_judge_prompts, check_data, play_a_match_pair, play_a_match_single, get_model_list, Judge, MatchPair, MatchSingle, NEED_REF_CATS, ) def make_match( questions, models, model_answers, judge, baseline_model, ref_answers=None, multi_turn=False, ): matches = [] for q in questions: if multi_turn and len(q["turns"]) != 2: continue for i in range(len(models)): q_id = q["question_id"] m_1 = models[i] m_2 = baseline_model if m_1 == m_2: continue a_1 = model_answers[m_1][q_id] a_2 = model_answers[baseline_model][q_id] if ref_answers is not None: ref = ref_answers[judge.model_name][q_id] match = MatchPair( dict(q), m_1, m_2, a_1, a_2, judge, ref_answer=ref, multi_turn=multi_turn, ) else: match = MatchPair( dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn ) matches.append(match) return matches def make_match_all_pairs( questions, models, model_answers, judge, baseline_model=None, ref_answers=None, multi_turn=False, ): matches = [] for q in questions: if multi_turn and len(q["turns"]) != 2: continue for i in range(len(models)): for j in range(i + 1, len(models)): q_id = q["question_id"] m_1 = models[i] m_2 = models[j] a_1 = model_answers[m_1][q_id] a_2 = model_answers[m_2][q_id] if ref_answers is not None: ref = ref_answers[judge.model_name][q_id] match = MatchPair( dict(q), m_1, m_2, a_1, a_2, judge, ref_answer=ref, multi_turn=multi_turn, ) else: match = MatchPair( dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn ) matches.append(match) return matches def make_match_single( questions, models, model_answers, judge, baseline_model=None, ref_answers=None, multi_turn=False, ): matches = [] for q in questions: if multi_turn and len(q["turns"]) != 2: continue for i in range(len(models)): q_id = q["question_id"] m = models[i] a = model_answers[m][q_id] if ref_answers is not None: ref = ref_answers[judge.model_name][q_id] matches.append( MatchSingle( dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn ) ) else: matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) return matches def make_judge_pairwise(judge_model, judge_prompts): judges = {} judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) judges["default-mt"] = Judge( judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True ) judges["math-mt"] = Judge( judge_model, judge_prompts["pair-math-v1-multi-turn"], ref_based=True, multi_turn=True, ) return judges def make_judge_single(judge_model, judge_prompts): judges = {} judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) judges["default-mt"] = Judge( judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True ) judges["math-mt"] = Judge( judge_model, judge_prompts["single-math-v1-multi-turn"], ref_based=True, multi_turn=True, ) return judges if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--bench-name", type=str, default="mt_bench", help="The name of the benchmark question set.", ) parser.add_argument( "--judge-file", type=str, default="data/judge_prompts.jsonl", help="The file of judge prompts.", ) parser.add_argument("--judge-model", type=str, default="gpt-4") parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") parser.add_argument( "--mode", type=str, default="single", choices=["pairwise-baseline", "pairwise-all", "single"], help=( "Evaluation mode. " "`pairwise-baseline` runs pairwise comparision against a baseline. " "`pairwise-all` runs pairwise comparision between all pairs. " "`single` runs single answer grading." ), ) parser.add_argument( "--model-list", type=str, nargs="+", default=None, help="A list of models to be evaluated", ) parser.add_argument( "--parallel", type=int, default=1, help="The number of concurrent API calls." ) parser.add_argument( "--first-n", type=int, help="A debug option. Only run the first `n` judgments." ) args = parser.parse_args() question_file = f"data/{args.bench_name}/question.jsonl" answer_dir = f"data/{args.bench_name}/model_answer" ref_answer_dir = f"data/{args.bench_name}/reference_answer" # Load questions questions = load_questions(question_file, None, None) # Load answers model_answers = load_model_answers(answer_dir) ref_answers = load_model_answers(ref_answer_dir) # Load judge judge_prompts = load_judge_prompts(args.judge_file) if args.first_n: questions = questions[: args.first_n] if args.model_list is None: models = get_model_list(answer_dir) else: models = args.model_list if args.mode == "single": judges = make_judge_single(args.judge_model, judge_prompts) play_a_match_func = play_a_match_single output_file = ( f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" ) make_match_func = make_match_single baseline_model = None else: judges = make_judge_pairwise(args.judge_model, judge_prompts) play_a_match_func = play_a_match_pair output_file = ( f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" ) if args.mode == "pairwise-all": make_match_func = make_match_all_pairs baseline_model = None else: make_match_func = make_match baseline_model = args.baseline_model check_data(questions, model_answers, ref_answers, models, judges) question_math = [q for q in questions if q["category"] in NEED_REF_CATS] question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] # Make matches matches = [] matches += make_match_func( question_default, models, model_answers, judges["default"], baseline_model ) matches += make_match_func( question_math, models, model_answers, judges["math"], baseline_model, ref_answers, ) matches += make_match_func( question_default, models, model_answers, judges["default-mt"], baseline_model, multi_turn=True, ) matches += make_match_func( question_math, models, model_answers, judges["math-mt"], baseline_model, ref_answers, multi_turn=True, ) match_stat = {} match_stat["bench_name"] = args.bench_name match_stat["mode"] = args.mode match_stat["judge"] = args.judge_model match_stat["baseline"] = baseline_model match_stat["model_list"] = models match_stat["total_num_questions"] = len(questions) match_stat["total_num_matches"] = len(matches) match_stat["output_path"] = output_file # Show match stats and prompt enter to continue print("Stats:") print(json.dumps(match_stat, indent=4)) input("Press Enter to confirm...") # Play matches if args.parallel == 1: for match in tqdm(matches): play_a_match_func(match, output_file=output_file) else: def play_a_match_wrapper(match): play_a_match_func(match, output_file=output_file) np.random.seed(0) np.random.shuffle(matches) with ThreadPoolExecutor(args.parallel) as executor: for match in tqdm( executor.map(play_a_match_wrapper, matches), total=len(matches) ): pass