File size: 3,246 Bytes
019ee78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29be613
 
e907f96
 
25511b9
faf58a4
019ee78
 
 
 
a8e5b87
019ee78
 
 
 
29be613
 
019ee78
 
 
1bfaaaf
 
 
 
 
 
019ee78
 
 
1bfaaaf
019ee78
 
a8e5b87
019ee78
 
 
d537b9c
698a3ea
019ee78
 
1bfaaaf
 
faf58a4
 
34e7ce5
 
c9023c8
 
1bfaaaf
019ee78
 
 
1bfaaaf
 
 
34e7ce5
29be613
1bfaaaf
29be613
 
110a394
 
29be613
25511b9
29be613
 
110a394
29be613
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*-coding:utf-8 -*-
from tqdm import tqdm
import tiktoken
from ape.prompt import MyTemplate
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain

Cost = {
    'davinci': 0.02,
    'chatgpt': 0.004
}


class LLMGPT(object):
    encoding = tiktoken.get_encoding("cl100k_base")

    def __init__(self, openai_key):
        self.gen_llm = ChatOpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, verbose=True)
        self.eval_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7, logprobs=1, verbose=True)
        self.test_llm = OpenAI(openai_api_key=openai_key, max_tokens=2000, temperature=0.7,  verbose=True)
        self.gen_chain = None
        self.eval_chain = None

    @staticmethod
    def confirm_cost(text, mode):
        if mode == 'train':
            cost = 0.02
        else:
            cost = 0.0004

        num_tokens = len(LLMGPT.encoding.encode(text))
        total_price = ((num_tokens / 1000) * cost)
        return total_price

    def generate_instruction(self, gen_prompt, few_shot):
        """
        Generate instruction
        """
        if not gen_prompt:
            gen_prompt = MyTemplate['gen_user_prompt']
        prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessagePromptTemplate.from_template(MyTemplate['gen_sys_prompt']),
                HumanMessagePromptTemplate.from_template(gen_prompt),
            ]
        )
        self.gen_chain = LLMChain(llm=self.gen_llm, prompt=prompt)

        prompt = ''
        for shot in few_shot:
            prompt += MyTemplate['few_shot_prompt'].format(input=shot[0], output=shot[1])
        result = self.gen_chain({'few_shot': prompt})
        return result

    def generate_output(self, test_prompt, instruction, input):
        if not test_prompt:
            test_prompt = MyTemplate['test_prompt']
        prompt = PromptTemplate.from_template(test_prompt)
        test_chain = LLMChain(llm=self.test_llm, prompt=prompt)
        output = test_chain({'input': input, 'instruction': instruction})
        return output

    def generate_logprobs(self, eval_prompt, instruction, eval_set):
        """
        Eval instruction
        """
        if not eval_prompt:
            eval_prompt = MyTemplate['eval_prompt']
        prompt = PromptTemplate.from_template(eval_prompt)
        eval_chain = LLMChain(llm=self.eval_llm, prompt=prompt)
        score = 0
        for sample in eval_set:
            output_len = len(LLMGPT.encoding.encode(sample[1]))
            print(output_len)
            llmresult = eval_chain.generate([{'instruction': instruction, 'input': sample[0], 'output': sample[1]}])
            print(llmresult)
            logprobs = llmresult.generations[0][0].generation_info['logprobs']
            print(logprobs)
            tokens = logprobs['tokens']
            token_probs = logprobs['token_logprobs']
            print(tokens)
            print(token_probs)
            score += sum(token_probs)
        return score