import torch from espnet2.layers.label_aggregation import LabelAggregate class LabelProcessor(torch.nn.Module): """Label aggregator for speaker diarization """ def __init__( self, win_length: int = 512, hop_length: int = 128, center: bool = True ): super().__init__() self.label_aggregator = LabelAggregate(win_length, hop_length, center) def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input: (Batch, Nsamples, Label_dim) ilens: (Batch) Returns: output: (Batch, Frames, Label_dim) olens: (Batch) """ output, olens = self.label_aggregator(input, ilens) return output, olens