anakib1 commited on
Commit
bc332e6
1 Parent(s): bed851f

Added answer options

Browse files
Files changed (4) hide show
  1. app.py +9 -7
  2. src/chains.py +75 -9
  3. src/clients.py +4 -3
  4. src/complex.ipynb +91 -5
app.py CHANGED
@@ -6,26 +6,28 @@ load_dotenv()
6
  client = AcademicClient()
7
 
8
 
9
- def perform_qa(query):
10
- return client.answer(query)
11
 
12
 
13
  css = """
14
  body {
15
- align-items: center;
16
  display:block;
17
  }
18
  """
19
 
20
  with gr.Blocks(css=css) as demo:
21
- gr.Markdown('Wisdom.AI'),
22
  gr.Image('misc/wisdom.jpg', height=600, width=400)
23
  with gr.Row():
24
- inp = gr.Textbox('Що б ви хотіли дізнатися у мудрого?', label='Питання')
25
- out = gr.Textbox('Мудрий каже...', label='Відповідь')
 
 
26
 
27
  btn = gr.Button('Спитати')
28
- btn.click(fn=perform_qa, inputs=inp, outputs=out)
29
 
30
  if __name__ == "__main__":
31
  demo.launch()
 
6
  client = AcademicClient()
7
 
8
 
9
+ def perform_qa(query: str, options: str) -> str:
10
+ return client.answer(query, options.split('\n'))
11
 
12
 
13
  css = """
14
  body {
15
+ image-align: center;
16
  display:block;
17
  }
18
  """
19
 
20
  with gr.Blocks(css=css) as demo:
21
+ gr.Markdown('# Wisdom.AI'),
22
  gr.Image('misc/wisdom.jpg', height=600, width=400)
23
  with gr.Row():
24
+ inp = gr.Textbox('Що б ви хотіли дізнатися у мудрого?', label='Питання', min_width=400)
25
+ out = gr.Textbox('Мудрий каже...', label='Відповідь', min_width=400)
26
+
27
+ options = gr.Textbox(label='Варіанти відповіді:', min_width=800)
28
 
29
  btn = gr.Button('Спитати')
30
+ btn.click(fn=perform_qa, inputs=[inp, options], outputs=out)
31
 
32
  if __name__ == "__main__":
33
  demo.launch()
src/chains.py CHANGED
@@ -1,31 +1,71 @@
1
  from langchain_core.output_parsers import StrOutputParser
2
- from langchain_core.runnables import RunnablePassthrough, RunnableLambda
3
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain.prompts import PromptTemplate
6
  from langchain_community.utilities import GoogleSerperAPIWrapper
 
 
 
7
 
8
  CUSTOM_RAG_PROMPT = """
9
- Використай наступні **надійні** елементи, для того, щоб відповісти на питання в кінці.
10
- Якщо вони не містять відповіді, зверни увагу на відповідь з інтернету, хоча вона може бути не надійною.
11
  Якщо ти не знаєш відповіді, використаши всі свої джерела, то просто скажи про це, не потрібно вигадувати відповідь.
12
- Використовуй не більше трьох речень, та намагайся відповісти коротко та чітко.
 
13
 
14
  {context}
15
 
16
  Відповідь з інтернету: {internet}
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  Питання: {question}
19
 
20
- Корисна відповідь:"""
 
 
 
21
 
22
  CUSTOM_RAG_PROMPT = PromptTemplate.from_template(CUSTOM_RAG_PROMPT)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def documents_parser(docs):
26
  return "\n\n".join(doc.page_content for doc in docs)
27
 
28
 
 
 
 
 
 
 
 
 
29
  class PdfAndGoogleChain:
30
 
31
  def use_google_search(self, query):
@@ -34,17 +74,43 @@ class PdfAndGoogleChain:
34
  except Exception as ex:
35
  return 'NONE'
36
 
 
 
 
 
 
 
 
 
37
  def __init__(self, retriever, llm_name: str = "gpt-3.5-turbo-0125"):
38
  self.search = GoogleSerperAPIWrapper()
 
39
  self.llm = ChatOpenAI(model=llm_name)
40
 
41
  self.rag_chain = (
42
- {"context": retriever | documents_parser, "internet": RunnableLambda(self.use_google_search),
43
- "question": RunnablePassthrough()}
 
 
44
  | CUSTOM_RAG_PROMPT
45
  | self.llm
46
  | StrOutputParser()
47
  )
48
 
49
- def answer(self, query: str):
50
- return self.rag_chain.invoke(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_core.output_parsers import StrOutputParser
2
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
3
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain.prompts import PromptTemplate
6
  from langchain_community.utilities import GoogleSerperAPIWrapper
7
+ from typing import List
8
+
9
+ from operator import itemgetter
10
 
11
  CUSTOM_RAG_PROMPT = """
12
+ Використай наступні **надійні** елементи, для того, щоб вибрати відповідь на питання з запропонованих.
13
+ Якщо вони не містять відповіді, зверни увагу на інформацію з інтернету.
14
  Якщо ти не знаєш відповіді, використаши всі свої джерела, то просто скажи про це, не потрібно вигадувати відповідь.
15
+ Напиши у відповіді номер правильного варіанту відповіді. Якщо серед варантів немає правильної відповіді, напиши коротко відповідь самостійно.
16
+
17
 
18
  {context}
19
 
20
  Відповідь з інтернету: {internet}
21
 
22
+ Приклад:
23
+
24
+ Чия типологія поділяється на традиційні, харизматичні й раціональні системи?
25
+ 1) Вебер
26
+ 2) Ленін
27
+ 3) Сталін
28
+ 4) Обама
29
+
30
+ Правильна відповідь: 1 - Вебер.
31
+
32
+
33
  Питання: {question}
34
 
35
+ Варіанти відповіді:
36
+ {options}
37
+
38
+ Правильна відповідь:"""
39
 
40
  CUSTOM_RAG_PROMPT = PromptTemplate.from_template(CUSTOM_RAG_PROMPT)
41
 
42
+ VERIFICATION_PROMPT = """
43
+ Вам було задано наступне питання:
44
+ {question}
45
+ З варіантами відповіді:
46
+ {options}
47
+ На яку було запропоновано відповідь:
48
+ {answer}
49
+
50
+ Повторіть відповідь, якщо вона правильна. Інакше, скажіть "відповідь відсутня".
51
+
52
+ Відповідь:
53
+ """
54
+ VERIFICATION_PROMPT = PromptTemplate.from_template(VERIFICATION_PROMPT)
55
+
56
 
57
  def documents_parser(docs):
58
  return "\n\n".join(doc.page_content for doc in docs)
59
 
60
 
61
+ def prepare_options(options):
62
+ return "\n".join([f"{i + 1}) {option}" for i, option in enumerate(options)])
63
+
64
+ def flatten_input(d):
65
+ ret = d.pop('input')
66
+ ret.update(d)
67
+ return ret
68
+
69
  class PdfAndGoogleChain:
70
 
71
  def use_google_search(self, query):
 
74
  except Exception as ex:
75
  return 'NONE'
76
 
77
+ def retrieve_multiple(self, query_dict):
78
+ query = query_dict['query']
79
+ options = query_dict['options']
80
+ ret = self.retriever.get_relevant_documents(query)
81
+ for option in options:
82
+ ret.extend(self.retriever.get_relevant_documents(option)[:2])
83
+ return ret
84
+
85
  def __init__(self, retriever, llm_name: str = "gpt-3.5-turbo-0125"):
86
  self.search = GoogleSerperAPIWrapper()
87
+ self.retriever = retriever
88
  self.llm = ChatOpenAI(model=llm_name)
89
 
90
  self.rag_chain = (
91
+ {"context": RunnableLambda(self.retrieve_multiple) | documents_parser,
92
+ "internet": itemgetter("query") | RunnableLambda(self.use_google_search),
93
+ "question": itemgetter("query") | RunnablePassthrough(),
94
+ "options": itemgetter("options") | RunnableLambda(prepare_options)}
95
  | CUSTOM_RAG_PROMPT
96
  | self.llm
97
  | StrOutputParser()
98
  )
99
 
100
+ self.verification_chain = (
101
+ {"question": itemgetter("query") | RunnablePassthrough(),
102
+ "options": itemgetter("options") | RunnableLambda(prepare_options),
103
+ "answer": itemgetter("answer") | RunnablePassthrough()}
104
+ | VERIFICATION_PROMPT
105
+ | self.llm
106
+ | StrOutputParser()
107
+ )
108
+
109
+ self.global_chain = (RunnableParallel(input=RunnablePassthrough(), answer=self.rag_chain)
110
+ | RunnableLambda(flatten_input)
111
+ | self.verification_chain)
112
+
113
+ def answer(self, query: str, options: List[str]):
114
+ options = list(filter(lambda x: x is not None and len(x) > 0, options))
115
+ return self.global_chain.invoke({"query": query, "options": options})
116
+
src/clients.py CHANGED
@@ -43,7 +43,8 @@ class AcademicClient:
43
 
44
  def __init__(self):
45
  self.create_vectordb()
46
- self.chain = PdfAndGoogleChain(self.vectordb.as_retriever())
 
47
 
48
- def answer(self, query):
49
- return self.chain.answer(query)
 
43
 
44
  def __init__(self):
45
  self.create_vectordb()
46
+ self.chain = PdfAndGoogleChain(
47
+ self.vectordb.as_retriever(search_type="mmr", search_kwargs={"fetch_k": 30, "k": 6}))
48
 
49
+ def answer(self, query, options):
50
+ return self.chain.answer(query, options)
src/complex.ipynb CHANGED
@@ -11,7 +11,7 @@
11
  "from langchain.vectorstores import Chroma\n",
12
  "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
13
  "from langchain_core.output_parsers import StrOutputParser\n",
14
- "from langchain_core.runnables import RunnablePassthrough, RunnableLambda\n",
15
  "from langchain.document_loaders import PyPDFLoader\n",
16
  "from langchain_openai import ChatOpenAI\n",
17
  "from dotenv import load_dotenv\n",
@@ -149,7 +149,7 @@
149
  },
150
  {
151
  "cell_type": "code",
152
- "execution_count": 18,
153
  "outputs": [],
154
  "source": [
155
  "retriever = vectordb.as_retriever()\n",
@@ -174,12 +174,37 @@
174
  "metadata": {
175
  "collapsed": false,
176
  "ExecuteTime": {
177
- "end_time": "2024-04-09T13:38:45.242263100Z",
178
- "start_time": "2024-04-09T13:38:45.221315700Z"
179
  }
180
  },
181
  "id": "64cb22281c854513"
182
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  {
184
  "cell_type": "code",
185
  "execution_count": 19,
@@ -263,6 +288,67 @@
263
  },
264
  "id": "b0a54b5e476b46e0"
265
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  {
267
  "cell_type": "code",
268
  "execution_count": null,
@@ -271,7 +357,7 @@
271
  "metadata": {
272
  "collapsed": false
273
  },
274
- "id": "5c977fcc519c1a6e"
275
  }
276
  ],
277
  "metadata": {
 
11
  "from langchain.vectorstores import Chroma\n",
12
  "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
13
  "from langchain_core.output_parsers import StrOutputParser\n",
14
+ "from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel\n",
15
  "from langchain.document_loaders import PyPDFLoader\n",
16
  "from langchain_openai import ChatOpenAI\n",
17
  "from dotenv import load_dotenv\n",
 
149
  },
150
  {
151
  "cell_type": "code",
152
+ "execution_count": 26,
153
  "outputs": [],
154
  "source": [
155
  "retriever = vectordb.as_retriever()\n",
 
174
  "metadata": {
175
  "collapsed": false,
176
  "ExecuteTime": {
177
+ "end_time": "2024-04-09T18:59:04.620561600Z",
178
+ "start_time": "2024-04-09T18:59:04.602246900Z"
179
  }
180
  },
181
  "id": "64cb22281c854513"
182
  },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 27,
186
+ "outputs": [
187
+ {
188
+ "data": {
189
+ "text/plain": "langchain_core.runnables.base.RunnableSequence"
190
+ },
191
+ "execution_count": 27,
192
+ "metadata": {},
193
+ "output_type": "execute_result"
194
+ }
195
+ ],
196
+ "source": [
197
+ "type(rag_chain)"
198
+ ],
199
+ "metadata": {
200
+ "collapsed": false,
201
+ "ExecuteTime": {
202
+ "end_time": "2024-04-09T18:59:09.631586300Z",
203
+ "start_time": "2024-04-09T18:59:09.610952200Z"
204
+ }
205
+ },
206
+ "id": "c2fe5487662fc6f0"
207
+ },
208
  {
209
  "cell_type": "code",
210
  "execution_count": 19,
 
288
  },
289
  "id": "b0a54b5e476b46e0"
290
  },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 29,
294
+ "outputs": [],
295
+ "source": [
296
+ "from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel\n"
297
+ ],
298
+ "metadata": {
299
+ "collapsed": false,
300
+ "ExecuteTime": {
301
+ "end_time": "2024-04-09T19:11:51.492629300Z",
302
+ "start_time": "2024-04-09T19:11:51.470103600Z"
303
+ }
304
+ },
305
+ "id": "5c977fcc519c1a6e"
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": 36,
310
+ "outputs": [],
311
+ "source": [
312
+ "def flatten_input(d):\n",
313
+ " ret = d.pop('a')\n",
314
+ " ret.update(d)\n",
315
+ " return ret\n",
316
+ "a = RunnableParallel(a = RunnablePassthrough(), b = RunnableLambda(lambda x: \"abracadabra\")) | RunnableLambda(flatten_input)"
317
+ ],
318
+ "metadata": {
319
+ "collapsed": false,
320
+ "ExecuteTime": {
321
+ "end_time": "2024-04-09T19:14:33.339534100Z",
322
+ "start_time": "2024-04-09T19:14:33.319287200Z"
323
+ }
324
+ },
325
+ "id": "720b1320bc0fb7d8"
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": 37,
330
+ "outputs": [
331
+ {
332
+ "data": {
333
+ "text/plain": "{'xx': 'yy', 'zz': 11, 'b': 'abracadabra'}"
334
+ },
335
+ "execution_count": 37,
336
+ "metadata": {},
337
+ "output_type": "execute_result"
338
+ }
339
+ ],
340
+ "source": [
341
+ "a.invoke({\"xx\" : \"yy\", \"zz\" : 11})"
342
+ ],
343
+ "metadata": {
344
+ "collapsed": false,
345
+ "ExecuteTime": {
346
+ "end_time": "2024-04-09T19:14:33.976773500Z",
347
+ "start_time": "2024-04-09T19:14:33.918734500Z"
348
+ }
349
+ },
350
+ "id": "465e8521af889ff6"
351
+ },
352
  {
353
  "cell_type": "code",
354
  "execution_count": null,
 
357
  "metadata": {
358
  "collapsed": false
359
  },
360
+ "id": "27c9bf7387f058cf"
361
  }
362
  ],
363
  "metadata": {