Spaces:
Sleeping
Sleeping
Update app.py
Browse filesbetter handling of messier outputs
app.py
CHANGED
|
@@ -180,16 +180,55 @@ ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATI
|
|
| 180 |
ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
# --- Backend Function for Local Inference ---
|
| 184 |
@spaces.GPU
|
| 185 |
def generate_request_local(authority, kw1, kw2, kw3):
|
| 186 |
-
"""Generates a request using the locally loaded transformer model
|
| 187 |
if not model or not tokenizer:
|
| 188 |
return "Error: Model is not loaded. Please check the Space logs for details."
|
| 189 |
-
|
| 190 |
keywords = [kw for kw in [kw1, kw2, kw3] if kw]
|
| 191 |
keyword_string = ", ".join(keywords)
|
| 192 |
-
|
| 193 |
prompt = (
|
| 194 |
"You are an expert at writing formal Freedom of Information requests to UK public authorities. "
|
| 195 |
f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}
|
|
@@ -203,48 +242,56 @@ def generate_request_local(authority, kw1, kw2, kw3):
|
|
| 203 |
on the part of the FOI officer to answer the request. No clarification should be needed.
|
| 204 |
7. Do not ask for all policies, or all information
|
| 205 |
8. End with "Yours Faithfully, [Your Name]" exactly
|
| 206 |
-
|
| 207 |
Make the requests specific, professional, and relevant to what this public authority would reasonably hold.
|
| 208 |
Use accessible language, avoiding terms that are overly legalistic or technical and UK English. Be clear and concise"""
|
| 209 |
)
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
# --- Gradio UI and Spinning Logic ---
|
| 250 |
def spin_the_reels():
|
|
@@ -263,7 +310,7 @@ def spin_the_reels():
|
|
| 263 |
"Spinning..."
|
| 264 |
)
|
| 265 |
time.sleep(spin_interval)
|
| 266 |
-
|
| 267 |
# 2. Select the final fixed combination
|
| 268 |
final_combination = random.choice(FOI_COMBINATIONS)
|
| 269 |
final_authority = final_combination["authority"]
|
|
@@ -278,7 +325,7 @@ def spin_the_reels():
|
|
| 278 |
final_authority, kw1, kw2, kw3,
|
| 279 |
f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
|
| 280 |
)
|
| 281 |
-
|
| 282 |
# 3. Call the local model and yield the final result
|
| 283 |
generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
|
| 284 |
yield (
|
|
@@ -334,7 +381,7 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
|
|
| 334 |
reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
|
| 335 |
reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
|
| 336 |
reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
|
| 337 |
-
|
| 338 |
pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
|
| 339 |
|
| 340 |
output_request = gr.Textbox(
|
|
|
|
| 180 |
ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
|
| 181 |
|
| 182 |
|
| 183 |
+
# --- Helper Function for Cleaning and Validation ---
|
| 184 |
+
def clean_and_validate_output(raw_text: str) -> tuple[str, bool]:
|
| 185 |
+
"""
|
| 186 |
+
Cleans the model's output by keeping only the first complete request.
|
| 187 |
+
|
| 188 |
+
It validates that the output contains essential markers ("Dear" and "[Your Name]").
|
| 189 |
+
If it detects that the model has started generating a second request, it truncates
|
| 190 |
+
the string after the first "[Your Name]".
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
raw_text: The raw string output from the language model.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
A tuple containing:
|
| 197 |
+
- The cleaned text.
|
| 198 |
+
- A boolean flag: True if the output is valid, False if it is malformed.
|
| 199 |
+
"""
|
| 200 |
+
end_marker = "[Your Name]"
|
| 201 |
+
start_marker = "Dear"
|
| 202 |
+
|
| 203 |
+
# Validate: A valid request must contain the end marker.
|
| 204 |
+
if end_marker not in raw_text:
|
| 205 |
+
return raw_text, False # Malformed, signal for regeneration.
|
| 206 |
+
|
| 207 |
+
# Find the end of the first complete request.
|
| 208 |
+
first_end_pos = raw_text.find(end_marker)
|
| 209 |
+
end_of_first_request_index = first_end_pos + len(end_marker)
|
| 210 |
+
|
| 211 |
+
# Check if a second request has started after the first one ended.
|
| 212 |
+
start_of_second_request_pos = raw_text.find(start_marker, end_of_first_request_index)
|
| 213 |
+
|
| 214 |
+
if start_of_second_request_pos != -1:
|
| 215 |
+
# If a second request is found, truncate to keep only the first one.
|
| 216 |
+
cleaned_text = raw_text[:end_of_first_request_index]
|
| 217 |
+
return cleaned_text, True
|
| 218 |
+
else:
|
| 219 |
+
# No second request found, the output is valid.
|
| 220 |
+
return raw_text, True
|
| 221 |
+
|
| 222 |
+
|
| 223 |
# --- Backend Function for Local Inference ---
|
| 224 |
@spaces.GPU
|
| 225 |
def generate_request_local(authority, kw1, kw2, kw3):
|
| 226 |
+
"""Generates a request using the locally loaded transformer model, with validation and retry logic."""
|
| 227 |
if not model or not tokenizer:
|
| 228 |
return "Error: Model is not loaded. Please check the Space logs for details."
|
| 229 |
+
|
| 230 |
keywords = [kw for kw in [kw1, kw2, kw3] if kw]
|
| 231 |
keyword_string = ", ".join(keywords)
|
|
|
|
| 232 |
prompt = (
|
| 233 |
"You are an expert at writing formal Freedom of Information requests to UK public authorities. "
|
| 234 |
f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}
|
|
|
|
| 242 |
on the part of the FOI officer to answer the request. No clarification should be needed.
|
| 243 |
7. Do not ask for all policies, or all information
|
| 244 |
8. End with "Yours Faithfully, [Your Name]" exactly
|
|
|
|
| 245 |
Make the requests specific, professional, and relevant to what this public authority would reasonably hold.
|
| 246 |
Use accessible language, avoiding terms that are overly legalistic or technical and UK English. Be clear and concise"""
|
| 247 |
)
|
| 248 |
+
|
| 249 |
+
max_retries = 3
|
| 250 |
+
for attempt in range(max_retries):
|
| 251 |
+
try:
|
| 252 |
+
# Tokenize the input prompt
|
| 253 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 254 |
+
|
| 255 |
+
# Set generation parameters
|
| 256 |
+
generation_params = {
|
| 257 |
+
"max_new_tokens": 500,
|
| 258 |
+
"temperature": 0.3,
|
| 259 |
+
"top_p": 0.95,
|
| 260 |
+
"top_k": 50,
|
| 261 |
+
"repetition_penalty": 1.05,
|
| 262 |
+
"do_sample": True,
|
| 263 |
+
"pad_token_id": tokenizer.eos_token_id
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# Generate text sequences
|
| 267 |
+
output_sequences = model.generate(**inputs, **generation_params)
|
| 268 |
+
|
| 269 |
+
# Decode the generated text
|
| 270 |
+
generated_text = tokenizer.decode(
|
| 271 |
+
output_sequences[0][len(inputs["input_ids"][0]):],
|
| 272 |
+
skip_special_tokens=True
|
| 273 |
+
).strip()
|
| 274 |
+
|
| 275 |
+
# Remove artifact if present
|
| 276 |
+
if generated_text.startswith('.\n'):
|
| 277 |
+
generated_text = generated_text[2:]
|
| 278 |
+
|
| 279 |
+
# **NEW**: Clean and validate the output
|
| 280 |
+
cleaned_text, is_valid = clean_and_validate_output(generated_text)
|
| 281 |
+
|
| 282 |
+
if is_valid:
|
| 283 |
+
return cleaned_text # Success! Return the valid, cleaned text.
|
| 284 |
+
else:
|
| 285 |
+
print(f"Attempt {attempt + 1}/{max_retries}: Malformed output detected. Retrying...")
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
|
| 289 |
+
if attempt == max_retries - 1:
|
| 290 |
+
return f"An error occurred during text generation: {e}"
|
| 291 |
+
|
| 292 |
+
# If the loop finishes, all retries have failed
|
| 293 |
+
return "Failed to generate a valid request after multiple attempts. Please try again."
|
| 294 |
+
|
| 295 |
|
| 296 |
# --- Gradio UI and Spinning Logic ---
|
| 297 |
def spin_the_reels():
|
|
|
|
| 310 |
"Spinning..."
|
| 311 |
)
|
| 312 |
time.sleep(spin_interval)
|
| 313 |
+
|
| 314 |
# 2. Select the final fixed combination
|
| 315 |
final_combination = random.choice(FOI_COMBINATIONS)
|
| 316 |
final_authority = final_combination["authority"]
|
|
|
|
| 325 |
final_authority, kw1, kw2, kw3,
|
| 326 |
f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
|
| 327 |
)
|
| 328 |
+
|
| 329 |
# 3. Call the local model and yield the final result
|
| 330 |
generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
|
| 331 |
yield (
|
|
|
|
| 381 |
reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
|
| 382 |
reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
|
| 383 |
reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
|
| 384 |
+
|
| 385 |
pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
|
| 386 |
|
| 387 |
output_request = gr.Textbox(
|