abrakjamson commited on
Commit
d646867
·
1 Parent(s): 453c7fc

Adding control info to conversation display

Browse files
Files changed (1) hide show
  1. app.py +48 -15
app.py CHANGED
@@ -31,7 +31,7 @@ model = ControlModel(model, list(range(-5, -18, -1)))
31
  default_generation_settings = {
32
  "pad_token_id": tokenizer.eos_token_id, # Silence warning
33
  "do_sample": False, # Deterministic output
34
- "max_new_tokens": 256,
35
  "repetition_penalty": 1.1, # Reduce repetition
36
  }
37
 
@@ -69,6 +69,7 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
69
  model.reset()
70
 
71
  # Apply selected control vectors with their corresponding weights
 
72
  for i in range(len(control_vector_files)):
73
  if checkboxes[i]:
74
  cv_file = control_vector_files[i]
@@ -76,25 +77,30 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
76
  try:
77
  control_vector = ControlVector.import_gguf(cv_file)
78
  model.set_control(control_vector, weight)
 
79
  except Exception as e:
80
  print(f"Failed to set control vector {cv_file}: {e}")
81
 
82
  formatted_prompt = ""
83
 
84
-
85
  # Mistral expects the history to be wrapped in <s>history</s>
86
  if len(history) > 0:
87
  formatted_prompt += "<s>"
88
 
89
  # Append the system prompt if provided
90
  if system_prompt.strip():
91
- formatted_prompt += f"[INST] {system_prompt} [/INST] "
92
 
93
  # Construct the formatted prompt based on history
94
  if len(history) > 0:
95
  for turn in history:
96
- user_msg, asst_msg = turn
97
- formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
 
 
 
 
98
 
99
  if len(history) > 0:
100
  formatted_prompt += "</s>"
@@ -127,9 +133,19 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
127
  assistant_response = get_assistant_response(response)
128
 
129
  # Update conversation history
130
- history.append((user_message, assistant_response))
 
 
 
 
131
  return history
132
 
 
 
 
 
 
 
133
  # Function to reset the conversation history
134
  def reset_chat():
135
  # returns a blank user input text and a blank conversation history
@@ -137,7 +153,8 @@ def reset_chat():
137
 
138
  # Build the Gradio interface
139
  with gr.Blocks() as demo:
140
- gr.Markdown("# 🧠 Mistral v3 Language Model Interface")
 
141
 
142
  with gr.Row():
143
  # Left Column: Settings and Control Vectors
@@ -148,14 +165,16 @@ with gr.Blocks() as demo:
148
  system_prompt = gr.Textbox(
149
  label="System Prompt",
150
  lines=2,
151
- placeholder="Respond tot he user concisely"
152
  )
153
 
154
- gr.Markdown("### 📊 Control Vectors")
 
155
 
156
  # Create checkboxes and sliders for each control vector
157
  control_checks = []
158
  control_sliders = []
 
159
  for cv_file in control_vector_files:
160
  with gr.Row():
161
  # Checkbox to select the control vector
@@ -180,11 +199,11 @@ with gr.Blocks() as demo:
180
  outputs=slider
181
  )
182
 
183
- # Advanced Settings Section (collapsed by default)
184
  with gr.Accordion("🔧 Advanced Settings", open=False):
185
  with gr.Row():
186
  max_new_tokens = gr.Number(
187
- label="Max New Tokens",
188
  value=default_generation_settings["max_new_tokens"],
189
  precision=0,
190
  step=10,
@@ -193,7 +212,7 @@ with gr.Blocks() as demo:
193
  label="Repetition Penalty",
194
  value=default_generation_settings["repetition_penalty"],
195
  precision=2,
196
- step=0.1,
197
  )
198
 
199
  # Right Column: Chat Interface
@@ -201,19 +220,21 @@ with gr.Blocks() as demo:
201
  gr.Markdown("### 🗨️ Conversation")
202
 
203
  # Chatbot to display conversation
204
- chatbot = gr.Chatbot(label="Conversation")
205
 
206
  # User Message Input
207
  user_input = gr.Textbox(
208
- label="Your Message",
209
  lines=2,
210
- placeholder="Type your message here..."
211
  )
212
 
213
  with gr.Row():
214
  # Submit and New Chat buttons
215
  submit_button = gr.Button("💬 Submit")
 
216
  new_chat_button = gr.Button("🆕 New Chat")
 
217
 
218
  inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
219
 
@@ -223,6 +244,18 @@ with gr.Blocks() as demo:
223
  inputs=inputs_list,
224
  outputs=[chatbot]
225
  )
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  new_chat_button.click(
228
  reset_chat,
 
31
  default_generation_settings = {
32
  "pad_token_id": tokenizer.eos_token_id, # Silence warning
33
  "do_sample": False, # Deterministic output
34
+ "max_new_tokens": 384,
35
  "repetition_penalty": 1.1, # Reduce repetition
36
  }
37
 
 
69
  model.reset()
70
 
71
  # Apply selected control vectors with their corresponding weights
72
+ assistant_message_title = ""
73
  for i in range(len(control_vector_files)):
74
  if checkboxes[i]:
75
  cv_file = control_vector_files[i]
 
77
  try:
78
  control_vector = ControlVector.import_gguf(cv_file)
79
  model.set_control(control_vector, weight)
80
+ assistant_message_title += f"{cv_file}: {weight};"
81
  except Exception as e:
82
  print(f"Failed to set control vector {cv_file}: {e}")
83
 
84
  formatted_prompt = ""
85
 
86
+ # <s>[INST] user message[/INST] assistant message</s>[INST] new user message[/INST]
87
  # Mistral expects the history to be wrapped in <s>history</s>
88
  if len(history) > 0:
89
  formatted_prompt += "<s>"
90
 
91
  # Append the system prompt if provided
92
  if system_prompt.strip():
93
+ formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "
94
 
95
  # Construct the formatted prompt based on history
96
  if len(history) > 0:
97
  for turn in history:
98
+ # TODO use history[0].role and history[0].content to replace this
99
+ # consider what tags to use
100
+ if turn.role == 'user':
101
+ formatted_prompt += f"{user_tag} {turn.content}{asst_tag}"
102
+ elif turn.role == 'assistant':
103
+ formatted_prompt += f" {turn.content}"
104
 
105
  if len(history) > 0:
106
  formatted_prompt += "</s>"
 
133
  assistant_response = get_assistant_response(response)
134
 
135
  # Update conversation history
136
+ assistant_response = get_assistant_response(response)
137
+ assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"
138
+
139
+ # Update conversation history
140
+ history.append((user_message, assistant_response_display))
141
  return history
142
 
143
+ def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
144
+ # Remove last user input and assistant response from history, then call generate_response()
145
+ if history:
146
+ history = history[0:-1]
147
+ return generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args)
148
+
149
  # Function to reset the conversation history
150
  def reset_chat():
151
  # returns a blank user input text and a blank conversation history
 
153
 
154
  # Build the Gradio interface
155
  with gr.Blocks() as demo:
156
+ gr.Markdown("# 🧠🧑‍🔬 LLM Brain Control")
157
+ gr.Markdown("Usage demo: (link)")
158
 
159
  with gr.Row():
160
  # Left Column: Settings and Control Vectors
 
165
  system_prompt = gr.Textbox(
166
  label="System Prompt",
167
  lines=2,
168
+ value="Respond to the user concisely"
169
  )
170
 
171
+ gr.Markdown("### Control Vectors")
172
+ gr.Markdown("Select how you want to control the LLM. Values greater than +/- 1.5 may overload it.")
173
 
174
  # Create checkboxes and sliders for each control vector
175
  control_checks = []
176
  control_sliders = []
177
+
178
  for cv_file in control_vector_files:
179
  with gr.Row():
180
  # Checkbox to select the control vector
 
199
  outputs=slider
200
  )
201
 
202
+ # Advanced Settings Section (collapsed by default)
203
  with gr.Accordion("🔧 Advanced Settings", open=False):
204
  with gr.Row():
205
  max_new_tokens = gr.Number(
206
+ label="Max Response Length (in tokens)",
207
  value=default_generation_settings["max_new_tokens"],
208
  precision=0,
209
  step=10,
 
212
  label="Repetition Penalty",
213
  value=default_generation_settings["repetition_penalty"],
214
  precision=2,
215
+ step=0.1
216
  )
217
 
218
  # Right Column: Chat Interface
 
220
  gr.Markdown("### 🗨️ Conversation")
221
 
222
  # Chatbot to display conversation
223
+ chatbot = gr.Chatbot(label="Conversation", type='tuples')
224
 
225
  # User Message Input
226
  user_input = gr.Textbox(
227
+ label="Your Message (Shift+Enter submits)",
228
  lines=2,
229
+ placeholder="I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?"
230
  )
231
 
232
  with gr.Row():
233
  # Submit and New Chat buttons
234
  submit_button = gr.Button("💬 Submit")
235
+ retry_button = gr.Button("🔃 Retry last turn")
236
  new_chat_button = gr.Button("🆕 New Chat")
237
+
238
 
239
  inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
240
 
 
244
  inputs=inputs_list,
245
  outputs=[chatbot]
246
  )
247
+
248
+ user_input.submit(
249
+ generate_response,
250
+ inputs=inputs_list,
251
+ outputs=[chatbot]
252
+ )
253
+
254
+ retry_button.click(
255
+ generate_response_with_retry,
256
+ inputs=inputs_list,
257
+ outputs=[chatbot]
258
+ )
259
 
260
  new_chat_button.click(
261
  reset_chat,