hbertrand commited on
Commit
6aad21a
1 Parent(s): ebace01

PR: source display name (#80)

Browse files

* source display name

* tests

* black

* CR

* isort

buster/docparser.py CHANGED
@@ -157,6 +157,11 @@ def documents_to_db(documents: pd.DataFrame, output_filepath: str):
157
  logger.info(f"Documents saved to: {output_filepath}")
158
 
159
 
 
 
 
 
 
160
  def generate_embeddings(
161
  documents: pd.DataFrame,
162
  output_filepath: str = "documents.db",
 
157
  logger.info(f"Documents saved to: {output_filepath}")
158
 
159
 
160
+ def update_source(source: str, output_filepath: str, display_name: str = None, note: str = None):
161
+ documents_manager = get_documents_manager_from_extension(output_filepath)(output_filepath)
162
+ documents_manager.update_source(source, display_name, note)
163
+
164
+
165
  def generate_embeddings(
166
  documents: pd.DataFrame,
167
  output_filepath: str = "documents.db",
buster/documents/base.py CHANGED
@@ -8,4 +8,10 @@ import pandas as pd
8
  class DocumentsManager(ABC):
9
  @abstractmethod
10
  def add(self, source: str, df: pd.DataFrame):
 
 
 
 
 
 
11
  ...
 
8
  class DocumentsManager(ABC):
9
  @abstractmethod
10
  def add(self, source: str, df: pd.DataFrame):
11
+ """Write all documents from the dataframe into the db as a new version."""
12
+ ...
13
+
14
+ @abstractmethod
15
+ def update_source(self, source: str, display_name: str = None, note: str = None):
16
+ """Update the display name and/or note of a source. Also create the source if it does not exist."""
17
  ...
buster/documents/pickle.py CHANGED
@@ -15,6 +15,7 @@ class DocumentsPickle(DocumentsManager):
15
  self.documents = None
16
 
17
  def add(self, source: str, df: pd.DataFrame):
 
18
  if source is not None:
19
  df["source"] = source
20
 
@@ -27,3 +28,7 @@ class DocumentsPickle(DocumentsManager):
27
  self.documents = df
28
 
29
  self.documents.to_pickle(self.filepath)
 
 
 
 
 
15
  self.documents = None
16
 
17
  def add(self, source: str, df: pd.DataFrame):
18
+ """Write all documents from the dataframe into the db as a new version."""
19
  if source is not None:
20
  df["source"] = source
21
 
 
28
  self.documents = df
29
 
30
  self.documents.to_pickle(self.filepath)
31
+
32
+ def update_source(self, source: str, display_name: str = None, note: str = None):
33
+ """Update the display name and/or note of a source. Also create the source if it does not exist."""
34
+ print("If you need this function, please switch your backend to DocumentsDB.")
buster/documents/sqlite/documents.py CHANGED
@@ -141,3 +141,13 @@ class DocumentsDB(DocumentsManager):
141
  sid, vid = self.add_parse(source, (section for section, _ in sections))
142
  self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
143
  self.conn.commit()
 
 
 
 
 
 
 
 
 
 
 
141
  sid, vid = self.add_parse(source, (section for section, _ in sections))
142
  self.add_chunking(sid, vid, size, (chunks for _, chunks in sections))
143
  self.conn.commit()
144
+
145
+ def update_source(self, source: str, display_name: str = None, note: str = None):
146
+ """Update the display name and/or note of a source. Also create the source if it does not exist."""
147
+ sid = self.get_source(source)
148
+
149
+ if display_name is not None:
150
+ self.conn.execute("UPDATE sources SET display_name = ? WHERE id = ?", (display_name, sid))
151
+ if note is not None:
152
+ self.conn.execute("UPDATE sources SET note = ? WHERE id = ?", (note, sid))
153
+ self.conn.commit()
buster/documents/sqlite/schema.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
7
  id INTEGER PRIMARY KEY AUTOINCREMENT,
8
  name TEXT NOT NULL,
 
9
  note TEXT,
10
  UNIQUE(name)
11
  )"""
 
6
  SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
7
  id INTEGER PRIMARY KEY AUTOINCREMENT,
8
  name TEXT NOT NULL,
9
+ display_name TEXT,
10
  note TEXT,
11
  UNIQUE(name)
12
  )"""
buster/retriever/base.py CHANGED
@@ -4,11 +4,19 @@ from dataclasses import dataclass
4
  import pandas as pd
5
  from openai.embeddings_utils import cosine_similarity
6
 
 
 
7
 
8
  @dataclass
9
  class Retriever(ABC):
10
  @abstractmethod
11
  def get_documents(self, source: str) -> pd.DataFrame:
 
 
 
 
 
 
12
  ...
13
 
14
  def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
 
4
  import pandas as pd
5
  from openai.embeddings_utils import cosine_similarity
6
 
7
+ ALL_SOURCES = "All"
8
+
9
 
10
  @dataclass
11
  class Retriever(ABC):
12
  @abstractmethod
13
  def get_documents(self, source: str) -> pd.DataFrame:
14
+ """Get all current documents from a given source."""
15
+ ...
16
+
17
+ @abstractmethod
18
+ def get_source_display_name(self, source: str) -> str:
19
+ """Get the display name of a source."""
20
  ...
21
 
22
  def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
buster/retriever/pickle.py CHANGED
@@ -1,6 +1,6 @@
1
  import pandas as pd
2
 
3
- from buster.retriever.base import Retriever
4
 
5
 
6
  class PickleRetriever(Retriever):
@@ -9,6 +9,7 @@ class PickleRetriever(Retriever):
9
  self.documents = pd.read_pickle(filepath)
10
 
11
  def get_documents(self, source: str) -> pd.DataFrame:
 
12
  if self.documents is None:
13
  raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
14
 
@@ -24,3 +25,10 @@ class PickleRetriever(Retriever):
24
  documents = documents[documents.source == source]
25
 
26
  return documents
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
 
3
+ from buster.retriever.base import ALL_SOURCES, Retriever
4
 
5
 
6
  class PickleRetriever(Retriever):
 
9
  self.documents = pd.read_pickle(filepath)
10
 
11
  def get_documents(self, source: str) -> pd.DataFrame:
12
+ """Get all current documents from a given source."""
13
  if self.documents is None:
14
  raise FileNotFoundError(f"No documents found at {self.filepath}. Are you sure this is the correct path?")
15
 
 
25
  documents = documents[documents.source == source]
26
 
27
  return documents
28
+
29
+ def get_source_display_name(self, source: str) -> str:
30
+ """Get the display name of a source."""
31
+ if source is None:
32
+ return ALL_SOURCES
33
+ else:
34
+ return source
buster/retriever/sqlite.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
  import pandas as pd
5
 
6
  import buster.documents.sqlite.schema as schema
7
- from buster.retriever.base import Retriever
8
 
9
 
10
  class SQLiteRetriever(Retriever):
@@ -44,3 +44,15 @@ class SQLiteRetriever(Retriever):
44
  # Convert the results to a pandas DataFrame
45
  df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
46
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pandas as pd
5
 
6
  import buster.documents.sqlite.schema as schema
7
+ from buster.retriever.base import ALL_SOURCES, Retriever
8
 
9
 
10
  class SQLiteRetriever(Retriever):
 
44
  # Convert the results to a pandas DataFrame
45
  df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
46
  return df
47
+
48
+ def get_source_display_name(self, source: str) -> str:
49
+ """Get the display name of a source."""
50
+ if source is "":
51
+ return ALL_SOURCES
52
+ else:
53
+ cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,))
54
+ row = cur.fetchone()
55
+ if row is None:
56
+ raise KeyError(f'"{source}" is not a known source')
57
+ (display_name,) = row
58
+ return display_name
tests/test_chatbot.py CHANGED
@@ -49,6 +49,9 @@ class MockRetriever(Retriever):
49
  def get_documents(self, source):
50
  return self.documents
51
 
 
 
 
52
 
53
  import logging
54
 
 
49
  def get_documents(self, source):
50
  return self.documents
51
 
52
+ def get_source_display_name(self, source):
53
+ return source
54
+
55
 
56
  import logging
57
 
tests/test_documents.py CHANGED
@@ -70,3 +70,14 @@ def test_write_write_read(tmp_path, documents_manager, retriever, extension):
70
  assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
71
  assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
72
  assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
 
 
 
 
 
 
 
 
 
 
 
 
70
  assert db_data["content"].iloc[0] == data_2["content"].iloc[0]
71
  assert np.allclose(db_data["embedding"].iloc[0], data_2["embedding"].iloc[0])
72
  assert db_data["n_tokens"].iloc[0] == data_2["n_tokens"].iloc[0]
73
+
74
+
75
+ def test_update_source(tmp_path):
76
+ display_name = "Super Test"
77
+ db = DocumentsDB(tmp_path / "test.db")
78
+
79
+ db.update_source(source="test", display_name=display_name)
80
+
81
+ returned_display_name = SQLiteRetriever(tmp_path / "test.db").get_source_display_name("test")
82
+
83
+ assert display_name == returned_display_name