Spaces:
Runtime error
Runtime error
import sqlite3 | |
from pathlib import Path | |
import pandas as pd | |
import buster.documents.sqlite.schema as schema | |
from buster.retriever.base import ALL_SOURCES, Retriever | |
class SQLiteRetriever(Retriever): | |
"""Simple SQLite database for retrieval of documents. | |
The database is just a file on disk. It can store documents from different sources, and it | |
can store multiple versions of the same document (e.g. if the document is updated). | |
Example: | |
>>> db = DocumentsDB("/path/to/the/db.db") | |
>>> df = db.get_documents("source") | |
""" | |
def __init__(self, db_path: sqlite3.Connection | str): | |
if isinstance(db_path, (str, Path)): | |
self.db_path = db_path | |
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) | |
else: | |
self.db_path = None | |
self.conn = db_path | |
schema.setup_db(self.conn) | |
def __del__(self): | |
if self.db_path is not None: | |
self.conn.close() | |
def get_documents(self, source: str) -> pd.DataFrame: | |
"""Get all current documents from a given source.""" | |
# Execute the SQL statement and fetch the results. | |
if source is "": | |
results = self.conn.execute("SELECT * FROM documents") | |
else: | |
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,)) | |
rows = results.fetchall() | |
# Convert the results to a pandas DataFrame | |
df = pd.DataFrame(rows, columns=[description[0] for description in results.description]) | |
return df | |
def get_source_display_name(self, source: str) -> str: | |
"""Get the display name of a source.""" | |
if source is "": | |
return ALL_SOURCES | |
else: | |
cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,)) | |
row = cur.fetchone() | |
if row is None: | |
raise KeyError(f'"{source}" is not a known source') | |
(display_name,) = row | |
return display_name | |