LOUIS SANNA commited on
Commit
780c913
1 Parent(s): 3a575de

feat(domains)

Browse files
Files changed (3) hide show
  1. anyqa/config.py +10 -0
  2. anyqa/retriever.py +7 -8
  3. app.py +11 -12
anyqa/config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+
5
+ def get_domains():
6
+ domains = []
7
+ for root, dirs, files in os.walk("data"):
8
+ for dir in dirs:
9
+ domains.append(dir)
10
+ return domains
anyqa/retriever.py CHANGED
@@ -13,25 +13,24 @@ SUMMARY_TYPES = []
13
 
14
  class QARetriever(BaseRetriever):
15
  vectorstore: VectorStore
16
- sources: list = []
17
  threshold: float = 22
18
  k_summary: int = 0
19
  k_total: int = 10
20
  namespace: str = "vectors"
21
 
22
  def get_relevant_documents(self, query: str) -> List[Document]:
23
- # Check if all elements in the list are either IPCC or IPBES
24
- assert isinstance(self.sources, list)
25
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
26
 
27
- query = "He who can bear the misfortune of a nation is called the ruler of the world."
28
  # Prepare base search kwargs
29
  filters = {}
30
- if len(self.sources):
31
- filters["source"] = {"$in": self.sources}
32
 
33
  if self.k_summary > 0:
34
  # Search for k_summary documents in the summaries dataset
 
35
  if len(SUMMARY_TYPES):
36
  filters_summaries = {
37
  **filters_summaries,
@@ -48,7 +47,8 @@ class QARetriever(BaseRetriever):
48
  docs_summaries = []
49
 
50
  # Search for k_total - k_summary documents in the full reports dataset
51
- filters_full = {}
 
52
  if len(SUMMARY_TYPES):
53
  filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
54
 
@@ -59,7 +59,6 @@ class QARetriever(BaseRetriever):
59
  filter=self.format_filter(filters_full),
60
  k=k_full,
61
  )
62
- print("docs_full", docs_full)
63
 
64
  # Concatenate documents
65
  docs = docs_summaries + docs_full
 
13
 
14
  class QARetriever(BaseRetriever):
15
  vectorstore: VectorStore
16
+ domains: list = []
17
  threshold: float = 22
18
  k_summary: int = 0
19
  k_total: int = 10
20
  namespace: str = "vectors"
21
 
22
  def get_relevant_documents(self, query: str) -> List[Document]:
23
+ assert isinstance(self.domains, list)
 
24
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
25
 
 
26
  # Prepare base search kwargs
27
  filters = {}
28
+ if len(self.domains):
29
+ filters["domain"] = {"$in": self.domains}
30
 
31
  if self.k_summary > 0:
32
  # Search for k_summary documents in the summaries dataset
33
+ filters_summaries = {**filters}
34
  if len(SUMMARY_TYPES):
35
  filters_summaries = {
36
  **filters_summaries,
 
47
  docs_summaries = []
48
 
49
  # Search for k_total - k_summary documents in the full reports dataset
50
+ filters_full = {**filters}
51
+ print("filters", filters)
52
  if len(SUMMARY_TYPES):
53
  filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
54
 
 
59
  filter=self.format_filter(filters_full),
60
  k=k_full,
61
  )
 
62
 
63
  # Concatenate documents
64
  docs = docs_summaries + docs_full
app.py CHANGED
@@ -7,6 +7,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
 
9
  # ClimateQ&A imports
 
10
  from anyqa.embeddings import EMBEDDING_MODEL_NAME
11
  from anyqa.llm import get_llm
12
  from anyqa.qa_logging import log
@@ -136,16 +137,14 @@ def answer_user_example(query, query_example, history):
136
  return query_example, history + [[query_example, ". . ."]]
137
 
138
 
139
- def fetch_sources(query, sources):
140
- # Prepare default values
141
- if len(sources) == 0:
142
- sources = ["IPCC"]
143
 
144
  llm_reformulation = get_llm(
145
  max_tokens=512, temperature=0.0, verbose=True, streaming=False
146
  )
 
147
  retriever = QARetriever(
148
- vectorstore=vectorstore, sources=[], k_summary=0, k_total=10
149
  )
150
  reformulation_chain = load_reformulation_chain(llm_reformulation)
151
 
@@ -379,11 +378,11 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
379
  gr.Markdown(
380
  "Reminder: You can talk in any language, this tool is multi-lingual!"
381
  )
382
-
383
- dropdown_sources = gr.CheckboxGroup(
384
- ["IPCC", "IPBES"],
385
- label="Select reports",
386
- value=["IPCC"],
387
  interactive=True,
388
  )
389
 
@@ -419,7 +418,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
419
  .success(change_tab, None, tabs)
420
  .success(
421
  fetch_sources,
422
- [textbox, dropdown_sources],
423
  [
424
  textbox,
425
  sources_textbox,
@@ -454,7 +453,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
454
  .success(change_tab, None, tabs)
455
  .success(
456
  fetch_sources,
457
- [textbox, dropdown_sources],
458
  [
459
  textbox,
460
  sources_textbox,
 
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
 
9
  # ClimateQ&A imports
10
+ from anyqa.config import get_domains
11
  from anyqa.embeddings import EMBEDDING_MODEL_NAME
12
  from anyqa.llm import get_llm
13
  from anyqa.qa_logging import log
 
137
  return query_example, history + [[query_example, ". . ."]]
138
 
139
 
140
+ def fetch_sources(query, domains):
 
 
 
141
 
142
  llm_reformulation = get_llm(
143
  max_tokens=512, temperature=0.0, verbose=True, streaming=False
144
  )
145
+ print("domains", domains)
146
  retriever = QARetriever(
147
+ vectorstore=vectorstore, domains=domains, k_summary=0, k_total=10
148
  )
149
  reformulation_chain = load_reformulation_chain(llm_reformulation)
150
 
 
378
  gr.Markdown(
379
  "Reminder: You can talk in any language, this tool is multi-lingual!"
380
  )
381
+ domains = get_domains()
382
+ dropdown_domains = gr.CheckboxGroup(
383
+ domains,
384
+ label="Select source types",
385
+ value=[],
386
  interactive=True,
387
  )
388
 
 
418
  .success(change_tab, None, tabs)
419
  .success(
420
  fetch_sources,
421
+ [textbox, dropdown_domains],
422
  [
423
  textbox,
424
  sources_textbox,
 
453
  .success(change_tab, None, tabs)
454
  .success(
455
  fetch_sources,
456
+ [textbox, dropdown_domains],
457
  [
458
  textbox,
459
  sources_textbox,