hbertrand commited on
Commit
71e7dd8
1 Parent(s): 2457179

PR: DocumentsManager interface (#57)

Browse files
buster/chatbot.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
  import promptlayer
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
- from buster.docparser import read_documents
13
  from buster.formatter import Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter
14
  from buster.formatter.base import Response, Source
15
 
@@ -47,7 +47,7 @@ class ChatbotConfig:
47
  text_after_response: Generic response to add the the chatbot's reply.
48
  """
49
 
50
- documents_file: str = "buster/data/document_embeddings.csv"
51
  embedding_model: str = "text-embedding-ada-002"
52
  top_k: int = 3
53
  thresh: float = 0.7
@@ -82,7 +82,7 @@ class Chatbot:
82
  def _init_documents(self):
83
  filepath = self.cfg.documents_file
84
  logger.info(f"loading embeddings from {filepath}...")
85
- self.documents = read_documents(filepath)
86
  logger.info(f"embeddings loaded.")
87
 
88
  def _init_unk_embedding(self):
@@ -94,7 +94,6 @@ class Chatbot:
94
 
95
  def rank_documents(
96
  self,
97
- documents: pd.DataFrame,
98
  query: str,
99
  top_k: float,
100
  thresh: float,
@@ -108,14 +107,7 @@ class Chatbot:
108
  query,
109
  engine=engine,
110
  )
111
- documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
112
-
113
- # sort the matched_documents by score
114
- matched_documents = documents.sort_values("similarity", ascending=False)
115
-
116
- # limit search to top_k matched_documents.
117
- top_k = len(matched_documents) if top_k == -1 else top_k
118
- matched_documents = matched_documents.head(top_k)
119
 
120
  # log matched_documents to the console
121
  logger.info(f"matched documents before thresh: {matched_documents}")
@@ -236,7 +228,6 @@ class Chatbot:
236
  question += "\n"
237
 
238
  matched_documents = self.rank_documents(
239
- documents=self.documents,
240
  query=question,
241
  top_k=self.cfg.top_k,
242
  thresh=self.cfg.thresh,
 
9
  import promptlayer
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
+ from buster.documents import get_documents_manager_from_extension
13
  from buster.formatter import Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter
14
  from buster.formatter.base import Response, Source
15
 
 
47
  text_after_response: Generic response to add the the chatbot's reply.
48
  """
49
 
50
+ documents_file: str = "buster/data/document_embeddings.tar.gz"
51
  embedding_model: str = "text-embedding-ada-002"
52
  top_k: int = 3
53
  thresh: float = 0.7
 
82
  def _init_documents(self):
83
  filepath = self.cfg.documents_file
84
  logger.info(f"loading embeddings from {filepath}...")
85
+ self.documents = get_documents_manager_from_extension(filepath)(filepath)
86
  logger.info(f"embeddings loaded.")
87
 
88
  def _init_unk_embedding(self):
 
94
 
95
  def rank_documents(
96
  self,
 
97
  query: str,
98
  top_k: float,
99
  thresh: float,
 
107
  query,
108
  engine=engine,
109
  )
110
+ matched_documents = self.documents.retrieve(query_embedding, top_k)
 
 
 
 
 
 
 
111
 
112
  # log matched_documents to the console
113
  logger.info(f"matched documents before thresh: {matched_documents}")
 
228
  question += "\n"
229
 
230
  matched_documents = self.rank_documents(
 
231
  query=question,
232
  top_k=self.cfg.top_k,
233
  thresh=self.cfg.thresh,
buster/docparser.py CHANGED
@@ -8,16 +8,13 @@ import tiktoken
8
  from bs4 import BeautifulSoup
9
  from openai.embeddings_utils import get_embedding
10
 
11
- from buster.db import DocumentsDB
12
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
13
 
14
  EMBEDDING_MODEL = "text-embedding-ada-002"
15
  EMBEDDING_ENCODING = "cl100k_base" # this the encoding for text-embedding-ada-002
16
 
17
 
18
- PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
19
-
20
-
21
  supported_docs = {
22
  "mila": {
23
  "base_url": "https://docs.mila.quebec/",
@@ -77,46 +74,6 @@ def get_all_documents(
77
  return documents_df
78
 
79
 
80
- def get_file_extension(filepath: str) -> str:
81
- return os.path.splitext(filepath)[1]
82
-
83
-
84
- def write_documents(filepath: str, documents_df: pd.DataFrame, source: str = ""):
85
- ext = get_file_extension(filepath)
86
-
87
- if ext == ".csv":
88
- documents_df.to_csv(filepath, index=False)
89
- elif ext in PICKLE_EXTENSIONS:
90
- documents_df.to_pickle(filepath)
91
- elif ext == ".db":
92
- db = DocumentsDB(filepath)
93
- db.write_documents(source, documents_df)
94
- else:
95
- raise ValueError(f"Unsupported format: {ext}.")
96
-
97
-
98
- def read_documents(filepath: str, source: str = "") -> pd.DataFrame:
99
- ext = get_file_extension(filepath)
100
-
101
- if ext == ".csv":
102
- df = pd.read_csv(filepath)
103
-
104
- if "embedding" in df.columns:
105
- df["embedding"] = df.embedding.apply(eval).apply(np.array)
106
- elif ext in PICKLE_EXTENSIONS:
107
- df = pd.read_pickle(filepath)
108
-
109
- if "embedding" in df.columns:
110
- df["embedding"] = df.embedding.apply(np.array)
111
- elif ext == ".db":
112
- db = DocumentsDB(filepath)
113
- df = db.get_documents(source)
114
- else:
115
- raise ValueError(f"Unsupported format: {ext}.")
116
-
117
- return df
118
-
119
-
120
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
121
  encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
122
  # TODO are there unexpected consequences of allowing endoftext?
@@ -129,10 +86,13 @@ def precompute_embeddings(df: pd.DataFrame) -> pd.DataFrame:
129
  return df
130
 
131
 
132
- def generate_embeddings(filepath: str, output_file: str, source: str) -> pd.DataFrame:
133
  # Get all documents and precompute their embeddings
134
- df = read_documents(filepath, source)
135
- df = compute_n_tokens(df)
136
- df = precompute_embeddings(df)
137
- write_documents(filepath=output_file, documents_df=df, source=source)
138
- return df
 
 
 
 
8
  from bs4 import BeautifulSoup
9
  from openai.embeddings_utils import get_embedding
10
 
11
+ from buster.documents import get_documents_manager_from_extension
12
  from buster.parser import HuggingfaceParser, Parser, SphinxParser
13
 
14
  EMBEDDING_MODEL = "text-embedding-ada-002"
15
  EMBEDDING_ENCODING = "cl100k_base" # this the encoding for text-embedding-ada-002
16
 
17
 
 
 
 
18
  supported_docs = {
19
  "mila": {
20
  "base_url": "https://docs.mila.quebec/",
 
74
  return documents_df
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def compute_n_tokens(df: pd.DataFrame) -> pd.DataFrame:
78
  encoding = tiktoken.get_encoding(EMBEDDING_ENCODING)
79
  # TODO are there unexpected consequences of allowing endoftext?
 
86
  return df
87
 
88
 
89
+ def generate_embeddings(root_dir: str, output_filepath: str, source: str) -> pd.DataFrame:
90
  # Get all documents and precompute their embeddings
91
+ documents = get_all_documents(root_dir, supported_docs[source]["base_url"], supported_docs[source]["parser"])
92
+ documents = compute_n_tokens(documents)
93
+ documents = precompute_embeddings(documents)
94
+
95
+ documents_manager = get_documents_manager_from_extension(output_filepath)(output_filepath)
96
+ documents_manager.add(source, documents)
97
+
98
+ return documents
buster/documents/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base import DocumentsManager
2
+ from .pickle import DocumentsPickle
3
+ from .sqlite import DocumentsDB
4
+ from .utils import get_documents_manager_from_extension
5
+
6
+ __all__ = [DocumentsManager, DocumentsPickle, DocumentsDB, get_documents_manager_from_extension]
buster/documents/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+
4
+ import pandas as pd
5
+ from openai.embeddings_utils import cosine_similarity
6
+
7
+
8
+ @dataclass
9
+ class DocumentsManager(ABC):
10
+ @abstractmethod
11
+ def add(self, source: str, df: pd.DataFrame):
12
+ ...
13
+
14
+ @abstractmethod
15
+ def get_documents(self, source: str) -> pd.DataFrame:
16
+ ...
17
+
18
+ def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
19
+ documents = self.get_documents(source)
20
+
21
+ documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
22
+
23
+ # sort the matched_documents by score
24
+ matched_documents = documents.sort_values("similarity", ascending=False)
25
+
26
+ # limit search to top_k matched_documents.
27
+ top_k = len(matched_documents) if top_k == -1 else top_k
28
+ matched_documents = matched_documents.head(top_k)
29
+
30
+ return matched_documents
buster/documents/pickle.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
+
5
+ from buster.documents.base import DocumentsManager
6
+
7
+
8
+ class DocumentsPickle(DocumentsManager):
9
+ def __init__(self, filepath: str):
10
+ self.filepath = filepath
11
+
12
+ if os.path.exists(filepath):
13
+ self.documents = pd.read_pickle(filepath)
14
+ else:
15
+ self.documents = None
16
+
17
+ def add(self, source: str, df: pd.DataFrame):
18
+ if source is not None:
19
+ df["source"] = source
20
+
21
+ df["current"] = 1
22
+
23
+ if self.documents is not None:
24
+ self.documents.loc[self.documents.source == source, "current"] = 0
25
+ self.documents = pd.concat([self.documents, df])
26
+ else:
27
+ self.documents = df
28
+
29
+ self.documents.to_pickle(self.filepath)
30
+
31
+ def get_documents(self, source: str) -> pd.DataFrame:
32
+ documents = self.documents.copy()
33
+ documents = documents[documents.current == 1]
34
+
35
+ if source is not None and "source" in documents.columns:
36
+ documents = documents[documents.source == source]
37
+
38
+ return documents
buster/{db.py → documents/sqlite.py} RENAMED
@@ -5,6 +5,8 @@ import zlib
5
  import numpy as np
6
  import pandas as pd
7
 
 
 
8
  documents_table = """CREATE TABLE IF NOT EXISTS documents (
9
  id INTEGER PRIMARY KEY AUTOINCREMENT,
10
  source TEXT NOT NULL,
@@ -33,7 +35,7 @@ qa_table = """CREATE TABLE IF NOT EXISTS qa (
33
  )"""
34
 
35
 
36
- class DocumentsDB:
37
  """Simple SQLite database for storing documents and questions/answers.
38
 
39
  The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
@@ -41,13 +43,13 @@ class DocumentsDB:
41
 
42
  Example:
43
  >>> db = DocumentsDB("/path/to/the/db.db")
44
- >>> db.write_documents("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
45
  >>> df = db.get_documents("source")
46
  """
47
 
48
- def __init__(self, db_path):
49
- self.db_path = db_path
50
- self.conn = sqlite3.connect(db_path)
51
  self.cursor = self.conn.cursor()
52
 
53
  self.__initialize()
@@ -61,7 +63,7 @@ class DocumentsDB:
61
  self.cursor.execute(qa_table)
62
  self.conn.commit()
63
 
64
- def write_documents(self, source: str, df: pd.DataFrame):
65
  """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
66
  df = df.copy()
67
 
@@ -102,7 +104,10 @@ class DocumentsDB:
102
  def get_documents(self, source: str) -> pd.DataFrame:
103
  """Get all current documents from a given source."""
104
  # Execute the SQL statement and fetch the results
105
- results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
 
 
 
106
  rows = results.fetchall()
107
 
108
  # Convert the results to a pandas DataFrame
 
5
  import numpy as np
6
  import pandas as pd
7
 
8
+ from buster.documents.base import DocumentsManager
9
+
10
  documents_table = """CREATE TABLE IF NOT EXISTS documents (
11
  id INTEGER PRIMARY KEY AUTOINCREMENT,
12
  source TEXT NOT NULL,
 
35
  )"""
36
 
37
 
38
+ class DocumentsDB(DocumentsManager):
39
  """Simple SQLite database for storing documents and questions/answers.
40
 
41
  The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
 
43
 
44
  Example:
45
  >>> db = DocumentsDB("/path/to/the/db.db")
46
+ >>> db.add("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
47
  >>> df = db.get_documents("source")
48
  """
49
 
50
+ def __init__(self, filepath: str):
51
+ self.db_path = filepath
52
+ self.conn = sqlite3.connect(filepath)
53
  self.cursor = self.conn.cursor()
54
 
55
  self.__initialize()
 
63
  self.cursor.execute(qa_table)
64
  self.conn.commit()
65
 
66
+ def add(self, source: str, df: pd.DataFrame):
67
  """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
68
  df = df.copy()
69
 
 
104
  def get_documents(self, source: str) -> pd.DataFrame:
105
  """Get all current documents from a given source."""
106
  # Execute the SQL statement and fetch the results
107
+ if source is not None:
108
+ results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,))
109
+ else:
110
+ results = self.cursor.execute("SELECT * FROM documents WHERE current = 1")
111
  rows = results.fetchall()
112
 
113
  # Convert the results to a pandas DataFrame
buster/documents/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Type
3
+
4
+ from buster.documents.base import DocumentsManager
5
+ from buster.documents.pickle import DocumentsPickle
6
+ from buster.documents.sqlite import DocumentsDB
7
+
8
+ PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]
9
+
10
+
11
+ def get_file_extension(filepath: str) -> str:
12
+ return os.path.splitext(filepath)[1]
13
+
14
+
15
+ def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager]:
16
+ ext = get_file_extension(filepath)
17
+
18
+ if ext in PICKLE_EXTENSIONS:
19
+ return DocumentsPickle
20
+ elif ext == ".db":
21
+ return DocumentsDB
22
+ else:
23
+ raise ValueError(f"Unsupported format: {ext}.")
tests/test_docparser.py CHANGED
@@ -1,26 +1,24 @@
1
  import numpy as np
2
  import pandas as pd
3
 
4
- from buster.docparser import generate_embeddings, read_documents, write_documents
 
5
 
6
 
7
  def test_generate_embeddings(tmp_path, monkeypatch):
8
- # Patch the get_embedding function to return a fixed embedding
9
- monkeypatch.setattr("buster.docparser.get_embedding", lambda x, engine: [-0.005, 0.0018])
10
-
11
  # Create fake data
12
  data = pd.DataFrame.from_dict({"title": ["test"], "url": ["http://url.com"], "content": ["cool text"]})
13
 
14
- # Write the data to a file
15
- filepath = tmp_path / "test_document.csv"
16
- write_documents(filepath=filepath, documents_df=data, source="test")
17
 
18
  # Generate embeddings, store in a file
19
  output_file = tmp_path / "test_document_embeddings.tar.gz"
20
- df = generate_embeddings(filepath=filepath, output_file=output_file, source="test")
21
 
22
  # Read the embeddings from the file
23
- read_df = read_documents(output_file, "test")
24
 
25
  # Check all the values are correct across the files
26
  assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
 
1
  import numpy as np
2
  import pandas as pd
3
 
4
+ from buster.docparser import generate_embeddings
5
+ from buster.documents import get_documents_manager_from_extension
6
 
7
 
8
  def test_generate_embeddings(tmp_path, monkeypatch):
 
 
 
9
  # Create fake data
10
  data = pd.DataFrame.from_dict({"title": ["test"], "url": ["http://url.com"], "content": ["cool text"]})
11
 
12
+ # Patch the get_embedding function to return a fixed embedding
13
+ monkeypatch.setattr("buster.docparser.get_embedding", lambda x, engine: [-0.005, 0.0018])
14
+ monkeypatch.setattr("buster.docparser.get_all_documents", lambda a, b, c: data)
15
 
16
  # Generate embeddings, store in a file
17
  output_file = tmp_path / "test_document_embeddings.tar.gz"
18
+ df = generate_embeddings(tmp_path, output_file, source="mila")
19
 
20
  # Read the embeddings from the file
21
+ read_df = get_documents_manager_from_extension(output_file)(output_file).get_documents("mila")
22
 
23
  # Check all the values are correct across the files
24
  assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0]
tests/{test_db.py → test_documents.py} RENAMED
@@ -1,11 +1,13 @@
1
  import numpy as np
2
  import pandas as pd
 
3
 
4
- from buster.db import DocumentsDB
5
 
6
 
7
- def test_write_read():
8
- db = DocumentsDB(":memory:")
 
9
 
10
  data = pd.DataFrame.from_dict(
11
  {
@@ -16,7 +18,7 @@ def test_write_read():
16
  "n_tokens": [10],
17
  }
18
  )
19
- db.write_documents(source="test", df=data)
20
 
21
  db_data = db.get_documents("test")
22
 
@@ -27,8 +29,9 @@ def test_write_read():
27
  assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
28
 
29
 
30
- def test_write_write_read():
31
- db = DocumentsDB(":memory:")
 
32
 
33
  data_1 = pd.DataFrame.from_dict(
34
  {
@@ -39,7 +42,7 @@ def test_write_write_read():
39
  "n_tokens": [10],
40
  }
41
  )
42
- db.write_documents(source="test", df=data_1)
43
 
44
  data_2 = pd.DataFrame.from_dict(
45
  {
@@ -50,7 +53,7 @@ def test_write_write_read():
50
  "n_tokens": [20],
51
  }
52
  )
53
- db.write_documents(source="test", df=data_2)
54
 
55
  db_data = db.get_documents("test")
56
 
 
1
  import numpy as np
2
  import pandas as pd
3
+ import pytest
4
 
5
+ from buster.documents import DocumentsDB, DocumentsPickle
6
 
7
 
8
+ @pytest.mark.parametrize("documents_manager, extension", [(DocumentsDB, "db"), (DocumentsPickle, "tar.gz")])
9
+ def test_write_read(tmp_path, documents_manager, extension):
10
+ db = documents_manager(tmp_path / f"test.{extension}")
11
 
12
  data = pd.DataFrame.from_dict(
13
  {
 
18
  "n_tokens": [10],
19
  }
20
  )
21
+ db.add(source="test", df=data)
22
 
23
  db_data = db.get_documents("test")
24
 
 
29
  assert db_data["n_tokens"].iloc[0] == data["n_tokens"].iloc[0]
30
 
31
 
32
+ @pytest.mark.parametrize("documents_manager, extension", [(DocumentsDB, "db"), (DocumentsPickle, "tar.gz")])
33
+ def test_write_write_read(tmp_path, documents_manager, extension):
34
+ db = documents_manager(tmp_path / f"test.{extension}")
35
 
36
  data_1 = pd.DataFrame.from_dict(
37
  {
 
42
  "n_tokens": [10],
43
  }
44
  )
45
+ db.add(source="test", df=data_1)
46
 
47
  data_2 = pd.DataFrame.from_dict(
48
  {
 
53
  "n_tokens": [20],
54
  }
55
  )
56
+ db.add(source="test", df=data_2)
57
 
58
  db_data = db.get_documents("test")
59