girish00 commited on
Commit
637f64c
·
verified ·
1 Parent(s): 2d01b35

update endpoint helper files

Browse files
Files changed (1) hide show
  1. evaluate_model.py +61 -11
evaluate_model.py CHANGED
@@ -1,7 +1,8 @@
1
- import argparse
2
- import json
3
- import subprocess
4
- import sys
 
5
 
6
 
7
  DEFAULT_TEST_PROMPTS = [
@@ -50,14 +51,63 @@ def run_inference(python_exec, model_path, base_model, prompt, max_new_tokens, a
50
  return None, f"invalid json output: {exc}: {stdout[:300]}"
51
 
52
 
53
- def safe_float(value):
54
  try:
55
  return float(value)
56
  except (TypeError, ValueError):
57
- return 0.0
58
-
59
-
60
- def score_payload(payload):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  required_keys = {
62
  "code",
63
  "explanation",
@@ -69,7 +119,7 @@ def score_payload(payload):
69
  "latency_ms",
70
  }
71
  has_all_keys = required_keys.issubset(payload.keys())
72
- code_ok = bool(str(payload.get("code", "")).strip())
73
  explanation_ok = bool(str(payload.get("explanation", "")).strip())
74
  confidence = safe_float(payload.get("confidence", 0.0))
75
  relevancy = safe_float(payload.get("relevancy_score", 0.0))
@@ -116,7 +166,7 @@ def main():
116
  results.append({"prompt": prompt, "error": error, "pass": False})
117
  continue
118
 
119
- metrics = score_payload(payload)
120
  is_pass = (
121
  metrics["schema_ok"]
122
  and metrics["content_ok"]
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import subprocess
5
+ import sys
6
 
7
 
8
  DEFAULT_TEST_PROMPTS = [
 
51
  return None, f"invalid json output: {exc}: {stdout[:300]}"
52
 
53
 
54
+ def safe_float(value):
55
  try:
56
  return float(value)
57
  except (TypeError, ValueError):
58
+ return 0.0
59
+
60
+
61
+ def prompt_expects_code(prompt):
62
+ prompt_l = prompt.lower()
63
+ markers = (
64
+ "fix",
65
+ "debug",
66
+ "repair",
67
+ "write",
68
+ "create",
69
+ "generate",
70
+ "implement",
71
+ "function",
72
+ "code",
73
+ "snippet",
74
+ "python",
75
+ "multiply",
76
+ "multiplication",
77
+ "product",
78
+ "add",
79
+ "addition",
80
+ "sum",
81
+ "subtract",
82
+ "subtraction",
83
+ "difference",
84
+ "divide",
85
+ "division",
86
+ "quotient",
87
+ )
88
+ return any(marker in prompt_l for marker in markers)
89
+
90
+
91
+ def code_is_valid_for_prompt(prompt, code):
92
+ code = str(code or "").strip()
93
+ if not code:
94
+ return False
95
+ if not prompt_expects_code(prompt):
96
+ return True
97
+ python_like = any(
98
+ marker in code
99
+ for marker in ("def ", "import ", "class ", "print(", "return ", "for ", "if ")
100
+ )
101
+ if not python_like:
102
+ return False
103
+ try:
104
+ ast.parse(code)
105
+ return True
106
+ except SyntaxError:
107
+ return False
108
+
109
+
110
+ def score_payload(prompt, payload):
111
  required_keys = {
112
  "code",
113
  "explanation",
 
119
  "latency_ms",
120
  }
121
  has_all_keys = required_keys.issubset(payload.keys())
122
+ code_ok = code_is_valid_for_prompt(prompt, payload.get("code", ""))
123
  explanation_ok = bool(str(payload.get("explanation", "")).strip())
124
  confidence = safe_float(payload.get("confidence", 0.0))
125
  relevancy = safe_float(payload.get("relevancy_score", 0.0))
 
166
  results.append({"prompt": prompt, "error": error, "pass": False})
167
  continue
168
 
169
+ metrics = score_payload(prompt, payload)
170
  is_pass = (
171
  metrics["schema_ok"]
172
  and metrics["content_ok"]