diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9d6bb2c6c165eabaf24604396cac0d4ab2a2224f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/music.wav filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/soundeffects.wav filter=lfs diff=lfs merge=lfs -text +Assets/ExampleInput/speech.wav filter=lfs diff=lfs merge=lfs -text +Assets/Figure.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f5ce32c448f943505687920a52cb9d566c82117f --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.pyo +.eggs/ +*.egg-info/ +dist/ +build/ +.env diff --git a/Assets/ExampleInput/music.wav b/Assets/ExampleInput/music.wav new file mode 100644 index 0000000000000000000000000000000000000000..5088ffd7e19bd52d042e59f4eab43826d213e3f5 --- /dev/null +++ b/Assets/ExampleInput/music.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57031bac5cdc14b2be1c50811eca257070bf1fa9fb98d2aebeb7eda86c67ceaa +size 491564 diff --git a/Assets/ExampleInput/soundeffects.wav b/Assets/ExampleInput/soundeffects.wav new file mode 100644 index 0000000000000000000000000000000000000000..5627ac6ec5c098d1f0d7f66af10dcd8247220851 --- /dev/null +++ b/Assets/ExampleInput/soundeffects.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96b31750ddb93397789224ddc047ee21cb34760f536818e9ea4cd2328d8dfe69 +size 491564 diff --git a/Assets/ExampleInput/speech.wav b/Assets/ExampleInput/speech.wav new file mode 100644 index 0000000000000000000000000000000000000000..09523e201abdfdb851f26da8ff2849c90b47e50e --- /dev/null +++ b/Assets/ExampleInput/speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:046907c4a7a5bcb9f1b1bb49623fdc4d01be215d6c2783ff5157150938ed21a2 +size 491564 diff --git a/Assets/Figure.png b/Assets/Figure.png new file mode 100644 index 0000000000000000000000000000000000000000..c532d9c7ba9e3d6cdea158dc178541d127b5c203 --- /dev/null +++ b/Assets/Figure.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70639d9a3d6fd61ac73b97104cc082e816a81b4b211ac9092cfee557d8aff6c7 +size 1090177 diff --git a/FlashSR/AudioSR/AudioSRUnet.py b/FlashSR/AudioSR/AudioSRUnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee04b5f0187223fd16439f4e5dbacbec43027aa --- /dev/null +++ b/FlashSR/AudioSR/AudioSRUnet.py @@ -0,0 +1,1127 @@ +''' +from TorchJaekwon.Util.Util import Util +Util.set_sys_path_to_parent_dir(__file__,3) +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +''' +################################################################################ +#just copy from audiosr/latent_diffusion/modules/diffusionmodules/openaimodel.py +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from FlashSR.AudioSR.latent_diffusion.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AudioSRUnet(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size:int = 64, + in_channels:int = 32, + model_channels:int = 128, + out_channels:int = 16, + num_res_blocks:int = 2, + attention_resolutions:list = [8, 4, 2], + dropout=0, + channel_mult=[1, 2, 3, 5], + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=32, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=list(), + context_attn_mask_list=list(), + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + x = th.concat([x,y], dim=1) #jakeoneijk added + y = None + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # if self.num_classes is not None: + # assert y.shape == (x.shape[0],) + # emb = emb + self.label_emb(y) + + if self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = th.cat([h, concate_tensor], dim=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + +if __name__ == '__main__': + ''' + + args = { + 'in_channels': 2, + 'model_channels': 64, + 'out_channels': 1 + } + audio_sr = AudioSRUnet(**args) + audio_sr(x = th.randn(1, 1, 128, 256), timesteps = th.tensor([30]), y = th.randn(1, 1, 128, 256)) #jakeoneijk added +''' + audio_sr = AudioSRUnet() + audio_sr(x = th.randn(1, 16, 64, 32), timesteps = th.tensor([30]), y = th.randn(1, 16, 64, 32)) #jakeoneijk added \ No newline at end of file diff --git a/FlashSR/AudioSR/EncoderDecoder.py b/FlashSR/AudioSR/EncoderDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..aa54efac86ba56bccd4f43052411910c306cc78d --- /dev/null +++ b/FlashSR/AudioSR/EncoderDecoder.py @@ -0,0 +1,1010 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x \ No newline at end of file diff --git a/FlashSR/AudioSR/Vocoder.py b/FlashSR/AudioSR/Vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac265761a9522c60afe3fdbece63b05f0c42484 --- /dev/null +++ b/FlashSR/AudioSR/Vocoder.py @@ -0,0 +1,167 @@ +import torch + +import FlashSR.AudioSR.hifigan as hifigan + + +def get_vocoder_config(): + return { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, + } + + +def get_vocoder_config_48k(): + return { + "resblock": "1", + "num_gpus": 8, + "batch_size": 128, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [6, 5, 4, 2, 2], + "upsample_kernel_sizes": [12, 10, 8, 4, 4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3, 7, 11, 15], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 15360, + "num_mels": 256, + "n_fft": 2048, + "hop_size": 480, + "win_size": 2048, + "sampling_rate": 48000, + "fmin": 20, + "fmax": 24000, + "fmax_for_loss": None, + "num_workers": 8, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:18273", + "world_size": 1, + }, + } + + +def get_available_checkpoint_keys(model, ckpt): + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def torch_version_orig_mod_remove(state_dict): + new_state_dict = {} + new_state_dict["generator"] = {} + for key in state_dict["generator"].keys(): + if "_orig_mod." in key: + new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ + "generator" + ][key] + else: + new_state_dict["generator"][key] = state_dict["generator"][key] + return new_state_dict + + +def get_vocoder(config, device, mel_bins): + name = "HiFi-GAN" + speaker = "" + if name == "MelGAN": + if speaker == "LJSpeech": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" + ) + elif speaker == "universal": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" + ) + vocoder.mel2wav.eval() + vocoder.mel2wav.to(device) + elif name == "HiFi-GAN": + if mel_bins == 64: + config = get_vocoder_config() + config = hifigan.AttrDict(config) + vocoder = hifigan.Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + else: + config = get_vocoder_config_48k() + config = hifigan.AttrDict(config) + vocoder = hifigan.Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + # wavs = [wav for wav in wavs] + + # for i in range(len(mels)): + # if lengths is not None: + # wavs[i] = wavs[i][: lengths[i]] + + return wavs diff --git a/FlashSR/AudioSR/args/mel_argument.yaml b/FlashSR/AudioSR/args/mel_argument.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a50510087283d519c6fe59873cbfeedc3ac25b8f --- /dev/null +++ b/FlashSR/AudioSR/args/mel_argument.yaml @@ -0,0 +1,6 @@ +nfft: 2048 +hop_size: 480 +sample_rate: 48000 +mel_size: 256 +frequency_min: 20 +frequency_max: 24000 \ No newline at end of file diff --git a/FlashSR/AudioSR/args/model_argument.yaml b/FlashSR/AudioSR/args/model_argument.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a83d6780d29d3017995eeabfaa5973ae0e9bdae --- /dev/null +++ b/FlashSR/AudioSR/args/model_argument.yaml @@ -0,0 +1,25 @@ +batchsize: 4 +ddconfig: + attn_resolutions: [] + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 8 + double_z: true + downsample_time: false + dropout: 0.1 + in_channels: 1 + mel_bins: 256 + num_res_blocks: 2 + out_ch: 1 + resolution: 256 + z_channels: 16 +embed_dim: 16 +image_key: fbank +monitor: val/rec_loss +reload_from_ckpt: /mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt +sampling_rate: 48000 +subband: 1 +time_shuffle: 1 diff --git a/FlashSR/AudioSR/autoencoder.py b/FlashSR/AudioSR/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9c502fb76f1818427736ed62d3287de3beb6d6 --- /dev/null +++ b/FlashSR/AudioSR/autoencoder.py @@ -0,0 +1,370 @@ +import os +import soundfile as sf +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from FlashSR.AudioSR.EncoderDecoder import Encoder, Decoder +from FlashSR.AudioSR.Vocoder import get_vocoder + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + batchsize=None, + embed_dim=None, + time_shuffle=1, + subband=1, + sampling_rate=16000, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + image_key="fbank", + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__() + self.automatic_optimization = False + assert ( + "mel_bins" in ddconfig.keys() + ), "mel_bins is not specified in the Autoencoder config" + num_mel = ddconfig["mel_bins"] + self.image_key = image_key + self.sampling_rate = sampling_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.loss = None + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + if self.image_key == "fbank": + self.vocoder = get_vocoder(None, "cpu", num_mel) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.learning_rate = float(base_learning_rate) + # print("Initial learning rate %s" % self.learning_rate) + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.feature_cache = None + self.flag_first_run = True + self.train_step = 0 + + self.logger_save_dir = None + self.logger_exp_name = None + + def get_log_dir(self): + if self.logger_save_dir is None and self.logger_exp_name is None: + return os.path.join(self.logger.save_dir, self.logger._project) + else: + return os.path.join(self.logger_save_dir, self.logger_exp_name) + + def set_log_dir(self, save_dir, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_name = exp_name + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + # x = self.time_shuffle_operation(x) + # x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + # bs, ch, shuffled_timesteps, fbins = dec.size() + # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) + # dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + from audiosr.utilities.model import vocoder_infer + + if self.image_key == "fbank": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + elif self.image_key == "stft": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = self.wave_decoder(dec) + return wav_reconstruction + + def visualize_latent(self, input): + import matplotlib.pyplot as plt + + # for i in range(10): + # zero_input = torch.zeros_like(input) - 11.59 + # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 + + # posterior = self.encode(zero_input) + # latent = posterior.sample() + # avg_latent = torch.mean(latent, dim=1)[0] + # plt.imshow(avg_latent.cpu().detach().numpy().T) + # plt.savefig("%s.png" % i) + # plt.close() + + np.save("input.npy", input.cpu().detach().numpy()) + # zero_input = torch.zeros_like(input) - 11.59 + time_input = input.clone() + time_input[:, :, :, :32] *= 0 + time_input[:, :, :, :32] -= 11.59 + + np.save("time_input.npy", time_input.cpu().detach().numpy()) + + posterior = self.encode(time_input) + latent = posterior.sample() + np.save("time_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("freq_%s.png" % i) + plt.close() + + freq_input = input.clone() + freq_input[:, :, :512, :] *= 0 + freq_input[:, :, :512, :] -= 11.59 + + np.save("freq_input.npy", freq_input.cpu().detach().numpy()) + + posterior = self.encode(freq_input) + latent = posterior.sample() + np.save("freq_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("time_%s.png" % i) + plt.close() + + def get_input(self, batch): + fname, text, label_indices, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["label_vector"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + # if(self.time_shuffle != 1): + # if(fbank.size(1) % self.time_shuffle != 0): + # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) + # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) + + ret = {} + + ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( + fbank.unsqueeze(1), + stft.unsqueeze(1), + fname, + waveform.unsqueeze(1), + ) + + return ret + + def save_wave(self, batch_wav, fname, save_dir): + os.makedirs(save_dir, exist_ok=True) + + for wav, name in zip(batch_wav, fname): + name = os.path.basename(name) + + sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): + log = dict() + x = batch.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + log["samples"] = self.decode(posterior.sample()) + log["reconstructions"] = xrec + + log["inputs"] = x + wavs = self._log_img(log, train=train, index=0, waveform=waveform) + return wavs + + def _log_img(self, log, train=True, index=0, waveform=None): + images_input = self.tensor2numpy(log["inputs"][index, 0]).T + images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T + images_samples = self.tensor2numpy(log["samples"][index, 0]).T + + if train: + name = "train" + else: + name = "val" + + if self.logger is not None: + self.logger.log_image( + "img_%s" % name, + [images_input, images_reconstruct, images_samples], + caption=["input", "reconstruct", "samples"], + ) + + inputs, reconstructions, samples = ( + log["inputs"], + log["reconstructions"], + log["samples"], + ) + + if self.image_key == "fbank": + wav_original, wav_prediction = synth_one_sample( + inputs[index], + reconstructions[index], + labels="validation", + vocoder=self.vocoder, + ) + wav_original, wav_samples = synth_one_sample( + inputs[index], samples[index], labels="validation", vocoder=self.vocoder + ) + wav_original, wav_samples, wav_prediction = ( + wav_original[0], + wav_samples[0], + wav_prediction[0], + ) + elif self.image_key == "stft": + wav_prediction = ( + self.decode_to_waveform(reconstructions)[index, 0] + .cpu() + .detach() + .numpy() + ) + wav_samples = ( + self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() + ) + wav_original = waveform[index, 0].cpu().detach().numpy() + + if self.logger is not None: + self.logger.experiment.log( + { + "original_%s" + % name: wandb.Audio( + wav_original, caption="original", sample_rate=self.sampling_rate + ), + "reconstruct_%s" + % name: wandb.Audio( + wav_prediction, + caption="reconstruct", + sample_rate=self.sampling_rate, + ), + "samples_%s" + % name: wandb.Audio( + wav_samples, caption="samples", sample_rate=self.sampling_rate + ), + } + ) + + return wav_original, wav_prediction, wav_samples + + def tensor2numpy(self, tensor): + return tensor.cpu().detach().numpy() + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean \ No newline at end of file diff --git a/FlashSR/AudioSR/hifigan/LICENSE b/FlashSR/AudioSR/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac --- /dev/null +++ b/FlashSR/AudioSR/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/AudioSR/hifigan/__init__.py b/FlashSR/AudioSR/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34e055557bf2ecb457376663b67390543c71fb1f --- /dev/null +++ b/FlashSR/AudioSR/hifigan/__init__.py @@ -0,0 +1,8 @@ +from .models_v2 import Generator +from .models import Generator as Generator_old + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/FlashSR/AudioSR/hifigan/models.py b/FlashSR/AudioSR/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/FlashSR/AudioSR/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/FlashSR/AudioSR/hifigan/models_v2.py b/FlashSR/AudioSR/hifigan/models_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..27a2df6b54bdd3a5b259645442624800ac0e8afe --- /dev/null +++ b/FlashSR/AudioSR/hifigan/models_v2.py @@ -0,0 +1,395 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + u * 2, + u, + padding=u // 2 + u % 2, + output_padding=u % 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + # import ipdb; ipdb.set_trace() + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +################################################################################################## + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch.nn import Conv1d, ConvTranspose1d +# from torch.nn.utils import weight_norm, remove_weight_norm + +# LRELU_SLOPE = 0.1 + + +# def init_weights(m, mean=0.0, std=0.01): +# classname = m.__class__.__name__ +# if classname.find("Conv") != -1: +# m.weight.data.normal_(mean, std) + + +# def get_padding(kernel_size, dilation=1): +# return int((kernel_size * dilation - dilation) / 2) + + +# class ResBlock(torch.nn.Module): +# def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): +# super(ResBlock, self).__init__() +# self.h = h +# self.convs1 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[0], +# padding=get_padding(kernel_size, dilation[0]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[1], +# padding=get_padding(kernel_size, dilation[1]), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=dilation[2], +# padding=get_padding(kernel_size, dilation[2]), +# ) +# ), +# ] +# ) +# self.convs1.apply(init_weights) + +# self.convs2 = nn.ModuleList( +# [ +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# weight_norm( +# Conv1d( +# channels, +# channels, +# kernel_size, +# 1, +# dilation=1, +# padding=get_padding(kernel_size, 1), +# ) +# ), +# ] +# ) +# self.convs2.apply(init_weights) + +# def forward(self, x): +# for c1, c2 in zip(self.convs1, self.convs2): +# xt = F.leaky_relu(x, LRELU_SLOPE) +# xt = c1(xt) +# xt = F.leaky_relu(xt, LRELU_SLOPE) +# xt = c2(xt) +# x = xt + x +# return x + +# def remove_weight_norm(self): +# for l in self.convs1: +# remove_weight_norm(l) +# for l in self.convs2: +# remove_weight_norm(l) + +# class Generator(torch.nn.Module): +# def __init__(self, h): +# super(Generator, self).__init__() +# self.h = h +# self.num_kernels = len(h.resblock_kernel_sizes) +# self.num_upsamples = len(h.upsample_rates) +# self.conv_pre = weight_norm( +# Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) +# ) +# resblock = ResBlock + +# self.ups = nn.ModuleList() +# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): +# self.ups.append( +# weight_norm( +# ConvTranspose1d( +# h.upsample_initial_channel // (2**i), +# h.upsample_initial_channel // (2 ** (i + 1)), +# k, +# u, +# padding=(k - u) // 2, +# ) +# ) +# ) + +# self.resblocks = nn.ModuleList() +# for i in range(len(self.ups)): +# ch = h.upsample_initial_channel // (2 ** (i + 1)) +# for j, (k, d) in enumerate( +# zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) +# ): +# self.resblocks.append(resblock(h, ch, k, d)) + +# self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) +# self.ups.apply(init_weights) +# self.conv_post.apply(init_weights) + +# def forward(self, x): +# x = self.conv_pre(x) +# for i in range(self.num_upsamples): +# x = F.leaky_relu(x, LRELU_SLOPE) +# x = self.ups[i](x) +# xs = None +# for j in range(self.num_kernels): +# if xs is None: +# xs = self.resblocks[i * self.num_kernels + j](x) +# else: +# xs += self.resblocks[i * self.num_kernels + j](x) +# x = xs / self.num_kernels +# x = F.leaky_relu(x) +# x = self.conv_post(x) +# x = torch.tanh(x) + +# return x + +# def remove_weight_norm(self): +# print("Removing weight norm...") +# for l in self.ups: +# remove_weight_norm(l) +# for l in self.resblocks: +# l.remove_weight_norm() +# remove_weight_norm(self.conv_pre) +# remove_weight_norm(self.conv_post) diff --git a/FlashSR/AudioSR/latent_diffusion/__init__.py b/FlashSR/AudioSR/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/attention.py b/FlashSR/AudioSR/latent_diffusion/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1132144bd7699b8c91800f84cc05d5e061ca3b0e --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/attention.py @@ -0,0 +1,467 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +# class CrossAttention(nn.Module): +# """ +# ### Cross Attention Layer +# This falls-back to self-attention when conditional embeddings are not specified. +# """ + +# use_flash_attention: bool = True + +# # use_flash_attention: bool = False +# def __init__( +# self, +# query_dim, +# context_dim=None, +# heads=8, +# dim_head=64, +# dropout=0.0, +# is_inplace: bool = True, +# ): +# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): +# """ +# :param d_model: is the input embedding size +# :param n_heads: is the number of attention heads +# :param d_head: is the size of a attention head +# :param d_cond: is the size of the conditional embeddings +# :param is_inplace: specifies whether to perform the attention softmax computation inplace to +# save memory +# """ +# super().__init__() + +# self.is_inplace = is_inplace +# self.n_heads = heads +# self.d_head = dim_head + +# # Attention scaling factor +# self.scale = dim_head**-0.5 + +# # The normal self-attention layer +# if context_dim is None: +# context_dim = query_dim + +# # Query, key and value mappings +# d_attn = dim_head * heads +# self.to_q = nn.Linear(query_dim, d_attn, bias=False) +# self.to_k = nn.Linear(context_dim, d_attn, bias=False) +# self.to_v = nn.Linear(context_dim, d_attn, bias=False) + +# # Final linear layer +# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + +# # Setup [flash attention](https://github.com/HazyResearch/flash-attention). +# # Flash attention is only used if it's installed +# # and `CrossAttention.use_flash_attention` is set to `True`. +# try: +# # You can install flash attention by cloning their Github repo, +# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) +# # and then running `python setup.py install` +# from flash_attn.flash_attention import FlashAttention + +# self.flash = FlashAttention() +# # Set the scale for scaled dot-product attention. +# self.flash.softmax_scale = self.scale +# # Set to `None` if it's not installed +# except ImportError: +# self.flash = None + +# def forward(self, x, context=None, mask=None): +# """ +# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` +# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` +# """ + +# # If `cond` is `None` we perform self attention +# has_cond = context is not None +# if not has_cond: +# context = x + +# # Get query, key and value vectors +# q = self.to_q(x) +# k = self.to_k(context) +# v = self.to_v(context) + +# # Use flash attention if it's available and the head size is less than or equal to `128` +# if ( +# CrossAttention.use_flash_attention +# and self.flash is not None +# and not has_cond +# and self.d_head <= 128 +# ): +# return self.flash_attention(q, k, v) +# # Otherwise, fallback to normal attention +# else: +# return self.normal_attention(q, k, v) + +# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Flash Attention +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Get batch size and number of elements along sequence axis (`width * height`) +# batch_size, seq_len, _ = q.shape + +# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of +# # shape `[batch_size, seq_len, 3, n_heads * d_head]` +# qkv = torch.stack((q, k, v), dim=2) +# # Split the heads +# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + +# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to +# # fit this size. +# if self.d_head <= 32: +# pad = 32 - self.d_head +# elif self.d_head <= 64: +# pad = 64 - self.d_head +# elif self.d_head <= 128: +# pad = 128 - self.d_head +# else: +# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + +# # Pad the heads +# if pad: +# qkv = torch.cat( +# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 +# ) + +# # Compute attention +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` +# # TODO here I add the dtype changing +# out, _ = self.flash(qkv.type(torch.float16)) +# # Truncate the extra head size +# out = out[:, :, :, : self.d_head].float() +# # Reshape to `[batch_size, seq_len, n_heads * d_head]` +# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + +# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): +# """ +# #### Normal Attention + +# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` +# """ + +# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` +# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] +# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] +# v = v.view(*v.shape[:2], self.n_heads, -1) + +# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ +# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + +# # Compute softmax +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ +# if self.is_inplace: +# half = attn.shape[0] // 2 +# attn[half:] = attn[half:].softmax(dim=-1) +# attn[:half] = attn[:half].softmax(dim=-1) +# else: +# attn = attn.softmax(dim=-1) + +# # Compute attention output +# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ +# # attn: [bs, 20, 64, 1] +# # v: [bs, 1, 20, 32] +# out = torch.einsum("bhij,bjhd->bihd", attn, v) +# # Reshape to `[batch_size, height * width, n_heads * d_head]` +# out = out.reshape(*out.shape[:2], -1) +# # Map to `[batch_size, height * width, d_model]` with a linear layer +# return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~(mask == 1), max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context, mask), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + ): + super().__init__() + + context_dim = context_dim + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None, mask=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py new file mode 100644 index 0000000000000000000000000000000000000000..855f3782a35542083013df256815564df8d34655 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py @@ -0,0 +1,149 @@ +""" +Reference Repo: https://github.com/facebookresearch/AudioMAE +""" + +import torch +import torch.nn as nn +#from timm.models.layers import to_2tuple +import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit +import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae + +# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" + + def __init__( + self, + ): + super().__init__() + model = models_vit.__dict__["vit_base_patch16"]( + num_classes=527, + drop_path_rate=0.1, + global_pool=True, + mask_2d=True, + use_custom_patch=False, + ) + + img_size = (1024, 128) + emb_dim = 768 + + model.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=(16, 16), + in_chans=1, + embed_dim=emb_dim, + stride=16, + ) + num_patches = model.patch_embed.num_patches + # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 + model.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False + ) # fixed sin-cos embedding + + # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model + + def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): + """ + x: mel fbank [Batch, 1, T, F] + mask_t_prob: 'T masking ratio (percentage of removed patches).' + mask_f_prob: 'F masking ratio (percentage of removed patches).' + """ + return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) + + +class Vanilla_AudioMAE(nn.Module): + """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" + + def __init__( + self, + ): + super().__init__() + model = models_mae.__dict__["mae_vit_base_patch16"]( + in_chans=1, audio_exp=True, img_size=(1024, 128) + ) + + # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # msg = model.load_state_dict(checkpoint['model'], strict=False) + + # Skip the missing keys of decoder modules (not required) + # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') + + self.model = model.eval() + + def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): + """ + x: mel fbank [Batch, 1, 1024 (T), 128 (F)] + mask_ratio: 'masking ratio (percentage of removed patches).' + """ + with torch.no_grad(): + # embed: [B, 513, 768] for mask_ratio=0.0 + if no_mask: + if no_average: + raise RuntimeError("This function is deprecated") + embed = self.model.forward_encoder_no_random_mask_no_average( + x + ) # mask_ratio + else: + embed = self.model.forward_encoder_no_mask(x) # mask_ratio + else: + raise RuntimeError("This function is deprecated") + embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) + return embed + + +if __name__ == "__main__": + model = Vanilla_AudioMAE().cuda() + input = torch.randn(4, 1, 1024, 128).cuda() + print("The first run") + embed = model(input, mask_ratio=0.0, no_mask=True) + print(embed) + print("The second run") + embed = model(input, mask_ratio=0.0) + print(embed) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..84395f2ddfd31ad81a420121dd24ca73b10f59e9 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py @@ -0,0 +1,613 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import Block +from audiosr.latent_diffusion.modules.audiomae.util.pos_embed import ( + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_flexible, +) +from audiosr.latent_diffusion.modules.audiomae.util.patch_embed import ( + PatchEmbed_new, + PatchEmbed_org, +) + + +class MaskedAutoencoderViT(nn.Module): + """Masked Autoencoder with VisionTransformer backbone""" + + def __init__( + self, + img_size=224, + patch_size=16, + stride=10, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + norm_pix_loss=False, + audio_exp=False, + alpha=0.0, + temperature=0.2, + mode=0, + contextual_depth=8, + use_custom_patch=False, + split_pos=False, + pos_trainable=False, + use_nce=False, + beta=4.0, + decoder_mode=0, + mask_t_prob=0.6, + mask_f_prob=0.5, + mask_2d=False, + epoch=0, + no_shift=False, + ): + super().__init__() + + self.audio_exp = audio_exp + self.embed_dim = embed_dim + self.decoder_embed_dim = decoder_embed_dim + # -------------------------------------------------------------------------- + # MAE encoder specifics + if use_custom_patch: + print( + f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}" + ) + self.patch_embed = PatchEmbed_new( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + stride=stride, + ) + else: + self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) + self.use_custom_patch = use_custom_patch + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # self.split_pos = split_pos # not useful + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable + ) # fixed sin-cos embedding + + self.encoder_depth = depth + self.contextual_depth = contextual_depth + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=pos_trainable, + ) # fixed sin-cos embedding + + self.no_shift = no_shift + + self.decoder_mode = decoder_mode + if ( + self.use_custom_patch + ): # overlapped patches as in AST. Similar performance yet compute heavy + window_size = (6, 6) + feat_size = (102, 12) + else: + window_size = (4, 4) + feat_size = (64, 8) + if self.decoder_mode == 1: + decoder_modules = [] + for index in range(16): + if self.no_shift: + shift_size = (0, 0) + else: + if (index % 2) == 0: + shift_size = (0, 0) + else: + shift_size = (2, 0) + # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]) + decoder_modules.append( + SwinTransformerBlock( + dim=decoder_embed_dim, + num_heads=16, + feat_size=feat_size, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + drop=0.0, + drop_attn=0.0, + drop_path=0.0, + extra_norm=False, + sequential_attn=False, + norm_layer=norm_layer, # nn.LayerNorm, + ) + ) + self.decoder_blocks = nn.ModuleList(decoder_modules) + else: + # Transfomer + self.decoder_blocks = nn.ModuleList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None, + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias=True + ) # decoder to patch + + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.patch_size = patch_size + self.stride = stride + + # audio exps + self.alpha = alpha + self.T = temperature + self.mode = mode + self.use_nce = use_nce + self.beta = beta + + self.log_softmax = nn.LogSoftmax(dim=-1) + + self.mask_t_prob = mask_t_prob + self.mask_f_prob = mask_f_prob + self.mask_2d = mask_2d + + self.epoch = epoch + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + if self.audio_exp: + pos_embed = get_2d_sincos_pos_embed_flexible( + self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True + ) + else: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.audio_exp: + decoder_pos_embed = get_2d_sincos_pos_embed_flexible( + self.decoder_pos_embed.shape[-1], + self.patch_embed.patch_hw, + cls_token=True, + ) + else: + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + L = (H/p)*(W/p) + """ + p = self.patch_embed.patch_size[0] + # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + if self.audio_exp: + if self.use_custom_patch: # overlapped patch + h, w = self.patch_embed.patch_hw + # todo: fixed h/w patch size and stride size. Make hw custom in the future + x = imgs.unfold(2, self.patch_size, self.stride).unfold( + 3, self.patch_size, self.stride + ) # n,1,H,W -> n,1,h,w,p,p + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + # x = torch.einsum('nchpwq->nhwpqc', x) + # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = imgs.shape[2] // p + w = imgs.shape[3] // p + # h,w = self.patch_embed.patch_hw + x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + else: + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + specs: (N, 1, H, W) + """ + p = self.patch_embed.patch_size[0] + h = 1024 // p + w = 128 // p + x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) + x = torch.einsum("nhwpqc->nchpwq", x) + specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p)) + return specs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: # overlapped patch + T = 101 + F = 12 + else: + T = 64 + F = 8 + # x = x.reshape(N, T, F, D) + len_keep_t = int(T * (1 - mask_t_prob)) + len_keep_f = int(F * (1 - mask_f_prob)) + + # noise for mask in time + noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample aling time + ids_shuffle_t = torch.argsort( + noise_t, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) + ids_keep_t = ids_shuffle_t[:, :len_keep_t] + # noise mask in freq + noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1] + ids_shuffle_f = torch.argsort( + noise_f, dim=1 + ) # ascend: small is keep, large is remove + ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) + ids_keep_f = ids_shuffle_f[:, :len_keep_f] # + + # generate the binary mask: 0 is keep, 1 is remove + # mask in freq + mask_f = torch.ones(N, F, device=x.device) + mask_f[:, :len_keep_f] = 0 + mask_f = ( + torch.gather(mask_f, dim=1, index=ids_restore_f) + .unsqueeze(1) + .repeat(1, T, 1) + ) # N,T,F + # mask in time + mask_t = torch.ones(N, T, device=x.device) + mask_t[:, :len_keep_t] = 0 + mask_t = ( + torch.gather(mask_t, dim=1, index=ids_restore_t) + .unsqueeze(1) + .repeat(1, F, 1) + .permute(0, 2, 1) + ) # N,T,F + mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F + + # get masked x + id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device) + id2res = id2res + 999 * mask # add a large value for masked elements + id2res2 = torch.argsort(id2res.flatten(start_dim=1)) + ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) + mask = mask.flatten(start_dim=1) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio, mask_2d=False): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + if mask_2d: + x, mask, ids_restore = self.random_masking_2d( + x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob + ) + else: + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore, None + + def forward_encoder_no_random_mask_no_average(self, x): + # embed patches + x = self.patch_embed(x) + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # if mask_2d: + # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob) + # else: + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x + + def forward_encoder_no_mask(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + contextual_embs = [] + for n, blk in enumerate(self.blocks): + x = blk(x) + if n > self.contextual_depth: + contextual_embs.append(self.norm(x)) + # x = self.norm(x) + contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0) + + return contextual_emb + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + if self.decoder_mode != 0: + B, L, D = x.shape + x = x[:, 1:, :] + if self.use_custom_patch: + x = x.reshape(B, 101, 12, D) + x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack + x = x.reshape(B, 1224, D) + if self.decoder_mode > 3: # mvit + x = self.decoder_blocks(x) + else: + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + pred = self.decoder_pred(x) + + # remove cls token + if self.decoder_mode != 0: + if self.use_custom_patch: + pred = pred.reshape(B, 102, 12, 256) + pred = pred[:, :101, :, :] + pred = pred.reshape(B, 1212, 256) + else: + pred = pred + else: + pred = pred[:, 1:, :] + return pred, None, None # emb, emb_pixel + + def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.8): + emb_enc, mask, ids_restore, _ = self.forward_encoder( + imgs, mask_ratio, mask_2d=self.mask_2d + ) + pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3] + loss_recon = self.forward_loss( + imgs, pred, mask, norm_pix_loss=self.norm_pix_loss + ) + loss_contrastive = torch.FloatTensor([0.0]).cuda() + return loss_recon, pred, mask, loss_contrastive + + +def mae_vit_small_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cb37adbc16cfb9a232493c473c9400f199655b6c --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn +import timm.models.vision_transformer + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with support for global average pooling""" + + def __init__( + self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs + ): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs["norm_layer"] + embed_dim = kwargs["embed_dim"] + self.fc_norm = norm_layer(embed_dim) + del self.norm # remove the original norm + self.mask_2d = mask_2d + self.use_custom_patch = use_custom_patch + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed[:, 1:, :] + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def random_masking_2d(self, x, mask_t_prob, mask_f_prob): + """ + 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + + N, L, D = x.shape # batch, length, dim + if self.use_custom_patch: + # # for AS + T = 101 # 64,101 + F = 12 # 8,12 + # # for ESC + # T=50 + # F=12 + # for SPC + # T=12 + # F=12 + else: + # ## for AS + T = 64 + F = 8 + # ## for ESC + # T=32 + # F=8 + ## for SPC + # T=8 + # F=8 + + # mask T + x = x.reshape(N, T, F, D) + len_keep_T = int(T * (1 - mask_t_prob)) + noise = torch.rand(N, T, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_T] + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) + # x_masked = torch.gather(x, dim=1, index=index) + # x_masked = x_masked.reshape(N,len_keep_T*F,D) + x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D + + # mask F + # x = x.reshape(N, T, F, D) + x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D + len_keep_F = int(F * (1 - mask_f_prob)) + noise = torch.rand(N, F, device=x.device) # noise in [0, 1] + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_keep = ids_shuffle[:, :len_keep_F] + # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) + index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) + x_masked = torch.gather(x, dim=1, index=index) + x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D + # x_masked = x_masked.reshape(N,len_keep*T,D) + x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) + + return x_masked, None, None + + def forward_features_mask(self, x, mask_t_prob, mask_f_prob): + B = x.shape[0] # 4,1,1024,128 + x = self.patch_embed(x) # 4, 512, 768 + + x = x + self.pos_embed[:, 1:, :] + if self.random_masking_2d: + x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) + else: + x, mask, ids_restore = self.random_masking(x, mask_t_prob) + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self.pos_drop(x) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + # overwrite original timm + def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): + if mask_t_prob > 0.0 or mask_f_prob > 0.0: + x = self.forward_features_mask( + x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob + ) + else: + x = self.forward_features(x) + x = self.head(x) + return x + + +def vit_small_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..525e3c783c3d348e593dc89c2b5fb8520918e9ea --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b90f89a7d5f78c31bc9113dd88b632b0c234f10a --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import os +import PIL + +from torchvision import datasets, transforms + +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + root = os.path.join(args.data_path, "train" if is_train else "val") + dataset = datasets.ImageFolder(root, transform=transform) + + print(dataset) + + return dataset + + +def build_transform(is_train, args): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + # train transform + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation="bicubic", + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + mean=mean, + std=std, + ) + return transform + + # eval transform + t = [] + if args.input_size <= 224: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(args.input_size / crop_pct) + t.append( + transforms.Resize( + size, interpolation=PIL.Image.BICUBIC + ), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..fc43923d22cf2c9af4ae9166612c3f3477faf254 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + + def __init__( + self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 + ): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + trust_coefficient=trust_coefficient, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g["params"]: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g["weight_decay"]) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where( + param_norm > 0.0, + torch.where( + update_norm > 0, + (g["trust_coefficient"] * param_norm / update_norm), + one, + ), + one, + ) + dp = dp.mul(q) + + param_state = self.state[p] + if "mu" not in param_state: + param_state["mu"] = torch.zeros_like(p) + mu = param_state["mu"] + mu.mul_(g["momentum"]).add_(dp) + p.add_(mu, alpha=-g["lr"]) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ed69d7b8d019dbf5d90571541668e2bd8efe8 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + + +def param_groups_lrd( + model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 +): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + + num_layers = len(model.blocks) + 1 + + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0.0 + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ["cls_token", "pos_embed"]: + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("blocks"): + return int(name.split(".")[1]) + 1 + else: + return num_layers diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py new file mode 100644 index 0000000000000000000000000000000000000000..efe184d8e3fb63ec6b4f83375b6ea719985900de --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..74184e09e23e0e174350b894b0cff29600c18b71 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.dist_url = "tcp://%s:%s" % ( + os.environ["MASTER_ADDR"], + os.environ["MASTER_PORT"], + ) + os.environ["LOCAL_RANK"] = str(args.gpu) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "scaler": loss_scaler.state_dict(), + "args": args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {"epoch": epoch} + model.save_checkpoint( + save_dir=args.output_dir, + tag="checkpoint-%s" % epoch_name, + client_state=client_state, + ) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + print("Resume checkpoint %s" % args.resume) + if ( + "optimizer" in checkpoint + and "epoch" in checkpoint + and not (hasattr(args, "eval") and args.eval) + ): + optimizer.load_state_dict(checkpoint["optimizer"]) + args.start_epoch = checkpoint["epoch"] + 1 + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt): + # keys_to_copy=['pos_embed','patch_embed'] + # replaced=0 + + vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"] + vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"] + + # pos_emb % not trainable, use default + pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768 + pos_embed = pos_embed_v[:, 1:, :] # 1,588,768 + cls_embed = pos_embed_v[:, 0, :].unsqueeze(1) + pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768 + print("Position interpolate from 14,14 to 64,8") + pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14 + pos_embed = torch.nn.functional.interpolate( + pos_embed, size=(64, 8), mode="bicubic", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).flatten( + 1, 2 + ) # 1, 14, 14, 768 => 1, 196,768 + pos_embed = torch.cat((cls_embed, pos_embed), dim=1) + assert vmae_ckpt["pos_embed"].shape == pos_embed.shape + vmae_ckpt["pos_embed"] = pos_embed + # patch_emb + # aggregate 3 channels in video-rgb ckpt to 1 channel for audio + v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16 + new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1)) + assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape + vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight + vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"] + + # hack + vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"] + vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"] + + # replace transformer encoder + for k, v in vmae_ckpt.items(): + if k.startswith("blocks."): + kk = k.replace("blocks.", "blocks_v.") + vmae_ckpt[k] = vmae_ckpt[kk] + elif k.startswith("blocks_v."): + pass + else: + print(k) + print(k) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1e4d436c6f79aef9bf1de32cdac5d4f037c775 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple + + +class PatchEmbed_org(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + y = x.flatten(2).transpose(1, 2) + return y + + +class PatchEmbed_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) # with overlapped patches + # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h * w + + def get_output_shape(self, img_size): + # todo: don't be lazy.. + return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + + +class PatchEmbed3D_new(nn.Module): + """Flexible Image to Patch Embedding""" + + def __init__( + self, + video_size=(16, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ): + super().__init__() + + self.video_size = video_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride + ) + _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w + self.patch_thw = (t, h, w) + self.num_patches = t * h * w + + def get_output_shape(self, video_size): + # todo: don't be lazy.. + return self.proj( + torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) + ).shape + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 + x = x.flatten(2) # 32, 768, 1568 + x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 + return x + + +if __name__ == "__main__": + # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) + # input = torch.rand(8,1,1024,128) + # output = patch_emb(input) + # print(output.shape) # (8,512,64) + + patch_emb = PatchEmbed3D_new( + video_size=(6, 224, 224), + patch_size=(2, 16, 16), + in_chans=3, + embed_dim=768, + stride=(2, 16, 16), + ) + input = torch.rand(8, 3, 6, 224, 224) + output = patch_emb(input) + print(output.shape) # (8,64) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9177ed98dffcf35264f38aff94e7f00fb50abf --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + # omega = np.arange(embed_dim // 2, dtype=np.float) + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + # new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size[0], new_size[1]), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + model.pos_embed.shape[-2] - num_patches + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size[0], orig_size[1], new_size[0], new_size[1]) + ) + # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) + pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove + pos_tokens = pos_tokens.reshape( + -1, orig_size[0], orig_size[1], embedding_size + ) # .permute(0, 3, 1, 2) + # pos_tokens = torch.nn.functional.interpolate( + # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) + + # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff + pos_tokens = pos_tokens.flatten(1, 2) + new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +def interpolate_patch_embed_audio( + model, + checkpoint_model, + orig_channel, + new_channel=1, + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), +): + if orig_channel != new_channel: + if "patch_embed.proj.weight" in checkpoint_model: + # aggregate 3 channels in rgb ckpt to 1 channel for audio + new_proj_weight = torch.nn.Parameter( + torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( + 1 + ) + ) + checkpoint_model["patch_embed.proj.weight"] = new_proj_weight diff --git a/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py new file mode 100644 index 0000000000000000000000000000000000000000..3f8137249503f6eaa25c3170fe5ef6b87f187347 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py @@ -0,0 +1,76 @@ +import numpy as np +from scipy import stats +from sklearn import metrics +import torch + + +def d_prime(auc): + standard_normal = stats.norm() + d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) + return d_prime + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def calculate_stats(output, target): + """Calculate statistics including mAP, AUC, etc. + + Args: + output: 2d array, (samples_num, classes_num) + target: 2d array, (samples_num, classes_num) + + Returns: + stats: list of statistic of each class. + """ + + classes_num = target.shape[-1] + stats = [] + + # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet + acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) + + # Class-wise statistics + for k in range(classes_num): + # Average precision + avg_precision = metrics.average_precision_score( + target[:, k], output[:, k], average=None + ) + + # AUC + # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) + + # Precisions, recalls + (precisions, recalls, thresholds) = metrics.precision_recall_curve( + target[:, k], output[:, k] + ) + + # FPR, TPR + (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) + + save_every_steps = 1000 # Sample statistics to reduce size + dict = { + "precisions": precisions[0::save_every_steps], + "recalls": recalls[0::save_every_steps], + "AP": avg_precision, + "fpr": fpr[0::save_every_steps], + "fnr": 1.0 - tpr[0::save_every_steps], + # 'auc': auc, + # note acc is not class-wise, this is just to keep consistent with other metrics + "acc": acc, + } + stats.append(dict) + + return stats diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7d49e3e2cda157afcc5b925734f459e4a7d2a6c8 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py @@ -0,0 +1,1069 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audiosr.latent_diffusion.util import instantiate_from_config +from audiosr.latent_diffusion.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..3712857caf29672de9a1dea007ac98f4aaed1a0b --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1103 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audiosr.latent_diffusion.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audiosr.latent_diffusion.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=None, + context_attn_mask_list=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # if self.num_classes is not None: + # assert y.shape == (x.shape[0],) + # emb = emb + self.label_emb(y) + + if self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = th.cat([h, concate_tensor], dim=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fdc33082e4eed6bba217d15ab09cbe952c0116 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py @@ -0,0 +1,294 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from FlashSR.AudioSR.latent_diffusion.util import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + # betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py b/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/ema.py b/FlashSR/AudioSR/latent_diffusion/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py b/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ebec8699ae3591d067a6687c754a9b292c701d05 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py @@ -0,0 +1,682 @@ +import torch +import logging +import torch.nn as nn +#from audiosr.clap.open_clip import create_model +#from audiosr.clap.training.data import get_audio_features +import torchaudio +#from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel +import torch.nn.functional as F +from audiosr.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE +from audiosr.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder +from audiosr.latent_diffusion.util import instantiate_from_config + +from transformers import AutoTokenizer, T5Config + + +import numpy as np + +""" +The model forward function can return three types of data: +1. tensor: used directly as conditioning signal +2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. +3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. + +The output shape for the cross attention condition should be: +x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] + +All the returned data, in which will be used as diffusion input, will need to be in float type +""" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class PhonemeEncoder(nn.Module): + def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): + super().__init__() + """ + encoder = PhonemeEncoder(40) + data = torch.randint(0, 39, (2, 250)) + output = encoder(data) + import ipdb;ipdb.set_trace() + """ + assert pad_token_id is not None + + self.device = None + self.PAD_LENGTH = int(pad_length) + self.pad_token_id = pad_token_id + self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) + + self.text_encoder = TextEncoder( + n_vocab=vocabs_size, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ) + + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, 192, self.PAD_LENGTH)) + ) # [batchsize, seqlen, padlen] + self.learnable_positional_embedding.requires_grad = True + + # Required + def get_unconditional_condition(self, batchsize): + unconditional_tokens = self.pad_token_sequence.expand( + batchsize, self.PAD_LENGTH + ) + return self(unconditional_tokens) # Need to return float type + + # def get_unconditional_condition(self, batchsize): + + # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) + # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) + # return [hidden_state, attention_mask] # Need to return float type + + def _get_src_mask(self, phoneme): + src_mask = phoneme != self.pad_token_id + return src_mask + + def _get_src_length(self, phoneme): + src_mask = self._get_src_mask(phoneme) + length = torch.sum(src_mask, dim=-1) + return length + + # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): + # # src_length: [bs] + # # text_emb: [bs, 192, pad_length] + # # attention_mask: [bs, pad_length] + # mask = src_length[..., None, None] > 1 + # text_emb = text_emb * mask + + # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 + # return text_emb, attention_mask + + def forward(self, phoneme_idx): + if self.device is None: + self.device = self.learnable_positional_embedding.device + self.pad_token_sequence = self.pad_token_sequence.to(self.device) + + phoneme_idx = phoneme_idx.to(self.device) + + src_length = self._get_src_length(phoneme_idx) + text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) + text_emb = text_emb + self.learnable_positional_embedding + + # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) + + return [ + text_emb.permute(0, 2, 1), + text_emb_mask.squeeze(1), + ] # [2, 250, 192], [2, 250] + + +class VAEFeatureExtract(nn.Module): + def __init__(self, first_stage_config): + super().__init__() + # self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + self.vae = None + self.instantiate_first_stage(first_stage_config) + self.device = None + self.unconditional_cond = None + + def get_unconditional_condition(self, batchsize): + return self.unconditional_cond.unsqueeze(0).expand(batchsize, -1, -1, -1) + + def instantiate_first_stage(self, config): + self.vae = instantiate_from_config(config) + self.vae.eval() + for p in self.vae.parameters(): + p.requires_grad = False + self.vae.train = disabled_train + + def forward(self, batch): + assert self.vae.training == False + if self.device is None: + self.device = next(self.vae.parameters()).device + + with torch.no_grad(): + vae_embed = self.vae.encode(batch.unsqueeze(1)).sample() + + self.unconditional_cond = -11.4981 + vae_embed[0].clone() * 0.0 + + return vae_embed.detach() + + +class FlanT5HiddenState(nn.Module): + """ + llama = FlanT5HiddenState() + data = ["","this is not an empty sentence"] + encoder_hidden_states = llama(data) + import ipdb;ipdb.set_trace() + """ + + def __init__( + self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True + ): + super().__init__() + self.freeze_text_encoder = freeze_text_encoder + self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) + #self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name)) + if freeze_text_encoder: + self.model.eval() + for p in self.model.parameters(): + p.requires_grad = False + else: + print("=> The text encoder is learnable") + + self.empty_hidden_state_cfg = None + self.device = None + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + # device = param.device + if self.empty_hidden_state_cfg is None: + self.empty_hidden_state_cfg, _ = self([""]) + + hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() + attention_mask = ( + torch.ones((batchsize, hidden_state.size(1))) + .to(hidden_state.device) + .float() + ) + return [hidden_state, attention_mask] # Need to return float type + + def forward(self, batch): + param = next(self.model.parameters()) + if self.freeze_text_encoder: + assert param.requires_grad == False + + if self.device is None: + self.device = param.device + + # print("Manually change text") + # for i in range(len(batch)): + # batch[i] = "dog barking" + try: + return self.encode_text(batch) + except Exception as e: + print(e, batch) + logging.exception("An error occurred: %s", str(e)) + + def encode_text(self, prompt): + device = self.model.device + batch = self.tokenizer( + prompt, + max_length=128, # self.tokenizer.model_max_length + padding=True, + truncation=True, + return_tensors="pt", + ) + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( + device + ) + # Get text encoding + if self.freeze_text_encoder: + with torch.no_grad(): + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + else: + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + return [ + encoder_hidden_states.detach(), + attention_mask.float(), + ] + + +class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + freq_pool = min( + 8, + self.freq_pooling_factors[ + np.random.choice(list(range(len(self.freq_pooling_factors)))) + ], + ) + # freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = batch.device + + batch = batch.unsqueeze(1) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class AudioMAEConditionCTPoolRand(nn.Module): + """ + audiomae = AudioMAEConditionCTPool2x2() + data = torch.randn((4, 1024, 128)) + output = audiomae(data) + import ipdb;ipdb.set_trace() + exit(0) + """ + + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.requires_grad = False + + # Required + def get_unconditional_condition(self, batchsize): + param = next(self.audiomae.parameters()) + assert param.requires_grad == False + device = param.device + # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + token_num = int(512 / (time_pool * freq_pool)) + return [ + torch.zeros((batchsize, token_num, 768)).to(device).float(), + torch.ones((batchsize, token_num)).to(device).float(), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.size(-1) == 768 + representation = representation[:, 1:, :].transpose(1, 2) + bs, embedding_dim, token_num = representation.size() + representation = representation.reshape(bs, embedding_dim, 64, 8) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] + freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2d( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose(1, 2) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.size(-1) == 768 + x = F.normalize(x, p=2, dim=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.size(-2) == 1024 and batch.size(-1) == 128 + + if self.device is None: + self.device = next(self.audiomae.parameters()).device + + batch = batch.unsqueeze(1).to(self.device) + with torch.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + torch.ones((representation.size(0), representation.size(1))) + .to(representation.device) + .float(), + ] + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + enable_cuda=False, + sampling_rate=16000, + embed_mode="audio", + amodel="HTSAT-base", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + self.device = "cpu" # The model itself is on cpu + self.cuda = enable_cuda + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + #self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + self.model = self.model.to(self.device) + audio_cfg = self.model_cfg["audio_cfg"] + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ) + for p in self.model.parameters(): + p.requires_grad = False + self.unconditional_token = None + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + original_embed_mode = self.embed_mode + with torch.no_grad(): + self.embed_mode = "audio" + # MPS currently does not support ComplexFloat dtype and operator 'aten::_fft_r2c' + if self.cuda: + audio_emb = self(waveform.cuda()) + else: + audio_emb = self(waveform.to("cpu")) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) + self.embed_mode = original_embed_mode + return similarity.squeeze() + + def build_unconditional_emb(self): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + def forward(self, batch): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda" if self.cuda else "cpu", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + if self.unconditional_token is None: + self.build_unconditional_emb() + + # if(self.training_mode): + # assert self.model.training == True + # else: + # assert self.model.training == False + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + if not self.training: + print("INFO: clap model calculate the audio embedding as condition") + with torch.no_grad(): + # assert ( + # self.sampling_rate == 16000 + # ), "We only support 16000 sampling rate" + + # if self.random_mute: + # batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + if self.sampling_rate != 48000: + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + audio_data = batch.squeeze(1).to("cpu") + self.mel_transform = self.mel_transform.to(audio_data.device) + mel = self.mel_transform(audio_data) + audio_dict = get_audio_features( + audio_data, + mel, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + + if isinstance(batch, str) or ( + isinstance(batch, list) and len(batch) == 1 + ): + for key in text_data.keys(): + text_data[key] = text_data[key].unsqueeze(0) + + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..287c409dbfe535d95c65a2c69d013965fce97348 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py @@ -0,0 +1,430 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..9515724c12ab2f856b9a2ec14e38cc63df9b85d6 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py @@ -0,0 +1,161 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3085c72c98627eb66e7c944b18b218f5fcee5321 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py @@ -0,0 +1,50 @@ +import math +import torch +from torch import nn + +import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons +import audiosr.latent_diffusion.modules.phoneme_encoder.attentions as attentions + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4ad4ed1d5e34d95c8380768ec16405d789cc6de4 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2017 Keith Ito + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cea4352df38cbedbf7f8e30cacd2a14eb0eb17f2 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py @@ -0,0 +1,52 @@ +""" from https://github.com/keithito/tacotron """ +from audiosr.latent_diffusion.modules.phoneme_encoder.text import cleaners +from audiosr.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +cleaner = getattr(cleaners, "english_cleaners2") + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + text = cleaner(text) + return text diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..18fd8193fb280f6e10688dc78d61486683f2e41b --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py @@ -0,0 +1,110 @@ +""" from https://github.com/keithito/tacotron """ + +""" +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import re +#from unidecode import unidecode +#from phonemizer import phonemize + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_cleaners2(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize( + text, + language="en-us", + backend="espeak", + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = collapse_whitespace(phonemes) + return phonemes diff --git a/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..b419a4e6bcdbe663617c2edebff9100ab09baa6c --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py @@ -0,0 +1,16 @@ +""" from https://github.com/keithito/tacotron """ + +""" +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/FlashSR/AudioSR/latent_diffusion/util.py b/FlashSR/AudioSR/latent_diffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..47225418b8d35a6a010a4a44a7f8bfd013c0ac89 --- /dev/null +++ b/FlashSR/AudioSR/latent_diffusion/util.py @@ -0,0 +1,267 @@ +import importlib + +import torch +import numpy as np +from collections import abc + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +CACHE = { + "get_vits_phoneme_ids": { + "PAD_LENGTH": 310, + "_pad": "_", + "_punctuation": ';:,.!?¡¿—…"«»“” ', + "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", + "_special": "♪☎☒☝⚠", + } +} + +CACHE["get_vits_phoneme_ids"]["symbols"] = ( + [CACHE["get_vits_phoneme_ids"]["_pad"]] + + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + + list(CACHE["get_vits_phoneme_ids"]["_special"]) +) +CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { + s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) +} + + +def get_vits_phoneme_ids_no_padding(phonemes): + pad_token_id = 0 + pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] + _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] + batchsize = len(phonemes) + + clean_text = phonemes[0] + "⚠" + sequence = [] + + for symbol in clean_text: + if symbol not in _symbol_to_id.keys(): + print("%s is not in the vocabulary. %s" % (symbol, clean_text)) + symbol = "_" + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) + + sequence = sequence[:pad_length] + + return { + "phoneme_idx": torch.LongTensor(_pad_phonemes(sequence)) + .unsqueeze(0) + .expand(batchsize, -1) + } + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/FlashSR/BigVGAN/LICENSE b/FlashSR/BigVGAN/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e9663595cc28938f88d6299acd3ba791542e4c0c --- /dev/null +++ b/FlashSR/BigVGAN/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 NVIDIA CORPORATION. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/FlashSR/BigVGAN/LibriTTS/dev-clean.txt b/FlashSR/BigVGAN/LibriTTS/dev-clean.txt new file mode 100644 index 0000000000000000000000000000000000000000..563b86e601c12604b511548fe50a21d57524e438 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/dev-clean.txt @@ -0,0 +1,115 @@ +dev-clean/1272/128104/1272_128104_000001_000000|A 'JOLLY' ART CRITIC +dev-clean/1272/141231/1272_141231_000007_000003|And when he attacked, it was always there to beat him aside. +dev-clean/1272/141231/1272_141231_000033_000002|If anything, he was pressing the attack. +dev-clean/1462/170138/1462_170138_000012_000002|Dear me, Mac, the girl couldn't possibly be better, you know." +dev-clean/1462/170142/1462_170142_000002_000005|Alexander did not sit down. +dev-clean/1462/170142/1462_170142_000029_000001|"I meant to, but somehow I couldn't. +dev-clean/1462/170142/1462_170142_000046_000004|The sight of you, Bartley, to see you living and happy and successful-can I never make you understand what that means to me?" She pressed his shoulders gently. +dev-clean/1462/170145/1462_170145_000012_000003|There is a letter for you there, in my desk drawer. +dev-clean/1462/170145/1462_170145_000033_000000|She felt the strength leap in the arms that held her so lightly. +dev-clean/1673/143397/1673_143397_000031_000007|He attempted to remove or intimidate the leaders by a common sentence, of acquittal or condemnation; he invested his representatives at Ephesus with ample power and military force; he summoned from either party eight chosen deputies to a free and candid conference in the neighborhood of the capital, far from the contagion of popular frenzy. +dev-clean/174/168635/174_168635_000040_000000|To teach Cosette to read, and to let her play, this constituted nearly the whole of Jean Valjean's existence. +dev-clean/174/50561/174_50561_000058_000001|They have the end of the game to themselves.) +dev-clean/174/84280/174_84280_000015_000000|And perhaps in this story I have said enough for you to understand why Mary has identified herself with something world-wide, has added to herself a symbolical value, and why it is I find in the whole crowded spectacle of mankind, a quality that is also hers, a sense of fine things entangled and stifled and unable to free themselves from the ancient limiting jealousies which law and custom embody. +dev-clean/1919/142785/1919_142785_000063_000000|[Illustration: SHALOT.] +dev-clean/1919/142785/1919_142785_000131_000001|Cut the bread into thin slices, place them in a cool oven overnight, and when thoroughly dry and crisp, roll them down into fine crumbs. +dev-clean/1988/147956/1988_147956_000016_000009|He was neatly dressed. +dev-clean/1988/148538/1988_148538_000015_000007|These persons then displayed towards each other precisely the same puerile jealousies which animate the men of democracies, the same eagerness to snatch the smallest advantages which their equals contested, and the same desire to parade ostentatiously those of which they were in possession. +dev-clean/1988/24833/1988_24833_000028_000003|He's taking the kid for a walk when a thunderstorm blows up. +dev-clean/1988/24833/1988_24833_000059_000000|"Doesn't pay enough?" Pop asks. +dev-clean/1993/147149/1993_147149_000051_000002|So leaving kind messages to George and Jane Wilson, and hesitating whether she might dare to send a few kind words to Jem, and deciding that she had better not, she stepped out into the bright morning light, so fresh a contrast to the darkened room where death had been. +dev-clean/1993/147965/1993_147965_000003_000004|I suppose, in the crowded clutter of their cave, the old man had come to believe that peace and order had vanished from the earth, or existed only in the old world he had left so far behind. +dev-clean/1993/147966/1993_147966_000020_000003|We found the chickens asleep; perhaps they thought night had come to stay. +dev-clean/2035/147960/2035_147960_000019_000001|He is all over Jimmy's boots. I scream for him to run, but he just hit and hit that snake like he was crazy." +dev-clean/2035/147961/2035_147961_000011_000002|He grew more and more excited, and kept pointing all around his bed, as if there were things there and he wanted mr Shimerda to see them. +dev-clean/2035/147961/2035_147961_000025_000002|Beside a frozen pond something happened to the other sledge; peter saw it plainly. +dev-clean/2035/152373/2035_152373_000010_000007|saint Aidan, the Apostle of Northumbria, had refused the late Egfrid's father absolution, on one occasion, until he solemnly promised to restore their freedom to certain captives of this description. +dev-clean/2086/149214/2086_149214_000005_000002|It is a legend prolonging itself, from an epoch now gray in the distance, down into our own broad daylight, and bringing along with it some of its legendary mist, which the reader, according to his pleasure, may either disregard, or allow it to float almost imperceptibly about the characters and events for the sake of a picturesque effect. +dev-clean/2086/149220/2086_149220_000016_000003|In short, I make pictures out of sunshine; and, not to be too much dazzled with my own trade, I have prevailed with Miss Hepzibah to let me lodge in one of these dusky gables. +dev-clean/2086/149220/2086_149220_000028_000000|Phoebe was on the point of retreating, but turned back, with some hesitation; for she did not exactly comprehend his manner, although, on better observation, its feature seemed rather to be lack of ceremony than any approach to offensive rudeness. +dev-clean/2277/149874/2277_149874_000007_000001|Her husband asked a few questions and sat down to read the evening paper. +dev-clean/2277/149896/2277_149896_000007_000006|He saw only her pretty face and neat figure and wondered why life was not arranged so that such joy as he found with her could be steadily maintained. +dev-clean/2277/149896/2277_149896_000025_000008|He jangled it fiercely several times in succession, but without avail. +dev-clean/2277/149897/2277_149897_000023_000000|"Well?" said Hurstwood. +dev-clean/2277/149897/2277_149897_000046_000002|He troubled over many little details and talked perfunctorily to everybody. +dev-clean/2412/153954/2412_153954_000004_000005|Even in middle age they were still comely, and the old grey haired women at their cottage doors had a dignity, not to say majesty, of their own. +dev-clean/2428/83699/2428_83699_000009_000000|Now it is a remarkable thing that I have always had an extraordinary predilection for the name Madge. +dev-clean/2428/83699/2428_83699_000024_000004|I had long been wishing that an old-fashioned Christmas had been completely extinct before I had thought of adventuring in quest of one. +dev-clean/2428/83699/2428_83699_000047_000000|"Perhaps you had better come inside." +dev-clean/2428/83705/2428_83705_000015_000004|I did not want any unpleasantness; and I am quite sure there would have been unpleasantness had I demurred. +dev-clean/2428/83705/2428_83705_000034_000002|"And what," inquired mrs Macpherson, "has Mary Ann given you?" +dev-clean/251/118436/251_118436_000017_000001|This man was clad in a brown camel hair robe and sandals, and a green turban was on his head. His expression was tranquil, his gaze impersonal. +dev-clean/251/136532/251_136532_000000_000003|Fitzgerald was still trying to find out how the germ had been transmitted. +dev-clean/251/136532/251_136532_000020_000004|Without question, he had become, overnight, the most widely known archaeologist in history. +dev-clean/251/137823/251_137823_000025_000001|Or grazed, at least," Tom added thankfully. +dev-clean/251/137823/251_137823_000054_000002|The two girls were as much upset as Tom's mother. +dev-clean/2803/154320/2803_154320_000017_000004|Think of Lady Glenarvan; think of Mary Grant!" +dev-clean/2803/154328/2803_154328_000028_000000|Wilson and Olbinett joined their companions, and all united to dig through the wall-john with his dagger, the others with stones taken from the ground, or with their nails, while Mulrady, stretched along the ground, watched the native guard through a crevice of the matting. +dev-clean/2803/154328/2803_154328_000080_000003|Where chance led them, but at any rate they were free. +dev-clean/2803/161169/2803_161169_000011_000019|What do you think of that from the coal tar. +dev-clean/2902/9008/2902_9008_000009_000001|He was a Greek, also, but of a more common, and, perhaps, lower type; dark and fiery, thin and graceful; his delicate figure and cheeks, wasted by meditation, harmonised well with the staid and simple philosophic cloak which he wore as a sign of his profession. +dev-clean/2902/9008/2902_9008_000048_000003|For aught I know or care, the plot may be an exactly opposite one, and the Christians intend to murder all the Jews. +dev-clean/3000/15664/3000_15664_000013_000004|These volcanic caves are not wanting in interest, and it is well to light a pitch pine torch and take a walk in these dark ways of the underworld whenever opportunity offers, if for no other reason to see with new appreciation on returning to the sunshine the beauties that lie so thick about us. +dev-clean/3000/15664/3000_15664_000029_000002|Thus the Shasta River issues from a large lake like spring in Shasta Valley, and about two thirds of the volume of the McCloud gushes forth in a grand spring on the east side of the mountain, a few miles back from its immediate base. +dev-clean/3170/137482/3170_137482_000010_000004|The nobility, the merchants, even workmen in good circumstances, are never seen in the 'magazzino', for cleanliness is not exactly worshipped in such places. +dev-clean/3170/137482/3170_137482_000037_000001|He was celebrated in Venice not only for his eloquence and his great talents as a statesman, but also for the gallantries of his youth. +dev-clean/3536/23268/3536_23268_000028_000000|"It is not the first time, I believe, you have acted contrary to that, Miss Milner," replied mrs Horton, and affected a tenderness of voice, to soften the harshness of her words. +dev-clean/3576/138058/3576_138058_000019_000003|He wondered to see the lance leaning against the tree, the shield on the ground, and Don Quixote in armour and dejected, with the saddest and most melancholy face that sadness itself could produce; and going up to him he said, "Be not so cast down, good man, for you have not fallen into the hands of any inhuman Busiris, but into Roque Guinart's, which are more merciful than cruel." +dev-clean/3752/4943/3752_4943_000026_000002|Lie quiet!" +dev-clean/3752/4943/3752_4943_000056_000002|His flogging wouldn't have killed a flea." +dev-clean/3752/4944/3752_4944_000031_000000|"Well now!" said Meekin, with asperity, "I don't agree with you. Everybody seems to be against that poor fellow-Captain Frere tried to make me think that his letters contained a hidden meaning, but I don't believe they did. +dev-clean/3752/4944/3752_4944_000063_000003|He'd rather kill himself." +dev-clean/3752/4944/3752_4944_000094_000000|"The Government may go to----, and you, too!" roared Burgess. +dev-clean/3853/163249/3853_163249_000058_000000|"I've done it, mother: tell me you're not sorry." +dev-clean/3853/163249/3853_163249_000125_000004|Help me to be brave and strong, David: don't let me complain or regret, but show me what lies beyond, and teach me to believe that simply doing the right is reward and happiness enough." +dev-clean/5338/24615/5338_24615_000004_000003|It had been built at a period when castles were no longer necessary, and when the Scottish architects had not yet acquired the art of designing a domestic residence. +dev-clean/5338/284437/5338_284437_000031_000001|A powerful ruler ought to be rich and to live in a splendid palace. +dev-clean/5536/43358/5536_43358_000012_000001|Being a natural man, the Indian was intensely poetical. +dev-clean/5536/43359/5536_43359_000015_000000|The family was not only the social unit, but also the unit of government. +dev-clean/5694/64025/5694_64025_000004_000006|Our regiment was the advance guard on Saturday evening, and did a little skirmishing; but General Gladden's brigade passed us and assumed a position in our immediate front. +dev-clean/5694/64029/5694_64029_000006_000005|I read it, and looked up to hand it back to him, when I discovered that he had a pistol cocked and leveled in my face, and says he, "Drop that gun; you are my prisoner." I saw there was no use in fooling about it. +dev-clean/5694/64029/5694_64029_000024_000002|The ground was literally covered with blue coats dead; and, if I remember correctly, there were eighty dead horses. +dev-clean/5694/64038/5694_64038_000015_000002|I could not imagine what had become of him. +dev-clean/5895/34615/5895_34615_000013_000003|Man can do nothing to create beauty, but everything to produce ugliness. +dev-clean/5895/34615/5895_34615_000025_000000|With this exception, Gwynplaine's laugh was everlasting. +dev-clean/5895/34622/5895_34622_000029_000002|In the opposite corner was the kitchen. +dev-clean/5895/34629/5895_34629_000021_000005|The sea is a wall; and if Voltaire-a thing which he very much regretted when it was too late-had not thrown a bridge over to Shakespeare, Shakespeare might still be in England, on the other side of the wall, a captive in insular glory. +dev-clean/6241/61943/6241_61943_000020_000000|My uncle came out of his cabin pale, haggard, thin, but full of enthusiasm, his eyes dilated with pleasure and satisfaction. +dev-clean/6241/61946/6241_61946_000014_000000|The rugged summits of the rocky hills were dimly visible on the edge of the horizon, through the misty fogs; every now and then some heavy flakes of snow showed conspicuous in the morning light, while certain lofty and pointed rocks were first lost in the grey low clouds, their summits clearly visible above, like jagged reefs rising from a troublous sea. +dev-clean/6241/61946/6241_61946_000051_000001|Then my uncle, myself, and guide, two boatmen and the four horses got into a very awkward flat bottom boat. +dev-clean/6295/64301/6295_64301_000010_000002|The music was broken, and Joseph left alone with the dumb instruments. +dev-clean/6313/66125/6313_66125_000020_000002|"Are you hurt?" +dev-clean/6313/66125/6313_66125_000053_000000|"Are you ready?" +dev-clean/6313/66129/6313_66129_000011_000001|"Cold water is the most nourishing thing we've touched since last night." +dev-clean/6313/66129/6313_66129_000045_000004|Of course, dogs can't follow the trail of an animal as well, now, as they could with snow on the ground. +dev-clean/6313/66129/6313_66129_000081_000000|Stacy dismounted and removed the hat carefully to one side. +dev-clean/6313/76958/6313_76958_000029_000000|Instantly there was a chorus of yells and snarls from the disturbed cowpunchers, accompanied by dire threats as to what they would do to the gopher did he ever disturb their rest in that way again. +dev-clean/6313/76958/6313_76958_000073_000001|"Those fellows have to go out. +dev-clean/6319/275224/6319_275224_000014_000001|And what is the matter with the beautiful straggling branches, that they are to be cut off as fast as they appear? +dev-clean/6319/57405/6319_57405_000019_000000|"It is rather a silly thing to do," said Deucalion; "and yet there can be no harm in it, and we shall see what will happen." +dev-clean/6319/64726/6319_64726_000017_000002|Then the prince took the princess by the hand; she was dressed in great splendour, but he did not hint that she looked as he had seen pictures of his great grandmother look; he thought her all the more charming for that. +dev-clean/6345/93302/6345_93302_000000_000001|All LibriVox recordings are in the public domain. +dev-clean/6345/93302/6345_93302_000049_000000|The fine tact of a noble woman seemed to have deserted her. +dev-clean/6345/93302/6345_93302_000073_000000|So she said- +dev-clean/6345/93306/6345_93306_000024_000002|What is it? +dev-clean/652/130737/652_130737_000031_000001|Good aroma. +dev-clean/7850/111771/7850_111771_000009_000001|After various flanking movements and costly assaults, the problem of taking Lee narrowed itself down to a siege of Petersburg. +dev-clean/7850/281318/7850_281318_000012_000000|She began to show them how to weave the bits of things together into nests, as they should be made. +dev-clean/7850/286674/7850_286674_000006_000001|You would think that, with six legs apiece and three joints in each leg, they might walk quite fast, yet they never did. +dev-clean/7850/73752/7850_73752_000006_000003|What a Neapolitan ball was his career then! +dev-clean/7976/105575/7976_105575_000009_000000|The burying party the next morning found nineteen dead Rebels lying together at one place. +dev-clean/7976/105575/7976_105575_000017_000000|Our regiment now pursued the flying Rebels with great vigor. +dev-clean/7976/110124/7976_110124_000021_000001|"We two are older and wiser than you are. It is for us to determine what shall be done. +dev-clean/7976/110124/7976_110124_000053_000002|The doors were strong and held securely. +dev-clean/7976/110523/7976_110523_000027_000000|"We will go in here," said Hansel, "and have a glorious feast. +dev-clean/8297/275154/8297_275154_000008_000000|Was this man-haggard, pallid, shabby, looking at him piteously with bloodshot eyes-the handsome, pleasant, prosperous brother whom he remembered? +dev-clean/8297/275154/8297_275154_000024_000011|Tell me where my wife is living now?" +dev-clean/8297/275155/8297_275155_000013_000006|What a perfect gentleman!" +dev-clean/8297/275155/8297_275155_000037_000000|"Say thoroughly worthy of the course forced upon me and my daughter by your brother's infamous conduct-and you will be nearer the mark!" +dev-clean/8297/275156/8297_275156_000013_000005|No more of it now. +dev-clean/84/121123/84_121123_000009_000000|But in less than five minutes the staircase groaned beneath an extraordinary weight. +dev-clean/84/121123/84_121123_000054_000000|It was something terrible to witness the silent agony, the mute despair of Noirtier, whose tears silently rolled down his cheeks. +dev-clean/84/121550/84_121550_000064_000000|And lo! a sudden lustre ran across On every side athwart the spacious forest, Such that it made me doubt if it were lightning. +dev-clean/84/121550/84_121550_000156_000000|Nor prayer for inspiration me availed, By means of which in dreams and otherwise I called him back, so little did he heed them. +dev-clean/84/121550/84_121550_000247_000000|Thus Beatrice; and I, who at the feet Of her commandments all devoted was, My mind and eyes directed where she willed. +dev-clean/8842/302203/8842_302203_000001_000001|And I remember that on the ninth day, being overcome with intolerable pain, a thought came into my mind concerning my lady: but when it had a little nourished this thought, my mind returned to its brooding over mine enfeebled body. diff --git a/FlashSR/BigVGAN/LibriTTS/dev-other.txt b/FlashSR/BigVGAN/LibriTTS/dev-other.txt new file mode 100644 index 0000000000000000000000000000000000000000..9952353bccd0d62f22da2c409ba6764df09bc005 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/dev-other.txt @@ -0,0 +1,93 @@ +dev-other/116/288045/116_288045_000003_000000|PART one +dev-other/116/288045/116_288045_000034_000001|He was only an idol. +dev-other/116/288047/116_288047_000002_000002|Observing the sun, the moon, and the stars overhead, the primitive man wished to account for them. +dev-other/116/288048/116_288048_000001_000000|Let me now give an idea of the method I propose to follow in the study of this subject. +dev-other/116/288048/116_288048_000020_000003|Leaving out Judas, and counting Matthias, who was elected in his place, we have thirteen apostles. +dev-other/1255/138279/1255_138279_000012_000000|"One. +dev-other/1255/138279/1255_138279_000049_000001|Will it be by banns or license?" +dev-other/1255/74899/1255_74899_000020_000000|"Pardon me. +dev-other/1255/90407/1255_90407_000006_000001|But, as the rain gave not the least sign of cessation, he observed: 'I think we shall have to go back.' +dev-other/1255/90407/1255_90407_000039_000002|Into it they plodded without pause, crossing the harbour bridge about midnight, wet to the skin. +dev-other/1255/90413/1255_90413_000023_000001|'Now what the devil this means I cannot tell,' he said to himself, reflecting stock still for a moment on the stairs. +dev-other/1585/131718/1585_131718_000025_000009|Edison marginalized documents extensively. +dev-other/1630/141772/1630_141772_000000_000002|Suddenly he again felt that he was alive and suffering from a burning, lacerating pain in his head. +dev-other/1630/141772/1630_141772_000039_000000|The quiet home life and peaceful happiness of Bald Hills presented itself to him. +dev-other/1630/73710/1630_73710_000019_000003|I almost wish papa would return, though I dread to see him. +dev-other/1630/96099/1630_96099_000033_000001|Why did you not follow him? +dev-other/1650/157641/1650_157641_000037_000001|mr w m +dev-other/1650/173551/1650_173551_000025_000000|Pierre went into that gloomy study which he had entered with such trepidation in his benefactor's lifetime. +dev-other/1651/136854/1651_136854_000046_000005|I have, however, this of gratitude, that I think of you with regard, when I do not, perhaps, give the proofs which I ought, of being, Sir, +dev-other/1686/142278/1686_142278_000015_000000|'No! not doubts as to religion; not the slightest injury to that.' He paused. +dev-other/1686/142278/1686_142278_000042_000001|Margaret was nearly upset again into a burst of crying. +dev-other/1701/141759/1701_141759_000001_000001|Not till midwinter was the count at last handed a letter addressed in his son's handwriting. +dev-other/1701/141759/1701_141759_000048_000000|"Why should you be ashamed?" +dev-other/1701/141760/1701_141760_000013_000003|"I only sent you the note yesterday by Bolkonski-an adjutant of Kutuzov's, who's a friend of mine. +dev-other/1701/141760/1701_141760_000056_000000|In spite of Prince Andrew's disagreeable, ironical tone, in spite of the contempt with which Rostov, from his fighting army point of view, regarded all these little adjutants on the staff of whom the newcomer was evidently one, Rostov felt confused, blushed, and became silent. +dev-other/2506/13150/2506_13150_000022_000000|--Nay-if you don't believe me, you may read the chapter for your pains. +dev-other/3660/172182/3660_172182_000012_000007|And a year, and a second, and a third, he proceeded thus, until his fame had flown over the face of the kingdom. +dev-other/3660/172183/3660_172183_000011_000000|So the maiden went forward, keeping in advance of Geraint, as he had desired her; and it grieved him as much as his wrath would permit, to see a maiden so illustrious as she having so much trouble with the care of the horses. +dev-other/3660/172183/3660_172183_000019_000040|Come with me to the court of a son in law of my sister, which is near here, and thou shalt have the best medical assistance in the kingdom." +dev-other/3660/6517/3660_6517_000036_000002|Bright sunshine. +dev-other/3660/6517/3660_6517_000059_000005|Not a single one has lost his good spirits. +dev-other/3663/172005/3663_172005_000022_000000|She must cross the Slide Brook valley, if possible, and gain the mountain opposite. +dev-other/3663/172528/3663_172528_000016_000008|He had been brought by my very dear friend Luca Martini, who passed the larger portion of the day with me. +dev-other/3915/57461/3915_57461_000018_000001|In a fit of madness I was tempted to kill and rob you. +dev-other/3915/98647/3915_98647_000018_000006|Thus the old custom is passing away. +dev-other/4323/13259/4323_13259_000009_000011|What would Jesus do? +dev-other/4323/13259/4323_13259_000020_000003|It seems she had been recently converted during the evangelist's meetings, and was killed while returning from one of the meetings in company with other converts and some of her friends. +dev-other/4323/18416/4323_18416_000019_000001|So she was asked to sing at musicales and receptions without end, until Alexia exclaimed at last, "They are all raving, stark mad over her, and it's all Polly's own fault, the whole of it." +dev-other/4323/18416/4323_18416_000050_000000|"I know, child; you think your old Grandpapa does just about right," said mr King soothingly, and highly gratified. +dev-other/4323/18416/4323_18416_000079_000002|"And I can't tolerate any thoughts I cannot speak." +dev-other/4323/55228/4323_55228_000028_000000|"Pete told you that I didn't care for any girl, only to paint?" demanded Bertram, angry and mystified. +dev-other/4323/55228/4323_55228_000071_000000|There was another silence. +dev-other/4570/102353/4570_102353_000001_000000|CHAPTER four. +dev-other/4570/14911/4570_14911_000009_000002|EYES-Brown, dark hazel or hazel, not deep set nor bulgy, and with a mild expression. +dev-other/4570/56594/4570_56594_000012_000000|"'No,' says the gentleman. +dev-other/4831/18525/4831_18525_000028_000000|"Oh! isn't it 'Oats, Peas, Beans, and Barley grow'?" cried Polly, as they watched them intently. +dev-other/4831/18525/4831_18525_000078_000001|"I want to write, too, I do," she cried, very much excited. +dev-other/4831/18525/4831_18525_000122_000000|"O dear me!" exclaimed Polly, softly, for she couldn't even yet get over that dreadful beginning. +dev-other/4831/25894/4831_25894_000022_000003|The other days were very much like this; sometimes they made more, sometimes less, but Tommo always 'went halves;' and Tessa kept on, in spite of cold and weariness, for her plans grew as her earnings increased, and now she hoped to get useful things, instead of candy and toys alone. +dev-other/4831/29134/4831_29134_000001_000000|The session was drawing toward its close. +dev-other/4831/29134/4831_29134_000018_000000|"So this poor little boy grew up to be a man, and had to go out in the world, far from home and friends to earn his living. +dev-other/5543/27761/5543_27761_000019_000000|Her mother went to hide. +dev-other/5543/27761/5543_27761_000065_000000|"Agathya says so, madam," answered Fedosya; "it's she that knows." +dev-other/5543/27761/5543_27761_000107_000000|"Sima, my dear, don't agitate yourself," said Sergey Modestovich in a whisper. +dev-other/5849/50873/5849_50873_000026_000000|"He has promised to do so." +dev-other/5849/50873/5849_50873_000074_000000|"The boy did it! +dev-other/5849/50962/5849_50962_000010_000000|"It's a schooner," said mr Bingham to mr Minturn, "and she has a very heavy cargo." +dev-other/5849/50963/5849_50963_000009_000003|Well, it was a long, slow job to drag those heavy logs around that point, and just when we were making headway, along comes a storm that drove the schooner and canoes out of business." +dev-other/5849/50964/5849_50964_000018_000001|There were the shells to be looked after, the fish nets, besides Downy, the duck, and Snoop, the cat. +dev-other/6123/59150/6123_59150_000016_000001|He kicked him two or three times with his heel in the face. +dev-other/6123/59186/6123_59186_000008_000000|"Catering care" is an appalling phrase. +dev-other/6267/53049/6267_53049_000007_000001|"I'd better be putting my grey matter into that algebra instead of wasting it plotting for a party dress that I certainly can't get. +dev-other/6267/53049/6267_53049_000045_000001|I am named after her." +dev-other/6267/65525/6267_65525_000018_000000|Dear mr Lincoln: +dev-other/6267/65525/6267_65525_000045_000006|You can't mistake it." +dev-other/6455/66379/6455_66379_000020_000002|(Deal, sir, if you please; better luck next time.)" +dev-other/6455/67803/6455_67803_000038_000000|"Yes," he answered. +dev-other/6467/56885/6467_56885_000012_000001|As you are so generously taking her on trust, may she never cause you a moment's regret. +dev-other/6467/97061/6467_97061_000010_000000|A terrible battle ensued, in which both kings performed prodigies of valour. +dev-other/6841/88291/6841_88291_000006_000006|One stood waiting for them to finish, a sheaf of long j h stamping irons in his hand. +dev-other/6841/88291/6841_88291_000019_000006|Cries arose in a confusion: "Marker" "Hot iron!" "Tally one!" Dust eddied and dissipated. +dev-other/6841/88294/6841_88294_000010_000003|Usually I didn't bother with his talk, for it didn't mean anything, but something in his voice made me turn. +dev-other/6841/88294/6841_88294_000048_000000|He stood there looking straight at me without winking or offering to move. +dev-other/700/122866/700_122866_000006_000003|You've been thirteen for a month, so I suppose it doesn't seem such a novelty to you as it does to me. +dev-other/700/122866/700_122866_000023_000006|Ruby Gillis is rather sentimental. +dev-other/700/122867/700_122867_000012_000004|My career is closed. +dev-other/700/122867/700_122867_000033_000003|At the end of the week Marilla said decidedly: +dev-other/700/122868/700_122868_000015_000003|mrs Lynde says that all play acting is abominably wicked." +dev-other/700/122868/700_122868_000038_000001|And Ruby is in hysterics-oh, Anne, how did you escape?" +dev-other/7601/101622/7601_101622_000018_000002|The very girls themselves set them on: +dev-other/7601/175351/7601_175351_000031_000008|Still, during the nights which followed the fifteenth of August, darkness was never profound; although the sun set, he still gave sufficient light by refraction. +dev-other/7641/96252/7641_96252_000003_000006|For these are careful only for themselves, for their own egoism, just like the bandit, from whom they are only distinguished by the absurdity of their means. +dev-other/7641/96670/7641_96670_000013_000001|The mist lifted suddenly and she saw three strangers in the palace courtyard. +dev-other/7641/96684/7641_96684_000009_000000|"What years of happiness have been mine, O Apollo, through your friendship for me," said Admetus. +dev-other/7641/96684/7641_96684_000031_000002|How noble it was of Admetus to bring him into his house and give entertainment to him while such sorrow was upon him. +dev-other/7697/105815/7697_105815_000048_000002|And they brought out the jaw bone of an ass with which Samson did such great feats, and the sling and stone with which David slew Goliath of Gath. +dev-other/8173/294714/8173_294714_000006_000001|"Don't spoil my pleasure in seeing you again by speaking of what can never be! Have you still to be told how it is that you find me here alone with my child?" +dev-other/8173/294714/8173_294714_000027_000001|What was there to prevent her from insuring her life, if she pleased, and from so disposing of the insurance as to give Van Brandt a direct interest in her death? +dev-other/8254/115543/8254_115543_000034_000000|"Yes, and how he orders every one about him. +dev-other/8254/84205/8254_84205_000029_000000|"I'm not afraid of them hitting me, my lad," said Griggs confidently. "Being shot at by fellows with bows and arrows sounds bad enough, but there's not much risk here." +dev-other/8254/84205/8254_84205_000073_000000|"Right; I do, neighbour, and it's very handsome of you to offer me the chance to back out. +dev-other/8288/274162/8288_274162_000023_000000|"Exactly. +dev-other/8288/274162/8288_274162_000078_000000|"So much the worse. diff --git a/FlashSR/BigVGAN/LibriTTS/test-clean.txt b/FlashSR/BigVGAN/LibriTTS/test-clean.txt new file mode 100644 index 0000000000000000000000000000000000000000..5bbfab4da31088b3f42dcf2e15238e046d1781dc --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/test-clean.txt @@ -0,0 +1,97 @@ +test-clean/1089/134686/1089_134686_000001_000001|He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce. Stuff it into you, his belly counselled him. +test-clean/1089/134686/1089_134686_000020_000001|We can scut the whole hour. +test-clean/1089/134691/1089_134691_000004_000001|Yet her mistrust pricked him more keenly than his father's pride and he thought coldly how he had watched the faith which was fading down in his soul ageing and strengthening in her eyes. +test-clean/1089/134691/1089_134691_000027_000004|Now, at the name of the fabulous artificer, he seemed to hear the noise of dim waves and to see a winged form flying above the waves and slowly climbing the air. +test-clean/1188/133604/1188_133604_000018_000002|There are just four touches-fine as the finest penmanship-to do that beak; and yet you will find that in the peculiar paroquettish mumbling and nibbling action of it, and all the character in which this nibbling beak differs from the tearing beak of the eagle, it is impossible to go farther or be more precise. +test-clean/121/121726/121_121726_000046_000003|Tied to a woman. +test-clean/121/127105/121_127105_000024_000000|He laughed for the first time. +test-clean/1284/1180/1284_1180_000001_000000|The Crooked Magician +test-clean/1284/1181/1284_1181_000005_000000|The head of the Patchwork Girl was the most curious part of her. +test-clean/1320/122612/1320_122612_000019_000005|It is true that the horses are here, but the Hurons are gone; let us, then, hunt for the path by which they parted." +test-clean/1320/122612/1320_122612_000056_000002|Then he reappeared, creeping along the earth, from which his dress was hardly distinguishable, directly in the rear of his intended captive. +test-clean/1580/141083/1580_141083_000012_000000|"The first page on the floor, the second in the window, the third where you left it," said he. +test-clean/1580/141083/1580_141083_000041_000003|Above were three students, one on each story. +test-clean/1580/141083/1580_141083_000063_000001|Holmes held it out on his open palm in the glare of the electric light. +test-clean/1580/141083/1580_141083_000110_000001|Where were you when you began to feel bad?" +test-clean/1580/141084/1580_141084_000024_000002|Pencils, too, and knives-all was satisfactory. +test-clean/1580/141084/1580_141084_000060_000001|"I frankly admit that I am unable to prove it. +test-clean/1580/141084/1580_141084_000085_000000|"Good heavens! have you nothing to add?" cried Soames. +test-clean/1995/1826/1995_1826_000022_000001|Miss Taylor did not know much about cotton, but at least one more remark seemed called for. +test-clean/1995/1836/1995_1836_000016_000001|No, of course there was no immediate danger; but when people were suddenly thrust beyond their natural station, filled with wild ideas and impossible ambitions, it meant terrible danger to Southern white women. +test-clean/1995/1837/1995_1837_000024_000000|He heard that she was down stairs and ran to meet her with beating heart. +test-clean/2300/131720/2300_131720_000016_000005|Having travelled around the world, I had cultivated an indifference to any special difficulties of that kind. +test-clean/2300/131720/2300_131720_000030_000005|I telephoned again, and felt something would happen, but fortunately it did not. +test-clean/237/126133/237_126133_000002_000004|It got to be noticed finally; and one and all redoubled their exertions to make everything twice as pleasant as ever! +test-clean/237/126133/237_126133_000049_000000|But the chubby face didn't look up brightly, as usual: and the next moment, without a bit of warning, Phronsie sprang past them all, even Polly, and flung herself into mr King's arms, in a perfect torrent of sobs. +test-clean/237/134493/237_134493_000008_000003|Alexandra lets you sleep late. +test-clean/237/134500/237_134500_000001_000001|Frank sat up until a late hour reading the Sunday newspapers. +test-clean/237/134500/237_134500_000014_000000|"I don't know all of them, but I know lindens are. +test-clean/237/134500/237_134500_000034_000000|She sighed despondently. +test-clean/260/123286/260_123286_000019_000002|Therefore don't talk to me about views and prospects." +test-clean/260/123286/260_123286_000049_000005|He shakes his head negatively. +test-clean/260/123288/260_123288_000016_000002|It rushes on from the farthest recesses of the vast cavern. +test-clean/260/123288/260_123288_000043_000001|I could just see my uncle at full length on the raft, and Hans still at his helm and spitting fire under the action of the electricity which has saturated him. +test-clean/2830/3979/2830_3979_000007_000000|PREFACE +test-clean/2830/3980/2830_3980_000018_000001|Humble man that he was, he will not now take a back seat. +test-clean/2961/961/2961_961_000004_000037|Then your city did bravely, and won renown over the whole earth. +test-clean/2961/961/2961_961_000023_000003|But violent as were the internal and alimentary fluids, the tide became still more violent when the body came into contact with flaming fire, or the solid earth, or gliding waters, or the stormy wind; the motions produced by these impulses pass through the body to the soul and have the name of sensations. +test-clean/3570/5694/3570_5694_000009_000003|The canon of reputability is at hand and seizes upon such innovations as are, according to its standard, fit to survive. +test-clean/3570/5695/3570_5695_000001_000003|But the middle class wife still carries on the business of vicarious leisure, for the good name of the household and its master. +test-clean/3570/5695/3570_5695_000009_000005|Considered by itself simply-taken in the first degree-this added provocation to which the artisan and the urban laboring classes are exposed may not very seriously decrease the amount of savings; but in its cumulative action, through raising the standard of decent expenditure, its deterrent effect on the tendency to save cannot but be very great. +test-clean/3570/5696/3570_5696_000011_000006|For this is the basis of award of the instinct of workmanship, and that instinct is the court of final appeal in any question of economic truth or adequacy. +test-clean/3729/6852/3729_6852_000004_000003|In order to please her, I spoke to her of the Abbe Conti, and I had occasion to quote two lines of that profound writer. +test-clean/4077/13754/4077_13754_000002_000000|The troops, once in Utah, had to be provisioned; and everything the settlers could spare was eagerly bought at an unusual price. The gold changed hands. +test-clean/4446/2271/4446_2271_000003_000004|There's everything in seeing Hilda while she's fresh in a part. +test-clean/4446/2271/4446_2271_000020_000001|Lady Westmere is very fond of Hilda." +test-clean/4446/2273/4446_2273_000008_000002|I've no need for fine clothes in Mac's play this time, so I can afford a few duddies for myself. +test-clean/4446/2273/4446_2273_000027_000004|She did my blouses beautifully the last time I was there, and was so delighted to see me again. +test-clean/4446/2273/4446_2273_000046_000001|"Aren't you afraid to let the wind low like that on your neck? +test-clean/4446/2275/4446_2275_000013_000000|Hilda was pale by this time, and her eyes were wide with fright. +test-clean/4446/2275/4446_2275_000038_000006|"You want to tell me that you can only see me like this, as old friends do, or out in the world among people? +test-clean/4507/16021/4507_16021_000011_000000|It engenders a whole world, la pegre, for which read theft, and a hell, la pegrenne, for which read hunger. +test-clean/4507/16021/4507_16021_000030_000001|Facts form one of these, and ideas the other. +test-clean/4970/29093/4970_29093_000010_000000|Delightful illusion of paint and tinsel and silk attire, of cheap sentiment and high and mighty dialogue! +test-clean/4970/29093/4970_29093_000047_000000|"Never mind the map. +test-clean/4970/29095/4970_29095_000021_000000|"I will practice it." +test-clean/4970/29095/4970_29095_000055_000002|He took it with him from the Southern Hotel, when he went to walk, and read it over and again in an unfrequented street as he stumbled along. +test-clean/4992/41797/4992_41797_000014_000002|He keeps the thou shalt not commandments first rate, Hen Lord does! +test-clean/4992/41806/4992_41806_000020_000001|Thou who settest the solitary in families, bless the life that is sheltered here. +test-clean/5105/28241/5105_28241_000004_000004|The late astounding events, however, had rendered Procope manifestly uneasy, and not the less so from his consciousness that the count secretly partook of his own anxiety. +test-clean/5142/33396/5142_33396_000004_000004|At the prow I carved the head with open mouth and forked tongue thrust out. +test-clean/5142/33396/5142_33396_000039_000000|"The thralls were bringing in a great pot of meat. +test-clean/5142/36377/5142_36377_000013_000003|I liked Naomi Colebrook at first sight; liked her pleasant smile; liked her hearty shake of the hand when we were presented to each other. +test-clean/5639/40744/5639_40744_000003_000006|Mother! dear father! do you hear me? +test-clean/5639/40744/5639_40744_000022_000000|Just then Leocadia came to herself, and embracing the cross seemed changed into a sea of tears, and the gentleman remained in utter bewilderment, until his wife had repeated to him, from beginning to end, Leocadia's whole story; and he believed it, through the blessed dispensation of Heaven, which had confirmed it by so many convincing testimonies. +test-clean/5683/32865/5683_32865_000018_000000|Well, it was pretty-French, I dare say-a little set of tablets-a toy-the cover of enamel, studded in small jewels, with a slender border of symbolic flowers, and with a heart in the centre, a mosaic of little carbuncles, rubies, and other red and crimson stones, placed with a view to light and shade. +test-clean/5683/32866/5683_32866_000005_000000|'Did you see that?' said Wylder in my ear, with a chuckle; and, wagging his head, he added, rather loftily for him, 'Miss Brandon, I reckon, has taken your measure, Master Stanley, as well as i I wonder what the deuce the old dowager sees in him. +test-clean/5683/32866/5683_32866_000047_000002|I was not a bit afraid of being found out. +test-clean/5683/32879/5683_32879_000036_000002|Be he near, or be he far, I regard his very name with horror.' +test-clean/6829/68769/6829_68769_000011_000000|So as soon as breakfast was over the next morning Beth and Kenneth took one of the automobiles, the boy consenting unwillingly to this sort of locomotion because it would save much time. +test-clean/6829/68769/6829_68769_000051_000001|One morning she tried to light the fire with kerosene, and lost her sight. +test-clean/6829/68769/6829_68769_000089_000001|Why should you do all this?" +test-clean/6829/68771/6829_68771_000018_000003|A speakers' stand, profusely decorated, had been erected on the lawn, and hundreds of folding chairs provided for seats. +test-clean/6930/75918/6930_75918_000000_000001|Night. +test-clean/6930/81414/6930_81414_000041_000001|Here is his scarf, which has evidently been strained, and on it are spots of blood, while all around are marks indicating a struggle. +test-clean/7021/79740/7021_79740_000010_000006|I observe that, when you both wish for the same thing, you don't quarrel for it and try to pull it away from one another; but one waits like a lady until the other has done with it. +test-clean/7021/85628/7021_85628_000017_000000|"I am going to the court ball," answered Anders. +test-clean/7127/75946/7127_75946_000022_000002|It is necessary, therefore, that he should comply." +test-clean/7127/75946/7127_75946_000061_000001|Disdainful of a success of which Madame showed no acknowledgement, he thought of nothing but boldly regaining the marked preference of the princess. +test-clean/7127/75947/7127_75947_000035_000000|"Quite true, and I believe you are right. +test-clean/7176/88083/7176_88083_000002_000003|He was too imposing in appearance, too gorgeous in apparel, too bold and vigilant in demeanor to be so misunderstood. +test-clean/7176/88083/7176_88083_000017_000000|Immediately over his outstretched gleaming head flew the hawk. +test-clean/7176/92135/7176_92135_000011_000000|And, so on in the same vein for some thirty lines. +test-clean/7176/92135/7176_92135_000074_000001|Tea, please, Matthews. +test-clean/7729/102255/7729_102255_000011_000003|The Free State Hotel served as barracks. +test-clean/7729/102255/7729_102255_000028_000009|They were squads of Kansas militia, companies of "peaceful emigrants," or gangs of irresponsible outlaws, to suit the chance, the whim, or the need of the moment. +test-clean/8230/279154/8230_279154_000003_000002|In the present lecture I shall attempt the analysis of memory knowledge, both as an introduction to the problem of knowledge in general, and because memory, in some form, is presupposed in almost all other knowledge. +test-clean/8230/279154/8230_279154_000013_000003|One of these is context. +test-clean/8230/279154/8230_279154_000027_000000|A further stage is RECOGNITION. +test-clean/8455/210777/8455_210777_000022_000003|And immediately on his sitting down, there got up a gentleman to whom I had not been introduced before this day, and gave the health of Mrs Neverbend and the ladies of Britannula. +test-clean/8455/210777/8455_210777_000064_000001|Government that he shall be treated with all respect, and that those honours shall be paid to him which are due to the President of a friendly republic. +test-clean/8463/287645/8463_287645_000023_000001|For instance, Jacob Taylor was noticed on the record book as being twenty three years of age, and the name of his master was entered as "William Pollit;" but as Jacob had never been allowed to learn to read, he might have failed in giving a correct pronunciation of the name. +test-clean/8463/294825/8463_294825_000048_000000|- CENTIMETER Roughly two fifths of an inch +test-clean/8463/294828/8463_294828_000046_000001|Conseil did them in a flash, and I was sure the lad hadn't missed a thing, because he classified shirts and suits as expertly as birds and mammals. +test-clean/8555/284447/8555_284447_000018_000002|The poor Queen, by the way, was seldom seen, as she passed all her time playing solitaire with a deck that was one card short, hoping that before she had lived her entire six hundred years she would win the game. +test-clean/8555/284447/8555_284447_000049_000000|Now, indeed, the Boolooroo was as angry as he was amazed. +test-clean/8555/284449/8555_284449_000039_000000|When the courtiers and the people assembled saw the goat they gave a great cheer, for the beast had helped to dethrone their wicked Ruler. +test-clean/8555/292519/8555_292519_000041_000000|She was alone that night. He had broken into her courtyard. Above the gurgling gutters he heard- surely- a door unchained? diff --git a/FlashSR/BigVGAN/LibriTTS/test-other.txt b/FlashSR/BigVGAN/LibriTTS/test-other.txt new file mode 100644 index 0000000000000000000000000000000000000000..930b528eb7a3e92597cd1d5ad3d2519143ad2051 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/test-other.txt @@ -0,0 +1,103 @@ +test-other/1688/142285/1688_142285_000000_000000|'Margaret!' said mr Hale, as he returned from showing his guest downstairs; 'I could not help watching your face with some anxiety, when mr Thornton made his confession of having been a shop boy. +test-other/1688/142285/1688_142285_000046_000000|'No, mamma; that Anne Buckley would never have done.' +test-other/1998/15444/1998_15444_000012_000000|Simple filtration will sometimes suffice to separate the required substance; in other cases dialysis will be necessary, in order that crystalloid substances may be separated from colloid bodies. +test-other/1998/29454/1998_29454_000021_000001|Fried eggs and bacon-he had one egg and the man had three-bread and butter-and if the bread was thick, so was the butter-and as many cups of tea as you liked to say thank you for. +test-other/1998/29454/1998_29454_000053_000001|It almost looked, Dickie thought, as though he had brought them out for some special purpose. +test-other/1998/29455/1998_29455_000022_000000|It was a wonderful day. +test-other/1998/29455/1998_29455_000082_000003|But 'e's never let it out." +test-other/2414/128292/2414_128292_000003_000000|"What!" said he, "have not the most ludicrous things always happened to us old anchorites and saints? +test-other/2609/156975/2609_156975_000036_000004|The cruel fate of his people and the painful experience in Egypt that had driven him into the wilderness prepared his mind to receive this training. +test-other/3005/163389/3005_163389_000017_000001|And they laughed all the time, and that made the duke mad; and everybody left, anyway, before the show was over, but one boy which was asleep. +test-other/3005/163390/3005_163390_000023_000021|S'pose people left money laying around where he was what did he do? +test-other/3005/163391/3005_163391_000021_000000|"It's a pretty long journey. +test-other/3005/163399/3005_163399_000013_000002|When we got there she set me down in a split bottomed chair, and set herself down on a little low stool in front of me, holding both of my hands, and says: +test-other/3005/163399/3005_163399_000045_000000|He sprung to the window at the head of the bed, and that give mrs Phelps the chance she wanted. +test-other/3080/5040/3080_5040_000000_000010|You have no such ladies in Ireland? +test-other/3331/159605/3331_159605_000006_000002|I could do so much for all at home how I should enjoy that!" And Polly let her thoughts revel in the luxurious future her fancy painted. +test-other/3331/159605/3331_159605_000082_000000|"Who got up that nice idea, I should like to know?" demanded Polly, as Fanny stopped for breath. +test-other/3528/168656/3528_168656_000003_000003|She told wonders of the Abbey of Fontevrault,--that it was like a city, and that there were streets in the monastery. +test-other/3528/168669/3528_168669_000030_000000|A silence ensued. +test-other/3528/168669/3528_168669_000075_000000|"Like yourself, reverend Mother." +test-other/3528/168669/3528_168669_000123_000000|"But the commissary of police-" +test-other/3528/168669/3528_168669_000137_000000|"That is well." +test-other/3528/168669/3528_168669_000164_000008|I shall have my lever. +test-other/3538/142836/3538_142836_000021_000003|However, as late as the reigns of our two last Georges, fabulous sums were often expended upon fanciful desserts. +test-other/3538/163619/3538_163619_000054_000000|'Now he says that you are to make haste and throw yourself overboard,' answered the step mother. +test-other/3538/163622/3538_163622_000069_000000|So they travelled onwards again, for many and many a mile, over hill and dale. +test-other/3538/163624/3538_163624_000038_000000|Then Sigurd went down into that deep place, and dug many pits in it, and in one of the pits he lay hidden with his sword drawn. +test-other/367/130732/367_130732_000002_000001|Probably nowhere in San Francisco could one get lobster better served than in the Old Delmonico restaurant of the days before the fire. +test-other/3764/168670/3764_168670_000003_000000|"But you, Father Madeleine?" +test-other/3764/168670/3764_168670_000043_000000|"Yes." +test-other/3764/168670/3764_168670_000083_000005|He grumbled:-- +test-other/3764/168671/3764_168671_000012_000003|He did what he liked with him. +test-other/3764/168671/3764_168671_000046_000000|"Comrade!" cried Fauchelevent. +test-other/3997/180294/3997_180294_000023_000000|Then, when God allows love to a courtesan, that love, which at first seems like a pardon, becomes for her almost without penitence. +test-other/3997/180294/3997_180294_000065_000001|The count will be coming back, and there is nothing to be gained by his finding you here." +test-other/3997/180297/3997_180297_000034_000004|For these people we have to be merry when they are merry, well when they want to sup, sceptics like themselves. +test-other/3997/182399/3997_182399_000014_000003|Oh, my, no! +test-other/4198/61336/4198_61336_000000_000003|It is significant to note in this connection that the new king was an unswerving adherent of the cult of Ashur, by the adherents of which he was probably strongly supported. +test-other/4198/61336/4198_61336_000033_000001|Nabonassar had died and was succeeded by his son Nabu nadin zeri, who, after reigning for two years, was slain in a rebellion. +test-other/4294/14317/4294_14317_000022_000011|I do not condescend to smite you. He looked at me submissively and said nothing. +test-other/4294/35475/4294_35475_000018_000001|At last they reached a wide chasm that bounded the Ogre's domain. +test-other/4294/35475/4294_35475_000050_000002|They said, "We are only waiting to lay some wily plan to capture the Ogre." +test-other/4294/9934/4294_9934_000025_000000|"Gold; here it is." +test-other/4350/10919/4350_10919_000006_000000|"Immediately, princess. +test-other/4350/9170/4350_9170_000005_000001|Authority, in the sense in which the word is ordinarily understood, is a means of forcing a man to act in opposition to his desires. +test-other/4350/9170/4350_9170_000056_000000|But the fatal significance of universal military service, as the manifestation of the contradiction inherent in the social conception of life, is not only apparent in that. +test-other/4852/28311/4852_28311_000031_000001|After a step or two, not finding his friend beside him, he turned. +test-other/4852/28319/4852_28319_000013_000002|mr Wicker waited patiently beside him for a few moments for Chris to get up his courage. +test-other/533/1066/533_1066_000008_000000|"I mean," he persisted, "do you feel as though you could go through with something rather unusual?" +test-other/533/131562/533_131562_000018_000000|mr Huntingdon then went up stairs. +test-other/5442/41168/5442_41168_000002_000001|Sergey Ivanovitch, waiting till the malignant gentleman had finished speaking, said that he thought the best solution would be to refer to the act itself, and asked the secretary to find the act. +test-other/5442/41169/5442_41169_000003_000000|"He's such a blackguard! +test-other/5442/41169/5442_41169_000030_000000|"And with what he made he'd increase his stock, or buy some land for a trifle, and let it out in lots to the peasants," Levin added, smiling. He had evidently more than once come across those commercial calculations. +test-other/5484/24317/5484_24317_000040_000006|Let us hope that you will make this three leaved clover the luck promising four leaved one. +test-other/5484/24318/5484_24318_000015_000002|The blood of these innocent men would be on his head if he did not listen to her representations. +test-other/5484/24318/5484_24318_000068_000001|He was appearing before his companions only to give truth its just due. +test-other/5764/299665/5764_299665_000041_000004|He saw the seeds that man had planted wither and perish, but he sent no rain. +test-other/5764/299665/5764_299665_000070_000000|Think of the egotism of a man who believes that an infinite being wants his praise! +test-other/5764/299665/5764_299665_000102_000000|The first stone is that matter-substance-cannot be destroyed, cannot be annihilated. +test-other/5764/299665/5764_299665_000134_000000|You cannot reform these people with tracts and talk. +test-other/6070/63485/6070_63485_000025_000003|Hand me the cash, and I will hand you the pocketbook." +test-other/6070/86744/6070_86744_000027_000000|"Have you bachelor's apartments there? +test-other/6070/86745/6070_86745_000001_000002|Two windows only of the pavilion faced the street; three other windows looked into the court, and two at the back into the garden. +test-other/6128/63240/6128_63240_000012_000002|Neither five nor fifteen, and yet not ten exactly, but either nine or eleven. +test-other/6128/63240/6128_63240_000042_000002|mrs Luna explained to her sister that her freedom of speech was caused by his being a relation-though, indeed, he didn't seem to know much about them. +test-other/6128/63244/6128_63244_000002_000000|"I can't talk to those people, I can't!" said Olive Chancellor, with a face which seemed to plead for a remission of responsibility. +test-other/6432/63722/6432_63722_000026_000000|"Not the least in the world-not as much as you do," was the cool answer. +test-other/6432/63722/6432_63722_000050_000004|Queen Elizabeth was very fond of watches and clocks, and her friends, knowing that, used to present her with beautiful specimens. Some of the watches of her day were made in the form of crosses, purses, little books, and even skulls." +test-other/6432/63722/6432_63722_000080_000003|When it does it will create a sensation." +test-other/6432/63723/6432_63723_000026_000000|"No; but he will, or I'll sue him and get judgment. +test-other/6432/63723/6432_63723_000057_000000|"Then for the love of-" +test-other/6432/63723/6432_63723_000080_000000|"Hello, Harry! +test-other/6938/70848/6938_70848_000046_000003|Show me the source!" +test-other/6938/70848/6938_70848_000104_000000|With biting sarcasm he went on to speak of the Allied diplomats, till then contemptuous of Russia's invitation to an armistice, which had been accepted by the Central Powers. +test-other/7105/2330/7105_2330_000021_000000|"He won't go unless he has a brass band. +test-other/7105/2340/7105_2340_000015_000001|We feel that we must live on cream for the rest of our lives. +test-other/7902/96591/7902_96591_000008_000001|I did not come to frighten you; you frightened me." +test-other/7902/96591/7902_96591_000048_000000|"No," he thought to himself, "I don't believe they would kill me, but they would knock me about." +test-other/7902/96592/7902_96592_000024_000001|Once out of that room he could ran, and by daylight the smugglers dare not hunt him down. +test-other/7902/96592/7902_96592_000063_000000|"What for?" cried Ram. +test-other/7902/96594/7902_96594_000014_000001|These fellows are very cunning, but we shall be too many for them one of these days." +test-other/7902/96594/7902_96594_000062_000001|Keep a sharp look out on the cliff to see if Mr Raystoke is making signals for a boat. +test-other/7902/96595/7902_96595_000039_000000|The man shook his head, and stared as if he didn't half understand the drift of what was said. +test-other/7975/280057/7975_280057_000009_000000|Naturally we were Southerners in sympathy and in fact. +test-other/7975/280057/7975_280057_000025_000004|On reaching the camp the first person I saw whom I knew was Cole Younger. +test-other/7975/280076/7975_280076_000013_000001|I will give you this outline and sketch of my whereabouts and actions at the time of certain robberies with which I am charged. +test-other/7975/280084/7975_280084_000007_000000|But between the time we broke camp and the time they reached the bridge the three who went ahead drank a quart of whisky, and there was the initial blunder at Northfield. +test-other/7975/280085/7975_280085_000005_000002|Some of the boys wanted to kill him, on the theory that "dead men tell no tales," while others urged binding him and leaving him in the woods. +test-other/8131/117016/8131_117016_000005_000000|The Stonewall gang numbered perhaps five hundred. +test-other/8131/117016/8131_117016_000025_000001|"And don't let them get away!" +test-other/8131/117016/8131_117016_000047_000006|I can always go back to Earth, and I'll try to take you along. +test-other/8131/117017/8131_117017_000005_000000|Gordon hit the signal switch, and the Marspeaker let out a shrill whistle. +test-other/8131/117017/8131_117017_000020_000003|There's no graft out here." +test-other/8131/117029/8131_117029_000007_000002|Wrecks were being broken up, with salvageable material used for newer homes. Gordon came to a row of temporary bubbles, individual dwellings built like the dome, but opaque for privacy. +test-other/8131/117029/8131_117029_000023_000004|But there'll be pushers as long as weak men turn to drugs, and graft as long as voters allow the thing to get out of their hands. +test-other/8188/269288/8188_269288_000018_000000|A few moments later there came a tap at the door. +test-other/8188/269288/8188_269288_000053_000001|"Do you want to kill me? +test-other/8188/269290/8188_269290_000035_000001|"But now, Leslie, what is the trouble? +test-other/8188/269290/8188_269290_000065_000000|"I don't think she is quite well," replied Leslie. +test-other/8280/266249/8280_266249_000030_000000|The ladies were weary, and retired to their state rooms shortly after tea, but the gentlemen sought the open air again and paced the deck for some time. +test-other/8280/266249/8280_266249_000113_000000|It was the last game of cards for that trip. +test-other/8461/278226/8461_278226_000026_000000|Laura thanked the French artist and then took her husband's arm and walked away with him. +test-other/8461/281231/8461_281231_000029_000002|Before long the towering flames had surmounted every obstruction, and rose to the evening skies one huge and burning beacon, seen far and wide through the adjacent country; tower after tower crashed down, with blazing roof and rafter. diff --git a/FlashSR/BigVGAN/LibriTTS/val-full.txt b/FlashSR/BigVGAN/LibriTTS/val-full.txt new file mode 100644 index 0000000000000000000000000000000000000000..e06cd55f99d1b7835af616420d8f026eab221134 --- /dev/null +++ b/FlashSR/BigVGAN/LibriTTS/val-full.txt @@ -0,0 +1,119 @@ +train-clean-100/103/1241/103_1241_000000_000001|matthew Cuthbert is surprised +train-clean-100/1594/135914/1594_135914_000033_000001|He told them, that having taken refuge in a small village, he there fell sick; that some charitable peasants had taken care of him, but finding he did not recover, a camel driver had undertaken to carry him to the hospital at Bagdad. +train-clean-100/233/155990/233_155990_000018_000002|I did, however, receive aid from the Emperor of Germany. +train-clean-100/3240/131231/3240_131231_000041_000003|Some persons, thinking them to be sea fishes, placed them in salt water, according to mr Roberts. +train-clean-100/40/222/40_222_000026_000000|"No, read it yourself," cried Catherine, whose second thoughts were clearer. +train-clean-100/4406/16882/4406_16882_000014_000002|Then they set me upon a horse with my wounded child in my lap, and there being no furniture upon the horse's back, as we were going down a steep hill we both fell over the horse's head, at which they, like inhumane creatures, laughed, and rejoiced to see it, though I thought we should there have ended our days, as overcome with so many difficulties. +train-clean-100/5393/19218/5393_19218_000115_000000|"Where is it going then?" +train-clean-100/6147/34606/6147_34606_000013_000008|One was "a dancing master;" that is to say he made the rustics frisk about by pricking the calves of their legs with the point of his sword. +train-clean-100/6848/76049/6848_76049_000003_000007|But suppose she was not all ordinary female person.... +train-clean-100/7505/258964/7505_258964_000026_000007|During the Boer War horses and mules rose in price in the United States on account of British purchases. +train-clean-100/831/130739/831_130739_000015_000000|But enough of these revelations. +train-clean-100/887/123291/887_123291_000028_000000|Here the Professor laid hold of the fossil skeleton, and handled it with the skill of a dexterous showman. +train-clean-360/112/123216/112_123216_000035_000009|The wonderful day had come and Roy's violets had no place in it. +train-clean-360/1323/149236/1323_149236_000007_000004|It was vain to hope that mere words would quiet a nation which had not, in any age, been very amenable to control, and which was now agitated by hopes and resentments, such as great revolutions, following great oppressions, naturally engender. +train-clean-360/1463/134465/1463_134465_000058_000000|Both Sandy and I began to laugh. +train-clean-360/1748/1562/1748_1562_000067_000000|"Oh, Pocket, Pocket," said I; but by this time the party which had gone towards the house, rushed out again, shouting and screaming with laughter. +train-clean-360/1914/133440/1914_133440_000014_000001|With the last twenty or thirty feet of it a deadly nausea came upon me. +train-clean-360/207/143321/207_143321_000070_000002|The canoes were not on the river bank. +train-clean-360/2272/152267/2272_152267_000003_000001|After supper the knight shared his own bed with the leper. +train-clean-360/2517/135227/2517_135227_000006_000005|As I was anxious to witness some of their purely religious ceremonies, I wished to go. +train-clean-360/2709/158074/2709_158074_000054_000000|Meanwhile the women continued to protest. +train-clean-360/2929/86777/2929_86777_000009_000000|A long silence followed; the peach, like the grapes, fell to the ground. +train-clean-360/318/124224/318_124224_000022_000010|In spite of his prejudice against Edward, he could put himself into Mr Waller's place, and see the thing from his point of view. +train-clean-360/3368/170952/3368_170952_000006_000000|And can he be fearless of death, or will he choose death in battle rather than defeat and slavery, who believes the world below to be real and terrible? +train-clean-360/3549/9203/3549_9203_000005_000004|We must hope so. There are examples. +train-clean-360/3835/178028/3835_178028_000007_000001|That day Prince Vasili no longer boasted of his protege Kutuzov, but remained silent when the commander in chief was mentioned. +train-clean-360/3994/149798/3994_149798_000005_000002|Afterward we can visit the mountain and punish the cruel magician of the Flatheads." +train-clean-360/4257/6397/4257_6397_000009_000000|At that time Nostromo had been already long enough in the country to raise to the highest pitch Captain Mitchell's opinion of the extraordinary value of his discovery. +train-clean-360/454/134728/454_134728_000133_000000|After a week of physical anguish, Unrest and pain, and feverish heat, Toward the ending day a calm and lull comes on, Three hours of peace and soothing rest of brain. +train-clean-360/4848/28247/4848_28247_000026_000002|Had he gained this arduous height only to behold the rocks carpeted with ice and snow, and reaching interminably to the far off horizon? +train-clean-360/5039/1189/5039_1189_000091_000000|The Shaggy Man sat down again and seemed well pleased. +train-clean-360/5261/19373/5261_19373_000011_000001|Some cause was evidently at work on this distant planet, causing it to disagree with its motion as calculated according to the law of gravitation. +train-clean-360/5538/70919/5538_70919_000032_000001|Only one person in the world could have laid those discoloured pearls at his door in the dead of night. The black figure in the garden, with the chiffon fluttering about its head, was Evelina Grey-or what was left of her. +train-clean-360/5712/48848/5712_48848_000060_000003|Lily for the time had been raised to a pinnacle,--a pinnacle which might be dangerous, but which was, at any rate, lofty. +train-clean-360/5935/43322/5935_43322_000050_000002|I think too-yes, I think that on the whole the ritual is impressive. +train-clean-360/6115/58433/6115_58433_000007_000002|We must run the risk." +train-clean-360/6341/64956/6341_64956_000040_000000|"Why, papa, I thought we were going to have such a nice time, and she just spoiled it all." +train-clean-360/6509/67147/6509_67147_000028_000003|It "was n't done" in England. +train-clean-360/6694/70837/6694_70837_000027_000002|There an enormous smiling sailor stopped me, and when I showed my pass, just said, "If you were Saint Michael himself, comrade, you couldn't pass here!" Through the glass of the door I made out the distorted face and gesticulating arms of a French correspondent, locked in.... +train-clean-360/6956/76046/6956_76046_000055_000001|Twelve hundred, fifteen hundred millions perhaps." +train-clean-360/7145/87280/7145_87280_000004_000003|This modern Ulysses made a masterful effort, but alas! had no ships to carry him away, and no wax with which to fill his ears. +train-clean-360/7314/77782/7314_77782_000011_000000|"Well, then, what in thunder is the matter with you?" cried the Lawyer, irritated. +train-clean-360/7525/92915/7525_92915_000034_000001|It was desperate, too, and lasted nearly all day-and it was one of the important battles of the world, although the numbers engaged in it were not large. +train-clean-360/7754/108640/7754_108640_000001_000004|Was I aware-was I fully aware of the discrepancy between us? +train-clean-360/7909/106369/7909_106369_000006_000002|And Colchian Aea lies at the edge of Pontus and of the world." +train-clean-360/8011/280922/8011_280922_000009_000000|He stretched out his hand, and all at once stroked my cheek. +train-clean-360/8176/115046/8176_115046_000027_000001|"Bless my soul, I never can understand it!" +train-clean-360/8459/292347/8459_292347_000015_000000|A woman near Gort, in Galway, says: 'There is a boy, now, of the Cloran's; but I wouldn't for the world let them think I spoke of him; it's two years since he came from America, and since that time he never went to Mass, or to church, or to fairs, or to market, or to stand on the cross roads, or to hurling, or to nothing. +train-clean-360/8699/291107/8699_291107_000003_000005|He leaned closer over it, regardless of the thin choking haze that spread about his face. In his attitude there was a rigidity of controlled excitement out of keeping with the seeming harmlessness of the experiment. +train-clean-360/8855/283242/8855_283242_000061_000000|"That couldn't be helped, grannie. +train-other-500/102/129232/102_129232_000050_000000|Is it otherwise in the newest romance? +train-other-500/1124/134775/1124_134775_000087_000001|Some of them are enclosed only by hedges, which lends a cheerful aspect to the street. +train-other-500/1239/138254/1239_138254_000010_000001|It was past twelve when all preparations were finished. +train-other-500/1373/132103/1373_132103_000056_000000|So they moved on. +train-other-500/1566/153036/1566_153036_000087_000003|You enter the river close by the trees, and then keep straight for the pile of stones, which is some fifty yards higher up, for the ford crosses the river at an angle." +train-other-500/1653/142352/1653_142352_000005_000002|If he should not come! +train-other-500/1710/133294/1710_133294_000023_000000|When the Indians were the sole inhabitants of the wilds from whence they have since been expelled, their wants were few. +train-other-500/1773/139602/1773_139602_000032_000001|When the rabbit saw that the badger was getting well, he thought of another plan by which he could compass the creature's death. +train-other-500/1920/1793/1920_1793_000037_000001|She has a little Blenheim lapdog, that she loves a thousand times more than she ever will me!" +train-other-500/2067/143535/2067_143535_000009_000002|Indeed, there, to the left, was a stone shelf with a little ledge to it three inches or so high, and on the shelf lay what I took to be a corpse; at any rate, it looked like one, with something white thrown over it. +train-other-500/2208/11020/2208_11020_000037_000001|It's at my place over there.' +train-other-500/2312/157868/2312_157868_000019_000002|I am the manager of the theatre, and I'm thundering glad that your first play has been produced at the 'New York,' sir. +train-other-500/2485/151992/2485_151992_000028_000005|At last he looked up at his wife and said, in a gentle tone: +train-other-500/2587/54186/2587_54186_000015_000000|Concerning the work as a whole he wrote to Clara while in the throes of composition: "This music now in me, and always such beautiful melodies! +train-other-500/2740/288813/2740_288813_000018_000003|But Philip had kept him apart, had banked him off, and yet drained him to the dregs. +train-other-500/2943/171001/2943_171001_000122_000000|The sound of his voice pronouncing her name aroused her. +train-other-500/3063/138651/3063_138651_000028_000000|But, as may be imagined, the unfortunate john was as much surprised by this rencounter as the other two. +train-other-500/3172/166439/3172_166439_000050_000000|And now at last was clear a thing that had puzzled greatly-the mechanism of that opening process by which sphere became oval disk, pyramid a four pointed star and-as I had glimpsed in the play of the Little Things about Norhala, could see now so plainly in the Keeper-the blocks took this inverted cruciform shape. +train-other-500/331/132019/331_132019_000038_000000|"I say, this is folly! +train-other-500/3467/166570/3467_166570_000054_000001|Does he never mention Orlando?" +train-other-500/3587/140711/3587_140711_000015_000001|O fie, mrs Jervis, said I, how could you serve me so? Besides, it looks too free both in me, and to him. +train-other-500/3675/187020/3675_187020_000026_000001|"I wonder what would be suitable? +train-other-500/3819/134146/3819_134146_000019_000001|Also the figure half hidden by the cupboard door-was a female figure, massive, and in flowing robes. +train-other-500/3912/77626/3912_77626_000003_000004|You may almost distinguish the figures on the clock that has just told the hour. +train-other-500/4015/63729/4015_63729_000058_000000|"It does." +train-other-500/413/22436/413_22436_000035_000003|I conjecture, the French squadron is bound for Malta and Alexandria, and the Spanish fleet for the attack of Minorca." +train-other-500/4218/41159/4218_41159_000028_000002|Yes? That worries Alexey. +train-other-500/4352/10940/4352_10940_000037_000002|He doesn't exist." +train-other-500/4463/26871/4463_26871_000023_000000|"I did not notice him following me," she said timidly. +train-other-500/4591/14356/4591_14356_000019_000000|"Within three days," cried the enchanter, loudly, "bring Rinaldo and Ricciardetto into the pass of Ronces Valles. +train-other-500/4738/291957/4738_291957_000000_000001|ODE ON THE SPRING. +train-other-500/4824/36029/4824_36029_000045_000003|And indeed Janet herself had taken no part in the politics, content merely to advise the combatants upon their demeanour. +train-other-500/4936/65528/4936_65528_000014_000007|I immediately responded, "Yes, they are most terrible struck on each other," and I said it in a tone that indicated I thought it a most beautiful and lovely thing that they should be so. +train-other-500/5019/38670/5019_38670_000017_000000|"Let me make you a present of the gloves," she said, with her irresistible smile. +train-other-500/5132/33409/5132_33409_000016_000001|They waited on the table in Valhalla. +train-other-500/52/121057/52_121057_000019_000000|"I," cried the steward with a strange expression. +train-other-500/5321/53046/5321_53046_000025_000003|I gather from what mrs joel said that she's rather touched in her mind too, and has an awful hankering to get home here-to this very house. +train-other-500/5429/210770/5429_210770_000029_000006|But this was not all. +train-other-500/557/129797/557_129797_000072_000001|The guns were manned, the gunners already kindling fuses, when the buccaneer fleet, whilst still heading for Palomas, was observed to bear away to the west. +train-other-500/572/128861/572_128861_000016_000002|My home was desolate. +train-other-500/5826/53497/5826_53497_000044_000001|If it be as you say, he will have shown himself noble, and his nobility will have consisted in this, that he has been willing to take that which he does not want, in order that he may succour one whom he loves. +train-other-500/5906/52158/5906_52158_000055_000000|The impression that he gets this knowledge or suspicion from the outside is due, the scientists say, to the fact that his thinking has proceeded at such lightning like speed that he was unable to watch the wheels go round. +train-other-500/6009/57639/6009_57639_000038_000000|This, friendly reader, is my only motive. +train-other-500/6106/58196/6106_58196_000007_000001|I tell you that you must make the dress. +train-other-500/6178/86034/6178_86034_000079_000004|Then she will grow calmer, and will know you again. +train-other-500/6284/63091/6284_63091_000133_000001|I don't want to go anywhere where anybody'll see me." +train-other-500/6436/104980/6436_104980_000009_000002|I guess you never heard about this house." +train-other-500/6540/232291/6540_232291_000017_000003|The girl was not wholly a savage. +train-other-500/6627/67844/6627_67844_000046_000002|The other girls had stopped talking, and now looked at Sylvia as if wondering what she would say. +train-other-500/6707/77351/6707_77351_000002_000006|But our first words I may give you, because though they conveyed nothing to me at the time, afterwards they meant much. +train-other-500/6777/76694/6777_76694_000013_000011|When they are forcibly put out of Garraway's on Saturday night-which they must be, for they never would go out of their own accord-where do they vanish until Monday morning? +train-other-500/690/133452/690_133452_000011_000000|Campany lifted his quill pen and pointed to a case of big leather bound volumes in a far corner of the room. +train-other-500/7008/34667/7008_34667_000032_000002|What had happened? +train-other-500/7131/92815/7131_92815_000039_000001|The cabman tried to pass to the left, but a heavy express wagon cut him off. +train-other-500/7220/77911/7220_77911_000005_000000|"Do? +train-other-500/7326/245693/7326_245693_000008_000000|Whether the Appetite Is a Special Power of the Soul? +train-other-500/7392/105672/7392_105672_000013_000005|Whoever, being required, refused to answer upon oath to any article of this act of settlement, was declared to be guilty of treason; and by this clause a species of political inquisition was established in the kingdom, as well as the accusations of treason multiplied to an unreasonable degree. +train-other-500/7512/98636/7512_98636_000017_000002|A man thus rarely makes provision for the future, and looks with scorn on foreign customs which seem to betoken a fear lest, in old age, ungrateful children may neglect their parents and cast them aside. +train-other-500/7654/258963/7654_258963_000007_000007|Egypt, for a time reduced to a semi desert condition, has only in the past century been restored to a certain extent by the use of new methods and a return to the old ones. +train-other-500/7769/99396/7769_99396_000020_000002|I had to go out once a day in search of food. +train-other-500/791/127519/791_127519_000086_000000|This was how it came about. +train-other-500/8042/113769/8042_113769_000021_000000|House the second. +train-other-500/8180/274725/8180_274725_000010_000000|"What fools men are in love matters," quoth Patty to herself-"at least most men!" with a thought backward to Mark's sensible choosing. +train-other-500/8291/282929/8291_282929_000031_000006|He's in a devil of a-Well, he needs the money, and I've got to get it for him. You know my word's good, Cooper." +train-other-500/8389/120181/8389_120181_000022_000000|"No," I answered. +train-other-500/8476/269293/8476_269293_000078_000001|Annie, in some wonder, went downstairs alone. +train-other-500/8675/295195/8675_295195_000004_000004|Everything had gone on prosperously with them, and they had reared many successive families of young Nutcrackers, who went forth to assume their places in the forest of life, and to reflect credit on their bringing up,--so that naturally enough they began to have a very easy way of considering themselves models of wisdom. +train-other-500/9000/282381/9000_282381_000016_000008|Bank facings seemed to indicate that the richest pay dirt lay at bed rock. +train-other-500/978/132494/978_132494_000017_000001|And what made you come at that very minute? diff --git a/FlashSR/BigVGAN/README.md b/FlashSR/BigVGAN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a6cff37786a486deb55bc070254027aa492c2e92 --- /dev/null +++ b/FlashSR/BigVGAN/README.md @@ -0,0 +1,95 @@ +## BigVGAN: A Universal Neural Vocoder with Large-Scale Training +#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon + +
