Spaces:
Runtime error
Runtime error
PR: source display name (#80)
Browse files* source display name
* tests
* black
* CR
* isort
- buster/docparser.py +5 -0
- buster/documents/base.py +6 -0
- buster/documents/pickle.py +5 -0
- buster/documents/sqlite/documents.py +10 -0
- buster/documents/sqlite/schema.py +1 -0
- buster/retriever/base.py +8 -0
- buster/retriever/pickle.py +9 -1
- buster/retriever/sqlite.py +13 -1
- tests/test_chatbot.py +3 -0
- tests/test_documents.py +11 -0
buster/docparser.py
CHANGED
@@ -157,6 +157,11 @@ def documents_to_db(documents: pd.DataFrame, output_filepath: str):
|
|
157 |
logger.info(f"Documents saved to: {output_filepath}")
|
158 |
|
159 |
|
|
|
|
|
|
|
|
|
|
|
160 |
def generate_embeddings(
|
161 |
documents: pd.DataFrame,
|
162 |
output_filepath: str = "documents.db",
|
|
|
157 |
logger.info(f"Documents saved to: {output_filepath}")
|
158 |
|
159 |
|
160 |
+
def update_source(source: str, output_filepath: str, display_name: str = None, note: str = None):
|
161 |
+
documents_manager = get_documents_manager_from_extension(output_filepath)(output_filepath)
|
162 |
+
documents_manager.update_source(source, display_name, note)
|
163 |
+
|
164 |
+
|
165 |
def generate_embeddings(
|
166 |
documents: pd.DataFrame,
|
167 |
output_filepath: str = "documents.db",
|
buster/documents/base.py
CHANGED
@@ -8,4 +8,10 @@ import pandas as pd
|
|
8 |
class DocumentsManager(ABC):
|
9 |
@abstractmethod
|
10 |
def add(self, source: str, df: pd.DataFrame):
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
...
|
|
|
8 |
class DocumentsManager(ABC):
|
9 |
@abstractmethod
|
10 |
def add(self, source: str, df: pd.DataFrame):
|
11 |
+
"""Write all documents from the dataframe into the db as a new version."""
|
12 |
+
...
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def update_source(self, source: str, display_name: str = None, note: str = None):
|
16 |
+
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
|
17 |
...
|
buster/documents/pickle.py
CHANGED
@@ -15,6 +15,7 @@ class DocumentsPickle(DocumentsManager):
|
|
15 |
self.documents = None
|
16 |
|
17 |
def add(self, source: str, df: pd.DataFrame):
|
|
|
18 |
if source is not None:
|
19 |
df["source"] = source
|
20 |
|
@@ -27,3 +28,7 @@ class DocumentsPickle(DocumentsManager):
|
|
27 |
self.documents = df
|
28 |
|
29 |
self.documents.to_pickle(self.filepath)
|
|
|
|
|
|
|
|
|
|
15 |
self.documents = None
|
16 |
|
17 |
def add(self, source: str, df: pd.DataFrame):
|
18 |
+
"""Write all documents from the dataframe into the db as a new version."""
|
19 |
if source is not None:
|
20 |
df["source"] = source
|
21 |
|
|
|
28 |
self.documents = df
|
29 |
|
30 |
self.documents.to_pickle(self.filepath)
|
31 |
+
|
32 |
+
def update_source(self, source: str, display_name: str = None, note: str = None):
|
33 |
+
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
|
34 |
+
print("If you need this function, please switch your backend to DocumentsDB.")
|
buster/documents/sqlite/documents.py
CHANGED
@@ -141,3 +141,13 @@ class DocumentsDB(DocumentsManager):
|
|
141 |
sid, vid = self.add_parse(source, (section for section, _ in sections))
|
142 |
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
|
143 |
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
sid, vid = self.add_parse(source, (section for section, _ in sections))
|
142 |
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
|
143 |
self.conn.commit()
|
144 |
+
|
145 |
+
def update_source(self, source: str, display_name: str = None, note: str = None):
|
146 |
+
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
|
147 |
+
sid = self.get_source(source)
|
148 |
+
|
149 |
+
if display_name is not None:
|
150 |
+
self.conn.execute("UPDATE sources SET display_name = ? WHERE id = ?", (display_name, sid))
|
151 |
+
if note is not None:
|
152 |
+
self.conn.execute("UPDATE sources SET note = ? WHERE id = ?", (note, sid))
|
153 |
+
self.conn.commit()
|
buster/documents/sqlite/schema.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
|
7 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
8 |
name TEXT NOT NULL,
|
|
|
9 |
note TEXT,
|
10 |
UNIQUE(name)
|
11 |
)"""
|
|
|
6 |
SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
|
7 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
8 |
name TEXT NOT NULL,
|
9 |
+
display_name TEXT,
|
10 |
note TEXT,
|
11 |
UNIQUE(name)
|
12 |
)"""
|
buster/retriever/base.py
CHANGED
@@ -4,11 +4,19 @@ from dataclasses import dataclass
|
|
4 |
import pandas as pd
|
5 |
from openai.embeddings_utils import cosine_similarity
|
6 |
|
|
|
|
|
7 |
|
8 |
@dataclass
|
9 |
class Retriever(ABC):
|
10 |
@abstractmethod
|
11 |
def get_documents(self, source: str) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
...
|
13 |
|
14 |
def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
|
|
|
4 |
import pandas as pd
|
5 |
from openai.embeddings_utils import cosine_similarity
|
6 |
|
7 |
+
ALL_SOURCES = "All"
|
8 |
+
|
9 |
|
10 |
@dataclass
|
11 |
class Retriever(ABC):
|
12 |
@abstractmethod
|
13 |
def get_documents(self, source: str) -> pd.DataFrame:
|
14 |
+
"""Get all current documents from a given source."""
|
15 |
+
...
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def get_source_display_name(self, source: str) -> str:
|
19 |
+
"""Get the display name of a source."""
|
20 |
...
|
21 |
|
22 |
def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
|
buster/retriever/pickle.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import pandas as pd
|
2 |
|
3 |
-
from buster.retriever.base import Retriever
|
4 |
|
5 |
|
6 |
class PickleRetriever(Retriever):
|
@@ -9,6 +9,7 @@ class PickleRetriever(Retriever):
|
|
9 |
self.documents = pd.read_pickle(filepath)
|
10 |
|
11 |
def get_documents(self, source: str) -> pd.DataFrame:
|
|
|
12 |
if self.documents is None:
|
13 |
raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
|
14 |
|
@@ -24,3 +25,10 @@ class PickleRetriever(Retriever):
|
|
24 |
documents = documents[documents.source == source]
|
25 |
|
26 |
return documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
|
3 |
+
from buster.retriever.base import ALL_SOURCES, Retriever
|
4 |
|
5 |
|
6 |
class PickleRetriever(Retriever):
|
|
|
9 |
self.documents = pd.read_pickle(filepath)
|
10 |
|
11 |
def get_documents(self, source: str) -> pd.DataFrame:
|
12 |
+
"""Get all current documents from a given source."""
|
13 |
if self.documents is None:
|
14 |
raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
|
15 |
|
|
|
25 |
documents = documents[documents.source == source]
|
26 |
|
27 |
return documents
|
28 |
+
|
29 |
+
def get_source_display_name(self, source: str) -> str:
|
30 |
+
"""Get the display name of a source."""
|
31 |
+
if source is None:
|
32 |
+
return ALL_SOURCES
|
33 |
+
else:
|
34 |
+
return source
|
buster/retriever/sqlite.py
CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
4 |
import pandas as pd
|
5 |
|
6 |
import buster.documents.sqlite.schema as schema
|
7 |
-
from buster.retriever.base import Retriever
|
8 |
|
9 |
|
10 |
class SQLiteRetriever(Retriever):
|
@@ -44,3 +44,15 @@ class SQLiteRetriever(Retriever):
|
|
44 |
# Convert the results to a pandas DataFrame
|
45 |
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
46 |
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import pandas as pd
|
5 |
|
6 |
import buster.documents.sqlite.schema as schema
|
7 |
+
from buster.retriever.base import ALL_SOURCES, Retriever
|
8 |
|
9 |
|
10 |
class SQLiteRetriever(Retriever):
|
|
|
44 |
# Convert the results to a pandas DataFrame
|
45 |
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
46 |
return df
|
47 |
+
|
48 |
+
def get_source_display_name(self, source: str) -> str:
|
49 |
+
"""Get the display name of a source."""
|
50 |
+
if source is "":
|
51 |
+
return ALL_SOURCES
|
52 |
+
else:
|
53 |
+
cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,))
|
54 |
+
row = cur.fetchone()
|
55 |
+
if row is None:
|
56 |
+
raise KeyError(f'"{source}" is not a known source')
|
57 |
+
(display_name,) = row
|
58 |
+
return display_name
|
tests/test_chatbot.py
CHANGED
@@ -49,6 +49,9 @@ class MockRetriever(Retriever):
|
|
49 |
def get_documents(self, source):
|
50 |
return self.documents
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
import logging
|
54 |
|
|
|
49 |
def get_documents(self, source):
|
50 |
return self.documents
|
51 |
|
52 |
+
def get_source_display_name(self, source):
|
53 |
+
return source
|
54 |
+
|
55 |
|
56 |
import logging
|
57 |
|
tests/test_documents.py
CHANGED
@@ -70,3 +70,14 @@ def test_write_write_read(tmp_path, documents_manager, retriever, extension):
|
|
70 |
assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
|
71 |
assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
|
72 |
assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
|
71 |
assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
|
72 |
assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
|
73 |
+
|
74 |
+
|
75 |
+
def test_update_source(tmp_path):
|
76 |
+
display_name = "Super Test"
|
77 |
+
db = DocumentsDB(tmp_path / "test.db")
|
78 |
+
|
79 |
+
db.update_source(source="test", display_name=display_name)
|
80 |
+
|
81 |
+
returned_display_name = SQLiteRetriever(tmp_path / "test.db").get_source_display_name("test")
|
82 |
+
|
83 |
+
assert display_name == returned_display_name
|