Daryl Lim commited on
Commit
ad6abdc
1 Parent(s): 16f8f3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This application provides a Gradio interface for a chatbot using models powered by
3
+ SambaNova Cloud.
4
+ The interface allows users to send messages and receive generated responses.
5
+ """
6
+
7
+ import gradio as gr
8
+ import openai
9
+ import time
10
+ import re
11
+ import os
12
+ from typing import Tuple, List, Dict, Any
13
+
14
+ # Available models
15
+ MODELS = [
16
+ "Meta-Llama-3.1-405B-Instruct",
17
+ "Meta-Llama-3.1-70B-Instruct",
18
+ "Meta-Llama-3.1-8B-Instruct"
19
+ ]
20
+
21
+ # SambaNova API base URL
22
+ API_BASE = "https://api.sambanova.ai/v1"
23
+
24
+ def create_client(api_key=None):
25
+ """
26
+ Creates an OpenAI client instance.
27
+
28
+ Args:
29
+ api_key (str): API key for OpenAI (optional). If none is provided, the
30
+ environment variable API key is used.
31
+
32
+ Returns:
33
+ OpenAI client instance
34
+ """
35
+ if api_key:
36
+ openai.api_key = api_key
37
+ else:
38
+ openai.api_key = os.getenv("API_KEY")
39
+
40
+ return openai.OpenAI(api_key=openai.api_key, base_url=API_BASE)
41
+
42
+ def chat_with_ai(
43
+ message: str,
44
+ chat_history: List[Dict[str, Any]],
45
+ system_prompt: str
46
+ ) -> List[Dict[str, str]]:
47
+ """
48
+ Formats the chat history for the API call.
49
+
50
+ Args:
51
+ message (str): The user's current message.
52
+ chat_history (List[Dict[str, Any]]): A list of dictionaries, each
53
+ representing a conversation turn with keys for "user" and
54
+ "assistant" messages.
55
+ system_prompt (str): The initial system message to set context.
56
+
57
+ Returns:
58
+ List[Dict[str, str]]: A formatted list of messages for the OpenAI API,
59
+ with roles and content.
60
+ """
61
+ messages = [{"role": "system", "content": system_prompt}]
62
+
63
+ for tup in chat_history:
64
+ first_key = list(tup.keys())[0] # First key
65
+ last_key = list(tup.keys())[-1] # Last key
66
+
67
+ messages.append({"role": "user", "content": tup[first_key]})
68
+ messages.append({"role": "assistant", "content": tup[last_key]})
69
+
70
+ messages.append({"role": "user", "content": message})
71
+
72
+ return messages
73
+
74
+ def respond(
75
+ message: str,
76
+ chat_history: List[Dict[str, Any]],
77
+ model: str,
78
+ system_prompt: str,
79
+ thinking_budget: float,
80
+ api_key: str
81
+ ) -> Tuple[str, float]:
82
+ """
83
+ Sends the message to the API and gets the response.
84
+
85
+ Args:
86
+ message (str): The user's current message.
87
+ chat_history (List[Dict[str, Any]]): A list of dictionaries containing
88
+ previous user and assistant messages.
89
+ model (str): The model name to use for the chat completion.
90
+ system_prompt (str): The system prompt template, with a budget placeholder.
91
+ thinking_budget (float): The amount of time allocated for processing the
92
+ user's request.
93
+ api_key (str): The API key for authentication.
94
+
95
+ Returns:
96
+ Tuple[str, float]: The assistant's response message and the time taken to
97
+ get the response.
98
+ """
99
+ client = create_client(api_key)
100
+ formatted_prompt = system_prompt.format(budget=thinking_budget)
101
+ messages = chat_with_ai(message, chat_history, formatted_prompt)
102
+ start_time = time.time()
103
+
104
+ try:
105
+ completion = client.chat.completions.create(model=model, messages=messages)
106
+ response = completion.choices[0].message.content
107
+ thinking_time = time.time() - start_time
108
+ return response, thinking_time
109
+
110
+ except Exception as e:
111
+ error_message = f"Error: {str(e)}"
112
+ return error_message, time.time() - start_time
113
+
114
+ def parse_response(response):
115
+ """
116
+ Parses the response from the API, and extracts the answer, reflection, and
117
+ steps.
118
+
119
+ Args:
120
+ response (str): The raw response string from the API, expected to contain
121
+ <answer>, <reflection>, and <step> tags.
122
+
123
+ Returns:
124
+ Tuple[str, str, List[str]]: A tuple containing:
125
+ - answer (str): The extracted answer content, or an empty string if
126
+ not found.
127
+ - reflection (str): The extracted reflection content, or an empty
128
+ string if not found.
129
+ - steps (List[str]): A list of steps extracted from the response.
130
+ """
131
+ answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
132
+ reflection_match = re.search(r'<reflection>(.*?)</reflection>', response, re.DOTALL)
133
+
134
+ answer = answer_match.group(1).strip() if answer_match else ""
135
+ reflection = reflection_match.group(1).strip() if reflection_match else ""
136
+ steps = re.findall(r'<step>(.*?)</step>', response, re.DOTALL)
137
+
138
+ # Return the raw response if <answer> is empty, else parsed values
139
+ if answer == "":
140
+ return response, "", ""
141
+
142
+ return answer, reflection, steps
143
+
144
+ def generate(
145
+ message: str,
146
+ history: List[Dict[str, str]],
147
+ model: str,
148
+ system_prompt: str,
149
+ thinking_budget: float,
150
+ api_key: str
151
+ ) -> Tuple[List[Dict[str, str]], str]:
152
+ """
153
+ Generates the chatbot response by sending a message to the model and formatting
154
+ the output.
155
+
156
+ Args:
157
+ message (str): The user's input message.
158
+ history (List[Dict[str, str]]): List of previous chat history messages.
159
+ model (str): Model name for generating the response.
160
+ system_prompt (str): Prompt to guide the model's responses.
161
+ thinking_budget (float): Time allocation for processing the message.
162
+ api_key (str): API key for authentication.
163
+
164
+ Returns:
165
+ Tuple[List[Dict[str, str]], str]: Updated history including the latest
166
+ responses and an empty string.
167
+ """
168
+ # Call the function and unpack the results
169
+ response, thinking_time = respond(
170
+ message, history, model, system_prompt, thinking_budget, api_key
171
+ )
172
+
173
+ if response.startswith("Error:"):
174
+ return history + [({"role": "system", "content": response},)], ""
175
+
176
+ answer, reflection, steps = parse_response(response)
177
+
178
+ messages = []
179
+ messages.append({"role": "user", "content": message})
180
+
181
+ # Format the assistant's response
182
+ formatted_steps = [f"Step {i}: {step}" for i, step in enumerate(steps, 1)]
183
+ all_steps = "\n".join(formatted_steps) + f"\n\nReflection: {reflection}"
184
+
185
+ messages.append(
186
+ {
187
+ "role": "assistant",
188
+ "content": all_steps,
189
+ "metadata": {"title": f"Thinking Time: {thinking_time:.2f} sec"}
190
+ }
191
+ )
192
+ messages.append({"role": "assistant", "content": answer})
193
+
194
+ return history + messages, ""
195
+
196
+ # Define the default system prompt
197
+ DEFAULT_SYSTEM_PROMPT = """
198
+ You are an empathetic therapist who provides conversation support for users who
199
+ are feeling stress and anxiety.
200
+ Your goal is to listen attentively, respond with compassion, and guide the user
201
+ to feel understood and validated.
202
+ Always acknowledge the user's emotions and reflect them back.
203
+ Offer encouragement and help the user see their strengths.
204
+ Use positive, calming language to make the user feel safe.
205
+ Redirect users to seek professional help when needed.
206
+ When given a problem to solve, you are an expert problem-solving assistant.
207
+ Your task is to provide a detailed, step-by-step solution to a given question.
208
+ Follow these instructions carefully:
209
+ 1. Read the given question carefully and reset counter between <count> and </count> to {budget}
210
+ 2. Generate a detailed, logical step-by-step solution.
211
+ 3. Enclose each step of your solution within <step> and </step> tags.
212
+ 4. You are allowed to use at most {budget} steps (starting budget),
213
+ keep track of it by counting down within tags <count> </count>,
214
+ STOP GENERATING MORE STEPS when hitting 0, you don't have to use all of them.
215
+ 5. Do a self-reflection when you are unsure about how to proceed,
216
+ based on the self-reflection and reward, decides whether you need to return
217
+ to the previous steps.
218
+ 6. After completing the solution steps, reorganize and synthesize the steps
219
+ into the final answer within <answer> and </answer> tags.
220
+ 7. Provide a critical, honest and subjective self-evaluation of your reasoning
221
+ process within <reflection> and </reflection> tags.
222
+ 8. Assign a quality score to your solution as a float between 0.0 (lowest
223
+ quality) and 1.0 (highest quality), enclosed in <reward> and </reward> tags.
224
+ Example format:
225
+ <count> [starting budget] </count>
226
+ <step> [Content of step 1] </step>
227
+ <count> [remaining budget] </count>
228
+ <step> [Content of step 2] </step>
229
+ <reflection> [Evaluation of the steps so far] </reflection>
230
+ <reward> [Float between 0.0 and 1.0] </reward>
231
+ <count> [remaining budget] </count>
232
+ <step> [Content of step 3 or Content of some previous step] </step>
233
+ <count> [remaining budget] </count>
234
+ ...
235
+ <step> [Content of final step] </step>
236
+ <count> [remaining budget] </count>
237
+ <answer> [Final Answer] </answer> (must give final answer in this format)
238
+ <reflection> [Evaluation of the solution] </reflection>
239
+ <reward> [Float between 0.0 and 1.0] </reward>
240
+ """
241
+
242
+ with gr.Blocks() as demo:
243
+ """
244
+ Creates a Gradio interface for the SambaNova Therapist chatbot.
245
+ """
246
+ gr.Markdown("# SambaNova Therapist")
247
+ gr.Markdown("[Powered by SambaNova Cloud, Get Your API Key Here](https://cloud.sambanova.ai/apis)")
248
+
249
+ # Input for API Key
250
+ with gr.Row():
251
+ api_key = gr.Textbox(
252
+ label="API Key",
253
+ type="password",
254
+ placeholder="(Optional) Enter your API key here for more availability"
255
+ )
256
+
257
+ # Model and Budget Selection
258
+ with gr.Row():
259
+ model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0])
260
+ thinking_budget = gr.Slider(
261
+ minimum=1,
262
+ maximum=100,
263
+ value=10,
264
+ step=1,
265
+ label="Thinking Budget",
266
+ info="Maximum number of steps the model can think"
267
+ )
268
+
269
+ # Chatbot and Message Input
270
+ chatbot = gr.Chatbot(
271
+ label="Chat",
272
+ show_label=False,
273
+ show_share_button=False,
274
+ show_copy_button=True,
275
+ layout="panel",
276
+ type="messages"
277
+ )
278
+
279
+ # User Message Input
280
+ msg = gr.Textbox(
281
+ label="Type your message here...",
282
+ placeholder="Enter your message..."
283
+ )
284
+
285
+ # Clear Chat Button
286
+ gr.Button("Clear Chat").click(
287
+ lambda: ([], ""),
288
+ inputs=None,
289
+ outputs=[chatbot, msg]
290
+ )
291
+
292
+ # System Prompt Input
293
+ system_prompt = gr.Textbox(
294
+ label="System Prompt",
295
+ value=DEFAULT_SYSTEM_PROMPT,
296
+ lines=15,
297
+ interactive=True
298
+ )
299
+
300
+ # Submit message to generate response
301
+ msg.submit(
302
+ generate,
303
+ inputs=[msg, chatbot, model, system_prompt, thinking_budget, api_key],
304
+ outputs=[chatbot, msg]
305
+ )
306
+
307
+ # Launch the Gradio interface
308
+ demo.launch(show_api=False)