xl2533's picture
a
e998926
# -*-coding:utf-8 -*-
from tqdm import tqdm
import tiktoken
from ape.prompt import MyTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
#默认使用davinci-003来测试和评估(可控性高),使用ChatGPT生成指令(便宜)
Cost = {
'davinci': 0.02,
'chatgpt': 0.004
}
class LLMGPT(object):
encoding = tiktoken.get_encoding("cl100k_base")
def __init__(self, openai_key):
self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True)
self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=0, temperature=0.7, echo=True, logprobs=1)
self.test_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True)
self.gen_chain = None
self.eval_chain = None
@staticmethod
def confirm_cost(text, mode):
if mode == 'train':
cost = 0.02
else:
cost = 0.0004
num_tokens = len(LLMGPT.encoding.encode(text))
total_price = ((num_tokens / 1000) * cost)
return total_price
def generate_instruction(self, gen_prompt, few_shot):
"""
Generate instruction
"""
if not gen_prompt:
gen_prompt = MyTemplate['gen_user_prompt']
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']),
HumanMessagePromptTemplate.from_template(gen_prompt),
]
)
self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt)
prompt = ''
for shot in few_shot:
prompt += MyTemplate['few_shot_prompt'].format(input=shot[0], output=shot[1])
result = self.gen_chain({'few_shot': prompt})
return result
def generate_output(self, test_prompt, instruction, input):
if not test_prompt:
test_prompt = MyTemplate['test_prompt']
prompt = PromptTemplate.from_template(test_prompt)
test_chain = LLMChain(llm=self.test_llm, prompt=prompt)
output = test_chain({'input': input, 'instruction': instruction})
return output
def generate_logprobs(self, eval_prompt, instruction, eval_set):
"""
Eval instruction
"""
if not eval_prompt:
eval_prompt = MyTemplate['eval_prompt']
prompt = PromptTemplate.from_template(eval_prompt)
eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt)
score = 0
for sample in eval_set:
output_len = len(LLMGPT.encoding.encode(sample[1]))
llmresult = eval_chain.generate([{'instruction': instruction, 'input': sample[0], 'output': sample[1]}])
logprobs = llmresult.generations[0][0].generation_info['logprobs']
token_probs = logprobs['token_logprobs']
score += sum(token_probs[-output_len:])
## TODO:转成批请求,解决Rate Limit问题
return score