alejandro commited on
Commit
b0cfa99
·
1 Parent(s): efb8ba7

feat: add full chain

Browse files
Files changed (1) hide show
  1. src/app.py +26 -2
src/app.py CHANGED
@@ -34,12 +34,36 @@ def get_sql_chain(db):
34
  | StrOutputParser()
35
  )
36
 
37
-
38
  def get_response(user_query, chat_history, db):
39
 
40
  sql_chain = get_sql_chain(db)
41
 
42
- return sql_chain.invoke({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  "question": user_query
44
  })
45
 
 
34
  | StrOutputParser()
35
  )
36
 
 
37
  def get_response(user_query, chat_history, db):
38
 
39
  sql_chain = get_sql_chain(db)
40
 
41
+ template = """
42
+ Based on the table schema below, question, sql query, and sql response, write a natural language response:
43
+ {schema}
44
+
45
+ Question: {question}
46
+ SQL Query: {query}
47
+ SQL Response: {response}"""
48
+
49
+ prompt = ChatPromptTemplate.from_template(template)
50
+
51
+ llm = ChatOpenAI()
52
+
53
+ def get_schema(_):
54
+ return db.get_table_info()
55
+
56
+ chain = (
57
+ RunnablePassthrough.assign(query=sql_chain).assign(
58
+ schema=get_schema,
59
+ response= lambda vars: db.run(vars["query"])
60
+ )
61
+ | prompt
62
+ | llm
63
+ | StrOutputParser()
64
+ )
65
+
66
+ return chain.invoke({
67
  "question": user_query
68
  })
69