Spaces:
Runtime error
Runtime error
import logging | |
import sqlite3 | |
from collections import namedtuple | |
from contextlib import contextmanager | |
from typing import Generator | |
from gistillery.config import get_config | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
schema_entries = """ | |
CREATE TABLE entries | |
( | |
id TEXT PRIMARY KEY, | |
author TEXT NOT NULL, | |
source TEXT NOT NULL, | |
source_snippet TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
""" | |
# create schema for 'summary' table, id is a uuid4 | |
schema_summary = """ | |
CREATE TABLE summaries | |
( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
entry_id TEXT NOT NULL, | |
summary TEXT NOT NULL, | |
summarizer_name TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
FOREIGN KEY(entry_id) REFERENCES entries(id) | |
) | |
""" | |
schema_tag = """ | |
CREATE TABLE tags | |
( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
entry_id TEXT NOT NULL, | |
tag TEXT NOT NULL, | |
tagger_name TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
FOREIGN KEY(entry_id) REFERENCES entries(id) | |
) | |
""" | |
schema_job = """ | |
CREATE TABLE jobs | |
( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
entry_id TEXT NOT NULL, | |
status TEXT NOT NULL DEFAULT 'pending', | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
FOREIGN KEY(entry_id) REFERENCES entries(id) | |
) | |
""" | |
# store the processed inputs | |
schema_inputs = """ | |
CREATE TABLE inputs | |
( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
entry_id TEXT NOT NULL, | |
input TEXT NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
FOREIGN KEY(entry_id) REFERENCES entries(id) | |
) | |
""" | |
TABLES = { | |
'entries': schema_entries, | |
'summaries': schema_summary, | |
'tags': schema_tag, | |
'jobs': schema_job, | |
'inputs': schema_inputs, | |
} | |
TABLES_CREATED = False | |
# https://docs.python.org/3/library/sqlite3.html#how-to-create-and-use-row-factories | |
def namedtuple_factory(cursor, row): # type: ignore | |
fields = [column[0] for column in cursor.description] | |
cls = namedtuple("Row", fields) # type: ignore | |
return cls._make(row) | |
def get_db_connection() -> sqlite3.Connection: | |
global TABLES_CREATED | |
# sqlite cannot deal with concurrent access, so we set a big timeout | |
conn = sqlite3.connect(get_config().db_file_name, timeout=30) | |
conn.row_factory = namedtuple_factory | |
if TABLES_CREATED: | |
return conn | |
cursor = conn.cursor() | |
# create tables if needed | |
for table_name, schema in TABLES.items(): | |
cursor.execute( | |
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", | |
(table_name,), | |
) | |
table_exists = cursor.fetchone() is not None | |
if not table_exists: | |
logger.info(f"'{table_name}' table does not exist, creating it now...") | |
cursor.execute(schema) | |
conn.commit() | |
logger.info("done") | |
TABLES_CREATED = True | |
return conn | |
def get_db_cursor() -> Generator[sqlite3.Cursor, None, None]: | |
conn = get_db_connection() | |
cursor = conn.cursor() | |
try: | |
yield cursor | |
finally: | |
conn.commit() | |
cursor.close() | |
conn.close() | |