Spaces:
Runtime error
Runtime error
File size: 1,984 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
class BasePromptSelector(BaseModel, ABC):
"""Base class for prompt selectors."""
@abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model."""
class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals."""
default_prompt: BasePromptTemplate
"""Default prompt to use if no conditionals match."""
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)
"""List of conditionals and prompts to use if the conditionals match."""
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model.
Args:
llm: Language model to get prompt for.
Returns:
Prompt to use for the language model.
"""
for condition, prompt in self.conditionals:
if condition(llm):
return prompt
return self.default_prompt
def is_llm(llm: BaseLanguageModel) -> bool:
"""Check if the language model is a LLM.
Args:
llm: Language model to check.
Returns:
True if the language model is a BaseLLM model, False otherwise.
"""
return isinstance(llm, BaseLLM)
def is_chat_model(llm: BaseLanguageModel) -> bool:
"""Check if the language model is a chat model.
Args:
llm: Language model to check.
Returns:
True if the language model is a BaseChatModel model, False otherwise.
"""
return isinstance(llm, BaseChatModel)
|