wangrongsheng
commited on
Commit
•
a298e99
1
Parent(s):
1a5b5a3
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```python
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import platform
|
5 |
+
import subprocess
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
#from transformers.generation.utils import GenerationConfig
|
8 |
+
from transformers.generation import GenerationConfig
|
9 |
+
import json
|
10 |
+
from tqdm import tqdm
|
11 |
+
import re
|
12 |
+
|
13 |
+
def init_model():
|
14 |
+
print("init model ...")
|
15 |
+
'''
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
"./exportchatml",
|
18 |
+
torch_dtype=torch.float16,
|
19 |
+
device_map="auto",
|
20 |
+
trust_remote_code=True
|
21 |
+
)
|
22 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
23 |
+
"./exportchatml"
|
24 |
+
)
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
26 |
+
"./exportchatml",
|
27 |
+
use_fast=False,
|
28 |
+
trust_remote_code=True
|
29 |
+
)
|
30 |
+
'''
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
32 |
+
"./exportchatml", trust_remote_code=True, resume_download=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
device_map = "auto"
|
36 |
+
|
37 |
+
model = AutoModelForCausalLM.from_pretrained(
|
38 |
+
"./exportchatml",
|
39 |
+
device_map=device_map,
|
40 |
+
trust_remote_code=True,
|
41 |
+
resume_download=True,
|
42 |
+
).eval()
|
43 |
+
'''
|
44 |
+
config = GenerationConfig.from_pretrained(
|
45 |
+
"./exportchatml", trust_remote_code=True, resume_download=True,
|
46 |
+
)
|
47 |
+
'''
|
48 |
+
config = GenerationConfig(chat_format='chatml', eos_token_id=151643, pad_token_id=151643, max_window_size=6144, max_new_tokens=512, do_sample=True, top_k=0, top_p=0.5)
|
49 |
+
|
50 |
+
return model, tokenizer, config
|
51 |
+
|
52 |
+
def main():
|
53 |
+
model, tokenizer, config = init_model()
|
54 |
+
#tokenizer = AutoTokenizer.from_pretrained("./exportchatml", trust_remote_code=True)
|
55 |
+
#model = AutoModelForCausalLM.from_pretrained("./exportchatml", trust_remote_code=True)
|
56 |
+
#model = model.eval()
|
57 |
+
|
58 |
+
with open('CMB-test-choice-question-merge.json', 'r', encoding='utf-8') as file:
|
59 |
+
data = json.load(file)
|
60 |
+
|
61 |
+
results = []
|
62 |
+
for i in tqdm(range(len(data))):
|
63 |
+
messages = []
|
64 |
+
|
65 |
+
test_id = data[i]['id']
|
66 |
+
exam_type = data[i]['exam_type']
|
67 |
+
exam_class = data[i]['exam_class']
|
68 |
+
question_type = data[i]['question_type']
|
69 |
+
question = data[i]['question']
|
70 |
+
option_str = 'A. ' + data[i]['option']['A'] + '\nB. ' + data[i]['option']['B']+ '\nC. ' + data[i]['option']['C']+ '\nD. ' + data[i]['option']['D']
|
71 |
+
|
72 |
+
#prompt = f'以下是中国{exam_type}中{exam_class}考试的一道{question_type},不需要做任何分析和解释,直接给出正确选项。\n{question}\n{option_str}'
|
73 |
+
prompt = f'以下是一道{question_type},不需要做任何分析解释,直接给出正确选项:\n{question}\n{option_str}'
|
74 |
+
#messages.append({"role": "user", "content": prompt})
|
75 |
+
#history = ''
|
76 |
+
#response = model.chat(tokenizer, messages,history)
|
77 |
+
#print(response)
|
78 |
+
#print(type(response))
|
79 |
+
'''
|
80 |
+
generation_config = GenerationConfig(max_new_tokens=1024)
|
81 |
+
|
82 |
+
text = 'User: '+prompt+'<|endoftext|>\n Assistant: '
|
83 |
+
inputs = tokenizer.encode(text, return_tensors="pt").to('cpu')
|
84 |
+
outputs = model.generate(inputs, generation_config=generation_config)
|
85 |
+
output = tokenizer.decode(outputs[0])
|
86 |
+
response = output.replace(inputs, '')
|
87 |
+
print(response)
|
88 |
+
'''
|
89 |
+
history = []
|
90 |
+
response, history = model.chat(tokenizer, prompt, history=history, generation_config=config)
|
91 |
+
print(response)
|
92 |
+
matches = re.findall("[ABCDE]", response)
|
93 |
+
print(matches)
|
94 |
+
final_result = "".join(matches)
|
95 |
+
|
96 |
+
info = {
|
97 |
+
"id": test_id,
|
98 |
+
"model_answer": final_result
|
99 |
+
}
|
100 |
+
results.append(info)
|
101 |
+
|
102 |
+
history.clear()
|
103 |
+
|
104 |
+
with open('output.json', 'w', encoding="utf-8") as f1:
|
105 |
+
json.dump(results, f1, ensure_ascii=False, indent=4)
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
main()
|
109 |
+
```
|