import torch from torch import Tensor from typing import Optional, Tuple import torch.nn.functional as F def unpad_input( inputs: Tensor, attention_mask: Tensor, position_ids: Optional[Tensor] = None, labels: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]: """ Remove padding from input sequences. Args: inputs: (batch, seqlen, ...) or (batch, seqlen) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. position_ids: (batch, seqlen), int, position ids labels: (batch, seqlen), int, labels Returns: unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. indices: (total_nnz) cu_seqlens: (batch + 1), the cumulative sequence lengths max_seqlen_in_batch: int unpadded_position_ids: (total_nnz) or None unpadded_labels: (total_nnz) or None """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = int(seqlens_in_batch.max().item()) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) if inputs.dim() == 2: unpadded_inputs = inputs.flatten()[indices] else: batch, seqlen, *rest = inputs.shape shape = batch * seqlen unpadded_inputs = inputs.view(shape, *rest)[indices] unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None unpadded_labels = labels.flatten()[indices] if labels is not None else None return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels def pad_input( inputs: Tensor, indices: Tensor, batch: int, seqlen: int, labels: Optional[Tensor] = None, ignore_index: int = -100, ) -> Tuple[Tensor, Optional[Tensor]]: """ Add padding to sequences. Args: inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. indices: (total_nnz) batch: int, batch size seqlen: int, max sequence length position_ids: (total_nnz) or None labels: (total_nnz) or None Returns: padded_inputs: (batch, seqlen, ...) or (batch, seqlen) padded_labels: (batch, seqlen) or None """ if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen) else: breakpoint() _, *rest = inputs.shape output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) padded_labels = None if labels is not None: padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device) padded_labels[indices] = labels padded_labels = padded_labels.view(batch, seqlen) return padded_inputs, padded_labels