hbertrand's picture
PR: source display name (#80)
6aad21a unverified
import itertools
import sqlite3
from pathlib import Path
from typing import Iterable, NamedTuple
import numpy as np
import pandas as pd
import buster.documents.sqlite.schema as schema
from buster.documents.base import DocumentsManager
class Section(NamedTuple):
title: str
url: str
content: str
parent: int | None = None
type: str = "section"
class Chunk(NamedTuple):
content: str
n_tokens: int
emb: np.ndarray
class DocumentsDB(DocumentsManager):
"""Simple SQLite database for storing documents and questions/answers.
The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
Questions/answers refer to the version of the document that was used at the time.
Example:
>>> db = DocumentsDB("/path/to/the/db.db")
>>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
"""
def __init__(self, db_path: sqlite3.Connection | str):
if isinstance(db_path, (str, Path)):
self.db_path = db_path
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
else:
self.db_path = None
self.conn = db_path
schema.initialize_db(self.conn)
schema.setup_db(self.conn)
def __del__(self):
if self.db_path is not None:
self.conn.close()
def get_current_version(self, source: str) -> tuple[int, int]:
"""Get the current version of a source."""
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
row = cur.fetchone()
if row is None:
raise KeyError(f'"{source}" is not a known source')
sid, vid = row
return sid, vid
def get_source(self, source: str) -> int:
"""Get the id of a source."""
cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
row = cur.fetchone()
if row is not None:
(sid,) = row
else:
cur = self.conn.execute("INSERT INTO sources (name) VALUES (?)", (source,))
cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
row = cur.fetchone()
(sid,) = row
return sid
def new_version(self, source: str) -> tuple[int, int]:
"""Create a new version for a source."""
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
row = cur.fetchone()
if row is None:
sid = self.get_source(source)
vid = 0
else:
sid, vid = row
vid = vid + 1
self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
return sid, vid
def add_parse(self, source: str, sections: Iterable[Section]) -> tuple[int, int]:
"""Create a new version of a source filled with parsed sections."""
sid, vid = self.new_version(source)
values = (
(sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
for ind, section in enumerate(sections)
)
self.conn.executemany(
"INSERT INTO sections "
"(source, version, section, title, url, content, parent, type) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
values,
)
return sid, vid
def new_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
"""Create a new chunking for a source."""
self.conn.execute(
"INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
(size, overlap, strategy, sid, vid),
)
cur = self.conn.execute(
"SELECT chunking FROM chunkings "
"WHERE size = ? AND overlap = ? AND strategy = ? AND source = ? AND version = ?",
(size, overlap, strategy, sid, vid),
)
(id,) = (id for id, in cur)
return id
def add_chunking(self, sid: int, vid: int, size: int, sections: Iterable[Iterable[Chunk]]) -> int:
"""Create a new chunking for a source, filled with chunks organized by section."""
cid = self.new_chunking(sid, vid, size)
chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
self.conn.executemany(
"INSERT INTO chunks "
"(source, version, section, chunking, sequence, content, n_tokens, embedding) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
values,
)
return cid
def add(self, source: str, df: pd.DataFrame):
"""Write all documents from the dataframe into the db as a new version."""
data = sorted(df.itertuples(), key=lambda chunk: (chunk.url, chunk.title))
sections = []
size = 0
for (url, title), chunks in itertools.groupby(data, lambda chunk: (chunk.url, chunk.title)):
chunks = [Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks]
size = max(size, max(len(chunk.content) for chunk in chunks))
content = "".join(chunk.content for chunk in chunks)
sections.append((Section(title, url, content), chunks))
sid, vid = self.add_parse(source, (section for section, _ in sections))
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
self.conn.commit()
def update_source(self, source: str, display_name: str = None, note: str = None):
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
sid = self.get_source(source)
if display_name is not None:
self.conn.execute("UPDATE sources SET display_name = ? WHERE id = ?", (display_name, sid))
if note is not None:
self.conn.execute("UPDATE sources SET note = ? WHERE id = ?", (note, sid))
self.conn.commit()