carellm / README.md
wangrongsheng's picture
Create README.md
a298e99
import os
import torch
import platform
import subprocess
from transformers import AutoModelForCausalLM, AutoTokenizer
#from transformers.generation.utils import GenerationConfig
from transformers.generation import GenerationConfig
import json
from tqdm import tqdm
import re

def init_model():
    print("init model ...")
    '''
    model = AutoModelForCausalLM.from_pretrained(
        "./exportchatml",
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.generation_config = GenerationConfig.from_pretrained(
        "./exportchatml"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "./exportchatml",
        use_fast=False,
        trust_remote_code=True
    )
    '''
    tokenizer = AutoTokenizer.from_pretrained(
        "./exportchatml", trust_remote_code=True, resume_download=True,
    )

    device_map = "auto"

    model = AutoModelForCausalLM.from_pretrained(
        "./exportchatml",
        device_map=device_map,
        trust_remote_code=True,
        resume_download=True,
    ).eval()
    '''
    config = GenerationConfig.from_pretrained(
        "./exportchatml", trust_remote_code=True, resume_download=True,
    )
    '''
    config = GenerationConfig(chat_format='chatml', eos_token_id=151643, pad_token_id=151643, max_window_size=6144, max_new_tokens=512, do_sample=True, top_k=0, top_p=0.5)
    
    return model, tokenizer, config

def main():
    model, tokenizer, config = init_model()
    #tokenizer = AutoTokenizer.from_pretrained("./exportchatml", trust_remote_code=True)
    #model = AutoModelForCausalLM.from_pretrained("./exportchatml", trust_remote_code=True)
    #model = model.eval()
    
    with open('CMB-test-choice-question-merge.json', 'r', encoding='utf-8') as file:
        data = json.load(file)
    
    results = []
    for i in tqdm(range(len(data))):
        messages = []
        
        test_id = data[i]['id']
        exam_type = data[i]['exam_type']
        exam_class = data[i]['exam_class']
        question_type = data[i]['question_type']
        question = data[i]['question']
        option_str = 'A. ' + data[i]['option']['A'] + '\nB. ' + data[i]['option']['B']+ '\nC. ' + data[i]['option']['C']+ '\nD. ' + data[i]['option']['D']
        
        #prompt = f'以下是中国{exam_type}中{exam_class}考试的一道{question_type},不需要做任何分析和解释,直接给出正确选项。\n{question}\n{option_str}'
        prompt = f'以下是一道{question_type},不需要做任何分析解释,直接给出正确选项:\n{question}\n{option_str}'
        #messages.append({"role": "user", "content": prompt})
        #history = ''
        #response = model.chat(tokenizer, messages,history)
        #print(response)
        #print(type(response))
        '''
        generation_config = GenerationConfig(max_new_tokens=1024)
        
        text = 'User: '+prompt+'<|endoftext|>\n Assistant: '
        inputs = tokenizer.encode(text, return_tensors="pt").to('cpu')
        outputs = model.generate(inputs, generation_config=generation_config)
        output = tokenizer.decode(outputs[0])
        response = output.replace(inputs, '')
        print(response)
        '''
        history = []
        response, history = model.chat(tokenizer, prompt, history=history, generation_config=config)
        print(response)
        matches = re.findall("[ABCDE]", response)
        print(matches)
        final_result = "".join(matches)
    
        info = {
                "id": test_id,
                "model_answer": final_result
            }
        results.append(info)
        
        history.clear()

    with open('output.json', 'w', encoding="utf-8") as f1:
        json.dump(results, f1, ensure_ascii=False, indent=4)

if __name__ == "__main__":
    main()