yuancwang
init
b725c5a
from typing import List, Optional, Tuple
import torch
from torch import nn
from modules.wenet_extractor.utils.common import get_activation, get_rnn
def ApplyPadding(input, padding, pad_value) -> torch.Tensor:
"""
Args:
input: [bs, max_time_step, dim]
padding: [bs, max_time_step]
"""
return padding * pad_value + input * (1 - padding)
class PredictorBase(torch.nn.Module):
# NOTE(Mddct): We can use ABC abstract here, but
# keep this class simple enough for now
def __init__(self) -> None:
super().__init__()
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
_, _, _ = batch_size, method, device
raise NotImplementedError("this is a base precictor")
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
_ = cache
raise NotImplementedError("this is a base precictor")
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
_ = cache
raise NotImplementedError("this is a base precictor")
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
):
(
_,
_,
) = (
input,
cache,
)
raise NotImplementedError("this is a base precictor")
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
(
_,
_,
_,
) = (
input,
padding,
cache,
)
raise NotImplementedError("this is a base precictor")
class RNNPredictor(PredictorBase):
def __init__(
self,
voca_size: int,
embed_size: int,
output_size: int,
embed_dropout: float,
hidden_size: int,
num_layers: int,
bias: bool = True,
rnn_type: str = "lstm",
dropout: float = 0.1,
) -> None:
super().__init__()
self.n_layers = num_layers
self.hidden_size = hidden_size
# disable rnn base out projection
self.embed = nn.Embedding(voca_size, embed_size)
self.dropout = nn.Dropout(embed_dropout)
# NOTE(Mddct): rnn base from torch not support layer norm
# will add layer norm and prune value in cell and layer
# ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py
self.rnn = get_rnn(rnn_type=rnn_type)(
input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout,
)
self.projection = nn.Linear(hidden_size, output_size)
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): [batch, max_time).
padding (torch.Tensor): [batch, max_time]
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
Returns:
output: [batch, max_time, output_size]
"""
# NOTE(Mddct): we don't use pack input format
embed = self.embed(input) # [batch, max_time, emb_size]
embed = self.dropout(embed)
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
if cache is None:
state = self.init_state(batch_size=input.size(0), device=input.device)
states = (state[0], state[1])
else:
assert len(cache) == 2
states = (cache[0], cache[1])
out, (m, c) = self.rnn(embed, states)
out = self.projection(out)
# NOTE(Mddct): Although we don't use staate in transducer
# training forward, we need make it right for padding value
# so we create forward_step for infering, forward for training
_, _ = m, c
return out
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache: [state_m, state_c]
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
Returns:
new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...]
"""
assert len(cache) == 2
state_ms = cache[0]
state_cs = cache[1]
assert state_ms.size(1) == state_cs.size(1)
new_cache: List[List[torch.Tensor]] = []
for state_m, state_c in zip(
torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1)
):
new_cache.append([state_m, state_c])
return new_cache
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...]
Returns:
new_caceh: [state_ms, state_cs],
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
"""
state_ms = torch.cat([states[0] for states in cache], dim=1)
state_cs = torch.cat([states[1] for states in cache], dim=1)
return [state_ms, state_cs]
def init_state(
self,
batch_size: int,
device: torch.device,
method: str = "zero",
) -> List[torch.Tensor]:
assert batch_size > 0
# TODO(Mddct): xavier init method
_ = method
return [
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
]
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
"""
assert len(cache) == 2
state_m, state_c = cache[0], cache[1]
embed = self.embed(input) # [batch, 1, emb_size]
embed = self.dropout(embed)
out, (m, c) = self.rnn(embed, (state_m, state_c))
out = self.projection(out)
m = ApplyPadding(m, padding.unsqueeze(0), state_m)
c = ApplyPadding(c, padding.unsqueeze(0), state_c)
return (out, [m, c])
class EmbeddingPredictor(PredictorBase):
"""Embedding predictor
Described in:
https://arxiv.org/pdf/2109.07513.pdf
embed-> proj -> layer norm -> swish
"""
def __init__(
self,
voca_size: int,
embed_size: int,
embed_dropout: float,
n_head: int,
history_size: int = 2,
activation: str = "swish",
bias: bool = False,
layer_norm_epsilon: float = 1e-5,
) -> None:
super().__init__()
# multi head
self.num_heads = n_head
self.embed_size = embed_size
self.context_size = history_size + 1
self.pos_embed = torch.nn.Linear(
embed_size * self.context_size, self.num_heads, bias=bias
)
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.ffn = nn.Linear(self.embed_size, self.embed_size)
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
assert batch_size > 0
_ = method
return [
torch.zeros(
batch_size, self.context_size - 1, self.embed_size, device=device
),
]
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
"""forward for training"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat(
(zeros, input), dim=1
) # [bs, context_size-1 + seq_len, embed]
input = input.unfold(1, self.context_size, 1).permute(
0, 1, 3, 2
) # [bs, seq_len, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(
self.num_heads, self.embed_size, self.context_size
)
# broadcast dot attenton
input_expand = input.unsqueeze(2) # [bs, seq_len, 1, context_size, embed]
multi_head_pos = multi_head_pos.permute(
0, 2, 1
) # [num_heads, context_size, embed]
# [bs, seq_len, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3
) # [bs, seq_len, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(
dim=3
) # [bs, seq_len, num_heads, embed]
output = output.sum(dim=2) # [bs, seq_len, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
return output
def forward_step(
self,
input: torch.Tensor,
padding: torch.Tensor,
cache: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input_expand = context_input.unsqueeze(1).unsqueeze(
2
) # [bs, 1, 1, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(
self.num_heads, self.embed_size, self.context_size
)
multi_head_pos = multi_head_pos.permute(
0, 2, 1
) # [num_heads, context_size, embed]
# [bs, 1, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3
) # [bs, 1, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(dim=3) # [bs, 1, num_heads, embed]
output = output.sum(dim=2) # [bs, 1, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
new_cache = context_input[:, 1:, :]
# TODO(Mddct): we need padding new_cache in future
# new_cache = ApplyPadding(history, padding, new_cache)
return (output, [new_cache])
class ConvPredictor(PredictorBase):
def __init__(
self,
voca_size: int,
embed_size: int,
embed_dropout: float,
history_size: int = 2,
activation: str = "relu",
bias: bool = False,
layer_norm_epsilon: float = 1e-5,
) -> None:
super().__init__()
assert history_size >= 0
self.embed_size = embed_size
self.context_size = history_size + 1
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.conv = nn.Conv1d(
in_channels=embed_size,
out_channels=embed_size,
kernel_size=self.context_size,
padding=0,
groups=embed_size,
bias=bias,
)
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
assert batch_size > 0
assert method == "zero"
return [
torch.zeros(
batch_size, self.context_size - 1, self.embed_size, device=device
)
]
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
"""forward for training"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat(
(zeros, input), dim=1
) # [bs, context_size-1 + seq_len, embed]
input = input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
return out
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input = context_input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
new_cache = context_input[:, 1:, :]
# TODO(Mddct): apply padding in future
return (out, [new_cache])