abrakjamson commited on
Commit
9acb8e6
·
1 Parent(s): 129904f

advanced settings, bug fixes

Browse files
Files changed (1) hide show
  1. app.py +39 -12
app.py CHANGED
@@ -27,7 +27,7 @@ model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
27
  model = ControlModel(model, list(range(-5, -18, -1)))
28
 
29
  # Generation settings
30
- generation_settings = {
31
  "pad_token_id": tokenizer.eos_token_id, # Silence warning
32
  "do_sample": False, # Deterministic output
33
  "max_new_tokens": 256,
@@ -48,14 +48,19 @@ def toggle_slider(checked):
48
  return gr.update(visible=checked)
49
 
50
  # Function to generate the model's response
51
- def generate_response(system_prompt, user_message, *args, history):
 
 
 
52
  # Separate checkboxes and sliders based on type
53
- print(f"Generating response to {user_message}")
54
- checkboxes = [item for item in args if isinstance(item, bool)]
55
- sliders = [item for item in args if isinstance(item, (int, float))]
 
 
56
 
57
  if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
58
- return history # Return current history if there's a mismatch
59
 
60
  # Reset any previous control vectors
61
  model.reset()
@@ -66,7 +71,6 @@ def generate_response(system_prompt, user_message, *args, history):
66
  cv_file = control_vector_files[i]
67
  weight = sliders[i]
68
  try:
69
- print(f"Setting {cv_file} to {weight}")
70
  control_vector = ControlVector.import_gguf(cv_file)
71
  model.set_control(control_vector, weight)
72
  except Exception as e:
@@ -91,8 +95,15 @@ def generate_response(system_prompt, user_message, *args, history):
91
  # Tokenize the input
92
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
93
 
 
 
 
 
 
 
 
94
  # Generate the response
95
- output_ids = model.generate(**input_ids, **generation_settings)
96
  response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
97
 
98
  # Clean up the response by removing any trailing tags
@@ -101,7 +112,7 @@ def generate_response(system_prompt, user_message, *args, history):
101
 
102
  # Update conversation history
103
  history.append((user_message, response))
104
- return history
105
 
106
  # Function to reset the conversation history
107
  def reset_chat():
@@ -120,7 +131,7 @@ with gr.Blocks() as demo:
120
  system_prompt = gr.Textbox(
121
  label="System Prompt",
122
  lines=2,
123
- placeholder="Enter system-level instructions here..."
124
  )
125
 
126
  gr.Markdown("### 📊 Control Vectors")
@@ -152,6 +163,22 @@ with gr.Blocks() as demo:
152
  outputs=slider
153
  )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Right Column: Chat Interface
156
  with gr.Column(scale=2):
157
  gr.Markdown("### 🗨️ Conversation")
@@ -172,7 +199,7 @@ with gr.Blocks() as demo:
172
  new_chat_button = gr.Button("🆕 New Chat")
173
 
174
  # State to keep track of conversation history
175
- state = gr.State([])
176
 
177
  # Define button actions
178
  submit_button.click(
@@ -180,7 +207,7 @@ with gr.Blocks() as demo:
180
  inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
181
  outputs=[chatbot, state]
182
  )
183
-
184
  new_chat_button.click(
185
  reset_chat,
186
  inputs=[],
 
27
  model = ControlModel(model, list(range(-5, -18, -1)))
28
 
29
  # Generation settings
30
+ default_generation_settings = {
31
  "pad_token_id": tokenizer.eos_token_id, # Silence warning
32
  "do_sample": False, # Deterministic output
33
  "max_new_tokens": 256,
 
48
  return gr.update(visible=checked)
49
 
50
  # Function to generate the model's response
51
+ def generate_response(system_prompt, user_message, *args, history=None, max_new_tokens=256, repetition_penalty=1.1):
52
+ checkboxes = []
53
+ sliders = []
54
+
55
  # Separate checkboxes and sliders based on type
56
+ for item in args:
57
+ if type(item) == bool:
58
+ checkboxes.append(item)
59
+ elif isinstance(item, (int, float)):
60
+ sliders.append(item)
61
 
62
  if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
63
+ return history if history else [], history if history else []
64
 
65
  # Reset any previous control vectors
66
  model.reset()
 
71
  cv_file = control_vector_files[i]
72
  weight = sliders[i]
73
  try:
 
74
  control_vector = ControlVector.import_gguf(cv_file)
75
  model.set_control(control_vector, weight)
76
  except Exception as e:
 
95
  # Tokenize the input
96
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
97
 
98
+ generation_settings = {
99
+ "pad_token_id": tokenizer.eos_token_id,
100
+ "do_sample": default_generation_settings["do_sample"],
101
+ "max_new_tokens": int(max_new_tokens),
102
+ "repetition_penalty": repetition_penalty,
103
+ }
104
+
105
  # Generate the response
106
+ output_ids = model.generate(**input_ids, **default_generation_settings)
107
  response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
108
 
109
  # Clean up the response by removing any trailing tags
 
112
 
113
  # Update conversation history
114
  history.append((user_message, response))
115
+ return history, history
116
 
117
  # Function to reset the conversation history
118
  def reset_chat():
 
131
  system_prompt = gr.Textbox(
132
  label="System Prompt",
133
  lines=2,
134
+ placeholder="Respond tot he user concisely"
135
  )
136
 
137
  gr.Markdown("### 📊 Control Vectors")
 
163
  outputs=slider
164
  )
165
 
166
+ # Advanced Settings Section (collapsed by default)
167
+ with gr.Accordion("🔧 Advanced Settings", open=False):
168
+ with gr.Row():
169
+ max_new_tokens = gr.Number(
170
+ label="Max New Tokens",
171
+ value=default_generation_settings["max_new_tokens"],
172
+ precision=0,
173
+ step=10,
174
+ )
175
+ repetition_penalty = gr.Number(
176
+ label="Repetition Penalty",
177
+ value=default_generation_settings["repetition_penalty"],
178
+ precision=2,
179
+ step=0.1,
180
+ )
181
+
182
  # Right Column: Chat Interface
183
  with gr.Column(scale=2):
184
  gr.Markdown("### 🗨️ Conversation")
 
199
  new_chat_button = gr.Button("🆕 New Chat")
200
 
201
  # State to keep track of conversation history
202
+ state = gr.State()
203
 
204
  # Define button actions
205
  submit_button.click(
 
207
  inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
208
  outputs=[chatbot, state]
209
  )
210
+
211
  new_chat_button.click(
212
  reset_chat,
213
  inputs=[],