johntsi's picture
[UPD] mask and resources
487e674
raw
history blame contribute delete
No virus
13.4 kB
from transformers import PreTrainedModel, PretrainedConfig, Wav2Vec2ForCTC
import json
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import math
from typing import Optional
# x: torch.FloatTensor [T, B, D]
# mask: torch.BoolTensor [B, T], where True indicates padding
# returns: torch.LongTensor [B]
def get_lengths(x, mask=None):
if mask is not None:
return (~mask).long().sum(dim=1)
else:
return torch.LongTensor([x.size(0)] * x.size(1)).to(x.device)
# lens: torch.LongTensor [B]
# returns: torch.BoolTensor [B, max_lens], where True indicates padding
def lengths_to_padding_mask(lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
# input_lengths: torch.LongTensor [B]
def get_output_lengths(input_lengths):
conv_feature_layers = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
conv_cfg_list = eval(conv_feature_layers)
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
for i in range(len(conv_cfg_list)):
input_lengths = _conv_out_length(
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
)
return input_lengths.to(torch.long)
class ZeroSwotEncoderConfig(PretrainedConfig):
model_type = "zero_swot_encoder"
def __init__(
self,
wav2vec2_model_name_or_path="",
compression_adapter=None,
embed_dim=1024,
**kwargs
):
super().__init__(**kwargs)
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
self.compression_adapter = compression_adapter
self.embed_dim = embed_dim
@classmethod
def from_json_file(cls, json_file):
with open(json_file, "r") as reader:
text = reader.read()
config_dict = json.loads(text)
return cls(**config_dict)
class ZeroSwotEncoderModel(PreTrainedModel):
config_class = ZeroSwotEncoderConfig
model_type = "zero_swot_encoder"
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2ForCTC.from_pretrained(config.wav2vec2_model_name_or_path)
self.compression_adapter = CompressionAdapter(config.compression_adapter)
self.speech_embedder = SpeechEmbedder(config.embed_dim)
def forward(self, input_values, attention_mask=None):
input_lens = get_lengths(input_values, ~attention_mask)
# Forward pass through wav2vec2 encoder
x = self.wav2vec2.wav2vec2(input_values, attention_mask)[0] # [B, T, D]
# CTC predictions
preds = self.wav2vec2.lm_head(x).argmax(-1) # [B, T]
# Get output lengths for x
output_lens = get_output_lengths(input_lens)
# Compression
x, mask, _ = self.compression_adapter(x, preds, output_lens) # [B, N, D] with N << T
# BOS and EOS embeddings
x, mask = self.speech_embedder(x, mask) # [B, N+2, D]
return x, ~mask
class SpeechEmbedder(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
self.bos_emb = nn.Parameter(torch.empty(embed_dim))
self.eos_emb = nn.Parameter(torch.empty(embed_dim))
self.scale = self.embed_dim ** 0.5
def forward(self, x, padding_mask=None):
"""Add special embedding and positional embedding.
Args:
x (FloatTensor): (B, T, C)
padding_mask (ByteTensor): (B, T)
Outputs:
x (FloatTensor): (B, T+2, C)
padding_mask (ByteTensor): (B, T+2)
"""
B = x.size(0)
lengths = get_lengths(x.transpose(0, 1), padding_mask)
assert B == len(lengths)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
# prepend bos
x = torch.cat([self.bos_emb.view(1, 1, -1).expand(B, 1, -1), x], dim=1)
lengths += 1
# append padding (zeros) and then convert first padding to eos
x = torch.cat([x, torch.zeros(B, 1, x.size(-1), device=x.device, dtype=x.dtype)], dim=1)
for i in range(B):
x[i, lengths[i], :] = self.eos_emb
lengths += 1
padding_mask = lengths_to_padding_mask(lengths)
x = x * self.scale
return x, padding_mask
class PositionalEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
num_embeddings += padding_idx + 1
self.weights = PositionalEmbedding.get_embedding(
num_embeddings, embedding_dim, padding_idx
)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.max_positions = int(1e5)
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def make_positions(self, x, padding_idx: int):
mask = x.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def forward(self, input):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = PositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
positions = self.make_positions(input, self.padding_idx)
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)
class CLSPooling(nn.Module):
def __init__(self, embed_dim, num_transformer_layers, dropout_rate):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim))
nn.init.normal_(self.cls_token, mean=0.0, std=0.25)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
embed_dim,
nhead=16 if embed_dim == 1024 else 8,
dim_feedforward=4*embed_dim,
dropout=dropout_rate,
activation="relu",
batch_first=True,
norm_first=True
),
num_layers=num_transformer_layers,
)
self.pos_emb = PositionalEmbedding(512, embed_dim, 1)
self.scale = math.sqrt(embed_dim)
def forward(self, x, lens):
# x: [B, N, D]
# lens: [B]
# prepend cls token
x = torch.cat(
[
self.cls_token.to(dtype=x.dtype, device=x.device).repeat(x.size(0), 1, 1), # B x 1 x D
x
],
dim=1) # [B, N+1, D]
mask = lengths_to_padding_mask(lens+1)
x = x + self.pos_emb(mask.long()) / self.scale
x = self.transformer(x, src_key_padding_mask=mask) # [B, N+1, D]
x = x[:, 0] # [B, D]
return x
class CompressionAdapter(nn.Module):
def __init__(self, cfg):
super().__init__()
self.embed_dim = cfg["embed_dim"]
self.transformer_layers = cfg["transformer_layers"]
self.dropout = cfg["dropout"]
self.blank_idx = cfg["blank_idx"]
self.sep_idx = cfg["sep_idx"]
self.token_pooling_module = CLSPooling(
self.embed_dim, self.transformer_layers, self.dropout
)
def char_compression(self, x, preds, lens):
# x: B x T x D
# preds: B x T
# lens: B
B, T, D = x.size()
device = x.device
dtype = x.dtype
# zero-out the padding
mask = lengths_to_padding_mask(lens) # B x T
x = x.masked_fill(mask.unsqueeze(-1), 0)
preds = preds.masked_fill(mask, self.blank_idx)
# add a vector of -1 to know where each example ends after flattening the batch
preds = torch.cat([-torch.ones(B, 1, device=device, dtype=torch.long), preds], dim=1).view(-1)
x = torch.cat([torch.zeros(B, 1, D, device=device, dtype=dtype), x], dim=1).view(-1, D)
# get points of consecutive preds
preds, counts = preds.unique_consecutive(return_counts=True)
# split in representations of same chars
x = torch.split(x, counts.tolist())
# remove blanks
valid_mask = preds != self.blank_idx
preds = preds[valid_mask]
counts = counts[valid_mask] # [N]
x = [x_i for x_i, v_i in zip(x, valid_mask) if v_i]
# pack into tensor
x = pad_sequence(x, batch_first=True, padding_value=0)
# char pooling
x = torch.sum(x, dim=1) / counts.to(dtype=x.dtype).unsqueeze(1) # [B, N, D] -> [B, D]
# find split points for retrieving the examples
split_points = (preds == -1).nonzero(as_tuple=True)[0]
split_points = torch.cat([split_points, torch.tensor([len(preds)], device=device)])
split_points = (split_points[1:] - split_points[:-1]).tolist()
# split into examples
x = torch.split(x, split_points)
preds = torch.split(preds, split_points)
lens = torch.tensor([len(x_i) for x_i in x], device=device)
# pack into tensors
x = pad_sequence(x, batch_first=True, padding_value=0)
preds = pad_sequence(preds, batch_first=True, padding_value=self.blank_idx)
# remove the parts we add to identify the bounds for each example
x = x[:, 1:]
preds = preds[:, 1:]
lens -= 1
mask = lengths_to_padding_mask(lens)
# account for empty examples (just a sep token)
empty_examples = lens == 0
num_empty_examples = empty_examples.sum()
if num_empty_examples > 0:
mask[empty_examples, 0] = True
lens[empty_examples] = 1
preds[empty_examples, 0] = self.sep_idx
return x, mask, lens, preds, num_empty_examples
def token_compression(self, x, preds, lens):
# x: B x T x D
# preds: B x T
# lens: B
B, T, D = x.size()
device = x.device
dtype = x.dtype
# new lengths after compression
new_lens = preds.eq(self.sep_idx).sum(dim=1)
# unpad and unpack to list of tensors
preds = [preds[i, :lens[i]] for i in range(B)]
x = [x[i, :lens[i]] for i in range(B)]
# make sure every example ends with a separator
num_examples_without_ending_sep = torch.tensor(0, device=device, dtype=torch.long)
for i in range(B):
if preds[i][-1] != self.sep_idx:
preds[i] = torch.cat([preds[i], torch.tensor([self.sep_idx], device=device, dtype=torch.long)])
x[i] = torch.cat([x[i], torch.zeros(1, D, device=device, dtype=dtype)])
new_lens[i] += 1
num_examples_without_ending_sep += 1
# flatten
preds = torch.cat(preds)
x = torch.cat(x)
# split points according to separators
split_points = preds.eq(self.sep_idx).nonzero(as_tuple=True)[0] + 1
split_points = torch.cat([torch.tensor([0], device=device, dtype=torch.long), split_points])
split_points = (split_points[1:] - split_points[:-1]).tolist()
# re-arrange in 3d [total_num_tokens x max(count) x D]
x = torch.split(x, split_points) # Tuple[2d tensor]
counts = torch.tensor([len(x_i) for x_i in x], device=device, dtype=torch.long)
x = pad_sequence(x, batch_first=True, padding_value=0)
# reduce dim 1
x = self.token_pooling_module(x, counts)
# reconstruct the batch
split_points = new_lens.cumsum(dim=0)
split_points = torch.cat([torch.tensor([0], device=device, dtype=torch.long), split_points])
split_points = (split_points[1:] - split_points[:-1]).tolist()
x = torch.split(x, split_points)
x = pad_sequence(x, batch_first=True, padding_value=0) # B x ? x D
mask = lengths_to_padding_mask(new_lens)
return x, mask, new_lens, num_examples_without_ending_sep
def forward(self, x, preds, lens):
x, mask, lens, preds, _ = self.char_compression(x, preds, lens)
x, mask, lens, _ = self.token_compression(x, preds, lens)
return x, mask, lens