Luigi commited on
Commit
08c2564
·
verified ·
1 Parent(s): 30bb6c7

feat: enhance JSON validation to support both model versions

Browse files

Updated validate_json function to handle both:
- Original model version: outputs plain JSON format
- New fine-tuned version: outputs JSON wrapped in code blocks

Key changes:
- Added support for extracting JSON from code blocks (```json\n{...}\n```)
- Maintained backward compatibility with plain JSON format
- Preserved JSON formatting fixes for phone numbers and trailing commas
- Improved error handling for both formats

This ensures the demo interface works seamlessly with both old and new versions of the fine-tuned Gemma-3 model, providing robust parsing regardless of the output format.

Files changed (1) hide show
  1. app.py +20 -5
app.py CHANGED
@@ -44,14 +44,29 @@ def load_model():
44
  return None, None
45
 
46
  def validate_json(output: str) -> tuple:
47
- """Validate and extract JSON from model output"""
48
  try:
49
- json_match = re.search(r'\{[\s\S]*\}', output)
50
- if not json_match:
51
- return False, None, "No JSON found / 未找到JSON"
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- json_str = json_match.group(0)
 
 
 
54
  json_str = re.sub(r',\s*\}', '}', json_str)
 
55
  parsed = json.loads(json_str)
56
  return True, parsed, "Valid JSON / 有效的JSON"
57
  except json.JSONDecodeError:
 
44
  return None, None
45
 
46
  def validate_json(output: str) -> tuple:
47
+ """Validate and extract JSON from model output - supports both plain JSON and code block formats"""
48
  try:
49
+ # First, try to extract JSON from code blocks (new model version)
50
+ json_match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', output)
51
+ if json_match:
52
+ json_str = json_match.group(1)
53
+ else:
54
+ # If no code block, look for JSON directly (old model version)
55
+ json_match = re.search(r'\{[\s\S]*\}', output)
56
+ if not json_match:
57
+ return False, None, "No JSON found / 未找到JSON"
58
+ json_str = json_match.group(0)
59
+
60
+ # Fix common JSON issues for both formats
61
+ # 1. Add quotes around phone numbers (they often start with 0)
62
+ json_str = re.sub(r'("phone_num":\s*)(\d[-\d]*)', r'\1"\2"', json_str)
63
 
64
+ # 2. Add quotes around num_people if it's a number
65
+ json_str = re.sub(r'("num_people":\s*)(\d+)', r'\1"\2"', json_str)
66
+
67
+ # 3. Fix trailing commas
68
  json_str = re.sub(r',\s*\}', '}', json_str)
69
+
70
  parsed = json.loads(json_str)
71
  return True, parsed, "Valid JSON / 有效的JSON"
72
  except json.JSONDecodeError: