Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import List, Optional | |
from pydantic import BaseModel, Extra, Field, validator | |
import langchain | |
from langchain.callbacks import get_callback_manager | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.schema import ( | |
AIMessage, | |
BaseLanguageModel, | |
BaseMessage, | |
ChatGeneration, | |
ChatResult, | |
HumanMessage, | |
LLMResult, | |
PromptValue, | |
) | |
def _get_verbosity() -> bool: | |
return langchain.verbose | |
class BaseChatModel(BaseLanguageModel, BaseModel, ABC): | |
verbose: bool = Field(default_factory=_get_verbosity) | |
"""Whether to print out response text.""" | |
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def set_callback_manager( | |
cls, callback_manager: Optional[BaseCallbackManager] | |
) -> BaseCallbackManager: | |
"""If callback manager is None, set it. | |
This allows users to pass in None as callback manager, which is a nice UX. | |
""" | |
return callback_manager or get_callback_manager() | |
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |
return {} | |
def generate( | |
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Top Level call""" | |
results = [self._generate(m, stop=stop) for m in messages] | |
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) | |
generations = [res.generations for res in results] | |
return LLMResult(generations=generations, llm_output=llm_output) | |
async def agenerate( | |
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Top Level call""" | |
results = [await self._agenerate(m, stop=stop) for m in messages] | |
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) | |
generations = [res.generations for res in results] | |
return LLMResult(generations=generations, llm_output=llm_output) | |
def generate_prompt( | |
self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
prompt_messages = [p.to_messages() for p in prompts] | |
prompt_strings = [p.to_string() for p in prompts] | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose | |
) | |
try: | |
output = self.generate(prompt_messages, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
return output | |
async def agenerate_prompt( | |
self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
prompt_messages = [p.to_messages() for p in prompts] | |
prompt_strings = [p.to_string() for p in prompts] | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose | |
) | |
else: | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose | |
) | |
try: | |
output = await self.agenerate(prompt_messages, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
else: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
else: | |
self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
return output | |
def _generate( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> ChatResult: | |
"""Top Level call""" | |
async def _agenerate( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> ChatResult: | |
"""Top Level call""" | |
def __call__( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> BaseMessage: | |
return self._generate(messages, stop=stop).generations[0].message | |
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: | |
result = self([HumanMessage(content=message)], stop=stop) | |
return result.content | |
class SimpleChatModel(BaseChatModel): | |
def _generate( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> ChatResult: | |
output_str = self._call(messages, stop=stop) | |
message = AIMessage(content=output_str) | |
generation = ChatGeneration(message=message) | |
return ChatResult(generations=[generation]) | |
def _call( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |
) -> str: | |
"""Simpler interface.""" | |