jerpint's picture
last fixes
2a23253
import sqlite3
from pathlib import Path
import pandas as pd
import buster.documents.sqlite.schema as schema
from buster.retriever.base import ALL_SOURCES, Retriever
class SQLiteRetriever(Retriever):
"""Simple SQLite database for retrieval of documents.
The database is just a file on disk. It can store documents from different sources, and it
can store multiple versions of the same document (e.g. if the document is updated).
Example:
>>> db = DocumentsDB("/path/to/the/db.db")
>>> df = db.get_documents("source")
"""
def __init__(self, db_path: sqlite3.Connection | str):
if isinstance(db_path, (str, Path)):
self.db_path = db_path
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
else:
self.db_path = None
self.conn = db_path
schema.setup_db(self.conn)
def __del__(self):
if self.db_path is not None:
self.conn.close()
def get_documents(self, source: str) -> pd.DataFrame:
"""Get all current documents from a given source."""
# Execute the SQL statement and fetch the results.
if source is "":
results = self.conn.execute("SELECT * FROM documents")
else:
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
rows = results.fetchall()
# Convert the results to a pandas DataFrame
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
return df
def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
if source is "":
return ALL_SOURCES
else:
cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,))
row = cur.fetchone()
if row is None:
raise KeyError(f'"{source}" is not a known source')
(display_name,) = row
return display_name