|
import zhipuai |
|
import traceback |
|
import pandas as pd |
|
from tqdm import * |
|
import re |
|
import torch |
|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2' |
|
import random |
|
import time |
|
from transformers import ( |
|
AutoModelForSeq2SeqLM, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
) |
|
from transformers.generation.utils import GenerationConfig |
|
|
|
class GLM3_6B_API(): |
|
''' |
|
GLM3_6B_API defined by yourself |
|
''' |
|
def __init__(self) -> None: |
|
self.model_name_or_path = "your_model_path" |
|
self.init = True |
|
|
|
def chat(self, prompt) -> str: |
|
for _ in range(5): |
|
if self.init: |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name_or_path, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=( |
|
torch.bfloat16 |
|
if torch.cuda.is_bf16_supported() |
|
else torch.float32 |
|
), |
|
).eval() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name_or_path, |
|
trust_remote_code=True, |
|
use_fast=True, |
|
add_bos_token=False, |
|
add_eos_token=False, |
|
padding_side="left", |
|
) |
|
self.init = False |
|
try: |
|
print(prompt) |
|
response, re_history, probabilities = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False) |
|
print(response) |
|
return response |
|
except: |
|
traceback.print_exc() |
|
time.sleep(5) |
|
continue |
|
return None |
|
|
|
glm3_6b = GLM3_6B_API() |
|
|
|
def parse_num(res, min_score, max_score): |
|
""" |
|
Extract the integers within the specified range from the evaluation result. |
|
Input: A string |
|
Output: A score within the specified range or exception(-1) |
|
If no numbers appear: return -1 |
|
If a fraction appears, match the numerator; exit if it falls within the range, otherwise continue. |
|
If "out of" appears, match the preceding number; exit if it falls within the range, otherwise continue. |
|
Extract the first number that falls within the specified range from all appearing numbers; exit if it falls within the range, otherwise continue. |
|
If no numbers fall within the specified range, return -1. |
|
""" |
|
all_nums = re.findall(r"-?\d+(?:\.\d+)?", res) |
|
|
|
probs1_nums = re.finditer(r"\b(\d+(\.\d+)?)/\d+\b" , res) |
|
|
|
probs2_nums = re.finditer(r"\b(\d+(\.\d+)?)\s+out\s+of\s+\d+\b" , res) |
|
|
|
if len(all_nums) == 0: |
|
print("this res doesn't have num! \n", res) |
|
return -1 |
|
|
|
answer = -1 |
|
|
|
for match in probs1_nums: |
|
answer = match.group(1) |
|
|
|
if float(answer) >= min_score and float(answer) <= max_score: |
|
return answer |
|
else: |
|
for match in probs2_nums: |
|
answer = match.group(1) |
|
if float(answer) >= min_score and float(answer) <= max_score: |
|
return answer |
|
else: |
|
for num in all_nums: |
|
if float(num) >= min_score and float(num) <= max_score: |
|
answer = num |
|
return answer |
|
|
|
print("this res doesn't have right num! ", res) |
|
return -1 |
|
|
|
def get_prompt(taskId): |
|
""" |
|
Find the corresponding prompt based on the taskId. |
|
""" |
|
prompt = "" |
|
if taskId == 0: |
|
prompt = open("prompt/prompt_Dialog.txt", encoding='utf-8').read().strip() |
|
elif taskId == 1: |
|
prompt = open("prompt/prompt_Story.txt", encoding='utf-8').read().strip() |
|
elif taskId == 2: |
|
prompt = open("prompt/prompt_Xsum.txt", encoding='utf-8').read().strip() |
|
elif taskId == 3: |
|
prompt = open("prompt/prompt_NFCATS.txt", encoding='utf-8').read().strip() |
|
return prompt |
|
|
|
def get_model_score(taskId, question, answer, model): |
|
""" |
|
pointwise 5-level as an example |
|
""" |
|
prompt = get_prompt(taskId) |
|
prompt = prompt.replace("{{question_text}}", question) |
|
prompt = prompt.replace("{{answer_text}}", answer) |
|
result = model.chat(prompt) |
|
score = int(parse_num(result, 1, 5)) |
|
if score == -1: |
|
score = random.randint(1,5) |
|
return score |
|
|
|
def get_rank(data): |
|
""" |
|
Calculate the rankings in descending order, and for ties, assign the lowest rank. For example, the ranking for [1,1,2] would be [2,2,1]. |
|
""" |
|
series = pd.Series(data) |
|
ranks = series.rank(method='min', ascending=False) |
|
|
|
return list(map(int, ranks.tolist())) |
|
|
|
def get_output(path, model): |
|
""" |
|
Obtain the results of the test set from the specified path. |
|
""" |
|
df = pd.read_csv(path) |
|
row_labels = df.index |
|
|
|
|
|
model_scores = [] |
|
with open("output/baseline1_chatglm3_6B.txt", 'a') as f: |
|
for row in tqdm(row_labels): |
|
taskId = df.loc[row, "taskId"] |
|
questionId = df.loc[row, "questionId"] |
|
question = df.loc[row, "question"] |
|
answer = df.loc[row, "answer"] |
|
|
|
model_score = get_model_score(taskId, question, answer, model) |
|
|
|
model_scores.append(model_score) |
|
|
|
if len(model_scores) == 7: |
|
ranks = get_rank(model_scores) |
|
for i in range(7): |
|
answerId = i |
|
f.write(f"{taskId} {questionId} {answerId} {model_scores[i]} {ranks[i]}\n") |
|
model_scores = [] |
|
|
|
if __name__ == '__main__': |
|
paths = ['test/test_dialog.csv', 'test/test_NFCATS.csv', 'test/test_story.csv', 'test/test_Xsum.csv'] |
|
for path in paths[1:]: |
|
get_output(path, glm3_6b) |