In [None]:
import os
import torch
import numpy as np
import sys
root = '/'.join(os.path.realpath('.').replace('\\','/').split('/'))
p = root + '/CMMLU/src'
if p not in sys.path:
 sys.path.append(p)
import argparse
from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.generation.configuration_utils import GenerationConfig

```bash
git clone -- depth 1 https://github.com/haonan-li/CMMLU.git
```

cpoied from https://github.com/haonan-li/CMMLU/blob/master/src/hf_causal_model.py

In [None]:
model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device)
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
# for t5, set decoder_start_token_id = pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id 
generation_config.max_new_tokens = 1
generation_config.num_beams = 1
generation_config.do_sample = False # greedy search

choices = ['A', 'B', 'C', 'D']
choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]
choices_ids

In [3]:
def eval(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):
 choice_ids = [tokenizer.convert_tokens_to_ids(choice) for choice in choices]
 cors = []
 all_conf = []
 all_preds = []
 answers = choices[: test_df.shape[1] - 2]

 for i in range(test_df.shape[0]):
 prompt_end = format_example(test_df, i, subject, include_answer=False)
 prompt = gen_prompt(dev_df=dev_df,
 subject=subject,
 prompt_end=prompt_end,
 num_few_shot=num_few_shot,
 tokenizer=tokenizer,
 max_length=max_length)
 inputs = tokenizer([prompt])
 if "token_type_ids" in inputs: # For Falcon
 inputs.pop("token_type_ids")
 label = test_df.iloc[i, test_df.shape[1] - 1]
 torch.cuda.empty_cache()
 
 input_ids, attention_mask = torch.LongTensor(inputs['input_ids']), torch.LongTensor(inputs['attention_mask'])
 
 with torch.no_grad():
 outputs = model.generate(
 input_ids=input_ids.to(device),
 attention_mask=attention_mask.to(device),
 generation_config=generation_config,
 return_dict_in_generate=True,
 output_scores=True,
 )
 
 scores = torch.stack(outputs['scores'], dim=1).to('cpu')
 scores = torch.softmax(scores, dim=2)
 scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率
 conf = scores[0][choices.index(label)]
 choices_index = torch.argmax(scores)
 
 pred = choices[choices_index]

 all_preds += pred
 all_conf.append(conf)
 cors.append(pred == label)

 acc = np.mean(cors)
 print("Average accuracy {:.3f} - {}".format(acc, subject))
 return acc, all_preds, conf

In [4]:
from dataclasses import dataclass
@dataclass
class Args:
 data_dir: str = './CMMLU/data'
 save_dir: str = './result'
 num_few_shot: int = 0
 max_length: int = 512

run_eval(model, tokenizer, eval, Args())

Average accuracy 0.243 - agronomy
Average accuracy 0.243 - anatomy
Average accuracy 0.256 - ancient_chinese
Average accuracy 0.256 - arts
Average accuracy 0.248 - astronomy
Average accuracy 0.234 - business_ethics
Average accuracy 0.256 - chinese_civil_service_exam
Average accuracy 0.260 - chinese_driving_rule
Average accuracy 0.235 - chinese_food_culture
Average accuracy 0.252 - chinese_foreign_policy
Average accuracy 0.251 - chinese_history
Average accuracy 0.250 - chinese_literature
Average accuracy 0.246 - chinese_teacher_qualification
Average accuracy 0.253 - clinical_knowledge
Average accuracy 0.245 - college_actuarial_science
Average accuracy 0.318 - college_education
Average accuracy 0.302 - college_engineering_hydrology
Average accuracy 0.213 - college_law
Average accuracy 0.219 - college_mathematics
Average accuracy 0.264 - college_medical_statistics
Average accuracy 0.234 - college_medicine
Average accuracy 0.240 - computer_science
Average accuracy 0.263 - computer_security
