File size: 2,043 Bytes
019ee78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a824f
 
 
019ee78
 
 
 
 
a8e5b87
019ee78
 
 
 
 
a8e5b87
019ee78
 
 
 
 
 
 
 
 
 
a8e5b87
019ee78
 
 
 
 
 
 
 
 
 
 
 
a8e5b87
019ee78
 
 
 
 
 
a8e5b87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*-coding:utf-8 -*-
from tqdm import tqdm
import tiktoken
from ape.prompt import MyTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain

Cost = {
    'davinci': 0.02,
    'chatgpt': 0.004
}


class LLMGPT(object):
    def __init__(self, openai_key, n_instruct):
        self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, n=n_instruct)
        self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, echo=True)
        self.gen_chain = None
        self.eval_chain = None
        self.init()

    @staticmethod
    def confirm_cost(text, mode):
        if mode == 'train':
            cost = 0.02
        else:
            cost = 0.0004
        encoding = tiktoken.get_encoding("cl100k_base")
        num_tokens = len(encoding.encode(text))
        total_price = ((num_tokens / 1000) * cost)
        return total_price

    def init(self):
        prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']),
                HumanMessagePromptTemplate.from_template(MyTemplate['gen_user_prompt']),
            ]
        )
        self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt)

        prompt = PromptTemplate.from_template(MyTemplate['eval_prompt'])
        self.eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt)

    def generate_instruction(self, few_shot):
        """
        Generate instruction
        """
        prompt = ''
        for shot in few_shot:
            prompt += MyTemplate['few_shot'].format(shot[0], shot[1])
        print(prompt)
        result = self.gen_chain.generate(prompt)
        return result

    def generate_logprobs(self, ):
        """
        Eval instruction
        """
        pass