Spaces:
Runtime error
Runtime error
# -*-coding:utf-8 -*- | |
from tqdm import tqdm | |
import tiktoken | |
from ape.prompt import MyTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain.llms import OpenAI | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.llm import LLMChain | |
Cost = { | |
'davinci': 0.02, | |
'chatgpt': 0.004 | |
} | |
class LLMGPT(object): | |
def __init__(self, openai_key, n_instruct): | |
self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, n=n_instruct) | |
self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, echo=True) | |
self.gen_chain = None | |
self.eval_chain = None | |
self.init() | |
def confirm_cost(text, mode): | |
if mode == 'train': | |
cost = 0.02 | |
else: | |
cost = 0.0004 | |
encoding = tiktoken.get_encoding("cl100k_base") | |
num_tokens = len(encoding.encode(text)) | |
total_price = ((num_tokens / 1000) * cost) | |
return total_price | |
def init(self): | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']), | |
HumanMessagePromptTemplate.from_template(MyTemplate['gen_user_prompt']), | |
] | |
) | |
self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt) | |
prompt = PromptTemplate.from_template(MyTemplate['eval_prompt']) | |
self.eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt) | |
def generate_instruction(self, few_shot): | |
""" | |
Generate instruction | |
""" | |
prompt = '' | |
for shot in few_shot: | |
prompt += MyTemplate['few_shot'].format(shot[0], shot[1]) | |
print(prompt) | |
result = self.gen_chain.generate(prompt) | |
return result | |
def generate_logprobs(self, ): | |
""" | |
Eval instruction | |
""" | |
pass | |