jfeng1115's picture
init commit
58d33f0
raw
history blame contribute delete
No virus
1.19 kB
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from pydantic import BaseModel, Field
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):
@abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model."""
class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals."""
default_prompt: BasePromptTemplate
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
for condition, prompt in self.conditionals:
if condition(llm):
return prompt
return self.default_prompt
def is_llm(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseLLM)
def is_chat_model(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseChatModel)