wuhp commited on
Commit
b9d6d53
·
verified ·
1 Parent(s): ccc6355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -98
app.py CHANGED
@@ -4,13 +4,18 @@ import spaces # Import the spaces library
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
5
  import torch
6
  from threading import Thread
 
 
 
 
 
7
 
8
  # --- Model & Quantization Settings ---
9
  MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
10
 
11
  # Dictionaries to store the loaded model and tokenizer
12
- models = {}
13
- tokenizers = {}
14
 
15
  bnb_config_4bit = BitsAndBytesConfig(
16
  load_in_4bit=True,
@@ -18,23 +23,34 @@ bnb_config_4bit = BitsAndBytesConfig(
18
  bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
19
  )
20
 
21
- def get_model_and_tokenizer():
22
- """Lazy-load the model and tokenizer if not already loaded."""
 
 
 
 
 
23
  if "7B" not in models:
24
- print(f"Loading 7B model: {MODEL_ID} on demand")
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- MODEL_ID,
28
- quantization_config=bnb_config_4bit,
29
- torch_dtype=torch.bfloat16, # Or torch.float16 if needed
30
- device_map='auto',
31
- trust_remote_code=True,
32
- )
33
- models["7B"] = model
34
- tokenizers["7B"] = tokenizer
35
- print("Loaded 7B model on demand.")
 
 
 
 
 
36
  return models["7B"], tokenizers["7B"]
37
 
 
38
  # --- Default Prompt Templates ---
39
  default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
40
  As a Senior Code Analyst, provide an initial analysis of the problem below.
@@ -74,50 +90,68 @@ Review the detailed code generation and reasoning below, and produce a final, re
74
  {code_response}
75
  """
76
 
77
- # --- Shared Memory for Rounds ---
78
- shared_memory = []
79
 
80
- def store_in_memory(memory_item):
81
- """Store a memory item and log an excerpt."""
82
- shared_memory.append(memory_item)
83
- print(f"\n[Memory Stored]: {memory_item[:50]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- def retrieve_from_memory(query, top_k=2):
86
- """
87
- Retrieve memory items that contain the query text (case-insensitive).
88
- Returns up to top_k items.
89
- """
90
- relevant_memories = []
91
- query_lower = query.lower()
92
- for memory_item in shared_memory:
93
- if query_lower in memory_item.lower():
94
- relevant_memories.append(memory_item)
95
- if not relevant_memories:
96
- print("\n[Memory Retrieval]: No relevant memories found.")
97
- return []
98
- print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
99
- return relevant_memories[:top_k]
100
 
101
  # --- Multi-Round Swarm Agent Function ---
102
  @spaces.GPU(duration=180) # Adjust duration as needed
103
- def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k,
104
- prompt_brainstorm_text, prompt_code_generation_text, prompt_synthesis_text):
 
105
  """
106
  A three-round iterative process that uses the provided prompt templates:
107
  - Round 1: Brainstorming.
108
  - Round 2: Advanced reasoning & code generation.
109
  - Round 3: Synthesis & refinement.
 
110
  This generator yields the response from the final round as it is produced.
111
- """
112
- global shared_memory
113
- shared_memory = [] # Clear shared memory for each new request
114
 
 
 
 
115
  model, tokenizer = get_model_and_tokenizer()
116
 
117
  # ----- Round 1: Brainstorming -----
118
- print("\n--- Round 1: Brainstorming ---")
119
- prompt_round1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
120
- input_ids_r1 = tokenizer.encode(prompt_round1, return_tensors="pt").to(model.device)
121
  streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
122
  kwargs_r1 = dict(
123
  input_ids=input_ids_r1,
@@ -127,22 +161,32 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
127
  temperature=temp,
128
  top_p=top_p,
129
  )
130
- thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
131
- thread_r1.start()
 
 
 
 
 
132
 
133
  brainstorm_response = ""
134
- for text in streamer_r1:
135
- print(text, end="", flush=True)
136
- brainstorm_response += text
137
- store_in_memory(f"Brainstorm Response: {brainstorm_response[:200]}...")
 
 
 
 
 
138
 
139
  # ----- Round 2: Code Generation -----
140
- print("\n\n--- Round 2: Code Generation ---")
141
- prompt_round2 = prompt_code_generation_text.format(
142
  brainstorm_response=brainstorm_response,
143
  user_prompt=user_prompt
144
  )
145
- input_ids_r2 = tokenizer.encode(prompt_round2, return_tensors="pt").to(model.device)
146
  streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
147
  kwargs_r2 = dict(
148
  input_ids=input_ids_r2,
@@ -151,19 +195,29 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
151
  temperature=temp,
152
  top_p=top_p,
153
  )
154
- thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
155
- thread_r2.start()
 
 
 
 
 
156
 
157
  code_response = ""
158
- for text in streamer_r2:
159
- print(text, end="", flush=True)
160
- code_response += text
161
- store_in_memory(f"Code Generation Response: {code_response[:200]}...")
 
 
 
 
 
162
 
163
  # ----- Round 3: Synthesis & Refinement -----
164
- print("\n\n--- Round 3: Synthesis & Refinement ---")
165
- prompt_round3 = prompt_synthesis_text.format(code_response=code_response)
166
- input_ids_r3 = tokenizer.encode(prompt_round3, return_tensors="pt").to(model.device)
167
  streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
168
  kwargs_r3 = dict(
169
  input_ids=input_ids_r3,
@@ -172,58 +226,137 @@ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k
172
  temperature=temp,
173
  top_p=top_p,
174
  )
175
- thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
176
- thread_r3.start()
 
 
 
 
 
177
 
178
  final_response = ""
179
- for text in streamer_r3:
180
- print(text, end="", flush=True)
181
- final_response += text
182
- yield final_response # yield progressive updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- store_in_memory(f"Final Synthesis Response: {final_response[:200]}...")
185
 
186
  # --- Helper to Format History ---
187
- def format_history(history):
188
  """
189
  Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
190
  into a list of OpenAI-style message dictionaries.
 
 
 
 
 
 
191
  """
192
  messages = []
193
  for item in history:
194
  # If item is a list or tuple, try to unpack it if it has exactly 2 elements.
195
- if isinstance(item, (list, tuple)):
196
- if len(item) == 2:
197
- user_msg, assistant_msg = item
198
- messages.append({"role": "user", "content": user_msg})
199
- if assistant_msg:
200
- messages.append({"role": "assistant", "content": assistant_msg})
201
- else:
202
- # If it doesn't have exactly two items, skip it.
203
- continue
204
  elif isinstance(item, dict):
205
- # Already formatted message dictionary.
206
  messages.append(item)
207
- else:
208
- continue
209
  return messages
210
 
 
211
  # --- Gradio Chat Interface Function ---
212
- def gradio_interface(message, history, param_state, prompt_state):
213
  """
214
  This function is called by Gradio's ChatInterface.
215
  It uses the current saved generation parameters and prompt templates.
 
 
 
 
 
 
 
 
 
 
 
216
  """
217
- # Unpack parameter state (with fallback defaults)
 
 
 
 
 
 
218
  try:
219
  temp = float(param_state.get("temperature", 0.5))
220
  top_p = float(param_state.get("top_p", 0.9))
221
  max_new_tokens = int(param_state.get("max_new_tokens", 300))
222
  memory_top_k = int(param_state.get("memory_top_k", 2))
223
- except Exception:
 
224
  temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
225
 
226
- # Unpack prompt state (with fallback defaults)
227
  prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
228
  prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
229
  prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
@@ -244,9 +377,9 @@ def gradio_interface(message, history, param_state, prompt_state):
244
  ):
245
  # Update the last assistant message with the new partial response.
246
  history[-1][1] = partial_response
247
- # Yield the history formatted as OpenAI-style messages.
248
  yield format_history(history)
249
 
 
250
  # --- UI Settings & Styling ---
251
  ui_description = '''
252
  <div>
@@ -285,10 +418,11 @@ h1 {
285
  }
286
  """
287
 
 
288
  # --- Gradio UI ---
289
  with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
290
  gr.Markdown(ui_description)
291
-
292
  # Hidden States to hold parameters and prompt configuration
293
  param_state = gr.State({
294
  "temperature": 0.5,
@@ -301,14 +435,12 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
301
  "prompt_code_generation": default_prompt_code_generation,
302
  "prompt_synthesis": default_prompt_synthesis,
303
  })
304
-
305
  # Create top-level Tabs
306
  with gr.Tabs():
307
  # --- Chat Tab ---
308
  with gr.Tab("Chat"):
309
- # Set type="messages" for OpenAI-style message dictionaries
310
  chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
311
- # Use ChatInterface and pass the hidden states as additional inputs.
312
  gr.ChatInterface(
313
  fn=gradio_interface,
314
  chatbot=chatbot,
@@ -323,7 +455,7 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
323
  cache_examples=False,
324
  type="messages",
325
  )
326
-
327
  # --- Parameters Tab ---
328
  with gr.Tab("Parameters"):
329
  gr.Markdown("### Generation Parameters")
@@ -332,13 +464,12 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
332
  max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
333
  memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
334
  save_params_btn = gr.Button("Save Parameters")
335
- # When the user clicks Save, update the param_state
336
  save_params_btn.click(
337
  lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
338
  inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
339
  outputs=param_state,
340
  )
341
-
342
  # --- Prompt Config Tab ---
343
  with gr.Tab("Prompt Config"):
344
  gr.Markdown("### Configure Prompt Templates")
@@ -358,7 +489,6 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
358
  lines=8,
359
  )
360
  save_prompts_btn = gr.Button("Save Prompts")
361
- # When clicked, update the prompt_state with new values
362
  save_prompts_btn.click(
363
  lambda b, c, s: {
364
  "prompt_brainstorm": b,
@@ -368,8 +498,8 @@ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
368
  inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
369
  outputs=prompt_state,
370
  )
371
-
372
  gr.Markdown(ui_license)
373
 
374
  if __name__ == "__main__":
375
- demo.launch()
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
5
  import torch
6
  from threading import Thread
7
+ import logging
8
+ from typing import Tuple, List, Dict, Generator
9
+
10
+ # --- Logging Configuration ---
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
  # --- Model & Quantization Settings ---
14
  MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
15
 
16
  # Dictionaries to store the loaded model and tokenizer
17
+ models: Dict[str, AutoModelForCausalLM] = {}
18
+ tokenizers: Dict[str, AutoTokenizer] = {}
19
 
20
  bnb_config_4bit = BitsAndBytesConfig(
21
  load_in_4bit=True,
 
23
  bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
24
  )
25
 
26
+ def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
27
+ """
28
+ Lazy-load the model and tokenizer if not already loaded.
29
+
30
+ Returns:
31
+ Tuple[model, tokenizer]: The loaded model and tokenizer.
32
+ """
33
  if "7B" not in models:
34
+ logging.info(f"Loading 7B model: {MODEL_ID} on demand")
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_ID,
39
+ quantization_config=bnb_config_4bit,
40
+ torch_dtype=torch.bfloat16, # Or torch.float16 if needed
41
+ device_map='auto',
42
+ trust_remote_code=True,
43
+ )
44
+ model.eval() # Set the model to evaluation mode
45
+ models["7B"] = model
46
+ tokenizers["7B"] = tokenizer
47
+ logging.info("Loaded 7B model on demand.")
48
+ except Exception as e:
49
+ logging.error(f"Failed to load model and tokenizer: {e}")
50
+ raise e
51
  return models["7B"], tokenizers["7B"]
52
 
53
+
54
  # --- Default Prompt Templates ---
55
  default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
56
  As a Senior Code Analyst, provide an initial analysis of the problem below.
 
90
  {code_response}
91
  """
92
 
 
 
93
 
94
+ # --- Memory Management ---
95
+ class MemoryManager:
96
+ """Encapsulate shared memory for storing and retrieving conversation items."""
97
+ def __init__(self) -> None:
98
+ self.shared_memory: List[str] = []
99
+
100
+ def store(self, item: str) -> None:
101
+ """
102
+ Store a memory item and log an excerpt.
103
+
104
+ Args:
105
+ item (str): The memory content to store.
106
+ """
107
+ self.shared_memory.append(item)
108
+ logging.info(f"[Memory Stored]: {item[:50]}...")
109
+
110
+ def retrieve(self, query: str, top_k: int = 3) -> List[str]:
111
+ """
112
+ Retrieve memory items that contain the query text (case-insensitive).
113
+
114
+ Args:
115
+ query (str): The text query to search for.
116
+ top_k (int): Maximum number of memory items to return.
117
+
118
+ Returns:
119
+ List[str]: A list of up to top_k memory items.
120
+ """
121
+ query_lower = query.lower()
122
+ relevant = [item for item in self.shared_memory if query_lower in item.lower()]
123
+ if not relevant:
124
+ logging.info("[Memory Retrieval]: No relevant memories found.")
125
+ else:
126
+ logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
127
+ return relevant[:top_k]
128
+
129
+ # Create a global memory manager instance for RAG purposes.
130
+ global_memory_manager = MemoryManager()
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # --- Multi-Round Swarm Agent Function ---
134
  @spaces.GPU(duration=180) # Adjust duration as needed
135
+ def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
136
+ prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str
137
+ ) -> Generator[str, None, None]:
138
  """
139
  A three-round iterative process that uses the provided prompt templates:
140
  - Round 1: Brainstorming.
141
  - Round 2: Advanced reasoning & code generation.
142
  - Round 3: Synthesis & refinement.
143
+
144
  This generator yields the response from the final round as it is produced.
 
 
 
145
 
146
+ Yields:
147
+ str: Progressive updates of the final response.
148
+ """
149
  model, tokenizer = get_model_and_tokenizer()
150
 
151
  # ----- Round 1: Brainstorming -----
152
+ logging.info("--- Round 1: Brainstorming ---")
153
+ prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
154
+ input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device)
155
  streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
156
  kwargs_r1 = dict(
157
  input_ids=input_ids_r1,
 
161
  temperature=temp,
162
  top_p=top_p,
163
  )
164
+ try:
165
+ thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
166
+ with torch.no_grad():
167
+ thread_r1.start()
168
+ except Exception as e:
169
+ logging.error(f"Error starting Round 1 thread: {e}")
170
+ raise e
171
 
172
  brainstorm_response = ""
173
+ try:
174
+ for text in streamer_r1:
175
+ logging.info(text)
176
+ brainstorm_response += text
177
+ except Exception as e:
178
+ logging.error(f"Error during Round 1 generation: {e}")
179
+ raise e
180
+ thread_r1.join()
181
+ global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...")
182
 
183
  # ----- Round 2: Code Generation -----
184
+ logging.info("--- Round 2: Code Generation ---")
185
+ prompt_r2 = prompt_code_generation_text.format(
186
  brainstorm_response=brainstorm_response,
187
  user_prompt=user_prompt
188
  )
189
+ input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device)
190
  streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
191
  kwargs_r2 = dict(
192
  input_ids=input_ids_r2,
 
195
  temperature=temp,
196
  top_p=top_p,
197
  )
198
+ try:
199
+ thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
200
+ with torch.no_grad():
201
+ thread_r2.start()
202
+ except Exception as e:
203
+ logging.error(f"Error starting Round 2 thread: {e}")
204
+ raise e
205
 
206
  code_response = ""
207
+ try:
208
+ for text in streamer_r2:
209
+ logging.info(text)
210
+ code_response += text
211
+ except Exception as e:
212
+ logging.error(f"Error during Round 2 generation: {e}")
213
+ raise e
214
+ thread_r2.join()
215
+ global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...")
216
 
217
  # ----- Round 3: Synthesis & Refinement -----
218
+ logging.info("--- Round 3: Synthesis & Refinement ---")
219
+ prompt_r3 = prompt_synthesis_text.format(code_response=code_response)
220
+ input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device)
221
  streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
222
  kwargs_r3 = dict(
223
  input_ids=input_ids_r3,
 
226
  temperature=temp,
227
  top_p=top_p,
228
  )
229
+ try:
230
+ thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
231
+ with torch.no_grad():
232
+ thread_r3.start()
233
+ except Exception as e:
234
+ logging.error(f"Error starting Round 3 thread: {e}")
235
+ raise e
236
 
237
  final_response = ""
238
+ try:
239
+ for text in streamer_r3:
240
+ logging.info(text)
241
+ final_response += text
242
+ yield final_response # Yield progressive updates
243
+ except Exception as e:
244
+ logging.error(f"Error during Round 3 generation: {e}")
245
+ raise e
246
+ thread_r3.join()
247
+ global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...")
248
+
249
+
250
+ # --- Explanation Function for Puns ---
251
+ def handle_explanation_request(user_prompt: str) -> str:
252
+ """
253
+ If the user asks for an explanation of the puns, this function retrieves
254
+ relevant stored memory items (which are expected to include pun examples) and
255
+ constructs a new prompt to generate a detailed explanation.
256
+
257
+ Args:
258
+ user_prompt (str): The user request (e.g. "explain the different puns you mentioned")
259
+
260
+ Returns:
261
+ str: The explanation generated by the model.
262
+ """
263
+ # Retrieve memory items that contain "pun" (assuming previous outputs include puns)
264
+ retrieved = global_memory_manager.retrieve("pun", top_k=3)
265
+ if not retrieved:
266
+ explanation_prompt = "No previous puns found to explain. Please provide the pun examples."
267
+ else:
268
+ explanation_prompt = "Please explain the following coding puns in detail:\n\n"
269
+ for item in retrieved:
270
+ explanation_prompt += f"- {item}\n"
271
+ explanation_prompt += "\nProvide a detailed explanation for each pun."
272
+
273
+ model, tokenizer = get_model_and_tokenizer()
274
+ input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device)
275
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
276
+ kwargs = dict(
277
+ input_ids=input_ids,
278
+ streamer=streamer,
279
+ max_new_tokens=300,
280
+ temperature=0.7,
281
+ top_p=0.9,
282
+ )
283
+ try:
284
+ thread = Thread(target=model.generate, kwargs=kwargs)
285
+ with torch.no_grad():
286
+ thread.start()
287
+ except Exception as e:
288
+ logging.error(f"Error starting explanation thread: {e}")
289
+ raise e
290
+
291
+ explanation = ""
292
+ try:
293
+ for text in streamer:
294
+ explanation += text
295
+ except Exception as e:
296
+ logging.error(f"Error during explanation generation: {e}")
297
+ raise e
298
+ thread.join()
299
+ return explanation
300
 
 
301
 
302
  # --- Helper to Format History ---
303
+ def format_history(history: List) -> List[Dict[str, str]]:
304
  """
305
  Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
306
  into a list of OpenAI-style message dictionaries.
307
+
308
+ Args:
309
+ history (List): List of conversation items.
310
+
311
+ Returns:
312
+ List[Dict[str, str]]: A list of formatted message dictionaries.
313
  """
314
  messages = []
315
  for item in history:
316
  # If item is a list or tuple, try to unpack it if it has exactly 2 elements.
317
+ if isinstance(item, (list, tuple)) and len(item) == 2:
318
+ user_msg, assistant_msg = item
319
+ messages.append({"role": "user", "content": user_msg})
320
+ if assistant_msg:
321
+ messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
 
322
  elif isinstance(item, dict):
 
323
  messages.append(item)
 
 
324
  return messages
325
 
326
+
327
  # --- Gradio Chat Interface Function ---
328
+ def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
329
  """
330
  This function is called by Gradio's ChatInterface.
331
  It uses the current saved generation parameters and prompt templates.
332
+ If the user request appears to ask for an explanation of puns,
333
+ it routes the request to the explanation function.
334
+
335
+ Args:
336
+ message (str): The user message.
337
+ history (List): The conversation history.
338
+ param_state (Dict): Generation parameters.
339
+ prompt_state (Dict): Prompt templates.
340
+
341
+ Yields:
342
+ Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries.
343
  """
344
+ # Check if the user is asking to explain puns.
345
+ if "explain" in message.lower() and "pun" in message.lower():
346
+ explanation = handle_explanation_request(message)
347
+ history = history + [[message, explanation]]
348
+ yield format_history(history)
349
+ return
350
+
351
  try:
352
  temp = float(param_state.get("temperature", 0.5))
353
  top_p = float(param_state.get("top_p", 0.9))
354
  max_new_tokens = int(param_state.get("max_new_tokens", 300))
355
  memory_top_k = int(param_state.get("memory_top_k", 2))
356
+ except Exception as e:
357
+ logging.error(f"Parameter conversion error: {e}")
358
  temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
359
 
 
360
  prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
361
  prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
362
  prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
 
377
  ):
378
  # Update the last assistant message with the new partial response.
379
  history[-1][1] = partial_response
 
380
  yield format_history(history)
381
 
382
+
383
  # --- UI Settings & Styling ---
384
  ui_description = '''
385
  <div>
 
418
  }
419
  """
420
 
421
+
422
  # --- Gradio UI ---
423
  with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
424
  gr.Markdown(ui_description)
425
+
426
  # Hidden States to hold parameters and prompt configuration
427
  param_state = gr.State({
428
  "temperature": 0.5,
 
435
  "prompt_code_generation": default_prompt_code_generation,
436
  "prompt_synthesis": default_prompt_synthesis,
437
  })
438
+
439
  # Create top-level Tabs
440
  with gr.Tabs():
441
  # --- Chat Tab ---
442
  with gr.Tab("Chat"):
 
443
  chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
 
444
  gr.ChatInterface(
445
  fn=gradio_interface,
446
  chatbot=chatbot,
 
455
  cache_examples=False,
456
  type="messages",
457
  )
458
+
459
  # --- Parameters Tab ---
460
  with gr.Tab("Parameters"):
461
  gr.Markdown("### Generation Parameters")
 
464
  max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
465
  memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
466
  save_params_btn = gr.Button("Save Parameters")
 
467
  save_params_btn.click(
468
  lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
469
  inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
470
  outputs=param_state,
471
  )
472
+
473
  # --- Prompt Config Tab ---
474
  with gr.Tab("Prompt Config"):
475
  gr.Markdown("### Configure Prompt Templates")
 
489
  lines=8,
490
  )
491
  save_prompts_btn = gr.Button("Save Prompts")
 
492
  save_prompts_btn.click(
493
  lambda b, c, s: {
494
  "prompt_brainstorm": b,
 
498
  inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
499
  outputs=prompt_state,
500
  )
501
+
502
  gr.Markdown(ui_license)
503
 
504
  if __name__ == "__main__":
505
+ demo.launch()