import re import textwrap from copy import deepcopy from typing import Dict, List import torch from swift.llm import PtEngine, RequestConfig, Template, to_device from swift.llm.infer.protocol import ChatCompletionResponse from swift.utils import get_logger logger = get_logger() class DefaultRMPlugin: """ Default Reward Model Plugin This class implements the default processing logic for reward models. It assumes that `self.model` is a classification model with a value head(output dimmension 1). The first logits value from the model's output is used as the reward score. """ def __init__(self, model, template): self.model = model self.template: Template = template def __call__(self, inputs): batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) reward_inputs.pop('labels') with torch.inference_mode(): return self.model(**reward_inputs).logits[:, 0] class GenRMPlugin(DefaultRMPlugin): def __init__(self, model, template): """ Generative Reward Model Plugin Example. This method sets up the reward model plugin by initializing the PtEngine for efficient inference, configuring the request parameters, and defining the system prompt that guides the reward model in evaluating responses. Args: model (torch.nn.Module): The generative reward model. template (Template): The template used for encoding input data. """ super().__init__(model, template) # initilize PTEngine to infer self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit self.request_config = RequestConfig() # customise your request config here self.system = textwrap.dedent(""" Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant. Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct. Before finishing your response, please assign a reward using the following format: Reward: {reward} For example: Reward: 0.85 """) # noqa def __call__(self, inputs): """ Compute reward scores for the provided inputs. This method processes each input by converting dialogue messages into a query, sending the query to the reward model for inference, and extracting the reward scores from the model's responses. The final reward for each input is the average of all extracted scores. Args: inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing: - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes: - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). - 'content' (str): The content of the message. - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images'). Returns: torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,), where N is the number of input requests. """ rm_inputs = self.prepare_rm_inputs(inputs) results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False) rewards = self.compute_rewards(results) return torch.tensor(rewards, dtype=torch.float32) def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]: """ Prepare inputs for the reward model by converting messages into queries. Args: inputs (List[Dict]): A list of input requests. Returns: List[Dict]: Processed inputs for the reward model. """ rm_inputs = [] for idx, infer_request in enumerate(inputs): # Deep copy to prevent modification of original input rm_infer_request = deepcopy(infer_request) # Extract and convert messages to a single query string messages = rm_infer_request.get('messages') query = self.messages_to_query(messages) # Construct new messages tailored for the reward model rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}] # Update the messages in the reward infer request rm_infer_request['messages'] = rm_messages rm_inputs.append(rm_infer_request) return rm_inputs @staticmethod def extract_reward(model_output: str) -> float: """ Extract the reward score from the model's output. Args: model_output (str): The model's output string, expected to follow the format "Reward: {reward}". Returns: float: The extracted reward score. Raises: ValueError: If the reward score cannot be extracted or the format is incorrect. """ match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output) if match: return float(match.group(1)) else: logger.warning("Unable to extract reward score from the model's output, set reward to 0") return None @staticmethod def messages_to_query(messages): """ Compress a list of message dictionaries into a single query string. Args: messages (list[dict]): A list of message dictionaries, each containing: - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). - 'content' (str): The content of the message. Returns: str: A single string that concatenates all messages in a formatted manner. Example: >>> messages = [ ... {'role': 'user', 'content': 'Hello, how are you?'}, ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'}, ... {'role': 'user', 'content': 'Can you help me with my homework?'} ... ] >>> print(messages_to_query(messages)) User: Hello, how are you? Assistant: I am fine, thank you! How can I assist you today? User: Can you help me with my homework? """ # Initialize an empty list to hold formatted messages formatted_messages = [] # Define a mapping for role capitalization if needed role_mapping = { 'user': 'User', 'assistant': 'Assistant', 'system': 'System' # Add more roles here as needed } for idx, message in enumerate(messages): if not isinstance(message, dict): raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.') # Extract 'role' and 'content' from each message role = message.get('role') content = message.get('content') if not content: continue # Capitalize the role using the mapping, default to capitalized original role role_formatted = role_mapping.get(role.lower(), role.capitalize()) # Append the formatted message to the list formatted_messages.append(f'{role_formatted}: {content}') # Join all formatted messages with newline characters query = '\n'.join(formatted_messages) return query def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]: """ Compute average reward scores from the reward model's outputs. Args: results (List[ChatCompletionResponse]): A list of results from the reward model. Returns: List[float]: A list of average reward scores. """ rewards = [] for idx, output in enumerate(results): try: cur_rewards = [] for choice in output.choices: response = choice.message.content reward = self.extract_reward(response) cur_rewards.append(reward) cur_rewards = [r for r in cur_rewards if r is not None] if cur_rewards: average_reward = sum(cur_rewards) / len(cur_rewards) else: average_reward = 0.0 logger.warning('No valid rewards extracted. Assigning reward score of 0.0.') rewards.append(average_reward) except Exception as e: logger.error(f'Error computing reward: {e}') rewards.append(0.0) # Assign default reward score on failure return rewards rm_plugins = { 'default': DefaultRMPlugin, 'genrm': GenRMPlugin, }