LOUIS SANNA commited on
Commit
cde6d5c
1 Parent(s): dfcff8d

feat(code): step down rule

Browse files
Files changed (1) hide show
  1. climateqa/chains.py +57 -57
climateqa/chains.py CHANGED
@@ -7,52 +7,10 @@ from langchain.chains import QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
- from climateqa.prompts import answer_prompt, reformulation_prompt, audience_prompts
11
  from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
12
 
13
 
14
- def load_reformulation_chain(llm):
15
- prompt = PromptTemplate(
16
- template=reformulation_prompt,
17
- input_variables=["query"],
18
- )
19
- reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
20
-
21
- # Parse the output
22
- def parse_output(output):
23
- query = output["query"]
24
- print("output", output)
25
- json_output = json.loads(output["json"])
26
- question = json_output.get("question", query)
27
- language = json_output.get("language", "English")
28
- return {
29
- "question": question,
30
- "language": language,
31
- }
32
-
33
- transform_chain = TransformChain(
34
- input_variables=["json"],
35
- output_variables=["question", "language"],
36
- transform=parse_output,
37
- )
38
-
39
- reformulation_chain = SequentialChain(
40
- chains=[reformulation_chain, transform_chain],
41
- input_variables=["query"],
42
- output_variables=["question", "language"],
43
- )
44
- return reformulation_chain
45
-
46
-
47
- def load_combine_documents_chain(llm):
48
- prompt = PromptTemplate(
49
- template=answer_prompt,
50
- input_variables=["summaries", "question", "audience", "language"],
51
- )
52
- qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
53
- return qa_chain
54
-
55
-
56
  def load_qa_chain_with_docs(llm):
57
  """Load a QA chain with documents.
58
  Useful when you already have retrieved docs
@@ -78,6 +36,15 @@ def load_qa_chain_with_docs(llm):
78
  return chain
79
 
80
 
 
 
 
 
 
 
 
 
 
81
  def load_qa_chain_with_text(llm):
82
  prompt = PromptTemplate(
83
  template=answer_prompt,
@@ -87,6 +54,53 @@ def load_qa_chain_with_text(llm):
87
  return qa_chain
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def load_qa_chain_with_retriever(retriever, llm):
91
  qa_chain = load_combine_documents_chain(llm)
92
 
@@ -101,17 +115,3 @@ def load_qa_chain_with_retriever(retriever, llm):
101
  fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
102
  )
103
  return answer_chain
104
-
105
-
106
- def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
107
- reformulation_chain = load_reformulation_chain(llm_reformulation)
108
- answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
109
-
110
- climateqa_chain = SequentialChain(
111
- chains=[reformulation_chain, answer_chain],
112
- input_variables=["query", "audience"],
113
- output_variables=["answer", "question", "language", "source_documents"],
114
- return_all=True,
115
- verbose=True,
116
- )
117
- return climateqa_chain
 
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
10
+ from climateqa.prompts import answer_prompt, reformulation_prompt
11
  from climateqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def load_qa_chain_with_docs(llm):
15
  """Load a QA chain with documents.
16
  Useful when you already have retrieved docs
 
36
  return chain
37
 
38
 
39
+ def load_combine_documents_chain(llm):
40
+ prompt = PromptTemplate(
41
+ template=answer_prompt,
42
+ input_variables=["summaries", "question", "audience", "language"],
43
+ )
44
+ qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
45
+ return qa_chain
46
+
47
+
48
  def load_qa_chain_with_text(llm):
49
  prompt = PromptTemplate(
50
  template=answer_prompt,
 
54
  return qa_chain
55
 
56
 
57
+ def load_climateqa_chain(retriever, llm_reformulation, llm_answer):
58
+ reformulation_chain = load_reformulation_chain(llm_reformulation)
59
+ answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)
60
+
61
+ climateqa_chain = SequentialChain(
62
+ chains=[reformulation_chain, answer_chain],
63
+ input_variables=["query", "audience"],
64
+ output_variables=["answer", "question", "language", "source_documents"],
65
+ return_all=True,
66
+ verbose=True,
67
+ )
68
+ return climateqa_chain
69
+
70
+
71
+ def load_reformulation_chain(llm):
72
+ prompt = PromptTemplate(
73
+ template=reformulation_prompt,
74
+ input_variables=["query"],
75
+ )
76
+ reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")
77
+
78
+ # Parse the output
79
+ def parse_output(output):
80
+ query = output["query"]
81
+ print("output", output)
82
+ json_output = json.loads(output["json"])
83
+ question = json_output.get("question", query)
84
+ language = json_output.get("language", "English")
85
+ return {
86
+ "question": question,
87
+ "language": language,
88
+ }
89
+
90
+ transform_chain = TransformChain(
91
+ input_variables=["json"],
92
+ output_variables=["question", "language"],
93
+ transform=parse_output,
94
+ )
95
+
96
+ reformulation_chain = SequentialChain(
97
+ chains=[reformulation_chain, transform_chain],
98
+ input_variables=["query"],
99
+ output_variables=["question", "language"],
100
+ )
101
+ return reformulation_chain
102
+
103
+
104
  def load_qa_chain_with_retriever(retriever, llm):
105
  qa_chain = load_combine_documents_chain(llm)
106
 
 
115
  fallback_answer="**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**",
116
  )
117
  return answer_chain