"""Test base LLM functionality.""" from sqlalchemy import Column, Integer, Sequence, String, create_engine try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base import langchain from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.schema import Generation, LLMResult from tests.unit_tests.llms.fake_llm import FakeLLM def test_caching() -> None: """Test caching behavior.""" langchain.llm_cache = InMemoryCache() llm = FakeLLM() params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) output = llm.generate(["foo", "bar", "foo"]) expected_cache_output = [Generation(text="foo")] cache_output = langchain.llm_cache.lookup("bar", llm_string) assert cache_output == expected_cache_output langchain.llm_cache = None expected_generations = [ [Generation(text="fizz")], [Generation(text="foo")], [Generation(text="fizz")], ] expected_output = LLMResult( generations=expected_generations, llm_output=None, ) assert output == expected_output def test_custom_caching() -> None: """Test custom_caching behavior.""" Base = declarative_base() class FulltextLLMCache(Base): # type: ignore """Postgres table for fulltext-indexed LLM Cache.""" __tablename__ = "llm_cache_fulltext" id = Column(Integer, Sequence("cache_id"), primary_key=True) prompt = Column(String, nullable=False) llm = Column(String, nullable=False) idx = Column(Integer) response = Column(String) engine = create_engine("sqlite://") langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache) llm = FakeLLM() params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) output = llm.generate(["foo", "bar", "foo"]) expected_cache_output = [Generation(text="foo")] cache_output = langchain.llm_cache.lookup("bar", llm_string) assert cache_output == expected_cache_output langchain.llm_cache = None expected_generations = [ [Generation(text="fizz")], [Generation(text="foo")], [Generation(text="fizz")], ] expected_output = LLMResult( generations=expected_generations, llm_output=None, ) assert output == expected_output