import sqlite3 from pathlib import Path class Database: def __init__(self, db_path=None): if db_path is None: raise ValueError("db_path must be provided") self.db_path = db_path self.db_file = self.db_path / "models.db" if not self.db_file.exists(): print("Creating database") print("DB_FILE", self.db_file) db = sqlite3.connect(self.db_file) with open(Path("schema.sql"), "r") as f: db.executescript(f.read()) db.commit() db.close() def get_db(self): db = sqlite3.connect(self.db_file, check_same_thread=False) db.row_factory = sqlite3.Row return db def __enter__(self): self.db = self.get_db() return self.db def __exit__(self, exc_type, exc_value, traceback): self.db.close()