| """Cache-Aware Prompt Layout: Optimize prompt structure for prefix-cache reuse.""" |
| from typing import Dict, List, Tuple |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class PromptLayout: |
| prefix_content: str |
| suffix_content: str |
| prefix_tokens: int |
| suffix_tokens: int |
| cache_boundary: int |
| stable_sources: List[str] |
| dynamic_sources: List[str] |
|
|
| CACHE_STABLE_SOURCES = {"system_rules", "tool_descriptions", "user_preferences"} |
| CACHE_DYNAMIC_SOURCES = {"recent_messages", "task_plan", "retrieved_docs", "artifacts"} |
|
|
| class CacheAwareLayout: |
| def __init__(self, session_id: str = None): |
| self.session_id = session_id |
| self._prev_prefix_hash = None |
| self.cache_hits = 0 |
| self.cache_misses = 0 |
| self.total_prefix_tokens = 0 |
| self.total_suffix_tokens = 0 |
|
|
| def layout(self, sources: Dict[str, str], context_budget) -> PromptLayout: |
| prefix_parts = [] |
| suffix_parts = [] |
| stable = [] |
| dynamic = [] |
| prefix_tokens = 0 |
| suffix_tokens = 0 |
| for source_name, content in sources.items(): |
| token_est = len(content) // 4 |
| if source_name in context_budget.cache_prefix: |
| prefix_parts.append(f"# {source_name}\n{content}") |
| prefix_tokens += token_est |
| stable.append(source_name) |
| elif source_name in context_budget.dynamic_suffix: |
| suffix_parts.append(f"# {source_name}\n{content}") |
| suffix_tokens += token_est |
| dynamic.append(source_name) |
| elif source_name in context_budget.keep_exact: |
| prefix_parts.append(f"# {source_name}\n{content}") |
| prefix_tokens += token_est |
| stable.append(source_name) |
| else: |
| suffix_parts.append(f"# {source_name}\n{content}") |
| suffix_tokens += token_est |
| dynamic.append(source_name) |
| prefix = "\n\n".join(prefix_parts) |
| suffix = "\n\n".join(suffix_parts) |
| |
| import hashlib |
| prefix_hash = hashlib.md5(prefix.encode()).hexdigest() |
| if self._prev_prefix_hash == prefix_hash: |
| self.cache_hits += 1 |
| else: |
| self.cache_misses += 1 |
| self._prev_prefix_hash = prefix_hash |
| self.total_prefix_tokens += prefix_tokens |
| self.total_suffix_tokens += suffix_tokens |
| return PromptLayout( |
| prefix_content=prefix, |
| suffix_content=suffix, |
| prefix_tokens=prefix_tokens, |
| suffix_tokens=suffix_tokens, |
| cache_boundary=prefix_tokens, |
| stable_sources=stable, |
| dynamic_sources=dynamic, |
| ) |
|
|
| def stats(self) -> Dict: |
| total = self.cache_hits + self.cache_misses |
| return { |
| "cache_hit_rate": self.cache_hits / max(total, 1), |
| "total_cache_hits": self.cache_hits, |
| "total_cache_misses": self.cache_misses, |
| "avg_prefix_tokens": self.total_prefix_tokens / max(total, 1), |
| "avg_suffix_tokens": self.total_suffix_tokens / max(total, 1), |
| } |
|
|