Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import Callable, List, Tuple | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain.chat_models.base import BaseChatModel | |
from langchain.llms.base import BaseLLM | |
class BasePromptSelector(BaseModel, ABC): | |
"""Base class for prompt selectors.""" | |
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: | |
"""Get default prompt for a language model.""" | |
class ConditionalPromptSelector(BasePromptSelector): | |
"""Prompt collection that goes through conditionals.""" | |
default_prompt: BasePromptTemplate | |
"""Default prompt to use if no conditionals match.""" | |
conditionals: List[ | |
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate] | |
] = Field(default_factory=list) | |
"""List of conditionals and prompts to use if the conditionals match.""" | |
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: | |
"""Get default prompt for a language model. | |
Args: | |
llm: Language model to get prompt for. | |
Returns: | |
Prompt to use for the language model. | |
""" | |
for condition, prompt in self.conditionals: | |
if condition(llm): | |
return prompt | |
return self.default_prompt | |
def is_llm(llm: BaseLanguageModel) -> bool: | |
"""Check if the language model is a LLM. | |
Args: | |
llm: Language model to check. | |
Returns: | |
True if the language model is a BaseLLM model, False otherwise. | |
""" | |
return isinstance(llm, BaseLLM) | |
def is_chat_model(llm: BaseLanguageModel) -> bool: | |
"""Check if the language model is a chat model. | |
Args: | |
llm: Language model to check. | |
Returns: | |
True if the language model is a BaseChatModel model, False otherwise. | |
""" | |
return isinstance(llm, BaseChatModel) | |