Spaces:
Runtime error
Runtime error
PR: retriever interface (#77)
Browse files* retriever interface
* black + isort
* documents -> retriever
* PR
* black
- buster/apps/gradio_app.py +4 -4
- buster/busterbot.py +4 -4
- buster/docparser.py +1 -1
- buster/documents/__init__.py +1 -2
- buster/documents/base.py +0 -19
- buster/documents/pickle.py +0 -16
- buster/documents/sqlite/documents.py +0 -11
- buster/examples/gradio_app.py +4 -4
- buster/parser.py +0 -1
- buster/retriever/__init__.py +5 -0
- buster/retriever/base.py +26 -0
- buster/retriever/pickle.py +26 -0
- buster/retriever/sqlite.py +46 -0
- buster/{documents/utils.py → utils.py} +13 -3
- tests/test_chatbot.py +11 -13
- tests/test_docparser.py +6 -4
- tests/test_documents.py +13 -6
buster/apps/gradio_app.py
CHANGED
@@ -5,19 +5,19 @@ import gradio as gr
|
|
5 |
|
6 |
from buster.apps.bot_configs import available_configs
|
7 |
from buster.busterbot import Buster, BusterConfig
|
8 |
-
from buster.
|
9 |
-
from buster.
|
10 |
|
11 |
DEFAULT_CONFIG = "huggingface"
|
12 |
DB_URL = "https://huggingface.co/datasets/jerpint/buster-data/resolve/main/documents.db"
|
13 |
|
14 |
# Download the db...
|
15 |
documents_filepath = download_db(db_url=DB_URL, output_dir="./data")
|
16 |
-
|
17 |
|
18 |
# initialize buster with the default config...
|
19 |
default_cfg: BusterConfig = available_configs.get(DEFAULT_CONFIG)
|
20 |
-
buster = Buster(cfg=default_cfg,
|
21 |
|
22 |
|
23 |
def chat(question, history, bot_source):
|
|
|
5 |
|
6 |
from buster.apps.bot_configs import available_configs
|
7 |
from buster.busterbot import Buster, BusterConfig
|
8 |
+
from buster.retriever import Retriever
|
9 |
+
from buster.utils import download_db, get_retriever_from_extension
|
10 |
|
11 |
DEFAULT_CONFIG = "huggingface"
|
12 |
DB_URL = "https://huggingface.co/datasets/jerpint/buster-data/resolve/main/documents.db"
|
13 |
|
14 |
# Download the db...
|
15 |
documents_filepath = download_db(db_url=DB_URL, output_dir="./data")
|
16 |
+
retriever: Retriever = get_retriever_from_extension(documents_filepath)(documents_filepath)
|
17 |
|
18 |
# initialize buster with the default config...
|
19 |
default_cfg: BusterConfig = available_configs.get(DEFAULT_CONFIG)
|
20 |
+
buster = Buster(cfg=default_cfg, retriever=retriever)
|
21 |
|
22 |
|
23 |
def chat(question, history, bot_source):
|
buster/busterbot.py
CHANGED
@@ -64,16 +64,16 @@ class BusterConfig:
|
|
64 |
source: str = ""
|
65 |
|
66 |
|
67 |
-
from buster.
|
68 |
|
69 |
|
70 |
class Buster:
|
71 |
-
def __init__(self, cfg: BusterConfig,
|
72 |
self._unk_embedding = None
|
73 |
self.cfg = cfg
|
74 |
self.update_cfg(cfg)
|
75 |
|
76 |
-
self.
|
77 |
|
78 |
@property
|
79 |
def unk_embedding(self):
|
@@ -117,7 +117,7 @@ class Buster:
|
|
117 |
query,
|
118 |
engine=engine,
|
119 |
)
|
120 |
-
matched_documents = self.
|
121 |
|
122 |
# log matched_documents to the console
|
123 |
logger.info(f"matched documents before thresh: {matched_documents}")
|
|
|
64 |
source: str = ""
|
65 |
|
66 |
|
67 |
+
from buster.retriever import Retriever
|
68 |
|
69 |
|
70 |
class Buster:
|
71 |
+
def __init__(self, cfg: BusterConfig, retriever: Retriever):
|
72 |
self._unk_embedding = None
|
73 |
self.cfg = cfg
|
74 |
self.update_cfg(cfg)
|
75 |
|
76 |
+
self.retriever = retriever
|
77 |
|
78 |
@property
|
79 |
def unk_embedding(self):
|
|
|
117 |
query,
|
118 |
engine=engine,
|
119 |
)
|
120 |
+
matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)
|
121 |
|
122 |
# log matched_documents to the console
|
123 |
logger.info(f"matched documents before thresh: {matched_documents}")
|
buster/docparser.py
CHANGED
@@ -10,8 +10,8 @@ import tiktoken
|
|
10 |
from bs4 import BeautifulSoup
|
11 |
from openai.embeddings_utils import get_embedding
|
12 |
|
13 |
-
from buster.documents import get_documents_manager_from_extension
|
14 |
from buster.parser import HuggingfaceParser, Parser, SphinxParser
|
|
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
logging.basicConfig(level=logging.INFO)
|
|
|
10 |
from bs4 import BeautifulSoup
|
11 |
from openai.embeddings_utils import get_embedding
|
12 |
|
|
|
13 |
from buster.parser import HuggingfaceParser, Parser, SphinxParser
|
14 |
+
from buster.utils import get_documents_manager_from_extension
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
logging.basicConfig(level=logging.INFO)
|
buster/documents/__init__.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from .base import DocumentsManager
|
2 |
from .pickle import DocumentsPickle
|
3 |
from .sqlite import DocumentsDB
|
4 |
-
from .utils import get_documents_manager_from_extension
|
5 |
|
6 |
-
__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB
|
|
|
1 |
from .base import DocumentsManager
|
2 |
from .pickle import DocumentsPickle
|
3 |
from .sqlite import DocumentsDB
|
|
|
4 |
|
5 |
+
__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB]
|
buster/documents/base.py
CHANGED
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import pandas as pd
|
5 |
-
from openai.embeddings_utils import cosine_similarity
|
6 |
|
7 |
|
8 |
@dataclass
|
@@ -10,21 +9,3 @@ class DocumentsManager(ABC):
|
|
10 |
@abstractmethod
|
11 |
def add(self, source: str, df: pd.DataFrame):
|
12 |
...
|
13 |
-
|
14 |
-
@abstractmethod
|
15 |
-
def get_documents(self, source: str) -> pd.DataFrame:
|
16 |
-
...
|
17 |
-
|
18 |
-
def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
|
19 |
-
documents = self.get_documents(source)
|
20 |
-
|
21 |
-
documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
|
22 |
-
|
23 |
-
# sort the matched_documents by score
|
24 |
-
matched_documents = documents.sort_values("similarity", ascending=False)
|
25 |
-
|
26 |
-
# limit search to top_k matched_documents.
|
27 |
-
top_k = len(matched_documents) if top_k == -1 else top_k
|
28 |
-
matched_documents = matched_documents.head(top_k)
|
29 |
-
|
30 |
-
return matched_documents
|
|
|
2 |
from dataclasses import dataclass
|
3 |
|
4 |
import pandas as pd
|
|
|
5 |
|
6 |
|
7 |
@dataclass
|
|
|
9 |
@abstractmethod
|
10 |
def add(self, source: str, df: pd.DataFrame):
|
11 |
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buster/documents/pickle.py
CHANGED
@@ -27,19 +27,3 @@ class DocumentsPickle(DocumentsManager):
|
|
27 |
self.documents = df
|
28 |
|
29 |
self.documents.to_pickle(self.filepath)
|
30 |
-
|
31 |
-
def get_documents(self, source: str) -> pd.DataFrame:
|
32 |
-
if self.documents is None:
|
33 |
-
raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
|
34 |
-
|
35 |
-
documents = self.documents.copy()
|
36 |
-
if "current" in documents.columns:
|
37 |
-
documents = documents[documents.current == 1]
|
38 |
-
|
39 |
-
# Drop the `current` column
|
40 |
-
documents.drop(columns=["current"], inplace=True)
|
41 |
-
|
42 |
-
if source is not None and "source" in documents.columns:
|
43 |
-
documents = documents[documents.source == source]
|
44 |
-
|
45 |
-
return documents
|
|
|
27 |
self.documents = df
|
28 |
|
29 |
self.documents.to_pickle(self.filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buster/documents/sqlite/documents.py
CHANGED
@@ -33,7 +33,6 @@ class DocumentsDB(DocumentsManager):
|
|
33 |
Example:
|
34 |
>>> db = DocumentsDB("/path/to/the/db.db")
|
35 |
>>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
|
36 |
-
>>> df = db.get_documents("source")
|
37 |
"""
|
38 |
|
39 |
def __init__(self, db_path: sqlite3.Connection | str):
|
@@ -142,13 +141,3 @@ class DocumentsDB(DocumentsManager):
|
|
142 |
sid, vid = self.add_parse(source, (section for section, _ in sections))
|
143 |
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
|
144 |
self.conn.commit()
|
145 |
-
|
146 |
-
def get_documents(self, source: str) -> pd.DataFrame:
|
147 |
-
"""Get all current documents from a given source."""
|
148 |
-
# Execute the SQL statement and fetch the results
|
149 |
-
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
|
150 |
-
rows = results.fetchall()
|
151 |
-
|
152 |
-
# Convert the results to a pandas DataFrame
|
153 |
-
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
154 |
-
return df
|
|
|
33 |
Example:
|
34 |
>>> db = DocumentsDB("/path/to/the/db.db")
|
35 |
>>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
|
|
|
36 |
"""
|
37 |
|
38 |
def __init__(self, db_path: sqlite3.Connection | str):
|
|
|
141 |
sid, vid = self.add_parse(source, (section for section, _ in sections))
|
142 |
self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
|
143 |
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buster/examples/gradio_app.py
CHANGED
@@ -2,12 +2,12 @@ import cfg
|
|
2 |
import gradio as gr
|
3 |
|
4 |
from buster.busterbot import Buster
|
5 |
-
from buster.
|
6 |
-
from buster.
|
7 |
|
8 |
# initialize buster with the config in config.py (adapt to your needs) ...
|
9 |
-
|
10 |
-
buster: Buster = Buster(cfg=cfg.buster_cfg,
|
11 |
|
12 |
|
13 |
def chat(question, history):
|
|
|
2 |
import gradio as gr
|
3 |
|
4 |
from buster.busterbot import Buster
|
5 |
+
from buster.retriever import Retriever
|
6 |
+
from buster.utils import get_retriever_from_extension
|
7 |
|
8 |
# initialize buster with the config in config.py (adapt to your needs) ...
|
9 |
+
retriever: Retriever = get_retriever_from_extension(cfg.documents_filepath)(cfg.documents_filepath)
|
10 |
+
buster: Buster = Buster(cfg=cfg.buster_cfg, retriever=retriever)
|
11 |
|
12 |
|
13 |
def chat(question, history):
|
buster/parser.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import math
|
2 |
import os
|
3 |
from abc import ABC, abstractmethod
|
4 |
from dataclasses import InitVar, dataclass, field
|
|
|
|
|
1 |
import os
|
2 |
from abc import ABC, abstractmethod
|
3 |
from dataclasses import InitVar, dataclass, field
|
buster/retriever/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import Retriever
|
2 |
+
from .pickle import PickleRetriever
|
3 |
+
from .sqlite import SQLiteRetriever
|
4 |
+
|
5 |
+
__all__ = [Retriever, PickleRetriever, SQLiteRetriever]
|
buster/retriever/base.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from openai.embeddings_utils import cosine_similarity
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class Retriever(ABC):
|
10 |
+
@abstractmethod
|
11 |
+
def get_documents(self, source: str) -> pd.DataFrame:
|
12 |
+
...
|
13 |
+
|
14 |
+
def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
|
15 |
+
documents = self.get_documents(source)
|
16 |
+
|
17 |
+
documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
|
18 |
+
|
19 |
+
# sort the matched_documents by score
|
20 |
+
matched_documents = documents.sort_values("similarity", ascending=False)
|
21 |
+
|
22 |
+
# limit search to top_k matched_documents.
|
23 |
+
top_k = len(matched_documents) if top_k == -1 else top_k
|
24 |
+
matched_documents = matched_documents.head(top_k)
|
25 |
+
|
26 |
+
return matched_documents
|
buster/retriever/pickle.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
from buster.retriever.base import Retriever
|
4 |
+
|
5 |
+
|
6 |
+
class PickleRetriever(Retriever):
|
7 |
+
def __init__(self, filepath: str):
|
8 |
+
self.filepath = filepath
|
9 |
+
self.documents = pd.read_pickle(filepath)
|
10 |
+
|
11 |
+
def get_documents(self, source: str) -> pd.DataFrame:
|
12 |
+
if self.documents is None:
|
13 |
+
raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
|
14 |
+
|
15 |
+
documents = self.documents.copy()
|
16 |
+
# The `current` column exists when multiple versions of a document exist
|
17 |
+
if "current" in documents.columns:
|
18 |
+
documents = documents[documents.current == 1]
|
19 |
+
|
20 |
+
# Drop the `current` column
|
21 |
+
documents.drop(columns=["current"], inplace=True)
|
22 |
+
|
23 |
+
if source is not None and "source" in documents.columns:
|
24 |
+
documents = documents[documents.source == source]
|
25 |
+
|
26 |
+
return documents
|
buster/retriever/sqlite.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import buster.documents.sqlite.schema as schema
|
7 |
+
from buster.retriever.base import Retriever
|
8 |
+
|
9 |
+
|
10 |
+
class SQLiteRetriever(Retriever):
|
11 |
+
"""Simple SQLite database for retrieval of documents.
|
12 |
+
|
13 |
+
The database is just a file on disk. It can store documents from different sources, and it
|
14 |
+
can store multiple versions of the same document (e.g. if the document is updated).
|
15 |
+
|
16 |
+
Example:
|
17 |
+
>>> db = DocumentsDB("/path/to/the/db.db")
|
18 |
+
>>> df = db.get_documents("source")
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, db_path: sqlite3.Connection | str):
|
22 |
+
if isinstance(db_path, (str, Path)):
|
23 |
+
self.db_path = db_path
|
24 |
+
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
|
25 |
+
else:
|
26 |
+
self.db_path = None
|
27 |
+
self.conn = db_path
|
28 |
+
schema.initialize_db(self.conn)
|
29 |
+
schema.setup_db(self.conn)
|
30 |
+
|
31 |
+
def __del__(self):
|
32 |
+
if self.db_path is not None:
|
33 |
+
self.conn.close()
|
34 |
+
|
35 |
+
def get_documents(self, source: str) -> pd.DataFrame:
|
36 |
+
"""Get all current documents from a given source."""
|
37 |
+
# Execute the SQL statement and fetch the results.
|
38 |
+
if source is "":
|
39 |
+
results = self.conn.execute("SELECT * FROM documents")
|
40 |
+
else:
|
41 |
+
results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
|
42 |
+
rows = results.fetchall()
|
43 |
+
|
44 |
+
# Convert the results to a pandas DataFrame
|
45 |
+
df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
|
46 |
+
return df
|
buster/{documents/utils.py → utils.py}
RENAMED
@@ -2,9 +2,8 @@ import os
|
|
2 |
import urllib.request
|
3 |
from typing import Type
|
4 |
|
5 |
-
from buster.documents
|
6 |
-
from buster.
|
7 |
-
from buster.documents.sqlite import DocumentsDB
|
8 |
|
9 |
PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
|
10 |
|
@@ -34,3 +33,14 @@ def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager
|
|
34 |
return DocumentsDB
|
35 |
else:
|
36 |
raise ValueError(f"Unsupported format: {ext}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import urllib.request
|
3 |
from typing import Type
|
4 |
|
5 |
+
from buster.documents import DocumentsDB, DocumentsManager, DocumentsPickle
|
6 |
+
from buster.retriever import PickleRetriever, Retriever, SQLiteRetriever
|
|
|
7 |
|
8 |
PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
|
9 |
|
|
|
33 |
return DocumentsDB
|
34 |
else:
|
35 |
raise ValueError(f"Unsupported format: {ext}.")
|
36 |
+
|
37 |
+
|
38 |
+
def get_retriever_from_extension(filepath: str) -> Type[Retriever]:
|
39 |
+
ext = get_file_extension(filepath)
|
40 |
+
|
41 |
+
if ext in PICKLE_EXTENSIONS:
|
42 |
+
return PickleRetriever
|
43 |
+
elif ext == ".db":
|
44 |
+
return SQLiteRetriever
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Unsupported format: {ext}.")
|
tests/test_chatbot.py
CHANGED
@@ -6,8 +6,9 @@ import pandas as pd
|
|
6 |
|
7 |
from buster.busterbot import Buster, BusterConfig
|
8 |
from buster.completers.base import Completer
|
9 |
-
from buster.documents import DocumentsManager, get_documents_manager_from_extension
|
10 |
from buster.formatter.base import Response
|
|
|
|
|
11 |
|
12 |
TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
|
13 |
DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
|
@@ -29,7 +30,7 @@ class MockCompleter(Completer):
|
|
29 |
return Response(self.expected_answer)
|
30 |
|
31 |
|
32 |
-
class
|
33 |
def __init__(self, filepath):
|
34 |
self.filepath = filepath
|
35 |
|
@@ -45,9 +46,6 @@ class DocumentsMock(DocumentsManager):
|
|
45 |
}
|
46 |
)
|
47 |
|
48 |
-
def add(self, documents):
|
49 |
-
pass
|
50 |
-
|
51 |
def get_documents(self, source):
|
52 |
return self.documents
|
53 |
|
@@ -90,8 +88,8 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
|
|
90 |
},
|
91 |
)
|
92 |
filepath = tmp_path / "not_a_real_file.tar.gz"
|
93 |
-
|
94 |
-
buster = Buster(cfg=hf_transformers_cfg,
|
95 |
answer = buster.process_input("What is a transformer?")
|
96 |
assert isinstance(answer, str)
|
97 |
assert answer.startswith(gpt_expected_answer)
|
@@ -119,8 +117,8 @@ def test_chatbot_real_data__chatGPT():
|
|
119 |
},
|
120 |
},
|
121 |
)
|
122 |
-
|
123 |
-
buster = Buster(cfg=hf_transformers_cfg,
|
124 |
answer = buster.process_input("What is a transformer?")
|
125 |
assert isinstance(answer, str)
|
126 |
|
@@ -153,8 +151,8 @@ def test_chatbot_real_data__chatGPT_OOD():
|
|
153 |
},
|
154 |
response_format="gradio",
|
155 |
)
|
156 |
-
|
157 |
-
buster = Buster(cfg=buster_cfg,
|
158 |
answer = buster.process_input("What is a good recipe for brocolli soup?")
|
159 |
assert isinstance(answer, str)
|
160 |
assert buster_cfg.unknown_prompt in answer
|
@@ -187,7 +185,7 @@ def test_chatbot_real_data__GPT():
|
|
187 |
},
|
188 |
},
|
189 |
)
|
190 |
-
|
191 |
-
buster = Buster(cfg=hf_transformers_cfg,
|
192 |
answer = buster.process_input("What is a transformer?")
|
193 |
assert isinstance(answer, str)
|
|
|
6 |
|
7 |
from buster.busterbot import Buster, BusterConfig
|
8 |
from buster.completers.base import Completer
|
|
|
9 |
from buster.formatter.base import Response
|
10 |
+
from buster.retriever import Retriever
|
11 |
+
from buster.utils import get_retriever_from_extension
|
12 |
|
13 |
TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
|
14 |
DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
|
|
|
30 |
return Response(self.expected_answer)
|
31 |
|
32 |
|
33 |
+
class MockRetriever(Retriever):
|
34 |
def __init__(self, filepath):
|
35 |
self.filepath = filepath
|
36 |
|
|
|
46 |
}
|
47 |
)
|
48 |
|
|
|
|
|
|
|
49 |
def get_documents(self, source):
|
50 |
return self.documents
|
51 |
|
|
|
88 |
},
|
89 |
)
|
90 |
filepath = tmp_path / "not_a_real_file.tar.gz"
|
91 |
+
retriever = MockRetriever(filepath)
|
92 |
+
buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
|
93 |
answer = buster.process_input("What is a transformer?")
|
94 |
assert isinstance(answer, str)
|
95 |
assert answer.startswith(gpt_expected_answer)
|
|
|
117 |
},
|
118 |
},
|
119 |
)
|
120 |
+
retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
|
121 |
+
buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
|
122 |
answer = buster.process_input("What is a transformer?")
|
123 |
assert isinstance(answer, str)
|
124 |
|
|
|
151 |
},
|
152 |
response_format="gradio",
|
153 |
)
|
154 |
+
retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
|
155 |
+
buster = Buster(cfg=buster_cfg, retriever=retriever)
|
156 |
answer = buster.process_input("What is a good recipe for brocolli soup?")
|
157 |
assert isinstance(answer, str)
|
158 |
assert buster_cfg.unknown_prompt in answer
|
|
|
185 |
},
|
186 |
},
|
187 |
)
|
188 |
+
retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
|
189 |
+
buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
|
190 |
answer = buster.process_input("What is a transformer?")
|
191 |
assert isinstance(answer, str)
|
tests/test_docparser.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
|
|
3 |
|
4 |
from buster.docparser import generate_embeddings
|
5 |
-
from buster.
|
6 |
|
7 |
|
8 |
-
|
|
|
9 |
# Create fake data
|
10 |
data = pd.DataFrame.from_dict(
|
11 |
{"title": ["test"], "url": ["http://url.com"], "content": ["cool text"], "source": ["my_source"]}
|
@@ -16,11 +18,11 @@ def test_generate_embeddings(tmp_path, monkeypatch):
|
|
16 |
monkeypatch.setattr("buster.docparser.get_all_documents", lambda a, b, c: data)
|
17 |
|
18 |
# Generate embeddings, store in a file
|
19 |
-
output_file = tmp_path / "test_document_embeddings.
|
20 |
df = generate_embeddings(data, output_file)
|
21 |
|
22 |
# Read the embeddings from the file
|
23 |
-
read_df =
|
24 |
|
25 |
# Check all the values are correct across the files
|
26 |
assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
+
import pytest
|
4 |
|
5 |
from buster.docparser import generate_embeddings
|
6 |
+
from buster.utils import get_retriever_from_extension
|
7 |
|
8 |
|
9 |
+
@pytest.mark.parametrize("extension", ["db", "tar.gz"])
|
10 |
+
def test_generate_embeddings(tmp_path, monkeypatch, extension):
|
11 |
# Create fake data
|
12 |
data = pd.DataFrame.from_dict(
|
13 |
{"title": ["test"], "url": ["http://url.com"], "content": ["cool text"], "source": ["my_source"]}
|
|
|
18 |
monkeypatch.setattr("buster.docparser.get_all_documents", lambda a, b, c: data)
|
19 |
|
20 |
# Generate embeddings, store in a file
|
21 |
+
output_file = tmp_path / f"test_document_embeddings.{extension}"
|
22 |
df = generate_embeddings(data, output_file)
|
23 |
|
24 |
# Read the embeddings from the file
|
25 |
+
read_df = get_retriever_from_extension(output_file)(output_file).get_documents("my_source")
|
26 |
|
27 |
# Check all the values are correct across the files
|
28 |
assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
|
tests/test_documents.py
CHANGED
@@ -3,10 +3,14 @@ import pandas as pd
|
|
3 |
import pytest
|
4 |
|
5 |
from buster.documents import DocumentsDB, DocumentsPickle
|
|
|
6 |
|
7 |
|
8 |
-
@pytest.mark.parametrize(
|
9 |
-
|
|
|
|
|
|
|
10 |
db = documents_manager(tmp_path / f"test.{extension}")
|
11 |
|
12 |
data = pd.DataFrame.from_dict(
|
@@ -20,7 +24,7 @@ def test_write_read(tmp_path, documents_manager, extension):
|
|
20 |
)
|
21 |
db.add(source="test", df=data)
|
22 |
|
23 |
-
db_data =
|
24 |
|
25 |
assert db_data["title"].iloc[0] == data["title"].iloc[0]
|
26 |
assert db_data["url"].iloc[0] == data["url"].iloc[0]
|
@@ -29,8 +33,11 @@ def test_write_read(tmp_path, documents_manager, extension):
|
|
29 |
assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
|
30 |
|
31 |
|
32 |
-
@pytest.mark.parametrize(
|
33 |
-
|
|
|
|
|
|
|
34 |
db = documents_manager(tmp_path / f"test.{extension}")
|
35 |
|
36 |
data_1 = pd.DataFrame.from_dict(
|
@@ -55,7 +62,7 @@ def test_write_write_read(tmp_path, documents_manager, extension):
|
|
55 |
)
|
56 |
db.add(source="test", df=data_2)
|
57 |
|
58 |
-
db_data =
|
59 |
|
60 |
assert len(db_data) == len(data_2)
|
61 |
assert db_data["title"].iloc[0] == data_2["title"].iloc[0]
|
|
|
3 |
import pytest
|
4 |
|
5 |
from buster.documents import DocumentsDB, DocumentsPickle
|
6 |
+
from buster.retriever import PickleRetriever, SQLiteRetriever
|
7 |
|
8 |
|
9 |
+
@pytest.mark.parametrize(
|
10 |
+
"documents_manager, retriever, extension",
|
11 |
+
[(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")],
|
12 |
+
)
|
13 |
+
def test_write_read(tmp_path, documents_manager, retriever, extension):
|
14 |
db = documents_manager(tmp_path / f"test.{extension}")
|
15 |
|
16 |
data = pd.DataFrame.from_dict(
|
|
|
24 |
)
|
25 |
db.add(source="test", df=data)
|
26 |
|
27 |
+
db_data = retriever(tmp_path / f"test.{extension}").get_documents("test")
|
28 |
|
29 |
assert db_data["title"].iloc[0] == data["title"].iloc[0]
|
30 |
assert db_data["url"].iloc[0] == data["url"].iloc[0]
|
|
|
33 |
assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
|
34 |
|
35 |
|
36 |
+
@pytest.mark.parametrize(
|
37 |
+
"documents_manager, retriever, extension",
|
38 |
+
[(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")],
|
39 |
+
)
|
40 |
+
def test_write_write_read(tmp_path, documents_manager, retriever, extension):
|
41 |
db = documents_manager(tmp_path / f"test.{extension}")
|
42 |
|
43 |
data_1 = pd.DataFrame.from_dict(
|
|
|
62 |
)
|
63 |
db.add(source="test", df=data_2)
|
64 |
|
65 |
+
db_data = retriever(tmp_path / f"test.{extension}").get_documents("test")
|
66 |
|
67 |
assert len(db_data) == len(data_2)
|
68 |
assert db_data["title"].iloc[0] == data_2["title"].iloc[0]
|