| | import collections.abc |
| | import math |
| | import sys |
| | from itertools import repeat |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import timm |
| | import torch |
| | from torch import nn |
| | from torchvision.models.vision_transformer import Encoder |
| |
|
| |
|
| | from typing import Tuple |
| | from functools import partial |
| | from collections.abc import Iterable |
| |
|
| |
|
| | def plot_fbank(fbank, title=None, save_path=None, **kwargs): |
| | fig, axs = plt.subplots(min(4, fbank.shape[0]), 1, sharex=True, sharey=True) |
| | if not isinstance(axs, Iterable): |
| | axs = np.array([axs]) |
| | vmin, vmax = kwargs.get("vmin", None), kwargs.get("vmax", None) |
| | |
| | for channel in range(0, min(4, fbank.shape[0])): |
| | axs[channel].set_title(f"Filter bank channel {channel}, {title}") |
| | im = axs[channel].imshow(fbank[channel].T, aspect="auto", vmin=vmin, vmax=vmax) |
| | axs[channel].set_ylabel("mel") |
| | axs[channel].set_xlabel("time") |
| | plt.gca().invert_yaxis() |
| | plt.tight_layout() |
| | fig.colorbar(im, ax=axs.ravel().tolist()) |
| | plt.show() |
| | if save_path: |
| | fig.savefig(save_path) |
| | plt.close() |
| | return fig |
| |
|
| |
|
| | |
| | def _ntuple(n): |
| | def parse(x): |
| | |
| | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| | return tuple(x) |
| | |
| | return tuple(repeat(x, n)) |
| |
|
| | return parse |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """Image to Patch Embedding""" |
| |
|
| | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
| | super().__init__() |
| | img_size = _ntuple(2)(img_size) |
| | patch_size = _ntuple(2)(patch_size) |
| | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.num_patches = num_patches |
| |
|
| | self.proj = nn.Conv2d( |
| | in_channels=in_chans, |
| | out_channels=embed_dim, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | ) |
| |
|
| | |
| | def forward(self, x): |
| | x = self.proj(x).flatten(2).transpose(1, 2) |
| | return x |
| |
|
| |
|
| | def get_sinusoid_encoding(n_position, d_hid): |
| | """Sinusoid position encoding table""" |
| |
|
| | def get_position_angle_vec(position): |
| | return [ |
| | position / np.power(10000, 2 * (hid_j // 2) / d_hid) |
| | for hid_j in range(d_hid) |
| | ] |
| |
|
| | sinusoid_table = np.array( |
| | [get_position_angle_vec(pos_i) for pos_i in range(n_position)] |
| | ) |
| | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
| | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
| |
|
| | return torch.FloatTensor(sinusoid_table).unsqueeze(0) |
| |
|
| |
|
| | def create_pretrained_model(model_size, |
| | encoder_num_layers = 12, |
| | encoder_num_heads = 12, |
| | encoder_hidden_dim = 768, |
| | encoder_mlp_dim= 3072, |
| | encoder_dropout = 0.0, |
| | encoder_attention_dropout = 0.0, |
| | encoder_norm_layer_eps = 1e-6): |
| | if model_size == "tiny": |
| | v = timm.create_model("deit_tiny_distilled_patch16_224", pretrained=False) |
| | hidden_dim = 182 |
| | |
| | elif model_size == "small": |
| | v = timm.create_model("deit_small_distilled_patch16_224", pretrained=False) |
| | hidden_dim = 384 |
| | |
| | elif model_size == "base": |
| | v = Encoder( |
| | seq_length = 0, |
| | num_layers = encoder_num_layers, |
| | num_heads = encoder_num_heads, |
| | hidden_dim = encoder_hidden_dim, |
| | mlp_dim= encoder_mlp_dim, |
| | dropout = encoder_dropout, |
| | attention_dropout = encoder_attention_dropout, |
| | norm_layer = partial(nn.LayerNorm, eps=encoder_norm_layer_eps)) |
| | hidden_dim = encoder_hidden_dim |
| | |
| | elif model_size == "base_nokd": |
| | v = timm.create_model("deit_base_patch16_384", pretrained=False) |
| | hidden_dim = 768 |
| |
|
| | else: |
| | print("Wrong model size!") |
| | sys.exit(0) |
| |
|
| | return v, hidden_dim |
| |
|
| |
|
| | def _trunc_normal_(tensor, mean, std, a, b): |
| | |
| | |
| | def norm_cdf(x): |
| | |
| | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
| |
|
| | |
| | |
| | |
| | left = norm_cdf((a - mean) / std) |
| | up = norm_cdf((b - mean) / std) |
| |
|
| | |
| | |
| | tensor.uniform_(2 * left - 1, 2 * up - 1) |
| |
|
| | |
| | |
| | tensor.erfinv_() |
| |
|
| | |
| | tensor.mul_(std * math.sqrt(2.0)) |
| | tensor.add_(mean) |
| |
|
| | |
| | tensor.clamp_(min=a, max=b) |
| | return tensor |
| |
|
| |
|
| | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): |
| | |
| | r"""Fills the input Tensor with values drawn from a truncated |
| | normal distribution. The values are effectively drawn from the |
| | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| | with values outside :math:`[a, b]` redrawn until they are within |
| | the bounds. The method used for generating the random values works |
| | best when :math:`a \leq \text{mean} \leq b`. |
| | |
| | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are |
| | applied while sampling the normal with mean/std applied, therefore a, b args |
| | should be adjusted to match the range of mean, std args. |
| | |
| | Args: |
| | tensor: an n-dimensional `torch.Tensor` |
| | mean: the mean of the normal distribution |
| | std: the standard deviation of the normal distribution |
| | a: the minimum cutoff value |
| | b: the maximum cutoff value |
| | Examples: |
| | >>> w = torch.empty(3, 5) |
| | >>> nn.init.trunc_normal_(w) |
| | """ |
| | with torch.no_grad(): |
| | return _trunc_normal_(tensor, mean, std, a, b) |
| |
|
| |
|
| | def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: |
| | """Expands the index along the last dimension of the input tokens. |
| | |
| | Args: |
| | index: |
| | Index tensor with shape (batch_size, idx_length) where each entry is |
| | an index in [0, sequence_length). |
| | tokens: |
| | Tokens tensor with shape (batch_size, sequence_length, dim). |
| | |
| | Returns: |
| | Index tensor with shape (batch_size, idx_length, dim) where the original |
| | indices are repeated dim times along the last dimension. |
| | |
| | """ |
| | dim = tokens.shape[-1] |
| | index = index.unsqueeze(-1).expand(-1, -1, dim) |
| | return index |
| | |
| | def set_at_index( |
| | tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Copies all values into the input tensor at the given indices. |
| | |
| | Args: |
| | tokens: |
| | Tokens tensor with shape (batch_size, sequence_length, dim). |
| | index: |
| | Index tensor with shape (batch_size, index_length). |
| | value: |
| | Value tensor with shape (batch_size, index_length, dim). |
| | |
| | Returns: |
| | Tokens tensor with shape (batch_size, sequence_length, dim) containing |
| | the new values. |
| | |
| | """ |
| | index = expand_index_like(index, tokens) |
| | return torch.scatter(tokens, 1, index, value) |
| |
|
| |
|
| |
|
| |
|
| | def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: |
| | """Repeats a token size times. |
| | |
| | Args: |
| | token: |
| | Token tensor with shape (1, 1, dim). |
| | size: |
| | (batch_size, sequence_length) tuple. |
| | |
| | Returns: |
| | Tensor with shape (batch_size, sequence_length, dim) containing copies |
| | of the input token. |
| | |
| | """ |
| | batch_size, sequence_length = size |
| | return token.repeat(batch_size, sequence_length, 1) |