Spaces:
Runtime error
Runtime error
| """Common schema objects.""" | |
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List, NamedTuple, Optional | |
| from pydantic import BaseModel, Extra, Field, root_validator | |
| def get_buffer_string( | |
| messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" | |
| ) -> str: | |
| """Get buffer string of messages.""" | |
| string_messages = [] | |
| for m in messages: | |
| if isinstance(m, HumanMessage): | |
| role = human_prefix | |
| elif isinstance(m, AIMessage): | |
| role = ai_prefix | |
| elif isinstance(m, SystemMessage): | |
| role = "System" | |
| elif isinstance(m, ChatMessage): | |
| role = m.role | |
| else: | |
| raise ValueError(f"Got unsupported message type: {m}") | |
| string_messages.append(f"{role}: {m.content}") | |
| return "\n".join(string_messages) | |
| class AgentAction(NamedTuple): | |
| """Agent's action to take.""" | |
| tool: str | |
| tool_input: str | |
| log: str | |
| class AgentFinish(NamedTuple): | |
| """Agent's return value.""" | |
| return_values: dict | |
| log: str | |
| class AgentClarify(NamedTuple): | |
| """Agent's clarification request.""" | |
| question: str | |
| log: str | |
| class Generation(BaseModel): | |
| """Output of a single generation.""" | |
| text: str | |
| """Generated text output.""" | |
| generation_info: Optional[Dict[str, Any]] = None | |
| """Raw generation info response from the provider""" | |
| """May include things like reason for finishing (e.g. in OpenAI)""" | |
| # TODO: add log probs | |
| class BaseMessage(BaseModel): | |
| """Message object.""" | |
| content: str | |
| additional_kwargs: dict = Field(default_factory=dict) | |
| def type(self) -> str: | |
| """Type of the message, used for serialization.""" | |
| class HumanMessage(BaseMessage): | |
| """Type of message that is spoken by the human.""" | |
| def type(self) -> str: | |
| """Type of the message, used for serialization.""" | |
| return "human" | |
| class AIMessage(BaseMessage): | |
| """Type of message that is spoken by the AI.""" | |
| def type(self) -> str: | |
| """Type of the message, used for serialization.""" | |
| return "ai" | |
| class SystemMessage(BaseMessage): | |
| """Type of message that is a system message.""" | |
| def type(self) -> str: | |
| """Type of the message, used for serialization.""" | |
| return "system" | |
| class ChatMessage(BaseMessage): | |
| """Type of message with arbitrary speaker.""" | |
| role: str | |
| def type(self) -> str: | |
| """Type of the message, used for serialization.""" | |
| return "chat" | |
| def _message_to_dict(message: BaseMessage) -> dict: | |
| return {"type": message.type, "data": message.dict()} | |
| def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: | |
| return [_message_to_dict(m) for m in messages] | |
| def _message_from_dict(message: dict) -> BaseMessage: | |
| _type = message["type"] | |
| if _type == "human": | |
| return HumanMessage(**message["data"]) | |
| elif _type == "ai": | |
| return AIMessage(**message["data"]) | |
| elif _type == "system": | |
| return SystemMessage(**message["data"]) | |
| elif _type == "chat": | |
| return ChatMessage(**message["data"]) | |
| else: | |
| raise ValueError(f"Got unexpected type: {_type}") | |
| def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: | |
| return [_message_from_dict(m) for m in messages] | |
| class ChatGeneration(Generation): | |
| """Output of a single generation.""" | |
| text = "" | |
| message: BaseMessage | |
| def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
| values["text"] = values["message"].content | |
| return values | |
| class ChatResult(BaseModel): | |
| """Class that contains all relevant information for a Chat Result.""" | |
| generations: List[ChatGeneration] | |
| """List of the things generated.""" | |
| llm_output: Optional[dict] = None | |
| """For arbitrary LLM provider specific output.""" | |
| class LLMResult(BaseModel): | |
| """Class that contains all relevant information for an LLM Result.""" | |
| generations: List[List[Generation]] | |
| """List of the things generated. This is List[List[]] because | |
| each input could have multiple generations.""" | |
| llm_output: Optional[dict] = None | |
| """For arbitrary LLM provider specific output.""" | |
| class PromptValue(BaseModel, ABC): | |
| def to_string(self) -> str: | |
| """Return prompt as string.""" | |
| def to_messages(self) -> List[BaseMessage]: | |
| """Return prompt as messages.""" | |
| class BaseLanguageModel(BaseModel, ABC): | |
| def generate_prompt( | |
| self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
| ) -> LLMResult: | |
| """Take in a list of prompt values and return an LLMResult.""" | |
| async def agenerate_prompt( | |
| self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
| ) -> LLMResult: | |
| """Take in a list of prompt values and return an LLMResult.""" | |
| def get_num_tokens(self, text: str) -> int: | |
| """Get the number of tokens present in the text.""" | |
| # TODO: this method may not be exact. | |
| # TODO: this method may differ based on model (eg codex). | |
| try: | |
| from transformers import GPT2TokenizerFast | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import transformers python package. " | |
| "This is needed in order to calculate get_num_tokens. " | |
| "Please it install it with `pip install transformers`." | |
| ) | |
| # create a GPT-3 tokenizer instance | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| # tokenize the text using the GPT-3 tokenizer | |
| tokenized_text = tokenizer.tokenize(text) | |
| # calculate the number of tokens in the tokenized text | |
| return len(tokenized_text) | |
| def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |
| """Get the number of tokens in the message.""" | |
| return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) | |
| class BaseMemory(BaseModel, ABC): | |
| """Base interface for memory in chains.""" | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def memory_variables(self) -> List[str]: | |
| """Input keys this memory class will load dynamically.""" | |
| def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
| """Return key-value pairs given the text input to the chain. | |
| If None, return all memories | |
| """ | |
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
| """Save the context of this model run to memory.""" | |
| def clear(self) -> None: | |
| """Clear memory contents.""" | |
| class Document(BaseModel): | |
| """Interface for interacting with a document.""" | |
| page_content: str | |
| lookup_str: str = "" | |
| lookup_index = 0 | |
| metadata: dict = Field(default_factory=dict) | |
| def paragraphs(self) -> List[str]: | |
| """Paragraphs of the page.""" | |
| return self.page_content.split("\n\n") | |
| def summary(self) -> str: | |
| """Summary of the page (the first paragraph).""" | |
| return self.paragraphs[0] | |
| def lookup(self, string: str) -> str: | |
| """Lookup a term in the page, imitating cmd-F functionality.""" | |
| if string.lower() != self.lookup_str: | |
| self.lookup_str = string.lower() | |
| self.lookup_index = 0 | |
| else: | |
| self.lookup_index += 1 | |
| lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()] | |
| if len(lookups) == 0: | |
| return "No Results" | |
| elif self.lookup_index >= len(lookups): | |
| return "No More Results" | |
| else: | |
| result_prefix = f"(Result {self.lookup_index + 1}/{len(lookups)})" | |
| return f"{result_prefix} {lookups[self.lookup_index]}" | |
| class BaseRetriever(ABC): | |
| def get_relevant_texts(self, query: str) -> List[Document]: | |
| """Get texts relevant for a query. | |
| Args: | |
| query: string to find relevant tests for | |
| Returns: | |
| List of relevant documents | |
| """ | |
| # For backwards compatibility | |
| Memory = BaseMemory | |
| class BaseOutputParser(BaseModel, ABC): | |
| """Class to parse the output of an LLM call.""" | |
| def parse(self, text: str) -> Any: | |
| """Parse the output of an LLM call.""" | |
| def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: | |
| return self.parse(completion) | |
| def get_format_instructions(self) -> str: | |
| raise NotImplementedError | |
| def _type(self) -> str: | |
| """Return the type key.""" | |
| raise NotImplementedError | |
| def dict(self, **kwargs: Any) -> Dict: | |
| """Return dictionary representation of output parser.""" | |
| output_parser_dict = super().dict() | |
| output_parser_dict["_type"] = self._type | |
| return output_parser_dict | |
| class OutputParserException(Exception): | |
| """Exception that output parsers should raise to signify a parsing error. | |
| This exists to differentiate parsing errors from other code or execution errors | |
| that also may arise inside the output parser. OutputParserExceptions will be | |
| available to catch and handle in ways to fix the parsing error, while other | |
| errors will be raised. | |
| """ | |
| pass | |