Spaces:
Runtime error
Runtime error
import datetime as dt | |
import logging | |
import reprlib | |
import sqlite3 | |
import uuid | |
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
from gistillery.base import EntriesResult, JobStatus, JobStatusResult, RequestInput | |
from gistillery.db import TABLES, get_db_connection, get_db_cursor | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
app = FastAPI() | |
aRepr = reprlib.Repr() | |
aRepr.maxstring = 140 | |
# status | |
def status() -> str: | |
return "OK" | |
def submit_job(input: RequestInput) -> str: | |
# submit a new job, poor man's job queue | |
_id = uuid.uuid4().hex | |
logger.info(f"Submitting job for (_id={_id[:8]})") | |
with get_db_cursor() as cursor: | |
# create a job | |
query = "INSERT INTO jobs (entry_id, status) VALUES (?, ?)" | |
cursor.execute(query, (_id, "pending")) | |
# create an entry | |
source_snippet = aRepr.repr(input.content).strip("'") | |
query = ( | |
"INSERT INTO entries (id, author, source, source_snippet) VALUES " | |
"(?, ?, ?, ?)" | |
) | |
cursor.execute(query, (_id, input.author, input.content, source_snippet)) | |
return f"Submitted job {_id}" | |
def check_job_status() -> str: | |
with get_db_cursor() as cursor: | |
cursor.execute( | |
"SELECT entry_id " | |
"FROM jobs WHERE status = 'pending' " | |
"ORDER BY last_updated ASC" | |
) | |
result = cursor.fetchall() | |
if not result: | |
return "No pending jobs found" | |
entry_ids = [r.entry_id for r in result] | |
num_entries = len(entry_ids) | |
if len(entry_ids) > 3: | |
entry_ids = entry_ids[:3] + ["..."] | |
return f"Found {num_entries} pending job(s): {', '.join(entry_ids)}" | |
def check_job_status_id(_id: str) -> JobStatusResult: | |
with get_db_cursor() as cursor: | |
cursor.execute( | |
"SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,) | |
) | |
result = cursor.fetchone() | |
if result is None: | |
return JobStatusResult(id=_id, status=JobStatus.not_found, last_updated=None) | |
status, last_updated = result | |
return JobStatusResult(id=_id, status=status, last_updated=last_updated) | |
def recent() -> list[EntriesResult]: | |
with get_db_cursor() as cursor: | |
# get the last 10 entries, join summary and tag, where each tag is | |
# joined to a comma separated str | |
cursor.execute(""" | |
SELECT | |
e.id, e.author, e.created_at, e.source_snippet, | |
s.summary, GROUP_CONCAT(t.tag, ",") tags | |
FROM entries e | |
JOIN summaries s ON e.id = s.entry_id | |
JOIN tags t ON e.id = t.entry_id | |
GROUP BY e.id | |
ORDER BY e.created_at DESC | |
LIMIT 100 | |
""") | |
results = cursor.fetchall() | |
entries = [] | |
for row in results: | |
entry = EntriesResult( | |
id=row.id, | |
author=row.author, | |
summary=row.summary, | |
source_snippet=row.source_snippet, | |
tags=row.tags.split(","), | |
date=row.created_at, | |
) | |
entries.append(entry) | |
return entries | |
def recent_tag(tag: str) -> list[EntriesResult]: | |
tags = tag.split(",") | |
tags = ["#" + tag for tag in tags if not tag.startswith("#")] | |
# same as recent, but filter by tags, where at least one tag matches | |
with get_db_cursor() as cursor: | |
cursor.execute( | |
""" | |
SELECT | |
e.id, e.author, e.source_snippet, e.created_at, | |
s.summary, GROUP_CONCAT(t.tag, ",") tags | |
FROM entries e | |
JOIN summaries s ON e.id = s.entry_id | |
JOIN tags t ON e.id = t.entry_id | |
WHERE e.id IN ( | |
SELECT entry_id FROM tags WHERE tag IN ({}) | |
) | |
GROUP BY e.id | |
ORDER BY e.created_at DESC | |
LIMIT 100 | |
""".format(",".join("?" * len(tags))), | |
tags, | |
) | |
results = cursor.fetchall() | |
entries = [] | |
for row in results: | |
entry = EntriesResult( | |
id=row.id, | |
author=row.author, | |
summary=row.summary, | |
source_snippet=row.source_snippet, | |
tags=row.tags.split(","), | |
date=row.created_at, | |
) | |
entries.append(entry) | |
return entries | |
def tag_counts() -> dict[str, int]: | |
with get_db_cursor() as cursor: | |
cursor.execute(""" | |
SELECT tag, COUNT(*) count | |
FROM tags | |
GROUP BY tag | |
ORDER BY count DESC | |
""") | |
results = cursor.fetchall() | |
return {tag: count for tag, count in results} | |
def clear() -> str: | |
# clear all tables | |
logger.warning("Clearing all tables") | |
with get_db_cursor() as cursor: | |
for table_name in TABLES: | |
cursor.execute(f"DELETE FROM {table_name}") | |
return "OK" | |
def backup() -> FileResponse: | |
# create a backup and return it | |
def progress(status: int, remaining: int, total: int) -> None: | |
logger.debug(f"DB: Copied {total-remaining} of {total} pages...") | |
now = dt.datetime.now(dt.timezone.utc) | |
fname = f"sqlite-data_backup_{now.strftime('%Y-%m-%d_%H-%M-%S')}.db" | |
try: | |
conn = get_db_connection() | |
backup_db = sqlite3.connect(fname) | |
with backup_db: | |
conn.backup(backup_db, pages=1, progress=progress) | |
except Exception as e: | |
logger.error(f"Error creating backup: {e}") | |
conn.close() | |
backup_db.close() | |
raise e | |
return FileResponse(fname, media_type="application/octet-stream", filename=fname) | |