Spaces:
Runtime error
Runtime error
LOUIS SANNA
commited on
Commit
•
6e28a81
1
Parent(s):
35c9187
feat(*)
Browse files- climateqa/chains.py +37 -30
- climateqa/chat.py +14 -11
- climateqa/custom_retrieval_chain.py +3 -5
- climateqa/llm.py +8 -7
- climateqa/logging.py +1 -0
- climateqa/prompts.py +1 -2
- climateqa/retriever.py +25 -23
- climateqa/vectorstore.py +5 -3
- utils.py +2 -2
climateqa/chains.py
CHANGED
@@ -3,20 +3,20 @@
|
|
3 |
import json
|
4 |
|
5 |
from langchain import PromptTemplate, LLMChain
|
6 |
-
from langchain.chains import RetrievalQAWithSourcesChain,QAWithSourcesChain
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
10 |
-
from climateqa.prompts import answer_prompt, reformulation_prompt,audience_prompts
|
11 |
from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
|
12 |
|
13 |
-
def load_reformulation_chain(llm):
|
14 |
|
|
|
15 |
prompt = PromptTemplate(
|
16 |
-
template
|
17 |
input_variables=["query"],
|
18 |
)
|
19 |
-
reformulation_chain = LLMChain(llm
|
20 |
|
21 |
# Parse the output
|
22 |
def parse_output(output):
|
@@ -28,20 +28,30 @@ def load_reformulation_chain(llm):
|
|
28 |
"question": question,
|
29 |
"language": language,
|
30 |
}
|
31 |
-
|
32 |
transform_chain = TransformChain(
|
33 |
-
input_variables=["json"],
|
|
|
|
|
34 |
)
|
35 |
|
36 |
-
reformulation_chain = SequentialChain(
|
|
|
|
|
|
|
|
|
37 |
return reformulation_chain
|
38 |
|
39 |
|
40 |
def load_combine_documents_chain(llm):
|
41 |
-
prompt = PromptTemplate(
|
42 |
-
|
|
|
|
|
|
|
43 |
return qa_chain
|
44 |
|
|
|
45 |
def load_qa_chain_with_docs(llm):
|
46 |
"""Load a QA chain with documents.
|
47 |
Useful when you already have retrieved docs
|
@@ -60,50 +70,47 @@ def load_qa_chain_with_docs(llm):
|
|
60 |
|
61 |
qa_chain = load_combine_documents_chain(llm)
|
62 |
chain = QAWithSourcesChain(
|
63 |
-
input_docs_key
|
64 |
-
combine_documents_chain
|
65 |
-
return_source_documents
|
66 |
)
|
67 |
return chain
|
68 |
|
69 |
|
70 |
def load_qa_chain_with_text(llm):
|
71 |
-
|
72 |
prompt = PromptTemplate(
|
73 |
-
template
|
74 |
-
input_variables=["question","audience","language","summaries"],
|
75 |
)
|
76 |
-
qa_chain = LLMChain(llm
|
77 |
return qa_chain
|
78 |
|
79 |
|
80 |
-
def load_qa_chain_with_retriever(retriever,llm):
|
81 |
qa_chain = load_combine_documents_chain(llm)
|
82 |
|
83 |
# This could be improved by providing a document prompt to avoid modifying page_content in the docs
|
84 |
# See here https://github.com/langchain-ai/langchain/issues/3523
|
85 |
|
86 |
answer_chain = CustomRetrievalQAWithSourcesChain(
|
87 |
-
combine_documents_chain
|
88 |
retriever=retriever,
|
89 |
-
return_source_documents
|
90 |
-
verbose
|
91 |
fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
|
92 |
)
|
93 |
return answer_chain
|
94 |
|
95 |
|
96 |
-
def load_climateqa_chain(retriever,llm_reformulation,llm_answer):
|
97 |
-
|
98 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
99 |
-
answer_chain = load_qa_chain_with_retriever(retriever,llm_answer)
|
100 |
|
101 |
climateqa_chain = SequentialChain(
|
102 |
-
chains
|
103 |
-
input_variables=["query","audience"],
|
104 |
-
output_variables=["answer","question","language","source_documents"],
|
105 |
-
return_all
|
106 |
-
verbose
|
107 |
)
|
108 |
return climateqa_chain
|
109 |
-
|
|
|
3 |
import json
|
4 |
|
5 |
from langchain import PromptTemplate, LLMChain
|
6 |
+
from langchain.chains import RetrievalQAWithSourcesChain, QAWithSourcesChain
|
7 |
from langchain.chains import TransformChain, SequentialChain
|
8 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
9 |
|
10 |
+
from climateqa.prompts import answer_prompt, reformulation_prompt, audience_prompts
|
11 |
from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
|
12 |
|
|
|
13 |
|
14 |
+
def load_reformulation_chain(llm):
|
15 |
prompt = PromptTemplate(
|
16 |
+
template=reformulation_prompt,
|
17 |
input_variables=["query"],
|
18 |
)
|
19 |
+
reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
|
20 |
|
21 |
# Parse the output
|
22 |
def parse_output(output):
|
|
|
28 |
"question": question,
|
29 |
"language": language,
|
30 |
}
|
31 |
+
|
32 |
transform_chain = TransformChain(
|
33 |
+
input_variables=["json"],
|
34 |
+
output_variables=["question", "language"],
|
35 |
+
transform=parse_output,
|
36 |
)
|
37 |
|
38 |
+
reformulation_chain = SequentialChain(
|
39 |
+
chains=[reformulation_chain, transform_chain],
|
40 |
+
input_variables=["query"],
|
41 |
+
output_variables=["question", "language"],
|
42 |
+
)
|
43 |
return reformulation_chain
|
44 |
|
45 |
|
46 |
def load_combine_documents_chain(llm):
|
47 |
+
prompt = PromptTemplate(
|
48 |
+
template=answer_prompt,
|
49 |
+
input_variables=["summaries", "question", "audience", "language"],
|
50 |
+
)
|
51 |
+
qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
|
52 |
return qa_chain
|
53 |
|
54 |
+
|
55 |
def load_qa_chain_with_docs(llm):
|
56 |
"""Load a QA chain with documents.
|
57 |
Useful when you already have retrieved docs
|
|
|
70 |
|
71 |
qa_chain = load_combine_documents_chain(llm)
|
72 |
chain = QAWithSourcesChain(
|
73 |
+
input_docs_key="docs",
|
74 |
+
combine_documents_chain=qa_chain,
|
75 |
+
return_source_documents=True,
|
76 |
)
|
77 |
return chain
|
78 |
|
79 |
|
80 |
def load_qa_chain_with_text(llm):
|
|
|
81 |
prompt = PromptTemplate(
|
82 |
+
template=answer_prompt,
|
83 |
+
input_variables=["question", "audience", "language", "summaries"],
|
84 |
)
|
85 |
+
qa_chain = LLMChain(llm=llm, prompt=prompt)
|
86 |
return qa_chain
|
87 |
|
88 |
|
89 |
+
def load_qa_chain_with_retriever(retriever, llm):
|
90 |
qa_chain = load_combine_documents_chain(llm)
|
91 |
|
92 |
# This could be improved by providing a document prompt to avoid modifying page_content in the docs
|
93 |
# See here https://github.com/langchain-ai/langchain/issues/3523
|
94 |
|
95 |
answer_chain = CustomRetrievalQAWithSourcesChain(
|
96 |
+
combine_documents_chain=qa_chain,
|
97 |
retriever=retriever,
|
98 |
+
return_source_documents=True,
|
99 |
+
verbose=True,
|
100 |
fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
|
101 |
)
|
102 |
return answer_chain
|
103 |
|
104 |
|
105 |
+
def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
|
|
|
106 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
107 |
+
answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
|
108 |
|
109 |
climateqa_chain = SequentialChain(
|
110 |
+
chains=[reformulation_chain, answer_chain],
|
111 |
+
input_variables=["query", "audience"],
|
112 |
+
output_variables=["answer", "question", "language", "source_documents"],
|
113 |
+
return_all=True,
|
114 |
+
verbose=True,
|
115 |
)
|
116 |
return climateqa_chain
|
|
climateqa/chat.py
CHANGED
@@ -12,28 +12,31 @@ from climateqa.chains import load_climateqa_chain
|
|
12 |
|
13 |
|
14 |
class ClimateQA:
|
15 |
-
def __init__(
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
self.embeddings_function = HuggingFaceEmbeddings(
|
20 |
model_name=hf_embedding_model,
|
21 |
-
encode_kwargs={
|
|
|
|
|
|
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
def get_vectorstore(self):
|
27 |
pass
|
28 |
|
29 |
-
|
30 |
def reformulate(self):
|
31 |
pass
|
32 |
|
33 |
-
|
34 |
def retrieve(self):
|
35 |
pass
|
36 |
|
37 |
-
|
38 |
def ask(self):
|
39 |
-
pass
|
|
|
12 |
|
13 |
|
14 |
class ClimateQA:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
hf_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
18 |
+
show_progress_bar=False,
|
19 |
+
batch_size=1,
|
20 |
+
max_tokens=1024,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
self.llm = self.get_llm(max_tokens=max_tokens, **kwargs)
|
24 |
self.embeddings_function = HuggingFaceEmbeddings(
|
25 |
model_name=hf_embedding_model,
|
26 |
+
encode_kwargs={
|
27 |
+
"show_progress_bar": show_progress_bar,
|
28 |
+
"batch_size": batch_size,
|
29 |
+
},
|
30 |
)
|
31 |
|
|
|
|
|
32 |
def get_vectorstore(self):
|
33 |
pass
|
34 |
|
|
|
35 |
def reformulate(self):
|
36 |
pass
|
37 |
|
|
|
38 |
def retrieve(self):
|
39 |
pass
|
40 |
|
|
|
41 |
def ask(self):
|
42 |
+
pass
|
climateqa/custom_retrieval_chain.py
CHANGED
@@ -29,11 +29,11 @@ from langchain.chains import RetrievalQAWithSourcesChain
|
|
29 |
|
30 |
from langchain.chains.router.llm_router import LLMRouterChain
|
31 |
|
32 |
-
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
-
def _call(self,inputs,run_manager=None):
|
37 |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
38 |
accepts_run_manager = (
|
39 |
"run_manager" in inspect.signature(self._get_docs).parameters
|
@@ -43,12 +43,10 @@ class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
|
43 |
else:
|
44 |
docs = self._get_docs(inputs) # type: ignore[call-arg]
|
45 |
|
46 |
-
|
47 |
if len(docs) == 0:
|
48 |
answer = self.fallback_answer
|
49 |
sources = []
|
50 |
else:
|
51 |
-
|
52 |
answer = self.combine_documents_chain.run(
|
53 |
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
54 |
)
|
|
|
29 |
|
30 |
from langchain.chains.router.llm_router import LLMRouterChain
|
31 |
|
|
|
32 |
|
33 |
+
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
34 |
+
fallback_answer: str = "No sources available to answer this question."
|
35 |
|
36 |
+
def _call(self, inputs, run_manager=None):
|
37 |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
38 |
accepts_run_manager = (
|
39 |
"run_manager" in inspect.signature(self._get_docs).parameters
|
|
|
43 |
else:
|
44 |
docs = self._get_docs(inputs) # type: ignore[call-arg]
|
45 |
|
|
|
46 |
if len(docs) == 0:
|
47 |
answer = self.fallback_answer
|
48 |
sources = []
|
49 |
else:
|
|
|
50 |
answer = self.combine_documents_chain.run(
|
51 |
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
52 |
)
|
climateqa/llm.py
CHANGED
@@ -1,25 +1,26 @@
|
|
1 |
from langchain.chat_models import AzureChatOpenAI
|
2 |
import os
|
|
|
3 |
# LOAD ENVIRONMENT VARIABLES
|
4 |
try:
|
5 |
from dotenv import load_dotenv
|
|
|
6 |
load_dotenv()
|
7 |
except:
|
8 |
pass
|
9 |
|
10 |
|
11 |
-
def get_llm(max_tokens
|
12 |
-
|
13 |
llm = AzureChatOpenAI(
|
14 |
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
15 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
16 |
deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
|
17 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
18 |
-
openai_api_type
|
19 |
-
max_tokens
|
20 |
-
temperature
|
21 |
-
verbose
|
22 |
-
streaming
|
23 |
**kwargs,
|
24 |
)
|
25 |
return llm
|
|
|
1 |
from langchain.chat_models import AzureChatOpenAI
|
2 |
import os
|
3 |
+
|
4 |
# LOAD ENVIRONMENT VARIABLES
|
5 |
try:
|
6 |
from dotenv import load_dotenv
|
7 |
+
|
8 |
load_dotenv()
|
9 |
except:
|
10 |
pass
|
11 |
|
12 |
|
13 |
+
def get_llm(max_tokens=1024, temperature=0.0, verbose=True, streaming=False, **kwargs):
|
|
|
14 |
llm = AzureChatOpenAI(
|
15 |
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
16 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
17 |
deployment_name=os.environ["AZURE_OPENAI_API_DEPLOYMENT_NAME"],
|
18 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
19 |
+
openai_api_type="azure",
|
20 |
+
max_tokens=max_tokens,
|
21 |
+
temperature=temperature,
|
22 |
+
verbose=verbose,
|
23 |
+
streaming=streaming,
|
24 |
**kwargs,
|
25 |
)
|
26 |
return llm
|
climateqa/logging.py
CHANGED
@@ -53,6 +53,7 @@ def get_azure_blob_client():
|
|
53 |
share_client = service.get_share_client(file_share_name)
|
54 |
return share_client
|
55 |
|
|
|
56 |
if has_blob_config():
|
57 |
share_client = get_azure_blob_client()
|
58 |
|
|
|
53 |
share_client = service.get_share_client(file_share_name)
|
54 |
return share_client
|
55 |
|
56 |
+
|
57 |
if has_blob_config():
|
58 |
share_client = get_azure_blob_client()
|
59 |
|
climateqa/prompts.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
# If the message is not relevant to climate change (like "How are you", "I am 18 years old" or "When was built the eiffel tower"), return N/A
|
3 |
|
4 |
reformulation_prompt = """
|
@@ -54,4 +53,4 @@ audience_prompts = {
|
|
54 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
55 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
56 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
57 |
-
}
|
|
|
|
|
1 |
# If the message is not relevant to climate change (like "How are you", "I am 18 years old" or "When was built the eiffel tower"), return N/A
|
2 |
|
3 |
reformulation_prompt = """
|
|
|
53 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
54 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
55 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
56 |
+
}
|
climateqa/retriever.py
CHANGED
@@ -9,41 +9,48 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
|
9 |
from typing import List
|
10 |
from pydantic import Field
|
11 |
|
|
|
12 |
class ClimateQARetriever(BaseRetriever):
|
13 |
-
vectorstore:VectorStore
|
14 |
-
sources:list = ["IPCC","IPBES"]
|
15 |
-
threshold:float = 22
|
16 |
-
k_summary:int = 3
|
17 |
-
k_total:int = 10
|
18 |
-
namespace:str = "vectors"
|
19 |
|
20 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
21 |
-
|
22 |
# Check if all elements in the list are either IPCC or IPBES
|
23 |
-
assert isinstance(self.sources,list)
|
24 |
-
assert all([x in ["IPCC","IPBES"] for x in self.sources])
|
25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
26 |
|
27 |
# Prepare base search kwargs
|
28 |
filters = {
|
29 |
-
"source": {
|
30 |
}
|
31 |
|
32 |
# Search for k_summary documents in the summaries dataset
|
33 |
filters_summaries = {
|
34 |
**filters,
|
35 |
-
"report_type": {
|
36 |
}
|
37 |
-
docs_summaries = self.vectorstore.similarity_search_with_score(
|
|
|
|
|
|
|
|
|
|
|
38 |
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
39 |
|
40 |
# Search for k_total - k_summary documents in the full reports dataset
|
41 |
filters_full = {
|
42 |
**filters,
|
43 |
-
"report_type": {
|
44 |
}
|
45 |
k_full = self.k_total - len(docs_summaries)
|
46 |
-
docs_full = self.vectorstore.similarity_search_with_score(
|
|
|
|
|
47 |
|
48 |
# Concatenate documents
|
49 |
docs = docs_summaries + docs_full
|
@@ -53,19 +60,18 @@ class ClimateQARetriever(BaseRetriever):
|
|
53 |
|
54 |
# Add score to metadata
|
55 |
results = []
|
56 |
-
for i,(doc,score) in enumerate(docs):
|
57 |
doc.metadata["similarity_score"] = score
|
58 |
doc.metadata["content"] = doc.page_content
|
59 |
doc.metadata["page_number"] = int(doc.metadata["page_number"])
|
60 |
-
doc.page_content =
|
|
|
|
|
61 |
results.append(doc)
|
62 |
|
63 |
return results
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
# def filter_summaries(df,k_summary = 3,k_total = 10):
|
70 |
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
|
71 |
|
@@ -92,8 +98,6 @@ class ClimateQARetriever(BaseRetriever):
|
|
92 |
# return passages
|
93 |
|
94 |
|
95 |
-
|
96 |
-
|
97 |
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
98 |
# assert max_k > k_total
|
99 |
|
@@ -125,7 +129,6 @@ class ClimateQARetriever(BaseRetriever):
|
|
125 |
# return passages_df
|
126 |
|
127 |
|
128 |
-
|
129 |
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
130 |
|
131 |
|
@@ -146,4 +149,3 @@ class ClimateQARetriever(BaseRetriever):
|
|
146 |
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
147 |
# }
|
148 |
# return response
|
149 |
-
|
|
|
9 |
from typing import List
|
10 |
from pydantic import Field
|
11 |
|
12 |
+
|
13 |
class ClimateQARetriever(BaseRetriever):
|
14 |
+
vectorstore: VectorStore
|
15 |
+
sources: list = ["IPCC", "IPBES"]
|
16 |
+
threshold: float = 22
|
17 |
+
k_summary: int = 3
|
18 |
+
k_total: int = 10
|
19 |
+
namespace: str = "vectors"
|
20 |
|
21 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
|
22 |
# Check if all elements in the list are either IPCC or IPBES
|
23 |
+
assert isinstance(self.sources, list)
|
24 |
+
assert all([x in ["IPCC", "IPBES"] for x in self.sources])
|
25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
26 |
|
27 |
# Prepare base search kwargs
|
28 |
filters = {
|
29 |
+
"source": {"$in": self.sources},
|
30 |
}
|
31 |
|
32 |
# Search for k_summary documents in the summaries dataset
|
33 |
filters_summaries = {
|
34 |
**filters,
|
35 |
+
"report_type": {"$in": ["SPM", "TS"]},
|
36 |
}
|
37 |
+
docs_summaries = self.vectorstore.similarity_search_with_score(
|
38 |
+
query=query,
|
39 |
+
namespace=self.namespace,
|
40 |
+
filter=filters_summaries,
|
41 |
+
k=self.k_summary,
|
42 |
+
)
|
43 |
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
44 |
|
45 |
# Search for k_total - k_summary documents in the full reports dataset
|
46 |
filters_full = {
|
47 |
**filters,
|
48 |
+
"report_type": {"$nin": ["SPM", "TS"]},
|
49 |
}
|
50 |
k_full = self.k_total - len(docs_summaries)
|
51 |
+
docs_full = self.vectorstore.similarity_search_with_score(
|
52 |
+
query=query, namespace=self.namespace, filter=filters_full, k=k_full
|
53 |
+
)
|
54 |
|
55 |
# Concatenate documents
|
56 |
docs = docs_summaries + docs_full
|
|
|
60 |
|
61 |
# Add score to metadata
|
62 |
results = []
|
63 |
+
for i, (doc, score) in enumerate(docs):
|
64 |
doc.metadata["similarity_score"] = score
|
65 |
doc.metadata["content"] = doc.page_content
|
66 |
doc.metadata["page_number"] = int(doc.metadata["page_number"])
|
67 |
+
doc.page_content = (
|
68 |
+
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
69 |
+
)
|
70 |
results.append(doc)
|
71 |
|
72 |
return results
|
73 |
|
74 |
|
|
|
|
|
|
|
75 |
# def filter_summaries(df,k_summary = 3,k_total = 10):
|
76 |
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
|
77 |
|
|
|
98 |
# return passages
|
99 |
|
100 |
|
|
|
|
|
101 |
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
102 |
# assert max_k > k_total
|
103 |
|
|
|
129 |
# return passages_df
|
130 |
|
131 |
|
|
|
132 |
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
133 |
|
134 |
|
|
|
149 |
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
150 |
# }
|
151 |
# return response
|
|
climateqa/vectorstore.py
CHANGED
@@ -8,13 +8,13 @@ from langchain.vectorstores import Pinecone
|
|
8 |
# LOAD ENVIRONMENT VARIABLES
|
9 |
try:
|
10 |
from dotenv import load_dotenv
|
|
|
11 |
load_dotenv()
|
12 |
except:
|
13 |
pass
|
14 |
|
15 |
|
16 |
-
def get_pinecone_vectorstore(embeddings,text_key
|
17 |
-
|
18 |
# initialize pinecone
|
19 |
pinecone.init(
|
20 |
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
|
@@ -22,5 +22,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content"):
|
|
22 |
)
|
23 |
|
24 |
index_name = os.getenv("PINECONE_API_INDEX")
|
25 |
-
vectorstore = Pinecone.from_existing_index(
|
|
|
|
|
26 |
return vectorstore
|
|
|
8 |
# LOAD ENVIRONMENT VARIABLES
|
9 |
try:
|
10 |
from dotenv import load_dotenv
|
11 |
+
|
12 |
load_dotenv()
|
13 |
except:
|
14 |
pass
|
15 |
|
16 |
|
17 |
+
def get_pinecone_vectorstore(embeddings, text_key="content"):
|
|
|
18 |
# initialize pinecone
|
19 |
pinecone.init(
|
20 |
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
|
|
|
22 |
)
|
23 |
|
24 |
index_name = os.getenv("PINECONE_API_INDEX")
|
25 |
+
vectorstore = Pinecone.from_existing_index(
|
26 |
+
index_name, embeddings, text_key=text_key
|
27 |
+
)
|
28 |
return vectorstore
|
utils.py
CHANGED
@@ -6,7 +6,7 @@ import uuid
|
|
6 |
|
7 |
def create_user_id():
|
8 |
"""Create user_id
|
9 |
-
|
10 |
"""
|
11 |
user_id = str(uuid.uuid4())
|
12 |
-
return user_id
|
|
|
6 |
|
7 |
def create_user_id():
|
8 |
"""Create user_id
|
9 |
+
str: String to id user
|
10 |
"""
|
11 |
user_id = str(uuid.uuid4())
|
12 |
+
return user_id
|