Spaces:
Runtime error
Runtime error
# -*-coding:utf-8 -*- | |
from tqdm import tqdm | |
import tiktoken | |
from ape.prompt import MyTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain.llms import OpenAI | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.llm import LLMChain | |
#默认使用davinci-003来测试和评估(可控性高),使用ChatGPT生成指令(便宜) | |
Cost = { | |
'davinci': 0.02, | |
'chatgpt': 0.004 | |
} | |
class LLMGPT(object): | |
encoding = tiktoken.get_encoding("cl100k_base") | |
def __init__(self, openai_key): | |
self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True) | |
self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=0, temperature=0.7, echo=True, logprobs=1) | |
self.test_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True) | |
self.gen_chain = None | |
self.eval_chain = None | |
def confirm_cost(text, mode): | |
if mode == 'train': | |
cost = 0.02 | |
else: | |
cost = 0.0004 | |
num_tokens = len(LLMGPT.encoding.encode(text)) | |
total_price = ((num_tokens / 1000) * cost) | |
return total_price | |
def generate_instruction(self, gen_prompt, few_shot): | |
""" | |
Generate instruction | |
""" | |
if not gen_prompt: | |
gen_prompt = MyTemplate['gen_user_prompt'] | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']), | |
HumanMessagePromptTemplate.from_template(gen_prompt), | |
] | |
) | |
self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt) | |
prompt = '' | |
for shot in few_shot: | |
prompt += MyTemplate['few_shot_prompt'].format(input=shot[0], output=shot[1]) | |
result = self.gen_chain({'few_shot': prompt}) | |
return result | |
def generate_output(self, test_prompt, instruction, input): | |
if not test_prompt: | |
test_prompt = MyTemplate['test_prompt'] | |
prompt = PromptTemplate.from_template(test_prompt) | |
test_chain = LLMChain(llm=self.test_llm, prompt=prompt) | |
output = test_chain({'input': input, 'instruction': instruction}) | |
return output | |
def generate_logprobs(self, eval_prompt, instruction, eval_set): | |
""" | |
Eval instruction | |
""" | |
if not eval_prompt: | |
eval_prompt = MyTemplate['eval_prompt'] | |
prompt = PromptTemplate.from_template(eval_prompt) | |
eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt) | |
score = 0 | |
for sample in eval_set: | |
output_len = len(LLMGPT.encoding.encode(sample[1])) | |
llmresult = eval_chain.generate([{'instruction': instruction, 'input': sample[0], 'output': sample[1]}]) | |
logprobs = llmresult.generations[0][0].generation_info['logprobs'] | |
token_probs = logprobs['token_logprobs'] | |
score += sum(token_probs[-output_len:]) | |
## TODO:转成批请求,解决Rate Limit问题 | |
return score |