Spaces:
Runtime error
Runtime error
from abc import ABC | |
from typing import Any, Dict, List, Literal, TypedDict, Union, cast | |
from pydantic import BaseModel, PrivateAttr | |
class BaseSerialized(TypedDict): | |
"""Base class for serialized objects.""" | |
lc: int | |
id: List[str] | |
class SerializedConstructor(BaseSerialized): | |
"""Serialized constructor.""" | |
type: Literal["constructor"] | |
kwargs: Dict[str, Any] | |
class SerializedSecret(BaseSerialized): | |
"""Serialized secret.""" | |
type: Literal["secret"] | |
class SerializedNotImplemented(BaseSerialized): | |
"""Serialized not implemented.""" | |
type: Literal["not_implemented"] | |
class Serializable(BaseModel, ABC): | |
"""Serializable base class.""" | |
def lc_serializable(self) -> bool: | |
""" | |
Return whether or not the class is serializable. | |
""" | |
return False | |
def lc_namespace(self) -> List[str]: | |
""" | |
Return the namespace of the langchain object. | |
eg. ["langchain", "llms", "openai"] | |
""" | |
return self.__class__.__module__.split(".") | |
def lc_secrets(self) -> Dict[str, str]: | |
""" | |
Return a map of constructor argument names to secret ids. | |
eg. {"openai_api_key": "OPENAI_API_KEY"} | |
""" | |
return dict() | |
def lc_attributes(self) -> Dict: | |
""" | |
Return a list of attribute names that should be included in the | |
serialized kwargs. These attributes must be accepted by the | |
constructor. | |
""" | |
return {} | |
class Config: | |
extra = "ignore" | |
_lc_kwargs = PrivateAttr(default_factory=dict) | |
def __init__(self, **kwargs: Any) -> None: | |
super().__init__(**kwargs) | |
self._lc_kwargs = kwargs | |
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: | |
if not self.lc_serializable: | |
return self.to_json_not_implemented() | |
secrets = dict() | |
# Get latest values for kwargs if there is an attribute with same name | |
lc_kwargs = { | |
k: getattr(self, k, v) | |
for k, v in self._lc_kwargs.items() | |
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore | |
} | |
# Merge the lc_secrets and lc_attributes from every class in the MRO | |
for cls in [None, *self.__class__.mro()]: | |
# Once we get to Serializable, we're done | |
if cls is Serializable: | |
break | |
# Get a reference to self bound to each class in the MRO | |
this = cast(Serializable, self if cls is None else super(cls, self)) | |
secrets.update(this.lc_secrets) | |
lc_kwargs.update(this.lc_attributes) | |
# include all secrets, even if not specified in kwargs | |
# as these secrets may be passed as an environment variable instead | |
for key in secrets.keys(): | |
secret_value = getattr(self, key, None) or lc_kwargs.get(key) | |
if secret_value is not None: | |
lc_kwargs.update({key: secret_value}) | |
return { | |
"lc": 1, | |
"type": "constructor", | |
"id": [*self.lc_namespace, self.__class__.__name__], | |
"kwargs": lc_kwargs | |
if not secrets | |
else _replace_secrets(lc_kwargs, secrets), | |
} | |
def to_json_not_implemented(self) -> SerializedNotImplemented: | |
return to_json_not_implemented(self) | |
def _replace_secrets( | |
root: Dict[Any, Any], secrets_map: Dict[str, str] | |
) -> Dict[Any, Any]: | |
result = root.copy() | |
for path, secret_id in secrets_map.items(): | |
[*parts, last] = path.split(".") | |
current = result | |
for part in parts: | |
if part not in current: | |
break | |
current[part] = current[part].copy() | |
current = current[part] | |
if last in current: | |
current[last] = { | |
"lc": 1, | |
"type": "secret", | |
"id": [secret_id], | |
} | |
return result | |
def to_json_not_implemented(obj: object) -> SerializedNotImplemented: | |
"""Serialize a "not implemented" object. | |
Args: | |
obj: object to serialize | |
Returns: | |
SerializedNotImplemented | |
""" | |
_id: List[str] = [] | |
try: | |
if hasattr(obj, "__name__"): | |
_id = [*obj.__module__.split("."), obj.__name__] | |
elif hasattr(obj, "__class__"): | |
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] | |
except Exception: | |
pass | |
return { | |
"lc": 1, | |
"type": "not_implemented", | |
"id": _id, | |
} |