| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| |
|
| |
|
| | class MaskedMeanPooling(torch.nn.Module): |
| | """ |
| | Mean pooling layer with explicit masking support. |
| | |
| | This layer computes the mean over the sequence dimension while |
| | ignoring padded elements according to a boolean mask. It supports |
| | both PyTorch-style padding masks and valid-position masks. |
| | """ |
| |
|
| | def __init__(self, valid_pad: bool = True, eps: float = 1e-6): |
| | """ |
| | Initialize the masked mean pooling layer. |
| | |
| | Args: |
| | valid_pad (bool, optional): Mask interpretation mode. If True, |
| | `True` values in the mask indicate valid (non-padded) positions. |
| | If False, `True` values indicate padded positions, following |
| | PyTorch-style padding conventions. Defaults to True. |
| | eps (float, optional): Small constant to avoid division by zero |
| | when all positions are masked. Defaults to 1e-8. |
| | """ |
| | super().__init__() |
| | self.valid_pad = valid_pad |
| | self.eps = eps |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | mask: torch.Tensor |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Apply masked mean pooling. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor of shape (..., S, D), where |
| | B is the batch size, S the sequence length, and D the |
| | feature dimension. |
| | mask (torch.Tensor): Boolean mask tensor of shape (..., S). |
| | The interpretation depends on `valid_pad`. |
| | |
| | Returns: |
| | tuple: |
| | torch.Tensor: Pooled tensor of shape (..., D). |
| | torch.Tensor: Updated valid mask after pooling of shape (..., ). |
| | """ |
| | |
| | if mask is None: |
| | valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device) |
| | else: |
| | valid_mask = mask |
| |
|
| | |
| | if self.valid_pad: |
| | valid_mask = valid_mask |
| | else: |
| | valid_mask = torch.logical_not(valid_mask) |
| |
|
| | valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) |
| | summed = torch.sum(x * valid_mask, dim=-2) |
| | denom = valid_mask.sum(dim=-2).clamp(min=self.eps) |
| |
|
| | |
| | valid_mask = valid_mask.squeeze(-1).any(dim=-1) |
| |
|
| | return summed / denom, valid_mask |
| | |
| | |
| | |
| |
|