LeTue09's picture
initial clean commit
1faccd4
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from verl.utils import tensordict_utils as tu
from verl.utils.attention_utils import index_first_axis, unpad_input
def left_right_2_no_padding(data: TensorDict) -> TensorDict:
"""
Convert TensorDict from left-right padding to no-padding format.
Args:
data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids"
Returns:
data: TensorDict with
- Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids"
- NonTensorData includes "max_seq_len", "max_response_len", "indices"
Note:
1. the return input_ids/position_ids/loss_mask are nested tensor.
2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept.
"""
assert "input_ids" in data, "input_ids is required in left-right padding data"
assert "attention_mask" in data, "attention_mask is required in left-right padding data"
assert "response_mask" in data, "response_mask is required in left-right padding data"
assert "position_ids" in data, "position_ids is required in left-right padding data"
input_ids = data.pop("input_ids")
attention_mask = data["attention_mask"]
response_mask = data["response_mask"]
position_ids = data["position_ids"] # (bs, seq_len) or # (bs, 4, seq_len)
max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1]
tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len)
tu.assign_non_tensor_data(data, "max_response_len", max_response_len)
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
tu.assign_non_tensor_data(data, "indices", indices)
input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens)
position_ids_list = []
for i in range(attention_mask.shape[0]):
curr_mask = attention_mask[i].bool()
curr_pos_ids = position_ids[i]
if curr_pos_ids.dim() == 1: # (seq_len,)
valid_ids = curr_pos_ids[curr_mask]
else: # (4, seq_len)
valid_ids = curr_pos_ids[:, curr_mask]
position_ids_list.append(valid_ids)
position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged)
data["input_ids"] = input_ids_nested
data["position_ids"] = position_ids_nested
data["loss_mask"] = data["response_mask"]
routed_experts = data.get("routed_experts", None)
if routed_experts is not None and not routed_experts.is_nested:
if routed_experts.max() <= 255:
routed_experts = routed_experts.to(torch.uint8)
routed_experts_rmpad = index_first_axis(routed_experts.unsqueeze(-1).flatten(0, 1), indices)
routed_experts_nested = torch.nested.nested_tensor_from_jagged(
routed_experts_rmpad.squeeze(-1), offsets=cu_seqlens
)
data["routed_experts"] = routed_experts_nested
return data
def no_padding_2_padding(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor:
"""Slice response from unpad model output.
Args:
tensor: a nested tensor or a 1D tensor in shape (total_nnz,),
total_nnz is the total number of tokens across all sequences in the batch
data: TensorDict with "prompts", "responses", "attention_mask"
Returns:
tensor: sliced response tensor of shape [bsz, max_response_len]
"""
values = tensor.values() if tensor.is_nested else tensor
prompt_ids = data["prompts"]
response_ids = data["responses"]
attention_mask = data["attention_mask"]
max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=-1)
if prompt_ids.is_nested:
prompt_lens = prompt_ids.offsets().diff()
response_lens = response_ids.offsets().diff()
if max_response_len < 0:
max_response_len = response_lens.max().item()
else:
assert not attention_mask.is_nested
prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1)
response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1)
max_response_len = response_ids.shape[1]
sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]
response_list = []
for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True):
pad_size = max_response_len - resp_len
# left-shift model output by one token for log_probs/values
response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size)))
output = torch.stack(response_list, dim=0)
return output