maulikanalog commited on
Commit
ed548d7
โ€ข
1 Parent(s): ec2848d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -28
app.py CHANGED
@@ -9,12 +9,23 @@ from swiftsage.agents import SwiftSage
9
  from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
10
  from pkg_resources import resource_filename
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  ENGINE = "SambaNova"
13
  SWIFT_MODEL_ID = "Meta-Llama-3.1-8B-Instruct"
14
  FEEDBACK_MODEL_ID = "Meta-Llama-3.1-70B-Instruct"
15
  SAGE_MODEL_ID = "Meta-Llama-3.1-405B-Instruct"
16
 
17
- def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feedback_model_id, max_iterations, reward_threshold, swift_temperature, swift_top_p, sage_temperature, sage_top_p, feedback_temperature, feedback_top_p):
18
  global ENGINE
19
  # Configuration for each LLM
20
  max_iterations = int(max_iterations)
@@ -25,7 +36,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
25
  "api_config": api_configs[ENGINE],
26
  "temperature": float(swift_temperature),
27
  "top_p": float(swift_top_p),
28
- "max_tokens": 4096,
29
  }
30
 
31
  feedback_config = {
@@ -33,7 +44,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
33
  "api_config": api_configs[ENGINE],
34
  "temperature": float(feedback_temperature),
35
  "top_p": float(feedback_top_p),
36
- "max_tokens": 4096,
37
  }
38
 
39
  sage_config = {
@@ -41,9 +52,13 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
41
  "api_config": api_configs[ENGINE],
42
  "temperature": float(sage_temperature),
43
  "top_p": float(sage_top_p),
44
- "max_tokens": 4096,
45
  }
46
 
 
 
 
 
47
  # Try multiple locations for the prompt templates
48
  possible_paths = [
49
  resource_filename('swiftsage', 'prompt_templates'),
@@ -59,7 +74,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
59
  break
60
 
61
  dataset = []
62
- embeddings = []
63
  s2 = SwiftSage(
64
  dataset,
65
  embeddings,
@@ -67,54 +82,70 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
67
  swift_config,
68
  sage_config,
69
  feedback_config,
70
- use_retrieval=False,
71
- start_with_sage=False,
72
  )
73
 
74
- problem = f"Predict bank failure chances in percentage form from the following balance sheet data or parameters: {balance_sheet_data}"
75
  reasoning, solution, messages = s2.solve(problem, max_iterations, reward_threshold)
76
  reasoning = reasoning.replace("The generated code is:", "\n---\nThe generated code is:").strip()
77
  solution = solution.replace("Answer (from running the code):\n ", " ").strip()
 
78
 
79
  log_messages = "<pre style='white-space: pre-wrap; max-height: 500px; overflow-y: scroll;'><code class='log'>" + "\n".join(messages) + "</code></pre>"
80
  return reasoning, solution, log_messages
81
 
82
 
83
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
84
  gr.HTML("<h1 style='text-align: center;'>๐Ÿฆ Bank Failure Predictor</h1>")
85
  gr.HTML("<span>This tool predicts the likelihood of bank failure based on balance sheet data.</span>")
86
 
87
  with gr.Row():
88
- swift_model_id = gr.Textbox(label="Swift Model ID", value=SWIFT_MODEL_ID)
89
- feedback_model_id = gr.Textbox(label="Feedback Model ID", value=FEEDBACK_MODEL_ID)
90
- sage_model_id = gr.Textbox(label="Sage Model ID", value=SAGE_MODEL_ID)
91
-
92
- balance_sheet_data = gr.Textbox(label="Input balance sheet data or parameters", placeholder="Enter the bank's financial data here...", lines=5)
93
-
94
- # Hidden fields for advanced options
95
- max_iterations = gr.Textbox(visible=False, value="5")
96
- reward_threshold = gr.Textbox(visible=False, value="8")
97
- temperature_swift = gr.Textbox(visible=False, value="0.5")
98
- top_p_swift = gr.Textbox(visible=False, value="0.9")
99
- temperature_sage = gr.Textbox(visible=False, value="0.5")
100
- top_p_sage = gr.Textbox(visible=False, value="0.9")
101
- temperature_feedback = gr.Textbox(visible=False, value="0.5")
102
- top_p_feedback = gr.Textbox(visible=False, value="0.9")
103
-
104
- predict_button = gr.Button("๐Ÿ”ฎ Predict Failure Chance")
 
 
 
 
 
 
 
 
 
 
105
  reasoning_output = gr.Textbox(label="Prediction steps with Code", interactive=False)
106
  solution_output = gr.Textbox(label="Prediction Result", interactive=False)
107
 
 
108
  with gr.Accordion(label="๐Ÿ“œ Log Messages", open=False):
109
  log_output = gr.HTML("<p>No log messages yet.</p>")
110
 
111
- predict_button.click(
112
- predict_bank_failure,
113
- inputs=[balance_sheet_data, swift_model_id, sage_model_id, feedback_model_id, max_iterations, reward_threshold, temperature_swift, top_p_swift, temperature_sage, top_p_sage, temperature_feedback, top_p_feedback],
114
  outputs=[reasoning_output, solution_output, log_output],
115
  )
116
 
 
 
117
  if __name__ == '__main__':
 
118
  if not os.path.exists('logs'):
119
  os.makedirs('logs')
120
  multiprocessing.set_start_method('spawn')
 
9
  from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
10
  from pkg_resources import resource_filename
11
 
12
+ #ENGINE = "Together"
13
+ #SWIFT_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
14
+ #FEEDBACK_MODEL_ID = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
15
+ #SAGE_MODEL_ID = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
16
+
17
+
18
+ # ENGINE = "Groq"
19
+ # SWIFT_MODEL_ID = "llama-3.1-8b-instant"
20
+ # FEEDBACK_MODEL_ID = "llama-3.1-8b-instant"
21
+ # SAGE_MODEL_ID = "llama-3.1-70b-versatile"
22
+
23
  ENGINE = "SambaNova"
24
  SWIFT_MODEL_ID = "Meta-Llama-3.1-8B-Instruct"
25
  FEEDBACK_MODEL_ID = "Meta-Llama-3.1-70B-Instruct"
26
  SAGE_MODEL_ID = "Meta-Llama-3.1-405B-Instruct"
27
 
28
+ def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, swift_temperature, swift_top_p, sage_temperature, sage_top_p, feedback_temperature, feedback_top_p):
29
  global ENGINE
30
  # Configuration for each LLM
31
  max_iterations = int(max_iterations)
 
36
  "api_config": api_configs[ENGINE],
37
  "temperature": float(swift_temperature),
38
  "top_p": float(swift_top_p),
39
+ "max_tokens": 2048,
40
  }
41
 
42
  feedback_config = {
 
44
  "api_config": api_configs[ENGINE],
45
  "temperature": float(feedback_temperature),
46
  "top_p": float(feedback_top_p),
47
+ "max_tokens": 2048,
48
  }
49
 
50
  sage_config = {
 
52
  "api_config": api_configs[ENGINE],
53
  "temperature": float(sage_temperature),
54
  "top_p": float(sage_top_p),
55
+ "max_tokens": 2048,
56
  }
57
 
58
+ # specify the path to the prompt templates
59
+ # prompt_template_dir = './swiftsage/prompt_templates'
60
+ # prompt_template_dir = resource_filename('swiftsage', 'prompt_templates')
61
+
62
  # Try multiple locations for the prompt templates
63
  possible_paths = [
64
  resource_filename('swiftsage', 'prompt_templates'),
 
74
  break
75
 
76
  dataset = []
77
+ embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
78
  s2 = SwiftSage(
79
  dataset,
80
  embeddings,
 
82
  swift_config,
83
  sage_config,
84
  feedback_config,
85
+ use_retrieval=use_retrieval,
86
+ start_with_sage=start_with_sage,
87
  )
88
 
 
89
  reasoning, solution, messages = s2.solve(problem, max_iterations, reward_threshold)
90
  reasoning = reasoning.replace("The generated code is:", "\n---\nThe generated code is:").strip()
91
  solution = solution.replace("Answer (from running the code):\n ", " ").strip()
92
+ # generate HTML for the log messages and display them with wrap and a scroll bar and a max height in the code block with log style
93
 
94
  log_messages = "<pre style='white-space: pre-wrap; max-height: 500px; overflow-y: scroll;'><code class='log'>" + "\n".join(messages) + "</code></pre>"
95
  return reasoning, solution, log_messages
96
 
97
 
98
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
99
+ # gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
100
+ # use the html and center the title
101
  gr.HTML("<h1 style='text-align: center;'>๐Ÿฆ Bank Failure Predictor</h1>")
102
  gr.HTML("<span>This tool predicts the likelihood of bank failure based on balance sheet data.</span>")
103
 
104
  with gr.Row():
105
+ swift_model_id = gr.Textbox(label="๐Ÿ˜„ Swift Model ID", value=SWIFT_MODEL_ID)
106
+ feedback_model_id = gr.Textbox(label="๐Ÿค” Feedback Model ID", value=FEEDBACK_MODEL_ID)
107
+ sage_model_id = gr.Textbox(label="๐Ÿ˜Ž Sage Model ID", value=SAGE_MODEL_ID)
108
+ # the following two should have a smaller width
109
+
110
+ with gr.Accordion(label="โš™๏ธ Advanced Options", open=False):
111
+ with gr.Row():
112
+ with gr.Column():
113
+ max_iterations = gr.Textbox(label="Max Iterations", value="5")
114
+ reward_threshold = gr.Textbox(label="feedback Threshold", value="8")
115
+ # TODO: add top-p and temperature for each module for controlling
116
+ with gr.Column():
117
+ top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
118
+ temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.5")
119
+ with gr.Column():
120
+ top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
121
+ temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.5")
122
+ with gr.Column():
123
+ top_p_feedback = gr.Textbox(label="Top-p for Feedback", value="0.9")
124
+ temperature_feedback = gr.Textbox(label="Temperature for Feedback", value="0.5")
125
+
126
+ use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
127
+ start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)
128
+
129
+ problem = gr.Textbox(label="Input balance sheet data or parameters", value="Enter the bank's financial data here...", lines=5)
130
+
131
+ solve_button = gr.Button("๐Ÿ”ฎ Predict Failure Chance")
132
  reasoning_output = gr.Textbox(label="Prediction steps with Code", interactive=False)
133
  solution_output = gr.Textbox(label="Prediction Result", interactive=False)
134
 
135
+ # add a log display for showing the log messages
136
  with gr.Accordion(label="๐Ÿ“œ Log Messages", open=False):
137
  log_output = gr.HTML("<p>No log messages yet.</p>")
138
 
139
+ solve_button.click(
140
+ solve_problem,
141
+ inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, temperature_swift, top_p_swift, temperature_sage, top_p_sage, temperature_feedback, top_p_feedback],
142
  outputs=[reasoning_output, solution_output, log_output],
143
  )
144
 
145
+
146
+
147
  if __name__ == '__main__':
148
+ # make logs dir if it does not exist
149
  if not os.path.exists('logs'):
150
  os.makedirs('logs')
151
  multiprocessing.set_start_method('spawn')