jfeng1115's picture
init commit
58d33f0
raw
history blame contribute delete
No virus
1.75 kB
"""Wrapper around HazyResearch's Manifest library."""
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
class ManifestWrapper(LLM, BaseModel):
"""Wrapper around HazyResearch's Manifest library."""
client: Any #: :meta private:
llm_kwargs: Optional[Dict] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
try:
from manifest import Manifest
if not isinstance(values["client"], Manifest):
raise ValueError
except ImportError:
raise ValueError(
"Could not import manifest python package. "
"Please it install it with `pip install manifest-ml`."
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
kwargs = self.llm_kwargs or {}
return {**self.client.client.get_model_params(), **kwargs}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "manifest"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1:
raise NotImplementedError(
f"Manifest currently only supports a single stop token, got {stop}"
)
kwargs = self.llm_kwargs or {}
if stop is not None:
kwargs["stop_token"] = stop
return self.client.run(prompt, **kwargs)