import os import pandas as pd import numpy as np import argparse import datasets import torch import re from thefuzz import process from typing import List from tqdm import tqdm from transformers.trainer_utils import set_seed ''' wget https://people.eecs.berkeley.edu/~hendrycks/data.tar mkdir data/mmlu mv data.tar data/mmlu cd data/mmlu; tar xf data.tar cd ../../ pip install thefuzz python eval/evaluate_chat_mmlu.py -d data/mmlu/data/ ''' def load_models_tokenizer(args): from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval() model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True) model.generation_config.do_sample = False # use greedy decoding return model, tokenizer def format_example(line): example = 'The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n' + line['question'] + "\n" for choice in choices: example += f'{choice}. {line[f"{choice}"]}\n' return example def process_before_extraction(gen, choice_dict): # replace the choice by letter in the generated sentence # from longest one to shortest one for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True): pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE) gen = pattern.sub(key, gen) return gen def extract_choice(gen, choice_list): # answer is A | choice is A | choose A res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen) # A is correct | A is right if res is None: res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen) # straight answer: A if res is None: res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen) # simply extract the first appearred letter if res is None: res = re.search(r"(?