File size: 4,894 Bytes
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3332aa4
a1ca2de
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936d161
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
936d161
 
 
 
 
 
 
 
 
 
a1ca2de
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import toml
from pathlib import Path
import google.generativeai as palm_api

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

from .utils import set_palm_api_key


# Set PaLM API Key
set_palm_api_key()

# Load PaLM Prompt Templates
palm_prompts = toml.load(Path('.') / 'assets' / 'palm_prompts.toml')

class PaLMChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        ping = pingpong.ping[:truncate_size]
        pong = pingpong.pong
        
        if pong is None or pong.strip() == "":
            return [
                {
                    "author": "USER",
                    "content": ping
                },
            ]
        else:
            pong = pong[:truncate_size]

            return [
                {
                    "author": "USER",
                    "content": ping
                },
                {
                    "author": "AI",
                    "content": pong
                },
            ]

class PaLMChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        results = []
        
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results    

class GradioPaLMChatPPManager(PaLMChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results    

async def gen_text(
    prompt,
    mode="chat", #chat or text
    parameters=None,
    use_filter=True
):
    warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
    if parameters is None:
        temperature = 1.0
        top_k = 40
        top_p = 0.95
        max_output_tokens = 1024
        
        # default safety settings
        safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
                           {"category":"HARM_CATEGORY_TOXICITY","threshold":1},
                           {"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
                           {"category":"HARM_CATEGORY_SEXUAL","threshold":2},
                           {"category":"HARM_CATEGORY_MEDICAL","threshold":2},
                           {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
        if not use_filter:
            for idx, _ in enumerate(safety_settings):
                safety_settings[idx]['threshold'] = 4

        if mode == "chat":
            parameters = {
                'model': 'models/chat-bison-001',
                'candidate_count': 1,
                'context': "",
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
                'safety_settings': safety_settings,
            }
        else:
            parameters = {
                'model': 'models/text-bison-001',
                'candidate_count': 1,
                'temperature': temperature,
                'top_k': top_k,
                'top_p': top_p,
                'max_output_tokens': max_output_tokens,
                'safety_settings': safety_settings,
            }

    try:
        if mode == "chat":
            response = await palm_api.chat_async(**parameters, messages=prompt)
        else:
            response = palm_api.generate_text(**parameters, prompt=prompt)
    except:
        raise EnvironmentError("PaLM API is not available.")

    if use_filter and len(response.filters) > 0:
        raise Exception("PaLM API has withheld a response due to content safety concerns.")
    else:
        if mode == "chat":
            response_txt = response.last
        else:
            response_txt = response.result
    
    return response, response_txt