import itertools import sqlite3 from pathlib import Path from typing import Iterable, NamedTuple import numpy as np import pandas as pd import buster.documents.sqlite.schema as schema from buster.documents.base import DocumentsManager class Section(NamedTuple): title: str url: str content: str parent: int | None = None type: str = "section" class Chunk(NamedTuple): content: str n_tokens: int emb: np.ndarray class DocumentsDB(DocumentsManager): """Simple SQLite database for storing documents and questions/answers. The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated). Questions/answers refer to the version of the document that was used at the time. Example: >>> db = DocumentsDB("/path/to/the/db.db") >>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings """ def __init__(self, db_path: sqlite3.Connection | str): if isinstance(db_path, (str, Path)): self.db_path = db_path self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) else: self.db_path = None self.conn = db_path schema.initialize_db(self.conn) schema.setup_db(self.conn) def __del__(self): if self.db_path is not None: self.conn.close() def get_current_version(self, source: str) -> tuple[int, int]: """Get the current version of a source.""" cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,)) row = cur.fetchone() if row is None: raise KeyError(f'"{source}" is not a known source') sid, vid = row return sid, vid def get_source(self, source: str) -> int: """Get the id of a source.""" cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,)) row = cur.fetchone() if row is not None: (sid,) = row else: cur = self.conn.execute("INSERT INTO sources (name) VALUES (?)", (source,)) cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,)) row = cur.fetchone() (sid,) = row return sid def new_version(self, source: str) -> tuple[int, int]: """Create a new version for a source.""" cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,)) row = cur.fetchone() if row is None: sid = self.get_source(source) vid = 0 else: sid, vid = row vid = vid + 1 self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid)) return sid, vid def add_parse(self, source: str, sections: Iterable[Section]) -> tuple[int, int]: """Create a new version of a source filled with parsed sections.""" sid, vid = self.new_version(source) values = ( (sid, vid, ind, section.title, section.url, section.content, section.parent, section.type) for ind, section in enumerate(sections) ) self.conn.executemany( "INSERT INTO sections " "(source, version, section, title, url, content, parent, type) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", values, ) return sid, vid def new_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int: """Create a new chunking for a source.""" self.conn.execute( "INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)", (size, overlap, strategy, sid, vid), ) cur = self.conn.execute( "SELECT chunking FROM chunkings " "WHERE size = ? AND overlap = ? AND strategy = ? AND source = ? AND version = ?", (size, overlap, strategy, sid, vid), ) (id,) = (id for id, in cur) return id def add_chunking(self, sid: int, vid: int, size: int, sections: Iterable[Iterable[Chunk]]) -> int: """Create a new chunking for a source, filled with chunks organized by section.""" cid = self.new_chunking(sid, vid, size) chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section)) values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks) self.conn.executemany( "INSERT INTO chunks " "(source, version, section, chunking, sequence, content, n_tokens, embedding) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", values, ) return cid def add(self, source: str, df: pd.DataFrame): """Write all documents from the dataframe into the db as a new version.""" data = sorted(df.itertuples(), key=lambda chunk: (chunk.url, chunk.title)) sections = [] size = 0 for (url, title), chunks in itertools.groupby(data, lambda chunk: (chunk.url, chunk.title)): chunks = [Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks] size = max(size, max(len(chunk.content) for chunk in chunks)) content = "".join(chunk.content for chunk in chunks) sections.append((Section(title, url, content), chunks)) sid, vid = self.add_parse(source, (section for section, _ in sections)) self.add_chunking(sid, vid, size, (chunks for _, chunks in sections)) self.conn.commit() def update_source(self, source: str, display_name: str = None, note: str = None): """Update the display name and/or note of a source. Also create the source if it does not exist.""" sid = self.get_source(source) if display_name is not None: self.conn.execute("UPDATE sources SET display_name = ? WHERE id = ?", (display_name, sid)) if note is not None: self.conn.execute("UPDATE sources SET note = ? WHERE id = ?", (note, sid)) self.conn.commit()