Spaces:
Runtime error
Runtime error
| """Beta Feature: base interface for cache.""" | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from sqlalchemy import Column, Integer, String, create_engine, select | |
| from sqlalchemy.engine.base import Engine | |
| from sqlalchemy.orm import Session | |
| try: | |
| from sqlalchemy.orm import declarative_base | |
| except ImportError: | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from langchain.schema import Generation | |
| RETURN_VAL_TYPE = List[Generation] | |
| class BaseCache(ABC): | |
| """Base interface for cache.""" | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| class InMemoryCache(BaseCache): | |
| """Cache that stores things in memory.""" | |
| def __init__(self) -> None: | |
| """Initialize with empty cache.""" | |
| self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| return self._cache.get((prompt, llm_string), None) | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| self._cache[(prompt, llm_string)] = return_val | |
| Base = declarative_base() | |
| class FullLLMCache(Base): # type: ignore | |
| """SQLite table for full LLM Cache (all generations).""" | |
| __tablename__ = "full_llm_cache" | |
| prompt = Column(String, primary_key=True) | |
| llm = Column(String, primary_key=True) | |
| idx = Column(Integer, primary_key=True) | |
| response = Column(String) | |
| class SQLAlchemyCache(BaseCache): | |
| """Cache that uses SQAlchemy as a backend.""" | |
| def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache): | |
| """Initialize by creating all tables.""" | |
| self.engine = engine | |
| self.cache_schema = cache_schema | |
| self.cache_schema.metadata.create_all(self.engine) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| stmt = ( | |
| select(self.cache_schema.response) | |
| .where(self.cache_schema.prompt == prompt) | |
| .where(self.cache_schema.llm == llm_string) | |
| .order_by(self.cache_schema.idx) | |
| ) | |
| with Session(self.engine) as session: | |
| generations = [Generation(text=row[0]) for row in session.execute(stmt)] | |
| if len(generations) > 0: | |
| return generations | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Look up based on prompt and llm_string.""" | |
| for i, generation in enumerate(return_val): | |
| item = self.cache_schema( | |
| prompt=prompt, llm=llm_string, response=generation.text, idx=i | |
| ) | |
| with Session(self.engine) as session, session.begin(): | |
| session.merge(item) | |
| class SQLiteCache(SQLAlchemyCache): | |
| """Cache that uses SQLite as a backend.""" | |
| def __init__(self, database_path: str = ".langchain.db"): | |
| """Initialize by creating the engine and all tables.""" | |
| engine = create_engine(f"sqlite:///{database_path}") | |
| super().__init__(engine) | |
| class RedisCache(BaseCache): | |
| """Cache that uses Redis as a backend.""" | |
| def __init__(self, redis_: Any): | |
| """Initialize by passing in Redis instance.""" | |
| try: | |
| from redis import Redis | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import redis python package. " | |
| "Please install it with `pip install redis`." | |
| ) | |
| if not isinstance(redis_, Redis): | |
| raise ValueError("Please pass in Redis object.") | |
| self.redis = redis_ | |
| def _key(self, prompt: str, llm_string: str, idx: int) -> str: | |
| """Compute key from prompt, llm_string, and idx.""" | |
| return str(hash(prompt + llm_string)) + "_" + str(idx) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| idx = 0 | |
| generations = [] | |
| while self.redis.get(self._key(prompt, llm_string, idx)): | |
| result = self.redis.get(self._key(prompt, llm_string, idx)) | |
| if not result: | |
| break | |
| elif isinstance(result, bytes): | |
| result = result.decode() | |
| generations.append(Generation(text=result)) | |
| idx += 1 | |
| return generations if generations else None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| for i, generation in enumerate(return_val): | |
| self.redis.set(self._key(prompt, llm_string, i), generation.text) | |