Spaces:
Runtime error
Runtime error
"""Base interface for large language models to expose.""" | |
import json | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union | |
import yaml | |
from pydantic import BaseModel, Extra, Field, validator | |
import langchain | |
from langchain.callbacks import get_callback_manager | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue | |
def _get_verbosity() -> bool: | |
return langchain.verbose | |
def get_prompts( | |
params: Dict[str, Any], prompts: List[str] | |
) -> Tuple[Dict[int, List], str, List[int], List[str]]: | |
"""Get prompts that are already cached.""" | |
llm_string = str(sorted([(k, v) for k, v in params.items()])) | |
missing_prompts = [] | |
missing_prompt_idxs = [] | |
existing_prompts = {} | |
for i, prompt in enumerate(prompts): | |
if langchain.llm_cache is not None: | |
cache_val = langchain.llm_cache.lookup(prompt, llm_string) | |
if isinstance(cache_val, list): | |
existing_prompts[i] = cache_val | |
else: | |
missing_prompts.append(prompt) | |
missing_prompt_idxs.append(i) | |
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts | |
def update_cache( | |
existing_prompts: Dict[int, List], | |
llm_string: str, | |
missing_prompt_idxs: List[int], | |
new_results: LLMResult, | |
prompts: List[str], | |
) -> Optional[dict]: | |
"""Update the cache and get the LLM output.""" | |
for i, result in enumerate(new_results.generations): | |
existing_prompts[missing_prompt_idxs[i]] = result | |
prompt = prompts[missing_prompt_idxs[i]] | |
if langchain.llm_cache is not None: | |
langchain.llm_cache.update(prompt, llm_string, result) | |
llm_output = new_results.llm_output | |
return llm_output | |
class BaseLLM(BaseLanguageModel, BaseModel, ABC): | |
"""LLM wrapper should take in a prompt and return a string.""" | |
cache: Optional[bool] = None | |
verbose: bool = Field(default_factory=_get_verbosity) | |
"""Whether to print out response text.""" | |
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def set_callback_manager( | |
cls, callback_manager: Optional[BaseCallbackManager] | |
) -> BaseCallbackManager: | |
"""If callback manager is None, set it. | |
This allows users to pass in None as callback manager, which is a nice UX. | |
""" | |
return callback_manager or get_callback_manager() | |
def set_verbose(cls, verbose: Optional[bool]) -> bool: | |
"""If verbose is None, set it. | |
This allows users to pass in None as verbose to access the global setting. | |
""" | |
if verbose is None: | |
return _get_verbosity() | |
else: | |
return verbose | |
def _generate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompts.""" | |
async def _agenerate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompts.""" | |
def generate_prompt( | |
self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
prompt_strings = [p.to_string() for p in prompts] | |
return self.generate(prompt_strings, stop=stop) | |
async def agenerate_prompt( | |
self, prompts: List[PromptValue], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
prompt_strings = [p.to_string() for p in prompts] | |
return await self.agenerate(prompt_strings, stop=stop) | |
def generate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompt and input.""" | |
# If string is passed in directly no errors will be raised but outputs will | |
# not make sense. | |
if not isinstance(prompts, list): | |
raise ValueError( | |
"Argument 'prompts' is expected to be of type List[str], received" | |
f" argument of type {type(prompts)}." | |
) | |
disregard_cache = self.cache is not None and not self.cache | |
if langchain.llm_cache is None or disregard_cache: | |
# This happens when langchain.cache is None, but self.cache is True | |
if self.cache is not None and self.cache: | |
raise ValueError( | |
"Asked to cache, but no cache found at `langchain.cache`." | |
) | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompts, verbose=self.verbose | |
) | |
try: | |
output = self._generate(prompts, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
return output | |
params = self.dict() | |
params["stop"] = stop | |
( | |
existing_prompts, | |
llm_string, | |
missing_prompt_idxs, | |
missing_prompts, | |
) = get_prompts(params, prompts) | |
if len(missing_prompts) > 0: | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose | |
) | |
try: | |
new_results = self._generate(missing_prompts, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) | |
llm_output = update_cache( | |
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts | |
) | |
else: | |
llm_output = {} | |
generations = [existing_prompts[i] for i in range(len(prompts))] | |
return LLMResult(generations=generations, llm_output=llm_output) | |
async def agenerate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompt and input.""" | |
disregard_cache = self.cache is not None and not self.cache | |
if langchain.llm_cache is None or disregard_cache: | |
# This happens when langchain.cache is None, but self.cache is True | |
if self.cache is not None and self.cache: | |
raise ValueError( | |
"Asked to cache, but no cache found at `langchain.cache`." | |
) | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompts, verbose=self.verbose | |
) | |
else: | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, prompts, verbose=self.verbose | |
) | |
try: | |
output = await self._agenerate(prompts, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
else: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
else: | |
self.callback_manager.on_llm_end(output, verbose=self.verbose) | |
return output | |
params = self.dict() | |
params["stop"] = stop | |
( | |
existing_prompts, | |
llm_string, | |
missing_prompt_idxs, | |
missing_prompts, | |
) = get_prompts(params, prompts) | |
if len(missing_prompts) > 0: | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, | |
missing_prompts, | |
verbose=self.verbose, | |
) | |
else: | |
self.callback_manager.on_llm_start( | |
{"name": self.__class__.__name__}, | |
missing_prompts, | |
verbose=self.verbose, | |
) | |
try: | |
new_results = await self._agenerate(missing_prompts, stop=stop) | |
except (KeyboardInterrupt, Exception) as e: | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
else: | |
self.callback_manager.on_llm_error(e, verbose=self.verbose) | |
raise e | |
if self.callback_manager.is_async: | |
await self.callback_manager.on_llm_end( | |
new_results, verbose=self.verbose | |
) | |
else: | |
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) | |
llm_output = update_cache( | |
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts | |
) | |
else: | |
llm_output = {} | |
generations = [existing_prompts[i] for i in range(len(prompts))] | |
return LLMResult(generations=generations, llm_output=llm_output) | |
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
"""Check Cache and run the LLM on the given prompt and input.""" | |
return self.generate([prompt], stop=stop).generations[0][0].text | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {} | |
def __str__(self) -> str: | |
"""Get a string representation of the object for printing.""" | |
cls_name = f"\033[1m{self.__class__.__name__}\033[0m" | |
return f"{cls_name}\nParams: {self._identifying_params}" | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
def dict(self, **kwargs: Any) -> Dict: | |
"""Return a dictionary of the LLM.""" | |
starter_dict = dict(self._identifying_params) | |
starter_dict["_type"] = self._llm_type | |
return starter_dict | |
def save(self, file_path: Union[Path, str]) -> None: | |
"""Save the LLM. | |
Args: | |
file_path: Path to file to save the LLM to. | |
Example: | |
.. code-block:: python | |
llm.save(file_path="path/llm.yaml") | |
""" | |
# Convert file to Path object. | |
if isinstance(file_path, str): | |
save_path = Path(file_path) | |
else: | |
save_path = file_path | |
directory_path = save_path.parent | |
directory_path.mkdir(parents=True, exist_ok=True) | |
# Fetch dictionary to save | |
prompt_dict = self.dict() | |
if save_path.suffix == ".json": | |
with open(file_path, "w") as f: | |
json.dump(prompt_dict, f, indent=4) | |
elif save_path.suffix == ".yaml": | |
with open(file_path, "w") as f: | |
yaml.dump(prompt_dict, f, default_flow_style=False) | |
else: | |
raise ValueError(f"{save_path} must be json or yaml") | |
class LLM(BaseLLM): | |
"""LLM class that expect subclasses to implement a simpler call method. | |
The purpose of this class is to expose a simpler interface for working | |
with LLMs, rather than expect the user to implement the full _generate method. | |
""" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
"""Run the LLM on the given prompt and input.""" | |
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
"""Run the LLM on the given prompt and input.""" | |
raise NotImplementedError("Async generation not implemented for this LLM.") | |
def _generate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompt and input.""" | |
# TODO: add caching here. | |
generations = [] | |
for prompt in prompts: | |
text = self._call(prompt, stop=stop) | |
generations.append([Generation(text=text)]) | |
return LLMResult(generations=generations) | |
async def _agenerate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
"""Run the LLM on the given prompt and input.""" | |
generations = [] | |
for prompt in prompts: | |
text = await self._acall(prompt, stop=stop) | |
generations.append([Generation(text=text)]) | |
return LLMResult(generations=generations) | |