from vsp.llm.llm_cache import LLMCache from vsp.llm.llm_service import LLMService from vsp.shared import logger_factory logger = logger_factory.get_logger(__name__) class CachedLLMService(LLMService): def __init__(self, llm_service: LLMService, cache: LLMCache | None = None): self._llm_service = llm_service self._cache = cache or LLMCache() async def invoke( self, user_prompt: str | None = None, system_prompt: str | None = None, partial_assistant_prompt: str | None = None, max_tokens: int = 1000, temperature: float = 0.0, ) -> str | None: cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}" cached_response = self._cache.get(cache_key, {}) if cached_response is not None: logger.debug("LLM cache hit") return cached_response response = await self._llm_service.invoke( user_prompt=user_prompt, system_prompt=system_prompt, partial_assistant_prompt=partial_assistant_prompt, max_tokens=max_tokens, temperature=temperature, ) if response is not None: self._cache.set(cache_key, response, {}) return response