test-flex-gpt / padding.py
oweller2
add breakpoint
d4cfde8
raw
history blame
3.23 kB
import torch
from torch import Tensor
from typing import Optional, Tuple
import torch.nn.functional as F
def unpad_input(
inputs: Tensor,
attention_mask: Tensor,
position_ids: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]:
"""
Remove padding from input sequences.
Args:
inputs: (batch, seqlen, ...) or (batch, seqlen)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
position_ids: (batch, seqlen), int, position ids
labels: (batch, seqlen), int, labels
Returns:
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
cu_seqlens: (batch + 1), the cumulative sequence lengths
max_seqlen_in_batch: int
unpadded_position_ids: (total_nnz) or None
unpadded_labels: (total_nnz) or None
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
if inputs.dim() == 2:
unpadded_inputs = inputs.flatten()[indices]
else:
batch, seqlen, *rest = inputs.shape
shape = batch * seqlen
unpadded_inputs = inputs.view(shape, *rest)[indices]
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
unpadded_labels = labels.flatten()[indices] if labels is not None else None
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
def pad_input(
inputs: Tensor,
indices: Tensor,
batch: int,
seqlen: int,
labels: Optional[Tensor] = None,
ignore_index: int = -100,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Add padding to sequences.
Args:
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
batch: int, batch size
seqlen: int, max sequence length
position_ids: (total_nnz) or None
labels: (total_nnz) or None
Returns:
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
padded_labels: (batch, seqlen) or None
"""
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
breakpoint()
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
padded_labels = None
if labels is not None:
padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device)
padded_labels[indices] = labels
padded_labels = padded_labels.view(batch, seqlen)
return padded_inputs, padded_labels