|
|
import os |
|
|
import re |
|
|
from typing import Dict, List, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from swift.llm import InferRequest |
|
|
|
|
|
|
|
|
class ORM: |
|
|
|
|
|
def __call__(self, **kwargs) -> List[float]: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class ReactORM(ORM): |
|
|
|
|
|
@staticmethod |
|
|
def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list): |
|
|
f1 = [] |
|
|
for i in range(len(action_pred)): |
|
|
ref_action = action_ref[i] |
|
|
pred_action = action_pred[i] |
|
|
|
|
|
ref_input = ref_list[i] |
|
|
cand_input = cand_list[i] |
|
|
|
|
|
ref_is_json = False |
|
|
try: |
|
|
ref_input_json = json.loads(ref_input) |
|
|
ref_is_json = True |
|
|
except Exception: |
|
|
ref_input_json = ref_input |
|
|
|
|
|
cand_is_json = False |
|
|
try: |
|
|
cand_input_json = json.loads(cand_input) |
|
|
cand_is_json = True |
|
|
except Exception: |
|
|
cand_input_json = cand_input |
|
|
|
|
|
if ref_action != pred_action or (ref_is_json ^ cand_is_json): |
|
|
f1.append(0) |
|
|
elif not ref_is_json and not cand_is_json: |
|
|
rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json]) |
|
|
if rougel is None or rougel < 10: |
|
|
f1.append(0) |
|
|
elif 10 <= rougel < 20: |
|
|
f1.append(0.1) |
|
|
else: |
|
|
f1.append(1) |
|
|
else: |
|
|
if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f1.append(0) |
|
|
continue |
|
|
|
|
|
half_match = 0 |
|
|
full_match = 0 |
|
|
if ref_input_json == {}: |
|
|
if cand_input_json == {}: |
|
|
f1.append(1) |
|
|
else: |
|
|
f1.append(0) |
|
|
else: |
|
|
for k, v in ref_input_json.items(): |
|
|
if k in cand_input_json.keys(): |
|
|
if cand_input_json[k] == v: |
|
|
full_match += 1 |
|
|
else: |
|
|
half_match += 1 |
|
|
|
|
|
recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30) |
|
|
precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30) |
|
|
try: |
|
|
f1.append((2 * recall * precision) / (recall + precision)) |
|
|
except Exception: |
|
|
f1.append(0.0) |
|
|
|
|
|
if f1[0] == 1.0: |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def parse_action(text): |
|
|
if 'Action Input:' in text: |
|
|
input_idx = text.rindex('Action Input:') |
|
|
action_input = text[input_idx + len('Action Input:'):].strip() |
|
|
else: |
|
|
action_input = '{}' |
|
|
|
|
|
if 'Action:' in text: |
|
|
action_idx = text.rindex('Action:') |
|
|
action = text[action_idx + len('Action:'):].strip() |
|
|
if 'Action Input:' in action: |
|
|
input_idx = action.index('Action Input:') |
|
|
action = action[:input_idx].strip() |
|
|
else: |
|
|
action = 'none' |
|
|
return action, action_input |
|
|
|
|
|
@staticmethod |
|
|
def parse_output(text): |
|
|
action, action_input = ReactORM.parse_action(text) |
|
|
return action, action_input |
|
|
|
|
|
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]: |
|
|
rewards = [] |
|
|
if not isinstance(infer_requests[0], str): |
|
|
predictions = [request['messages'][-1]['content'] for request in infer_requests] |
|
|
else: |
|
|
predictions = infer_requests |
|
|
for prediction, ground_truth in zip(predictions, solution): |
|
|
if prediction.endswith('Observation:'): |
|
|
prediction = prediction[:prediction.index('Observation:')].strip() |
|
|
action_ref = [] |
|
|
action_input_ref = [] |
|
|
action_pred = [] |
|
|
action_input_pred = [] |
|
|
reference = ground_truth |
|
|
prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip() |
|
|
ref_action, ref_input = ReactORM.parse_output(reference) |
|
|
pred_action, pred_input = ReactORM.parse_output(prediction) |
|
|
action_ref.append(ref_action) |
|
|
action_input_ref.append(ref_input) |
|
|
if pred_action is None: |
|
|
action_pred.append('none') |
|
|
else: |
|
|
action_pred.append(pred_action) |
|
|
|
|
|
if pred_input is None: |
|
|
action_input_pred.append('{}') |
|
|
else: |
|
|
action_input_pred.append(pred_input) |
|
|
|
|
|
reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref) |
|
|
rewards.append(float(reward)) |
|
|
return rewards |
|
|
|
|
|
@staticmethod |
|
|
def evaluate_rougel(cand_list: list, ref_list: list): |
|
|
if len(ref_list) == 0: |
|
|
return None |
|
|
try: |
|
|
from rouge import Rouge |
|
|
rouge = Rouge() |
|
|
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True) |
|
|
rougel = rouge_score['rouge-l']['f'] |
|
|
return rougel |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
class MathORM(ORM): |
|
|
|
|
|
def __init__(self): |
|
|
from transformers.utils import strtobool |
|
|
self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False')) |
|
|
if self.use_opencompass: |
|
|
from opencompass.datasets.math import MATHEvaluator |
|
|
self.evaluator = MATHEvaluator() |
|
|
|
|
|
@staticmethod |
|
|
def check_terminate(answers: Union[str, List[str]]) -> List[bool]: |
|
|
if isinstance(answers, str): |
|
|
answers = [answers] |
|
|
results = [] |
|
|
for answer in answers: |
|
|
results.append('\\boxed' in answer) |
|
|
return results |
|
|
|
|
|
@staticmethod |
|
|
def extract_boxed_result(text): |
|
|
pattern = r'\\boxed{([^}]*)}' |
|
|
match = re.search(pattern, text) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
else: |
|
|
return text |
|
|
|
|
|
@staticmethod |
|
|
def clean_latex(latex_str): |
|
|
latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str) |
|
|
latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '') |
|
|
return latex_str.strip() |
|
|
|
|
|
@staticmethod |
|
|
def parse_expression(latex_str): |
|
|
from sympy import simplify |
|
|
from sympy.parsing.latex import parse_latex |
|
|
try: |
|
|
expr = parse_latex(latex_str) |
|
|
return simplify(expr) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def compare_consecutive(first, second): |
|
|
cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]] |
|
|
parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list] |
|
|
if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'): |
|
|
value = parsed_exprs[0].equals(parsed_exprs[1]) |
|
|
else: |
|
|
value = parsed_exprs[0] == parsed_exprs[1] |
|
|
if value is None: |
|
|
value = False |
|
|
return value |
|
|
|
|
|
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str], |
|
|
**kwargs) -> List[float]: |
|
|
rewards = [] |
|
|
predictions = [request.messages[-1]['content'] for request in infer_requests] |
|
|
for prediction, ground_truth in zip(predictions, ground_truths): |
|
|
if '# Answer' in prediction: |
|
|
prediction = prediction.split('# Answer')[1] |
|
|
if '# Answer' in ground_truth: |
|
|
ground_truth = ground_truth.split('# Answer')[1] |
|
|
prediction = prediction.strip() |
|
|
ground_truth = ground_truth.strip() |
|
|
prediction = MathORM.extract_boxed_result(prediction) |
|
|
ground_truth = MathORM.extract_boxed_result(ground_truth) |
|
|
if self.use_opencompass: |
|
|
reward = self.evaluator.is_equiv(prediction, ground_truth) |
|
|
else: |
|
|
reward = MathORM.compare_consecutive(prediction, ground_truth) |
|
|
rewards.append(float(reward)) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class MathAccuracy(ORM): |
|
|
|
|
|
def __init__(self): |
|
|
import importlib.util |
|
|
assert importlib.util.find_spec('math_verify') is not None, ( |
|
|
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
from latex2sympy2_extended import NormalizationConfig |
|
|
from math_verify import LatexExtractionConfig, parse, verify |
|
|
rewards = [] |
|
|
for content, sol in zip(completions, solution): |
|
|
gold_parsed = parse(sol, extraction_mode='first_match') |
|
|
if len(gold_parsed) != 0: |
|
|
|
|
|
answer_parsed = parse( |
|
|
content, |
|
|
extraction_config=[ |
|
|
LatexExtractionConfig( |
|
|
normalization_config=NormalizationConfig( |
|
|
nits=False, |
|
|
malformed_operators=False, |
|
|
basic_latex=True, |
|
|
equations=True, |
|
|
boxed=True, |
|
|
units=True, |
|
|
), |
|
|
|
|
|
boxed_match_priority=0, |
|
|
try_extract_without_anchor=False, |
|
|
) |
|
|
], |
|
|
extraction_mode='first_match', |
|
|
) |
|
|
|
|
|
try: |
|
|
reward = float(verify(gold_parsed, answer_parsed)) |
|
|
except Exception: |
|
|
reward = 0.0 |
|
|
else: |
|
|
|
|
|
reward = 0.0 |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class Format(ORM): |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])' |
|
|
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
|
|
return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
|
|
|
|
class ReActFormat(ORM): |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
pattern = r'^<think>.*?</think>\s*Action:.*?Action Input:.*?$' |
|
|
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
|
|
return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
|
|
|
|
class CosineReward(ORM): |
|
|
|
|
|
def __init__(self, |
|
|
tokenizer=None, |
|
|
cosine_min_len_value_wrong: float = -0.5, |
|
|
cosine_max_len_value_wrong: float = 0.0, |
|
|
cosine_min_len_value_correct: float = 1.0, |
|
|
cosine_max_len_value_correct: float = 0.5, |
|
|
cosine_max_len: int = 1000, |
|
|
accuracy_orm=None): |
|
|
self.tokenizer = tokenizer |
|
|
self.min_len_value_wrong = cosine_min_len_value_wrong |
|
|
self.max_len_value_wrong = cosine_max_len_value_wrong |
|
|
self.min_len_value_correct = cosine_min_len_value_correct |
|
|
self.max_len_value_correct = cosine_max_len_value_correct |
|
|
self.max_len = cosine_max_len |
|
|
self.accuracy_orm = accuracy_orm or MathAccuracy() |
|
|
|
|
|
@staticmethod |
|
|
def cosfn(t, T, min_value, max_value): |
|
|
import math |
|
|
return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2 |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
acc_rewards = self.accuracy_orm(completions, solution, **kwargs) |
|
|
rewards = [] |
|
|
for content, acc_reward in zip(completions, acc_rewards): |
|
|
is_correct = acc_reward >= 1. |
|
|
if is_correct: |
|
|
|
|
|
min_value = self.max_len_value_correct |
|
|
max_value = self.min_len_value_correct |
|
|
else: |
|
|
min_value = self.max_len_value_wrong |
|
|
max_value = self.min_len_value_wrong |
|
|
gen_len = len(self.tokenizer.encode(content)) |
|
|
reward = self.cosfn(gen_len, self.max_len, min_value, max_value) |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class RepetitionPenalty(ORM): |
|
|
|
|
|
def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0): |
|
|
self.ngram_size = repetition_n_grams |
|
|
self.max_penalty = repetition_max_penalty |
|
|
|
|
|
@staticmethod |
|
|
def zipngram(text: str, ngram_size: int): |
|
|
words = text.lower().split() |
|
|
return zip(*[words[i:] for i in range(ngram_size)]) |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
""" |
|
|
reward function the penalizes repetitions |
|
|
|
|
|
Args: |
|
|
completions: List of model completions |
|
|
""" |
|
|
rewards = [] |
|
|
for completion in completions: |
|
|
if completion == '': |
|
|
rewards.append(0.0) |
|
|
continue |
|
|
if len(completion.split()) < self.ngram_size: |
|
|
rewards.append(0.0) |
|
|
continue |
|
|
|
|
|
ngrams = set() |
|
|
total = 0 |
|
|
for ng in self.zipngram(completion, self.ngram_size): |
|
|
ngrams.add(ng) |
|
|
total += 1 |
|
|
|
|
|
scaling = 1 - len(ngrams) / total |
|
|
reward = scaling * self.max_penalty |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class SoftOverlong(ORM): |
|
|
|
|
|
def __init__(self, tokenizer, soft_max_length, soft_cache_length): |
|
|
self.tokenizer = tokenizer |
|
|
assert soft_cache_length < soft_max_length |
|
|
self.soft_max_length = soft_max_length |
|
|
self.soft_cache_length = soft_cache_length |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
rewards = [] |
|
|
for completion in completions: |
|
|
completion_length = len(self.tokenizer.encode(completion)) |
|
|
expected_len = self.soft_max_length - self.soft_cache_length |
|
|
exceed_len = completion_length - expected_len |
|
|
rewards.append(min(-exceed_len / self.soft_cache_length, 0)) |
|
|
return rewards |
|
|
|
|
|
|
|
|
orms = { |
|
|
'toolbench': ReactORM, |
|
|
'math': MathORM, |
|
|
'accuracy': MathAccuracy, |
|
|
'format': Format, |
|
|
'react_format': ReActFormat, |
|
|
'cosine': CosineReward, |
|
|
'repetition': RepetitionPenalty, |
|
|
'soft_overlong': SoftOverlong, |
|
|
} |
|
|
|