Daemontatox commited on
Commit
0b72fd3
·
verified ·
1 Parent(s): 9e07bfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -102
app.py CHANGED
@@ -1,14 +1,10 @@
1
  import subprocess
2
 
3
-
4
-
5
  subprocess.run(
6
  'pip install flash-attn --no-build-isolation',
7
  env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
8
  shell=True
9
  )
10
-
11
-
12
  import os
13
  import re
14
  import time
@@ -59,11 +55,7 @@ Always organize your responses using these tags for clear reasoning structure.""
59
 
60
  # UI Configuration
61
  TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
62
- PLACEHOLDER = """
63
- <center>
64
- <p>Ask me anything! I'll think through it step by step.</p>
65
- </center>
66
- """
67
 
68
  CSS = """
69
  .duplicate-button {
@@ -99,23 +91,24 @@ h3 {
99
  color: #0066cc;
100
  font-weight: bold;
101
  }
 
 
 
 
102
  """
103
 
104
  def initialize_model():
105
  """Initialize the model with appropriate configurations"""
106
- # Quantization configuration
107
  quantization_config = BitsAndBytesConfig(
108
  load_in_4bit=True,
109
  bnb_4bit_compute_dtype=torch.bfloat16,
110
  bnb_4bit_use_double_quant=True
111
  )
112
 
113
- # Initialize tokenizer
114
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
115
  if tokenizer.pad_token_id is None:
116
  tokenizer.pad_token_id = tokenizer.eos_token_id
117
 
118
- # Initialize model
119
  model = AutoModelForCausalLM.from_pretrained(
120
  MODEL_ID,
121
  torch_dtype=torch.float16,
@@ -128,7 +121,6 @@ def initialize_model():
128
 
129
  def format_text(text):
130
  """Format text with proper spacing and tag highlighting"""
131
- # Add newlines around tags
132
  tag_patterns = [
133
  (r'<Thinking>', '\n<Thinking>\n'),
134
  (r'</Thinking>', '\n</Thinking>\n'),
@@ -144,15 +136,24 @@ def format_text(text):
144
  for pattern, replacement in tag_patterns:
145
  formatted = re.sub(pattern, replacement, formatted)
146
 
147
- # Remove extra blank lines
148
  formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
149
 
150
  return formatted
151
 
 
 
 
 
 
 
 
 
 
152
  @spaces.GPU()
153
- def stream_chat(
154
  message: str,
155
  history: list,
 
156
  system_prompt: str,
157
  temperature: float = 0.2,
158
  max_new_tokens: int = 8192,
@@ -160,30 +161,25 @@ def stream_chat(
160
  top_k: int = 20,
161
  penalty: float = 1.2,
162
  ):
163
- """Generate streaming chat responses with proper tag handling"""
164
- # Format conversation context
165
  conversation = [
166
  {"role": "system", "content": system_prompt}
167
  ]
168
 
169
- # Add conversation history
170
  for prompt, answer in history:
171
  conversation.extend([
172
  {"role": "user", "content": prompt},
173
  {"role": "assistant", "content": answer}
174
  ])
175
 
176
- # Add current message
177
  conversation.append({"role": "user", "content": message})
178
 
179
- # Prepare input for model
180
  input_ids = tokenizer.apply_chat_template(
181
  conversation,
182
  add_generation_prompt=True,
183
  return_tensors="pt"
184
  ).to(model.device)
185
 
186
- # Configure streamer
187
  streamer = TextIteratorStreamer(
188
  tokenizer,
189
  timeout=60.0,
@@ -191,7 +187,6 @@ def stream_chat(
191
  skip_special_tokens=True
192
  )
193
 
194
- # Set generation parameters
195
  generate_kwargs = dict(
196
  input_ids=input_ids,
197
  max_new_tokens=max_new_tokens,
@@ -203,7 +198,6 @@ def stream_chat(
203
  streamer=streamer,
204
  )
205
 
206
- # Generate and stream response
207
  buffer = ""
208
  current_line = ""
209
 
@@ -211,6 +205,8 @@ def stream_chat(
211
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
212
  thread.start()
213
 
 
 
214
  for new_text in streamer:
215
  buffer += new_text
216
  current_line += new_text
@@ -219,35 +215,23 @@ def stream_chat(
219
  lines = current_line.split('\n')
220
  current_line = lines[-1]
221
  formatted_buffer = format_text(buffer)
222
- yield formatted_buffer
 
 
223
  else:
224
- yield buffer
 
 
225
 
226
- def create_examples():
227
- """Create example queries that demonstrate the system's capabilities"""
228
- return [
229
- ["Explain how neural networks learn through backpropagation."],
230
- ["What are the key differences between classical and quantum computing?"],
231
- ["Analyze the environmental impact of renewable energy sources."],
232
- ["How does the human memory system work?"],
233
- ["Explain the concept of ethical AI and its importance."]
234
- ]
235
 
236
  def main():
237
  """Main function to set up and launch the Gradio interface"""
238
- # Initialize model and tokenizer
239
  global model, tokenizer
240
  model, tokenizer = initialize_model()
241
 
242
- # Create chatbot interface
243
- chatbot = gr.Chatbot(
244
- height=600,
245
- placeholder=PLACEHOLDER,
246
- bubble_full_width=False,
247
- show_copy_button=True
248
- )
249
-
250
- # Create interface
251
  with gr.Blocks(css=CSS, theme="soft") as demo:
252
  gr.HTML(TITLE)
253
  gr.DuplicateButton(
@@ -255,66 +239,119 @@ def main():
255
  elem_classes="duplicate-button"
256
  )
257
 
258
- gr.ChatInterface(
259
- fn=stream_chat,
260
- chatbot=chatbot,
261
- fill_height=True,
262
- additional_inputs_accordion=gr.Accordion(
263
- label="⚙️ Advanced Settings",
264
- open=False,
265
- render=False
266
- ),
267
- additional_inputs=[
268
- gr.Textbox(
269
- value=DEFAULT_SYSTEM_PROMPT,
270
- label="System Prompt",
271
- lines=5,
272
- render=False,
273
- ),
274
- gr.Slider(
275
- minimum=0,
276
- maximum=1,
277
- step=0.1,
278
- value=0.2,
279
- label="Temperature",
280
- render=False,
281
- ),
282
- gr.Slider(
283
- minimum=128,
284
- maximum=32000,
285
- step=128,
286
- value=8192,
287
- label="Max Tokens",
288
- render=False,
289
- ),
290
- gr.Slider(
291
- minimum=0.1,
292
- maximum=1.0,
293
- step=0.1,
294
- value=1.0,
295
- label="Top-p",
296
- render=False,
297
- ),
298
- gr.Slider(
299
- minimum=1,
300
- maximum=100,
301
- step=1,
302
- value=20,
303
- label="Top-k",
304
- render=False,
305
- ),
306
- gr.Slider(
307
- minimum=1.0,
308
- maximum=2.0,
309
- step=0.1,
310
- value=1.2,
311
- label="Repetition Penalty",
312
- render=False,
313
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  ],
315
- examples=create_examples(),
316
- cache_examples=False,
317
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  return demo
320
 
 
1
  import subprocess
2
 
 
 
3
  subprocess.run(
4
  'pip install flash-attn --no-build-isolation',
5
  env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
6
  shell=True
7
  )
 
 
8
  import os
9
  import re
10
  import time
 
55
 
56
  # UI Configuration
57
  TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
58
+ PLACEHOLDER = "Ask me anything! I'll think through it step by step."
 
 
 
 
59
 
60
  CSS = """
61
  .duplicate-button {
 
91
  color: #0066cc;
92
  font-weight: bold;
93
  }
94
+ .chat-area {
95
+ height: 500px !important;
96
+ overflow-y: auto !important;
97
+ }
98
  """
99
 
100
  def initialize_model():
101
  """Initialize the model with appropriate configurations"""
 
102
  quantization_config = BitsAndBytesConfig(
103
  load_in_4bit=True,
104
  bnb_4bit_compute_dtype=torch.bfloat16,
105
  bnb_4bit_use_double_quant=True
106
  )
107
 
 
108
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
109
  if tokenizer.pad_token_id is None:
110
  tokenizer.pad_token_id = tokenizer.eos_token_id
111
 
 
112
  model = AutoModelForCausalLM.from_pretrained(
113
  MODEL_ID,
114
  torch_dtype=torch.float16,
 
121
 
122
  def format_text(text):
123
  """Format text with proper spacing and tag highlighting"""
 
124
  tag_patterns = [
125
  (r'<Thinking>', '\n<Thinking>\n'),
126
  (r'</Thinking>', '\n</Thinking>\n'),
 
136
  for pattern, replacement in tag_patterns:
137
  formatted = re.sub(pattern, replacement, formatted)
138
 
 
139
  formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
140
 
141
  return formatted
142
 
143
+ def format_chat_history(history):
144
+ """Format chat history for display in text area"""
145
+ formatted = []
146
+ for user_msg, assistant_msg in history:
147
+ formatted.append(f"User: {user_msg}")
148
+ if assistant_msg:
149
+ formatted.append(f"Assistant: {assistant_msg}")
150
+ return "\n\n".join(formatted)
151
+
152
  @spaces.GPU()
153
+ def chat_response(
154
  message: str,
155
  history: list,
156
+ chat_display: str,
157
  system_prompt: str,
158
  temperature: float = 0.2,
159
  max_new_tokens: int = 8192,
 
161
  top_k: int = 20,
162
  penalty: float = 1.2,
163
  ):
164
+ """Generate chat responses with proper tag handling"""
 
165
  conversation = [
166
  {"role": "system", "content": system_prompt}
167
  ]
168
 
 
169
  for prompt, answer in history:
170
  conversation.extend([
171
  {"role": "user", "content": prompt},
172
  {"role": "assistant", "content": answer}
173
  ])
174
 
 
175
  conversation.append({"role": "user", "content": message})
176
 
 
177
  input_ids = tokenizer.apply_chat_template(
178
  conversation,
179
  add_generation_prompt=True,
180
  return_tensors="pt"
181
  ).to(model.device)
182
 
 
183
  streamer = TextIteratorStreamer(
184
  tokenizer,
185
  timeout=60.0,
 
187
  skip_special_tokens=True
188
  )
189
 
 
190
  generate_kwargs = dict(
191
  input_ids=input_ids,
192
  max_new_tokens=max_new_tokens,
 
198
  streamer=streamer,
199
  )
200
 
 
201
  buffer = ""
202
  current_line = ""
203
 
 
205
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
206
  thread.start()
207
 
208
+ history = history + [[message, ""]]
209
+
210
  for new_text in streamer:
211
  buffer += new_text
212
  current_line += new_text
 
215
  lines = current_line.split('\n')
216
  current_line = lines[-1]
217
  formatted_buffer = format_text(buffer)
218
+ history[-1][1] = formatted_buffer
219
+ chat_display = format_chat_history(history)
220
+ yield history, chat_display
221
  else:
222
+ history[-1][1] = buffer
223
+ chat_display = format_chat_history(history)
224
+ yield history, chat_display
225
 
226
+ def process_example(example: str) -> tuple:
227
+ """Process example query and return empty history and updated display"""
228
+ return [], f"User: {example}\n\n"
 
 
 
 
 
 
229
 
230
  def main():
231
  """Main function to set up and launch the Gradio interface"""
 
232
  global model, tokenizer
233
  model, tokenizer = initialize_model()
234
 
 
 
 
 
 
 
 
 
 
235
  with gr.Blocks(css=CSS, theme="soft") as demo:
236
  gr.HTML(TITLE)
237
  gr.DuplicateButton(
 
239
  elem_classes="duplicate-button"
240
  )
241
 
242
+ with gr.Row():
243
+ with gr.Column():
244
+ chat_history = gr.State([])
245
+ chat_display = gr.TextArea(
246
+ value="",
247
+ label="Chat History",
248
+ interactive=False,
249
+ elem_classes=["chat-area"],
250
+ )
251
+
252
+ message = gr.TextArea(
253
+ placeholder=PLACEHOLDER,
254
+ label="Your message",
255
+ lines=3
256
+ )
257
+
258
+ with gr.Row():
259
+ submit = gr.Button("Send")
260
+ clear = gr.Button("Clear")
261
+
262
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
263
+ system_prompt = gr.TextArea(
264
+ value=DEFAULT_SYSTEM_PROMPT,
265
+ label="System Prompt",
266
+ lines=5,
267
+ )
268
+ temperature = gr.Slider(
269
+ minimum=0,
270
+ maximum=1,
271
+ step=0.1,
272
+ value=0.2,
273
+ label="Temperature",
274
+ )
275
+ max_tokens = gr.Slider(
276
+ minimum=128,
277
+ maximum=32000,
278
+ step=128,
279
+ value=8192,
280
+ label="Max Tokens",
281
+ )
282
+ top_p = gr.Slider(
283
+ minimum=0.1,
284
+ maximum=1.0,
285
+ step=0.1,
286
+ value=1.0,
287
+ label="Top-p",
288
+ )
289
+ top_k = gr.Slider(
290
+ minimum=1,
291
+ maximum=100,
292
+ step=1,
293
+ value=20,
294
+ label="Top-k",
295
+ )
296
+ penalty = gr.Slider(
297
+ minimum=1.0,
298
+ maximum=2.0,
299
+ step=0.1,
300
+ value=1.2,
301
+ label="Repetition Penalty",
302
+ )
303
+
304
+ examples = gr.Examples(
305
+ examples=create_examples(),
306
+ inputs=[message],
307
+ outputs=[chat_history, chat_display],
308
+ fn=process_example,
309
+ cache_examples=False,
310
+ )
311
+
312
+ # Set up event handlers
313
+ submit_click = submit.click(
314
+ chat_response,
315
+ inputs=[
316
+ message,
317
+ chat_history,
318
+ chat_display,
319
+ system_prompt,
320
+ temperature,
321
+ max_tokens,
322
+ top_p,
323
+ top_k,
324
+ penalty,
325
  ],
326
+ outputs=[chat_history, chat_display],
327
+ show_progress=True,
328
  )
329
+
330
+ message.submit(
331
+ chat_response,
332
+ inputs=[
333
+ message,
334
+ chat_history,
335
+ chat_display,
336
+ system_prompt,
337
+ temperature,
338
+ max_tokens,
339
+ top_p,
340
+ top_k,
341
+ penalty,
342
+ ],
343
+ outputs=[chat_history, chat_display],
344
+ show_progress=True,
345
+ )
346
+
347
+ clear.click(
348
+ lambda: ([], ""),
349
+ outputs=[chat_history, chat_display],
350
+ show_progress=True,
351
+ )
352
+
353
+ submit_click.then(lambda: "", outputs=message)
354
+ message.submit(lambda: "", outputs=message)
355
 
356
  return demo
357