RMSnow's picture
add backend inference and inferface output
0883aa1
raw
history blame
No virus
5.83 kB
from typing import Optional, Tuple
import torch
from wenet.ssl.bestrq.mask import compute_mask_indices
from wenet.utils.mask import make_pad_mask
class BestRQModel(torch.nn.Module):
def __init__(
self,
encoder: torch.nn.Module,
input_dim: int = 256,
embedding_dim: int = 256,
num_embeddings: int = 8192,
num_codebooks: int = 1,
dropout_rate: float = 0.1,
mask_prob: float = 0.01,
mask_length: int = 10,
min_masks: int = 2,
layer_norm_epsilon=1e-5,
) -> None:
super().__init__()
assert mask_prob > 0.0
self.mask_prob = mask_prob
# NOTE: should filter audio less than mask_length
self.mask_length = mask_length
self.min_masks = min_masks
self.input_dropout = torch.nn.Dropout(dropout_rate)
# [embedding_dim, num_embeddings]
random_embedding_weight = torch.empty(
num_codebooks, embedding_dim, num_embeddings, requires_grad=False
)
self.embeddings = torch.nn.init.normal_(random_embedding_weight)
random_projection_weight = torch.empty(
input_dim, embedding_dim, requires_grad=False
)
self.projection = torch.nn.init.xavier_normal_(random_projection_weight)
mask_emb_weight = torch.Tensor(input_dim)
mask_emb_weight.requires_grad = True
self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1)
self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon)
self.encoder = encoder
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.Tensor(num_codebooks, self.encoder.output_size(), num_embeddings)
)
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
text: Optional[torch.Tensor] = None,
text_length: Optional[torch.Tensor] = None,
):
# should support nonstreamming and streamming
# TODO(Mddct): streamming future
# eg: full attenton and chunk or dynamic chunk training
# 1 forward subsampling
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
unmasked_xs = xs
# 2 mask features
# 2.0 apply mask
masked_xs, masked_masks = self._apply_mask(xs)
# 2.1 get nearest embedding
target_ids = self._nearest_embedding_idx(unmasked_xs)
# 3 forward xxx-formaer block
out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks)
# 4 get logits
out = out.unsqueeze(1) # [B, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0
) # [num_codebooks, dim, num_embeddings]
out = torch.matmul(out, top_n_out) # [B, num_codebooks, T', num_embeddings]
# 5 compute loss
loss = self._compute_loss(out, target_ids, out_mask.squeeze(1) * masked_masks)
return {"loss": loss}
def _compute_loss(
self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
):
input = input.transpose(1, 3) # [B, num_embeddings, T' num_codebooks]
entropy = torch.nn.functional.cross_entropy(
input, target, reduction="none"
) # [B, T', num_codebooks]
# stop gradient for non mask area
loss = entropy * mask.unsqueeze(2)
return loss.sum() / (mask.sum() * loss.size(2))
def _forward_encoder_blocks(
self,
xs: torch.Tensor,
xs_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor,
):
masks = xs_masks
for layer in self.encoder.encoders:
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.input_layer_norm(xs)
xs = self.input_dropout(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))
B, T, C = xs.size()
flattened_input = xs.view(-1, C)
embeddings = self.embeddings.to(
xs.device
) # [num_codebooks, embedding_dim, num_embeddings]
# [num_codebooks, B*T, num_embeddings]
distance = (
torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0)
+ torch.sum(embeddings**2, dim=1, keepdim=True)
- 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings)
)
out = torch.argmin(distance, dim=-1) # [num_codebooks, B*T]
out = out.transpose(0, 1) # [B*T, num_codebooks]
return out.reshape(B, T, -1) # [B, T, num_codebooks]
def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
masks = compute_mask_indices(
xs.size()[:-1],
self.mask_prob,
self.mask_length,
self.min_masks,
device=xs.device,
)
masks_expand = masks.unsqueeze(-1) # [B, T, 1]
mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1)
xs = torch.where(masks_expand, mask_emb, xs)
return xs, masks
def _forward_subsampling(
self, xs: torch.Tensor, xs_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.encoder.global_cmvn is not None:
xs = self.encoder.global_cmvn(xs)
xs, pos_emb, masks = self.encoder.embed(xs, masks)
return xs, pos_emb, masks