|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from .configuration_reborn import RebornUASRConfig |
|
from typing import Optional, Tuple, Union |
|
|
|
class RebornSegmenter(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.conv1 = nn.Conv1d(config.segmenter_input_dim, config.segmenter_hidden_dim, config.segmenter_kernel_size, padding=config.segmenter_kernel_size//2) |
|
self.conv2 = nn.Conv1d(config.segmenter_hidden_dim, config.segmenter_hidden_dim, 3, padding=1) |
|
self.conv3 = nn.Conv1d(config.segmenter_hidden_dim, 2, 1) |
|
self.dropout = nn.Dropout(config.segmenter_dropout) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
""" |
|
Input: |
|
x: (B, T, C) |
|
padding_mask: (B, T) # 0: not padding; 1: padding |
|
Output: |
|
boundary: (B, T, 2) # 0: not boundary; 1: boundary |
|
""" |
|
x = x.transpose(1, 2) |
|
x = self.dropout(self.relu(self.conv1(x))) |
|
x = self.dropout(self.relu(self.conv2(x))) |
|
x = self.conv3(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
def boundary_predict(self, x, padding_mask, deterministic=False): |
|
""" |
|
Input: |
|
x: (B, T, C) |
|
padding_mask: (B, T) |
|
Output: |
|
boundary: (B, T) # 0: not boundary; 1: boundary |
|
boundary_logits: (B, T, 2) # 0: not boundary; 1: boundary |
|
""" |
|
boundary_logits = self.forward(x) |
|
if deterministic: |
|
boundary = boundary_logits.argmax(-1) |
|
boundary[padding_mask] = -1 |
|
else: |
|
boundary = torch.distributions.Categorical(logits=boundary_logits).sample() |
|
boundary[padding_mask] = -1 |
|
return boundary, boundary_logits |
|
|
|
def pre_segment(self, logits, padding_mask, return_boundary=False, deterministic=True): |
|
""" |
|
Input: |
|
logits: (B, T, C) |
|
padding_mask: (B, T) |
|
Output: |
|
new_logits: (B, T', C) |
|
new_padding_mask: (B, T') |
|
""" |
|
|
|
bsz, tsz, csz = logits.size() |
|
|
|
boundary, boundary_logits = self.boundary_predict(logits, padding_mask, deterministic=deterministic) |
|
|
|
|
|
|
|
|
|
new_tsz = int(torch.max(torch.sum(boundary==1, dim=1)).item())+1 |
|
new_logits = logits.new_zeros(bsz, new_tsz, csz) |
|
new_pad = padding_mask.new_zeros(bsz, new_tsz) |
|
|
|
for b in range(bsz): |
|
|
|
new_idx = 0 |
|
count = 0 |
|
for t in range(tsz): |
|
if padding_mask[b, t] == 1: |
|
break |
|
if boundary[b, t] == 1: |
|
new_logits[b, new_idx] /= count |
|
new_idx += 1 |
|
count = 0 |
|
new_logits[b, new_idx] += logits[b, t] |
|
count += 1 |
|
if count > 0: |
|
|
|
new_logits[b, new_idx] /= count |
|
new_idx += 1 |
|
count = 0 |
|
if new_idx < new_tsz: |
|
pad = new_tsz - new_idx |
|
new_logits[b, -pad:] = 0 |
|
new_pad[b, -pad:] = True |
|
|
|
if return_boundary: |
|
return new_logits, new_pad, boundary, boundary_logits |
|
return new_logits, new_pad |
|
|
|
class RebornGenerator(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.output_dim = config.generator_output_dim |
|
self.stride = config.generator_stride |
|
self.dropout = nn.Dropout(config.generator_dropout) |
|
cnn_input_dim = config.generator_input_dim |
|
cnn_output_dim = config.generator_output_dim |
|
|
|
padding = config.generator_kernel // 2 |
|
self.proj = nn.Sequential( |
|
nn.Conv1d( |
|
cnn_input_dim, |
|
cnn_output_dim, |
|
kernel_size=config.generator_kernel, |
|
stride=config.generator_stride, |
|
dilation=config.generator_dilation, |
|
padding=padding, |
|
bias=config.generator_bias, |
|
), |
|
) |
|
|
|
def forward(self, dense_x, tokens, dense_padding_mask): |
|
dense_x = self.dropout(dense_x) |
|
|
|
dense_x = dense_x.transpose(-2, -1) |
|
|
|
dense_x = self.proj(dense_x) |
|
|
|
dense_x = dense_x.transpose(-2, -1) |
|
if self.stride > 1: |
|
dense_padding_mask = dense_padding_mask[:, :: self.stride] |
|
|
|
if dense_padding_mask.size(1) != dense_x.size(1): |
|
new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) |
|
diff = new_padding.size(1) - dense_padding_mask.size(1) |
|
assert ( |
|
diff > 0 |
|
), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}" |
|
if diff > 0: |
|
new_padding[:, diff:] = dense_padding_mask |
|
else: |
|
assert diff < 0 |
|
new_padding = dense_padding_mask[:, :diff] |
|
|
|
dense_padding_mask = new_padding |
|
|
|
result = {} |
|
|
|
token_x = None |
|
if tokens is not None: |
|
token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) |
|
token_x.scatter_(1, tokens.view(-1, 1).long(), 1) |
|
token_x = token_x.view(tokens.shape + (self.output_dim,)) |
|
|
|
result["dense_x"] = dense_x |
|
result["token_x"] = token_x |
|
result["dense_padding_mask"] = dense_padding_mask |
|
|
|
return result |
|
|
|
class RebornUASRModel(PreTrainedModel): |
|
config_class = RebornUASRConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.pca = nn.Linear(1024, 512) |
|
self.segmenter = RebornSegmenter(config) |
|
self.generator = RebornGenerator(config) |
|
|
|
def forward( |
|
self, |
|
x: Optional[torch.Tensor], |
|
padding_mask: Optional[torch.Tensor], |
|
): |
|
x_reduced = self.pca(x) |
|
x_segmented, segmented_padding_mask = self.segmenter.pre_segment(x_reduced, padding_mask, deterministic=True) |
|
x_generated = self.generator(x_segmented, None, segmented_padding_mask) |
|
|
|
return { |
|
'x_reduced': x_reduced, |
|
'x_segmented': x_segmented, |
|
'x_generated': x_generated |
|
} |
|
|
|
|