File size: 3,228 Bytes
019ee78
 
 
 
 
 
 
 
 
 
 
 
 
 
13d0afb
019ee78
 
 
 
 
 
 
29be613
e907f96
 
a536ff7
faf58a4
019ee78
 
 
 
a8e5b87
019ee78
 
 
 
29be613
 
019ee78
 
 
1bfaaaf
 
 
 
 
 
019ee78
 
 
1bfaaaf
019ee78
 
a8e5b87
019ee78
 
 
d537b9c
698a3ea
019ee78
 
1bfaaaf
 
faf58a4
 
34e7ce5
 
c9023c8
 
1bfaaaf
019ee78
 
 
5c6545c
1bfaaaf
 
 
34e7ce5
29be613
1bfaaaf
29be613
110a394
29be613
 
2e05be9
5c6545c
29be613
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*-coding:utf-8 -*-
from tqdm import tqdm
import tiktoken
from ape.prompt import MyTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain

#默认使用davinci-003来测试和评估(可控性高),使用ChatGPT生成指令(便宜)
Cost = {
    'davinci': 0.02,
    'chatgpt': 0.004
}


class LLMGPT(object):
    encoding = tiktoken.get_encoding("cl100k_base")
    def __init__(self, openai_key):
        self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True)
        self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=0, temperature=0.7, echo=True, logprobs=1)
        self.test_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7,  verbose=True)
        self.gen_chain = None
        self.eval_chain = None

    @staticmethod
    def confirm_cost(text, mode):
        if mode == 'train':
            cost = 0.02
        else:
            cost = 0.0004

        num_tokens = len(LLMGPT.encoding.encode(text))
        total_price = ((num_tokens / 1000) * cost)
        return total_price

    def generate_instruction(self, gen_prompt, few_shot):
        """
        Generate instruction
        """
        if not gen_prompt:
            gen_prompt = MyTemplate['gen_user_prompt']
        prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']),
                HumanMessagePromptTemplate.from_template(gen_prompt),
            ]
        )
        self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt)

        prompt = ''
        for shot in few_shot:
            prompt += MyTemplate['few_shot_prompt'].format(input=shot[0], output=shot[1])
        result = self.gen_chain({'few_shot': prompt})
        return result

    def generate_output(self, test_prompt, instruction, input):
        if not test_prompt:
            test_prompt = MyTemplate['test_prompt']
        prompt = PromptTemplate.from_template(test_prompt)
        test_chain = LLMChain(llm=self.test_llm, prompt=prompt)
        output = test_chain({'input': input, 'instruction': instruction})
        return output

    def generate_logprobs(self, eval_prompt, instruction, eval_set):
        """
        Eval instruction
        """

        if not eval_prompt:
            eval_prompt = MyTemplate['eval_prompt']
        prompt = PromptTemplate.from_template(eval_prompt)
        eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt)
        score = 0
        for sample in eval_set:
            output_len = len(LLMGPT.encoding.encode(sample[1]))
            llmresult = eval_chain.generate([{'instruction': instruction, 'input': sample[0], 'output': sample[1]}])
            logprobs = llmresult.generations[0][0].generation_info['logprobs']
            token_probs = logprobs['token_logprobs']
            score += sum(token_probs[-output_len:])
        ## TODO:转成批请求,解决Rate Limit问题
        return score