Spaces:
Runtime error
Runtime error
"""Beta Feature: base interface for cache.""" | |
from abc import ABC, abstractmethod | |
from typing import Any, Dict, List, Optional, Tuple | |
from sqlalchemy import Column, Integer, String, create_engine, select | |
from sqlalchemy.engine.base import Engine | |
from sqlalchemy.orm import Session | |
try: | |
from sqlalchemy.orm import declarative_base | |
except ImportError: | |
from sqlalchemy.ext.declarative import declarative_base | |
from langchain.schema import Generation | |
RETURN_VAL_TYPE = List[Generation] | |
class BaseCache(ABC): | |
"""Base interface for cache.""" | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
class InMemoryCache(BaseCache): | |
"""Cache that stores things in memory.""" | |
def __init__(self) -> None: | |
"""Initialize with empty cache.""" | |
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
return self._cache.get((prompt, llm_string), None) | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
self._cache[(prompt, llm_string)] = return_val | |
Base = declarative_base() | |
class FullLLMCache(Base): # type: ignore | |
"""SQLite table for full LLM Cache (all generations).""" | |
__tablename__ = "full_llm_cache" | |
prompt = Column(String, primary_key=True) | |
llm = Column(String, primary_key=True) | |
idx = Column(Integer, primary_key=True) | |
response = Column(String) | |
class SQLAlchemyCache(BaseCache): | |
"""Cache that uses SQAlchemy as a backend.""" | |
def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache): | |
"""Initialize by creating all tables.""" | |
self.engine = engine | |
self.cache_schema = cache_schema | |
self.cache_schema.metadata.create_all(self.engine) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
stmt = ( | |
select(self.cache_schema.response) | |
.where(self.cache_schema.prompt == prompt) | |
.where(self.cache_schema.llm == llm_string) | |
.order_by(self.cache_schema.idx) | |
) | |
with Session(self.engine) as session: | |
generations = [Generation(text=row[0]) for row in session.execute(stmt)] | |
if len(generations) > 0: | |
return generations | |
return None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Look up based on prompt and llm_string.""" | |
for i, generation in enumerate(return_val): | |
item = self.cache_schema( | |
prompt=prompt, llm=llm_string, response=generation.text, idx=i | |
) | |
with Session(self.engine) as session, session.begin(): | |
session.merge(item) | |
class SQLiteCache(SQLAlchemyCache): | |
"""Cache that uses SQLite as a backend.""" | |
def __init__(self, database_path: str = ".langchain.db"): | |
"""Initialize by creating the engine and all tables.""" | |
engine = create_engine(f"sqlite:///{database_path}") | |
super().__init__(engine) | |
class RedisCache(BaseCache): | |
"""Cache that uses Redis as a backend.""" | |
def __init__(self, redis_: Any): | |
"""Initialize by passing in Redis instance.""" | |
try: | |
from redis import Redis | |
except ImportError: | |
raise ValueError( | |
"Could not import redis python package. " | |
"Please install it with `pip install redis`." | |
) | |
if not isinstance(redis_, Redis): | |
raise ValueError("Please pass in Redis object.") | |
self.redis = redis_ | |
def _key(self, prompt: str, llm_string: str, idx: int) -> str: | |
"""Compute key from prompt, llm_string, and idx.""" | |
return str(hash(prompt + llm_string)) + "_" + str(idx) | |
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
"""Look up based on prompt and llm_string.""" | |
idx = 0 | |
generations = [] | |
while self.redis.get(self._key(prompt, llm_string, idx)): | |
result = self.redis.get(self._key(prompt, llm_string, idx)) | |
if not result: | |
break | |
elif isinstance(result, bytes): | |
result = result.decode() | |
generations.append(Generation(text=result)) | |
idx += 1 | |
return generations if generations else None | |
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
"""Update cache based on prompt and llm_string.""" | |
for i, generation in enumerate(return_val): | |
self.redis.set(self._key(prompt, llm_string, i), generation.text) | |