Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from modules.wenet_extractor.utils.common import get_activation, get_rnn | |
| def ApplyPadding(input, padding, pad_value) -> torch.Tensor: | |
| """ | |
| Args: | |
| input: [bs, max_time_step, dim] | |
| padding: [bs, max_time_step] | |
| """ | |
| return padding * pad_value + input * (1 - padding) | |
| class PredictorBase(torch.nn.Module): | |
| # NOTE(Mddct): We can use ABC abstract here, but | |
| # keep this class simple enough for now | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def init_state( | |
| self, batch_size: int, device: torch.device, method: str = "zero" | |
| ) -> List[torch.Tensor]: | |
| _, _, _ = batch_size, method, device | |
| raise NotImplementedError("this is a base precictor") | |
| def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
| _ = cache | |
| raise NotImplementedError("this is a base precictor") | |
| def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
| _ = cache | |
| raise NotImplementedError("this is a base precictor") | |
| def forward( | |
| self, | |
| input: torch.Tensor, | |
| cache: Optional[List[torch.Tensor]] = None, | |
| ): | |
| ( | |
| _, | |
| _, | |
| ) = ( | |
| input, | |
| cache, | |
| ) | |
| raise NotImplementedError("this is a base precictor") | |
| def forward_step( | |
| self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| ( | |
| _, | |
| _, | |
| _, | |
| ) = ( | |
| input, | |
| padding, | |
| cache, | |
| ) | |
| raise NotImplementedError("this is a base precictor") | |
| class RNNPredictor(PredictorBase): | |
| def __init__( | |
| self, | |
| voca_size: int, | |
| embed_size: int, | |
| output_size: int, | |
| embed_dropout: float, | |
| hidden_size: int, | |
| num_layers: int, | |
| bias: bool = True, | |
| rnn_type: str = "lstm", | |
| dropout: float = 0.1, | |
| ) -> None: | |
| super().__init__() | |
| self.n_layers = num_layers | |
| self.hidden_size = hidden_size | |
| # disable rnn base out projection | |
| self.embed = nn.Embedding(voca_size, embed_size) | |
| self.dropout = nn.Dropout(embed_dropout) | |
| # NOTE(Mddct): rnn base from torch not support layer norm | |
| # will add layer norm and prune value in cell and layer | |
| # ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py | |
| self.rnn = get_rnn(rnn_type=rnn_type)( | |
| input_size=embed_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| bias=bias, | |
| batch_first=True, | |
| dropout=dropout, | |
| ) | |
| self.projection = nn.Linear(hidden_size, output_size) | |
| def forward( | |
| self, | |
| input: torch.Tensor, | |
| cache: Optional[List[torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| input (torch.Tensor): [batch, max_time). | |
| padding (torch.Tensor): [batch, max_time] | |
| cache : rnn predictor cache[0] == state_m | |
| cache[1] == state_c | |
| Returns: | |
| output: [batch, max_time, output_size] | |
| """ | |
| # NOTE(Mddct): we don't use pack input format | |
| embed = self.embed(input) # [batch, max_time, emb_size] | |
| embed = self.dropout(embed) | |
| states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None | |
| if cache is None: | |
| state = self.init_state(batch_size=input.size(0), device=input.device) | |
| states = (state[0], state[1]) | |
| else: | |
| assert len(cache) == 2 | |
| states = (cache[0], cache[1]) | |
| out, (m, c) = self.rnn(embed, states) | |
| out = self.projection(out) | |
| # NOTE(Mddct): Although we don't use staate in transducer | |
| # training forward, we need make it right for padding value | |
| # so we create forward_step for infering, forward for training | |
| _, _ = m, c | |
| return out | |
| def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
| """ | |
| Args: | |
| cache: [state_m, state_c] | |
| state_ms: [1*n_layers, bs, ...] | |
| state_cs: [1*n_layers, bs, ...] | |
| Returns: | |
| new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...] | |
| """ | |
| assert len(cache) == 2 | |
| state_ms = cache[0] | |
| state_cs = cache[1] | |
| assert state_ms.size(1) == state_cs.size(1) | |
| new_cache: List[List[torch.Tensor]] = [] | |
| for state_m, state_c in zip( | |
| torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1) | |
| ): | |
| new_cache.append([state_m, state_c]) | |
| return new_cache | |
| def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
| """ | |
| Args: | |
| cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...] | |
| Returns: | |
| new_caceh: [state_ms, state_cs], | |
| state_ms: [1*n_layers, bs, ...] | |
| state_cs: [1*n_layers, bs, ...] | |
| """ | |
| state_ms = torch.cat([states[0] for states in cache], dim=1) | |
| state_cs = torch.cat([states[1] for states in cache], dim=1) | |
| return [state_ms, state_cs] | |
| def init_state( | |
| self, | |
| batch_size: int, | |
| device: torch.device, | |
| method: str = "zero", | |
| ) -> List[torch.Tensor]: | |
| assert batch_size > 0 | |
| # TODO(Mddct): xavier init method | |
| _ = method | |
| return [ | |
| torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), | |
| torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), | |
| ] | |
| def forward_step( | |
| self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| """ | |
| Args: | |
| input (torch.Tensor): [batch_size, time_step=1] | |
| padding (torch.Tensor): [batch_size,1], 1 is padding value | |
| cache : rnn predictor cache[0] == state_m | |
| cache[1] == state_c | |
| """ | |
| assert len(cache) == 2 | |
| state_m, state_c = cache[0], cache[1] | |
| embed = self.embed(input) # [batch, 1, emb_size] | |
| embed = self.dropout(embed) | |
| out, (m, c) = self.rnn(embed, (state_m, state_c)) | |
| out = self.projection(out) | |
| m = ApplyPadding(m, padding.unsqueeze(0), state_m) | |
| c = ApplyPadding(c, padding.unsqueeze(0), state_c) | |
| return (out, [m, c]) | |
| class EmbeddingPredictor(PredictorBase): | |
| """Embedding predictor | |
| Described in: | |
| https://arxiv.org/pdf/2109.07513.pdf | |
| embed-> proj -> layer norm -> swish | |
| """ | |
| def __init__( | |
| self, | |
| voca_size: int, | |
| embed_size: int, | |
| embed_dropout: float, | |
| n_head: int, | |
| history_size: int = 2, | |
| activation: str = "swish", | |
| bias: bool = False, | |
| layer_norm_epsilon: float = 1e-5, | |
| ) -> None: | |
| super().__init__() | |
| # multi head | |
| self.num_heads = n_head | |
| self.embed_size = embed_size | |
| self.context_size = history_size + 1 | |
| self.pos_embed = torch.nn.Linear( | |
| embed_size * self.context_size, self.num_heads, bias=bias | |
| ) | |
| self.embed = nn.Embedding(voca_size, self.embed_size) | |
| self.embed_dropout = nn.Dropout(p=embed_dropout) | |
| self.ffn = nn.Linear(self.embed_size, self.embed_size) | |
| self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon) | |
| self.activatoin = get_activation(activation) | |
| def init_state( | |
| self, batch_size: int, device: torch.device, method: str = "zero" | |
| ) -> List[torch.Tensor]: | |
| assert batch_size > 0 | |
| _ = method | |
| return [ | |
| torch.zeros( | |
| batch_size, self.context_size - 1, self.embed_size, device=device | |
| ), | |
| ] | |
| def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
| """ | |
| Args: | |
| cache : [history] | |
| history: [bs, ...] | |
| Returns: | |
| new_ache : [[history_1], [history_2], [history_3]...] | |
| """ | |
| assert len(cache) == 1 | |
| cache_0 = cache[0] | |
| history: List[List[torch.Tensor]] = [] | |
| for h in torch.split(cache_0, 1, dim=0): | |
| history.append([h]) | |
| return history | |
| def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
| """ | |
| Args: | |
| cache : [[history_1], [history_2], [history3]...] | |
| Returns: | |
| new_caceh: [history], | |
| history: [bs, ...] | |
| """ | |
| history = torch.cat([h[0] for h in cache], dim=0) | |
| return [history] | |
| def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): | |
| """forward for training""" | |
| input = self.embed(input) # [bs, seq_len, embed] | |
| input = self.embed_dropout(input) | |
| if cache is None: | |
| zeros = self.init_state(input.size(0), device=input.device)[0] | |
| else: | |
| assert len(cache) == 1 | |
| zeros = cache[0] | |
| input = torch.cat( | |
| (zeros, input), dim=1 | |
| ) # [bs, context_size-1 + seq_len, embed] | |
| input = input.unfold(1, self.context_size, 1).permute( | |
| 0, 1, 3, 2 | |
| ) # [bs, seq_len, context_size, embed] | |
| # multi head pos: [n_head, embed, context_size] | |
| multi_head_pos = self.pos_embed.weight.view( | |
| self.num_heads, self.embed_size, self.context_size | |
| ) | |
| # broadcast dot attenton | |
| input_expand = input.unsqueeze(2) # [bs, seq_len, 1, context_size, embed] | |
| multi_head_pos = multi_head_pos.permute( | |
| 0, 2, 1 | |
| ) # [num_heads, context_size, embed] | |
| # [bs, seq_len, num_heads, context_size, embed] | |
| weight = input_expand * multi_head_pos | |
| weight = weight.sum(dim=-1, keepdim=False).unsqueeze( | |
| 3 | |
| ) # [bs, seq_len, num_heads, 1, context_size] | |
| output = weight.matmul(input_expand).squeeze( | |
| dim=3 | |
| ) # [bs, seq_len, num_heads, embed] | |
| output = output.sum(dim=2) # [bs, seq_len, embed] | |
| output = output / (self.num_heads * self.context_size) | |
| output = self.ffn(output) | |
| output = self.norm(output) | |
| output = self.activatoin(output) | |
| return output | |
| def forward_step( | |
| self, | |
| input: torch.Tensor, | |
| padding: torch.Tensor, | |
| cache: List[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| """forward step for inference | |
| Args: | |
| input (torch.Tensor): [batch_size, time_step=1] | |
| padding (torch.Tensor): [batch_size,1], 1 is padding value | |
| cache: for embedding predictor, cache[0] == history | |
| """ | |
| assert input.size(1) == 1 | |
| assert len(cache) == 1 | |
| history = cache[0] | |
| assert history.size(1) == self.context_size - 1 | |
| input = self.embed(input) # [bs, 1, embed] | |
| input = self.embed_dropout(input) | |
| context_input = torch.cat((history, input), dim=1) | |
| input_expand = context_input.unsqueeze(1).unsqueeze( | |
| 2 | |
| ) # [bs, 1, 1, context_size, embed] | |
| # multi head pos: [n_head, embed, context_size] | |
| multi_head_pos = self.pos_embed.weight.view( | |
| self.num_heads, self.embed_size, self.context_size | |
| ) | |
| multi_head_pos = multi_head_pos.permute( | |
| 0, 2, 1 | |
| ) # [num_heads, context_size, embed] | |
| # [bs, 1, num_heads, context_size, embed] | |
| weight = input_expand * multi_head_pos | |
| weight = weight.sum(dim=-1, keepdim=False).unsqueeze( | |
| 3 | |
| ) # [bs, 1, num_heads, 1, context_size] | |
| output = weight.matmul(input_expand).squeeze(dim=3) # [bs, 1, num_heads, embed] | |
| output = output.sum(dim=2) # [bs, 1, embed] | |
| output = output / (self.num_heads * self.context_size) | |
| output = self.ffn(output) | |
| output = self.norm(output) | |
| output = self.activatoin(output) | |
| new_cache = context_input[:, 1:, :] | |
| # TODO(Mddct): we need padding new_cache in future | |
| # new_cache = ApplyPadding(history, padding, new_cache) | |
| return (output, [new_cache]) | |
| class ConvPredictor(PredictorBase): | |
| def __init__( | |
| self, | |
| voca_size: int, | |
| embed_size: int, | |
| embed_dropout: float, | |
| history_size: int = 2, | |
| activation: str = "relu", | |
| bias: bool = False, | |
| layer_norm_epsilon: float = 1e-5, | |
| ) -> None: | |
| super().__init__() | |
| assert history_size >= 0 | |
| self.embed_size = embed_size | |
| self.context_size = history_size + 1 | |
| self.embed = nn.Embedding(voca_size, self.embed_size) | |
| self.embed_dropout = nn.Dropout(p=embed_dropout) | |
| self.conv = nn.Conv1d( | |
| in_channels=embed_size, | |
| out_channels=embed_size, | |
| kernel_size=self.context_size, | |
| padding=0, | |
| groups=embed_size, | |
| bias=bias, | |
| ) | |
| self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon) | |
| self.activatoin = get_activation(activation) | |
| def init_state( | |
| self, batch_size: int, device: torch.device, method: str = "zero" | |
| ) -> List[torch.Tensor]: | |
| assert batch_size > 0 | |
| assert method == "zero" | |
| return [ | |
| torch.zeros( | |
| batch_size, self.context_size - 1, self.embed_size, device=device | |
| ) | |
| ] | |
| def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: | |
| """ | |
| Args: | |
| cache : [[history_1], [history_2], [history3]...] | |
| Returns: | |
| new_caceh: [history], | |
| history: [bs, ...] | |
| """ | |
| history = torch.cat([h[0] for h in cache], dim=0) | |
| return [history] | |
| def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |
| """ | |
| Args: | |
| cache : [history] | |
| history: [bs, ...] | |
| Returns: | |
| new_ache : [[history_1], [history_2], [history_3]...] | |
| """ | |
| assert len(cache) == 1 | |
| cache_0 = cache[0] | |
| history: List[List[torch.Tensor]] = [] | |
| for h in torch.split(cache_0, 1, dim=0): | |
| history.append([h]) | |
| return history | |
| def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): | |
| """forward for training""" | |
| input = self.embed(input) # [bs, seq_len, embed] | |
| input = self.embed_dropout(input) | |
| if cache is None: | |
| zeros = self.init_state(input.size(0), device=input.device)[0] | |
| else: | |
| assert len(cache) == 1 | |
| zeros = cache[0] | |
| input = torch.cat( | |
| (zeros, input), dim=1 | |
| ) # [bs, context_size-1 + seq_len, embed] | |
| input = input.permute(0, 2, 1) | |
| out = self.conv(input).permute(0, 2, 1) | |
| out = self.activatoin(self.norm(out)) | |
| return out | |
| def forward_step( | |
| self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| """forward step for inference | |
| Args: | |
| input (torch.Tensor): [batch_size, time_step=1] | |
| padding (torch.Tensor): [batch_size,1], 1 is padding value | |
| cache: for embedding predictor, cache[0] == history | |
| """ | |
| assert input.size(1) == 1 | |
| assert len(cache) == 1 | |
| history = cache[0] | |
| assert history.size(1) == self.context_size - 1 | |
| input = self.embed(input) # [bs, 1, embed] | |
| input = self.embed_dropout(input) | |
| context_input = torch.cat((history, input), dim=1) | |
| input = context_input.permute(0, 2, 1) | |
| out = self.conv(input).permute(0, 2, 1) | |
| out = self.activatoin(self.norm(out)) | |
| new_cache = context_input[:, 1:, :] | |
| # TODO(Mddct): apply padding in future | |
| return (out, [new_cache]) | |