zero2story / modules /llms /palm_service.py
chansung's picture
.
3332aa4
import os
import threading
import toml
from pathlib import Path
import google.generativeai as palm_api
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
from modules.llms import (
LLMFactory,
PromptFmt, PromptManager, PPManager, UIPPManager, LLMService
)
class PaLMFactory(LLMFactory):
_palm_api_key = None
def __init__(self, palm_api_key=None):
if not PaLMFactory._palm_api_key:
PaLMFactory.load_palm_api_key()
assert PaLMFactory._palm_api_key, "PaLM API Key is missing."
palm_api.configure(api_key=PaLMFactory._palm_api_key)
def create_prompt_format(self):
return PaLMChatPromptFmt()
def create_prompt_manager(self, prompts_path: str=None):
return PaLMPromptManager((prompts_path or Path('.') / 'prompts' / 'palm_prompts.toml'))
def create_pp_manager(self):
return PaLMChatPPManager()
def create_ui_pp_manager(self):
return GradioPaLMChatPPManager()
def create_llm_service(self):
return PaLMService()
@classmethod
def load_palm_api_key(cls, palm_api_key: str=None):
if palm_api_key:
cls._palm_api_key = palm_api_key
else:
palm_api_key = os.getenv("PALM_API_KEY")
if palm_api_key is None:
with open('.palm_api_key.txt', 'r') as file:
palm_api_key = file.read().strip()
if not palm_api_key:
raise ValueError("PaLM API Key is missing.")
cls._palm_api_key = palm_api_key
@property
def palm_api_key(self):
return PaLMFactory._palm_api_key
@palm_api_key.setter
def palm_api_key(self, palm_api_key: str):
assert palm_api_key, "PaLM API Key is missing."
PaLMFactory._palm_api_key = palm_api_key
palm_api.configure(api_key=PaLMFactory._palm_api_key)
class PaLMChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
pass
@classmethod
def prompt(cls, pingpong, truncate_size):
ping = pingpong.ping[:truncate_size]
pong = pingpong.pong
if pong is None or pong.strip() == "":
return [
{
"author": "USER",
"content": ping
},
]
else:
pong = pong[:truncate_size]
return [
{
"author": "USER",
"content": ping
},
{
"author": "AI",
"content": pong
},
]
class PaLMPromptManager(PromptManager):
_instance = None
_lock = threading.Lock()
_prompts = None
def __new__(cls, prompts_path):
if cls._instance is None:
with cls._lock:
if not cls._instance:
cls._instance = super(PaLMPromptManager, cls).__new__(cls)
cls._instance.load_prompts(prompts_path)
return cls._instance
def load_prompts(self, prompts_path):
self._prompts_path = prompts_path
self.reload_prompts()
def reload_prompts(self):
assert self.prompts_path, "Prompt path is missing."
self._prompts = toml.load(self.prompts_path)
@property
def prompts_path(self):
return self._prompts_path
@prompts_path.setter
def prompts_path(self, prompts_path):
self._prompts_path = prompts_path
self.reload_prompts()
@property
def prompts(self):
if self._prompts is None:
self.load_prompts()
return self._prompts
class PaLMChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=None, truncate_size: int=None):
if fmt is None:
factory = PaLMFactory()
fmt = factory.create_prompt_format()
results = []
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioPaLMChatPPManager(UIPPManager, PaLMChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
class PaLMService(LLMService):
def __init__(self):
self._default_parameters_text = {
'model': 'models/text-bison-001',
'temperature': 0.7,
'candidate_count': 1,
'top_k': 40,
'top_p': 0.95,
'max_output_tokens': 1024,
'stop_sequences': [],
'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
{"category":"HARM_CATEGORY_TOXICITY","threshold":1},
{"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
{"category":"HARM_CATEGORY_SEXUAL","threshold":2},
{"category":"HARM_CATEGORY_MEDICAL","threshold":2},
{"category":"HARM_CATEGORY_DANGEROUS","threshold":2}],
}
self._default_parameters_chat = {
'model': 'models/chat-bison-001',
'temperature': 0.25,
'candidate_count': 1,
'top_k': 40,
'top_p': 0.95,
}
def make_params(self, mode="chat",
temperature=None,
candidate_count=None,
top_k=None,
top_p=None,
max_output_tokens=None,
use_filter=True):
parameters = None
if mode == "chat":
parameters = self._default_parameters_chat.copy()
elif mode == "text":
parameters = self._default_parameters_text.copy()
if temperature is not None:
parameters['temperature'] = temperature
if candidate_count is not None:
parameters['candidate_count'] = candidate_count
if top_k is not None:
parameters['top_k'] = top_k
if max_output_tokens is not None and mode == "text":
parameters['max_output_tokens'] = max_output_tokens
if not use_filter and mode == "text":
for idx, _ in enumerate(parameters['safety_settings']):
parameters['safety_settings'][idx]['threshold'] = 4
return parameters
async def gen_text(
self,
prompt,
mode="chat", #chat or text
parameters=None,
use_filter=True
):
if parameters is None:
temperature = 1.0
top_k = 40
top_p = 0.95
max_output_tokens = 1024
# default safety settings
safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
{"category":"HARM_CATEGORY_TOXICITY","threshold":1},
{"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
{"category":"HARM_CATEGORY_SEXUAL","threshold":2},
{"category":"HARM_CATEGORY_MEDICAL","threshold":2},
{"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
if not use_filter:
for idx, _ in enumerate(safety_settings):
safety_settings[idx]['threshold'] = 4
if mode == "chat":
parameters = {
'model': 'models/chat-bison-001',
'candidate_count': 1,
'context': "",
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'safety_settings': safety_settings,
}
else:
parameters = {
'model': 'models/text-bison-001',
'candidate_count': 1,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'max_output_tokens': max_output_tokens,
'safety_settings': safety_settings,
}
try:
if mode == "chat":
response = await palm_api.chat_async(**parameters, messages=prompt)
else:
response = palm_api.generate_text(**parameters, prompt=prompt)
except:
raise EnvironmentError("PaLM API is not available.")
if use_filter and len(response.filters) > 0:
raise Exception("PaLM API has withheld a response due to content safety concerns.")
else:
if mode == "chat":
response_txt = response.last
else:
response_txt = response.result
return response, response_txt