import sqlite3 from pathlib import Path class Database: def __init__(self, db_path=None): if db_path is None: raise ValueError("db_path must be provided") self.db_path = db_path self.db_file = self.db_path / "cache.db" if not self.db_file.exists(): print("Creating database") print("DB_FILE", self.db_file) db = sqlite3.connect(self.db_file) with open(Path("schema.sql"), "r") as f: db.executescript(f.read()) db.commit() db.close() def get_db(self): db = sqlite3.connect(self.db_file, check_same_thread=False) db.row_factory = sqlite3.Row return db def __enter__(self): self.db = self.get_db() return self.db def __exit__(self, exc_type, exc_value, traceback): self.db.close() def __call__(self): return self def insert(self, prompt: str, negative_prompt: str, image_path: str, seed: int): with self() as db: cursor = db.cursor() cursor.execute( "INSERT INTO cache (prompt, negative_prompt, image_path, seed) VALUES (?, ?, ?, ?)", (prompt, negative_prompt, image_path, seed), ) db.commit() def check(self, prompt: str, negative_prompt: str, seed: int): with self() as db: cursor = db.cursor() cursor.execute( "SELECT image_path FROM cache WHERE prompt = ? AND negative_prompt = ? AND seed = ? ORDER BY RANDOM() LIMIT 1", (prompt, negative_prompt, seed), ) image_path = cursor.fetchone() if image_path: return image_path return False