File size: 2,071 Bytes
06bca0c
 
 
 
 
 
6aad21a
06bca0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aad21a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sqlite3
from pathlib import Path

import pandas as pd

import buster.documents.sqlite.schema as schema
from buster.retriever.base import ALL_SOURCES, Retriever


class SQLiteRetriever(Retriever):
    """Simple SQLite database for retrieval of documents.

    The database is just a file on disk. It can store documents from different sources, and it
    can store multiple versions of the same document (e.g. if the document is updated).

    Example:
        >>> db = DocumentsDB("/path/to/the/db.db")
        >>> df = db.get_documents("source")
    """

    def __init__(self, db_path: sqlite3.Connection | str):
        if isinstance(db_path, (str, Path)):
            self.db_path = db_path
            self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
        else:
            self.db_path = None
            self.conn = db_path
        schema.setup_db(self.conn)

    def __del__(self):
        if self.db_path is not None:
            self.conn.close()

    def get_documents(self, source: str) -> pd.DataFrame:
        """Get all current documents from a given source."""
        # Execute the SQL statement and fetch the results.
        if source is "":
            results = self.conn.execute("SELECT * FROM documents")
        else:
            results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
        rows = results.fetchall()

        # Convert the results to a pandas DataFrame
        df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
        return df

    def get_source_display_name(self, source: str) -> str:
        """Get the display name of a source."""
        if source is "":
            return ALL_SOURCES
        else:
            cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,))
            row = cur.fetchone()
            if row is None:
                raise KeyError(f'"{source}" is not a known source')
            (display_name,) = row
            return display_name