LOUIS SANNA commited on
Commit
6e28a81
1 Parent(s): 35c9187
climateqa/chains.py CHANGED
@@ -3,20 +3,20 @@
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
- from langchain.chains import RetrievalQAWithSourcesChain,QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
- from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts
11
  from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
12
 
13
- def load_reformulation_chain(llm):
14
 
 
15
  prompt = PromptTemplate(
16
- template = reformulation_prompt,
17
  input_variables=["query"],
18
  )
19
- reformulation_chain = LLMChain(llm = llm,prompt = prompt,output_key="json")
20
 
21
  # Parse the output
22
  def parse_output(output):
@@ -28,20 +28,30 @@ def load_reformulation_chain(llm):
28
  "question": question,
29
  "language": language,
30
  }
31
-
32
  transform_chain = TransformChain(
33
- input_variables=["json"], output_variables=["question","language"], transform=parse_output
 
 
34
  )
35
 
36
- reformulation_chain = SequentialChain(chains = [reformulation_chain,transform_chain],input_variables=["query"],output_variables=["question","language"])
 
 
 
 
37
  return reformulation_chain
38
 
39
 
40
  def load_combine_documents_chain(llm):
41
- prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
42
- qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
 
 
 
43
  return qa_chain
44
 
 
45
  def load_qa_chain_with_docs(llm):
46
  """Load a QA chain with documents.
47
  Useful when you already have retrieved docs
@@ -60,50 +70,47 @@ def load_qa_chain_with_docs(llm):
60
 
61
  qa_chain = load_combine_documents_chain(llm)
62
  chain = QAWithSourcesChain(
63
- input_docs_key = "docs",
64
- combine_documents_chain = qa_chain,
65
- return_source_documents = True,
66
  )
67
  return chain
68
 
69
 
70
  def load_qa_chain_with_text(llm):
71
-
72
  prompt = PromptTemplate(
73
- template = answer_prompt,
74
- input_variables=["question","audience","language","summaries"],
75
  )
76
- qa_chain = LLMChain(llm = llm,prompt = prompt)
77
  return qa_chain
78
 
79
 
80
- def load_qa_chain_with_retriever(retriever,llm):
81
  qa_chain = load_combine_documents_chain(llm)
82
 
83
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
84
  # See here https://github.com/langchain-ai/langchain/issues/3523
85
 
86
  answer_chain = CustomRetrievalQAWithSourcesChain(
87
- combine_documents_chain = qa_chain,
88
  retriever=retriever,
89
- return_source_documents = True,
90
- verbose = True,
91
  fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
92
  )
93
  return answer_chain
94
 
95
 
96
- def load_climateqa_chain(retriever,llm_reformulation,llm_answer):
97
-
98
  reformulation_chain = load_reformulation_chain(llm_reformulation)
99
- answer_chain = load_qa_chain_with_retriever(retriever,llm_answer)
100
 
101
  climateqa_chain = SequentialChain(
102
- chains = [reformulation_chain,answer_chain],
103
- input_variables=["query","audience"],
104
- output_variables=["answer","question","language","source_documents"],
105
- return_all = True,
106
- verbose = True,
107
  )
108
  return climateqa_chain
109
-
 
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
+ from langchain.chains import RetrievalQAWithSourcesChain, QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
+ from climateqa.prompts import answer_prompt, reformulation_prompt, audience_prompts
11
  from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
12
 
 
13
 
14
+ def load_reformulation_chain(llm):
15
  prompt = PromptTemplate(
16
+ template=reformulation_prompt,
17
  input_variables=["query"],
18
  )
19
+ reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
20
 
21
  # Parse the output
22
  def parse_output(output):
 
28
  "question": question,
29
  "language": language,
30
  }
31
+
32
  transform_chain = TransformChain(
33
+ input_variables=["json"],
34
+ output_variables=["question", "language"],
35
+ transform=parse_output,
36
  )
37
 
38
+ reformulation_chain = SequentialChain(
39
+ chains=[reformulation_chain, transform_chain],
40
+ input_variables=["query"],
41
+ output_variables=["question", "language"],
42
+ )
43
  return reformulation_chain
44
 
45
 
46
  def load_combine_documents_chain(llm):
47
+ prompt = PromptTemplate(
48
+ template=answer_prompt,
49
+ input_variables=["summaries", "question", "audience", "language"],
50
+ )
51
+ qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
52
  return qa_chain
53
 
54
+
55
  def load_qa_chain_with_docs(llm):
56
  """Load a QA chain with documents.
57
  Useful when you already have retrieved docs
 
70
 
71
  qa_chain = load_combine_documents_chain(llm)
72
  chain = QAWithSourcesChain(
73
+ input_docs_key="docs",
74
+ combine_documents_chain=qa_chain,
75
+ return_source_documents=True,
76
  )
77
  return chain
78
 
79
 
80
  def load_qa_chain_with_text(llm):
 
81
  prompt = PromptTemplate(
82
+ template=answer_prompt,
83
+ input_variables=["question", "audience", "language", "summaries"],
84
  )
85
+ qa_chain = LLMChain(llm=llm, prompt=prompt)
86
  return qa_chain
87
 
88
 
89
+ def load_qa_chain_with_retriever(retriever, llm):
90
  qa_chain = load_combine_documents_chain(llm)
91
 
92
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
93
  # See here https://github.com/langchain-ai/langchain/issues/3523
94
 
95
  answer_chain = CustomRetrievalQAWithSourcesChain(
96
+ combine_documents_chain=qa_chain,
97
  retriever=retriever,
98
+ return_source_documents=True,
99
+ verbose=True,
100
  fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
101
  )
102
  return answer_chain
103
 
104
 
105
+ def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
 
106
  reformulation_chain = load_reformulation_chain(llm_reformulation)
107
+ answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
108
 
109
  climateqa_chain = SequentialChain(
110
+ chains=[reformulation_chain, answer_chain],
111
+ input_variables=["query", "audience"],
112
+ output_variables=["answer", "question", "language", "source_documents"],
113
+ return_all=True,
114
+ verbose=True,
115
  )
116
  return climateqa_chain
 
climateqa/chat.py CHANGED
@@ -12,28 +12,31 @@ from climateqa.chains import load_climateqa_chain
12
 
13
 
14
  class ClimateQA:
15
- def __init__(self,hf_embedding_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
16
- show_progress_bar = False,batch_size = 1,max_tokens = 1024,**kwargs):
17
-
18
- self.llm = self.get_llm(max_tokens = max_tokens,**kwargs)
 
 
 
 
 
19
  self.embeddings_function = HuggingFaceEmbeddings(
20
  model_name=hf_embedding_model,
21
- encode_kwargs={"show_progress_bar":show_progress_bar,"batch_size":batch_size}
 
 
 
22
  )
23
 
24
-
25
-
26
  def get_vectorstore(self):
27
  pass
28
 
29
-
30
  def reformulate(self):
31
  pass
32
 
33
-
34
  def retrieve(self):
35
  pass
36
 
37
-
38
  def ask(self):
39
- pass
 
12
 
13
 
14
  class ClimateQA:
15
+ def __init__(
16
+ self,
17
+ hf_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
18
+ show_progress_bar=False,
19
+ batch_size=1,
20
+ max_tokens=1024,
21
+ **kwargs
22
+ ):
23
+ self.llm = self.get_llm(max_tokens=max_tokens, **kwargs)
24
  self.embeddings_function = HuggingFaceEmbeddings(
25
  model_name=hf_embedding_model,
26
+ encode_kwargs={
27
+ "show_progress_bar": show_progress_bar,
28
+ "batch_size": batch_size,
29
+ },
30
  )
31
 
 
 
32
  def get_vectorstore(self):
33
  pass
34
 
 
35
  def reformulate(self):
36
  pass
37
 
 
38
  def retrieve(self):
39
  pass
40
 
 
41
  def ask(self):
42
+ pass
climateqa/custom_retrieval_chain.py CHANGED
@@ -29,11 +29,11 @@ from langchain.chains import RetrievalQAWithSourcesChain
29
 
30
  from langchain.chains.router.llm_router import LLMRouterChain
31
 
32
- class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
33
 
34
- fallback_answer:str = "No sources available to answer this question."
 
35
 
36
- def _call(self,inputs,run_manager=None):
37
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
38
  accepts_run_manager = (
39
  "run_manager" in inspect.signature(self._get_docs).parameters
@@ -43,12 +43,10 @@ class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
43
  else:
44
  docs = self._get_docs(inputs) # type: ignore[call-arg]
45
 
46
-
47
  if len(docs) == 0:
48
  answer = self.fallback_answer
49
  sources = []
50
  else:
51
-
52
  answer = self.combine_documents_chain.run(
53
  input_documents=docs, callbacks=_run_manager.get_child(), **inputs
54
  )
 
29
 
30
  from langchain.chains.router.llm_router import LLMRouterChain
31
 
 
32
 
33
+ class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
34
+ fallback_answer: str = "No sources available to answer this question."
35
 
36
+ def _call(self, inputs, run_manager=None):
37
  _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
38
  accepts_run_manager = (
39
  "run_manager" in inspect.signature(self._get_docs).parameters
 
43
  else:
44
  docs = self._get_docs(inputs) # type: ignore[call-arg]
45
 
 
46
  if len(docs) == 0:
47
  answer = self.fallback_answer
48
  sources = []
49
  else:
 
50
  answer = self.combine_documents_chain.run(
51
  input_documents=docs, callbacks=_run_manager.get_child(), **inputs
52
  )
climateqa/llm.py CHANGED
@@ -1,25 +1,26 @@
1
  from langchain.chat_models import AzureChatOpenAI
2
  import os
 
3
  # LOAD ENVIRONMENT VARIABLES
4
  try:
5
  from dotenv import load_dotenv
 
6
  load_dotenv()
7
  except:
8
  pass
9
 
10
 
11
- def get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = False, **kwargs):
12
-
13
  llm = AzureChatOpenAI(
14
  openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
15
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
16
  deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
17
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
18
- openai_api_type = "azure",
19
- max_tokens = max_tokens,
20
- temperature = temperature,
21
- verbose = verbose,
22
- streaming = streaming,
23
  **kwargs,
24
  )
25
  return llm
 
1
  from langchain.chat_models import AzureChatOpenAI
2
  import os
3
+
4
  # LOAD ENVIRONMENT VARIABLES
5
  try:
6
  from dotenv import load_dotenv
7
+
8
  load_dotenv()
9
  except:
10
  pass
11
 
12
 
13
+ def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
 
14
  llm = AzureChatOpenAI(
15
  openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
16
  openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
17
  deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
18
  openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
19
+ openai_api_type="azure",
20
+ max_tokens=max_tokens,
21
+ temperature=temperature,
22
+ verbose=verbose,
23
+ streaming=streaming,
24
  **kwargs,
25
  )
26
  return llm
climateqa/logging.py CHANGED
@@ -53,6 +53,7 @@ def get_azure_blob_client():
53
  share_client = service.get_share_client(file_share_name)
54
  return share_client
55
 
 
56
  if has_blob_config():
57
  share_client = get_azure_blob_client()
58
 
 
53
  share_client = service.get_share_client(file_share_name)
54
  return share_client
55
 
56
+
57
  if has_blob_config():
58
  share_client = get_azure_blob_client()
59
 
climateqa/prompts.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  # If the message is not relevant to climate change (like "How are you", "I am 18 years old" or "When was built the eiffel tower"), return N/A
3
 
4
  reformulation_prompt = """
@@ -54,4 +53,4 @@ audience_prompts = {
54
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
55
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
56
  "experts": "expert and climate scientists that are not afraid of technical terms",
57
- }
 
 
1
  # If the message is not relevant to climate change (like "How are you", "I am 18 years old" or "When was built the eiffel tower"), return N/A
2
 
3
  reformulation_prompt = """
 
53
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
54
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
55
  "experts": "expert and climate scientists that are not afraid of technical terms",
56
+ }
climateqa/retriever.py CHANGED
@@ -9,41 +9,48 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
9
  from typing import List
10
  from pydantic import Field
11
 
 
12
  class ClimateQARetriever(BaseRetriever):
13
- vectorstore:VectorStore
14
- sources:list = ["IPCC","IPBES"]
15
- threshold:float = 22
16
- k_summary:int = 3
17
- k_total:int = 10
18
- namespace:str = "vectors"
19
 
20
  def get_relevant_documents(self, query: str) -> List[Document]:
21
-
22
  # Check if all elements in the list are either IPCC or IPBES
23
- assert isinstance(self.sources,list)
24
- assert all([x in ["IPCC","IPBES"] for x in self.sources])
25
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
26
 
27
  # Prepare base search kwargs
28
  filters = {
29
- "source": { "$in":self.sources},
30
  }
31
 
32
  # Search for k_summary documents in the summaries dataset
33
  filters_summaries = {
34
  **filters,
35
- "report_type": { "$in":["SPM","TS"]},
36
  }
37
- docs_summaries = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_summaries,k = self.k_summary)
 
 
 
 
 
38
  docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
39
 
40
  # Search for k_total - k_summary documents in the full reports dataset
41
  filters_full = {
42
  **filters,
43
- "report_type": { "$nin":["SPM","TS"]},
44
  }
45
  k_full = self.k_total - len(docs_summaries)
46
- docs_full = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_full,k = k_full)
 
 
47
 
48
  # Concatenate documents
49
  docs = docs_summaries + docs_full
@@ -53,19 +60,18 @@ class ClimateQARetriever(BaseRetriever):
53
 
54
  # Add score to metadata
55
  results = []
56
- for i,(doc,score) in enumerate(docs):
57
  doc.metadata["similarity_score"] = score
58
  doc.metadata["content"] = doc.page_content
59
  doc.metadata["page_number"] = int(doc.metadata["page_number"])
60
- doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
 
 
61
  results.append(doc)
62
 
63
  return results
64
 
65
 
66
-
67
-
68
-
69
  # def filter_summaries(df,k_summary = 3,k_total = 10):
70
  # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
71
 
@@ -92,8 +98,6 @@ class ClimateQARetriever(BaseRetriever):
92
  # return passages
93
 
94
 
95
-
96
-
97
  # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
98
  # assert max_k > k_total
99
 
@@ -125,7 +129,6 @@ class ClimateQARetriever(BaseRetriever):
125
  # return passages_df
126
 
127
 
128
-
129
  # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
130
 
131
 
@@ -146,4 +149,3 @@ class ClimateQARetriever(BaseRetriever):
146
  # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
147
  # }
148
  # return response
149
-
 
9
  from typing import List
10
  from pydantic import Field
11
 
12
+
13
  class ClimateQARetriever(BaseRetriever):
14
+ vectorstore: VectorStore
15
+ sources: list = ["IPCC", "IPBES"]
16
+ threshold: float = 22
17
+ k_summary: int = 3
18
+ k_total: int = 10
19
+ namespace: str = "vectors"
20
 
21
  def get_relevant_documents(self, query: str) -> List[Document]:
 
22
  # Check if all elements in the list are either IPCC or IPBES
23
+ assert isinstance(self.sources, list)
24
+ assert all([x in ["IPCC", "IPBES"] for x in self.sources])
25
  assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
26
 
27
  # Prepare base search kwargs
28
  filters = {
29
+ "source": {"$in": self.sources},
30
  }
31
 
32
  # Search for k_summary documents in the summaries dataset
33
  filters_summaries = {
34
  **filters,
35
+ "report_type": {"$in": ["SPM", "TS"]},
36
  }
37
+ docs_summaries = self.vectorstore.similarity_search_with_score(
38
+ query=query,
39
+ namespace=self.namespace,
40
+ filter=filters_summaries,
41
+ k=self.k_summary,
42
+ )
43
  docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
44
 
45
  # Search for k_total - k_summary documents in the full reports dataset
46
  filters_full = {
47
  **filters,
48
+ "report_type": {"$nin": ["SPM", "TS"]},
49
  }
50
  k_full = self.k_total - len(docs_summaries)
51
+ docs_full = self.vectorstore.similarity_search_with_score(
52
+ query=query, namespace=self.namespace, filter=filters_full, k=k_full
53
+ )
54
 
55
  # Concatenate documents
56
  docs = docs_summaries + docs_full
 
60
 
61
  # Add score to metadata
62
  results = []
63
+ for i, (doc, score) in enumerate(docs):
64
  doc.metadata["similarity_score"] = score
65
  doc.metadata["content"] = doc.page_content
66
  doc.metadata["page_number"] = int(doc.metadata["page_number"])
67
+ doc.page_content = (
68
+ f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
69
+ )
70
  results.append(doc)
71
 
72
  return results
73
 
74
 
 
 
 
75
  # def filter_summaries(df,k_summary = 3,k_total = 10):
76
  # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
77
 
 
98
  # return passages
99
 
100
 
 
 
101
  # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
102
  # assert max_k > k_total
103
 
 
129
  # return passages_df
130
 
131
 
 
132
  # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
133
 
134
 
 
149
  # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
150
  # }
151
  # return response
 
climateqa/vectorstore.py CHANGED
@@ -8,13 +8,13 @@ from langchain.vectorstores import Pinecone
8
  # LOAD ENVIRONMENT VARIABLES
9
  try:
10
  from dotenv import load_dotenv
 
11
  load_dotenv()
12
  except:
13
  pass
14
 
15
 
16
- def get_pinecone_vectorstore(embeddings,text_key = "content"):
17
-
18
  # initialize pinecone
19
  pinecone.init(
20
  api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
@@ -22,5 +22,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content"):
22
  )
23
 
24
  index_name = os.getenv("PINECONE_API_INDEX")
25
- vectorstore = Pinecone.from_existing_index(index_name, embeddings,text_key = text_key)
 
 
26
  return vectorstore
 
8
  # LOAD ENVIRONMENT VARIABLES
9
  try:
10
  from dotenv import load_dotenv
11
+
12
  load_dotenv()
13
  except:
14
  pass
15
 
16
 
17
+ def get_pinecone_vectorstore(embeddings, text_key="content"):
 
18
  # initialize pinecone
19
  pinecone.init(
20
  api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
 
22
  )
23
 
24
  index_name = os.getenv("PINECONE_API_INDEX")
25
+ vectorstore = Pinecone.from_existing_index(
26
+ index_name, embeddings, text_key=text_key
27
+ )
28
  return vectorstore
utils.py CHANGED
@@ -6,7 +6,7 @@ import uuid
6
 
7
  def create_user_id():
8
  """Create user_id
9
- str: String to id user
10
  """
11
  user_id = str(uuid.uuid4())
12
- return user_id
 
6
 
7
  def create_user_id():
8
  """Create user_id
9
+ str: String to id user
10
  """
11
  user_id = str(uuid.uuid4())
12
+ return user_id