bbb / swift /plugin /orm.py
novateur's picture
Add files using upload-large-folder tool
a1652f6 verified
import os
import re
from typing import Dict, List, Union
import json
from swift.llm import InferRequest
class ORM:
def __call__(self, **kwargs) -> List[float]:
raise NotImplementedError
class ReactORM(ORM):
@staticmethod
def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list):
f1 = []
for i in range(len(action_pred)):
ref_action = action_ref[i]
pred_action = action_pred[i]
ref_input = ref_list[i]
cand_input = cand_list[i]
ref_is_json = False
try:
ref_input_json = json.loads(ref_input)
ref_is_json = True
except Exception:
ref_input_json = ref_input
cand_is_json = False
try:
cand_input_json = json.loads(cand_input)
cand_is_json = True
except Exception:
cand_input_json = cand_input
if ref_action != pred_action or (ref_is_json ^ cand_is_json):
f1.append(0)
elif not ref_is_json and not cand_is_json:
rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json])
if rougel is None or rougel < 10:
f1.append(0)
elif 10 <= rougel < 20:
f1.append(0.1)
else:
f1.append(1)
else:
if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict):
# This cannot be happen, but:
# line 62, in evaluate_action_reward
# for k, v in ref_input_json.items():
# AttributeError: 'str' object has no attribute 'items'
# print(f'>>>>>>ref_input_json: {ref_input_json}, cand_input_json: {cand_input_json}')
f1.append(0)
continue
half_match = 0
full_match = 0
if ref_input_json == {}:
if cand_input_json == {}:
f1.append(1)
else:
f1.append(0)
else:
for k, v in ref_input_json.items():
if k in cand_input_json.keys():
if cand_input_json[k] == v:
full_match += 1
else:
half_match += 1
recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30)
precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30)
try:
f1.append((2 * recall * precision) / (recall + precision))
except Exception:
f1.append(0.0)
if f1[0] == 1.0:
return True
else:
return False
@staticmethod
def parse_action(text):
if 'Action Input:' in text:
input_idx = text.rindex('Action Input:')
action_input = text[input_idx + len('Action Input:'):].strip()
else:
action_input = '{}'
if 'Action:' in text:
action_idx = text.rindex('Action:')
action = text[action_idx + len('Action:'):].strip()
if 'Action Input:' in action:
input_idx = action.index('Action Input:')
action = action[:input_idx].strip()
else:
action = 'none'
return action, action_input
@staticmethod
def parse_output(text):
action, action_input = ReactORM.parse_action(text)
return action, action_input
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]:
rewards = []
if not isinstance(infer_requests[0], str):
predictions = [request['messages'][-1]['content'] for request in infer_requests]
else:
predictions = infer_requests
for prediction, ground_truth in zip(predictions, solution):
if prediction.endswith('Observation:'):
prediction = prediction[:prediction.index('Observation:')].strip()
action_ref = []
action_input_ref = []
action_pred = []
action_input_pred = []
reference = ground_truth
prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip()
ref_action, ref_input = ReactORM.parse_output(reference)
pred_action, pred_input = ReactORM.parse_output(prediction)
action_ref.append(ref_action)
action_input_ref.append(ref_input)
if pred_action is None:
action_pred.append('none')
else:
action_pred.append(pred_action)
if pred_input is None:
action_input_pred.append('{}')
else:
action_input_pred.append(pred_input)
reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref)
rewards.append(float(reward))
return rewards
@staticmethod
def evaluate_rougel(cand_list: list, ref_list: list):
if len(ref_list) == 0:
return None
try:
from rouge import Rouge
rouge = Rouge()
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
rougel = rouge_score['rouge-l']['f']
return rougel
except Exception:
return None
class MathORM(ORM):
def __init__(self):
from transformers.utils import strtobool
self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False'))
if self.use_opencompass:
from opencompass.datasets.math import MATHEvaluator
self.evaluator = MATHEvaluator()
@staticmethod
def check_terminate(answers: Union[str, List[str]]) -> List[bool]:
if isinstance(answers, str):
answers = [answers]
results = []
for answer in answers:
results.append('\\boxed' in answer)
return results
@staticmethod
def extract_boxed_result(text):
pattern = r'\\boxed{([^}]*)}'
match = re.search(pattern, text)
if match:
return match.group(1).strip()
else:
return text
@staticmethod
def clean_latex(latex_str):
latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str)
latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '')
return latex_str.strip()
@staticmethod
def parse_expression(latex_str):
from sympy import simplify
from sympy.parsing.latex import parse_latex
try:
expr = parse_latex(latex_str)
return simplify(expr)
except Exception:
return None
@staticmethod
def compare_consecutive(first, second):
cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]]
parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list]
if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'):
value = parsed_exprs[0].equals(parsed_exprs[1])
else:
value = parsed_exprs[0] == parsed_exprs[1]
if value is None:
value = False
return value
def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
**kwargs) -> List[float]:
rewards = []
predictions = [request.messages[-1]['content'] for request in infer_requests]
for prediction, ground_truth in zip(predictions, ground_truths):
if '# Answer' in prediction:
prediction = prediction.split('# Answer')[1]
if '# Answer' in ground_truth:
ground_truth = ground_truth.split('# Answer')[1]
prediction = prediction.strip()
ground_truth = ground_truth.strip()
prediction = MathORM.extract_boxed_result(prediction)
ground_truth = MathORM.extract_boxed_result(ground_truth)
if self.use_opencompass:
reward = self.evaluator.is_equiv(prediction, ground_truth)
else:
reward = MathORM.compare_consecutive(prediction, ground_truth)
rewards.append(float(reward))
return rewards
class MathAccuracy(ORM):
def __init__(self):
import importlib.util
assert importlib.util.find_spec('math_verify') is not None, (
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
def __call__(self, completions, solution, **kwargs) -> List[float]:
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
rewards = []
for content, sol in zip(completions, solution):
gold_parsed = parse(sol, extraction_mode='first_match')
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode='first_match',
)
# edge case
try:
reward = float(verify(gold_parsed, answer_parsed))
except Exception:
reward = 0.0
else:
# If the gold solution is not parseable, we reward 0 to skip this example
reward = 0.0
rewards.append(reward)
return rewards
class Format(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
return [1.0 if match else 0.0 for match in matches]
class ReActFormat(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r'^<think>.*?</think>\s*Action:.*?Action Input:.*?$'
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
return [1.0 if match else 0.0 for match in matches]
class CosineReward(ORM):
# https://arxiv.org/abs/2502.03373
def __init__(self,
tokenizer=None,
cosine_min_len_value_wrong: float = -0.5,
cosine_max_len_value_wrong: float = 0.0,
cosine_min_len_value_correct: float = 1.0,
cosine_max_len_value_correct: float = 0.5,
cosine_max_len: int = 1000,
accuracy_orm=None):
self.tokenizer = tokenizer
self.min_len_value_wrong = cosine_min_len_value_wrong
self.max_len_value_wrong = cosine_max_len_value_wrong
self.min_len_value_correct = cosine_min_len_value_correct
self.max_len_value_correct = cosine_max_len_value_correct
self.max_len = cosine_max_len
self.accuracy_orm = accuracy_orm or MathAccuracy()
@staticmethod
def cosfn(t, T, min_value, max_value):
import math
return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2
def __call__(self, completions, solution, **kwargs) -> List[float]:
acc_rewards = self.accuracy_orm(completions, solution, **kwargs)
rewards = []
for content, acc_reward in zip(completions, acc_rewards):
is_correct = acc_reward >= 1.
if is_correct:
# Swap min/max for correct answers
min_value = self.max_len_value_correct
max_value = self.min_len_value_correct
else:
min_value = self.max_len_value_wrong
max_value = self.min_len_value_wrong
gen_len = len(self.tokenizer.encode(content))
reward = self.cosfn(gen_len, self.max_len, min_value, max_value)
rewards.append(reward)
return rewards
class RepetitionPenalty(ORM):
# https://arxiv.org/abs/2502.03373
def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0):
self.ngram_size = repetition_n_grams
self.max_penalty = repetition_max_penalty
@staticmethod
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def __call__(self, completions, **kwargs) -> List[float]:
"""
reward function the penalizes repetitions
Args:
completions: List of model completions
"""
rewards = []
for completion in completions:
if completion == '':
rewards.append(0.0)
continue
if len(completion.split()) < self.ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in self.zipngram(completion, self.ngram_size):
ngrams.add(ng)
total += 1
scaling = 1 - len(ngrams) / total
reward = scaling * self.max_penalty
rewards.append(reward)
return rewards
class SoftOverlong(ORM):
def __init__(self, tokenizer, soft_max_length, soft_cache_length):
self.tokenizer = tokenizer
assert soft_cache_length < soft_max_length
self.soft_max_length = soft_max_length
self.soft_cache_length = soft_cache_length
def __call__(self, completions, **kwargs) -> List[float]:
rewards = []
for completion in completions:
completion_length = len(self.tokenizer.encode(completion))
expected_len = self.soft_max_length - self.soft_cache_length
exceed_len = completion_length - expected_len
rewards.append(min(-exceed_len / self.soft_cache_length, 0))
return rewards
orms = {
'toolbench': ReactORM,
'math': MathORM,
'accuracy': MathAccuracy,
'format': Format,
'react_format': ReActFormat,
'cosine': CosineReward,
'repetition': RepetitionPenalty,
'soft_overlong': SoftOverlong,
}