AEOLLM / baseline_example /baseline3.py
陈俊杰
baseline
5f37ab9
raw
history blame
5.87 kB
import zhipuai
import traceback
import pandas as pd
from tqdm import *
import re
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2'
import random
import time
from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.generation.utils import GenerationConfig
class GLM3_6B_API():
'''
GLM3_6B_API defined by yourself
'''
def __init__(self) -> None:
self.model_name_or_path = "your_model_path"
self.init = True
def chat(self, prompt) -> str:
for _ in range(5):
if self.init:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name_or_path,
trust_remote_code=True,
device_map="auto",
torch_dtype=(
torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float32
),
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
trust_remote_code=True,
use_fast=True,
add_bos_token=False,
add_eos_token=False,
padding_side="left",
)
self.init = False
try:
print(prompt)
response, re_history, probabilities = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
print(response)
return response
except:
traceback.print_exc()
time.sleep(5)
continue
return None
glm3_6b = GLM3_6B_API()
def parse_num(res, min_score, max_score):
"""
Extract the integers within the specified range from the evaluation result.
Input: A string
Output: A score within the specified range or exception(-1)
If no numbers appear: return -1
If a fraction appears, match the numerator; exit if it falls within the range, otherwise continue.
If "out of" appears, match the preceding number; exit if it falls within the range, otherwise continue.
Extract the first number that falls within the specified range from all appearing numbers; exit if it falls within the range, otherwise continue.
If no numbers fall within the specified range, return -1.
"""
all_nums = re.findall(r"-?\d+(?:\.\d+)?", res)
probs1_nums = re.finditer(r"\b(\d+(\.\d+)?)/\d+\b" , res) # extract fraction
probs2_nums = re.finditer(r"\b(\d+(\.\d+)?)\s+out\s+of\s+\d+\b" , res) # extract "out of"
if len(all_nums) == 0:
print("this res doesn't have num! \n", res)
return -1
answer = -1
for match in probs1_nums:
answer = match.group(1)
if float(answer) >= min_score and float(answer) <= max_score:
return answer
else:
for match in probs2_nums:
answer = match.group(1)
if float(answer) >= min_score and float(answer) <= max_score:
return answer
else:
for num in all_nums:
if float(num) >= min_score and float(num) <= max_score: # the specified range
answer = num
return answer
print("this res doesn't have right num! ", res)
return -1
def get_prompt(taskId):
"""
Find the corresponding prompt based on the taskId.
"""
prompt = ""
if taskId == 0:
prompt = open("prompt/prompt_Dialog.txt", encoding='utf-8').read().strip()
elif taskId == 1:
prompt = open("prompt/prompt_Story.txt", encoding='utf-8').read().strip()
elif taskId == 2:
prompt = open("prompt/prompt_Xsum.txt", encoding='utf-8').read().strip()
elif taskId == 3:
prompt = open("prompt/prompt_NFCATS.txt", encoding='utf-8').read().strip()
return prompt
def get_model_score(taskId, question, answer, model):
"""
pointwise 5-level as an example
"""
prompt = get_prompt(taskId)
prompt = prompt.replace("{{question_text}}", question)
prompt = prompt.replace("{{answer_text}}", answer)
result = model.chat(prompt)
score = int(parse_num(result, 1, 5))
if score == -1:
score = random.randint(1,5)
return score
def get_rank(data):
"""
Calculate the rankings in descending order, and for ties, assign the lowest rank. For example, the ranking for [1,1,2] would be [2,2,1].
"""
series = pd.Series(data)
ranks = series.rank(method='min', ascending=False)
return list(map(int, ranks.tolist()))
def get_output(path, model):
"""
Obtain the results of the test set from the specified path.
"""
df = pd.read_csv(path)
row_labels = df.index
# taskId,taskName,questionId,question,answerId,answer,score,rank
model_scores = []
with open("output/baseline1_chatglm3_6B.txt", 'a') as f:
for row in tqdm(row_labels):
taskId = df.loc[row, "taskId"]
questionId = df.loc[row, "questionId"]
question = df.loc[row, "question"]
answer = df.loc[row, "answer"]
model_score = get_model_score(taskId, question, answer, model)
model_scores.append(model_score)
if len(model_scores) == 7:
ranks = get_rank(model_scores)
for i in range(7):
answerId = i
f.write(f"{taskId} {questionId} {answerId} {model_scores[i]} {ranks[i]}\n")
model_scores = []
if __name__ == '__main__':
paths = ['test/test_dialog.csv', 'test/test_NFCATS.csv', 'test/test_story.csv', 'test/test_Xsum.csv']
for path in paths[1:]:
get_output(path, glm3_6b)