trl-4-dnd / trl /models /utils.py
vishaljoshi24's picture
Initial Commit
a080fe0
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch.nn as nn
from packaging import version
from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
if TYPE_CHECKING:
from accelerate import Accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"
@property
def system(self):
return f"{self.bos_token}system"
@property
def user(self):
return f"{self.bos_token}user"
@property
def assistant(self):
return f"{self.bos_token}assistant"
@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{{{ '{self.assistant}\n' }}}}"
"{% endif %}"
)
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
embedding layer of the model based on the new special tokens.
<Tip warning="true"> We recommend using [`clone_chat_template`] instead of this function.
</Tip>
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
`tokenizer.chat_template` to `None`.
Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
Returns:
model (`~transformers.PreTrainedModel`):
The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`):
The modified tokenizer.
"""
# check if model already had a chat template
if tokenizer.chat_template is not None:
raise ValueError(
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None"
)
# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
chat_format = FORMAT_MAPPING[format]()
# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def clone_chat_template(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
source_tokenizer_path: str,
resize_to_multiple_of: Optional[int] = 64,
) -> tuple[PreTrainedModel, PreTrainedTokenizer, list[int]]:
"""
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
This function:
- Copies the chat template from a source tokenizer to the target tokenizer.
- Adds any new tokens from the source tokenizer to the target tokenizer.
- Sets and synchronizes the EOS token across the tokenizer and model.
- Resizes the model's token embeddings to match the new vocabulary size, optionally rounding it up to a multiple of
a specified value. In such cases, dummy tokens are added to the tokenizer to ensure the vocabulary size matches
the embedding dimensions.
Args:
model (`PreTrainedModel`):
Model to update.
tokenizer (`PreTrainedTokenizer`):
Tokenizer to update.
source_tokenizer_path (`str`):
Path or identifier of the pretrained tokenizer to clone from.
resize_to_multiple_of (`int` or `None`, *optional*, defaults to `64`):
The embedding layer will be resized to the new vocabulary size. If this is not `None`, it will round up the
new vocabulary size to the nearest multiple of this value.
Returns:
model (`PreTrainedModel`):
Updated model with resized token embeddings and EOS token configured.
tokenizer (`~transformers.PreTrainedTokenizer`):
Updated tokenizer with the chat template and special tokens applied.
added_tokens (`list[int]`):
List of tokens that were added to the tokenizer from the source tokenizer.
Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import clone_chat_template
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model, tokenizer, added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
```
"""
# Load the source tokenizer containing the desired chat template
tokenizer_source = AutoTokenizer.from_pretrained(source_tokenizer_path)
# Copy the chat template from the source tokenizer
tokenizer.chat_template = tokenizer_source.get_chat_template()
# Ensure all added tokens from the source are available in the target tokenizer
added_tokens = [
token for token in tokenizer_source.added_tokens_decoder.values() if token.content not in tokenizer.vocab
]
tokenizer.add_tokens(added_tokens)
# Set the EOS token from the source tokenizer (important for generation)
tokenizer.eos_token = tokenizer_source.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
# Resize model embeddings to include any new tokens, optionally rounding up to a multiple
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# After resizing, the embedding matrix size may exceed the vocabulary size. Add dummy tokens to the tokenizer to
# ensure vocabulary size matches the embedding matrix dimensions.
idx = 0
while model.vocab_size > len(tokenizer.vocab):
dummy_token = AddedToken(f"<extra_id_{idx}>")
is_added = tokenizer.add_tokens(dummy_token)
idx += 1
if is_added == 1:
added_tokens.append(dummy_token)
# Verify that vocabulary size now matches embedding dimensions
if len(tokenizer.vocab) != model.vocab_size:
raise RuntimeError(
f"Vocabulary size mismatch after resizing: tokenizer vocab size is {len(tokenizer.vocab)}, but model "
f"embedding size is {model.vocab_size}. This indicates an internal error in the token alignment process."
)
added_tokens = [token.content for token in added_tokens]
added_tokens = tokenizer.convert_tokens_to_ids(added_tokens)
return model, tokenizer, added_tokens
def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
else:
raise RuntimeError("The model optimizer is None, which is not yet supported.")
for param in iter_params(optimizer_offload.module, recurse=True):
param.ds_active_sub_modules.clear()
for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
hook.remove()
optimizer_offload.forward_hooks = []
optimizer_offload.backward_hooks = []
def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
def iter_params(module, recurse=False):
return [param for _, param in get_all_parameters(module, recurse)]
def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
else:
raise RuntimeError("The model optimizer is None, which is not yet supported.")
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
else:
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
@contextmanager
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
accelerator: "Accelerator",
gather_deepspeed3_params: bool = True,
):
"""
Context manager to unwrap distributed or accelerated models for generation tasks.
Args:
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
Model to be unwrapped.
accelerator (`~accelerate.Accelerator`):
Accelerator instance managing the model.
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
can be more memory-efficient but may lead to slower generation times.
Yields:
Unwrapped model.
Example:
```python
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
generated_outputs = unwrapped_model.generate(input_ids)
```
"""
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
else:
yield unwrapped_model
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
Adapted from accelerate:
https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
"""
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252
deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]
if model is not None:
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and stage == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
# @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
# disabled (stage 0)
if stage != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def prepare_fsdp(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1421
from torch.distributed.fsdp import FSDPModule
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if not (isinstance(model, FSDP) or isinstance(model, FSDPModule)):
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = accelerator.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": accelerator.device,
}
model = FSDP(model, **kwargs)
model.eval()
return model
class _ForwardRedirection:
"""Implements the `forward-redirection`.
Taken from Pytorch-lightning:
https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
"""
def __call__(
self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any
):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
original_forward = original_module.forward
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
self.on_after_inner_forward(wrapper_module, original_module)
return out
# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]
wrapper_output = wrapper_module(*args, **kwargs)
self.on_after_outer_forward(wrapper_module, original_module)
return wrapper_output
def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass
def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass