| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import asyncio |
| import atexit |
| import copy |
| import importlib.resources as pkg_resources |
| import inspect |
| import math |
| import os |
| import sys |
| import textwrap |
| import time |
| import warnings |
| from collections import defaultdict, deque |
| from collections.abc import Callable |
| from contextlib import nullcontext |
| from pathlib import Path |
| from typing import Any, Protocol |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.utils.data |
| import transformers |
| from accelerate.logging import get_logger |
| from accelerate.utils import gather, gather_object, is_peft_model, set_seed |
| from datasets import Dataset, IterableDataset |
| from huggingface_hub import CommitScheduler, DatasetCard, DatasetCardData, create_repo |
| from packaging.version import Version |
| from torch import nn |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.utils.data import Sampler |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoProcessor, |
| AutoTokenizer, |
| GenerationConfig, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| ProcessorMixin, |
| TrainerCallback, |
| is_trackio_available, |
| is_wandb_available, |
| ) |
| from transformers.utils import is_peft_available, is_rich_available |
|
|
| from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response |
| from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages |
| from ..extras.profiling import profiling_context, profiling_decorator |
| from ..generation.vllm_generation import VLLMGeneration |
| from ..import_utils import is_jmespath_available, is_liger_kernel_available |
| from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation |
| from ..models.utils import _ForwardRedirection, disable_gradient_checkpointing |
| from .base_trainer import _BaseTrainer |
| from .callbacks import SyncRefModelCallback |
| from .grpo_config import GRPOConfig |
| from .utils import ( |
| RepeatSampler, |
| create_model_from_path, |
| disable_dropout_in_model, |
| entropy_from_logits, |
| get_config_model_id, |
| identity, |
| nanmax, |
| nanmin, |
| nanstd, |
| pad, |
| print_prompt_completions_sample, |
| selective_log_softmax, |
| shuffle_sequence_dict, |
| shutdown_event_loop_in_daemon, |
| split_pixel_values_by_grid, |
| split_tensor_dict, |
| start_event_loop_in_daemon, |
| unsplit_pixel_values_by_grid, |
| use_adapter, |
| ) |
|
|
|
|
| if is_peft_available(): |
| from peft import PeftConfig, PeftModel, get_peft_model |
|
|
| if is_liger_kernel_available(): |
| from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss |
|
|
|
|
| if is_wandb_available(): |
| import wandb |
|
|
| if is_trackio_available(): |
| import trackio |
|
|
|
|
| logger = get_logger(__name__) |
|
|
| |
| |
| |
| |
| RewardFunc = str | PreTrainedModel | Callable[..., list[float | None]] |
|
|
| |
| |
| |
| RolloutFunc = Callable[[list[str], "GRPOTrainer"], dict[str, Any]] |
|
|
|
|
| class _SupportsReset(Protocol): |
| def reset(self, **kwargs) -> str | None: ... |
|
|
|
|
| EnvironmentFactory = Callable[[], _SupportsReset] |
|
|
|
|
| class GRPOTrainer(_BaseTrainer): |
| """ |
| Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the |
| paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language |
| Models](https://huggingface.co/papers/2402.03300). |
| |
| Example: |
| |
| ```python |
| from trl import GRPOTrainer |
| from trl.rewards import accuracy_reward |
| from datasets import load_dataset |
| |
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") |
| |
| trainer = GRPOTrainer( |
| model="Qwen/Qwen2.5-0.5B-Instruct", |
| reward_funcs=accuracy_reward, |
| train_dataset=dataset, |
| ) |
| trainer.train() |
| ``` |
| |
| Args: |
| model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): |
| Model to be trained. Can be either: |
| |
| - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a |
| path to a *directory* containing model weights saved using |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded |
| using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model |
| config) with the keyword arguments in `args.model_init_kwargs`. |
| - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. |
| - A [`~peft.PeftModel`] object. Only causal language models are supported. |
| reward_funcs (`RewardFunc | list[RewardFunc]`): |
| Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward |
| functions with the prompts and completions and sum the rewards. Can be either: |
| |
| - A single reward function, such as: |
| - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a |
| path to a *directory* containing model weights saved using |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded |
| using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the |
| keyword arguments in `args.model_init_kwargs`. |
| - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. |
| - A custom reward function: The function is provided with the prompts and the generated completions, |
| plus any additional columns in the dataset. It should return a list of rewards. Custom reward |
| functions can be either synchronous or asynchronous and can also return `None` when the reward is |
| not applicable to those samples. This is useful for multi-task training where different reward |
| functions apply to different types of samples. When a reward function returns `None` for a sample, |
| that reward function is excluded from the reward calculation for that sample. For more details, see |
| [Using a custom reward |
| function](#using-a-custom-reward-function). |
| |
| The trainer's state is also passed to the reward function. The trainer's state is an instance of |
| [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the |
| reward function's signature. |
| - A list of reward functions, where each item can independently be any of the above types. Mixing different |
| types within the list (e.g., a string model ID and a custom reward function) is allowed. |
| args ([`GRPOConfig`], *optional*): |
| Configuration for this trainer. If `None`, a default configuration is used. |
| train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): |
| Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is |
| ignored. The format of the samples can be either: |
| |
| - [Standard](dataset_formats#standard): Each sample contains plain text. |
| - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role |
| and content). |
| eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): |
| Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. |
| processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): |
| Processing class used to process the data. The padding side must be set to "left". If `None`, the |
| processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A |
| padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, |
| `tokenizer.eos_token` will be used as the default. |
| reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): |
| Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: |
| |
| - A single processing class: Used when `reward_funcs` contains only one reward function. |
| - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. |
| If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is |
| `None`, the tokenizer for the model is automatically loaded using |
| [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward |
| functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` |
| are ignored. |
| callbacks (list of [`~transformers.TrainerCallback`], *optional*): |
| List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed |
| in [here](https://huggingface.co/docs/transformers/main_classes/callback). |
| |
| If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] |
| method. |
| optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): |
| A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your |
| model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. |
| peft_config ([`~peft.PeftConfig`], *optional*): |
| PEFT configuration used to wrap the model. If `None`, the model is not wrapped. |
| tools (list of `Callable`, *optional*): |
| A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool |
| should be a standard Python function with properly type-hinted arguments and return values, and a |
| Google-style docstring describing its purpose, arguments, and return value. For more details, see: |
| https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, |
| type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool |
| use and that it has been fine-tuned for tool calling. |
| rollout_func (`RolloutFunc`, *optional*): |
| Function to use for generating completions. It receives the list of prompts allocated to the current |
| process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and |
| `"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any |
| other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice |
| with no duplication; it is responsible for returning the correct number of completions per prompt (see |
| `num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or |
| be removed at any time without prior notice. |
| environment_factory (`EnvironmentFactory`, *optional*): |
| A callable that creates and returns an environment instance. The environment class should define methods |
| that can be invoked as tools during generation. Each method should comply with the same requirements as the |
| `tools` described above. If `environment_factory` is provided, an instance of the environment is created |
| for each generation in the batch, allowing for parallel and independent interactions. The environment must |
| also implement a callable `reset` method that can be used to reset state between generations. The `reset` |
| method should return either `None` or a string: when it returns a string, that string is appended to the |
| last user message before generation. This feature is experimental and may change or be removed at any time |
| without prior notice. |
| """ |
|
|
| _tag_names = ["trl", "grpo"] |
| _name = "GRPO" |
| _paper = { |
| "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", |
| "id": "2402.03300", |
| |
| "citation": textwrap.dedent("""\ |
| @article{shao2024deepseekmath, |
| title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, |
| author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, |
| year = 2024, |
| eprint = {arXiv:2402.03300}, |
| }"""), |
| } |
|
|
| def __init__( |
| self, |
| model: "str | PreTrainedModel | PeftModel", |
| reward_funcs: RewardFunc | list[RewardFunc], |
| args: GRPOConfig | None = None, |
| train_dataset: Dataset | IterableDataset | None = None, |
| eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, |
| processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, |
| reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, |
| callbacks: list[TrainerCallback] | None = None, |
| optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), |
| peft_config: "PeftConfig | None" = None, |
| tools: list[Callable] | None = None, |
| rollout_func: RolloutFunc | None = None, |
| environment_factory: EnvironmentFactory | None = None, |
| ): |
| |
| if args is None: |
| model_name = model if isinstance(model, str) else get_config_model_id(model.config) |
| model_name = model_name.split("/")[-1] |
| args = GRPOConfig(f"{model_name}-GRPO") |
|
|
| |
| if isinstance(model, str): |
| model_init_kwargs = args.model_init_kwargs or {} |
| |
| if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: |
| model_init_kwargs["device_map"] = None |
| model = create_model_from_path(model, **model_init_kwargs) |
| else: |
| if args.model_init_kwargs is not None: |
| logger.warning( |
| "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " |
| "The `model_init_kwargs` will be ignored." |
| ) |
|
|
| |
| |
| self.model_kwarg_keys = ( |
| inspect.signature(model.forward).parameters.keys() |
| if not hasattr(model, "get_base_model") |
| else inspect.signature(model.get_base_model().forward).parameters.keys() |
| ) |
|
|
| |
| if processing_class is None: |
| processing_class = AutoProcessor.from_pretrained( |
| get_config_model_id(model.config), truncation_side="left", padding_side="left" |
| ) |
|
|
| |
| if isinstance(processing_class, ProcessorMixin): |
| tokenizer = processing_class.tokenizer |
| elif isinstance(processing_class, PreTrainedTokenizerBase): |
| tokenizer = processing_class |
| else: |
| raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| self.pad_token = tokenizer.pad_token |
| self.pad_token_id = tokenizer.pad_token_id |
| self.eos_token_id = tokenizer.eos_token_id |
|
|
| if is_peft_available() and is_peft_model(model) and peft_config is not None: |
| raise ValueError( |
| "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " |
| "and unload the existing adapter, save the resulting base model, and then pass that base model along " |
| "with the new `peft_config` to the trainer." |
| ) |
| if is_peft_available() and is_peft_model(model) and args.beta != 0.0: |
| |
| |
| model.add_adapter("ref", model.peft_config["default"]) |
| for name, param in model.named_parameters(): |
| if ".default." in name: |
| ref_name = name.replace(".default.", ".ref.") |
| ref_param = model.get_parameter(ref_name) |
| ref_param.data.copy_(param.data) |
|
|
| |
| if peft_config is not None: |
| model = get_peft_model(model, peft_config) |
|
|
| |
| |
| if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: |
| model.enable_input_require_grads() |
|
|
| |
| |
| |
| |
| |
| if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): |
| for param in model.parameters(): |
| if param.requires_grad: |
| param.data = param.data.to(torch.bfloat16) |
|
|
| |
| if not isinstance(reward_funcs, list): |
| reward_funcs = [reward_funcs] |
| self.reward_func_names = [] |
| for i, reward_func in enumerate(reward_funcs): |
| if isinstance(reward_func, str): |
| model_init_kwargs = args.model_init_kwargs or {} |
| |
| if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: |
| model_init_kwargs["device_map"] = None |
| reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( |
| reward_func, num_labels=1, **model_init_kwargs |
| ) |
| if isinstance(reward_funcs[i], nn.Module): |
| self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) |
| else: |
| self.reward_func_names.append(reward_funcs[i].__name__) |
| self.reward_funcs = reward_funcs |
|
|
| |
| if args.reward_weights is not None: |
| if len(args.reward_weights) != len(reward_funcs): |
| raise ValueError( |
| f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " |
| f"functions ({len(reward_funcs)})" |
| ) |
| self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) |
| else: |
| self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) |
|
|
| |
| if reward_processing_classes is None: |
| reward_processing_classes = [None] * len(reward_funcs) |
| elif not isinstance(reward_processing_classes, list): |
| reward_processing_classes = [reward_processing_classes] |
| if len(reward_processing_classes) != len(reward_funcs): |
| raise ValueError( |
| f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " |
| f"reward functions ({len(reward_funcs)})." |
| ) |
|
|
| for i, (reward_processing_class, reward_func) in enumerate( |
| zip(reward_processing_classes, reward_funcs, strict=True) |
| ): |
| if isinstance(reward_func, PreTrainedModel): |
| if reward_processing_class is None: |
| reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) |
| if reward_processing_class.pad_token_id is None: |
| reward_processing_class.pad_token = reward_processing_class.eos_token |
| |
| |
| reward_func.config.pad_token_id = reward_processing_class.pad_token_id |
| reward_processing_classes[i] = reward_processing_class |
|
|
| self.reward_processing_classes = reward_processing_classes |
|
|
| |
| if rollout_func is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": |
| warnings.warn( |
| "You are using 'rollout_func', which is an experimental feature. This API may change or be removed at " |
| "any time without prior notice. Silence this warning by setting environment variable " |
| "TRL_EXPERIMENTAL_SILENCE=1.", |
| UserWarning, |
| stacklevel=2, |
| ) |
| self.rollout_func = rollout_func |
| if environment_factory is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": |
| warnings.warn( |
| "You are using 'environment_factory', which is an experimental feature. This API may change or be " |
| "removed at any time without prior notice. Silence this warning by setting environment variable " |
| "TRL_EXPERIMENTAL_SILENCE=1.", |
| UserWarning, |
| stacklevel=2, |
| ) |
|
|
| |
| if tools: |
| if not Version(transformers.__version__) >= Version("5.0.0"): |
| raise ImportError( |
| "Using tools with GRPOTrainer requires transformers version 5.0.0 or higher. Please upgrade " |
| "transformers with `pip install --upgrade transformers` to use this feature." |
| ) |
| if environment_factory: |
| if not Version(transformers.__version__) >= Version("5.2.0"): |
| raise ImportError( |
| "Using `environment_factory` with GRPOTrainer requires transformers version 5.2.0 or higher. " |
| "Please install transformers from the main branch with `pip install " |
| "git+https://github.com/huggingface/transformers.git@main` to use this feature." |
| ) |
| if tools or environment_factory: |
| if not is_jmespath_available(): |
| raise ImportError( |
| "Using tools with GRPOTrainer requires the jmespath library for response parsing. Please install " |
| "it with `pip install jmespath` to use this feature." |
| ) |
|
|
| |
| generation_batch_size = args.per_device_train_batch_size * args.steps_per_generation |
| if environment_factory is not None: |
| self.environments = [environment_factory() for _ in range(generation_batch_size)] |
| environment_methods = [[] for _ in range(generation_batch_size)] |
| for i, environment in enumerate(self.environments): |
| has_reset = False |
| for name, member in inspect.getmembers(environment, predicate=inspect.ismethod): |
| if name == "reset": |
| has_reset = True |
| elif not name.startswith("_"): |
| environment_methods[i].append(member) |
| if not has_reset: |
| raise ValueError( |
| "Each environment instance returned by `environment_factory` must define a callable `reset` " |
| ) |
| else: |
| self.environments = None |
|
|
| tools = tools or [] |
| self._sync_tool_dicts = [{} for _ in range(generation_batch_size)] |
| self._async_tool_dicts = [{} for _ in range(generation_batch_size)] |
| for i in range(generation_batch_size): |
| for tool in tools + (environment_methods[i] if self.environments is not None else []): |
| if inspect.iscoroutinefunction(tool): |
| self._async_tool_dicts[i][tool.__name__] = tool |
| else: |
| self._sync_tool_dicts[i][tool.__name__] = tool |
|
|
| self.tools = tools + (environment_methods[0] if self.environments is not None else []) |
|
|
| |
| self._has_async_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs + self.tools) |
|
|
| if self._has_async_funcs: |
| self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon( |
| name="GRPOTrainer-AsyncLoop" |
| ) |
| |
| self.async_loop_ready_event.wait() |
| atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop) |
|
|
| |
| |
| |
| |
| if self.tools and not getattr(processing_class, "response_schema", None): |
| processing_class = add_response_schema(processing_class) |
| |
| |
| if self.tools: |
| self.chat_template = get_training_chat_template(processing_class) |
| else: |
| self.chat_template = None |
|
|
| |
| self.max_completion_length = args.max_completion_length |
| self.num_generations = args.num_generations |
| self.max_tool_calling_iterations = args.max_tool_calling_iterations or sys.maxsize |
| self.num_generations_eval = args.num_generations_eval or self.num_generations |
| self.chat_template_kwargs = args.chat_template_kwargs or {} |
| self.temperature = args.temperature |
| self.top_p = args.top_p |
| self.top_k = args.top_k |
| self.min_p = args.min_p |
| self.repetition_penalty = args.repetition_penalty |
| self.use_transformers_paged = args.use_transformers_paged |
| self.pad_to_multiple_of = args.pad_to_multiple_of |
| self.use_vllm = args.use_vllm |
| self.vllm_mode = args.vllm_mode |
| self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization |
| self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size |
| self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction |
| self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode |
| self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap |
| self.use_liger_kernel = args.use_liger_kernel |
| self.loss_type = args.loss_type |
| self.multi_objective_aggregation = args.multi_objective_aggregation |
| self.scale_rewards = args.scale_rewards |
| self.importance_sampling_level = args.importance_sampling_level |
| self.off_policy_mask_threshold = args.off_policy_mask_threshold |
| if self.use_liger_kernel and self.off_policy_mask_threshold is not None: |
| raise ValueError("Liger kernel does not support off-policy sequence masking yet.") |
| self.mask_truncated_completions = args.mask_truncated_completions |
| self.top_entropy_quantile = args.top_entropy_quantile |
| if self.use_liger_kernel and self.top_entropy_quantile < 1.0: |
| raise NotImplementedError( |
| "Liger Kernels don't currently support masking token positions based on entropy." |
| ) |
| if self.use_liger_kernel and self.importance_sampling_level not in ("token", "sequence"): |
| raise ValueError( |
| f"Unknown importance sampling level: {self.importance_sampling_level}. " |
| "Possible values are 'token' and 'sequence'." |
| ) |
|
|
| |
| self.shuffle_dataset = args.shuffle_dataset |
|
|
| if train_dataset is None: |
| raise ValueError("`train_dataset` is required") |
| elif ( |
| isinstance(train_dataset, IterableDataset) |
| or isinstance(eval_dataset, IterableDataset) |
| or ( |
| isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) |
| ) |
| ): |
| |
| raise NotImplementedError( |
| "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." |
| ) |
|
|
| if args.loss_type == "luspo" and args.importance_sampling_level != "sequence": |
| logger.warning( |
| "When using `'luspo'` loss, `importance_sampling_level` should be set to `'sequence'` to mirror the " |
| "paper's setup." |
| ) |
|
|
| if args.loss_type == "vespo" and args.importance_sampling_level != "token": |
| logger.warning( |
| "VESPO computes sequence-level importance weights internally. `importance_sampling_level` should be " |
| "set to `'token'` (the default)." |
| ) |
|
|
| if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction: |
| if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]: |
| raise ValueError( |
| f"VESPO loss requires `vllm_importance_sampling_mode` to be either 'token_truncate' or " |
| f"'token_mask'. Got: {self.vllm_importance_sampling_mode}." |
| ) |
|
|
| |
| self.num_iterations = args.num_iterations |
| self.epsilon_low = args.epsilon |
| self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon |
| |
| self._step = 0 |
| |
| |
| self._buffered_inputs = None |
|
|
| |
| |
| |
| |
| if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): |
| args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} |
| args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) |
|
|
| super().__init__( |
| model=model, |
| args=args, |
| data_collator=identity, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=processing_class, |
| callbacks=callbacks, |
| optimizers=optimizers, |
| |
| |
| |
| |
| |
| compute_loss_func="non-None value to disable scaling", |
| ) |
|
|
| |
| self.beta = args.beta |
| if self.beta == 0.0: |
| |
| self.ref_model = None |
| elif is_peft_model(model): |
| |
| |
| self.ref_model = None |
| else: |
| |
| model_init_kwargs = args.model_init_kwargs or {} |
| |
| if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: |
| model_init_kwargs["device_map"] = None |
| self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) |
|
|
| |
| if args.disable_dropout: |
| disable_dropout_in_model(model) |
| if self.ref_model is not None: |
| disable_dropout_in_model(self.ref_model) |
|
|
| |
| if args.cast_lm_head_to_fp32: |
|
|
| def _cast_lm_head_to_fp32(target_model: PreTrainedModel): |
| """Cast lm_head to fp32 while preserving embedding output dtype if tied.""" |
|
|
| def cast_inputs_to_fp32(module, inputs): |
| |
| if not inputs: |
| return inputs |
| return (inputs[0].to(torch.float32),) + inputs[1:] |
|
|
| original_dtype_local = target_model.lm_head.weight.dtype |
| target_model.lm_head = target_model.lm_head.float() |
| target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) |
|
|
| if target_model.config.tie_word_embeddings: |
|
|
| def cast_outputs_to_original_dtype(module, args, output): |
| return output.to(original_dtype_local) |
|
|
| |
| target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype) |
|
|
| _cast_lm_head_to_fp32(model) |
| if self.ref_model is not None: |
| _cast_lm_head_to_fp32(self.ref_model) |
|
|
| |
| if self.use_liger_kernel: |
| if not is_liger_kernel_available(): |
| raise ImportError( |
| "Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`." |
| ) |
| |
| self._forward_redirection = _ForwardRedirection() |
|
|
| self.liger_grpo_loss = LigerFusedLinearGRPOLoss( |
| beta=self.beta, |
| epsilon_low=self.epsilon_low, |
| epsilon_high=self.epsilon_high, |
| temperature=self.temperature, |
| use_ref_model=self.beta != 0.0, |
| loss_type=self.loss_type, |
| max_completion_length=self.max_completion_length, |
| importance_sampling_level=self.importance_sampling_level, |
| ) |
|
|
| |
| self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} |
| self._total_train_tokens = 0 |
| self._current_train_step_time = 0.0 |
| self.log_completions = args.log_completions |
| self.log_unique_prompts = args.log_unique_prompts |
| self.num_completions_to_print = args.num_completions_to_print |
| |
| self._logs = { |
| "images": deque(maxlen=args.generation_batch_size), |
| "prompt": deque(maxlen=args.generation_batch_size), |
| "completion": deque(maxlen=args.generation_batch_size), |
| "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), |
| "advantages": deque(maxlen=args.generation_batch_size), |
| "extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), |
| } |
| |
| self._pending_extra_logs = defaultdict(list) |
| self._pending_metrics = defaultdict(list) |
|
|
| |
| |
| |
| set_seed(args.seed, device_specific=True) |
|
|
| if self.use_vllm: |
| |
| self.vllm_generation = VLLMGeneration( |
| model=self.model, |
| accelerator=self.accelerator, |
| is_fsdp_enabled=self.is_fsdp_enabled, |
| processing_class=self.processing_class, |
| |
| mode=args.vllm_mode, |
| structured_outputs_regex=args.vllm_structured_outputs_regex, |
| |
| server_base_url=args.vllm_server_base_url, |
| server_host=args.vllm_server_host, |
| server_port=args.vllm_server_port, |
| group_port=args.vllm_group_port, |
| server_timeout=args.vllm_server_timeout, |
| |
| tensor_parallel_size=args.vllm_tensor_parallel_size, |
| gpu_memory_utilization=args.vllm_gpu_memory_utilization, |
| max_model_length=args.vllm_max_model_length, |
| max_num_seqs=args.per_device_train_batch_size |
| * args.vllm_tensor_parallel_size |
| * args.steps_per_generation, |
| enable_sleep_mode=args.vllm_enable_sleep_mode, |
| model_impl=args.vllm_model_impl, |
| |
| repetition_penalty=self.repetition_penalty, |
| temperature=self.temperature, |
| top_p=self.top_p, |
| top_k=self.top_k, |
| min_p=self.min_p, |
| max_completion_length=self.max_completion_length, |
| logprobs=0, |
| generation_kwargs=args.generation_kwargs, |
| ) |
| self._last_loaded_step = -1 |
| else: |
| generation_kwargs = { |
| "max_new_tokens": self.max_completion_length, |
| "do_sample": True, |
| "pad_token_id": tokenizer.pad_token_id, |
| "bos_token_id": tokenizer.bos_token_id, |
| "eos_token_id": tokenizer.eos_token_id, |
| "temperature": self.temperature, |
| "top_p": self.top_p, |
| "top_k": self.top_k, |
| "min_p": self.min_p, |
| "repetition_penalty": self.repetition_penalty, |
| "cache_implementation": args.cache_implementation, |
| } |
| if args.generation_kwargs is not None: |
| generation_kwargs.update(args.generation_kwargs) |
| self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True) |
| |
| self.generation_kwargs = generation_kwargs |
|
|
| |
| |
| |
| self.model_accepts_loss_kwargs = False |
|
|
| |
| self.model.add_model_tags(self._tag_names) |
|
|
| if self.ref_model is not None: |
| if self.is_deepspeed_enabled: |
| self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
| elif self.is_fsdp_enabled: |
| self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) |
| else: |
| self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) |
|
|
| if args.sync_ref_model: |
| if self.beta == 0.0: |
| raise ValueError( |
| "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used " |
| "during training. Consequently, GRPOTrainer does not create a `ref_model` instance, and there is " |
| "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value." |
| ) |
| if is_peft_model(model): |
| raise NotImplementedError( |
| "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. " |
| "With PEFT, GRPOTrainer does not keep a separate reference model in memory; instead, it recovers " |
| "reference behavior by temporarily disabling the adapter. As a result, there is no standalone " |
| "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if " |
| "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a " |
| "feature request at https://github.com/huggingface/trl/issues." |
| ) |
| self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) |
|
|
| for i, reward_func in enumerate(self.reward_funcs): |
| if isinstance(reward_func, PreTrainedModel): |
| if self.is_deepspeed_enabled: |
| self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) |
| else: |
| |
| self.reward_funcs[i] = self.accelerator.prepare_model( |
| reward_func, evaluation_mode=True, device_placement=True |
| ) |
|
|
| if self.accelerator.is_main_process and self.log_completions: |
| os.makedirs(os.path.join(self.args.output_dir, "completions"), exist_ok=True) |
| if self.args.log_completions_hub_repo is not None: |
| repo_id = self.args.log_completions_hub_repo |
| create_repo(repo_id, private=self.args.hub_private_repo, repo_type="dataset", exist_ok=True) |
| template_path = pkg_resources.files("trl").joinpath("templates/completions_dataset_card.md") |
| card_data = DatasetCardData( |
| pretty_name="TRL Completion logs", |
| tags=["trl", "trl-logs", "completions"], |
| ) |
| card = DatasetCard.from_template( |
| card_data=card_data, |
| template_path=str(template_path), |
| repo_id=repo_id, |
| hub_model_id=self.args.hub_model_id, |
| ) |
| card.push_to_hub(repo_id) |
| self.commit_scheduler = CommitScheduler( |
| repo_id=repo_id, |
| repo_type="dataset", |
| folder_path=f"{self.args.output_dir}/completions", |
| every=2, |
| allow_patterns=["*.parquet"], |
| ) |
|
|
| def _set_signature_columns_if_needed(self): |
| |
| |
| |
| |
| if self._signature_columns is None: |
| self._signature_columns = ["prompt", "image", "images"] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def get_train_dataloader(self): |
| return self._get_dataloader( |
| dataset=self.train_dataset, |
| description="Training", |
| batch_size=self._train_batch_size * self.args.steps_per_generation, |
| sampler_fn=self._get_train_sampler, |
| is_training=True, |
| ) |
|
|
| def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if dataset is None: |
| dataset = self.train_dataset |
| return RepeatSampler( |
| data_source=dataset, |
| mini_repeat_count=self.num_generations, |
| batch_size=self.args.generation_batch_size // self.num_generations, |
| repeat_count=self.num_iterations * self.args.steps_per_generation, |
| shuffle=self.shuffle_dataset, |
| seed=self.args.seed, |
| ) |
|
|
| def _get_eval_sampler(self, eval_dataset) -> Sampler: |
| |
| return RepeatSampler( |
| data_source=eval_dataset, |
| mini_repeat_count=self.num_generations_eval, |
| seed=self.args.seed, |
| ) |
|
|
| @profiling_decorator |
| def _get_last_hidden_state( |
| self, |
| unwrapped_model, |
| input_ids, |
| attention_mask, |
| logits_to_keep, |
| pixel_values=None, |
| image_grid_thw=None, |
| pixel_attention_mask=None, |
| image_sizes=None, |
| pixel_position_ids=None, |
| ): |
| if is_peft_model(unwrapped_model): |
| unwrapped_model = unwrapped_model.base_model.model |
|
|
| |
| model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| |
| if image_grid_thw is not None and pixel_values is not None: |
| model_inputs["image_grid_thw"] = image_grid_thw |
| |
| if pixel_values is not None: |
| model_inputs["pixel_values"] = pixel_values |
| |
| if pixel_attention_mask is not None: |
| model_inputs["pixel_attention_mask"] = pixel_attention_mask |
| |
| if image_sizes is not None: |
| model_inputs["image_sizes"] = image_sizes |
| if pixel_position_ids is not None: |
| model_inputs["pixel_position_ids"] = pixel_position_ids |
|
|
| |
| if "logits_to_keep" in self.model_kwarg_keys: |
| |
| model_inputs["logits_to_keep"] = logits_to_keep + 1 |
|
|
| model_inputs["use_cache"] = False |
|
|
| last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state |
| |
| last_hidden_state = last_hidden_state[:, :-1, :] |
| |
| last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] |
| return last_hidden_state |
|
|
| def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: |
| """ |
| Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. |
| |
| Args: |
| entropies (`torch.Tensor`): |
| Tensor of shape (batch_size, seq_len) with per-token entropy values. |
| mask (`torch.Tensor`): |
| Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. |
| threshold (`float`): |
| Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. |
| |
| Returns: |
| `torch.Tensor`: |
| Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold |
| and `False` otherwise. |
| """ |
| local = entropies[mask.bool()].float() |
|
|
| |
| |
| pad_value = -1e9 |
|
|
| |
| padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) |
| gathered = self.accelerator.gather(padded) |
|
|
| |
| gathered = gathered[gathered != pad_value] |
|
|
| if gathered.numel() == 0: |
| return torch.zeros_like(entropies, dtype=torch.bool) |
|
|
| entropy_threshold = torch.quantile(gathered, threshold) |
| masked_entropies = entropies * mask.float() |
| entropy_mask = masked_entropies >= entropy_threshold |
| return entropy_mask & mask.bool() |
|
|
| @profiling_decorator |
| def _get_per_token_logps_and_entropies( |
| self, |
| model, |
| input_ids, |
| attention_mask, |
| logits_to_keep, |
| batch_size=None, |
| compute_entropy=False, |
| pixel_values=None, |
| image_grid_thw=None, |
| num_images=None, |
| pixel_attention_mask=None, |
| image_sizes=None, |
| token_type_ids=None, |
| mm_token_type_ids=None, |
| pixel_position_ids=None, |
| ) -> dict[str, torch.Tensor | None]: |
| """Compute log-probs and (optionally) entropies for each token.""" |
| batch_size = batch_size or input_ids.size(0) |
| all_logps = [] |
| all_entropies = [] |
| for start in range(0, input_ids.size(0), batch_size): |
| input_ids_batch = input_ids[start : start + batch_size] |
| attention_mask_batch = attention_mask[start : start + batch_size] |
|
|
| |
| model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} |
| if image_grid_thw is not None and pixel_values is not None: |
| rows_per_image = image_grid_thw.prod(dim=-1) |
| rows_per_sample = torch.split(rows_per_image, num_images) |
| rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) |
| cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) |
| row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() |
| model_inputs["pixel_values"] = pixel_values[row_start:row_end] |
| cum_imgs = torch.tensor([0] + num_images).cumsum(0) |
| img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] |
| model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] |
| elif pixel_values is not None: |
| model_inputs["pixel_values"] = pixel_values[start : start + batch_size] |
| if pixel_attention_mask is not None: |
| model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] |
| if image_sizes is not None: |
| model_inputs["image_sizes"] = image_sizes[start : start + batch_size] |
| if token_type_ids is not None: |
| model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] |
| if mm_token_type_ids is not None: |
| model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size] |
| if pixel_position_ids is not None: |
| model_inputs["pixel_position_ids"] = pixel_position_ids[start : start + batch_size] |
|
|
| |
| if "logits_to_keep" in self.model_kwarg_keys: |
| |
| model_inputs["logits_to_keep"] = logits_to_keep + 1 |
|
|
| model_inputs["use_cache"] = False |
|
|
| logits = model(**model_inputs).logits |
| |
| logits = logits[:, :-1, :] |
| |
| logits = logits[:, -logits_to_keep:, :] |
| |
| |
| logits.div_(self.temperature) |
| completion_ids = input_ids_batch[:, -logits_to_keep:] |
| logps = selective_log_softmax(logits, completion_ids) |
| all_logps.append(logps) |
|
|
| if compute_entropy: |
| with torch.no_grad(): |
| entropies = entropy_from_logits(logits) |
| all_entropies.append(entropies) |
|
|
| logps = torch.cat(all_logps, dim=0) |
| entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None |
| return logps, entropies |
|
|
| def training_step(self, model, inputs, num_items_in_batch): |
| time_before = time.perf_counter() |
| output = super().training_step(model, inputs, num_items_in_batch) |
| self._step += 1 |
| time_after = time.perf_counter() |
| self._current_train_step_time += time_after - time_before |
| if self._step % self.current_gradient_accumulation_steps == 0: |
| self._metrics["train"]["step_time"].append(self._current_train_step_time) |
| self._current_train_step_time = 0.0 |
| return output |
|
|
| @profiling_decorator |
| def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| mode = "train" if self.model.training else "eval" |
| if mode == "train": |
| generate_every = self.args.steps_per_generation * self.num_iterations |
| if self._step % generate_every == 0 or self._buffered_inputs is None: |
| |
| generation_batch = self._generate_and_score_completions(generation_batch) |
| generation_batch = split_pixel_values_by_grid(generation_batch) |
| generation_batch = shuffle_sequence_dict(generation_batch) |
| generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) |
| self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] |
| inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] |
| else: |
| |
| |
| inputs = self._generate_and_score_completions(generation_batch) |
| return inputs |
|
|
| def _log_completion_extra(self, column: str, values: list): |
| """ |
| Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg. |
| |
| Args: |
| column (`str`): |
| Name of the column to add. |
| values (`list`): |
| Values for the column, one per sample in the batch. |
| """ |
| self._pending_extra_logs[column].extend(values) |
|
|
| def _log_metric(self, name: str, value: float): |
| """ |
| Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged over each |
| logging step and reported alongside built-in metrics like `kl` and `entropy`. |
| |
| Args: |
| name (`str`): |
| Name of the metric. |
| value (`float`): |
| Scalar value for this batch. |
| """ |
| self._pending_metrics[name].append(value) |
|
|
| @profiling_decorator |
| def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): |
| device = self.accelerator.device |
| rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) |
|
|
| |
| keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] |
| reward_kwargs = {key: [example[key] for example in inputs] for key in keys} |
|
|
| |
| reward_kwargs["trainer_state"] = self.state |
|
|
| |
| reward_kwargs["log_extra"] = self._log_completion_extra |
|
|
| |
| reward_kwargs["log_metric"] = self._log_metric |
|
|
| async_funcs_info = [] |
|
|
| for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( |
| zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True) |
| ): |
| if isinstance(reward_func, nn.Module): |
| with profiling_context(self, reward_func_name): |
| if is_conversational(inputs[0]): |
| messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] |
| texts = [ |
| apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] |
| for x in messages |
| ] |
| else: |
| texts = [p + c for p, c in zip(prompts, completions, strict=True)] |
| reward_inputs = reward_processing_class( |
| text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False |
| ) |
| reward_inputs = super()._prepare_inputs(reward_inputs) |
| with torch.inference_mode(): |
| rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] |
| elif inspect.iscoroutinefunction(reward_func): |
| async_funcs_info.append((i, reward_func, reward_func_name)) |
| else: |
| |
| with profiling_context(self, reward_func_name): |
| if self.environments is not None: |
| reward_kwargs["environments"] = self.environments |
| output_reward_func = reward_func( |
| prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs |
| ) |
| |
| output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] |
| rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) |
|
|
| |
| if async_funcs_info: |
|
|
| async def _invoke_async(index, func, func_name): |
| with profiling_context(self, func_name): |
| output = await func( |
| prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs |
| ) |
| output = [r if r is not None else torch.nan for r in output] |
| return index, output |
|
|
| async def _run_async_funcs(): |
| coros = [_invoke_async(i, func, func_name) for (i, func, func_name) in async_funcs_info] |
| return await asyncio.gather(*coros) |
|
|
| async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_loop).result() |
| for idx, output_reward_func in async_results: |
| rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) |
|
|
| |
| if torch.isnan(rewards_per_func).all(dim=1).any(): |
| nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] |
| row_reward_kwargs = { |
| key: value[nan_row_idx] |
| for key, value in reward_kwargs.items() |
| if key not in ("trainer_state", "log_extra", "log_metric") |
| } |
| row_reward_kwargs["prompt"] = prompts[nan_row_idx] |
| row_reward_kwargs["completion"] = completions[nan_row_idx] |
| logger.warning( |
| f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" |
| "Please ensure that at least one reward function returns a valid reward." |
| ) |
|
|
| |
| |
| rewards_per_func = gather(rewards_per_func) |
| return rewards_per_func |
|
|
| def _tokenize_prompts(self, prompts: list): |
| """Tokenize prompts and extract images/multimodal fields for generation.""" |
| if is_conversational({"prompt": prompts[0]}): |
| |
| images = [] |
| has_images = False |
| for prompt in prompts: |
| prompt_images = [] |
| for message in prompt: |
| if isinstance(message["content"], list): |
| for part in message["content"]: |
| if part["type"] == "image": |
| prompt_images.append(part["image"]) |
| has_images = True |
| images.append(prompt_images if prompt_images else None) |
| images = images if has_images else None |
|
|
| |
| |
| |
| tokenized = self.processing_class.apply_chat_template( |
| conversation=prompts, |
| tools=self.tools or None, |
| chat_template=self.chat_template, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| padding=True, |
| **self.chat_template_kwargs, |
| ) |
| |
| prompt_ids = [ |
| [tok for tok, m in zip(ids, mask, strict=True) if m] |
| for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) |
| ] |
| |
| multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} |
| else: |
| prompt_ids = self.processing_class(text=prompts)["input_ids"] |
| images = None |
| multimodal_fields = {} |
| return prompt_ids, images, multimodal_fields |
|
|
| def _generate_single_turn(self, prompt_ids, images, multimodal_fields): |
| device = self.accelerator.device |
| mode = "train" if self.model.training else "eval" |
|
|
| |
| if self.use_vllm: |
| |
| if self.state.global_step != self._last_loaded_step: |
| with profiling_context(self, "sync_weights"): |
| self.vllm_generation.sync_weights() |
| self._last_loaded_step = self.state.global_step |
|
|
| |
| num_generations = self.num_generations if mode == "train" else self.num_generations_eval |
| _, completion_ids, logprobs, _ = self.vllm_generation.generate( |
| prompts=prompt_ids, |
| images=images, |
| num_generations=num_generations, |
| profiler=profiling_context(self, "vLLM.generate"), |
| ) |
| |
| logprobs = [[lp[0] for lp in seq] for seq in logprobs] |
|
|
| elif self.use_transformers_paged: |
| with ( |
| profiling_context(self, "transformers.generate_batch"), |
| unwrap_model_for_generation( |
| self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation |
| ) as unwrapped_model, |
| torch.no_grad(), |
| FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), |
| ): |
| |
| if self.args.bf16: |
| unwrapped_model.to(torch.bfloat16) |
| elif self.args.fp16: |
| unwrapped_model.to(torch.float16) |
| if self.args.cast_lm_head_to_fp32: |
| unwrapped_model.lm_head.to(torch.float32) |
| with torch.inference_mode(): |
| |
| all_outputs = unwrapped_model.generate_batch( |
| prompt_ids, generation_config=self.generation_config, progress_bar=False |
| ) |
| unwrapped_model.train() |
| completion_ids = [output.generated_tokens for output in all_outputs.values()] |
| logprobs = None |
|
|
| else: |
| |
| prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] |
| padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") |
| attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") |
| generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} |
| |
| for k, v in multimodal_fields.items(): |
| if isinstance(v, torch.Tensor): |
| generate_inputs[k] = v |
| elif isinstance(v, list) and v and isinstance(v[0], list): |
| |
| generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left") |
| else: |
| generate_inputs[k] = torch.tensor(np.array(v)) |
| generate_inputs = super()._prepare_inputs(generate_inputs) |
|
|
| with ( |
| profiling_context(self, "transformers.generate"), |
| unwrap_model_for_generation( |
| self.model_wrapped, |
| self.accelerator, |
| gather_deepspeed3_params=self.args.ds3_gather_for_generation, |
| generation_kwargs=self.generation_kwargs, |
| ) as unwrapped_model, |
| torch.no_grad(), |
| FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), |
| ): |
| prompt_completion_ids = unwrapped_model.generate( |
| **generate_inputs, generation_config=self.generation_config |
| ) |
| |
| prompt_length = generate_inputs["input_ids"].size(1) |
| completion_ids = prompt_completion_ids[:, prompt_length:] |
|
|
| |
| is_eos = completion_ids == self.eos_token_id |
| eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) |
| eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] |
| sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) |
| completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() |
| completion_ids = [ |
| c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) |
| ] |
| logprobs = None |
|
|
| return completion_ids, logprobs |
|
|
| def _get_tool_suffix_ids(self, tool_messages): |
| """Get token IDs for tool result formatting by using a minimal dummy conversation.""" |
| dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] |
| prefix_ids = self.processing_class.apply_chat_template( |
| dummy_messages, |
| add_generation_prompt=False, |
| chat_template=self.chat_template, |
| return_dict=False, |
| **self.chat_template_kwargs, |
| ) |
| full_ids = self.processing_class.apply_chat_template( |
| dummy_messages + tool_messages, |
| add_generation_prompt=True, |
| chat_template=self.chat_template, |
| return_dict=False, |
| **self.chat_template_kwargs, |
| ) |
|
|
| |
| |
| |
| last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id) |
| prefix_ids = prefix_ids[: last_eos_idx + 1] |
|
|
| if full_ids[: len(prefix_ids)] != prefix_ids: |
| raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") |
|
|
| return full_ids[len(prefix_ids) :] |
|
|
| def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): |
| |
| tool_calls = [completion[0].get("tool_calls") for completion in completions] |
| idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] |
| tool_calls = [tool_calls[idx] for idx in idxs_with_tool] |
| tool_mask = [[1] * len(ids) for ids in completion_ids] |
| tool_call_count = 0 |
| tool_failure_count = 0 |
| iteration_num = 0 |
| while idxs_with_tool and iteration_num < self.max_tool_calling_iterations: |
| prompt_completion_tools = [prompts[i] for i in idxs_with_tool] |
|
|
| |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| tool_call_list = tool_calls[idx] |
| prompt_completion_tool = prompt_completion_tools[idx] |
| sync_tool_dict = self._sync_tool_dicts[idx_with_tool] |
| async_tool_dict = self._async_tool_dicts[idx_with_tool] |
| |
| prompt_completion_tool.append(completions[idx_with_tool][-1]) |
| async_coros = [] |
| tool_call_results = [] |
| for tool_call in tool_call_list: |
| tool_call_count += 1 |
| if tool_call["type"] == "function": |
| function = tool_call["function"] |
| name = function["name"] |
| try: |
| if name in sync_tool_dict: |
| tool_call_results.append((name, sync_tool_dict[name](**function["arguments"]))) |
| elif name in async_tool_dict: |
| async_coros.append((name, async_tool_dict[name](**function["arguments"]))) |
| else: |
| raise ValueError(f"Tool {name} not found.") |
| except Exception as e: |
| tool_failure_count += 1 |
| result = {"error": str(e)} |
| tool_call_results.append((name, result)) |
| else: |
| tool_failure_count += 1 |
| name = tool_call.get("name", "unknown") |
| tool_call_results.append((name, {"error": f"Unsupported tool call type: {tool_call['type']}"})) |
|
|
| if async_coros: |
|
|
| async def _run_async_tools(async_coros): |
| coros = [coro for _, coro in async_coros] |
| results = await asyncio.gather(*coros, return_exceptions=True) |
| return [(name, result) for (name, _), result in zip(async_coros, results, strict=False)] |
|
|
| async_results = asyncio.run_coroutine_threadsafe( |
| _run_async_tools(async_coros), self.async_loop |
| ).result() |
|
|
| for name, result in async_results: |
| if isinstance(result, Exception): |
| tool_failure_count += 1 |
| tool_call_results.append((name, {"error": str(result)})) |
| else: |
| tool_call_results.append((name, result)) |
|
|
| for name, result in tool_call_results: |
| tool_message = {"role": "tool", "name": name, "content": str(result)} |
| prompt_completion_tool.append(tool_message) |
| completions[idx_with_tool].append(tool_message) |
|
|
| |
| prompt_completion_tool_ids = [] |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| |
| tool_messages = [] |
| for message in reversed(completions[idx_with_tool]): |
| if message["role"] == "tool": |
| tool_messages.insert(0, message) |
| else: |
| break |
| suffix_ids = self._get_tool_suffix_ids(tool_messages) |
| prompt_completion_tool_ids.append( |
| prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids |
| ) |
|
|
| |
| |
| if self.use_vllm and self.vllm_mode == "colocate": |
| max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len |
| elif self.use_vllm and self.vllm_mode == "server": |
| max_model_len = self.model.config.max_position_embeddings |
| elif not self.use_vllm: |
| max_model_len = self.model.config.max_position_embeddings |
| else: |
| raise NotImplementedError( |
| f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" |
| ) |
| overlong = [len(pct) >= max_model_len for pct in prompt_completion_tool_ids] |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| if overlong[idx]: |
| prompt_length = len(prompt_ids[idx_with_tool]) |
| ct = prompt_completion_tool_ids[idx][prompt_length : prompt_length + self.max_completion_length] |
| completion_ids[idx_with_tool] = ct |
| tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) |
| if logprobs is not None: |
| logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) |
| |
| idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] |
| prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] |
| prompt_completion_tool_ids = [ |
| pct for pct, o in zip(prompt_completion_tool_ids, overlong, strict=True) if not o |
| ] |
| if not idxs_with_tool: |
| break |
|
|
| |
| loop_images = [images[i] for i in idxs_with_tool] if images else None |
| loop_multimodal_fields = ( |
| {k: [v[i] for i in idxs_with_tool] for k, v in multimodal_fields.items()} if multimodal_fields else {} |
| ) |
|
|
| |
| post_tool_ids, post_tool_logprobs = self._generate_single_turn( |
| prompt_completion_tool_ids, loop_images, loop_multimodal_fields |
| ) |
|
|
| |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| prompt_len = len(prompt_ids[idx_with_tool]) |
| completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] |
| excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length |
| if excess_length > 0: |
| |
| post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] |
| if logprobs is not None: |
| post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] |
| excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length |
| if excess_length > 0: |
| |
| prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] |
|
|
| |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) |
| prompt_length = len(prompt_ids[idx_with_tool]) |
| completion_length = len(completion_ids[idx_with_tool]) |
| post_tool_length = len(post_tool_ids[idx]) |
| tool_length = prompt_completion_tool_length - prompt_length - completion_length |
| tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length |
| if logprobs is not None: |
| logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] |
|
|
| |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| prompt_length = len(prompt_ids[idx_with_tool]) |
| pct = prompt_completion_tool_ids[idx] |
| completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] |
|
|
| |
| post_tool_completions = [ |
| parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids |
| ] |
|
|
| |
| for idx in range(len(idxs_with_tool)): |
| idx_with_tool = idxs_with_tool[idx] |
| if post_tool_completions[idx]: |
| completions[idx_with_tool].append(post_tool_completions[idx]) |
|
|
| |
| tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] |
| idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] |
| tool_calls = [tool_call for tool_call in tool_calls if tool_call] |
| iteration_num += 1 |
|
|
| return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count |
|
|
| def _generate(self, prompts: list): |
| device = self.accelerator.device |
| mode = "train" if self.model.training else "eval" |
|
|
| |
| prompts = copy.deepcopy(prompts) |
|
|
| if self.rollout_func is not None: |
| |
| if self.use_vllm and self.state.global_step != self._last_loaded_step: |
| with profiling_context(self, "sync_weights"): |
| self.vllm_generation.sync_weights() |
| self._last_loaded_step = self.state.global_step |
|
|
| |
| |
| |
| output = self.rollout_func(prompts, self) |
| required_keys = {"prompt_ids", "completion_ids", "logprobs"} |
| missing_keys = required_keys - output.keys() |
| if missing_keys: |
| missing_keys_list = sorted(missing_keys) |
| raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") |
| extra_fields = {k: v for k, v in output.items() if k not in required_keys} |
| prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] |
| else: |
| prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) |
| completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) |
| extra_fields = {} |
|
|
| |
| if is_conversational({"prompt": prompts[0]}): |
| if ( |
| Version(transformers.__version__) >= Version("5.0.0") |
| and isinstance(self.processing_class, PreTrainedTokenizerBase) |
| and hasattr(self.processing_class, "response_schema") |
| and self.processing_class.response_schema is not None |
| ): |
| completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] |
| else: |
| contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) |
| completions = [[{"role": "assistant", "content": content}] for content in contents] |
| else: |
| completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) |
|
|
| |
| if self.tools: |
| ( |
| tool_mask, |
| completions, |
| completion_ids, |
| logprobs, |
| tool_call_count, |
| tool_failure_count, |
| ) = self._tool_call_loop( |
| prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields |
| ) |
| else: |
| |
| |
| tool_mask = extra_fields.pop("env_mask", None) |
|
|
| |
| prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) |
| if tool_mask is not None: |
| completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) |
| else: |
| completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) |
| agg_prompt_lengths = self.accelerator.gather(prompt_lengths) |
| agg_completion_lengths = self.accelerator.gather(completion_lengths) |
| total_prompt_tokens = agg_prompt_lengths.sum() |
| total_completion_tokens = agg_completion_lengths.sum() |
|
|
| |
| if mode == "train": |
| self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() |
| self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] |
|
|
| |
| self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) |
| self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) |
| self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) |
|
|
| |
| eos_and_pad = [self.eos_token_id, self.pad_token_id] |
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) |
| agg_is_truncated = self.accelerator.gather(is_truncated) |
| self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) |
| term_completion_lengths = agg_completion_lengths[~agg_is_truncated] |
| if len(term_completion_lengths) == 0: |
| term_completion_lengths = torch.zeros(1, device=device) |
| self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) |
| self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) |
| self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) |
|
|
| if self.tools: |
| agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() |
| tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() |
| self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) |
| agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum() |
| failure_frequency = ( |
| (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0 |
| ) |
| self._metrics[mode]["tools/failure_frequency"].append(failure_frequency) |
|
|
| return ( |
| prompt_ids, |
| completion_ids, |
| tool_mask, |
| completions, |
| total_completion_tokens, |
| logprobs, |
| extra_fields, |
| ) |
|
|
| def _generate_and_score_completions( |
| self, inputs: list[dict[str, torch.Tensor | Any]] |
| ) -> dict[str, torch.Tensor | Any]: |
| device = self.accelerator.device |
| mode = "train" if self.model.training else "eval" |
|
|
| prompts = [x["prompt"] for x in inputs] |
|
|
| if self.environments: |
| for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True): |
| observation = environment.reset(**reset_kwargs) |
| if observation is None: |
| continue |
| prompt[-1]["content"] += observation |
|
|
| if "images" in inputs[0]: |
| images = [example.get("images") for example in inputs] |
| elif "image" in inputs[0]: |
| images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] |
| else: |
| images = None |
| |
| if images is not None and all(img_list == [] for img_list in images): |
| images = None |
|
|
| |
| |
| |
| if images is not None: |
| if not is_conversational(inputs[0]): |
| raise ValueError( |
| "Multimodal training requires conversational prompts. It looks like the dataset contains " |
| "non-conversational inputs, likely because a chat template was applied before passing the dataset " |
| "to the trainer. Please provide the raw conversational prompts and let the trainer apply the chat " |
| "template internally." |
| ) |
| prompts = [ |
| prepare_multimodal_messages(prompt, image_list) |
| for prompt, image_list in zip(prompts, images, strict=True) |
| ] |
|
|
| ( |
| prompt_ids_list, |
| completion_ids_list, |
| tool_mask_list, |
| completions, |
| num_items_in_batch, |
| sampling_per_token_logps_list, |
| extra_fields, |
| ) = self._generate(prompts) |
|
|
| |
| prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] |
| prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] |
| prompt_ids = pad( |
| prompt_ids, |
| padding_value=self.pad_token_id, |
| padding_side="left", |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| ).to(device=device) |
| prompt_mask = pad( |
| prompt_mask, padding_value=0, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of |
| ).to(device=device) |
| completion_ids = [torch.tensor(ids) for ids in completion_ids_list] |
| completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] |
| completion_ids = pad( |
| completion_ids, |
| padding_value=self.pad_token_id, |
| padding_side="right", |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| ).to(device=device) |
| completion_mask = pad( |
| completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of |
| ).to(device=device) |
| if sampling_per_token_logps_list is not None: |
| sampling_per_token_logps = [torch.tensor(logps) for logps in sampling_per_token_logps_list] |
| sampling_per_token_logps = pad( |
| sampling_per_token_logps, |
| padding_value=0.0, |
| padding_side="right", |
| pad_to_multiple_of=self.pad_to_multiple_of, |
| ).to(device=device) |
| else: |
| sampling_per_token_logps = None |
| if tool_mask_list is not None: |
| tool_mask = [torch.tensor(mask) for mask in tool_mask_list] |
| tool_mask = pad( |
| tool_mask, padding_value=1, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of |
| ).to(device=device) |
| else: |
| tool_mask = None |
|
|
| |
| if self.mask_truncated_completions: |
| eos_and_pad = [self.eos_token_id, self.pad_token_id] |
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) |
| |
| completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() |
| |
| if tool_mask is not None: |
| tool_mask = tool_mask * (~is_truncated).unsqueeze(1).int() |
|
|
| |
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
|
|
| logits_to_keep = completion_ids.size(1) |
| batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size |
|
|
| num_images = [len(img_list) for img_list in images] if images is not None else None |
|
|
| |
| if images is not None: |
| prompts_text = [ |
| apply_chat_template( |
| {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs |
| )["prompt"] |
| for prompt in prompts |
| ] |
| prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") |
| prompt_inputs = super()._prepare_inputs(prompt_inputs) |
| forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} |
| else: |
| forward_kwargs = {} |
|
|
| |
| if "token_type_ids" in forward_kwargs: |
| token_type_ids = forward_kwargs["token_type_ids"] |
| if self.pad_to_multiple_of is not None: |
| |
| padding_size = prompt_ids.size(1) - token_type_ids.size(1) |
| if padding_size > 0: |
| token_type_ids = torch.cat( |
| [token_type_ids.new_zeros((token_type_ids.size(0), padding_size)), token_type_ids], dim=1 |
| ) |
| forward_kwargs["token_type_ids"] = torch.cat( |
| [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 |
| ) |
| |
| if "mm_token_type_ids" in forward_kwargs: |
| mm_token_type_ids = forward_kwargs["mm_token_type_ids"] |
| if self.pad_to_multiple_of is not None: |
| |
| padding_size = prompt_ids.size(1) - mm_token_type_ids.size(1) |
| if padding_size > 0: |
| mm_token_type_ids = torch.cat( |
| [mm_token_type_ids.new_zeros((mm_token_type_ids.size(0), padding_size)), mm_token_type_ids], |
| dim=1, |
| ) |
| forward_kwargs["mm_token_type_ids"] = torch.cat( |
| [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1 |
| ) |
|
|
| |
| |
| |
| with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): |
| |
| |
| |
| |
| |
| |
| |
| generate_every = self.args.steps_per_generation * self.num_iterations |
| if self.args.gradient_accumulation_steps % generate_every != 0 or ( |
| self.use_vllm and self.vllm_importance_sampling_correction |
| ): |
| old_per_token_logps, _ = self._get_per_token_logps_and_entropies( |
| self.model, |
| prompt_completion_ids, |
| attention_mask, |
| logits_to_keep, |
| batch_size, |
| num_images=num_images, |
| **forward_kwargs, |
| ) |
| else: |
| old_per_token_logps = None |
|
|
| |
| if self.use_vllm and self.vllm_importance_sampling_correction: |
| mask = completion_mask if tool_mask is None else completion_mask * tool_mask |
| per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask |
|
|
| sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] |
| if sequence_level_is: |
| per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) |
| logps_diff = per_sequence_logps_diff |
| else: |
| logps_diff = per_token_logps_diff |
|
|
| vllm_importance_sampling_ratio = torch.exp(logps_diff) |
|
|
| |
| |
| |
|
|
| if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: |
| vllm_importance_sampling_ratio = torch.clamp( |
| vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap |
| ) |
| elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: |
| vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( |
| vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 |
| ) |
| else: |
| raise ValueError( |
| f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." |
| ) |
|
|
| |
| if self.beta != 0.0: |
| if self.ref_model is not None: |
| ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( |
| self.ref_model, |
| prompt_completion_ids, |
| attention_mask, |
| logits_to_keep, |
| batch_size=batch_size, |
| num_images=num_images, |
| **forward_kwargs, |
| ) |
| else: |
| |
| |
| |
| model = self.accelerator.unwrap_model(self.model) |
| with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): |
| ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( |
| self.model, |
| prompt_completion_ids, |
| attention_mask, |
| logits_to_keep, |
| batch_size=batch_size, |
| num_images=num_images, |
| **forward_kwargs, |
| ) |
| else: |
| ref_per_token_logps = None |
|
|
| |
| prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) |
| completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) |
|
|
| |
| if extra_fields: |
| for i, inp in enumerate(inputs): |
| for key, values in extra_fields.items(): |
| if isinstance(values, list) and i < len(values): |
| inp[key] = values[i] |
| elif not isinstance(values, list): |
| inp[key] = values |
|
|
| |
| |
| |
| rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) |
| num_generations = self.num_generations if mode == "train" else self.num_generations_eval |
|
|
| if self.multi_objective_aggregation == "sum_then_normalize": |
| |
| rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) |
| mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1) |
| mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0) |
| if self.scale_rewards in ["group", "none"]: |
| |
| if num_generations > 1: |
| std_rewards = rewards.view(-1, num_generations).std(dim=1) |
| std_rewards = std_rewards.repeat_interleave(num_generations, dim=0) |
| else: |
| std_rewards = torch.zeros_like(rewards) |
| elif self.scale_rewards == "batch": |
| |
| if rewards.numel() > 1: |
| std_rewards = rewards.std().expand_as(rewards) |
| else: |
| std_rewards = torch.zeros_like(rewards) |
| else: |
| raise ValueError( |
| f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." |
| ) |
|
|
| advantages = rewards - mean_grouped_rewards |
| if self.scale_rewards != "none": |
| advantages = advantages / (std_rewards + 1e-4) |
| is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) |
|
|
| elif self.multi_objective_aggregation == "normalize_then_sum": |
| grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) |
| mean_k = torch.nanmean(grouped, dim=1, keepdim=True) |
| std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) |
| reward_k = (grouped - mean_k) / (std_k + 1e-4) |
| reward_k = reward_k.view(-1, len(self.reward_funcs)) |
| rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) |
| std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) |
| advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) |
| is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) |
|
|
| else: |
| raise ValueError( |
| f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. Must be " |
| "'sum_then_normalize' or 'normalize_then_sum'." |
| ) |
|
|
| |
| process_slice = slice( |
| self.accelerator.process_index * len(prompts), |
| (self.accelerator.process_index + 1) * len(prompts), |
| ) |
| all_process_advantages = advantages.clone() |
| advantages = advantages[process_slice] |
|
|
| |
| for i, reward_func_name in enumerate(self.reward_func_names): |
| mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() |
| self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) |
| std_func_rewards = nanstd(rewards_per_func[:, i]).item() |
| self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) |
| rewards = (rewards_per_func * self.reward_weights.to(rewards_per_func.device).unsqueeze(0)).nansum(dim=1) |
| self._metrics[mode]["reward"].append(rewards.mean().item()) |
| self._metrics[mode]["reward_std"].append(rewards.std().item()) |
| self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) |
|
|
| |
| self._logs["prompt"].extend(gather_object(prompts_text)) |
| self._logs["completion"].extend(gather_object(completions_text)) |
| for i, name in enumerate(self.reward_func_names): |
| self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) |
| self._logs["advantages"].extend(all_process_advantages.tolist()) |
|
|
| |
| |
| |
| for column in sorted(self._pending_extra_logs): |
| self._logs["extra"][column].extend(gather_object(self._pending_extra_logs[column])) |
| self._pending_extra_logs.clear() |
|
|
| |
| |
| |
| for name in sorted(self._pending_metrics): |
| values = self._pending_metrics[name] |
| local_mean = sum(values) / len(values) |
| global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item() |
| self._metrics[mode][name].append(global_mean) |
| self._pending_metrics.clear() |
|
|
| if images is not None: |
| self._logs["images"].extend(gather_object(images)) |
|
|
| if self.use_vllm and self.vllm_importance_sampling_correction: |
| delta = torch.abs(old_per_token_logps - sampling_per_token_logps) |
| mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() |
| delta = delta[mask] |
| mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) |
| max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) |
| self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( |
| self.accelerator.gather(mean_delta).mean().item() |
| ) |
| self._metrics[mode]["sampling/sampling_logp_difference/max"].append( |
| self.accelerator.gather(max_delta).max().item() |
| ) |
| if sequence_level_is: |
| flat_is_ratio = vllm_importance_sampling_ratio.flatten() |
| else: |
| flat_is_ratio = vllm_importance_sampling_ratio[mask] |
|
|
| min_importance_sampling_ratio = ( |
| torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) |
| ) |
| mean_importance_sampling_ratio = ( |
| torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) |
| ) |
| max_importance_sampling_ratio = ( |
| torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) |
| ) |
| self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( |
| nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() |
| ) |
| self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( |
| self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() |
| ) |
| self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( |
| nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() |
| ) |
|
|
| output = { |
| "prompt_ids": prompt_ids, |
| "prompt_mask": prompt_mask, |
| "completion_ids": completion_ids, |
| "completion_mask": completion_mask, |
| "advantages": advantages, |
| "num_items_in_batch": num_items_in_batch, |
| } |
| if old_per_token_logps is not None: |
| output["old_per_token_logps"] = old_per_token_logps |
| if self.use_vllm and self.vllm_importance_sampling_correction: |
| output["importance_sampling_ratio"] = vllm_importance_sampling_ratio |
| if sampling_per_token_logps is not None: |
| output["sampling_per_token_logps"] = sampling_per_token_logps |
| if ref_per_token_logps is not None: |
| output["ref_per_token_logps"] = ref_per_token_logps |
| if "pixel_values" in forward_kwargs: |
| output["pixel_values"] = forward_kwargs["pixel_values"] |
| if "image_grid_thw" in forward_kwargs: |
| output["image_grid_thw"] = forward_kwargs["image_grid_thw"] |
| if "pixel_attention_mask" in forward_kwargs: |
| output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] |
| if "image_sizes" in forward_kwargs: |
| output["image_sizes"] = forward_kwargs["image_sizes"] |
| if "token_type_ids" in forward_kwargs: |
| output["token_type_ids"] = forward_kwargs["token_type_ids"] |
| if "mm_token_type_ids" in forward_kwargs: |
| output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"] |
| if "pixel_position_ids" in forward_kwargs: |
| output["pixel_position_ids"] = forward_kwargs["pixel_position_ids"] |
| if images is not None: |
| output["num_images"] = num_images |
| if tool_mask is not None: |
| output["tool_mask"] = tool_mask |
| return output |
|
|
| def compute_liger_loss(self, unwrapped_model, inputs): |
| |
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
| logits_to_keep = completion_ids.size(1) |
|
|
| |
| last_hidden_state = self._get_last_hidden_state( |
| unwrapped_model, |
| input_ids, |
| attention_mask, |
| logits_to_keep, |
| inputs.get("pixel_values"), |
| inputs.get("image_grid_thw"), |
| inputs.get("pixel_attention_mask"), |
| inputs.get("image_sizes"), |
| inputs.get("pixel_position_ids"), |
| ) |
|
|
| |
| loss_mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] |
| |
| loss, metrics = self.liger_grpo_loss( |
| _input=last_hidden_state, |
| lin_weight=unwrapped_model.lm_head.weight, |
| selected_token_ids=completion_ids, |
| |
| attention_mask=loss_mask, |
| advantages=inputs["advantages"], |
| bias=unwrapped_model.lm_head.bias, |
| old_per_token_logps=inputs.get("old_per_token_logps"), |
| ref_per_token_logps=inputs.get("ref_per_token_logps"), |
| vllm_is_ratio=inputs.get("importance_sampling_ratio"), |
| ) |
| |
| |
| mean_kl = metrics[0] if self.beta != 0.0 else None |
| clip_ratio = metrics[-1] |
|
|
| mode = "train" if self.model.training else "eval" |
| if self.beta != 0.0: |
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) |
| self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| return loss / normalizer |
|
|
| @profiling_decorator |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| if return_outputs: |
| raise ValueError("The GRPOTrainer does not support returning outputs") |
| if self.use_liger_kernel: |
| |
| unwrapped_model = self.accelerator.unwrap_model(model) |
| return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) |
| else: |
| return self._compute_loss(model, inputs) |
|
|
| @staticmethod |
| def get_off_policy_mask( |
| advantages: torch.Tensor, |
| per_token_logps: torch.Tensor, |
| sampling_per_token_logps: torch.Tensor, |
| mask: torch.Tensor, |
| off_policy_threshold: float, |
| ) -> torch.Tensor: |
| """ |
| Computes the Off-Policy Sequence Mask from DeepSeek-V3.2 paper. Returns a (B, 1) tensor where 1.0 indicates |
| "Keep" and 0.0 indicates "Drop". |
| """ |
| |
| kl_div = sampling_per_token_logps - per_token_logps.detach() |
| |
| seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True) |
| avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0) |
| |
| is_pos_adv = advantages >= 0 |
| is_low_kl = avg_seq_kl <= off_policy_threshold |
| return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) |
|
|
| @staticmethod |
| @torch.no_grad() |
| def get_gamma_weights( |
| advantages: torch.Tensor, |
| log_ratio_per_token: torch.Tensor, |
| mask: torch.Tensor, |
| importance_sampling_ratio: torch.Tensor | None, |
| k_pos: float = 2.0, |
| lambda_pos: float = 3.0, |
| k_neg: float = 3.0, |
| lambda_neg: float = 2.0, |
| ) -> torch.Tensor: |
| """ |
| Computes the Gamma weights for the VESPO loss. For reference: |
| φ(w) = e^λ × w^k × e^{-λw} is the gamma weighting (normalized so φ(1)=1) |
| with w = sequence-level importance sampling ratio |
| note: we will compute φ(w) in log space |
| |
| φ(w) is detached via @torch.no_grad(), only acts as gradient scaling coefficient |
| |
| VESPO loss = -φ(w) × A × log_prob, gradient naturally gives φ(w) × A × ∇log π |
| """ |
| |
| |
| lower_clamp = math.log(1e-8) |
|
|
| |
| log_ratio_clamped = torch.clamp(log_ratio_per_token, -20.0, 20.0) |
| seq_log_ratio = torch.sum(log_ratio_clamped * mask, dim=-1, keepdim=True) |
|
|
| |
| if importance_sampling_ratio is not None: |
| log_is_ratio = torch.clamp(torch.log(importance_sampling_ratio), lower_clamp, 20.0) |
| |
| seq_log_ratio += torch.sum(log_is_ratio, dim=-1, keepdim=True) |
|
|
| log_w_seq = torch.clamp(seq_log_ratio, lower_clamp, 20.0) |
| w_seq = torch.exp(log_w_seq) |
|
|
| |
| is_nonneg_adv = advantages >= 0 |
| k_seq = torch.where(is_nonneg_adv, k_pos, k_neg) |
| lambda_seq = torch.where(is_nonneg_adv, lambda_pos, lambda_neg).clamp(min=1e-4) |
|
|
| |
| log_phi = lambda_seq + k_seq * log_w_seq - lambda_seq * w_seq |
| phi_seq = torch.exp(log_phi).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) |
|
|
| return phi_seq |
|
|
| def _compute_loss(self, model, inputs): |
| |
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
| logits_to_keep = completion_ids.size(1) |
| mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] |
|
|
| |
| per_token_logps, entropies = self._get_per_token_logps_and_entropies( |
| model, |
| input_ids, |
| attention_mask, |
| logits_to_keep, |
| compute_entropy=True, |
| pixel_values=inputs.get("pixel_values"), |
| image_grid_thw=inputs.get("image_grid_thw"), |
| num_images=inputs.get("num_images"), |
| pixel_attention_mask=inputs.get("pixel_attention_mask"), |
| image_sizes=inputs.get("image_sizes"), |
| token_type_ids=inputs.get("token_type_ids"), |
| mm_token_type_ids=inputs.get("mm_token_type_ids"), |
| pixel_position_ids=inputs.get("pixel_position_ids"), |
| ) |
|
|
| if self.top_entropy_quantile < 1.0: |
| entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) |
| else: |
| entropy_mask = None |
|
|
| |
| advantages = inputs["advantages"] |
| |
| |
| if advantages.dim() == 1: |
| advantages = advantages.unsqueeze(1) |
| |
| |
| |
| |
| |
| old_per_token_logps = inputs.get("old_per_token_logps") |
| old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps |
|
|
| if self.off_policy_mask_threshold is not None: |
| |
| |
| |
| |
| sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) |
|
|
| off_policy_mask = self.get_off_policy_mask( |
| advantages=advantages, |
| per_token_logps=per_token_logps, |
| sampling_per_token_logps=sampling_per_token_logps, |
| mask=mask, |
| off_policy_threshold=self.off_policy_mask_threshold, |
| ) |
|
|
| log_ratio = per_token_logps - old_per_token_logps |
| if self.importance_sampling_level == "token": |
| log_importance_weights = log_ratio |
| elif self.importance_sampling_level == "sequence": |
| log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) |
| log_importance_weights = log_importance_weights.unsqueeze(-1) |
| else: |
| raise ValueError( |
| f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " |
| "and 'sequence'." |
| ) |
|
|
| coef_1 = torch.exp(log_importance_weights) |
|
|
| |
| if self.beta != 0.0: |
| ref_per_token_logps = inputs["ref_per_token_logps"] |
| per_token_kl = ( |
| torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
| ) |
| |
| if self.args.use_bias_correction_kl: |
| per_token_kl = per_token_kl * coef_1 |
|
|
| |
| |
| if self.loss_type == "cispo": |
| clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() |
| per_token_loss = -clamped_ratios * advantages * per_token_logps |
| elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: |
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) |
| |
| if self.args.delta is not None: |
| coef_1 = torch.clamp(coef_1, max=self.args.delta) |
|
|
| per_token_loss1 = coef_1 * advantages |
| per_token_loss2 = coef_2 * advantages |
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) |
| elif self.loss_type == "sapo": |
| temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) |
| soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures |
| per_token_loss = -soft_coef_1 * advantages |
| elif self.loss_type == "vespo": |
| phi_seq = self.get_gamma_weights( |
| advantages=advantages, |
| log_ratio_per_token=log_ratio, |
| mask=mask, |
| importance_sampling_ratio=inputs.get("importance_sampling_ratio"), |
| k_pos=self.args.vespo_k_pos, |
| lambda_pos=self.args.vespo_lambda_pos, |
| k_neg=self.args.vespo_k_neg, |
| lambda_neg=self.args.vespo_lambda_neg, |
| ) |
| per_token_loss = -phi_seq * advantages * per_token_logps |
| else: |
| raise ValueError(f"Unknown loss type: {self.loss_type}") |
|
|
| if self.off_policy_mask_threshold is not None: |
| per_token_loss = per_token_loss * off_policy_mask |
|
|
| if entropy_mask is not None: |
| per_token_loss = per_token_loss * entropy_mask |
|
|
| if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": |
| per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] |
|
|
| if self.beta != 0.0: |
| per_token_loss = per_token_loss + self.beta * per_token_kl |
|
|
| mode = "train" if self.model.training else "eval" |
| if self.loss_type in ["grpo", "sapo"]: |
| loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type == "bnpo": |
| loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type == "dr_grpo": |
| loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type in ["cispo", "dapo", "vespo"]: |
| normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes |
| loss = (per_token_loss * mask).sum() / normalizer |
| elif self.loss_type == "luspo": |
| |
| loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| else: |
| raise ValueError(f"Unknown loss type: {self.loss_type}") |
|
|
| |
| completion_token_count = mask.sum().clamp(min=1.0) |
|
|
| def masked_batch_mean(x): |
| if x.shape[1] == 1: |
| return x.mean() |
| else: |
| return (x * mask).sum() / completion_token_count |
|
|
| if self.beta != 0.0: |
| mean_kl = masked_batch_mean(per_token_kl) |
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) |
|
|
| mean_entropy = masked_batch_mean(entropies) |
| self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) |
|
|
| if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: |
| |
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) |
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) |
| is_region_clipped = is_low_clipped | is_high_clipped |
|
|
| low_clip = masked_batch_mean(is_low_clipped.float()) |
| high_clip = masked_batch_mean(is_high_clipped.float()) |
| clip_ratio = masked_batch_mean(is_region_clipped.float()) |
|
|
| gathered_low_clip = self.accelerator.gather(low_clip) |
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) |
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) |
| gathered_high_clip = self.accelerator.gather(high_clip) |
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) |
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) |
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) |
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) |
| elif self.loss_type == "cispo": |
| is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) |
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) |
| gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) |
| self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) |
| elif self.loss_type == "vespo": |
| gathered_phi_seq = self.accelerator.gather(phi_seq) |
| self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) |
|
|
| return loss |
|
|
| |
| |
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): |
| inputs = self._prepare_inputs(inputs) |
| with torch.no_grad(): |
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
| loss = loss.mean().detach() |
| return loss, None, None |
|
|
| def log(self, logs: dict[str, float], start_time: float | None = None) -> None: |
| mode = "train" if self.model.training else "eval" |
| |
| metrics = {} |
| for key, val in self._metrics[mode].items(): |
| |
| |
| |
| |
| valid = [v for v in val if not math.isnan(v)] |
| metrics[key] = sum(valid) / len(valid) if valid else None |
|
|
| |
| |
| if mode == "eval": |
| metrics = {f"eval_{key}": val for key, val in metrics.items()} |
|
|
| logs = {**logs, **metrics} |
| super().log(logs, start_time) |
| self._metrics[mode].clear() |
|
|
| if self.accelerator.is_main_process and self.log_completions: |
| if is_rich_available(): |
| print_prompt_completions_sample( |
| self._logs["prompt"], |
| self._logs["completion"], |
| self._logs["rewards"], |
| self._logs["advantages"], |
| self.state.global_step, |
| self.num_completions_to_print, |
| ) |
|
|
| logging_backends = [] |
| if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: |
| logging_backends.append(wandb) |
| if self.args.report_to and "trackio" in self.args.report_to: |
| logging_backends.append(trackio) |
|
|
| table = { |
| "step": [self.state.global_step] * len(self._logs["prompt"]), |
| "prompt": self._logs["prompt"], |
| "completion": self._logs["completion"], |
| **self._logs["rewards"], |
| **self._logs["extra"], |
| "advantage": self._logs["advantages"], |
| } |
|
|
| df_base = pd.DataFrame(table) |
| df_base.to_parquet( |
| os.path.join( |
| self.args.output_dir, |
| "completions", |
| f"completions_{self.state.global_step:05d}.parquet", |
| ) |
| ) |
|
|
| images_raw = self._logs["images"] or [] |
|
|
| for logging_backend in logging_backends: |
| if images_raw: |
| images = [] |
| for image_list in self._logs["images"]: |
| images.append([logging_backend.Image(image) for image in image_list]) |
| df = pd.concat( |
| [df_base, pd.Series(images, name="image")], |
| axis=1, |
| copy=False, |
| ) |
| else: |
| df = df_base |
|
|
| if self.log_unique_prompts: |
| df = df.drop_duplicates(subset=["prompt"]) |
|
|
| logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) |
|
|
| |
| def _save_checkpoint(self, model, trial): |
| if self.args.hub_model_id is None: |
| model_name = Path(self.args.output_dir).name |
| else: |
| model_name = self.args.hub_model_id.split("/")[-1] |
| self.create_model_card(model_name=model_name) |
| super()._save_checkpoint(model, trial) |
|
|