Spaces:
Sleeping
Sleeping
maulikanalog
commited on
Commit
โข
ed548d7
1
Parent(s):
ec2848d
Update app.py
Browse files
app.py
CHANGED
@@ -9,12 +9,23 @@ from swiftsage.agents import SwiftSage
|
|
9 |
from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
|
10 |
from pkg_resources import resource_filename
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
ENGINE = "SambaNova"
|
13 |
SWIFT_MODEL_ID = "Meta-Llama-3.1-8B-Instruct"
|
14 |
FEEDBACK_MODEL_ID = "Meta-Llama-3.1-70B-Instruct"
|
15 |
SAGE_MODEL_ID = "Meta-Llama-3.1-405B-Instruct"
|
16 |
|
17 |
-
def
|
18 |
global ENGINE
|
19 |
# Configuration for each LLM
|
20 |
max_iterations = int(max_iterations)
|
@@ -25,7 +36,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
|
|
25 |
"api_config": api_configs[ENGINE],
|
26 |
"temperature": float(swift_temperature),
|
27 |
"top_p": float(swift_top_p),
|
28 |
-
"max_tokens":
|
29 |
}
|
30 |
|
31 |
feedback_config = {
|
@@ -33,7 +44,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
|
|
33 |
"api_config": api_configs[ENGINE],
|
34 |
"temperature": float(feedback_temperature),
|
35 |
"top_p": float(feedback_top_p),
|
36 |
-
"max_tokens":
|
37 |
}
|
38 |
|
39 |
sage_config = {
|
@@ -41,9 +52,13 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
|
|
41 |
"api_config": api_configs[ENGINE],
|
42 |
"temperature": float(sage_temperature),
|
43 |
"top_p": float(sage_top_p),
|
44 |
-
"max_tokens":
|
45 |
}
|
46 |
|
|
|
|
|
|
|
|
|
47 |
# Try multiple locations for the prompt templates
|
48 |
possible_paths = [
|
49 |
resource_filename('swiftsage', 'prompt_templates'),
|
@@ -59,7 +74,7 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
|
|
59 |
break
|
60 |
|
61 |
dataset = []
|
62 |
-
embeddings = []
|
63 |
s2 = SwiftSage(
|
64 |
dataset,
|
65 |
embeddings,
|
@@ -67,54 +82,70 @@ def predict_bank_failure(balance_sheet_data, swift_model_id, sage_model_id, feed
|
|
67 |
swift_config,
|
68 |
sage_config,
|
69 |
feedback_config,
|
70 |
-
use_retrieval=
|
71 |
-
start_with_sage=
|
72 |
)
|
73 |
|
74 |
-
problem = f"Predict bank failure chances in percentage form from the following balance sheet data or parameters: {balance_sheet_data}"
|
75 |
reasoning, solution, messages = s2.solve(problem, max_iterations, reward_threshold)
|
76 |
reasoning = reasoning.replace("The generated code is:", "\n---\nThe generated code is:").strip()
|
77 |
solution = solution.replace("Answer (from running the code):\n ", " ").strip()
|
|
|
78 |
|
79 |
log_messages = "<pre style='white-space: pre-wrap; max-height: 500px; overflow-y: scroll;'><code class='log'>" + "\n".join(messages) + "</code></pre>"
|
80 |
return reasoning, solution, log_messages
|
81 |
|
82 |
|
83 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
|
|
84 |
gr.HTML("<h1 style='text-align: center;'>๐ฆ Bank Failure Predictor</h1>")
|
85 |
gr.HTML("<span>This tool predicts the likelihood of bank failure based on balance sheet data.</span>")
|
86 |
|
87 |
with gr.Row():
|
88 |
-
swift_model_id = gr.Textbox(label="Swift Model ID", value=SWIFT_MODEL_ID)
|
89 |
-
feedback_model_id = gr.Textbox(label="Feedback Model ID", value=FEEDBACK_MODEL_ID)
|
90 |
-
sage_model_id = gr.Textbox(label="Sage Model ID", value=SAGE_MODEL_ID)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
reasoning_output = gr.Textbox(label="Prediction steps with Code", interactive=False)
|
106 |
solution_output = gr.Textbox(label="Prediction Result", interactive=False)
|
107 |
|
|
|
108 |
with gr.Accordion(label="๐ Log Messages", open=False):
|
109 |
log_output = gr.HTML("<p>No log messages yet.</p>")
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
inputs=[
|
114 |
outputs=[reasoning_output, solution_output, log_output],
|
115 |
)
|
116 |
|
|
|
|
|
117 |
if __name__ == '__main__':
|
|
|
118 |
if not os.path.exists('logs'):
|
119 |
os.makedirs('logs')
|
120 |
multiprocessing.set_start_method('spawn')
|
|
|
9 |
from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
|
10 |
from pkg_resources import resource_filename
|
11 |
|
12 |
+
#ENGINE = "Together"
|
13 |
+
#SWIFT_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
|
14 |
+
#FEEDBACK_MODEL_ID = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
15 |
+
#SAGE_MODEL_ID = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
|
16 |
+
|
17 |
+
|
18 |
+
# ENGINE = "Groq"
|
19 |
+
# SWIFT_MODEL_ID = "llama-3.1-8b-instant"
|
20 |
+
# FEEDBACK_MODEL_ID = "llama-3.1-8b-instant"
|
21 |
+
# SAGE_MODEL_ID = "llama-3.1-70b-versatile"
|
22 |
+
|
23 |
ENGINE = "SambaNova"
|
24 |
SWIFT_MODEL_ID = "Meta-Llama-3.1-8B-Instruct"
|
25 |
FEEDBACK_MODEL_ID = "Meta-Llama-3.1-70B-Instruct"
|
26 |
SAGE_MODEL_ID = "Meta-Llama-3.1-405B-Instruct"
|
27 |
|
28 |
+
def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, swift_temperature, swift_top_p, sage_temperature, sage_top_p, feedback_temperature, feedback_top_p):
|
29 |
global ENGINE
|
30 |
# Configuration for each LLM
|
31 |
max_iterations = int(max_iterations)
|
|
|
36 |
"api_config": api_configs[ENGINE],
|
37 |
"temperature": float(swift_temperature),
|
38 |
"top_p": float(swift_top_p),
|
39 |
+
"max_tokens": 2048,
|
40 |
}
|
41 |
|
42 |
feedback_config = {
|
|
|
44 |
"api_config": api_configs[ENGINE],
|
45 |
"temperature": float(feedback_temperature),
|
46 |
"top_p": float(feedback_top_p),
|
47 |
+
"max_tokens": 2048,
|
48 |
}
|
49 |
|
50 |
sage_config = {
|
|
|
52 |
"api_config": api_configs[ENGINE],
|
53 |
"temperature": float(sage_temperature),
|
54 |
"top_p": float(sage_top_p),
|
55 |
+
"max_tokens": 2048,
|
56 |
}
|
57 |
|
58 |
+
# specify the path to the prompt templates
|
59 |
+
# prompt_template_dir = './swiftsage/prompt_templates'
|
60 |
+
# prompt_template_dir = resource_filename('swiftsage', 'prompt_templates')
|
61 |
+
|
62 |
# Try multiple locations for the prompt templates
|
63 |
possible_paths = [
|
64 |
resource_filename('swiftsage', 'prompt_templates'),
|
|
|
74 |
break
|
75 |
|
76 |
dataset = []
|
77 |
+
embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
|
78 |
s2 = SwiftSage(
|
79 |
dataset,
|
80 |
embeddings,
|
|
|
82 |
swift_config,
|
83 |
sage_config,
|
84 |
feedback_config,
|
85 |
+
use_retrieval=use_retrieval,
|
86 |
+
start_with_sage=start_with_sage,
|
87 |
)
|
88 |
|
|
|
89 |
reasoning, solution, messages = s2.solve(problem, max_iterations, reward_threshold)
|
90 |
reasoning = reasoning.replace("The generated code is:", "\n---\nThe generated code is:").strip()
|
91 |
solution = solution.replace("Answer (from running the code):\n ", " ").strip()
|
92 |
+
# generate HTML for the log messages and display them with wrap and a scroll bar and a max height in the code block with log style
|
93 |
|
94 |
log_messages = "<pre style='white-space: pre-wrap; max-height: 500px; overflow-y: scroll;'><code class='log'>" + "\n".join(messages) + "</code></pre>"
|
95 |
return reasoning, solution, log_messages
|
96 |
|
97 |
|
98 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
99 |
+
# gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
|
100 |
+
# use the html and center the title
|
101 |
gr.HTML("<h1 style='text-align: center;'>๐ฆ Bank Failure Predictor</h1>")
|
102 |
gr.HTML("<span>This tool predicts the likelihood of bank failure based on balance sheet data.</span>")
|
103 |
|
104 |
with gr.Row():
|
105 |
+
swift_model_id = gr.Textbox(label="๐ Swift Model ID", value=SWIFT_MODEL_ID)
|
106 |
+
feedback_model_id = gr.Textbox(label="๐ค Feedback Model ID", value=FEEDBACK_MODEL_ID)
|
107 |
+
sage_model_id = gr.Textbox(label="๐ Sage Model ID", value=SAGE_MODEL_ID)
|
108 |
+
# the following two should have a smaller width
|
109 |
+
|
110 |
+
with gr.Accordion(label="โ๏ธ Advanced Options", open=False):
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column():
|
113 |
+
max_iterations = gr.Textbox(label="Max Iterations", value="5")
|
114 |
+
reward_threshold = gr.Textbox(label="feedback Threshold", value="8")
|
115 |
+
# TODO: add top-p and temperature for each module for controlling
|
116 |
+
with gr.Column():
|
117 |
+
top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
|
118 |
+
temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.5")
|
119 |
+
with gr.Column():
|
120 |
+
top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
|
121 |
+
temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.5")
|
122 |
+
with gr.Column():
|
123 |
+
top_p_feedback = gr.Textbox(label="Top-p for Feedback", value="0.9")
|
124 |
+
temperature_feedback = gr.Textbox(label="Temperature for Feedback", value="0.5")
|
125 |
+
|
126 |
+
use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
|
127 |
+
start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)
|
128 |
+
|
129 |
+
problem = gr.Textbox(label="Input balance sheet data or parameters", value="Enter the bank's financial data here...", lines=5)
|
130 |
+
|
131 |
+
solve_button = gr.Button("๐ฎ Predict Failure Chance")
|
132 |
reasoning_output = gr.Textbox(label="Prediction steps with Code", interactive=False)
|
133 |
solution_output = gr.Textbox(label="Prediction Result", interactive=False)
|
134 |
|
135 |
+
# add a log display for showing the log messages
|
136 |
with gr.Accordion(label="๐ Log Messages", open=False):
|
137 |
log_output = gr.HTML("<p>No log messages yet.</p>")
|
138 |
|
139 |
+
solve_button.click(
|
140 |
+
solve_problem,
|
141 |
+
inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, temperature_swift, top_p_swift, temperature_sage, top_p_sage, temperature_feedback, top_p_feedback],
|
142 |
outputs=[reasoning_output, solution_output, log_output],
|
143 |
)
|
144 |
|
145 |
+
|
146 |
+
|
147 |
if __name__ == '__main__':
|
148 |
+
# make logs dir if it does not exist
|
149 |
if not os.path.exists('logs'):
|
150 |
os.makedirs('logs')
|
151 |
multiprocessing.set_start_method('spawn')
|