omlakhani commited on
Commit
4015cd6
1 Parent(s): e649211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -21,23 +21,42 @@ from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
21
 
22
  index = GPTSimpleVectorIndex.load_from_disk('index.json')
23
 
24
- def querying_db(query: str):
25
  response = index.query(query)
26
  return response
27
 
 
28
  tools = [
29
  Tool(
30
  name="QueryingDB",
31
  func=querying_db,
32
- description="This function takes a query string as input and returns the most relevant answer from the documentation as output"
 
33
  )
 
34
  ]
35
 
36
- llm = OpenAI(temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def get_answer(query_string):
39
- agent = initialize_agent(tools, llm, agent="zero-shot-react-description")
40
- result = agent.run(query_string)
 
41
  return result
42
 
43
  def qa_app(query):
 
21
 
22
  index = GPTSimpleVectorIndex.load_from_disk('index.json')
23
 
24
+ ef querying_db(query: str):
25
  response = index.query(query)
26
  return response
27
 
28
+
29
  tools = [
30
  Tool(
31
  name="QueryingDB",
32
  func=querying_db,
33
+ description="useful for when you need to answer questions from the database. The answer is given in bullet points.",
34
+ return_direct=True
35
  )
36
+
37
  ]
38
 
39
+ prefix = """Give a detailed answer to the question"""
40
+ suffix = """Give answer in bullet points"
41
+
42
+ Question: {input}
43
+ {agent_scratchpad}"""
44
+
45
+ prompt = ZeroShotAgent.create_prompt(
46
+ tools,
47
+ prefix=prefix,
48
+ suffix=suffix,
49
+ input_variables=["input", "agent_scratchpad"]
50
+ )
51
+ llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
52
+ tool_names = [tool.name for tool in tools]
53
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
54
+
55
 
56
  def get_answer(query_string):
57
+ agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
58
+ result = agent_executor.run(query_string)
59
+
60
  return result
61
 
62
  def qa_app(query):