{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "import sys\n", "root = '/'.join(os.path.realpath('.').replace('\\\\','/').split('/'))\n", "p = root + '/CMMLU/src'\n", "if p not in sys.path:\n", " sys.path.append(p)\n", "import argparse\n", "from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "from transformers.generation.configuration_utils import GenerationConfig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```bash\n", "git clone -- depth 1 https://github.com/haonan-li/CMMLU.git\n", "```\n", "\n", "cpoied from https://github.com/haonan-li/CMMLU/blob/master/src/hf_causal_model.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "# 加载模型\n", "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device)\n", "generation_config = GenerationConfig()\n", "generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor\n", "generation_config.eos_token_id = tokenizer.eos_token_id\n", "generation_config.pad_token_id = tokenizer.pad_token_id\n", "# for t5, set decoder_start_token_id = pad_token_id\n", "generation_config.decoder_start_token_id = tokenizer.pad_token_id \n", "generation_config.max_new_tokens = 1\n", "generation_config.num_beams = 1\n", "generation_config.do_sample = False # greedy search\n", "\n", "choices = ['A', 'B', 'C', 'D']\n", "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n", "choices_ids" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def eval(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):\n", " choice_ids = [tokenizer.convert_tokens_to_ids(choice) for choice in choices]\n", " cors = []\n", " all_conf = []\n", " all_preds = []\n", " answers = choices[: test_df.shape[1] - 2]\n", "\n", " for i in range(test_df.shape[0]):\n", " prompt_end = format_example(test_df, i, subject, include_answer=False)\n", " prompt = gen_prompt(dev_df=dev_df,\n", " subject=subject,\n", " prompt_end=prompt_end,\n", " num_few_shot=num_few_shot,\n", " tokenizer=tokenizer,\n", " max_length=max_length)\n", " inputs = tokenizer([prompt])\n", " if \"token_type_ids\" in inputs: # For Falcon\n", " inputs.pop(\"token_type_ids\")\n", " label = test_df.iloc[i, test_df.shape[1] - 1]\n", " torch.cuda.empty_cache()\n", " \n", " input_ids, attention_mask = torch.LongTensor(inputs['input_ids']), torch.LongTensor(inputs['attention_mask'])\n", " \n", " with torch.no_grad():\n", " outputs = model.generate(\n", " input_ids=input_ids.to(device),\n", " attention_mask=attention_mask.to(device),\n", " generation_config=generation_config,\n", " return_dict_in_generate=True,\n", " output_scores=True,\n", " )\n", " \n", " scores = torch.stack(outputs['scores'], dim=1).to('cpu')\n", " scores = torch.softmax(scores, dim=2)\n", " scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率\n", " conf = scores[0][choices.index(label)]\n", " choices_index = torch.argmax(scores)\n", " \n", " pred = choices[choices_index]\n", "\n", " all_preds += pred\n", " all_conf.append(conf)\n", " cors.append(pred == label)\n", "\n", " acc = np.mean(cors)\n", " print(\"Average accuracy {:.3f} - {}\".format(acc, subject))\n", " return acc, all_preds, conf" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average accuracy 0.243 - agronomy\n", "Average accuracy 0.243 - anatomy\n", "Average accuracy 0.256 - ancient_chinese\n", "Average accuracy 0.256 - arts\n", "Average accuracy 0.248 - astronomy\n", "Average accuracy 0.234 - business_ethics\n", "Average accuracy 0.256 - chinese_civil_service_exam\n", "Average accuracy 0.260 - chinese_driving_rule\n", "Average accuracy 0.235 - chinese_food_culture\n", "Average accuracy 0.252 - chinese_foreign_policy\n", "Average accuracy 0.251 - chinese_history\n", "Average accuracy 0.250 - chinese_literature\n", "Average accuracy 0.246 - chinese_teacher_qualification\n", "Average accuracy 0.253 - clinical_knowledge\n", "Average accuracy 0.245 - college_actuarial_science\n", "Average accuracy 0.318 - college_education\n", "Average accuracy 0.302 - college_engineering_hydrology\n", "Average accuracy 0.213 - college_law\n", "Average accuracy 0.219 - college_mathematics\n", "Average accuracy 0.264 - college_medical_statistics\n", "Average accuracy 0.234 - college_medicine\n", "Average accuracy 0.240 - computer_science\n", "Average accuracy 0.263 - computer_security\n", "Average accuracy 0.252 - conceptual_physics\n", "Average accuracy 0.252 - construction_project_management\n", "Average accuracy 0.239 - economics\n", "Average accuracy 0.258 - education\n", "Average accuracy 0.250 - electrical_engineering\n", "Average accuracy 0.282 - elementary_chinese\n", "Average accuracy 0.242 - elementary_commonsense\n", "Average accuracy 0.282 - elementary_information_and_technology\n", "Average accuracy 0.283 - elementary_mathematics\n", "Average accuracy 0.252 - ethnology\n", "Average accuracy 0.252 - food_science\n", "Average accuracy 0.239 - genetics\n", "Average accuracy 0.242 - global_facts\n", "Average accuracy 0.272 - high_school_biology\n", "Average accuracy 0.235 - high_school_chemistry\n", "Average accuracy 0.271 - high_school_geography\n", "Average accuracy 0.250 - high_school_mathematics\n", "Average accuracy 0.255 - high_school_physics\n", "Average accuracy 0.252 - high_school_politics\n", "Average accuracy 0.254 - human_sexuality\n", "Average accuracy 0.249 - international_law\n", "Average accuracy 0.250 - journalism\n", "Average accuracy 0.253 - jurisprudence\n", "Average accuracy 0.252 - legal_and_moral_basis\n", "Average accuracy 0.252 - logical\n", "Average accuracy 0.238 - machine_learning\n", "Average accuracy 0.243 - management\n", "Average accuracy 0.250 - marketing\n", "Average accuracy 0.249 - marxist_theory\n", "Average accuracy 0.250 - modern_chinese\n", "Average accuracy 0.241 - nutrition\n", "Average accuracy 0.257 - philosophy\n", "Average accuracy 0.251 - professional_accounting\n", "Average accuracy 0.251 - professional_law\n", "Average accuracy 0.242 - professional_medicine\n", "Average accuracy 0.246 - professional_psychology\n", "Average accuracy 0.247 - public_relations\n", "Average accuracy 0.252 - security_study\n", "Average accuracy 0.252 - sociology\n", "Average accuracy 0.248 - sports_science\n", "Average accuracy 0.254 - traditional_chinese_medicine\n", "Average accuracy 0.243 - virology\n", "Average accuracy 0.242 - world_history\n", "Average accuracy 0.256 - world_religions\n", "STEM 25.16\n", "Humanities 24.78\n", "Social Science 25.42\n", "Other 25.15\n", "China specific 25.26\n", "Overall 25.17\n" ] } ], "source": [ "from dataclasses import dataclass\n", "@dataclass\n", "class Args:\n", " data_dir: str = './CMMLU/data'\n", " save_dir: str = './result'\n", " num_few_shot: int = 0\n", " max_length: int = 512\n", "\n", "run_eval(model, tokenizer, eval, Args())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }