DSXiangLi
a
15a824f
raw
history blame
No virus
2.04 kB
# -*-coding:utf-8 -*-
from tqdm import tqdm
import tiktoken
from ape.prompt import MyTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
Cost = {
'davinci': 0.02,
'chatgpt': 0.004
}
class LLMGPT(object):
def __init__(self, openai_key, n_instruct):
self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, n=n_instruct)
self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, echo=True)
self.gen_chain = None
self.eval_chain = None
self.init()
@staticmethod
def confirm_cost(text, mode):
if mode == 'train':
cost = 0.02
else:
cost = 0.0004
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
total_price = ((num_tokens / 1000) * cost)
return total_price
def init(self):
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']),
HumanMessagePromptTemplate.from_template(MyTemplate['gen_user_prompt']),
]
)
self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt)
prompt = PromptTemplate.from_template(MyTemplate['eval_prompt'])
self.eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt)
def generate_instruction(self, few_shot):
"""
Generate instruction
"""
prompt = ''
for shot in few_shot:
prompt += MyTemplate['few_shot'].format(shot[0], shot[1])
print(prompt)
result = self.gen_chain.generate(prompt)
return result
def generate_logprobs(self, ):
"""
Eval instruction
"""
pass