Spaces:
Sleeping
Sleeping
Fix CoT truncation: increase min_new_tokens to 1000, add generation logging, improve truncated JSON handling
Browse files- gradio_app.py +16 -2
gradio_app.py
CHANGED
|
@@ -106,8 +106,8 @@ def generate_response(prompt, temperature=0.8):
|
|
| 106 |
|
| 107 |
# Set minimum tokens based on request type
|
| 108 |
if is_cot_request:
|
| 109 |
-
min_tokens =
|
| 110 |
-
logger.info("Detected Chain of Thinking request - using min_new_tokens=
|
| 111 |
else:
|
| 112 |
min_tokens = 200 # Standard minimum
|
| 113 |
|
|
@@ -147,6 +147,15 @@ def generate_response(prompt, temperature=0.8):
|
|
| 147 |
# Decode the response
|
| 148 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Post-decode guard: if a top-level JSON array closes, trim to the first full array
|
| 151 |
# This helps prevent trailing prose like 'assistant' or 'Message'.
|
| 152 |
try:
|
|
@@ -194,6 +203,11 @@ def generate_response(prompt, temperature=0.8):
|
|
| 194 |
json_text = generated_text[start_idx:end_idx+1]
|
| 195 |
logger.info(f"Extracted complete JSON array of length {len(json_text)}")
|
| 196 |
generated_text = json_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
except Exception as e:
|
| 198 |
logger.warning(f"Error in JSON extraction: {e}")
|
| 199 |
pass
|
|
|
|
| 106 |
|
| 107 |
# Set minimum tokens based on request type
|
| 108 |
if is_cot_request:
|
| 109 |
+
min_tokens = 1000 # Much higher minimum for CoT to ensure complete responses
|
| 110 |
+
logger.info("Detected Chain of Thinking request - using min_new_tokens=1000")
|
| 111 |
else:
|
| 112 |
min_tokens = 200 # Standard minimum
|
| 113 |
|
|
|
|
| 147 |
# Decode the response
|
| 148 |
generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 149 |
|
| 150 |
+
# Log generation details for debugging
|
| 151 |
+
input_length = inputs['input_ids'].shape[1]
|
| 152 |
+
output_length = outputs[0].shape[0]
|
| 153 |
+
generated_length = output_length - input_length
|
| 154 |
+
logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
|
| 155 |
+
|
| 156 |
+
if generated_length < min_tokens:
|
| 157 |
+
logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
|
| 158 |
+
|
| 159 |
# Post-decode guard: if a top-level JSON array closes, trim to the first full array
|
| 160 |
# This helps prevent trailing prose like 'assistant' or 'Message'.
|
| 161 |
try:
|
|
|
|
| 203 |
json_text = generated_text[start_idx:end_idx+1]
|
| 204 |
logger.info(f"Extracted complete JSON array of length {len(json_text)}")
|
| 205 |
generated_text = json_text
|
| 206 |
+
elif start_idx is not None:
|
| 207 |
+
# Found start but no end - response was truncated
|
| 208 |
+
logger.warning("JSON array started but never closed - response truncated")
|
| 209 |
+
# Try to extract what we have and let the client handle it
|
| 210 |
+
generated_text = generated_text[start_idx:]
|
| 211 |
except Exception as e:
|
| 212 |
logger.warning(f"Error in JSON extraction: {e}")
|
| 213 |
pass
|