reborn-uasr_ls100h_iter2-stage1 / modeling_reborn.py
andybi7676's picture
Upload model
eef5961 verified
raw
history blame
No virus
6.55 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_reborn import RebornUASRConfig
from typing import Optional, Tuple, Union
class RebornSegmenter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.conv1 = nn.Conv1d(config.segmenter_input_dim, config.segmenter_hidden_dim, config.segmenter_kernel_size, padding=config.segmenter_kernel_size//2)
self.conv2 = nn.Conv1d(config.segmenter_hidden_dim, config.segmenter_hidden_dim, 3, padding=1)
self.conv3 = nn.Conv1d(config.segmenter_hidden_dim, 2, 1)
self.dropout = nn.Dropout(config.segmenter_dropout)
self.relu = nn.ReLU()
def forward(self, x):
"""
Input:
x: (B, T, C)
padding_mask: (B, T) # 0: not padding; 1: padding
Output:
boundary: (B, T, 2) # 0: not boundary; 1: boundary
"""
x = x.transpose(1, 2)
x = self.dropout(self.relu(self.conv1(x)))
x = self.dropout(self.relu(self.conv2(x)))
x = self.conv3(x)
x = x.transpose(1, 2)
return x
def boundary_predict(self, x, padding_mask, deterministic=False):
"""
Input:
x: (B, T, C)
padding_mask: (B, T)
Output:
boundary: (B, T) # 0: not boundary; 1: boundary
boundary_logits: (B, T, 2) # 0: not boundary; 1: boundary
"""
boundary_logits = self.forward(x)
if deterministic:
boundary = boundary_logits.argmax(-1)
boundary[padding_mask] = -1
else:
boundary = torch.distributions.Categorical(logits=boundary_logits).sample()
boundary[padding_mask] = -1
return boundary, boundary_logits
def pre_segment(self, logits, padding_mask, return_boundary=False, deterministic=True):
"""
Input:
logits: (B, T, C)
padding_mask: (B, T)
Output:
new_logits: (B, T', C)
new_padding_mask: (B, T')
"""
bsz, tsz, csz = logits.size()
boundary, boundary_logits = self.boundary_predict(logits, padding_mask, deterministic=deterministic)
# max boundary number
# print("boundary", boundary)
# print(torch.sum(boundary==1, dim=1))
new_tsz = int(torch.max(torch.sum(boundary==1, dim=1)).item())+1 # add <bos>
new_logits = logits.new_zeros(bsz, new_tsz, csz)
new_pad = padding_mask.new_zeros(bsz, new_tsz)
for b in range(bsz):
# merge consecutive segments when meeting a boundary (mean_pool_join)
new_idx = 0
count = 0
for t in range(tsz):
if padding_mask[b, t] == 1:
break
if boundary[b, t] == 1:
new_logits[b, new_idx] /= count
new_idx += 1
count = 0
new_logits[b, new_idx] += logits[b, t]
count += 1
if count > 0:
# last segment
new_logits[b, new_idx] /= count
new_idx += 1
count = 0
if new_idx < new_tsz:
pad = new_tsz - new_idx
new_logits[b, -pad:] = 0
new_pad[b, -pad:] = True
if return_boundary:
return new_logits, new_pad, boundary, boundary_logits
return new_logits, new_pad
class RebornGenerator(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.output_dim = config.generator_output_dim
self.stride = config.generator_stride
self.dropout = nn.Dropout(config.generator_dropout)
cnn_input_dim = config.generator_input_dim
cnn_output_dim = config.generator_output_dim
padding = config.generator_kernel // 2
self.proj = nn.Sequential(
nn.Conv1d(
cnn_input_dim,
cnn_output_dim,
kernel_size=config.generator_kernel,
stride=config.generator_stride,
dilation=config.generator_dilation,
padding=padding,
bias=config.generator_bias,
),
)
def forward(self, dense_x, tokens, dense_padding_mask):
dense_x = self.dropout(dense_x)
# (B, T, C) -> (B, C, T)
dense_x = dense_x.transpose(-2, -1)
dense_x = self.proj(dense_x)
# (B, C, T) -> (B, T, C)
dense_x = dense_x.transpose(-2, -1)
if self.stride > 1:
dense_padding_mask = dense_padding_mask[:, :: self.stride]
if dense_padding_mask.size(1) != dense_x.size(1):
new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1])
diff = new_padding.size(1) - dense_padding_mask.size(1)
assert (
diff > 0
), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}"
if diff > 0:
new_padding[:, diff:] = dense_padding_mask
else:
assert diff < 0
new_padding = dense_padding_mask[:, :diff]
dense_padding_mask = new_padding
result = {}
token_x = None
if tokens is not None:
token_x = dense_x.new_zeros(tokens.numel(), self.output_dim)
token_x.scatter_(1, tokens.view(-1, 1).long(), 1)
token_x = token_x.view(tokens.shape + (self.output_dim,))
result["dense_x"] = dense_x
result["token_x"] = token_x
result["dense_padding_mask"] = dense_padding_mask
return result
class RebornUASRModel(PreTrainedModel):
config_class = RebornUASRConfig
def __init__(self, config):
super().__init__(config)
self.pca = nn.Linear(1024, 512)
self.segmenter = RebornSegmenter(config)
self.generator = RebornGenerator(config)
def forward(
self,
x: Optional[torch.Tensor], # (B, T, C)
padding_mask: Optional[torch.Tensor], # (B, T)
):
x_reduced = self.pca(x)
x_segmented, segmented_padding_mask = self.segmenter.pre_segment(x_reduced, padding_mask, deterministic=True)
x_generated = self.generator(x_segmented, None, segmented_padding_mask)
return {
'x_reduced': x_reduced,
'x_segmented': x_segmented,
'x_generated': x_generated
}