import os import threading import toml from pathlib import Path import google.generativeai as palm_api from pingpong import PingPong from pingpong.pingpong import PPManager from pingpong.pingpong import PromptFmt from pingpong.pingpong import UIFmt from pingpong.gradio import GradioChatUIFmt from modules.llms import ( LLMFactory, PromptFmt, PromptManager, PPManager, UIPPManager, LLMService ) class PaLMFactory(LLMFactory): _palm_api_key = None def __init__(self, palm_api_key=None): if not PaLMFactory._palm_api_key: PaLMFactory.load_palm_api_key() assert PaLMFactory._palm_api_key, "PaLM API Key is missing." palm_api.configure(api_key=PaLMFactory._palm_api_key) def create_prompt_format(self): return PaLMChatPromptFmt() def create_prompt_manager(self, prompts_path: str=None): return PaLMPromptManager((prompts_path or Path('.') / 'prompts' / 'palm_prompts.toml')) def create_pp_manager(self): return PaLMChatPPManager() def create_ui_pp_manager(self): return GradioPaLMChatPPManager() def create_llm_service(self): return PaLMService() @classmethod def load_palm_api_key(cls, palm_api_key: str=None): if palm_api_key: cls._palm_api_key = palm_api_key else: palm_api_key = os.getenv("PALM_API_KEY") if palm_api_key is None: with open('.palm_api_key.txt', 'r') as file: palm_api_key = file.read().strip() if not palm_api_key: raise ValueError("PaLM API Key is missing.") cls._palm_api_key = palm_api_key @property def palm_api_key(self): return PaLMFactory._palm_api_key @palm_api_key.setter def palm_api_key(self, palm_api_key: str): assert palm_api_key, "PaLM API Key is missing." PaLMFactory._palm_api_key = palm_api_key palm_api.configure(api_key=PaLMFactory._palm_api_key) class PaLMChatPromptFmt(PromptFmt): @classmethod def ctx(cls, context): pass @classmethod def prompt(cls, pingpong, truncate_size): ping = pingpong.ping[:truncate_size] pong = pingpong.pong if pong is None or pong.strip() == "": return [ { "author": "USER", "content": ping }, ] else: pong = pong[:truncate_size] return [ { "author": "USER", "content": ping }, { "author": "AI", "content": pong }, ] class PaLMPromptManager(PromptManager): _instance = None _lock = threading.Lock() _prompts = None def __new__(cls, prompts_path): if cls._instance is None: with cls._lock: if not cls._instance: cls._instance = super(PaLMPromptManager, cls).__new__(cls) cls._instance.load_prompts(prompts_path) return cls._instance def load_prompts(self, prompts_path): self._prompts_path = prompts_path self.reload_prompts() def reload_prompts(self): assert self.prompts_path, "Prompt path is missing." self._prompts = toml.load(self.prompts_path) @property def prompts_path(self): return self._prompts_path @prompts_path.setter def prompts_path(self, prompts_path): self._prompts_path = prompts_path self.reload_prompts() @property def prompts(self): if self._prompts is None: self.load_prompts() return self._prompts class PaLMChatPPManager(PPManager): def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=None, truncate_size: int=None): if fmt is None: factory = PaLMFactory() fmt = factory.create_prompt_format() results = [] if to_idx == -1 or to_idx >= len(self.pingpongs): to_idx = len(self.pingpongs) for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): results += fmt.prompt(pingpong, truncate_size=truncate_size) return results class GradioPaLMChatPPManager(UIPPManager, PaLMChatPPManager): def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): if to_idx == -1 or to_idx >= len(self.pingpongs): to_idx = len(self.pingpongs) results = [] for pingpong in self.pingpongs[from_idx:to_idx]: results.append(fmt.ui(pingpong)) return results class PaLMService(LLMService): def __init__(self): self._default_parameters_text = { 'model': 'models/text-bison-001', 'temperature': 0.7, 'candidate_count': 1, 'top_k': 40, 'top_p': 0.95, 'max_output_tokens': 1024, 'stop_sequences': [], 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1}, {"category":"HARM_CATEGORY_TOXICITY","threshold":1}, {"category":"HARM_CATEGORY_VIOLENCE","threshold":2}, {"category":"HARM_CATEGORY_SEXUAL","threshold":2}, {"category":"HARM_CATEGORY_MEDICAL","threshold":2}, {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}], } self._default_parameters_chat = { 'model': 'models/chat-bison-001', 'temperature': 0.25, 'candidate_count': 1, 'top_k': 40, 'top_p': 0.95, } def make_params(self, mode="chat", temperature=None, candidate_count=None, top_k=None, top_p=None, max_output_tokens=None, use_filter=True): parameters = None if mode == "chat": parameters = self._default_parameters_chat.copy() elif mode == "text": parameters = self._default_parameters_text.copy() if temperature is not None: parameters['temperature'] = temperature if candidate_count is not None: parameters['candidate_count'] = candidate_count if top_k is not None: parameters['top_k'] = top_k if max_output_tokens is not None and mode == "text": parameters['max_output_tokens'] = max_output_tokens if not use_filter and mode == "text": for idx, _ in enumerate(parameters['safety_settings']): parameters['safety_settings'][idx]['threshold'] = 4 return parameters async def gen_text( self, prompt, mode="chat", #chat or text parameters=None, use_filter=True ): if parameters is None: temperature = 1.0 top_k = 40 top_p = 0.95 max_output_tokens = 1024 # default safety settings safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1}, {"category":"HARM_CATEGORY_TOXICITY","threshold":1}, {"category":"HARM_CATEGORY_VIOLENCE","threshold":2}, {"category":"HARM_CATEGORY_SEXUAL","threshold":2}, {"category":"HARM_CATEGORY_MEDICAL","threshold":2}, {"category":"HARM_CATEGORY_DANGEROUS","threshold":2}] if not use_filter: for idx, _ in enumerate(safety_settings): safety_settings[idx]['threshold'] = 4 if mode == "chat": parameters = { 'model': 'models/chat-bison-001', 'candidate_count': 1, 'context': "", 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'safety_settings': safety_settings, } else: parameters = { 'model': 'models/text-bison-001', 'candidate_count': 1, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'max_output_tokens': max_output_tokens, 'safety_settings': safety_settings, } try: if mode == "chat": response = await palm_api.chat_async(**parameters, messages=prompt) else: response = palm_api.generate_text(**parameters, prompt=prompt) except: raise EnvironmentError("PaLM API is not available.") if use_filter and len(response.filters) > 0: raise Exception("PaLM API has withheld a response due to content safety concerns.") else: if mode == "chat": response_txt = response.last else: response_txt = response.result return response, response_txt