tobil commited on
Commit
36fb469
·
verified ·
1 Parent(s): 65268ed

Upload train_1.7B_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_1.7B_grpo.py +12 -4
train_1.7B_grpo.py CHANGED
@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
166
  """Score expansion. Returns 0.0-1.0 for RL reward."""
167
  text = expansion.strip()
168
 
 
 
 
 
 
 
 
 
 
169
  # HARD FAIL: Chat template artifacts
170
- if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
171
- '\nassistant\n', '\nuser\n', '<|endoftext|>']):
172
  return 0.0
173
 
174
  # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
275
  elif not entities:
276
  entity_score = 10
277
 
278
- total = format_score + diversity_score + hyde_score + quality_score + entity_score
279
- max_possible = 120 if parsed["hyde"] else 100
280
  return max(0.0, min(1.0, total / max_possible))
281
 
282
 
 
166
  """Score expansion. Returns 0.0-1.0 for RL reward."""
167
  text = expansion.strip()
168
 
169
+ # Strip end token if present
170
+ text = text.replace('<|im_end|>', '').strip()
171
+
172
+ # Check for <think>...</think> blocks - strip and mark as not skipped
173
+ skipped_think = 20 # Bonus for not using thinking mode
174
+ if '<think>' in text and '</think>' in text:
175
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
176
+ skipped_think = 0 # No bonus if thinking was used
177
+
178
  # HARD FAIL: Chat template artifacts
179
+ if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
 
180
  return 0.0
181
 
182
  # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
 
283
  elif not entities:
284
  entity_score = 10
285
 
286
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
287
+ max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
288
  return max(0.0, min(1.0, total / max_possible))
289
 
290