Spaces:
Sleeping
Sleeping
| from typing import Optional, Tuple | |
| import torch | |
| from wenet.ssl.bestrq.mask import compute_mask_indices | |
| from wenet.utils.mask import make_pad_mask | |
| class BestRQModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| encoder: torch.nn.Module, | |
| input_dim: int = 256, | |
| embedding_dim: int = 256, | |
| num_embeddings: int = 8192, | |
| num_codebooks: int = 1, | |
| dropout_rate: float = 0.1, | |
| mask_prob: float = 0.01, | |
| mask_length: int = 10, | |
| min_masks: int = 2, | |
| layer_norm_epsilon=1e-5, | |
| ) -> None: | |
| super().__init__() | |
| assert mask_prob > 0.0 | |
| self.mask_prob = mask_prob | |
| # NOTE: should filter audio less than mask_length | |
| self.mask_length = mask_length | |
| self.min_masks = min_masks | |
| self.input_dropout = torch.nn.Dropout(dropout_rate) | |
| # [embedding_dim, num_embeddings] | |
| random_embedding_weight = torch.empty( | |
| num_codebooks, embedding_dim, num_embeddings, requires_grad=False | |
| ) | |
| self.embeddings = torch.nn.init.normal_(random_embedding_weight) | |
| random_projection_weight = torch.empty( | |
| input_dim, embedding_dim, requires_grad=False | |
| ) | |
| self.projection = torch.nn.init.xavier_normal_(random_projection_weight) | |
| mask_emb_weight = torch.Tensor(input_dim) | |
| mask_emb_weight.requires_grad = True | |
| self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1) | |
| self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon) | |
| self.encoder = encoder | |
| self.encoder_top_n_out = torch.nn.parameter.Parameter( | |
| torch.Tensor(num_codebooks, self.encoder.output_size(), num_embeddings) | |
| ) | |
| def forward( | |
| self, | |
| xs: torch.Tensor, | |
| xs_lens: torch.Tensor, | |
| text: Optional[torch.Tensor] = None, | |
| text_length: Optional[torch.Tensor] = None, | |
| ): | |
| # should support nonstreamming and streamming | |
| # TODO(Mddct): streamming future | |
| # eg: full attenton and chunk or dynamic chunk training | |
| # 1 forward subsampling | |
| xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens) | |
| unmasked_xs = xs | |
| # 2 mask features | |
| # 2.0 apply mask | |
| masked_xs, masked_masks = self._apply_mask(xs) | |
| # 2.1 get nearest embedding | |
| target_ids = self._nearest_embedding_idx(unmasked_xs) | |
| # 3 forward xxx-formaer block | |
| out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks) | |
| # 4 get logits | |
| out = out.unsqueeze(1) # [B, 1, T', dim] | |
| top_n_out = self.encoder_top_n_out.unsqueeze( | |
| 0 | |
| ) # [num_codebooks, dim, num_embeddings] | |
| out = torch.matmul(out, top_n_out) # [B, num_codebooks, T', num_embeddings] | |
| # 5 compute loss | |
| loss = self._compute_loss(out, target_ids, out_mask.squeeze(1) * masked_masks) | |
| return {"loss": loss} | |
| def _compute_loss( | |
| self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | |
| ): | |
| input = input.transpose(1, 3) # [B, num_embeddings, T' num_codebooks] | |
| entropy = torch.nn.functional.cross_entropy( | |
| input, target, reduction="none" | |
| ) # [B, T', num_codebooks] | |
| # stop gradient for non mask area | |
| loss = entropy * mask.unsqueeze(2) | |
| return loss.sum() / (mask.sum() * loss.size(2)) | |
| def _forward_encoder_blocks( | |
| self, | |
| xs: torch.Tensor, | |
| xs_masks: torch.Tensor, | |
| pos_emb: torch.Tensor, | |
| mask_pad: torch.Tensor, | |
| ): | |
| masks = xs_masks | |
| for layer in self.encoder.encoders: | |
| xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) | |
| if self.encoder.normalize_before: | |
| xs = self.encoder.after_norm(xs) | |
| # Here we assume the mask is not changed in encoder layers, so just | |
| # return the masks before encoder layers, and the masks will be used | |
| # for cross attention with decoder later | |
| return xs, masks | |
| def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: | |
| xs = self.input_layer_norm(xs) | |
| xs = self.input_dropout(xs) | |
| xs = torch.matmul(xs, self.projection.to(xs.device)) | |
| B, T, C = xs.size() | |
| flattened_input = xs.view(-1, C) | |
| embeddings = self.embeddings.to( | |
| xs.device | |
| ) # [num_codebooks, embedding_dim, num_embeddings] | |
| # [num_codebooks, B*T, num_embeddings] | |
| distance = ( | |
| torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0) | |
| + torch.sum(embeddings**2, dim=1, keepdim=True) | |
| - 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings) | |
| ) | |
| out = torch.argmin(distance, dim=-1) # [num_codebooks, B*T] | |
| out = out.transpose(0, 1) # [B*T, num_codebooks] | |
| return out.reshape(B, T, -1) # [B, T, num_codebooks] | |
| def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| masks = compute_mask_indices( | |
| xs.size()[:-1], | |
| self.mask_prob, | |
| self.mask_length, | |
| self.min_masks, | |
| device=xs.device, | |
| ) | |
| masks_expand = masks.unsqueeze(-1) # [B, T, 1] | |
| mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1) | |
| xs = torch.where(masks_expand, mask_emb, xs) | |
| return xs, masks | |
| def _forward_subsampling( | |
| self, xs: torch.Tensor, xs_lens: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| T = xs.size(1) | |
| masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) | |
| if self.encoder.global_cmvn is not None: | |
| xs = self.encoder.global_cmvn(xs) | |
| xs, pos_emb, masks = self.encoder.embed(xs, masks) | |
| return xs, pos_emb, masks | |