hbertrand commited on
Commit
06bca0c
1 Parent(s): 44ee439

PR: retriever interface (#77)

Browse files

* retriever interface

* black + isort

* documents -> retriever

* PR

* black

buster/apps/gradio_app.py CHANGED
@@ -5,19 +5,19 @@ import gradio as gr
5
 
6
  from buster.apps.bot_configs import available_configs
7
  from buster.busterbot import Buster, BusterConfig
8
- from buster.documents.base import DocumentsManager
9
- from buster.documents.utils import download_db, get_documents_manager_from_extension
10
 
11
  DEFAULT_CONFIG = "huggingface"
12
  DB_URL = "https://huggingface.co/datasets/jerpint/buster-data/resolve/main/documents.db"
13
 
14
  # Download the db...
15
  documents_filepath = download_db(db_url=DB_URL, output_dir="./data")
16
- documents: DocumentsManager = get_documents_manager_from_extension(documents_filepath)(documents_filepath)
17
 
18
  # initialize buster with the default config...
19
  default_cfg: BusterConfig = available_configs.get(DEFAULT_CONFIG)
20
- buster = Buster(cfg=default_cfg, documents=documents)
21
 
22
 
23
  def chat(question, history, bot_source):
 
5
 
6
  from buster.apps.bot_configs import available_configs
7
  from buster.busterbot import Buster, BusterConfig
8
+ from buster.retriever import Retriever
9
+ from buster.utils import download_db, get_retriever_from_extension
10
 
11
  DEFAULT_CONFIG = "huggingface"
12
  DB_URL = "https://huggingface.co/datasets/jerpint/buster-data/resolve/main/documents.db"
13
 
14
  # Download the db...
15
  documents_filepath = download_db(db_url=DB_URL, output_dir="./data")
16
+ retriever: Retriever = get_retriever_from_extension(documents_filepath)(documents_filepath)
17
 
18
  # initialize buster with the default config...
19
  default_cfg: BusterConfig = available_configs.get(DEFAULT_CONFIG)
20
+ buster = Buster(cfg=default_cfg, retriever=retriever)
21
 
22
 
23
  def chat(question, history, bot_source):
buster/busterbot.py CHANGED
@@ -64,16 +64,16 @@ class BusterConfig:
64
  source: str = ""
65
 
66
 
67
- from buster.documents.base import DocumentsManager
68
 
69
 
70
  class Buster:
71
- def __init__(self, cfg: BusterConfig, documents: DocumentsManager):
72
  self._unk_embedding = None
73
  self.cfg = cfg
74
  self.update_cfg(cfg)
75
 
76
- self.documents = documents
77
 
78
  @property
79
  def unk_embedding(self):
@@ -117,7 +117,7 @@ class Buster:
117
  query,
118
  engine=engine,
119
  )
120
- matched_documents = self.documents.retrieve(query_embedding, top_k=top_k, source=source)
121
 
122
  # log matched_documents to the console
123
  logger.info(f"matched documents before thresh: {matched_documents}")
 
64
  source: str = ""
65
 
66
 
67
+ from buster.retriever import Retriever
68
 
69
 
70
  class Buster:
71
+ def __init__(self, cfg: BusterConfig, retriever: Retriever):
72
  self._unk_embedding = None
73
  self.cfg = cfg
74
  self.update_cfg(cfg)
75
 
76
+ self.retriever = retriever
77
 
78
  @property
79
  def unk_embedding(self):
 
117
  query,
118
  engine=engine,
119
  )
120
+ matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)
121
 
122
  # log matched_documents to the console
123
  logger.info(f"matched documents before thresh: {matched_documents}")
buster/docparser.py CHANGED
@@ -10,8 +10,8 @@ import tiktoken
10
  from bs4 import BeautifulSoup
11
  from openai.embeddings_utils import get_embedding
12
 
13
- from buster.documents import get_documents_manager_from_extension
14
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
 
15
 
16
  logger = logging.getLogger(__name__)
17
  logging.basicConfig(level=logging.INFO)
 
10
  from bs4 import BeautifulSoup
11
  from openai.embeddings_utils import get_embedding
12
 
 
13
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
14
+ from buster.utils import get_documents_manager_from_extension
15
 
16
  logger = logging.getLogger(__name__)
17
  logging.basicConfig(level=logging.INFO)
buster/documents/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
  from .base import DocumentsManager
2
  from .pickle import DocumentsPickle
3
  from .sqlite import DocumentsDB
4
- from .utils import get_documents_manager_from_extension
5
 
6
- __all__ = [DocumentsManager, DocumentsPickle, DocumentsDB, get_documents_manager_from_extension]
 
1
  from .base import DocumentsManager
2
  from .pickle import DocumentsPickle
3
  from .sqlite import DocumentsDB
 
4
 
5
+ __all__ = [DocumentsManager, DocumentsPickle, DocumentsDB]
buster/documents/base.py CHANGED
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
2
  from dataclasses import dataclass
3
 
4
  import pandas as pd
5
- from openai.embeddings_utils import cosine_similarity
6
 
7
 
8
  @dataclass
@@ -10,21 +9,3 @@ class DocumentsManager(ABC):
10
  @abstractmethod
11
  def add(self, source: str, df: pd.DataFrame):
12
  ...
13
-
14
- @abstractmethod
15
- def get_documents(self, source: str) -> pd.DataFrame:
16
- ...
17
-
18
- def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
19
- documents = self.get_documents(source)
20
-
21
- documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
22
-
23
- # sort the matched_documents by score
24
- matched_documents = documents.sort_values("similarity", ascending=False)
25
-
26
- # limit search to top_k matched_documents.
27
- top_k = len(matched_documents) if top_k == -1 else top_k
28
- matched_documents = matched_documents.head(top_k)
29
-
30
- return matched_documents
 
2
  from dataclasses import dataclass
3
 
4
  import pandas as pd
 
5
 
6
 
7
  @dataclass
 
9
  @abstractmethod
10
  def add(self, source: str, df: pd.DataFrame):
11
  ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
buster/documents/pickle.py CHANGED
@@ -27,19 +27,3 @@ class DocumentsPickle(DocumentsManager):
27
  self.documents = df
28
 
29
  self.documents.to_pickle(self.filepath)
30
-
31
- def get_documents(self, source: str) -> pd.DataFrame:
32
- if self.documents is None:
33
- raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
34
-
35
- documents = self.documents.copy()
36
- if "current" in documents.columns:
37
- documents = documents[documents.current == 1]
38
-
39
- # Drop the `current` column
40
- documents.drop(columns=["current"], inplace=True)
41
-
42
- if source is not None and "source" in documents.columns:
43
- documents = documents[documents.source == source]
44
-
45
- return documents
 
27
  self.documents = df
28
 
29
  self.documents.to_pickle(self.filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
buster/documents/sqlite/documents.py CHANGED
@@ -33,7 +33,6 @@ class DocumentsDB(DocumentsManager):
33
  Example:
34
  >>> db = DocumentsDB("/path/to/the/db.db")
35
  >>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
36
- >>> df = db.get_documents("source")
37
  """
38
 
39
  def __init__(self, db_path: sqlite3.Connection | str):
@@ -142,13 +141,3 @@ class DocumentsDB(DocumentsManager):
142
  sid, vid = self.add_parse(source, (section for section, _ in sections))
143
  self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
144
  self.conn.commit()
145
-
146
- def get_documents(self, source: str) -> pd.DataFrame:
147
- """Get all current documents from a given source."""
148
- # Execute the SQL statement and fetch the results
149
- results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
150
- rows = results.fetchall()
151
-
152
- # Convert the results to a pandas DataFrame
153
- df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
154
- return df
 
33
  Example:
34
  >>> db = DocumentsDB("/path/to/the/db.db")
35
  >>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
 
36
  """
37
 
38
  def __init__(self, db_path: sqlite3.Connection | str):
 
141
  sid, vid = self.add_parse(source, (section for section, _ in sections))
142
  self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
143
  self.conn.commit()
 
 
 
 
 
 
 
 
 
 
buster/examples/gradio_app.py CHANGED
@@ -2,12 +2,12 @@ import cfg
2
  import gradio as gr
3
 
4
  from buster.busterbot import Buster
5
- from buster.documents.base import DocumentsManager
6
- from buster.documents.utils import get_documents_manager_from_extension
7
 
8
  # initialize buster with the config in config.py (adapt to your needs) ...
9
- documents: DocumentsManager = get_documents_manager_from_extension(cfg.documents_filepath)(cfg.documents_filepath)
10
- buster: Buster = Buster(cfg=cfg.buster_cfg, documents=documents)
11
 
12
 
13
  def chat(question, history):
 
2
  import gradio as gr
3
 
4
  from buster.busterbot import Buster
5
+ from buster.retriever import Retriever
6
+ from buster.utils import get_retriever_from_extension
7
 
8
  # initialize buster with the config in config.py (adapt to your needs) ...
9
+ retriever: Retriever = get_retriever_from_extension(cfg.documents_filepath)(cfg.documents_filepath)
10
+ buster: Buster = Buster(cfg=cfg.buster_cfg, retriever=retriever)
11
 
12
 
13
  def chat(question, history):
buster/parser.py CHANGED
@@ -1,4 +1,3 @@
1
- import math
2
  import os
3
  from abc import ABC, abstractmethod
4
  from dataclasses import InitVar, dataclass, field
 
 
1
  import os
2
  from abc import ABC, abstractmethod
3
  from dataclasses import InitVar, dataclass, field
buster/retriever/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base import Retriever
2
+ from .pickle import PickleRetriever
3
+ from .sqlite import SQLiteRetriever
4
+
5
+ __all__ = [Retriever, PickleRetriever, SQLiteRetriever]
buster/retriever/base.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+
4
+ import pandas as pd
5
+ from openai.embeddings_utils import cosine_similarity
6
+
7
+
8
+ @dataclass
9
+ class Retriever(ABC):
10
+ @abstractmethod
11
+ def get_documents(self, source: str) -> pd.DataFrame:
12
+ ...
13
+
14
+ def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
15
+ documents = self.get_documents(source)
16
+
17
+ documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
18
+
19
+ # sort the matched_documents by score
20
+ matched_documents = documents.sort_values("similarity", ascending=False)
21
+
22
+ # limit search to top_k matched_documents.
23
+ top_k = len(matched_documents) if top_k == -1 else top_k
24
+ matched_documents = matched_documents.head(top_k)
25
+
26
+ return matched_documents
buster/retriever/pickle.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from buster.retriever.base import Retriever
4
+
5
+
6
+ class PickleRetriever(Retriever):
7
+ def __init__(self, filepath: str):
8
+ self.filepath = filepath
9
+ self.documents = pd.read_pickle(filepath)
10
+
11
+ def get_documents(self, source: str) -> pd.DataFrame:
12
+ if self.documents is None:
13
+ raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
14
+
15
+ documents = self.documents.copy()
16
+ # The `current` column exists when multiple versions of a document exist
17
+ if "current" in documents.columns:
18
+ documents = documents[documents.current == 1]
19
+
20
+ # Drop the `current` column
21
+ documents.drop(columns=["current"], inplace=True)
22
+
23
+ if source is not None and "source" in documents.columns:
24
+ documents = documents[documents.source == source]
25
+
26
+ return documents
buster/retriever/sqlite.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ import buster.documents.sqlite.schema as schema
7
+ from buster.retriever.base import Retriever
8
+
9
+
10
+ class SQLiteRetriever(Retriever):
11
+ """Simple SQLite database for retrieval of documents.
12
+
13
+ The database is just a file on disk. It can store documents from different sources, and it
14
+ can store multiple versions of the same document (e.g. if the document is updated).
15
+
16
+ Example:
17
+ >>> db = DocumentsDB("/path/to/the/db.db")
18
+ >>> df = db.get_documents("source")
19
+ """
20
+
21
+ def __init__(self, db_path: sqlite3.Connection | str):
22
+ if isinstance(db_path, (str, Path)):
23
+ self.db_path = db_path
24
+ self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
25
+ else:
26
+ self.db_path = None
27
+ self.conn = db_path
28
+ schema.initialize_db(self.conn)
29
+ schema.setup_db(self.conn)
30
+
31
+ def __del__(self):
32
+ if self.db_path is not None:
33
+ self.conn.close()
34
+
35
+ def get_documents(self, source: str) -> pd.DataFrame:
36
+ """Get all current documents from a given source."""
37
+ # Execute the SQL statement and fetch the results.
38
+ if source is "":
39
+ results = self.conn.execute("SELECT * FROM documents")
40
+ else:
41
+ results = self.conn.execute("SELECT * FROM documents WHERE source = ?", (source,))
42
+ rows = results.fetchall()
43
+
44
+ # Convert the results to a pandas DataFrame
45
+ df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
46
+ return df
buster/{documents/utils.py → utils.py} RENAMED
@@ -2,9 +2,8 @@ import os
2
  import urllib.request
3
  from typing import Type
4
 
5
- from buster.documents.base import DocumentsManager
6
- from buster.documents.pickle import DocumentsPickle
7
- from buster.documents.sqlite import DocumentsDB
8
 
9
  PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
10
 
@@ -34,3 +33,14 @@ def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager
34
  return DocumentsDB
35
  else:
36
  raise ValueError(f"Unsupported format: {ext}.")
 
 
 
 
 
 
 
 
 
 
 
 
2
  import urllib.request
3
  from typing import Type
4
 
5
+ from buster.documents import DocumentsDB, DocumentsManager, DocumentsPickle
6
+ from buster.retriever import PickleRetriever, Retriever, SQLiteRetriever
 
7
 
8
  PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
9
 
 
33
  return DocumentsDB
34
  else:
35
  raise ValueError(f"Unsupported format: {ext}.")
36
+
37
+
38
+ def get_retriever_from_extension(filepath: str) -> Type[Retriever]:
39
+ ext = get_file_extension(filepath)
40
+
41
+ if ext in PICKLE_EXTENSIONS:
42
+ return PickleRetriever
43
+ elif ext == ".db":
44
+ return SQLiteRetriever
45
+ else:
46
+ raise ValueError(f"Unsupported format: {ext}.")
tests/test_chatbot.py CHANGED
@@ -6,8 +6,9 @@ import pandas as pd
6
 
7
  from buster.busterbot import Buster, BusterConfig
8
  from buster.completers.base import Completer
9
- from buster.documents import DocumentsManager, get_documents_manager_from_extension
10
  from buster.formatter.base import Response
 
 
11
 
12
  TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
13
  DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
@@ -29,7 +30,7 @@ class MockCompleter(Completer):
29
  return Response(self.expected_answer)
30
 
31
 
32
- class DocumentsMock(DocumentsManager):
33
  def __init__(self, filepath):
34
  self.filepath = filepath
35
 
@@ -45,9 +46,6 @@ class DocumentsMock(DocumentsManager):
45
  }
46
  )
47
 
48
- def add(self, documents):
49
- pass
50
-
51
  def get_documents(self, source):
52
  return self.documents
53
 
@@ -90,8 +88,8 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
90
  },
91
  )
92
  filepath = tmp_path / "not_a_real_file.tar.gz"
93
- documents = DocumentsMock(filepath)
94
- buster = Buster(cfg=hf_transformers_cfg, documents=documents)
95
  answer = buster.process_input("What is a transformer?")
96
  assert isinstance(answer, str)
97
  assert answer.startswith(gpt_expected_answer)
@@ -119,8 +117,8 @@ def test_chatbot_real_data__chatGPT():
119
  },
120
  },
121
  )
122
- documents = get_documents_manager_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
123
- buster = Buster(cfg=hf_transformers_cfg, documents=documents)
124
  answer = buster.process_input("What is a transformer?")
125
  assert isinstance(answer, str)
126
 
@@ -153,8 +151,8 @@ def test_chatbot_real_data__chatGPT_OOD():
153
  },
154
  response_format="gradio",
155
  )
156
- documents = get_documents_manager_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
157
- buster = Buster(cfg=buster_cfg, documents=documents)
158
  answer = buster.process_input("What is a good recipe for brocolli soup?")
159
  assert isinstance(answer, str)
160
  assert buster_cfg.unknown_prompt in answer
@@ -187,7 +185,7 @@ def test_chatbot_real_data__GPT():
187
  },
188
  },
189
  )
190
- documents = get_documents_manager_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
191
- buster = Buster(cfg=hf_transformers_cfg, documents=documents)
192
  answer = buster.process_input("What is a transformer?")
193
  assert isinstance(answer, str)
 
6
 
7
  from buster.busterbot import Buster, BusterConfig
8
  from buster.completers.base import Completer
 
9
  from buster.formatter.base import Response
10
+ from buster.retriever import Retriever
11
+ from buster.utils import get_retriever_from_extension
12
 
13
  TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
14
  DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
 
30
  return Response(self.expected_answer)
31
 
32
 
33
+ class MockRetriever(Retriever):
34
  def __init__(self, filepath):
35
  self.filepath = filepath
36
 
 
46
  }
47
  )
48
 
 
 
 
49
  def get_documents(self, source):
50
  return self.documents
51
 
 
88
  },
89
  )
90
  filepath = tmp_path / "not_a_real_file.tar.gz"
91
+ retriever = MockRetriever(filepath)
92
+ buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
93
  answer = buster.process_input("What is a transformer?")
94
  assert isinstance(answer, str)
95
  assert answer.startswith(gpt_expected_answer)
 
117
  },
118
  },
119
  )
120
+ retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
121
+ buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
122
  answer = buster.process_input("What is a transformer?")
123
  assert isinstance(answer, str)
124
 
 
151
  },
152
  response_format="gradio",
153
  )
154
+ retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
155
+ buster = Buster(cfg=buster_cfg, retriever=retriever)
156
  answer = buster.process_input("What is a good recipe for brocolli soup?")
157
  assert isinstance(answer, str)
158
  assert buster_cfg.unknown_prompt in answer
 
185
  },
186
  },
187
  )
188
+ retriever = get_retriever_from_extension(DOCUMENTS_FILE)(DOCUMENTS_FILE)
189
+ buster = Buster(cfg=hf_transformers_cfg, retriever=retriever)
190
  answer = buster.process_input("What is a transformer?")
191
  assert isinstance(answer, str)
tests/test_docparser.py CHANGED
@@ -1,11 +1,13 @@
1
  import numpy as np
2
  import pandas as pd
 
3
 
4
  from buster.docparser import generate_embeddings
5
- from buster.documents import get_documents_manager_from_extension
6
 
7
 
8
- def test_generate_embeddings(tmp_path, monkeypatch):
 
9
  # Create fake data
10
  data = pd.DataFrame.from_dict(
11
  {"title": ["test"], "url": ["http://url.com"], "content": ["cool text"], "source": ["my_source"]}
@@ -16,11 +18,11 @@ def test_generate_embeddings(tmp_path, monkeypatch):
16
  monkeypatch.setattr("buster.docparser.get_all_documents", lambda a, b, c: data)
17
 
18
  # Generate embeddings, store in a file
19
- output_file = tmp_path / "test_document_embeddings.tar.gz"
20
  df = generate_embeddings(data, output_file)
21
 
22
  # Read the embeddings from the file
23
- read_df = get_documents_manager_from_extension(output_file)(output_file).get_documents("my_source")
24
 
25
  # Check all the values are correct across the files
26
  assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
 
1
  import numpy as np
2
  import pandas as pd
3
+ import pytest
4
 
5
  from buster.docparser import generate_embeddings
6
+ from buster.utils import get_retriever_from_extension
7
 
8
 
9
+ @pytest.mark.parametrize("extension", ["db", "tar.gz"])
10
+ def test_generate_embeddings(tmp_path, monkeypatch, extension):
11
  # Create fake data
12
  data = pd.DataFrame.from_dict(
13
  {"title": ["test"], "url": ["http://url.com"], "content": ["cool text"], "source": ["my_source"]}
 
18
  monkeypatch.setattr("buster.docparser.get_all_documents", lambda a, b, c: data)
19
 
20
  # Generate embeddings, store in a file
21
+ output_file = tmp_path / f"test_document_embeddings.{extension}"
22
  df = generate_embeddings(data, output_file)
23
 
24
  # Read the embeddings from the file
25
+ read_df = get_retriever_from_extension(output_file)(output_file).get_documents("my_source")
26
 
27
  # Check all the values are correct across the files
28
  assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
tests/test_documents.py CHANGED
@@ -3,10 +3,14 @@ import pandas as pd
3
  import pytest
4
 
5
  from buster.documents import DocumentsDB, DocumentsPickle
 
6
 
7
 
8
- @pytest.mark.parametrize("documents_manager, extension", [(DocumentsDB, "db"), (DocumentsPickle, "tar.gz")])
9
- def test_write_read(tmp_path, documents_manager, extension):
 
 
 
10
  db = documents_manager(tmp_path / f"test.{extension}")
11
 
12
  data = pd.DataFrame.from_dict(
@@ -20,7 +24,7 @@ def test_write_read(tmp_path, documents_manager, extension):
20
  )
21
  db.add(source="test", df=data)
22
 
23
- db_data = db.get_documents("test")
24
 
25
  assert db_data["title"].iloc[0] == data["title"].iloc[0]
26
  assert db_data["url"].iloc[0] == data["url"].iloc[0]
@@ -29,8 +33,11 @@ def test_write_read(tmp_path, documents_manager, extension):
29
  assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
30
 
31
 
32
- @pytest.mark.parametrize("documents_manager, extension", [(DocumentsDB, "db"), (DocumentsPickle, "tar.gz")])
33
- def test_write_write_read(tmp_path, documents_manager, extension):
 
 
 
34
  db = documents_manager(tmp_path / f"test.{extension}")
35
 
36
  data_1 = pd.DataFrame.from_dict(
@@ -55,7 +62,7 @@ def test_write_write_read(tmp_path, documents_manager, extension):
55
  )
56
  db.add(source="test", df=data_2)
57
 
58
- db_data = db.get_documents("test")
59
 
60
  assert len(db_data) == len(data_2)
61
  assert db_data["title"].iloc[0] == data_2["title"].iloc[0]
 
3
  import pytest
4
 
5
  from buster.documents import DocumentsDB, DocumentsPickle
6
+ from buster.retriever import PickleRetriever, SQLiteRetriever
7
 
8
 
9
+ @pytest.mark.parametrize(
10
+ "documents_manager, retriever, extension",
11
+ [(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")],
12
+ )
13
+ def test_write_read(tmp_path, documents_manager, retriever, extension):
14
  db = documents_manager(tmp_path / f"test.{extension}")
15
 
16
  data = pd.DataFrame.from_dict(
 
24
  )
25
  db.add(source="test", df=data)
26
 
27
+ db_data = retriever(tmp_path / f"test.{extension}").get_documents("test")
28
 
29
  assert db_data["title"].iloc[0] == data["title"].iloc[0]
30
  assert db_data["url"].iloc[0] == data["url"].iloc[0]
 
33
  assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
34
 
35
 
36
+ @pytest.mark.parametrize(
37
+ "documents_manager, retriever, extension",
38
+ [(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")],
39
+ )
40
+ def test_write_write_read(tmp_path, documents_manager, retriever, extension):
41
  db = documents_manager(tmp_path / f"test.{extension}")
42
 
43
  data_1 = pd.DataFrame.from_dict(
 
62
  )
63
  db.add(source="test", df=data_2)
64
 
65
+ db_data = retriever(tmp_path / f"test.{extension}").get_documents("test")
66
 
67
  assert len(db_data) == len(data_2)
68
  assert db_data["title"].iloc[0] == data_2["title"].iloc[0]