TherapyNote / utils /cache.py
abagherp's picture
Upload folder using huggingface_hub
6830eb0 verified
from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
from typing import Any
from datetime import datetime
class CacheManager:
def __init__(self, cache_dir: str | Path = "cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Create SQLite database for structured results
self.db_path = self.cache_dir / "extraction_cache.db"
self._init_db()
def _init_db(self):
"""Initialize the SQLite database with necessary tables."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS extractions (
input_hash TEXT,
form_type TEXT,
result TEXT,
model_name TEXT,
timestamp DATETIME,
PRIMARY KEY (input_hash, form_type)
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS transcripts (
video_id TEXT PRIMARY KEY,
transcript TEXT,
timestamp DATETIME
)
""")
def _hash_content(self, content: str) -> str:
"""Generate a stable hash for input content."""
return hashlib.sha256(content.encode('utf-8')).hexdigest()
def get_transcript(self, video_id: str) -> str | None:
"""Retrieve a cached transcript if it exists."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT transcript FROM transcripts WHERE video_id = ?",
(video_id,)
)
result = cursor.fetchone()
return result[0] if result else None
def store_transcript(self, video_id: str, transcript: str):
"""Store a transcript in the cache."""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT OR REPLACE INTO transcripts (video_id, transcript, timestamp)
VALUES (?, ?, ?)
""",
(video_id, transcript, datetime.now())
)
def get_extraction(
self,
input_content: str,
form_type: str,
model_name: str
) -> dict | None:
"""Retrieve cached extraction results if they exist."""
input_hash = self._hash_content(input_content)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"""
SELECT result FROM extractions
WHERE input_hash = ? AND form_type = ? AND model_name = ?
""",
(input_hash, form_type, model_name)
)
result = cursor.fetchone()
if result:
return json.loads(result[0])
return None
def store_extraction(
self,
input_content: str,
form_type: str,
result: dict,
model_name: str
):
"""Store extraction results in the cache."""
input_hash = self._hash_content(input_content)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT OR REPLACE INTO extractions
(input_hash, form_type, result, model_name, timestamp)
VALUES (?, ?, ?, ?, ?)
""",
(
input_hash,
form_type,
json.dumps(result),
model_name,
datetime.now()
)
)
def clear_cache(self, older_than_days: int | None = None):
"""Clear the cache, optionally only entries older than specified days."""
with sqlite3.connect(self.db_path) as conn:
if older_than_days is not None:
conn.execute(
"""
DELETE FROM extractions
WHERE timestamp < datetime('now', ?)
""",
(f'-{older_than_days} days',)
)
conn.execute(
"""
DELETE FROM transcripts
WHERE timestamp < datetime('now', ?)
""",
(f'-{older_than_days} days',)
)
else:
conn.execute("DELETE FROM extractions")
conn.execute("DELETE FROM transcripts")
def cleanup_gradio_cache(self):
"""Clean up Gradio's example cache directory."""
gradio_cache = Path(".gradio")
if gradio_cache.exists():
import shutil
shutil.rmtree(gradio_cache)
print("Cleaned up Gradio cache")