Instructions to use wsntxxn/effb2-trm-audiocaps-captioning with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use wsntxxn/effb2-trm-audiocaps-captioning with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="wsntxxn/effb2-trm-audiocaps-captioning", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("wsntxxn/effb2-trm-audiocaps-captioning", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Dict, Callable, Union, List | |
| import random | |
| import math | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
| from torchaudio import transforms | |
| from efficientnet_pytorch import EfficientNet | |
| from efficientnet_pytorch import utils as efficientnet_utils | |
| from einops import rearrange, reduce | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| def sort_pack_padded_sequence(input, lengths): | |
| sorted_lengths, indices = torch.sort(lengths, descending=True) | |
| tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) | |
| inv_ix = indices.clone() | |
| inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) | |
| return tmp, inv_ix | |
| def pad_unsort_packed_sequence(input, inv_ix): | |
| tmp, _ = pad_packed_sequence(input, batch_first=True) | |
| tmp = tmp[inv_ix] | |
| return tmp | |
| def pack_wrapper(module, attn_feats, attn_feat_lens): | |
| packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) | |
| if isinstance(module, torch.nn.RNNBase): | |
| return pad_unsort_packed_sequence(module(packed)[0], inv_ix) | |
| else: | |
| return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) | |
| def embedding_pooling(x, lens, pooling="mean"): | |
| if pooling == "max": | |
| fc_embs = max_with_lens(x, lens) | |
| elif pooling == "mean": | |
| fc_embs = mean_with_lens(x, lens) | |
| elif pooling == "mean+max": | |
| x_mean = mean_with_lens(x, lens) | |
| x_max = max_with_lens(x, lens) | |
| fc_embs = x_mean + x_max | |
| elif pooling == "last": | |
| indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) | |
| # indices: [N, 1, hidden] | |
| fc_embs = torch.gather(x, 1, indices).squeeze(1) | |
| else: | |
| raise Exception(f"pooling method {pooling} not support") | |
| return fc_embs | |
| def interpolate(x, ratio): | |
| """Interpolate data in time domain. This is used to compensate the | |
| resolution reduction in downsampling of a CNN. | |
| Args: | |
| x: (batch_size, time_steps, classes_num) | |
| ratio: int, ratio to interpolate | |
| Returns: | |
| upsampled: (batch_size, time_steps * ratio, classes_num) | |
| """ | |
| (batch_size, time_steps, classes_num) = x.shape | |
| upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) | |
| upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) | |
| return upsampled | |
| def pad_framewise_output(framewise_output, frames_num): | |
| """Pad framewise_output to the same length as input frames. The pad value | |
| is the same as the value of the last frame. | |
| Args: | |
| framewise_output: (batch_size, frames_num, classes_num) | |
| frames_num: int, number of frames to pad | |
| Outputs: | |
| output: (batch_size, frames_num, classes_num) | |
| """ | |
| pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) | |
| """tensor for padding""" | |
| output = torch.cat((framewise_output, pad), dim=1) | |
| """(batch_size, frames_num, classes_num)""" | |
| return output | |
| def find_contiguous_regions(activity_array): | |
| """Find contiguous regions from bool valued numpy.array. | |
| Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder | |
| Reason is: | |
| 1. This does not belong to a class necessarily | |
| 2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters | |
| """ | |
| # Find the changes in the activity_array | |
| change_indices = np.logical_xor(activity_array[1:], | |
| activity_array[:-1]).nonzero()[0] | |
| # Shift change_index with one, focus on frame after the change. | |
| change_indices += 1 | |
| if activity_array[0]: | |
| # If the first element of activity_array is True add 0 at the beginning | |
| change_indices = np.r_[0, change_indices] | |
| if activity_array[-1]: | |
| # If the last element of activity_array is True, add the length of the array | |
| change_indices = np.r_[change_indices, activity_array.size] | |
| # Reshape the result into two columns | |
| return change_indices.reshape((-1, 2)) | |
| def double_threshold(x, high_thres, low_thres, n_connect=1): | |
| """double_threshold | |
| Helper function to calculate double threshold for n-dim arrays | |
| :param x: input array | |
| :param high_thres: high threshold value | |
| :param low_thres: Low threshold value | |
| :param n_connect: Distance of <= n clusters will be merged | |
| """ | |
| assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format( | |
| x.shape) | |
| if x.ndim == 3: | |
| apply_dim = 1 | |
| elif x.ndim < 3: | |
| apply_dim = 0 | |
| # x is assumed to be 3d: (batch, time, dim) | |
| # Assumed to be 2d : (time, dim) | |
| # Assumed to be 1d : (time) | |
| # time axis is therefore at 1 for 3d and 0 for 2d ( | |
| return np.apply_along_axis(lambda x: _double_threshold( | |
| x, high_thres, low_thres, n_connect=n_connect), | |
| axis=apply_dim, | |
| arr=x) | |
| def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True): | |
| """_double_threshold | |
| Computes a double threshold over the input array | |
| :param x: input array, needs to be 1d | |
| :param high_thres: High threshold over the array | |
| :param low_thres: Low threshold over the array | |
| :param n_connect: Postprocessing, maximal distance between clusters to connect | |
| :param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros. | |
| """ | |
| assert x.ndim == 1, "Input needs to be 1d" | |
| high_locations = np.where(x > high_thres)[0] | |
| locations = x > low_thres | |
| encoded_pairs = find_contiguous_regions(locations) | |
| filtered_list = list( | |
| filter( | |
| lambda pair: | |
| ((pair[0] <= high_locations) & (high_locations <= pair[1])).any(), | |
| encoded_pairs)) | |
| filtered_list = connect_(filtered_list, n_connect) | |
| if return_arr: | |
| zero_one_arr = np.zeros_like(x, dtype=int) | |
| for sl in filtered_list: | |
| zero_one_arr[sl[0]:sl[1]] = 1 | |
| return zero_one_arr | |
| return filtered_list | |
| def connect_(pairs, n=1): | |
| """connect_ | |
| Connects two adjacent clusters if their distance is <= n | |
| :param pairs: Clusters of iterateables e.g., [(1,5),(7,10)] | |
| :param n: distance between two clusters | |
| """ | |
| if len(pairs) == 0: | |
| return [] | |
| start_, end_ = pairs[0] | |
| new_pairs = [] | |
| for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])): | |
| end_ = next_item[1] | |
| if next_item[0] - cur_item[1] <= n: | |
| pass | |
| else: | |
| new_pairs.append((start_, cur_item[1])) | |
| start_ = next_item[0] | |
| new_pairs.append((start_, end_)) | |
| return new_pairs | |
| def segments_to_temporal_tag(segments, thre=0.5): | |
| after_flag, while_flag = 0, 0 | |
| for j in range(len(segments)): | |
| for k in range(len(segments)): | |
| if segments[j][0] == segments[k][0]: | |
| continue | |
| min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1]) | |
| overlap = segments[j][2] - segments[k][1] | |
| if overlap < thre * min_duration: | |
| after_flag = 2 | |
| if segments[j][1] < segments[k][1] and overlap > thre * min_duration: | |
| while_flag = 1 | |
| return after_flag + while_flag | |
| def decode_with_timestamps(labels, time_resolution): | |
| batch_results = [] | |
| for lab in labels: | |
| segments = [] | |
| for i, label_column in enumerate(lab.T): | |
| change_indices = find_contiguous_regions(label_column) | |
| # append [onset, offset] in the result list | |
| for row in change_indices: | |
| segments.append((i, row[0] * time_resolution, row[1] * time_resolution)) | |
| temporal_tag = segments_to_temporal_tag(segments) | |
| batch_results.append(temporal_tag) | |
| return batch_results | |
| class _EffiNet(nn.Module): | |
| """A proxy for efficient net models""" | |
| def __init__(self, | |
| blocks_args=None, | |
| global_params=None, | |
| ) -> None: | |
| super().__init__() | |
| self.eff_net = EfficientNet(blocks_args=blocks_args, | |
| global_params=global_params) | |
| def forward(self, x: torch.Tensor): | |
| x = rearrange(x, 'b f t -> b 1 f t') | |
| x = self.eff_net.extract_features(x) | |
| return reduce(x, 'b c f t -> b t c', 'mean') | |
| def get_effb2_model() -> _EffiNet: | |
| blocks_args, global_params = efficientnet_utils.get_model_params( | |
| 'efficientnet-b2', {'include_top': False}) | |
| model = _EffiNet(blocks_args=blocks_args, | |
| global_params=global_params) | |
| model.eff_net._change_in_channels(1) | |
| return model | |
| def merge_load_state_dict(state_dict, | |
| model: torch.nn.Module, | |
| output_fn: Callable = sys.stdout.write): | |
| model_dict = model.state_dict() | |
| pretrained_dict = {} | |
| mismatch_keys = [] | |
| for key, value in state_dict.items(): | |
| if key in model_dict and model_dict[key].shape == value.shape: | |
| pretrained_dict[key] = value | |
| else: | |
| mismatch_keys.append(key) | |
| output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n") | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict, strict=True) | |
| return pretrained_dict.keys() | |
| class EfficientNetB2(nn.Module): | |
| def __init__(self, | |
| n_mels: int = 64, | |
| win_length: int = 32, | |
| hop_length: int = 10, | |
| f_min: int = 0, | |
| freeze: bool = False,): | |
| super().__init__() | |
| sample_rate = 16000 | |
| self.melspec_extractor = transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=win_length * sample_rate // 1000, | |
| win_length=win_length * sample_rate // 1000, | |
| hop_length=hop_length * sample_rate // 1000, | |
| f_min=f_min, | |
| n_mels=n_mels, | |
| ) | |
| self.hop_length = 10 * sample_rate // 1000 | |
| self.db_transform = transforms.AmplitudeToDB(top_db=120) | |
| self.backbone = get_effb2_model() | |
| self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels | |
| self.downsample_ratio = 32 | |
| if freeze: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input_dict): | |
| waveform = input_dict["wav"] | |
| wave_length = input_dict["wav_len"] | |
| specaug = input_dict["specaug"] | |
| x = self.melspec_extractor(waveform) | |
| x = self.db_transform(x) # (batch_size, mel_bins, time_steps) | |
| x = rearrange(x, 'b f t -> b 1 t f') | |
| if self.training and specaug: | |
| x = self.spec_augmenter(x) | |
| x = rearrange(x, 'b 1 t f -> b f t') | |
| x = self.backbone(x) | |
| attn_emb = x | |
| wave_length = torch.as_tensor(wave_length) | |
| feat_length = torch.div(wave_length, self.hop_length, | |
| rounding_mode="floor") + 1 | |
| feat_length = torch.div(feat_length, self.downsample_ratio, | |
| rounding_mode="floor") | |
| fc_emb = mean_with_lens(attn_emb, feat_length) | |
| output_dict = { | |
| 'fc_emb': fc_emb, | |
| 'attn_emb': attn_emb, | |
| 'attn_emb_len': feat_length | |
| } | |
| return output_dict | |
| def generate_length_mask(lens, max_length=None): | |
| lens = torch.as_tensor(lens) | |
| N = lens.size(0) | |
| if max_length is None: | |
| max_length = max(lens) | |
| if isinstance(max_length, torch.Tensor): | |
| max_length = max_length.item() | |
| idxs = torch.arange(max_length).repeat(N).view(N, max_length) | |
| idxs = idxs.to(lens.device) | |
| mask = (idxs < lens.view(-1, 1)) | |
| return mask | |
| def mean_with_lens(features, lens): | |
| """ | |
| features: [N, T, ...] (assume the second dimension represents length) | |
| lens: [N,] | |
| """ | |
| lens = torch.as_tensor(lens) | |
| if max(lens) != features.size(1): | |
| max_length = features.size(1) | |
| mask = generate_length_mask(lens, max_length) | |
| else: | |
| mask = generate_length_mask(lens) | |
| mask = mask.to(features.device) # [N, T] | |
| while mask.ndim < features.ndim: | |
| mask = mask.unsqueeze(-1) | |
| feature_mean = features * mask | |
| feature_mean = feature_mean.sum(1) | |
| while lens.ndim < feature_mean.ndim: | |
| lens = lens.unsqueeze(1) | |
| feature_mean = feature_mean / lens.to(features.device) | |
| # feature_mean = features * mask.unsqueeze(-1) | |
| # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device) | |
| return feature_mean | |
| def max_with_lens(features, lens): | |
| """ | |
| features: [N, T, ...] (assume the second dimension represents length) | |
| lens: [N,] | |
| """ | |
| lens = torch.as_tensor(lens) | |
| if max(lens) != features.size(1): | |
| max_length = features.size(1) | |
| mask = generate_length_mask(lens, max_length) | |
| else: | |
| mask = generate_length_mask(lens) | |
| mask = mask.to(features.device) # [N, T] | |
| feature_max = features.clone() | |
| feature_max[~mask] = float("-inf") | |
| feature_max, _ = feature_max.max(1) | |
| return feature_max | |
| def repeat_tensor(x, n): | |
| return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) | |
| class CaptionMetaMixin: | |
| pad_idx = 0 | |
| start_idx = 1 | |
| end_idx = 2 | |
| max_length = 20 | |
| def set_index(cls, start_idx, end_idx, pad_idx): | |
| cls.start_idx = start_idx | |
| cls.end_idx = end_idx | |
| cls.pad_idx = pad_idx | |
| class CaptionModel(nn.Module, CaptionMetaMixin): | |
| """ | |
| Encoder-decoder captioning model. | |
| """ | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.vocab_size = decoder.vocab_size | |
| self.train_forward_keys = ["cap", "cap_len", "ss_ratio"] | |
| self.inference_forward_keys = ["sample_method", "max_length", "temp"] | |
| freeze_encoder = kwargs.get("freeze_encoder", False) | |
| if freeze_encoder: | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| self.check_decoder_compatibility() | |
| def check_decoder_compatibility(self): | |
| compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders] | |
| assert isinstance(self.decoder, self.compatible_decoders), \ | |
| f"{self.decoder.__class__.__name__} is incompatible with " \ | |
| f"{self.__class__.__name__}, please use decoder in {compatible_decoders} " | |
| def forward(self, input_dict: Dict): | |
| """ | |
| input_dict: { | |
| (required) | |
| mode: train/inference, | |
| [spec, spec_len], | |
| [fc], | |
| [attn, attn_len], | |
| [wav, wav_len], | |
| [sample_method: greedy], | |
| [temp: 1.0] (in case of no teacher forcing) | |
| (optional, mode=train) | |
| cap, | |
| cap_len, | |
| ss_ratio, | |
| (optional, mode=inference) | |
| sample_method: greedy/beam, | |
| max_length, | |
| temp, | |
| beam_size (optional, sample_method=beam), | |
| n_best (optional, sample_method=beam), | |
| } | |
| """ | |
| encoder_output_dict = self.encoder(input_dict) | |
| output = self.forward_decoder(input_dict, encoder_output_dict) | |
| return output | |
| def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict): | |
| if input_dict["mode"] == "train": | |
| forward_dict = { | |
| "mode": "train", "sample_method": "greedy", "temp": 1.0 | |
| } | |
| for key in self.train_forward_keys: | |
| forward_dict[key] = input_dict[key] | |
| forward_dict.update(encoder_output_dict) | |
| output = self.train_forward(forward_dict) | |
| elif input_dict["mode"] == "inference": | |
| forward_dict = {"mode": "inference"} | |
| default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 } | |
| for key in self.inference_forward_keys: | |
| if key in input_dict: | |
| forward_dict[key] = input_dict[key] | |
| else: | |
| forward_dict[key] = default_args[key] | |
| if forward_dict["sample_method"] == "beam": | |
| forward_dict["beam_size"] = input_dict.get("beam_size", 3) | |
| forward_dict["n_best"] = input_dict.get("n_best", False) | |
| forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"]) | |
| elif forward_dict["sample_method"] == "dbs": | |
| forward_dict["beam_size"] = input_dict.get("beam_size", 6) | |
| forward_dict["group_size"] = input_dict.get("group_size", 3) | |
| forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5) | |
| forward_dict["group_nbest"] = input_dict.get("group_nbest", True) | |
| forward_dict.update(encoder_output_dict) | |
| output = self.inference_forward(forward_dict) | |
| else: | |
| raise Exception("mode should be either 'train' or 'inference'") | |
| output.update(encoder_output_dict) | |
| return output | |
| def prepare_output(self, input_dict): | |
| output = {} | |
| batch_size = input_dict["fc_emb"].size(0) | |
| if input_dict["mode"] == "train": | |
| max_length = input_dict["cap"].size(1) - 1 | |
| elif input_dict["mode"] == "inference": | |
| max_length = input_dict["max_length"] | |
| else: | |
| raise Exception("mode should be either 'train' or 'inference'") | |
| device = input_dict["fc_emb"].device | |
| output["seq"] = torch.full((batch_size, max_length), self.end_idx, | |
| dtype=torch.long) | |
| output["logit"] = torch.empty(batch_size, max_length, | |
| self.vocab_size).to(device) | |
| output["sampled_logprob"] = torch.zeros(batch_size, max_length) | |
| output["embed"] = torch.empty(batch_size, max_length, | |
| self.decoder.d_model).to(device) | |
| return output | |
| def train_forward(self, input_dict): | |
| if input_dict["ss_ratio"] != 1: # scheduled sampling training | |
| input_dict["mode"] = "train" | |
| return self.stepwise_forward(input_dict) | |
| output = self.seq_forward(input_dict) | |
| self.train_process(output, input_dict) | |
| return output | |
| def seq_forward(self, input_dict): | |
| raise NotImplementedError | |
| def train_process(self, output, input_dict): | |
| pass | |
| def inference_forward(self, input_dict): | |
| if input_dict["sample_method"] == "beam": | |
| return self.beam_search(input_dict) | |
| elif input_dict["sample_method"] == "dbs": | |
| return self.diverse_beam_search(input_dict) | |
| return self.stepwise_forward(input_dict) | |
| def stepwise_forward(self, input_dict): | |
| """Step-by-step decoding""" | |
| output = self.prepare_output(input_dict) | |
| max_length = output["seq"].size(1) | |
| # start sampling | |
| for t in range(max_length): | |
| input_dict["t"] = t | |
| self.decode_step(input_dict, output) | |
| if input_dict["mode"] == "inference": # decide whether to stop when sampling | |
| unfinished_t = output["seq"][:, t] != self.end_idx | |
| if t == 0: | |
| unfinished = unfinished_t | |
| else: | |
| unfinished *= unfinished_t | |
| output["seq"][:, t][~unfinished] = self.end_idx | |
| if unfinished.sum() == 0: | |
| break | |
| self.stepwise_process(output) | |
| return output | |
| def decode_step(self, input_dict, output): | |
| """Decoding operation of timestep t""" | |
| decoder_input = self.prepare_decoder_input(input_dict, output) | |
| # feed to the decoder to get logit | |
| output_t = self.decoder(decoder_input) | |
| logit_t = output_t["logit"] | |
| # assert logit_t.ndim == 3 | |
| if logit_t.size(1) == 1: | |
| logit_t = logit_t.squeeze(1) | |
| embed_t = output_t["embed"].squeeze(1) | |
| elif logit_t.size(1) > 1: | |
| logit_t = logit_t[:, -1, :] | |
| embed_t = output_t["embed"][:, -1, :] | |
| else: | |
| raise Exception("no logit output") | |
| # sample the next input word and get the corresponding logit | |
| sampled = self.sample_next_word(logit_t, | |
| method=input_dict["sample_method"], | |
| temp=input_dict["temp"]) | |
| output_t.update(sampled) | |
| output_t["t"] = input_dict["t"] | |
| output_t["logit"] = logit_t | |
| output_t["embed"] = embed_t | |
| self.stepwise_process_step(output, output_t) | |
| def prepare_decoder_input(self, input_dict, output): | |
| """Prepare the inp ut dict for the decoder""" | |
| raise NotImplementedError | |
| def stepwise_process_step(self, output, output_t): | |
| """Postprocessing (save output values) after each timestep t""" | |
| t = output_t["t"] | |
| output["logit"][:, t, :] = output_t["logit"] | |
| output["seq"][:, t] = output_t["word"] | |
| output["sampled_logprob"][:, t] = output_t["probs"] | |
| output["embed"][:, t, :] = output_t["embed"] | |
| def stepwise_process(self, output): | |
| """Postprocessing after the whole step-by-step autoregressive decoding""" | |
| pass | |
| def sample_next_word(self, logit, method, temp): | |
| """Sample the next word, given probs output by the decoder""" | |
| logprob = torch.log_softmax(logit, dim=1) | |
| if method == "greedy": | |
| sampled_logprob, word = torch.max(logprob.detach(), 1) | |
| elif method == "gumbel": | |
| def sample_gumbel(shape, eps=1e-20): | |
| U = torch.rand(shape).to(logprob.device) | |
| return -torch.log(-torch.log(U + eps) + eps) | |
| def gumbel_softmax_sample(logit, temperature): | |
| y = logit + sample_gumbel(logit.size()) | |
| return torch.log_softmax(y / temperature, dim=-1) | |
| _logprob = gumbel_softmax_sample(logprob, temp) | |
| _, word = torch.max(_logprob.data, 1) | |
| sampled_logprob = logprob.gather(1, word.unsqueeze(-1)) | |
| else: | |
| logprob = logprob / temp | |
| if method.startswith("top"): | |
| top_num = float(method[3:]) | |
| if 0 < top_num < 1: # top-p sampling | |
| probs = torch.softmax(logit, dim=1) | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) | |
| _cumsum = sorted_probs.cumsum(1) | |
| mask = _cumsum < top_num | |
| mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) | |
| sorted_probs = sorted_probs * mask.to(sorted_probs) | |
| sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) | |
| logprob.scatter_(1, sorted_indices, sorted_probs.log()) | |
| else: # top-k sampling | |
| k = int(top_num) | |
| tmp = torch.empty_like(logprob).fill_(float('-inf')) | |
| topk, indices = torch.topk(logprob, k, dim=1) | |
| tmp = tmp.scatter(1, indices, topk) | |
| logprob = tmp | |
| word = torch.distributions.Categorical(logits=logprob.detach()).sample() | |
| sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1) | |
| word = word.detach().long() | |
| # sampled_logprob: [N,], word: [N,] | |
| return {"word": word, "probs": sampled_logprob} | |
| def beam_search(self, input_dict): | |
| output = self.prepare_output(input_dict) | |
| max_length = input_dict["max_length"] | |
| beam_size = input_dict["beam_size"] | |
| if input_dict["n_best"]: | |
| n_best_size = input_dict["n_best_size"] | |
| batch_size, max_length = output["seq"].size() | |
| output["seq"] = torch.full((batch_size, n_best_size, max_length), | |
| self.end_idx, dtype=torch.long) | |
| temp = input_dict["temp"] | |
| # instance by instance beam seach | |
| for i in range(output["seq"].size(0)): | |
| output_i = self.prepare_beamsearch_output(input_dict) | |
| input_dict["sample_idx"] = i | |
| for t in range(max_length): | |
| input_dict["t"] = t | |
| output_t = self.beamsearch_step(input_dict, output_i) | |
| ####################################### | |
| # merge with previous beam and select the current max prob beam | |
| ####################################### | |
| logit_t = output_t["logit"] | |
| if logit_t.size(1) == 1: | |
| logit_t = logit_t.squeeze(1) | |
| elif logit_t.size(1) > 1: | |
| logit_t = logit_t[:, -1, :] | |
| else: | |
| raise Exception("no logit output") | |
| logprob_t = torch.log_softmax(logit_t, dim=1) | |
| logprob_t = torch.log_softmax(logprob_t / temp, dim=1) | |
| logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t | |
| if t == 0: # for the first step, all k seq will have the same probs | |
| topk_logprob, topk_words = logprob_t[0].topk( | |
| beam_size, 0, True, True) | |
| else: # unroll and find top logprob, and their unrolled indices | |
| topk_logprob, topk_words = logprob_t.view(-1).topk( | |
| beam_size, 0, True, True) | |
| topk_words = topk_words.cpu() | |
| output_i["topk_logprob"] = topk_logprob | |
| # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,] | |
| output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size, | |
| rounding_mode='trunc') | |
| output_i["next_word"] = topk_words % self.vocab_size # [beam_size,] | |
| if t == 0: | |
| output_i["seq"] = output_i["next_word"].unsqueeze(1) | |
| else: | |
| output_i["seq"] = torch.cat([ | |
| output_i["seq"][output_i["prev_words_beam"]], | |
| output_i["next_word"].unsqueeze(1)], dim=1) | |
| # add finished beams to results | |
| is_end = output_i["next_word"] == self.end_idx | |
| if t == max_length - 1: | |
| is_end.fill_(1) | |
| for beam_idx in range(beam_size): | |
| if is_end[beam_idx]: | |
| final_beam = { | |
| "seq": output_i["seq"][beam_idx].clone(), | |
| "score": output_i["topk_logprob"][beam_idx].item() | |
| } | |
| final_beam["score"] = final_beam["score"] / (t + 1) | |
| output_i["done_beams"].append(final_beam) | |
| output_i["topk_logprob"][is_end] -= 1000 | |
| self.beamsearch_process_step(output_i, output_t) | |
| if len(output_i["done_beams"]) == beam_size: | |
| break | |
| self.beamsearch_process(output, output_i, input_dict) | |
| return output | |
| def prepare_beamsearch_output(self, input_dict): | |
| beam_size = input_dict["beam_size"] | |
| device = input_dict["fc_emb"].device | |
| output = { | |
| "topk_logprob": torch.zeros(beam_size).to(device), | |
| "seq": None, | |
| "prev_words_beam": None, | |
| "next_word": None, | |
| "done_beams": [], | |
| } | |
| return output | |
| def beamsearch_step(self, input_dict, output_i): | |
| decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i) | |
| output_t = self.decoder(decoder_input) | |
| output_t["t"] = input_dict["t"] | |
| return output_t | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| raise NotImplementedError | |
| def beamsearch_process_step(self, output_i, output_t): | |
| pass | |
| def beamsearch_process(self, output, output_i, input_dict): | |
| i = input_dict["sample_idx"] | |
| done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"]) | |
| if input_dict["n_best"]: | |
| done_beams = done_beams[:input_dict["n_best_size"]] | |
| for out_idx, done_beam in enumerate(done_beams): | |
| seq = done_beam["seq"] | |
| output["seq"][i][out_idx, :len(seq)] = seq | |
| else: | |
| seq = done_beams[0]["seq"] | |
| output["seq"][i][:len(seq)] = seq | |
| def diverse_beam_search(self, input_dict): | |
| def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash): | |
| local_time = t - divm | |
| unaug_logprob = logprob.clone() | |
| if divm > 0: | |
| change = torch.zeros(logprob.size(-1)) | |
| for prev_choice in range(divm): | |
| prev_decisions = seq_table[prev_choice][..., local_time] | |
| for prev_labels in range(bdash): | |
| change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1)) | |
| change = change.to(logprob.device) | |
| logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda | |
| return logprob, unaug_logprob | |
| output = self.prepare_output(input_dict) | |
| group_size = input_dict["group_size"] | |
| batch_size = output["seq"].size(0) | |
| beam_size = input_dict["beam_size"] | |
| bdash = beam_size // group_size | |
| input_dict["bdash"] = bdash | |
| diversity_lambda = input_dict["diversity_lambda"] | |
| device = input_dict["fc_emb"].device | |
| max_length = input_dict["max_length"] | |
| temp = input_dict["temp"] | |
| group_nbest = input_dict["group_nbest"] | |
| batch_size, max_length = output["seq"].size() | |
| if group_nbest: | |
| output["seq"] = torch.full((batch_size, beam_size, max_length), | |
| self.end_idx, dtype=torch.long) | |
| else: | |
| output["seq"] = torch.full((batch_size, group_size, max_length), | |
| self.end_idx, dtype=torch.long) | |
| for i in range(batch_size): | |
| input_dict["sample_idx"] = i | |
| seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0] | |
| logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)] | |
| done_beams_table = [[] for _ in range(group_size)] | |
| output_i = { | |
| "prev_words_beam": [None for _ in range(group_size)], | |
| "next_word": [None for _ in range(group_size)], | |
| "state": [None for _ in range(group_size)] | |
| } | |
| for t in range(max_length + group_size - 1): | |
| input_dict["t"] = t | |
| for divm in range(group_size): | |
| input_dict["divm"] = divm | |
| if t >= divm and t <= max_length + divm - 1: | |
| local_time = t - divm | |
| decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i) | |
| output_t = self.decoder(decoder_input) | |
| output_t["divm"] = divm | |
| logit_t = output_t["logit"] | |
| if logit_t.size(1) == 1: | |
| logit_t = logit_t.squeeze(1) | |
| elif logit_t.size(1) > 1: | |
| logit_t = logit_t[:, -1, :] | |
| else: | |
| raise Exception("no logit output") | |
| logprob_t = torch.log_softmax(logit_t, dim=1) | |
| logprob_t = torch.log_softmax(logprob_t / temp, dim=1) | |
| logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash) | |
| logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t | |
| if local_time == 0: # for the first step, all k seq will have the same probs | |
| topk_logprob, topk_words = logprob_t[0].topk( | |
| bdash, 0, True, True) | |
| else: # unroll and find top logprob, and their unrolled indices | |
| topk_logprob, topk_words = logprob_t.view(-1).topk( | |
| bdash, 0, True, True) | |
| topk_words = topk_words.cpu() | |
| logprob_table[divm] = topk_logprob | |
| output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,] | |
| output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,] | |
| if local_time > 0: | |
| seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]] | |
| seq_table[divm] = torch.cat([ | |
| seq_table[divm], | |
| output_i["next_word"][divm].unsqueeze(-1)], -1) | |
| is_end = seq_table[divm][:, t-divm] == self.end_idx | |
| assert seq_table[divm].shape[-1] == t - divm + 1 | |
| if t == max_length + divm - 1: | |
| is_end.fill_(1) | |
| for beam_idx in range(bdash): | |
| if is_end[beam_idx]: | |
| final_beam = { | |
| "seq": seq_table[divm][beam_idx].clone(), | |
| "score": logprob_table[divm][beam_idx].item() | |
| } | |
| final_beam["score"] = final_beam["score"] / (t - divm + 1) | |
| done_beams_table[divm].append(final_beam) | |
| logprob_table[divm][is_end] -= 1000 | |
| self.dbs_process_step(output_i, output_t) | |
| done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)] | |
| if group_nbest: | |
| done_beams = sum(done_beams_table, []) | |
| else: | |
| done_beams = [group_beam[0] for group_beam in done_beams_table] | |
| for _, done_beam in enumerate(done_beams): | |
| output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"] | |
| return output | |
| def prepare_dbs_decoder_input(self, input_dict, output_i): | |
| raise NotImplementedError | |
| def dbs_process_step(self, output_i, output_t): | |
| pass | |
| class TransformerModel(CaptionModel): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| TransformerDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| def seq_forward(self, input_dict): | |
| cap = input_dict["cap"] | |
| cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
| cap_padding_mask = cap_padding_mask[:, :-1] | |
| output = self.decoder( | |
| { | |
| "word": cap[:, :-1], | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"], | |
| "cap_padding_mask": cap_padding_mask | |
| } | |
| ) | |
| return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = { | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"] | |
| } | |
| t = input_dict["t"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
| word = input_dict["cap"][:, :t+1] | |
| else: | |
| start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
| # word: [N, T] | |
| decoder_input["word"] = word | |
| cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
| decoder_input["cap_padding_mask"] = cap_padding_mask | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = {} | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| ############### | |
| # prepare attn embeds | |
| ################ | |
| if t == 0: | |
| attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
| attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
| output_i["attn_emb"] = attn_emb | |
| output_i["attn_emb_len"] = attn_emb_len | |
| decoder_input["attn_emb"] = output_i["attn_emb"] | |
| decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
| ############### | |
| # determine input word | |
| ################ | |
| start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
| decoder_input["word"] = word | |
| cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
| decoder_input["cap_padding_mask"] = cap_padding_mask | |
| return decoder_input | |
| class BaseDecoder(nn.Module): | |
| """ | |
| Take word/audio embeddings and output the next word probs | |
| """ | |
| def __init__(self, emb_dim, vocab_size, fc_emb_dim, | |
| attn_emb_dim, dropout=0.2, tie_weights=False): | |
| super().__init__() | |
| self.emb_dim = emb_dim | |
| self.vocab_size = vocab_size | |
| self.fc_emb_dim = fc_emb_dim | |
| self.attn_emb_dim = attn_emb_dim | |
| self.tie_weights = tie_weights | |
| self.word_embedding = nn.Embedding(vocab_size, emb_dim) | |
| self.in_dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| raise NotImplementedError | |
| def load_word_embedding(self, weight, freeze=True): | |
| embedding = np.load(weight) | |
| assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch" | |
| assert embedding.shape[1] == self.emb_dim, "embed size mismatch" | |
| # embeddings = torch.as_tensor(embeddings).float() | |
| # self.word_embeddings.weight = nn.Parameter(embeddings) | |
| # for para in self.word_embeddings.parameters(): | |
| # para.requires_grad = tune | |
| self.word_embedding = nn.Embedding.from_pretrained(embedding, | |
| freeze=freeze) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=100): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ | |
| (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| # self.register_buffer("pe", pe) | |
| self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) | |
| def forward(self, x): | |
| # x: [T, N, E] | |
| x = x + self.pe[:x.size(0), :] | |
| return self.dropout(x) | |
| class TransformerDecoder(BaseDecoder): | |
| def __init__(self, | |
| emb_dim, | |
| vocab_size, | |
| fc_emb_dim, | |
| attn_emb_dim, | |
| dropout, | |
| freeze=False, | |
| tie_weights=False, | |
| **kwargs): | |
| super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout=dropout, tie_weights=tie_weights) | |
| self.d_model = emb_dim | |
| self.nhead = kwargs.get("nhead", self.d_model // 64) | |
| self.nlayers = kwargs.get("nlayers", 2) | |
| self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) | |
| self.pos_encoder = PositionalEncoding(self.d_model, dropout) | |
| layer = nn.TransformerDecoderLayer(d_model=self.d_model, | |
| nhead=self.nhead, | |
| dim_feedforward=self.dim_feedforward, | |
| dropout=dropout) | |
| self.model = nn.TransformerDecoder(layer, self.nlayers) | |
| self.classifier = nn.Linear(self.d_model, vocab_size, bias=False) | |
| if tie_weights: | |
| self.classifier.weight = self.word_embedding.weight | |
| self.attn_proj = nn.Sequential( | |
| nn.Linear(self.attn_emb_dim, self.d_model), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.LayerNorm(self.d_model) | |
| ) | |
| self.init_params() | |
| self.freeze = freeze | |
| if freeze: | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def init_params(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def load_pretrained(self, pretrained, output_fn): | |
| checkpoint = torch.load(pretrained, map_location="cpu") | |
| if "model" in checkpoint: | |
| checkpoint = checkpoint["model"] | |
| if next(iter(checkpoint)).startswith("decoder."): | |
| state_dict = {} | |
| for k, v in checkpoint.items(): | |
| state_dict[k[8:]] = v | |
| loaded_keys = merge_load_state_dict(state_dict, self, output_fn) | |
| if self.freeze: | |
| for name, param in self.named_parameters(): | |
| if name in loaded_keys: | |
| param.requires_grad = False | |
| else: | |
| param.requires_grad = True | |
| def generate_square_subsequent_mask(self, max_length): | |
| mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| def forward(self, input_dict): | |
| word = input_dict["word"] | |
| attn_emb = input_dict["attn_emb"] | |
| attn_emb_len = input_dict["attn_emb_len"] | |
| cap_padding_mask = input_dict["cap_padding_mask"] | |
| p_attn_emb = self.attn_proj(attn_emb) | |
| p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] | |
| word = word.to(attn_emb.device) | |
| embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] | |
| embed = embed.transpose(0, 1) # [T, N, emb_dim] | |
| embed = self.pos_encoder(embed) | |
| tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) | |
| memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) | |
| output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=cap_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask) | |
| output = output.transpose(0, 1) | |
| output = { | |
| "embed": output, | |
| "logit": self.classifier(output), | |
| } | |
| return output | |
| class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
| def __init__(self, | |
| model: nn.Module, | |
| shared_dim: int, | |
| tchr_dim: int, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.tchr_dim = tchr_dim | |
| if hasattr(model, "encoder"): | |
| self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
| shared_dim) | |
| else: | |
| self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
| shared_dim) | |
| self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| def forward(self, input_dict: Dict): | |
| unsup = input_dict.get("unsup", False) | |
| if unsup is False: | |
| output_dict = self.model(input_dict) | |
| else: | |
| output_dict = self.model.encoder(input_dict) | |
| if "tchr_output" in input_dict: | |
| stdnt_emb = output_dict["fc_emb"] | |
| stdnt_emb = self.stdnt_proj(stdnt_emb) | |
| tchr_emb = input_dict["tchr_output"]["embedding"] | |
| thcr_emb = self.tchr_proj(tchr_emb) | |
| stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
| thcr_emb = F.normalize(thcr_emb, dim=-1) | |
| unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
| logit = self.logit_scale * unscaled_logit | |
| label = torch.arange(logit.shape[0]).to(logit.device) | |
| loss1 = F.cross_entropy(logit, label) | |
| loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
| loss = (loss1 + loss2) / 2 | |
| output_dict["enc_kd_loss"] = loss | |
| return output_dict | |
| class Effb2TrmConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| sample_rate: int = 16000, | |
| tchr_dim: int = 768, | |
| shared_dim: int = 1024, | |
| fc_emb_dim: int = 1408, | |
| attn_emb_dim: int = 1408, | |
| decoder_n_layers: int = 2, | |
| decoder_we_tie_weights: bool = True, | |
| decoder_emb_dim: int = 256, | |
| decoder_dropout: float = 0.2, | |
| vocab_size: int = 4981, | |
| **kwargs | |
| ): | |
| self.sample_rate = sample_rate | |
| self.tchr_dim = tchr_dim | |
| self.shared_dim = shared_dim | |
| self.fc_emb_dim = fc_emb_dim | |
| self.attn_emb_dim = attn_emb_dim | |
| self.decoder_n_layers = decoder_n_layers | |
| self.decoder_we_tie_weights = decoder_we_tie_weights | |
| self.decoder_emb_dim = decoder_emb_dim | |
| self.decoder_dropout = decoder_dropout | |
| self.vocab_size = vocab_size | |
| super().__init__(**kwargs) | |
| class Effb2TrmCaptioningModel(PreTrainedModel): | |
| config_class = Effb2TrmConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| encoder = EfficientNetB2() | |
| decoder = TransformerDecoder( | |
| emb_dim=config.decoder_emb_dim, | |
| vocab_size=config.vocab_size, | |
| fc_emb_dim=config.fc_emb_dim, | |
| attn_emb_dim=config.attn_emb_dim, | |
| dropout=config.decoder_dropout, | |
| nlayers=config.decoder_n_layers, | |
| tie_weights=config.decoder_we_tie_weights | |
| ) | |
| model = TransformerModel(encoder, decoder) | |
| self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim) | |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
| model = super().from_pretrained( | |
| pretrained_model_name_or_path, *args, **kwargs | |
| ) | |
| model.model.model.decoder.word_embedding.weight = model.model.model.decoder.classifier.weight | |
| return model | |
| def forward(self, | |
| audio: torch.Tensor, | |
| audio_length: Union[List, np.ndarray, torch.Tensor], | |
| sample_method: str = "beam", | |
| beam_size: int = 3, | |
| max_length: int = 20, | |
| temp: float = 1.0,): | |
| device = self.device | |
| input_dict = { | |
| "wav": audio.to(device), | |
| "wav_len": audio_length, | |
| "specaug": False, | |
| "mode": "inference", | |
| "sample_method": sample_method, | |
| "max_length": max_length, | |
| "temp": temp, | |
| } | |
| if sample_method == "beam": | |
| input_dict["beam_size"] = beam_size | |
| return self.model(input_dict)["seq"].cpu() | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ConvBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), stride=(1, 1), | |
| padding=(1, 1), bias=False) | |
| self.conv2 = nn.Conv2d(in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), stride=(1, 1), | |
| padding=(1, 1), bias=False) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
| x = input | |
| x = F.relu_(self.bn1(self.conv1(x))) | |
| x = F.relu_(self.bn2(self.conv2(x))) | |
| if pool_type == 'max': | |
| x = F.max_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == 'avg': | |
| x = F.avg_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == 'avg+max': | |
| x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
| x2 = F.max_pool2d(x, kernel_size=pool_size) | |
| x = x1 + x2 | |
| else: | |
| raise Exception('Incorrect argument!') | |
| return x | |
| class Cnn14Encoder(nn.Module): | |
| def __init__(self, sample_rate=32000): | |
| super().__init__() | |
| sr_to_fmax = { | |
| 32000: 14000, | |
| 16000: 8000 | |
| } | |
| # Logmel spectrogram extractor | |
| self.melspec_extractor = transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=32 * sample_rate // 1000, | |
| win_length=32 * sample_rate // 1000, | |
| hop_length=10 * sample_rate // 1000, | |
| f_min=50, | |
| f_max=sr_to_fmax[sample_rate], | |
| n_mels=64, | |
| norm="slaney", | |
| mel_scale="slaney" | |
| ) | |
| self.hop_length = 10 * sample_rate // 1000 | |
| self.db_transform = transforms.AmplitudeToDB() | |
| self.bn0 = nn.BatchNorm2d(64) | |
| self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
| self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
| self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
| self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
| self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) | |
| self.downsample_ratio = 32 | |
| self.fc1 = nn.Linear(2048, 2048, bias=True) | |
| self.fc_emb_size = 2048 | |
| def forward(self, input_dict): | |
| lms = input_dict["lms"] | |
| wave_length = input_dict["wav_len"] | |
| x = lms # (batch_size, mel_bins, time_steps) | |
| x = x.transpose(1, 2) | |
| x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = torch.mean(x, dim=3) | |
| attn_emb = x.transpose(1, 2) | |
| wave_length = torch.as_tensor(wave_length) | |
| feat_length = torch.div(wave_length, self.hop_length, | |
| rounding_mode="floor") + 1 | |
| feat_length = torch.div(feat_length, self.downsample_ratio, | |
| rounding_mode="floor") | |
| x_max = max_with_lens(attn_emb, feat_length) | |
| x_mean = mean_with_lens(attn_emb, feat_length) | |
| x = x_max + x_mean | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu_(self.fc1(x)) | |
| fc_emb = F.dropout(x, p=0.5, training=self.training) | |
| output_dict = { | |
| 'fc_emb': fc_emb, | |
| 'attn_emb': attn_emb, | |
| 'attn_emb_len': feat_length | |
| } | |
| return output_dict | |
| class RnnEncoder(nn.Module): | |
| def __init__(self, | |
| attn_feat_dim, | |
| pooling="mean", | |
| **kwargs): | |
| super().__init__() | |
| self.pooling = pooling | |
| self.hidden_size = kwargs.get('hidden_size', 512) | |
| self.bidirectional = kwargs.get('bidirectional', False) | |
| self.num_layers = kwargs.get('num_layers', 1) | |
| self.dropout = kwargs.get('dropout', 0.2) | |
| self.rnn_type = kwargs.get('rnn_type', "GRU") | |
| self.in_bn = kwargs.get('in_bn', False) | |
| self.embed_dim = self.hidden_size * (self.bidirectional + 1) | |
| self.network = getattr(nn, self.rnn_type)( | |
| attn_feat_dim, | |
| self.hidden_size, | |
| num_layers=self.num_layers, | |
| bidirectional=self.bidirectional, | |
| dropout=self.dropout, | |
| batch_first=True) | |
| if self.in_bn: | |
| self.bn = nn.BatchNorm1d(self.embed_dim) | |
| def forward(self, input_dict): | |
| x = input_dict["attn"] | |
| lens = input_dict["attn_len"] | |
| lens = torch.as_tensor(lens) | |
| # x: [N, T, E] | |
| if self.in_bn: | |
| x = pack_wrapper(self.bn, x, lens) | |
| out = pack_wrapper(self.network, x, lens) | |
| # out: [N, T, hidden] | |
| attn_emb = out | |
| fc_emb = embedding_pooling(out, lens, self.pooling) | |
| return { | |
| "attn_emb": attn_emb, | |
| "fc_emb": fc_emb, | |
| "attn_emb_len": lens | |
| } | |
| class Cnn14RnnEncoder(nn.Module): | |
| def __init__(self, | |
| sample_rate, | |
| rnn_bidirectional, | |
| rnn_hidden_size, | |
| rnn_dropout, | |
| rnn_num_layers): | |
| super().__init__() | |
| self.cnn = Cnn14Encoder(sample_rate=sample_rate) | |
| self.rnn = RnnEncoder( | |
| 2048, | |
| bidirectional=rnn_bidirectional, | |
| hidden_size=rnn_hidden_size, | |
| dropout=rnn_dropout, | |
| num_layers=rnn_num_layers, | |
| ) | |
| def forward(self, input_dict): | |
| output_dict = self.cnn(input_dict) | |
| output_dict["attn"] = output_dict["attn_emb"] | |
| output_dict["attn_len"] = output_dict["attn_emb_len"] | |
| del output_dict["attn_emb"], output_dict["attn_emb_len"] | |
| output_dict = self.rnn(output_dict) | |
| return output_dict | |
| class Seq2SeqAttention(nn.Module): | |
| def __init__(self, hs_enc, hs_dec, attn_size): | |
| """ | |
| Args: | |
| hs_enc: encoder hidden size | |
| hs_dec: decoder hidden size | |
| attn_size: attention vector size | |
| """ | |
| super(Seq2SeqAttention, self).__init__() | |
| self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size) | |
| self.v = nn.Parameter(torch.randn(attn_size)) | |
| def forward(self, h_dec, h_enc, src_lens): | |
| """ | |
| Args: | |
| h_dec: decoder hidden (query), [N, hs_dec] | |
| h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] | |
| src_lens: source (encoder memory) lengths, [N, ] | |
| """ | |
| N = h_enc.size(0) | |
| src_max_len = h_enc.size(1) | |
| h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec] | |
| attn_input = torch.cat((h_dec, h_enc), dim=-1) | |
| attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size] | |
| v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size] | |
| score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len] | |
| idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) | |
| mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) | |
| score = score.masked_fill(mask == 0, -1e10) | |
| weights = torch.softmax(score, dim=-1) # [N, src_max_len] | |
| ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc] | |
| return ctx, weights | |
| class RnnDecoder(BaseDecoder): | |
| def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout, d_model, **kwargs): | |
| super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout,) | |
| self.d_model = d_model | |
| self.num_layers = kwargs.get('num_layers', 1) | |
| self.bidirectional = kwargs.get('bidirectional', False) | |
| self.rnn_type = kwargs.get('rnn_type', "GRU") | |
| self.classifier = nn.Linear( | |
| self.d_model * (self.bidirectional + 1), vocab_size) | |
| def forward(self, x): | |
| raise NotImplementedError | |
| def init_hidden(self, bs, device): | |
| num_dire = self.bidirectional + 1 | |
| n_layer = self.num_layers | |
| hid_dim = self.d_model | |
| if self.rnn_type == "LSTM": | |
| return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device), | |
| torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)) | |
| else: | |
| return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device) | |
| class BahAttnCatFcDecoder(RnnDecoder): | |
| def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout, d_model, **kwargs): | |
| """ | |
| concatenate fc, attn, word to feed to the rnn | |
| """ | |
| super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout, d_model, **kwargs) | |
| attn_size = kwargs.get("attn_size", self.d_model) | |
| self.model = getattr(nn, self.rnn_type)( | |
| input_size=self.emb_dim * 3, | |
| hidden_size=self.d_model, | |
| batch_first=True, | |
| num_layers=self.num_layers, | |
| bidirectional=self.bidirectional) | |
| self.attn = Seq2SeqAttention(self.attn_emb_dim, | |
| self.d_model * (self.bidirectional + 1) * \ | |
| self.num_layers, | |
| attn_size) | |
| self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) | |
| self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) | |
| def forward(self, input_dict): | |
| word = input_dict["word"] | |
| state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] | |
| fc_emb = input_dict["fc_emb"] | |
| attn_emb = input_dict["attn_emb"] | |
| attn_emb_len = input_dict["attn_emb_len"] | |
| word = word.to(fc_emb.device) | |
| embed = self.in_dropout(self.word_embedding(word)) | |
| # embed: [N, 1, embed_size] | |
| if state is None: | |
| state = self.init_hidden(word.size(0), fc_emb.device) | |
| if self.rnn_type == "LSTM": | |
| query = state[0].transpose(0, 1).flatten(1) | |
| else: | |
| query = state.transpose(0, 1).flatten(1) | |
| c, attn_weight = self.attn(query, attn_emb, attn_emb_len) | |
| p_fc_emb = self.fc_proj(fc_emb) | |
| p_ctx = self.ctx_proj(c) | |
| rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)), | |
| dim=-1) | |
| out, state = self.model(rnn_input, state) | |
| output = { | |
| "state": state, | |
| "embed": out, | |
| "logit": self.classifier(out), | |
| "attn_weight": attn_weight | |
| } | |
| return output | |
| class TemporalBahAttnDecoder(BahAttnCatFcDecoder): | |
| def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout, d_model, **kwargs): | |
| """ | |
| concatenate fc, attn, word to feed to the rnn | |
| """ | |
| super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
| dropout, d_model, **kwargs) | |
| self.temporal_embedding = nn.Embedding(4, emb_dim) | |
| def forward(self, input_dict): | |
| word = input_dict["word"] | |
| state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] | |
| fc_embs = input_dict["fc_emb"] | |
| attn_embs = input_dict["attn_emb"] | |
| attn_emb_lens = input_dict["attn_emb_len"] | |
| temporal_tag = input_dict["temporal_tag"] | |
| if input_dict["t"] == 0: | |
| embed = self.in_dropout( | |
| self.temporal_embedding(temporal_tag)).unsqueeze(1) | |
| elif word.size(-1) == self.fc_emb_dim: # fc_embs | |
| embed = word.unsqueeze(1) | |
| elif word.size(-1) == 1: # word | |
| word = word.to(fc_embs.device) | |
| embed = self.in_dropout(self.word_embedding(word)) | |
| else: | |
| raise Exception(f"problem with word input size {word.size()}") | |
| # embed: [N, 1, embed_size] | |
| if state is None: | |
| state = self.init_hidden(word.size(0), fc_embs.device) | |
| if self.rnn_type == "LSTM": | |
| query = state[0].transpose(0, 1).flatten(1) | |
| else: | |
| query = state.transpose(0, 1).flatten(1) | |
| c, attn_weight = self.attn(query, attn_embs, attn_emb_lens) | |
| p_ctx = self.ctx_proj(c) | |
| p_fc_embs = self.fc_proj(fc_embs) | |
| p_ctx = self.ctx_proj(c) | |
| rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1) | |
| out, state = self.model(rnn_input, state) | |
| output = { | |
| "state": state, | |
| "embed": out, | |
| "logit": self.classifier(out), | |
| "attn_weight": attn_weight | |
| } | |
| return output | |
| class Seq2SeqAttnModel(CaptionModel): | |
| def __init__(self, encoder, decoder, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| BahAttnCatFcDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| def seq_forward(self, input_dict): | |
| # Bahdanau attention only supports step-by-step implementation, so we implement forward in | |
| # step-by-step manner whether in training or evaluation | |
| return self.stepwise_forward(input_dict) | |
| def prepare_output(self, input_dict): | |
| output = super().prepare_output(input_dict) | |
| attn_weight = torch.empty(output["seq"].size(0), | |
| input_dict["attn_emb"].size(1), output["seq"].size(1)) | |
| output["attn_weight"] = attn_weight | |
| return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = { | |
| "fc_emb": input_dict["fc_emb"], | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"] | |
| } | |
| t = input_dict["t"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
| word = input_dict["cap"][:, t] | |
| else: | |
| if t == 0: | |
| word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long() | |
| else: | |
| word = output["seq"][:, t-1] | |
| # word: [N,] | |
| decoder_input["word"] = word.unsqueeze(1) | |
| ################ | |
| # prepare rnn state | |
| ################ | |
| if t > 0: | |
| decoder_input["state"] = output["state"] | |
| return decoder_input | |
| def stepwise_process_step(self, output, output_t): | |
| super().stepwise_process_step(output, output_t) | |
| output["state"] = output_t["state"] | |
| t = output_t["t"] | |
| output["attn_weight"][:, :, t] = output_t["attn_weight"] | |
| def prepare_beamsearch_output(self, input_dict): | |
| output = super().prepare_beamsearch_output(input_dict) | |
| beam_size = input_dict["beam_size"] | |
| max_length = input_dict["max_length"] | |
| output["attn_weight"] = torch.empty(beam_size, | |
| max(input_dict["attn_emb_len"]), max_length) | |
| return output | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = {} | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| ############### | |
| # prepare fc embeds | |
| ################ | |
| if t == 0: | |
| fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size) | |
| output_i["fc_emb"] = fc_emb | |
| decoder_input["fc_emb"] = output_i["fc_emb"] | |
| ############### | |
| # prepare attn embeds | |
| ################ | |
| if t == 0: | |
| attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
| attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
| output_i["attn_emb"] = attn_emb | |
| output_i["attn_emb_len"] = attn_emb_len | |
| decoder_input["attn_emb"] = output_i["attn_emb"] | |
| decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if t == 0: | |
| word = torch.tensor([self.start_idx,] * beam_size).long() | |
| else: | |
| word = output_i["next_word"] | |
| decoder_input["word"] = word.unsqueeze(1) | |
| ################ | |
| # prepare rnn state | |
| ################ | |
| if t > 0: | |
| if self.decoder.rnn_type == "LSTM": | |
| decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(), | |
| output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous()) | |
| else: | |
| decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous() | |
| return decoder_input | |
| def beamsearch_process_step(self, output_i, output_t): | |
| t = output_t["t"] | |
| output_i["state"] = output_t["state"] | |
| output_i["attn_weight"][..., t] = output_t["attn_weight"] | |
| output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...] | |
| def beamsearch_process(self, output, output_i, input_dict): | |
| super().beamsearch_process(output, output_i, input_dict) | |
| i = input_dict["sample_idx"] | |
| output["attn_weight"][i] = output_i["attn_weight"][0] | |
| def prepare_dbs_decoder_input(self, input_dict, output_i): | |
| decoder_input = {} | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| bdash = input_dict["bdash"] | |
| divm = input_dict["divm"] | |
| local_time = t - divm | |
| ############### | |
| # prepare fc embeds | |
| ################ | |
| # repeat only at the first timestep to save consumption | |
| if t == 0: | |
| fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1) | |
| output_i["fc_emb"] = fc_emb | |
| decoder_input["fc_emb"] = output_i["fc_emb"] | |
| ############### | |
| # prepare attn embeds | |
| ################ | |
| if t == 0: | |
| attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash) | |
| attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash) | |
| output_i["attn_emb"] = attn_emb | |
| output_i["attn_emb_len"] = attn_emb_len | |
| decoder_input["attn_emb"] = output_i["attn_emb"] | |
| decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if local_time == 0: | |
| word = torch.tensor([self.start_idx,] * bdash).long() | |
| else: | |
| word = output_i["next_word"][divm] | |
| decoder_input["word"] = word.unsqueeze(1) | |
| ################ | |
| # prepare rnn state | |
| ################ | |
| if local_time > 0: | |
| if self.decoder.rnn_type == "LSTM": | |
| decoder_input["state"] = ( | |
| output_i["state"][0][divm][ | |
| :, output_i["prev_words_beam"][divm], :].contiguous(), | |
| output_i["state"][1][divm][ | |
| :, output_i["prev_words_beam"][divm], :].contiguous() | |
| ) | |
| else: | |
| decoder_input["state"] = output_i["state"][divm][ | |
| :, output_i["prev_words_beam"][divm], :].contiguous() | |
| return decoder_input | |
| def dbs_process_step(self, output_i, output_t): | |
| divm = output_t["divm"] | |
| output_i["state"][divm] = output_t["state"] | |
| # TODO attention weight | |
| class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel): | |
| def __init__(self, encoder, decoder, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| TemporalBahAttnDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"] | |
| self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"] | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = super().prepare_decoder_input(input_dict, output) | |
| decoder_input["temporal_tag"] = input_dict["temporal_tag"] | |
| decoder_input["t"] = input_dict["t"] | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| ############### | |
| # prepare temporal_tag | |
| ################ | |
| if t == 0: | |
| temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size) | |
| output_i["temporal_tag"] = temporal_tag | |
| decoder_input["temporal_tag"] = output_i["temporal_tag"] | |
| decoder_input["t"] = input_dict["t"] | |
| return decoder_input | |
| def prepare_dbs_decoder_input(self, input_dict, output_i): | |
| decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i) | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| bdash = input_dict["bdash"] | |
| ############### | |
| # prepare temporal tag | |
| ################ | |
| # repeat only at the first timestep to save consumption | |
| if t == 0: | |
| temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash) | |
| output_i["temporal_tag"] = temporal_tag | |
| decoder_input["temporal_tag"] = output_i["temporal_tag"] | |
| decoder_input["t"] = input_dict["t"] | |
| return decoder_input | |
| class Cnn8rnnSedModel(nn.Module): | |
| def __init__(self, classes_num): | |
| super().__init__() | |
| self.time_resolution = 0.01 | |
| self.interpolate_ratio = 4 # Downsampled ratio | |
| self.bn0 = nn.BatchNorm2d(64) | |
| self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
| self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
| self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
| self.fc1 = nn.Linear(512, 512, bias=True) | |
| self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True) | |
| self.fc_audioset = nn.Linear(512, classes_num, bias=True) | |
| def forward(self, lms): | |
| output = self.forward_prob(lms) | |
| framewise_output = output["framewise_output"].cpu().numpy() | |
| thresholded_predictions = double_threshold( | |
| framewise_output, 0.75, 0.25) | |
| decoded_tags = decode_with_timestamps( | |
| thresholded_predictions, self.time_resolution | |
| ) | |
| return decoded_tags | |
| def forward_prob(self, lms): | |
| """ | |
| lms: (batch_size, mel_bins, time_steps)""" | |
| x = lms | |
| x = x.transpose(1, 2) | |
| x = x.unsqueeze(1) | |
| frames_num = x.shape[2] | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max') | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max') | |
| x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16) | |
| x = torch.mean(x, dim=3) | |
| x = x.transpose(1, 2) | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu_(self.fc1(x)) | |
| x, _ = self.rnn(x) | |
| segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.) | |
| framewise_output = interpolate(segmentwise_output, | |
| self.interpolate_ratio) | |
| framewise_output = pad_framewise_output(framewise_output, frames_num) | |
| output_dict = { | |
| "segmentwise_output": segmentwise_output, | |
| 'framewise_output': framewise_output, | |
| } | |
| return output_dict | |
| class Cnn14RnnTempAttnGruConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| sample_rate: int = 32000, | |
| encoder_rnn_bidirectional: bool = True, | |
| encoder_rnn_hidden_size: int = 256, | |
| encoder_rnn_dropout: float = 0.5, | |
| encoder_rnn_num_layers: int = 3, | |
| decoder_emb_dim: int = 512, | |
| vocab_size: int = 4981, | |
| fc_emb_dim: int = 512, | |
| attn_emb_dim: int = 512, | |
| decoder_rnn_type: str = "GRU", | |
| decoder_num_layers: int = 1, | |
| decoder_d_model: int = 512, | |
| decoder_dropout: float = 0.5, | |
| **kwargs | |
| ): | |
| self.sample_rate = sample_rate | |
| self.encoder_rnn_bidirectional = encoder_rnn_bidirectional | |
| self.encoder_rnn_hidden_size = encoder_rnn_hidden_size | |
| self.encoder_rnn_dropout = encoder_rnn_dropout | |
| self.encoder_rnn_num_layers = encoder_rnn_num_layers | |
| self.decoder_emb_dim = decoder_emb_dim | |
| self.vocab_size = vocab_size | |
| self.fc_emb_dim = fc_emb_dim | |
| self.attn_emb_dim = attn_emb_dim | |
| self.decoder_rnn_type = decoder_rnn_type | |
| self.decoder_num_layers = decoder_num_layers | |
| self.decoder_d_model = decoder_d_model | |
| self.decoder_dropout = decoder_dropout | |
| super().__init__(**kwargs) | |
| class Cnn14RnnTempAttnGruModel(PreTrainedModel): | |
| config_class = Cnn14RnnTempAttnGruConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| sample_rate = config.sample_rate | |
| sr_to_fmax = { | |
| 32000: 14000, | |
| 16000: 8000 | |
| } | |
| self.melspec_extractor = transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=32 * sample_rate // 1000, | |
| win_length=32 * sample_rate // 1000, | |
| hop_length=10 * sample_rate // 1000, | |
| f_min=50, | |
| f_max=sr_to_fmax[sample_rate], | |
| n_mels=64, | |
| norm="slaney", | |
| mel_scale="slaney" | |
| ) | |
| self.db_transform = transforms.AmplitudeToDB() | |
| encoder = Cnn14RnnEncoder( | |
| sample_rate=config.sample_rate, | |
| rnn_bidirectional=config.encoder_rnn_bidirectional, | |
| rnn_hidden_size=config.encoder_rnn_hidden_size, | |
| rnn_dropout=config.encoder_rnn_dropout, | |
| rnn_num_layers=config.encoder_rnn_num_layers | |
| ) | |
| decoder = TemporalBahAttnDecoder( | |
| emb_dim=config.decoder_emb_dim, | |
| vocab_size=config.vocab_size, | |
| fc_emb_dim=config.fc_emb_dim, | |
| attn_emb_dim=config.attn_emb_dim, | |
| rnn_type=config.decoder_rnn_type, | |
| num_layers=config.decoder_num_layers, | |
| d_model=config.decoder_d_model, | |
| dropout=config.decoder_dropout, | |
| ) | |
| cap_model = TemporalSeq2SeqAttnModel(encoder, decoder) | |
| sed_model = Cnn8rnnSedModel(classes_num=447) | |
| self.cap_model = cap_model | |
| self.sed_model = sed_model | |
| def forward(self, | |
| audio: torch.Tensor, | |
| audio_length: Union[List, np.ndarray, torch.Tensor], | |
| temporal_tag: Union[List, np.ndarray, torch.Tensor] = None, | |
| sample_method: str = "beam", | |
| beam_size: int = 3, | |
| max_length: int = 20, | |
| temp: float = 1.0,): | |
| device = self.device | |
| mel_spec = self.melspec_extractor(audio.to(device)) | |
| log_mel_spec = self.db_transform(mel_spec) | |
| sed_tag = self.sed_model(log_mel_spec) | |
| sed_tag = torch.as_tensor(sed_tag).to(device) | |
| if temporal_tag is not None: | |
| temporal_tag = torch.as_tensor(temporal_tag).to(device) | |
| temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0) | |
| temporal_tag = torch.min(temporal_tag, dim=0).values | |
| else: | |
| temporal_tag = sed_tag | |
| input_dict = { | |
| "lms": log_mel_spec, | |
| "wav_len": audio_length, | |
| "temporal_tag": temporal_tag, | |
| "mode": "inference", | |
| "sample_method": sample_method, | |
| "max_length": max_length, | |
| "temp": temp, | |
| } | |
| if sample_method == "beam": | |
| input_dict["beam_size"] = beam_size | |
| return self.cap_model(input_dict)["seq"].cpu() |