RMSnow's picture
add backend inference and inferface output
0883aa1
raw
history blame contribute delete
No virus
16.2 kB
from typing import List, Optional, Tuple
import torch
from torch import nn
from modules.wenet_extractor.utils.common import get_activation, get_rnn
def ApplyPadding(input, padding, pad_value) -> torch.Tensor:
"""
Args:
input: [bs, max_time_step, dim]
padding: [bs, max_time_step]
"""
return padding * pad_value + input * (1 - padding)
class PredictorBase(torch.nn.Module):
# NOTE(Mddct): We can use ABC abstract here, but
# keep this class simple enough for now
def __init__(self) -> None:
super().__init__()
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
_, _, _ = batch_size, method, device
raise NotImplementedError("this is a base precictor")
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
_ = cache
raise NotImplementedError("this is a base precictor")
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
_ = cache
raise NotImplementedError("this is a base precictor")
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
):
(
_,
_,
) = (
input,
cache,
)
raise NotImplementedError("this is a base precictor")
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
(
_,
_,
_,
) = (
input,
padding,
cache,
)
raise NotImplementedError("this is a base precictor")
class RNNPredictor(PredictorBase):
def __init__(
self,
voca_size: int,
embed_size: int,
output_size: int,
embed_dropout: float,
hidden_size: int,
num_layers: int,
bias: bool = True,
rnn_type: str = "lstm",
dropout: float = 0.1,
) -> None:
super().__init__()
self.n_layers = num_layers
self.hidden_size = hidden_size
# disable rnn base out projection
self.embed = nn.Embedding(voca_size, embed_size)
self.dropout = nn.Dropout(embed_dropout)
# NOTE(Mddct): rnn base from torch not support layer norm
# will add layer norm and prune value in cell and layer
# ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py
self.rnn = get_rnn(rnn_type=rnn_type)(
input_size=embed_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout,
)
self.projection = nn.Linear(hidden_size, output_size)
def forward(
self,
input: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): [batch, max_time).
padding (torch.Tensor): [batch, max_time]
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
Returns:
output: [batch, max_time, output_size]
"""
# NOTE(Mddct): we don't use pack input format
embed = self.embed(input) # [batch, max_time, emb_size]
embed = self.dropout(embed)
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
if cache is None:
state = self.init_state(batch_size=input.size(0), device=input.device)
states = (state[0], state[1])
else:
assert len(cache) == 2
states = (cache[0], cache[1])
out, (m, c) = self.rnn(embed, states)
out = self.projection(out)
# NOTE(Mddct): Although we don't use staate in transducer
# training forward, we need make it right for padding value
# so we create forward_step for infering, forward for training
_, _ = m, c
return out
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache: [state_m, state_c]
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
Returns:
new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...]
"""
assert len(cache) == 2
state_ms = cache[0]
state_cs = cache[1]
assert state_ms.size(1) == state_cs.size(1)
new_cache: List[List[torch.Tensor]] = []
for state_m, state_c in zip(
torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1)
):
new_cache.append([state_m, state_c])
return new_cache
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...]
Returns:
new_caceh: [state_ms, state_cs],
state_ms: [1*n_layers, bs, ...]
state_cs: [1*n_layers, bs, ...]
"""
state_ms = torch.cat([states[0] for states in cache], dim=1)
state_cs = torch.cat([states[1] for states in cache], dim=1)
return [state_ms, state_cs]
def init_state(
self,
batch_size: int,
device: torch.device,
method: str = "zero",
) -> List[torch.Tensor]:
assert batch_size > 0
# TODO(Mddct): xavier init method
_ = method
return [
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
]
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache : rnn predictor cache[0] == state_m
cache[1] == state_c
"""
assert len(cache) == 2
state_m, state_c = cache[0], cache[1]
embed = self.embed(input) # [batch, 1, emb_size]
embed = self.dropout(embed)
out, (m, c) = self.rnn(embed, (state_m, state_c))
out = self.projection(out)
m = ApplyPadding(m, padding.unsqueeze(0), state_m)
c = ApplyPadding(c, padding.unsqueeze(0), state_c)
return (out, [m, c])
class EmbeddingPredictor(PredictorBase):
"""Embedding predictor
Described in:
https://arxiv.org/pdf/2109.07513.pdf
embed-> proj -> layer norm -> swish
"""
def __init__(
self,
voca_size: int,
embed_size: int,
embed_dropout: float,
n_head: int,
history_size: int = 2,
activation: str = "swish",
bias: bool = False,
layer_norm_epsilon: float = 1e-5,
) -> None:
super().__init__()
# multi head
self.num_heads = n_head
self.embed_size = embed_size
self.context_size = history_size + 1
self.pos_embed = torch.nn.Linear(
embed_size * self.context_size, self.num_heads, bias=bias
)
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.ffn = nn.Linear(self.embed_size, self.embed_size)
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
assert batch_size > 0
_ = method
return [
torch.zeros(
batch_size, self.context_size - 1, self.embed_size, device=device
),
]
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
"""forward for training"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat(
(zeros, input), dim=1
) # [bs, context_size-1 + seq_len, embed]
input = input.unfold(1, self.context_size, 1).permute(
0, 1, 3, 2
) # [bs, seq_len, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(
self.num_heads, self.embed_size, self.context_size
)
# broadcast dot attenton
input_expand = input.unsqueeze(2) # [bs, seq_len, 1, context_size, embed]
multi_head_pos = multi_head_pos.permute(
0, 2, 1
) # [num_heads, context_size, embed]
# [bs, seq_len, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3
) # [bs, seq_len, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(
dim=3
) # [bs, seq_len, num_heads, embed]
output = output.sum(dim=2) # [bs, seq_len, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
return output
def forward_step(
self,
input: torch.Tensor,
padding: torch.Tensor,
cache: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input_expand = context_input.unsqueeze(1).unsqueeze(
2
) # [bs, 1, 1, context_size, embed]
# multi head pos: [n_head, embed, context_size]
multi_head_pos = self.pos_embed.weight.view(
self.num_heads, self.embed_size, self.context_size
)
multi_head_pos = multi_head_pos.permute(
0, 2, 1
) # [num_heads, context_size, embed]
# [bs, 1, num_heads, context_size, embed]
weight = input_expand * multi_head_pos
weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
3
) # [bs, 1, num_heads, 1, context_size]
output = weight.matmul(input_expand).squeeze(dim=3) # [bs, 1, num_heads, embed]
output = output.sum(dim=2) # [bs, 1, embed]
output = output / (self.num_heads * self.context_size)
output = self.ffn(output)
output = self.norm(output)
output = self.activatoin(output)
new_cache = context_input[:, 1:, :]
# TODO(Mddct): we need padding new_cache in future
# new_cache = ApplyPadding(history, padding, new_cache)
return (output, [new_cache])
class ConvPredictor(PredictorBase):
def __init__(
self,
voca_size: int,
embed_size: int,
embed_dropout: float,
history_size: int = 2,
activation: str = "relu",
bias: bool = False,
layer_norm_epsilon: float = 1e-5,
) -> None:
super().__init__()
assert history_size >= 0
self.embed_size = embed_size
self.context_size = history_size + 1
self.embed = nn.Embedding(voca_size, self.embed_size)
self.embed_dropout = nn.Dropout(p=embed_dropout)
self.conv = nn.Conv1d(
in_channels=embed_size,
out_channels=embed_size,
kernel_size=self.context_size,
padding=0,
groups=embed_size,
bias=bias,
)
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
self.activatoin = get_activation(activation)
def init_state(
self, batch_size: int, device: torch.device, method: str = "zero"
) -> List[torch.Tensor]:
assert batch_size > 0
assert method == "zero"
return [
torch.zeros(
batch_size, self.context_size - 1, self.embed_size, device=device
)
]
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Args:
cache : [[history_1], [history_2], [history3]...]
Returns:
new_caceh: [history],
history: [bs, ...]
"""
history = torch.cat([h[0] for h in cache], dim=0)
return [history]
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
"""
Args:
cache : [history]
history: [bs, ...]
Returns:
new_ache : [[history_1], [history_2], [history_3]...]
"""
assert len(cache) == 1
cache_0 = cache[0]
history: List[List[torch.Tensor]] = []
for h in torch.split(cache_0, 1, dim=0):
history.append([h])
return history
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
"""forward for training"""
input = self.embed(input) # [bs, seq_len, embed]
input = self.embed_dropout(input)
if cache is None:
zeros = self.init_state(input.size(0), device=input.device)[0]
else:
assert len(cache) == 1
zeros = cache[0]
input = torch.cat(
(zeros, input), dim=1
) # [bs, context_size-1 + seq_len, embed]
input = input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
return out
def forward_step(
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""forward step for inference
Args:
input (torch.Tensor): [batch_size, time_step=1]
padding (torch.Tensor): [batch_size,1], 1 is padding value
cache: for embedding predictor, cache[0] == history
"""
assert input.size(1) == 1
assert len(cache) == 1
history = cache[0]
assert history.size(1) == self.context_size - 1
input = self.embed(input) # [bs, 1, embed]
input = self.embed_dropout(input)
context_input = torch.cat((history, input), dim=1)
input = context_input.permute(0, 2, 1)
out = self.conv(input).permute(0, 2, 1)
out = self.activatoin(self.norm(out))
new_cache = context_input[:, 1:, :]
# TODO(Mddct): apply padding in future
return (out, [new_cache])