HMC83 commited on
Commit
6bcad72
·
verified ·
1 Parent(s): 3d4de52

Update app.py

Browse files

better handling of messier outputs

Files changed (1) hide show
  1. app.py +92 -45
app.py CHANGED
@@ -180,16 +180,55 @@ ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATI
180
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # --- Backend Function for Local Inference ---
184
  @spaces.GPU
185
  def generate_request_local(authority, kw1, kw2, kw3):
186
- """Generates a request using the locally loaded transformer model on a dynamically allocated GPU."""
187
  if not model or not tokenizer:
188
  return "Error: Model is not loaded. Please check the Space logs for details."
189
-
190
  keywords = [kw for kw in [kw1, kw2, kw3] if kw]
191
  keyword_string = ", ".join(keywords)
192
-
193
  prompt = (
194
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
195
  f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}
@@ -203,48 +242,56 @@ def generate_request_local(authority, kw1, kw2, kw3):
203
  on the part of the FOI officer to answer the request. No clarification should be needed.
204
  7. Do not ask for all policies, or all information
205
  8. End with "Yours Faithfully, [Your Name]" exactly
206
-
207
  Make the requests specific, professional, and relevant to what this public authority would reasonably hold.
208
  Use accessible language, avoiding terms that are overly legalistic or technical and UK English. Be clear and concise"""
209
  )
210
-
211
- try:
212
- # Tokenize the input prompt
213
- inputs = tokenizer(prompt, return_tensors="pt")
214
-
215
- # Move to the model's device (should be cuda when GPU is allocated)
216
- inputs = inputs.to(model.device)
217
-
218
- # Set generation parameters to match notebook for better performance
219
- generation_params = {
220
- "max_new_tokens": 500,
221
- "temperature": 0.3,
222
- "top_p": 0.95,
223
- "top_k": 50,
224
- "repetition_penalty": 1.05,
225
- "do_sample": True,
226
- "pad_token_id": tokenizer.eos_token_id
227
- }
228
-
229
- # Generate text sequences
230
- output_sequences = model.generate(**inputs, **generation_params)
231
-
232
- # Decode the generated text, skipping special tokens and the original prompt
233
- generated_text = tokenizer.decode(
234
- output_sequences[0][len(inputs["input_ids"][0]):],
235
- skip_special_tokens=True
236
- ).strip()
237
-
238
- # Add this check to remove artifact
239
- if generated_text.startswith('.\n'):
240
- generated_text = generated_text[2:]
241
- return generated_text
242
-
243
- except Exception as e:
244
- print(f"Error during generation: {e}")
245
- print(f"Model device: {model.device if model else 'No model'}")
246
- print(f"Inputs type: {type(inputs) if 'inputs' in locals() else 'Not created'}")
247
- return f"An error occurred during text generation: {e}"
 
 
 
 
 
 
 
 
 
248
 
249
  # --- Gradio UI and Spinning Logic ---
250
  def spin_the_reels():
@@ -263,7 +310,7 @@ def spin_the_reels():
263
  "Spinning..."
264
  )
265
  time.sleep(spin_interval)
266
-
267
  # 2. Select the final fixed combination
268
  final_combination = random.choice(FOI_COMBINATIONS)
269
  final_authority = final_combination["authority"]
@@ -278,7 +325,7 @@ def spin_the_reels():
278
  final_authority, kw1, kw2, kw3,
279
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
280
  )
281
-
282
  # 3. Call the local model and yield the final result
283
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
284
  yield (
@@ -334,7 +381,7 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
334
  reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
335
  reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
336
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
337
-
338
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
339
 
340
  output_request = gr.Textbox(
 
180
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
181
 
182
 
183
+ # --- Helper Function for Cleaning and Validation ---
184
+ def clean_and_validate_output(raw_text: str) -> tuple[str, bool]:
185
+ """
186
+ Cleans the model's output by keeping only the first complete request.
187
+
188
+ It validates that the output contains essential markers ("Dear" and "[Your Name]").
189
+ If it detects that the model has started generating a second request, it truncates
190
+ the string after the first "[Your Name]".
191
+
192
+ Args:
193
+ raw_text: The raw string output from the language model.
194
+
195
+ Returns:
196
+ A tuple containing:
197
+ - The cleaned text.
198
+ - A boolean flag: True if the output is valid, False if it is malformed.
199
+ """
200
+ end_marker = "[Your Name]"
201
+ start_marker = "Dear"
202
+
203
+ # Validate: A valid request must contain the end marker.
204
+ if end_marker not in raw_text:
205
+ return raw_text, False # Malformed, signal for regeneration.
206
+
207
+ # Find the end of the first complete request.
208
+ first_end_pos = raw_text.find(end_marker)
209
+ end_of_first_request_index = first_end_pos + len(end_marker)
210
+
211
+ # Check if a second request has started after the first one ended.
212
+ start_of_second_request_pos = raw_text.find(start_marker, end_of_first_request_index)
213
+
214
+ if start_of_second_request_pos != -1:
215
+ # If a second request is found, truncate to keep only the first one.
216
+ cleaned_text = raw_text[:end_of_first_request_index]
217
+ return cleaned_text, True
218
+ else:
219
+ # No second request found, the output is valid.
220
+ return raw_text, True
221
+
222
+
223
  # --- Backend Function for Local Inference ---
224
  @spaces.GPU
225
  def generate_request_local(authority, kw1, kw2, kw3):
226
+ """Generates a request using the locally loaded transformer model, with validation and retry logic."""
227
  if not model or not tokenizer:
228
  return "Error: Model is not loaded. Please check the Space logs for details."
229
+
230
  keywords = [kw for kw in [kw1, kw2, kw3] if kw]
231
  keyword_string = ", ".join(keywords)
 
232
  prompt = (
233
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
234
  f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}
 
242
  on the part of the FOI officer to answer the request. No clarification should be needed.
243
  7. Do not ask for all policies, or all information
244
  8. End with "Yours Faithfully, [Your Name]" exactly
 
245
  Make the requests specific, professional, and relevant to what this public authority would reasonably hold.
246
  Use accessible language, avoiding terms that are overly legalistic or technical and UK English. Be clear and concise"""
247
  )
248
+
249
+ max_retries = 3
250
+ for attempt in range(max_retries):
251
+ try:
252
+ # Tokenize the input prompt
253
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
254
+
255
+ # Set generation parameters
256
+ generation_params = {
257
+ "max_new_tokens": 500,
258
+ "temperature": 0.3,
259
+ "top_p": 0.95,
260
+ "top_k": 50,
261
+ "repetition_penalty": 1.05,
262
+ "do_sample": True,
263
+ "pad_token_id": tokenizer.eos_token_id
264
+ }
265
+
266
+ # Generate text sequences
267
+ output_sequences = model.generate(**inputs, **generation_params)
268
+
269
+ # Decode the generated text
270
+ generated_text = tokenizer.decode(
271
+ output_sequences[0][len(inputs["input_ids"][0]):],
272
+ skip_special_tokens=True
273
+ ).strip()
274
+
275
+ # Remove artifact if present
276
+ if generated_text.startswith('.\n'):
277
+ generated_text = generated_text[2:]
278
+
279
+ # **NEW**: Clean and validate the output
280
+ cleaned_text, is_valid = clean_and_validate_output(generated_text)
281
+
282
+ if is_valid:
283
+ return cleaned_text # Success! Return the valid, cleaned text.
284
+ else:
285
+ print(f"Attempt {attempt + 1}/{max_retries}: Malformed output detected. Retrying...")
286
+
287
+ except Exception as e:
288
+ print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
289
+ if attempt == max_retries - 1:
290
+ return f"An error occurred during text generation: {e}"
291
+
292
+ # If the loop finishes, all retries have failed
293
+ return "Failed to generate a valid request after multiple attempts. Please try again."
294
+
295
 
296
  # --- Gradio UI and Spinning Logic ---
297
  def spin_the_reels():
 
310
  "Spinning..."
311
  )
312
  time.sleep(spin_interval)
313
+
314
  # 2. Select the final fixed combination
315
  final_combination = random.choice(FOI_COMBINATIONS)
316
  final_authority = final_combination["authority"]
 
325
  final_authority, kw1, kw2, kw3,
326
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
327
  )
328
+
329
  # 3. Call the local model and yield the final result
330
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
331
  yield (
 
381
  reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
382
  reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
383
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
384
+
385
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
386
 
387
  output_request = gr.Textbox(