| import hashlib |
| import inspect |
| from copy import copy |
| from typing import Any, Dict, List, Optional |
|
|
| import json |
| import numpy as np |
|
|
| from swift.llm import InferRequest, RequestConfig |
| from swift.utils import get_logger |
|
|
| logger = get_logger() |
|
|
|
|
| def get_messages_md5(row: Dict[str, Any]): |
| row = copy(row) |
| row.pop('choices', None) |
| serialized = json.dumps(row, sort_keys=True) |
| return hashlib.md5(serialized.encode('utf-8')).hexdigest() |
|
|
|
|
| def get_reward(model: Any, |
| infer_requests: List[InferRequest], |
| request_config: RequestConfig = None, |
| ground_truths: List[str] = None, |
| threshold: Optional[float] = None): |
| """Get reward from an RM model. |
| |
| Args: |
| model: The model instance or an RM evaluator |
| infer_requests: Infer requests sent to the model |
| request_config: Infer config |
| ground_truths: The ground truth list |
| threshold: An optional threshold to generate the mask |
| |
| Returns: |
| Tuple |
| Index 0: The min-max normalized scores matched the infer_requests |
| Index 1: The mask filtered by the threshold |
| """ |
| from swift.llm import InferEngine |
| infer_func = model.infer if isinstance(model, InferEngine) else model.__call__ |
| parameters = inspect.signature(infer_func).parameters |
| gt_param = {} |
| if 'ground_truths' in parameters: |
| gt_param = {'ground_truths': ground_truths} |
| if isinstance(infer_requests[0], dict): |
| infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests] |
| rewards = infer_func(infer_requests, request_config=request_config, **gt_param) |
| from swift.llm.infer.protocol import ChatCompletionResponse |
| if isinstance(rewards[0], ChatCompletionResponse): |
| print('reward:', rewards[0].choices[0].message.content) |
| if isinstance(rewards[0].choices[0].message.content, str): |
| rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards] |
| elif isinstance(rewards[0].choices[0].message.content, list): |
| rewards = [float(min(r.choices[0].message.content)) for r in rewards] |
| else: |
| rewards = [float(r.choices[0].message.content) for r in rewards] |
| arr = [] |
| for reward in rewards: |
| if isinstance(reward, (list, tuple)): |
| arr.append(min(reward)) |
| else: |
| arr.append(float(reward)) |
|
|
| _mask = np.array([True] * len(arr)) |
| if threshold is not None: |
| |
| _mask = np.array([a > threshold for a in arr]) |
|
|
| def normalize(arr): |
| min_val = np.min(arr) |
| max_val = np.max(arr) |
| if min_val == max_val: |
| if min_val == 0: |
| constant_value = 0.0 |
| else: |
| constant_value = min(1.0, min_val) |
| return np.full_like(arr, fill_value=constant_value, dtype=np.float64) |
| normalized = (arr - min_val) / (max_val - min_val + 1e-5) |
| return normalized |
|
|
| return normalize(arr), _mask |
|
|
|
|
| def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): |
| if isinstance(infer_engines, list): |
| assert len(infer_engines) >= len(request_configs) >= len(infer_requests) |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| n = len(infer_requests) |
| with ThreadPoolExecutor(max_workers=n) as executor: |
| futures = { |
| executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs): |
| i |
| for i in range(n) |
| } |
| responses = [] |
| for future in as_completed(futures): |
| task_id = futures[future] |
| try: |
| responses += future.result() |
| except Exception as e: |
| logger.info(f'Perform infer task: {task_id} get an error: {e}') |
| return responses |
| elif isinstance(infer_requests, list): |
| responses = [] |
| if isinstance(request_configs, list): |
| assert len(infer_requests) <= len(request_configs) |
| for i in range(len(infer_requests)): |
| responses += infer_engines.infer( |
| [infer_requests[i]], |
| request_configs[i], |
| **infer_kwargs, |
| ) |
| elif isinstance(request_configs, RequestConfig): |
| for infer_request in infer_requests: |
| responses += infer_engines.infer( |
| [infer_request], |
| request_configs, |
| **infer_kwargs, |
| ) |
| return responses |
| return infer_engines.infer( |
| [infer_requests], |
| request_configs, |
| **infer_kwargs, |
| ) |
|
|
|
|
| def collect_from_mct(monte_carlo_tree, collect_filter_threshold): |
| from transformers.utils import strtobool |
| if isinstance(monte_carlo_tree, str): |
| monte_carlo_tree = json.loads(monte_carlo_tree) |
|
|
| def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): |
| _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] |
| _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']] |
| _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']] |
| if len(collect_curr_node['children']) > 0: |
| for child in collect_curr_node['children']: |
| p, c, i = _collect(child, _outcome_rewards, _process_rewards) |
| _prefer_pairs += p |
| _correct_answers += c |
| _incorrect_answers += i |
| sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward']) |
| if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold: |
| |
| prefer_pair = { |
| 'path': 'ки\n'.join(collect_curr_node['path']), |
| 'good': sorted_children[-1]['path'][-1], |
| 'good_score': sorted_children[-1]['outcome_reward'], |
| 'bad': sorted_children[0]['path'][-1], |
| 'bad_score': sorted_children[0]['outcome_reward'], |
| } |
| _prefer_pairs.append(prefer_pair) |
| if strtobool(collect_curr_node['terminated']): |
| _answer = { |
| 'answer': 'ки\n'.join(collect_curr_node['path']), |
| 'mean_outcome_reward': np.mean(_outcome_rewards), |
| 'min_outcome_reward': np.min(_outcome_rewards), |
| 'mean_process_reward': np.mean(_process_rewards), |
| 'min_process_reward': np.min(_process_rewards), |
| } |
| if strtobool(collect_curr_node['correct']): |
| _correct_answers.append(_answer) |
| else: |
| _incorrect_answers.append(_answer) |
| return _prefer_pairs, _correct_answers, _incorrect_answers |
|
|
| _root = monte_carlo_tree |
| prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) |
| return prefer_pairs, correct_answers, incorrect_answers |
|
|