Jesus Sanchez commited on
Commit
0cbd39d
·
1 Parent(s): f9276e2
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -11,11 +11,7 @@ from langchain import OpenAI
11
  from langchain import PromptTemplate, OpenAI, LLMChain
12
  from langchain.chains import SimpleSequentialChain
13
 
14
- JSON_DATA_LABEL = 'json_data'
15
-
16
- if JSON_DATA_LABEL not in st.session_state:
17
- st.session_state[JSON_DATA_LABEL] = {}
18
-
19
 
20
  def tables_from_db():
21
  db = sqlite3.connect('switrs.sqlite')
@@ -40,7 +36,7 @@ def from_db(table: str):
40
 
41
 
42
 
43
- def get_sql_agent(llm):
44
  db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
45
 
46
  toolkit = SQLDatabaseToolkit(llm=llm,db=db)
@@ -54,7 +50,7 @@ def get_sql_agent(llm):
54
  return_intermediate_steps=True
55
  )
56
 
57
- def get_json_chain(llm):
58
  prompt_template = "Reformat this {result} in JSON format"
59
  return LLMChain(
60
  llm=llm,
@@ -63,10 +59,10 @@ def get_json_chain(llm):
63
 
64
 
65
 
66
- def from_gpt(query: str, plot: bool, llm):
67
 
68
- sql_agent = get_sql_agent(llm)
69
- json_chain = get_json_chain(llm)
70
 
71
  chains = [sql_agent, json_chain] if plot else [sql_agent]
72
 
@@ -111,7 +107,6 @@ def get_response(prompt: str, llm, *kargs):
111
  return (response, on_render)
112
 
113
 
114
- llm=OpenAI(temperature=0)
115
 
116
 
117
  chat = idf_chat.Chat()
 
11
  from langchain import PromptTemplate, OpenAI, LLMChain
12
  from langchain.chains import SimpleSequentialChain
13
 
14
+ llm=OpenAI(temperature=0)
 
 
 
 
15
 
16
  def tables_from_db():
17
  db = sqlite3.connect('switrs.sqlite')
 
36
 
37
 
38
 
39
+ def get_sql_agent():
40
  db = SQLDatabase.from_uri("sqlite:///switrs.sqlite")
41
 
42
  toolkit = SQLDatabaseToolkit(llm=llm,db=db)
 
50
  return_intermediate_steps=True
51
  )
52
 
53
+ def get_json_chain():
54
  prompt_template = "Reformat this {result} in JSON format"
55
  return LLMChain(
56
  llm=llm,
 
59
 
60
 
61
 
62
+ def from_gpt(query: str, plot: bool):
63
 
64
+ sql_agent = get_sql_agent()
65
+ json_chain = get_json_chain()
66
 
67
  chains = [sql_agent, json_chain] if plot else [sql_agent]
68
 
 
107
  return (response, on_render)
108
 
109
 
 
110
 
111
 
112
  chat = idf_chat.Chat()